Skip to main content

arrow_ipc/
writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Writers
19//!
20//! # Notes
21//!
22//! [`FileWriter`] and [`StreamWriter`] have similar interfaces,
23//! however the [`FileWriter`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27use std::cmp::min;
28use std::collections::HashMap;
29use std::io::{BufWriter, Write};
30use std::mem::size_of;
31use std::sync::Arc;
32
33use flatbuffers::FlatBufferBuilder;
34
35use arrow_array::builder::BufferBuilder;
36use arrow_array::cast::*;
37use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
38use arrow_array::*;
39use arrow_buffer::bit_util;
40use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
41use arrow_data::{ArrayData, ArrayDataBuilder, BufferSpec, layout};
42use arrow_schema::*;
43
44use crate::CONTINUATION_MARKER;
45use crate::compression::CompressionCodec;
46pub use crate::compression::CompressionContext;
47use crate::convert::IpcSchemaEncoder;
48
49/// IPC write options used to control the behaviour of the [`IpcDataGenerator`]
50#[derive(Debug, Clone)]
51pub struct IpcWriteOptions {
52    /// Write padding after memory buffers to this multiple of bytes.
53    /// Must be 8, 16, 32, or 64 - defaults to 64.
54    alignment: u8,
55    /// The legacy format is for releases before 0.15.0, and uses metadata V4
56    write_legacy_ipc_format: bool,
57    /// The metadata version to write. The Rust IPC writer supports V4+
58    ///
59    /// *Default versions per crate*
60    ///
61    /// When creating the default IpcWriteOptions, the following metadata versions are used:
62    ///
63    /// version 2.0.0: V4, with legacy format enabled
64    /// version 4.0.0: V5
65    metadata_version: crate::MetadataVersion,
66    /// Compression, if desired. Will result in a runtime error
67    /// if the corresponding feature is not enabled
68    batch_compression_type: Option<crate::CompressionType>,
69    /// How to handle updating dictionaries in IPC messages
70    dictionary_handling: DictionaryHandling,
71}
72
73impl IpcWriteOptions {
74    /// Configures compression when writing IPC files.
75    ///
76    /// Will result in a runtime error if the corresponding feature
77    /// is not enabled
78    pub fn try_with_compression(
79        mut self,
80        batch_compression_type: Option<crate::CompressionType>,
81    ) -> Result<Self, ArrowError> {
82        self.batch_compression_type = batch_compression_type;
83
84        if self.batch_compression_type.is_some()
85            && self.metadata_version < crate::MetadataVersion::V5
86        {
87            return Err(ArrowError::InvalidArgumentError(
88                "Compression only supported in metadata v5 and above".to_string(),
89            ));
90        }
91        Ok(self)
92    }
93    /// Try to create IpcWriteOptions, checking for incompatible settings
94    pub fn try_new(
95        alignment: usize,
96        write_legacy_ipc_format: bool,
97        metadata_version: crate::MetadataVersion,
98    ) -> Result<Self, ArrowError> {
99        let is_alignment_valid =
100            alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
101        if !is_alignment_valid {
102            return Err(ArrowError::InvalidArgumentError(
103                "Alignment should be 8, 16, 32, or 64.".to_string(),
104            ));
105        }
106        let alignment: u8 = u8::try_from(alignment).expect("range already checked");
107        match metadata_version {
108            crate::MetadataVersion::V1
109            | crate::MetadataVersion::V2
110            | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
111                "Writing IPC metadata version 3 and lower not supported".to_string(),
112            )),
113            #[allow(deprecated)]
114            crate::MetadataVersion::V4 => Ok(Self {
115                alignment,
116                write_legacy_ipc_format,
117                metadata_version,
118                batch_compression_type: None,
119                dictionary_handling: DictionaryHandling::default(),
120            }),
121            crate::MetadataVersion::V5 => {
122                if write_legacy_ipc_format {
123                    Err(ArrowError::InvalidArgumentError(
124                        "Legacy IPC format only supported on metadata version 4".to_string(),
125                    ))
126                } else {
127                    Ok(Self {
128                        alignment,
129                        write_legacy_ipc_format,
130                        metadata_version,
131                        batch_compression_type: None,
132                        dictionary_handling: DictionaryHandling::default(),
133                    })
134                }
135            }
136            z => Err(ArrowError::InvalidArgumentError(format!(
137                "Unsupported crate::MetadataVersion {z:?}"
138            ))),
139        }
140    }
141
142    /// Configure how dictionaries are handled in IPC messages
143    pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
144        self.dictionary_handling = dictionary_handling;
145        self
146    }
147}
148
149impl Default for IpcWriteOptions {
150    fn default() -> Self {
151        Self {
152            alignment: 64,
153            write_legacy_ipc_format: false,
154            metadata_version: crate::MetadataVersion::V5,
155            batch_compression_type: None,
156            dictionary_handling: DictionaryHandling::default(),
157        }
158    }
159}
160
161#[derive(Debug, Default)]
162/// Handles low level details of encoding [`Array`] and [`Schema`] into the
163/// [Arrow IPC Format].
164///
165/// # Example
166/// ```
167/// # fn run() {
168/// # use std::sync::Arc;
169/// # use arrow_array::UInt64Array;
170/// # use arrow_array::RecordBatch;
171/// # use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
172///
173/// // Create a record batch
174/// let batch = RecordBatch::try_from_iter(vec![
175///  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
176/// ]).unwrap();
177///
178/// // Error of dictionary ids are replaced.
179/// let error_on_replacement = true;
180/// let options = IpcWriteOptions::default();
181/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement);
182///
183/// let mut compression_context = CompressionContext::default();
184///
185/// // encode the batch into zero or more encoded dictionaries
186/// // and the data for the actual array.
187/// let data_gen = IpcDataGenerator::default();
188/// let (encoded_dictionaries, encoded_message) = data_gen
189///   .encode(&batch, &mut dictionary_tracker, &options, &mut compression_context)
190///   .unwrap();
191/// # }
192/// ```
193///
194/// [Arrow IPC Format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
195pub struct IpcDataGenerator {}
196
197impl IpcDataGenerator {
198    /// Converts a schema to an IPC message along with `dictionary_tracker`
199    /// and returns it encoded inside [EncodedData] as a flatbuffer.
200    pub fn schema_to_bytes_with_dictionary_tracker(
201        &self,
202        schema: &Schema,
203        dictionary_tracker: &mut DictionaryTracker,
204        write_options: &IpcWriteOptions,
205    ) -> EncodedData {
206        let mut fbb = FlatBufferBuilder::new();
207        let schema = {
208            let fb = IpcSchemaEncoder::new()
209                .with_dictionary_tracker(dictionary_tracker)
210                .schema_to_fb_offset(&mut fbb, schema);
211            fb.as_union_value()
212        };
213
214        let mut message = crate::MessageBuilder::new(&mut fbb);
215        message.add_version(write_options.metadata_version);
216        message.add_header_type(crate::MessageHeader::Schema);
217        message.add_bodyLength(0);
218        message.add_header(schema);
219        // TODO: custom metadata
220        let data = message.finish();
221        fbb.finish(data, None);
222
223        let data = fbb.finished_data();
224        EncodedData {
225            ipc_message: data.to_vec(),
226            arrow_data: vec![],
227        }
228    }
229
230    fn _encode_dictionaries<I: Iterator<Item = i64>>(
231        &self,
232        column: &ArrayRef,
233        encoded_dictionaries: &mut Vec<EncodedData>,
234        dictionary_tracker: &mut DictionaryTracker,
235        write_options: &IpcWriteOptions,
236        dict_id: &mut I,
237        compression_context: &mut CompressionContext,
238    ) -> Result<(), ArrowError> {
239        match column.data_type() {
240            DataType::Struct(fields) => {
241                let s = as_struct_array(column);
242                for (field, column) in fields.iter().zip(s.columns()) {
243                    self.encode_dictionaries(
244                        field,
245                        column,
246                        encoded_dictionaries,
247                        dictionary_tracker,
248                        write_options,
249                        dict_id,
250                        compression_context,
251                    )?;
252                }
253            }
254            DataType::RunEndEncoded(_, values) => {
255                let data = column.to_data();
256                if data.child_data().len() != 2 {
257                    return Err(ArrowError::InvalidArgumentError(format!(
258                        "The run encoded array should have exactly two child arrays. Found {}",
259                        data.child_data().len()
260                    )));
261                }
262                // The run_ends array is not expected to be dictionary encoded. Hence encode dictionaries
263                // only for values array.
264                let values_array = make_array(data.child_data()[1].clone());
265                self.encode_dictionaries(
266                    values,
267                    &values_array,
268                    encoded_dictionaries,
269                    dictionary_tracker,
270                    write_options,
271                    dict_id,
272                    compression_context,
273                )?;
274            }
275            DataType::List(field) => {
276                let list = as_list_array(column);
277                self.encode_dictionaries(
278                    field,
279                    list.values(),
280                    encoded_dictionaries,
281                    dictionary_tracker,
282                    write_options,
283                    dict_id,
284                    compression_context,
285                )?;
286            }
287            DataType::LargeList(field) => {
288                let list = as_large_list_array(column);
289                self.encode_dictionaries(
290                    field,
291                    list.values(),
292                    encoded_dictionaries,
293                    dictionary_tracker,
294                    write_options,
295                    dict_id,
296                    compression_context,
297                )?;
298            }
299            DataType::ListView(field) => {
300                let list = column.as_list_view::<i32>();
301                self.encode_dictionaries(
302                    field,
303                    list.values(),
304                    encoded_dictionaries,
305                    dictionary_tracker,
306                    write_options,
307                    dict_id,
308                    compression_context,
309                )?;
310            }
311            DataType::LargeListView(field) => {
312                let list = column.as_list_view::<i64>();
313                self.encode_dictionaries(
314                    field,
315                    list.values(),
316                    encoded_dictionaries,
317                    dictionary_tracker,
318                    write_options,
319                    dict_id,
320                    compression_context,
321                )?;
322            }
323            DataType::FixedSizeList(field, _) => {
324                let list = column
325                    .as_any()
326                    .downcast_ref::<FixedSizeListArray>()
327                    .expect("Unable to downcast to fixed size list array");
328                self.encode_dictionaries(
329                    field,
330                    list.values(),
331                    encoded_dictionaries,
332                    dictionary_tracker,
333                    write_options,
334                    dict_id,
335                    compression_context,
336                )?;
337            }
338            DataType::Map(field, _) => {
339                let map_array = as_map_array(column);
340
341                let (keys, values) = match field.data_type() {
342                    DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
343                    _ => panic!("Incorrect field data type {:?}", field.data_type()),
344                };
345
346                // keys
347                self.encode_dictionaries(
348                    keys,
349                    map_array.keys(),
350                    encoded_dictionaries,
351                    dictionary_tracker,
352                    write_options,
353                    dict_id,
354                    compression_context,
355                )?;
356
357                // values
358                self.encode_dictionaries(
359                    values,
360                    map_array.values(),
361                    encoded_dictionaries,
362                    dictionary_tracker,
363                    write_options,
364                    dict_id,
365                    compression_context,
366                )?;
367            }
368            DataType::Union(fields, _) => {
369                let union = as_union_array(column);
370                for (type_id, field) in fields.iter() {
371                    let column = union.child(type_id);
372                    self.encode_dictionaries(
373                        field,
374                        column,
375                        encoded_dictionaries,
376                        dictionary_tracker,
377                        write_options,
378                        dict_id,
379                        compression_context,
380                    )?;
381                }
382            }
383            _ => (),
384        }
385
386        Ok(())
387    }
388
389    #[allow(clippy::too_many_arguments)]
390    fn encode_dictionaries<I: Iterator<Item = i64>>(
391        &self,
392        field: &Field,
393        column: &ArrayRef,
394        encoded_dictionaries: &mut Vec<EncodedData>,
395        dictionary_tracker: &mut DictionaryTracker,
396        write_options: &IpcWriteOptions,
397        dict_id_seq: &mut I,
398        compression_context: &mut CompressionContext,
399    ) -> Result<(), ArrowError> {
400        match column.data_type() {
401            DataType::Dictionary(_key_type, _value_type) => {
402                let dict_data = column.to_data();
403                let dict_values = &dict_data.child_data()[0];
404
405                let values = make_array(dict_data.child_data()[0].clone());
406
407                self._encode_dictionaries(
408                    &values,
409                    encoded_dictionaries,
410                    dictionary_tracker,
411                    write_options,
412                    dict_id_seq,
413                    compression_context,
414                )?;
415
416                // It's important to only take the dict_id at this point, because the dict ID
417                // sequence is assigned depth-first, so we need to first encode children and have
418                // them take their assigned dict IDs before we take the dict ID for this field.
419                let dict_id = dict_id_seq.next().ok_or_else(|| {
420                    ArrowError::IpcError(format!(
421                        "no dict id for field {:?}: field.data_type={:?}, column.data_type={:?}",
422                        field.name(),
423                        field.data_type(),
424                        column.data_type()
425                    ))
426                })?;
427
428                match dictionary_tracker.insert_column(
429                    dict_id,
430                    column,
431                    write_options.dictionary_handling,
432                )? {
433                    DictionaryUpdate::None => {}
434                    DictionaryUpdate::New | DictionaryUpdate::Replaced => {
435                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
436                            dict_id,
437                            dict_values,
438                            write_options,
439                            false,
440                            compression_context,
441                        )?);
442                    }
443                    DictionaryUpdate::Delta(data) => {
444                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
445                            dict_id,
446                            &data,
447                            write_options,
448                            true,
449                            compression_context,
450                        )?);
451                    }
452                }
453            }
454            _ => self._encode_dictionaries(
455                column,
456                encoded_dictionaries,
457                dictionary_tracker,
458                write_options,
459                dict_id_seq,
460                compression_context,
461            )?,
462        }
463
464        Ok(())
465    }
466
467    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
468    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
469    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
470    pub fn encode(
471        &self,
472        batch: &RecordBatch,
473        dictionary_tracker: &mut DictionaryTracker,
474        write_options: &IpcWriteOptions,
475        compression_context: &mut CompressionContext,
476    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
477        let schema = batch.schema();
478        let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
479
480        let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
481
482        for (i, field) in schema.fields().iter().enumerate() {
483            let column = batch.column(i);
484            self.encode_dictionaries(
485                field,
486                column,
487                &mut encoded_dictionaries,
488                dictionary_tracker,
489                write_options,
490                &mut dict_id,
491                compression_context,
492            )?;
493        }
494
495        let encoded_message =
496            self.record_batch_to_bytes(batch, write_options, compression_context)?;
497        Ok((encoded_dictionaries, encoded_message))
498    }
499
500    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
501    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
502    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
503    #[deprecated(since = "57.0.0", note = "Use `encode` instead")]
504    pub fn encoded_batch(
505        &self,
506        batch: &RecordBatch,
507        dictionary_tracker: &mut DictionaryTracker,
508        write_options: &IpcWriteOptions,
509    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
510        self.encode(
511            batch,
512            dictionary_tracker,
513            write_options,
514            &mut Default::default(),
515        )
516    }
517
518    /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the
519    /// other for the batch's data
520    fn record_batch_to_bytes(
521        &self,
522        batch: &RecordBatch,
523        write_options: &IpcWriteOptions,
524        compression_context: &mut CompressionContext,
525    ) -> Result<EncodedData, ArrowError> {
526        let mut fbb = FlatBufferBuilder::new();
527
528        let mut nodes: Vec<crate::FieldNode> = vec![];
529        let mut buffers: Vec<crate::Buffer> = vec![];
530        let mut arrow_data: Vec<u8> = vec![];
531        let mut offset = 0;
532
533        // get the type of compression
534        let batch_compression_type = write_options.batch_compression_type;
535
536        let compression = batch_compression_type.map(|batch_compression_type| {
537            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
538            c.add_method(crate::BodyCompressionMethod::BUFFER);
539            c.add_codec(batch_compression_type);
540            c.finish()
541        });
542
543        let compression_codec: Option<CompressionCodec> =
544            batch_compression_type.map(TryInto::try_into).transpose()?;
545
546        let mut variadic_buffer_counts = vec![];
547
548        for array in batch.columns() {
549            let array_data = array.to_data();
550            offset = write_array_data(
551                &array_data,
552                &mut buffers,
553                &mut arrow_data,
554                &mut nodes,
555                offset,
556                array.len(),
557                array.null_count(),
558                compression_codec,
559                compression_context,
560                write_options,
561            )?;
562
563            append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
564        }
565        // pad the tail of body data
566        let len = arrow_data.len();
567        let pad_len = pad_to_alignment(write_options.alignment, len);
568        arrow_data.extend_from_slice(&PADDING[..pad_len]);
569
570        // write data
571        let buffers = fbb.create_vector(&buffers);
572        let nodes = fbb.create_vector(&nodes);
573        let variadic_buffer = if variadic_buffer_counts.is_empty() {
574            None
575        } else {
576            Some(fbb.create_vector(&variadic_buffer_counts))
577        };
578
579        let root = {
580            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
581            batch_builder.add_length(batch.num_rows() as i64);
582            batch_builder.add_nodes(nodes);
583            batch_builder.add_buffers(buffers);
584            if let Some(c) = compression {
585                batch_builder.add_compression(c);
586            }
587
588            if let Some(v) = variadic_buffer {
589                batch_builder.add_variadicBufferCounts(v);
590            }
591            let b = batch_builder.finish();
592            b.as_union_value()
593        };
594        // create an crate::Message
595        let mut message = crate::MessageBuilder::new(&mut fbb);
596        message.add_version(write_options.metadata_version);
597        message.add_header_type(crate::MessageHeader::RecordBatch);
598        message.add_bodyLength(arrow_data.len() as i64);
599        message.add_header(root);
600        let root = message.finish();
601        fbb.finish(root, None);
602        let finished_data = fbb.finished_data();
603
604        Ok(EncodedData {
605            ipc_message: finished_data.to_vec(),
606            arrow_data,
607        })
608    }
609
610    /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the
611    /// other for the data
612    fn dictionary_batch_to_bytes(
613        &self,
614        dict_id: i64,
615        array_data: &ArrayData,
616        write_options: &IpcWriteOptions,
617        is_delta: bool,
618        compression_context: &mut CompressionContext,
619    ) -> Result<EncodedData, ArrowError> {
620        let mut fbb = FlatBufferBuilder::new();
621
622        let mut nodes: Vec<crate::FieldNode> = vec![];
623        let mut buffers: Vec<crate::Buffer> = vec![];
624        let mut arrow_data: Vec<u8> = vec![];
625
626        // get the type of compression
627        let batch_compression_type = write_options.batch_compression_type;
628
629        let compression = batch_compression_type.map(|batch_compression_type| {
630            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
631            c.add_method(crate::BodyCompressionMethod::BUFFER);
632            c.add_codec(batch_compression_type);
633            c.finish()
634        });
635
636        let compression_codec: Option<CompressionCodec> = batch_compression_type
637            .map(|batch_compression_type| batch_compression_type.try_into())
638            .transpose()?;
639
640        write_array_data(
641            array_data,
642            &mut buffers,
643            &mut arrow_data,
644            &mut nodes,
645            0,
646            array_data.len(),
647            array_data.null_count(),
648            compression_codec,
649            compression_context,
650            write_options,
651        )?;
652
653        let mut variadic_buffer_counts = vec![];
654        append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
655
656        // pad the tail of body data
657        let len = arrow_data.len();
658        let pad_len = pad_to_alignment(write_options.alignment, len);
659        arrow_data.extend_from_slice(&PADDING[..pad_len]);
660
661        // write data
662        let buffers = fbb.create_vector(&buffers);
663        let nodes = fbb.create_vector(&nodes);
664        let variadic_buffer = if variadic_buffer_counts.is_empty() {
665            None
666        } else {
667            Some(fbb.create_vector(&variadic_buffer_counts))
668        };
669
670        let root = {
671            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
672            batch_builder.add_length(array_data.len() as i64);
673            batch_builder.add_nodes(nodes);
674            batch_builder.add_buffers(buffers);
675            if let Some(c) = compression {
676                batch_builder.add_compression(c);
677            }
678            if let Some(v) = variadic_buffer {
679                batch_builder.add_variadicBufferCounts(v);
680            }
681            batch_builder.finish()
682        };
683
684        let root = {
685            let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
686            batch_builder.add_id(dict_id);
687            batch_builder.add_data(root);
688            batch_builder.add_isDelta(is_delta);
689            batch_builder.finish().as_union_value()
690        };
691
692        let root = {
693            let mut message_builder = crate::MessageBuilder::new(&mut fbb);
694            message_builder.add_version(write_options.metadata_version);
695            message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
696            message_builder.add_bodyLength(arrow_data.len() as i64);
697            message_builder.add_header(root);
698            message_builder.finish()
699        };
700
701        fbb.finish(root, None);
702        let finished_data = fbb.finished_data();
703
704        Ok(EncodedData {
705            ipc_message: finished_data.to_vec(),
706            arrow_data,
707        })
708    }
709}
710
711fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
712    match array.data_type() {
713        DataType::BinaryView | DataType::Utf8View => {
714            // The spec documents the counts only includes the variadic buffers, not the view/null buffers.
715            // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers
716            counts.push(array.buffers().len() as i64 - 1);
717        }
718        DataType::Dictionary(_, _) => {
719            // Do nothing
720            // Dictionary types are handled in `encode_dictionaries`.
721        }
722        _ => {
723            for child in array.child_data() {
724                append_variadic_buffer_counts(counts, child)
725            }
726        }
727    }
728}
729
730pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
731    match arr.data_type() {
732        DataType::RunEndEncoded(k, _) => match k.data_type() {
733            DataType::Int16 => {
734                Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
735            }
736            DataType::Int32 => {
737                Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
738            }
739            DataType::Int64 => {
740                Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
741            }
742            d => unreachable!("Unexpected data type {d}"),
743        },
744        d => Err(ArrowError::InvalidArgumentError(format!(
745            "The given array is not a run array. Data type of given array: {d}"
746        ))),
747    }
748}
749
750// Returns a `RunArray` with zero offset and length matching the last value
751// in run_ends array.
752fn into_zero_offset_run_array<R: RunEndIndexType>(
753    run_array: RunArray<R>,
754) -> Result<RunArray<R>, ArrowError> {
755    let run_ends = run_array.run_ends();
756    if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
757        return Ok(run_array);
758    }
759
760    // The physical index of original run_ends array from which the `ArrayData`is sliced.
761    let start_physical_index = run_ends.get_start_physical_index();
762
763    // The physical index of original run_ends array until which the `ArrayData`is sliced.
764    let end_physical_index = run_ends.get_end_physical_index();
765
766    let physical_length = end_physical_index - start_physical_index + 1;
767
768    // build new run_ends array by subtracting offset from run ends.
769    let offset = R::Native::usize_as(run_ends.offset());
770    let mut builder = BufferBuilder::<R::Native>::new(physical_length);
771    for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
772        builder.append(run_end_value.sub_wrapping(offset));
773    }
774    builder.append(R::Native::from_usize(run_array.len()).unwrap());
775    let new_run_ends = unsafe {
776        // Safety:
777        // The function builds a valid run_ends array and hence need not be validated.
778        ArrayDataBuilder::new(R::DATA_TYPE)
779            .len(physical_length)
780            .add_buffer(builder.finish())
781            .build_unchecked()
782    };
783
784    // build new values by slicing physical indices.
785    let new_values = run_array
786        .values()
787        .slice(start_physical_index, physical_length)
788        .into_data();
789
790    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
791        .len(run_array.len())
792        .add_child_data(new_run_ends)
793        .add_child_data(new_values);
794    let array_data = unsafe {
795        // Safety:
796        //  This function builds a valid run array and hence can skip validation.
797        builder.build_unchecked()
798    };
799    Ok(array_data.into())
800}
801
802/// Controls how dictionaries are handled in Arrow IPC messages
803#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
804pub enum DictionaryHandling {
805    /// Send the entire dictionary every time it is encountered (default)
806    #[default]
807    Resend,
808    /// Send only new dictionary values since the last batch (delta encoding)
809    ///
810    /// When a dictionary is first encountered, the entire dictionary is sent.
811    /// For subsequent batches, only values that are new (not previously sent)
812    /// are transmitted with the `isDelta` flag set to true.
813    Delta,
814}
815
816/// Describes what kind of update took place after a call to [`DictionaryTracker::insert`].
817#[derive(Debug, Clone)]
818pub enum DictionaryUpdate {
819    /// No dictionary was written, the dictionary was identical to what was already
820    /// in the tracker.
821    None,
822    /// No dictionary was present in the tracker
823    New,
824    /// Dictionary was replaced with the new data
825    Replaced,
826    /// Dictionary was updated, ArrayData is the delta between old and new
827    Delta(ArrayData),
828}
829
830/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
831/// multiple times.
832///
833/// Can optionally error if an update to an existing dictionary is attempted, which
834/// isn't allowed in the `FileWriter`.
835#[derive(Debug)]
836pub struct DictionaryTracker {
837    // NOTE: When adding fields, update the clear() method accordingly.
838    written: HashMap<i64, ArrayData>,
839    dict_ids: Vec<i64>,
840    error_on_replacement: bool,
841}
842
843impl DictionaryTracker {
844    /// Create a new [`DictionaryTracker`].
845    ///
846    /// If `error_on_replacement`
847    /// is true, an error will be generated if an update to an
848    /// existing dictionary is attempted.
849    pub fn new(error_on_replacement: bool) -> Self {
850        #[allow(deprecated)]
851        Self {
852            written: HashMap::new(),
853            dict_ids: Vec::new(),
854            error_on_replacement,
855        }
856    }
857
858    /// Record and return the next dictionary ID.
859    pub fn next_dict_id(&mut self) -> i64 {
860        let next = self
861            .dict_ids
862            .last()
863            .copied()
864            .map(|i| i + 1)
865            .unwrap_or_default();
866
867        self.dict_ids.push(next);
868        next
869    }
870
871    /// Return the sequence of dictionary IDs in the order they should be observed while
872    /// traversing the schema
873    pub fn dict_id(&mut self) -> &[i64] {
874        &self.dict_ids
875    }
876
877    /// Keep track of the dictionary with the given ID and values. Behavior:
878    ///
879    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
880    ///   that the dictionary was not actually inserted (because it's already been seen).
881    /// * If this ID has been written already but with different data, and this tracker is
882    ///   configured to return an error, return an error.
883    /// * If the tracker has not been configured to error on replacement or this dictionary
884    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
885    ///   inserted.
886    #[deprecated(since = "56.1.0", note = "Use `insert_column` instead")]
887    pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
888        let dict_data = column.to_data();
889        let dict_values = &dict_data.child_data()[0];
890
891        // If a dictionary with this id was already emitted, check if it was the same.
892        if let Some(last) = self.written.get(&dict_id) {
893            if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
894                // Same dictionary values => no need to emit it again
895                return Ok(false);
896            }
897            if self.error_on_replacement {
898                // If error on replacement perform a logical comparison
899                if last.child_data()[0] == *dict_values {
900                    // Same dictionary values => no need to emit it again
901                    return Ok(false);
902                }
903                return Err(ArrowError::InvalidArgumentError(
904                    "Dictionary replacement detected when writing IPC file format. \
905                     Arrow IPC files only support a single dictionary for a given field \
906                     across all batches."
907                        .to_string(),
908                ));
909            }
910        }
911
912        self.written.insert(dict_id, dict_data);
913        Ok(true)
914    }
915
916    /// Keep track of the dictionary with the given ID and values. The return
917    /// value indicates what, if any, update to the internal map took place
918    /// and how it should be interpreted based on the `dict_handling` parameter.
919    ///
920    /// # Returns
921    ///
922    /// * `Ok(Dictionary::New)` - If the dictionary was not previously written
923    /// * `Ok(Dictionary::Replaced)` - If the dictionary was previously written
924    ///   with completely different data, or if the data is a delta of the existing,
925    ///   but with `dict_handling` set to `DictionaryHandling::Resend`
926    /// * `Ok(Dictionary::Delta)` - If the dictionary was previously written, but
927    ///   the new data is a delta of the old and the `dict_handling` is set to
928    ///   `DictionaryHandling::Delta`
929    /// * `Err(e)` - If the dictionary was previously written with different data,
930    ///   and `error_on_replacement` is set to `true`.
931    pub fn insert_column(
932        &mut self,
933        dict_id: i64,
934        column: &ArrayRef,
935        dict_handling: DictionaryHandling,
936    ) -> Result<DictionaryUpdate, ArrowError> {
937        let new_data = column.to_data();
938        let new_values = &new_data.child_data()[0];
939
940        // If there is no existing dictionary with this ID, we always insert
941        let Some(old) = self.written.get(&dict_id) else {
942            self.written.insert(dict_id, new_data);
943            return Ok(DictionaryUpdate::New);
944        };
945
946        // Fast path - If the array data points to the same buffer as the
947        // existing then they're the same.
948        let old_values = &old.child_data()[0];
949        if ArrayData::ptr_eq(old_values, new_values) {
950            return Ok(DictionaryUpdate::None);
951        }
952
953        // Slow path - Compare the dictionaries value by value
954        let comparison = compare_dictionaries(old_values, new_values);
955        if matches!(comparison, DictionaryComparison::Equal) {
956            return Ok(DictionaryUpdate::None);
957        }
958
959        const REPLACEMENT_ERROR: &str = "Dictionary replacement detected when writing IPC file format. \
960                 Arrow IPC files only support a single dictionary for a given field \
961                 across all batches.";
962
963        match comparison {
964            DictionaryComparison::NotEqual => {
965                if self.error_on_replacement {
966                    return Err(ArrowError::InvalidArgumentError(
967                        REPLACEMENT_ERROR.to_string(),
968                    ));
969                }
970
971                self.written.insert(dict_id, new_data);
972                Ok(DictionaryUpdate::Replaced)
973            }
974            DictionaryComparison::Delta => match dict_handling {
975                DictionaryHandling::Resend => {
976                    if self.error_on_replacement {
977                        return Err(ArrowError::InvalidArgumentError(
978                            REPLACEMENT_ERROR.to_string(),
979                        ));
980                    }
981
982                    self.written.insert(dict_id, new_data);
983                    Ok(DictionaryUpdate::Replaced)
984                }
985                DictionaryHandling::Delta => {
986                    let delta =
987                        new_values.slice(old_values.len(), new_values.len() - old_values.len());
988                    self.written.insert(dict_id, new_data);
989                    Ok(DictionaryUpdate::Delta(delta))
990                }
991            },
992            DictionaryComparison::Equal => unreachable!("Already checked equal case"),
993        }
994    }
995
996    /// Clears the state of the dictionary tracker.
997    ///
998    /// This allows the dictionary tracker to be reused for a new IPC stream while avoiding the
999    /// allocation cost of creating a new instance. This method should not be called if
1000    /// the dictionary tracker will be used to continue writing to an existing IPC stream.
1001    pub fn clear(&mut self) {
1002        self.dict_ids.clear();
1003        self.written.clear();
1004    }
1005}
1006
1007/// Describes how two dictionary arrays compare to each other.
1008#[derive(Debug, Clone)]
1009enum DictionaryComparison {
1010    /// Neither a delta, nor an exact match
1011    NotEqual,
1012    /// Exact element-wise match
1013    Equal,
1014    /// The two arrays are dictionary deltas of each other, meaning the first
1015    /// is a prefix of the second.
1016    Delta,
1017}
1018
1019// Compares two dictionaries and returns a [`DictionaryComparison`].
1020fn compare_dictionaries(old: &ArrayData, new: &ArrayData) -> DictionaryComparison {
1021    // Check for exact match
1022    let existing_len = old.len();
1023    let new_len = new.len();
1024    if existing_len == new_len {
1025        if *old == *new {
1026            return DictionaryComparison::Equal;
1027        } else {
1028            return DictionaryComparison::NotEqual;
1029        }
1030    }
1031
1032    // Can't be a delta if the new is shorter than the existing
1033    if new_len < existing_len {
1034        return DictionaryComparison::NotEqual;
1035    }
1036
1037    // Check for delta
1038    if new.slice(0, existing_len) == *old {
1039        return DictionaryComparison::Delta;
1040    }
1041
1042    DictionaryComparison::NotEqual
1043}
1044
1045/// Arrow File Writer
1046///
1047/// Writes Arrow [`RecordBatch`]es in the [IPC File Format].
1048///
1049/// # See Also
1050///
1051/// * [`StreamWriter`] for writing IPC Streams
1052///
1053/// # Example
1054/// ```
1055/// # use arrow_array::record_batch;
1056/// # use arrow_ipc::writer::FileWriter;
1057/// # let mut file = vec![]; // mimic a file for the example
1058/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1059/// // create a new writer, the schema must be known in advance
1060/// let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1061/// // write each batch to the underlying writer
1062/// writer.write(&batch).unwrap();
1063/// // When all batches are written, call finish to flush all buffers
1064/// writer.finish().unwrap();
1065/// ```
1066/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1067pub struct FileWriter<W> {
1068    /// The object to write to
1069    writer: W,
1070    /// IPC write options
1071    write_options: IpcWriteOptions,
1072    /// A reference to the schema, used in validating record batches
1073    schema: SchemaRef,
1074    /// The number of bytes between each block of bytes, as an offset for random access
1075    block_offsets: usize,
1076    /// Dictionary blocks that will be written as part of the IPC footer
1077    dictionary_blocks: Vec<crate::Block>,
1078    /// Record blocks that will be written as part of the IPC footer
1079    record_blocks: Vec<crate::Block>,
1080    /// Whether the writer footer has been written, and the writer is finished
1081    finished: bool,
1082    /// Keeps track of dictionaries that have been written
1083    dictionary_tracker: DictionaryTracker,
1084    /// User level customized metadata
1085    custom_metadata: HashMap<String, String>,
1086
1087    data_gen: IpcDataGenerator,
1088
1089    compression_context: CompressionContext,
1090}
1091
1092impl<W: Write> FileWriter<BufWriter<W>> {
1093    /// Try to create a new file writer with the writer wrapped in a BufWriter.
1094    ///
1095    /// See [`FileWriter::try_new`] for an unbuffered version.
1096    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1097        Self::try_new(BufWriter::new(writer), schema)
1098    }
1099}
1100
1101impl<W: Write> FileWriter<W> {
1102    /// Try to create a new writer, with the schema written as part of the header
1103    ///
1104    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1105    ///
1106    /// # Errors
1107    ///
1108    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1109    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1110        let write_options = IpcWriteOptions::default();
1111        Self::try_new_with_options(writer, schema, write_options)
1112    }
1113
1114    /// Try to create a new writer with IpcWriteOptions
1115    ///
1116    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1117    ///
1118    /// # Errors
1119    ///
1120    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1121    pub fn try_new_with_options(
1122        mut writer: W,
1123        schema: &Schema,
1124        write_options: IpcWriteOptions,
1125    ) -> Result<Self, ArrowError> {
1126        let data_gen = IpcDataGenerator::default();
1127        // write magic to header aligned on alignment boundary
1128        let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
1129        let header_size = super::ARROW_MAGIC.len() + pad_len;
1130        writer.write_all(&super::ARROW_MAGIC)?;
1131        writer.write_all(&PADDING[..pad_len])?;
1132        // write the schema, set the written bytes to the schema + header
1133        let mut dictionary_tracker = DictionaryTracker::new(true);
1134        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1135            schema,
1136            &mut dictionary_tracker,
1137            &write_options,
1138        );
1139        let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
1140        Ok(Self {
1141            writer,
1142            write_options,
1143            schema: Arc::new(schema.clone()),
1144            block_offsets: meta + data + header_size,
1145            dictionary_blocks: vec![],
1146            record_blocks: vec![],
1147            finished: false,
1148            dictionary_tracker,
1149            custom_metadata: HashMap::new(),
1150            data_gen,
1151            compression_context: CompressionContext::default(),
1152        })
1153    }
1154
1155    /// Adds a key-value pair to the [FileWriter]'s custom metadata
1156    pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
1157        self.custom_metadata.insert(key.into(), value.into());
1158    }
1159
1160    /// Write a record batch to the file
1161    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1162        if self.finished {
1163            return Err(ArrowError::IpcError(
1164                "Cannot write record batch to file writer as it is closed".to_string(),
1165            ));
1166        }
1167
1168        let (encoded_dictionaries, encoded_message) = self.data_gen.encode(
1169            batch,
1170            &mut self.dictionary_tracker,
1171            &self.write_options,
1172            &mut self.compression_context,
1173        )?;
1174
1175        for encoded_dictionary in encoded_dictionaries {
1176            let (meta, data) =
1177                write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1178
1179            let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
1180            self.dictionary_blocks.push(block);
1181            self.block_offsets += meta + data;
1182        }
1183
1184        let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
1185
1186        // add a record block for the footer
1187        let block = crate::Block::new(
1188            self.block_offsets as i64,
1189            meta as i32, // TODO: is this still applicable?
1190            data as i64,
1191        );
1192        self.record_blocks.push(block);
1193        self.block_offsets += meta + data;
1194        Ok(())
1195    }
1196
1197    /// Write footer and closing tag, then mark the writer as done
1198    pub fn finish(&mut self) -> Result<(), ArrowError> {
1199        if self.finished {
1200            return Err(ArrowError::IpcError(
1201                "Cannot write footer to file writer as it is closed".to_string(),
1202            ));
1203        }
1204
1205        // write EOS
1206        write_continuation(&mut self.writer, &self.write_options, 0)?;
1207
1208        let mut fbb = FlatBufferBuilder::new();
1209        let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1210        let record_batches = fbb.create_vector(&self.record_blocks);
1211
1212        // dictionaries are already written, so we can reset dictionary tracker to reuse for schema
1213        self.dictionary_tracker.clear();
1214        let schema = IpcSchemaEncoder::new()
1215            .with_dictionary_tracker(&mut self.dictionary_tracker)
1216            .schema_to_fb_offset(&mut fbb, &self.schema);
1217        let fb_custom_metadata = (!self.custom_metadata.is_empty())
1218            .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1219
1220        let root = {
1221            let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1222            footer_builder.add_version(self.write_options.metadata_version);
1223            footer_builder.add_schema(schema);
1224            footer_builder.add_dictionaries(dictionaries);
1225            footer_builder.add_recordBatches(record_batches);
1226            if let Some(fb_custom_metadata) = fb_custom_metadata {
1227                footer_builder.add_custom_metadata(fb_custom_metadata);
1228            }
1229            footer_builder.finish()
1230        };
1231        fbb.finish(root, None);
1232        let footer_data = fbb.finished_data();
1233        self.writer.write_all(footer_data)?;
1234        self.writer
1235            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1236        self.writer.write_all(&super::ARROW_MAGIC)?;
1237        self.writer.flush()?;
1238        self.finished = true;
1239
1240        Ok(())
1241    }
1242
1243    /// Returns the arrow [`SchemaRef`] for this arrow file.
1244    pub fn schema(&self) -> &SchemaRef {
1245        &self.schema
1246    }
1247
1248    /// Gets a reference to the underlying writer.
1249    pub fn get_ref(&self) -> &W {
1250        &self.writer
1251    }
1252
1253    /// Gets a mutable reference to the underlying writer.
1254    ///
1255    /// It is inadvisable to directly write to the underlying writer.
1256    pub fn get_mut(&mut self) -> &mut W {
1257        &mut self.writer
1258    }
1259
1260    /// Flush the underlying writer.
1261    ///
1262    /// Both the BufWriter and the underlying writer are flushed.
1263    pub fn flush(&mut self) -> Result<(), ArrowError> {
1264        self.writer.flush()?;
1265        Ok(())
1266    }
1267
1268    /// Unwraps the underlying writer.
1269    ///
1270    /// The writer is flushed and the FileWriter is finished before returning.
1271    ///
1272    /// # Errors
1273    ///
1274    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1275    /// or while flushing the writer.
1276    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1277        if !self.finished {
1278            // `finish` flushes the writer.
1279            self.finish()?;
1280        }
1281        Ok(self.writer)
1282    }
1283}
1284
1285impl<W: Write> RecordBatchWriter for FileWriter<W> {
1286    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1287        self.write(batch)
1288    }
1289
1290    fn close(mut self) -> Result<(), ArrowError> {
1291        self.finish()
1292    }
1293}
1294
1295/// Arrow Stream Writer
1296///
1297/// Writes Arrow [`RecordBatch`]es to bytes using the [IPC Streaming Format].
1298///
1299/// # See Also
1300///
1301/// * [`FileWriter`] for writing IPC Files
1302///
1303/// # Example - Basic usage
1304/// ```
1305/// # use arrow_array::record_batch;
1306/// # use arrow_ipc::writer::StreamWriter;
1307/// # let mut stream = vec![]; // mimic a stream for the example
1308/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1309/// // create a new writer, the schema must be known in advance
1310/// let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1311/// // write each batch to the underlying stream
1312/// writer.write(&batch).unwrap();
1313/// // When all batches are written, call finish to flush all buffers
1314/// writer.finish().unwrap();
1315/// ```
1316/// # Example - Efficient delta dictionaries
1317/// ```
1318/// # use arrow_array::record_batch;
1319/// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1320/// # use arrow_ipc::writer::DictionaryHandling;
1321/// # use arrow_schema::{DataType, Field, Schema, SchemaRef};
1322/// # use arrow_array::{
1323/// #    builder::StringDictionaryBuilder, types::Int32Type, Array, ArrayRef, DictionaryArray,
1324/// #    RecordBatch, StringArray,
1325/// # };
1326/// # use std::sync::Arc;
1327///
1328/// let schema = Arc::new(Schema::new(vec![Field::new(
1329///    "col1",
1330///    DataType::Dictionary(Box::from(DataType::Int32), Box::from(DataType::Utf8)),
1331///    true,
1332/// )]));
1333///
1334/// let mut builder = StringDictionaryBuilder::<arrow_array::types::Int32Type>::new();
1335///
1336/// // `finish_preserve_values` will keep the dictionary values along with their
1337/// // key assignments so that they can be re-used in the next batch.
1338/// builder.append("a").unwrap();
1339/// builder.append("b").unwrap();
1340/// let array1 = builder.finish_preserve_values();
1341/// let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array1) as ArrayRef]).unwrap();
1342///
1343/// // In this batch, 'a' will have the same dictionary key as 'a' in the previous batch,
1344/// // and 'd' will take the next available key.
1345/// builder.append("a").unwrap();
1346/// builder.append("d").unwrap();
1347/// let array2 = builder.finish_preserve_values();
1348/// let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array2) as ArrayRef]).unwrap();
1349///
1350/// let mut stream = vec![];
1351/// // You must set `.with_dictionary_handling(DictionaryHandling::Delta)` to
1352/// // enable delta dictionaries in the writer
1353/// let options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
1354/// let mut writer = StreamWriter::try_new(&mut stream, &schema).unwrap();
1355///
1356/// // When writing the first batch, a dictionary message with 'a' and 'b' will be written
1357/// // prior to the record batch.
1358/// writer.write(&batch1).unwrap();
1359/// // With the second batch only a delta dictionary with 'd' will be written
1360/// // prior to the record batch. This is only possible with `finish_preserve_values`.
1361/// // Without it, 'a' and 'd' in this batch would have different keys than the
1362/// // first batch and so we'd have to send a replacement dictionary with new keys
1363/// // for both.
1364/// writer.write(&batch2).unwrap();
1365/// writer.finish().unwrap();
1366/// ```
1367/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1368pub struct StreamWriter<W> {
1369    /// The object to write to
1370    writer: W,
1371    /// IPC write options
1372    write_options: IpcWriteOptions,
1373    /// Whether the writer footer has been written, and the writer is finished
1374    finished: bool,
1375    /// Keeps track of dictionaries that have been written
1376    dictionary_tracker: DictionaryTracker,
1377
1378    data_gen: IpcDataGenerator,
1379
1380    compression_context: CompressionContext,
1381}
1382
1383impl<W: Write> StreamWriter<BufWriter<W>> {
1384    /// Try to create a new stream writer with the writer wrapped in a BufWriter.
1385    ///
1386    /// See [`StreamWriter::try_new`] for an unbuffered version.
1387    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1388        Self::try_new(BufWriter::new(writer), schema)
1389    }
1390}
1391
1392impl<W: Write> StreamWriter<W> {
1393    /// Try to create a new writer, with the schema written as part of the header.
1394    ///
1395    /// Note that there is no internal buffering. See also [`StreamWriter::try_new_buffered`].
1396    ///
1397    /// # Errors
1398    ///
1399    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1400    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1401        let write_options = IpcWriteOptions::default();
1402        Self::try_new_with_options(writer, schema, write_options)
1403    }
1404
1405    /// Try to create a new writer with [`IpcWriteOptions`].
1406    ///
1407    /// # Errors
1408    ///
1409    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1410    pub fn try_new_with_options(
1411        mut writer: W,
1412        schema: &Schema,
1413        write_options: IpcWriteOptions,
1414    ) -> Result<Self, ArrowError> {
1415        let data_gen = IpcDataGenerator::default();
1416        let mut dictionary_tracker = DictionaryTracker::new(false);
1417
1418        // write the schema, set the written bytes to the schema
1419        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1420            schema,
1421            &mut dictionary_tracker,
1422            &write_options,
1423        );
1424        write_message(&mut writer, encoded_message, &write_options)?;
1425        Ok(Self {
1426            writer,
1427            write_options,
1428            finished: false,
1429            dictionary_tracker,
1430            data_gen,
1431            compression_context: CompressionContext::default(),
1432        })
1433    }
1434
1435    /// Write a record batch to the stream
1436    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1437        if self.finished {
1438            return Err(ArrowError::IpcError(
1439                "Cannot write record batch to stream writer as it is closed".to_string(),
1440            ));
1441        }
1442
1443        let (encoded_dictionaries, encoded_message) = self
1444            .data_gen
1445            .encode(
1446                batch,
1447                &mut self.dictionary_tracker,
1448                &self.write_options,
1449                &mut self.compression_context,
1450            )
1451            .expect("StreamWriter is configured to not error on dictionary replacement");
1452
1453        for encoded_dictionary in encoded_dictionaries {
1454            write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1455        }
1456
1457        write_message(&mut self.writer, encoded_message, &self.write_options)?;
1458        Ok(())
1459    }
1460
1461    /// Write continuation bytes, and mark the stream as done
1462    pub fn finish(&mut self) -> Result<(), ArrowError> {
1463        if self.finished {
1464            return Err(ArrowError::IpcError(
1465                "Cannot write footer to stream writer as it is closed".to_string(),
1466            ));
1467        }
1468
1469        write_continuation(&mut self.writer, &self.write_options, 0)?;
1470
1471        self.finished = true;
1472
1473        Ok(())
1474    }
1475
1476    /// Gets a reference to the underlying writer.
1477    pub fn get_ref(&self) -> &W {
1478        &self.writer
1479    }
1480
1481    /// Gets a mutable reference to the underlying writer.
1482    ///
1483    /// It is inadvisable to directly write to the underlying writer.
1484    pub fn get_mut(&mut self) -> &mut W {
1485        &mut self.writer
1486    }
1487
1488    /// Flush the underlying writer.
1489    ///
1490    /// Both the BufWriter and the underlying writer are flushed.
1491    pub fn flush(&mut self) -> Result<(), ArrowError> {
1492        self.writer.flush()?;
1493        Ok(())
1494    }
1495
1496    /// Unwraps the the underlying writer.
1497    ///
1498    /// The writer is flushed and the StreamWriter is finished before returning.
1499    ///
1500    /// # Errors
1501    ///
1502    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1503    /// or while flushing the writer.
1504    ///
1505    /// # Example
1506    ///
1507    /// ```
1508    /// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1509    /// # use arrow_ipc::MetadataVersion;
1510    /// # use arrow_schema::{ArrowError, Schema};
1511    /// # fn main() -> Result<(), ArrowError> {
1512    /// // The result we expect from an empty schema
1513    /// let expected = vec![
1514    ///     255, 255, 255, 255,  48,   0,   0,   0,
1515    ///      16,   0,   0,   0,   0,   0,  10,   0,
1516    ///      12,   0,  10,   0,   9,   0,   4,   0,
1517    ///      10,   0,   0,   0,  16,   0,   0,   0,
1518    ///       0,   1,   4,   0,   8,   0,   8,   0,
1519    ///       0,   0,   4,   0,   8,   0,   0,   0,
1520    ///       4,   0,   0,   0,   0,   0,   0,   0,
1521    ///     255, 255, 255, 255,   0,   0,   0,   0
1522    /// ];
1523    ///
1524    /// let schema = Schema::empty();
1525    /// let buffer: Vec<u8> = Vec::new();
1526    /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?;
1527    /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?;
1528    ///
1529    /// assert_eq!(stream_writer.into_inner()?, expected);
1530    /// # Ok(())
1531    /// # }
1532    /// ```
1533    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1534        if !self.finished {
1535            // `finish` flushes.
1536            self.finish()?;
1537        }
1538        Ok(self.writer)
1539    }
1540}
1541
1542impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1543    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1544        self.write(batch)
1545    }
1546
1547    fn close(mut self) -> Result<(), ArrowError> {
1548        self.finish()
1549    }
1550}
1551
1552/// Stores the encoded data, which is an crate::Message, and optional Arrow data
1553pub struct EncodedData {
1554    /// An encoded crate::Message
1555    pub ipc_message: Vec<u8>,
1556    /// Arrow buffers to be written, should be an empty vec for schema messages
1557    pub arrow_data: Vec<u8>,
1558}
1559/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
1560pub fn write_message<W: Write>(
1561    mut writer: W,
1562    encoded: EncodedData,
1563    write_options: &IpcWriteOptions,
1564) -> Result<(usize, usize), ArrowError> {
1565    let arrow_data_len = encoded.arrow_data.len();
1566    if arrow_data_len % usize::from(write_options.alignment) != 0 {
1567        return Err(ArrowError::MemoryError(
1568            "Arrow data not aligned".to_string(),
1569        ));
1570    }
1571
1572    let a = usize::from(write_options.alignment - 1);
1573    let buffer = encoded.ipc_message;
1574    let flatbuf_size = buffer.len();
1575    let prefix_size = if write_options.write_legacy_ipc_format {
1576        4
1577    } else {
1578        8
1579    };
1580    let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1581    let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1582
1583    write_continuation(
1584        &mut writer,
1585        write_options,
1586        (aligned_size - prefix_size) as i32,
1587    )?;
1588
1589    // write the flatbuf
1590    if flatbuf_size > 0 {
1591        writer.write_all(&buffer)?;
1592    }
1593    // write padding
1594    writer.write_all(&PADDING[..padding_bytes])?;
1595
1596    // write arrow data
1597    let body_len = if arrow_data_len > 0 {
1598        write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1599    } else {
1600        0
1601    };
1602
1603    Ok((aligned_size, body_len))
1604}
1605
1606fn write_body_buffers<W: Write>(
1607    mut writer: W,
1608    data: &[u8],
1609    alignment: u8,
1610) -> Result<usize, ArrowError> {
1611    let len = data.len();
1612    let pad_len = pad_to_alignment(alignment, len);
1613    let total_len = len + pad_len;
1614
1615    // write body buffer
1616    writer.write_all(data)?;
1617    if pad_len > 0 {
1618        writer.write_all(&PADDING[..pad_len])?;
1619    }
1620
1621    writer.flush()?;
1622    Ok(total_len)
1623}
1624
1625/// Write a record batch to the writer, writing the message size before the message
1626/// if the record batch is being written to a stream
1627fn write_continuation<W: Write>(
1628    mut writer: W,
1629    write_options: &IpcWriteOptions,
1630    total_len: i32,
1631) -> Result<usize, ArrowError> {
1632    let mut written = 8;
1633
1634    // the version of the writer determines whether continuation markers should be added
1635    match write_options.metadata_version {
1636        crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1637            unreachable!("Options with the metadata version cannot be created")
1638        }
1639        crate::MetadataVersion::V4 => {
1640            if !write_options.write_legacy_ipc_format {
1641                // v0.15.0 format
1642                writer.write_all(&CONTINUATION_MARKER)?;
1643                written = 4;
1644            }
1645            writer.write_all(&total_len.to_le_bytes()[..])?;
1646        }
1647        crate::MetadataVersion::V5 => {
1648            // write continuation marker and message length
1649            writer.write_all(&CONTINUATION_MARKER)?;
1650            writer.write_all(&total_len.to_le_bytes()[..])?;
1651        }
1652        z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1653    };
1654
1655    writer.flush()?;
1656
1657    Ok(written)
1658}
1659
1660/// In V4, null types have no validity bitmap
1661/// In V5 and later, null and union types have no validity bitmap
1662/// Run end encoded type has no validity bitmap.
1663fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1664    if write_options.metadata_version < crate::MetadataVersion::V5 {
1665        !matches!(data_type, DataType::Null)
1666    } else {
1667        !matches!(
1668            data_type,
1669            DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1670        )
1671    }
1672}
1673
1674/// Whether to truncate the buffer
1675#[inline]
1676fn buffer_need_truncate(
1677    array_offset: usize,
1678    buffer: &Buffer,
1679    spec: &BufferSpec,
1680    min_length: usize,
1681) -> bool {
1682    spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1683}
1684
1685/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
1686#[inline]
1687fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1688    match spec {
1689        BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1690        _ => 0,
1691    }
1692}
1693
1694/// Common functionality for re-encoding offsets. Returns the new offsets as well as
1695/// original start offset and length for use in slicing child data.
1696fn reencode_offsets<O: OffsetSizeTrait>(
1697    offsets: &Buffer,
1698    data: &ArrayData,
1699) -> (Buffer, usize, usize) {
1700    let offsets_slice: &[O] = offsets.typed_data::<O>();
1701    let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1702
1703    let start_offset = offset_slice.first().unwrap();
1704    let end_offset = offset_slice.last().unwrap();
1705
1706    let offsets = match start_offset.as_usize() {
1707        0 => {
1708            let size = size_of::<O>();
1709            offsets.slice_with_length(data.offset() * size, (data.len() + 1) * size)
1710        }
1711        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1712    };
1713
1714    let start_offset = start_offset.as_usize();
1715    let end_offset = end_offset.as_usize();
1716
1717    (offsets, start_offset, end_offset - start_offset)
1718}
1719
1720/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O`
1721///
1722/// In particular, this handles re-encoding the offsets if they don't start at `0`,
1723/// slicing the values buffer as appropriate. This helps reduce the encoded
1724/// size of sliced arrays, as values that have been sliced away are not encoded
1725fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1726    if data.is_empty() {
1727        return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1728    }
1729
1730    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1731    let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1732    (offsets, values)
1733}
1734
1735/// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead
1736/// of a values buffer.
1737fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1738    if data.is_empty() {
1739        return (
1740            MutableBuffer::new(0).into(),
1741            data.child_data()[0].slice(0, 0),
1742        );
1743    }
1744
1745    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1746    let child_data = data.child_data()[0].slice(original_start_offset, len);
1747    (offsets, child_data)
1748}
1749
1750/// Returns the offsets, sizes, and child data buffers for a ListView array.
1751///
1752/// Unlike List arrays, ListView arrays store both offsets and sizes explicitly,
1753/// and offsets can be non-monotonic. When slicing, we simply pass through the
1754/// offsets and sizes without re-encoding, and do not slice the child data.
1755fn get_list_view_array_buffers<O: OffsetSizeTrait>(
1756    data: &ArrayData,
1757) -> (Buffer, Buffer, ArrayData) {
1758    if data.is_empty() {
1759        return (
1760            MutableBuffer::new(0).into(),
1761            MutableBuffer::new(0).into(),
1762            data.child_data()[0].slice(0, 0),
1763        );
1764    }
1765
1766    let offsets = &data.buffers()[0];
1767    let sizes = &data.buffers()[1];
1768
1769    let element_size = std::mem::size_of::<O>();
1770    let offsets_slice =
1771        offsets.slice_with_length(data.offset() * element_size, data.len() * element_size);
1772    let sizes_slice =
1773        sizes.slice_with_length(data.offset() * element_size, data.len() * element_size);
1774
1775    let child_data = data.child_data()[0].clone();
1776
1777    (offsets_slice, sizes_slice, child_data)
1778}
1779
1780/// Returns the sliced views [`Buffer`] for a BinaryView/Utf8View array.
1781///
1782/// The views buffer is sliced to only include views in the valid range based on
1783/// the array's offset and length. This helps reduce the encoded size of sliced
1784/// arrays
1785///
1786fn get_or_truncate_buffer(array_data: &ArrayData) -> &[u8] {
1787    let buffer = &array_data.buffers()[0];
1788    let layout = layout(array_data.data_type());
1789    let spec = &layout.buffers[0];
1790
1791    let byte_width = get_buffer_element_width(spec);
1792    let min_length = array_data.len() * byte_width;
1793    if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1794        let byte_offset = array_data.offset() * byte_width;
1795        let buffer_length = min(min_length, buffer.len() - byte_offset);
1796        &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1797    } else {
1798        buffer.as_slice()
1799    }
1800}
1801
1802/// Write array data to a vector of bytes
1803#[allow(clippy::too_many_arguments)]
1804fn write_array_data(
1805    array_data: &ArrayData,
1806    buffers: &mut Vec<crate::Buffer>,
1807    arrow_data: &mut Vec<u8>,
1808    nodes: &mut Vec<crate::FieldNode>,
1809    offset: i64,
1810    num_rows: usize,
1811    null_count: usize,
1812    compression_codec: Option<CompressionCodec>,
1813    compression_context: &mut CompressionContext,
1814    write_options: &IpcWriteOptions,
1815) -> Result<i64, ArrowError> {
1816    let mut offset = offset;
1817    if !matches!(array_data.data_type(), DataType::Null) {
1818        nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1819    } else {
1820        // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData
1821        // where null_count is always 0.
1822        nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1823    }
1824    if has_validity_bitmap(array_data.data_type(), write_options) {
1825        // write null buffer if exists
1826        let null_buffer = match array_data.nulls() {
1827            None => {
1828                // create a buffer and fill it with valid bits
1829                let num_bytes = bit_util::ceil(num_rows, 8);
1830                let buffer = MutableBuffer::new(num_bytes);
1831                let buffer = buffer.with_bitset(num_bytes, true);
1832                buffer.into()
1833            }
1834            Some(buffer) => buffer.inner().sliced(),
1835        };
1836
1837        offset = write_buffer(
1838            null_buffer.as_slice(),
1839            buffers,
1840            arrow_data,
1841            offset,
1842            compression_codec,
1843            compression_context,
1844            write_options.alignment,
1845        )?;
1846    }
1847
1848    let data_type = array_data.data_type();
1849    if matches!(data_type, DataType::Binary | DataType::Utf8) {
1850        let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1851        for buffer in [offsets, values] {
1852            offset = write_buffer(
1853                buffer.as_slice(),
1854                buffers,
1855                arrow_data,
1856                offset,
1857                compression_codec,
1858                compression_context,
1859                write_options.alignment,
1860            )?;
1861        }
1862    } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1863        // Slicing the views buffer is safe and easy,
1864        // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers
1865        //
1866        // Current implementation just serialize the raw arrays as given and not try to optimize anything.
1867        // If users wants to "compact" the arrays prior to sending them over IPC,
1868        // they should consider the gc API suggested in #5513
1869        let views = get_or_truncate_buffer(array_data);
1870        offset = write_buffer(
1871            views,
1872            buffers,
1873            arrow_data,
1874            offset,
1875            compression_codec,
1876            compression_context,
1877            write_options.alignment,
1878        )?;
1879
1880        for buffer in array_data.buffers().iter().skip(1) {
1881            offset = write_buffer(
1882                buffer.as_slice(),
1883                buffers,
1884                arrow_data,
1885                offset,
1886                compression_codec,
1887                compression_context,
1888                write_options.alignment,
1889            )?;
1890        }
1891    } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1892        let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1893        for buffer in [offsets, values] {
1894            offset = write_buffer(
1895                buffer.as_slice(),
1896                buffers,
1897                arrow_data,
1898                offset,
1899                compression_codec,
1900                compression_context,
1901                write_options.alignment,
1902            )?;
1903        }
1904    } else if DataType::is_numeric(data_type)
1905        || DataType::is_temporal(data_type)
1906        || matches!(
1907            array_data.data_type(),
1908            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1909        )
1910    {
1911        // Truncate values
1912        assert_eq!(array_data.buffers().len(), 1);
1913
1914        let buffer = get_or_truncate_buffer(array_data);
1915        offset = write_buffer(
1916            buffer,
1917            buffers,
1918            arrow_data,
1919            offset,
1920            compression_codec,
1921            compression_context,
1922            write_options.alignment,
1923        )?;
1924    } else if matches!(data_type, DataType::Boolean) {
1925        // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
1926        // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
1927        assert_eq!(array_data.buffers().len(), 1);
1928
1929        let buffer = &array_data.buffers()[0];
1930        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1931        offset = write_buffer(
1932            &buffer,
1933            buffers,
1934            arrow_data,
1935            offset,
1936            compression_codec,
1937            compression_context,
1938            write_options.alignment,
1939        )?;
1940    } else if matches!(
1941        data_type,
1942        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1943    ) {
1944        assert_eq!(array_data.buffers().len(), 1);
1945        assert_eq!(array_data.child_data().len(), 1);
1946
1947        // Truncate offsets and the child data to avoid writing unnecessary data
1948        let (offsets, sliced_child_data) = match data_type {
1949            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1950            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1951            DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1952            _ => unreachable!(),
1953        };
1954        offset = write_buffer(
1955            offsets.as_slice(),
1956            buffers,
1957            arrow_data,
1958            offset,
1959            compression_codec,
1960            compression_context,
1961            write_options.alignment,
1962        )?;
1963        offset = write_array_data(
1964            &sliced_child_data,
1965            buffers,
1966            arrow_data,
1967            nodes,
1968            offset,
1969            sliced_child_data.len(),
1970            sliced_child_data.null_count(),
1971            compression_codec,
1972            compression_context,
1973            write_options,
1974        )?;
1975        return Ok(offset);
1976    } else if matches!(
1977        data_type,
1978        DataType::ListView(_) | DataType::LargeListView(_)
1979    ) {
1980        assert_eq!(array_data.buffers().len(), 2); // offsets + sizes
1981        assert_eq!(array_data.child_data().len(), 1);
1982
1983        let (offsets, sizes, child_data) = match data_type {
1984            DataType::ListView(_) => get_list_view_array_buffers::<i32>(array_data),
1985            DataType::LargeListView(_) => get_list_view_array_buffers::<i64>(array_data),
1986            _ => unreachable!(),
1987        };
1988
1989        offset = write_buffer(
1990            offsets.as_slice(),
1991            buffers,
1992            arrow_data,
1993            offset,
1994            compression_codec,
1995            compression_context,
1996            write_options.alignment,
1997        )?;
1998
1999        offset = write_buffer(
2000            sizes.as_slice(),
2001            buffers,
2002            arrow_data,
2003            offset,
2004            compression_codec,
2005            compression_context,
2006            write_options.alignment,
2007        )?;
2008
2009        offset = write_array_data(
2010            &child_data,
2011            buffers,
2012            arrow_data,
2013            nodes,
2014            offset,
2015            child_data.len(),
2016            child_data.null_count(),
2017            compression_codec,
2018            compression_context,
2019            write_options,
2020        )?;
2021        return Ok(offset);
2022    } else if let DataType::FixedSizeList(_, fixed_size) = data_type {
2023        assert_eq!(array_data.child_data().len(), 1);
2024        let fixed_size = *fixed_size as usize;
2025
2026        let child_offset = array_data.offset() * fixed_size;
2027        let child_length = array_data.len() * fixed_size;
2028        let child_data = array_data.child_data()[0].slice(child_offset, child_length);
2029
2030        offset = write_array_data(
2031            &child_data,
2032            buffers,
2033            arrow_data,
2034            nodes,
2035            offset,
2036            child_data.len(),
2037            child_data.null_count(),
2038            compression_codec,
2039            compression_context,
2040            write_options,
2041        )?;
2042        return Ok(offset);
2043    } else {
2044        for buffer in array_data.buffers() {
2045            offset = write_buffer(
2046                buffer,
2047                buffers,
2048                arrow_data,
2049                offset,
2050                compression_codec,
2051                compression_context,
2052                write_options.alignment,
2053            )?;
2054        }
2055    }
2056
2057    match array_data.data_type() {
2058        DataType::Dictionary(_, _) => {}
2059        DataType::RunEndEncoded(_, _) => {
2060            // unslice the run encoded array.
2061            let arr = unslice_run_array(array_data.clone())?;
2062            // recursively write out nested structures
2063            for data_ref in arr.child_data() {
2064                // write the nested data (e.g list data)
2065                offset = write_array_data(
2066                    data_ref,
2067                    buffers,
2068                    arrow_data,
2069                    nodes,
2070                    offset,
2071                    data_ref.len(),
2072                    data_ref.null_count(),
2073                    compression_codec,
2074                    compression_context,
2075                    write_options,
2076                )?;
2077            }
2078        }
2079        _ => {
2080            // recursively write out nested structures
2081            for data_ref in array_data.child_data() {
2082                // write the nested data (e.g list data)
2083                offset = write_array_data(
2084                    data_ref,
2085                    buffers,
2086                    arrow_data,
2087                    nodes,
2088                    offset,
2089                    data_ref.len(),
2090                    data_ref.null_count(),
2091                    compression_codec,
2092                    compression_context,
2093                    write_options,
2094                )?;
2095            }
2096        }
2097    }
2098    Ok(offset)
2099}
2100
2101/// Write a buffer into `arrow_data`, a vector of bytes, and adds its
2102/// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data`
2103///
2104///
2105/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
2106/// Each constituent buffer is first compressed with the indicated
2107/// compressor, and then written with the uncompressed length in the first 8
2108/// bytes as a 64-bit little-endian signed integer followed by the compressed
2109/// buffer bytes (and then padding as required by the protocol). The
2110/// uncompressed length may be set to -1 to indicate that the data that
2111/// follows is not compressed, which can be useful for cases where
2112/// compression does not yield appreciable savings.
2113fn write_buffer(
2114    buffer: &[u8],                    // input
2115    buffers: &mut Vec<crate::Buffer>, // output buffer descriptors
2116    arrow_data: &mut Vec<u8>,         // output stream
2117    offset: i64,                      // current output stream offset
2118    compression_codec: Option<CompressionCodec>,
2119    compression_context: &mut CompressionContext,
2120    alignment: u8,
2121) -> Result<i64, ArrowError> {
2122    let len: i64 = match compression_codec {
2123        Some(compressor) => compressor.compress_to_vec(buffer, arrow_data, compression_context)?,
2124        None => {
2125            arrow_data.extend_from_slice(buffer);
2126            buffer.len()
2127        }
2128    }
2129    .try_into()
2130    .map_err(|e| {
2131        ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
2132    })?;
2133
2134    // make new index entry
2135    buffers.push(crate::Buffer::new(offset, len));
2136    // padding and make offset aligned
2137    let pad_len = pad_to_alignment(alignment, len as usize);
2138    arrow_data.extend_from_slice(&PADDING[..pad_len]);
2139
2140    Ok(offset + len + (pad_len as i64))
2141}
2142
2143const PADDING: [u8; 64] = [0; 64];
2144
2145/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary
2146#[inline]
2147fn pad_to_alignment(alignment: u8, len: usize) -> usize {
2148    let a = usize::from(alignment - 1);
2149    ((len + a) & !a) - len
2150}
2151
2152#[cfg(test)]
2153mod tests {
2154    use std::hash::Hasher;
2155    use std::io::Cursor;
2156    use std::io::Seek;
2157
2158    use arrow_array::builder::FixedSizeListBuilder;
2159    use arrow_array::builder::Float32Builder;
2160    use arrow_array::builder::Int64Builder;
2161    use arrow_array::builder::MapBuilder;
2162    use arrow_array::builder::StringViewBuilder;
2163    use arrow_array::builder::UnionBuilder;
2164    use arrow_array::builder::{
2165        GenericListBuilder, GenericListViewBuilder, ListBuilder, StringBuilder,
2166    };
2167    use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
2168    use arrow_array::types::*;
2169    use arrow_buffer::ScalarBuffer;
2170
2171    use crate::MetadataVersion;
2172    use crate::convert::fb_to_schema;
2173    use crate::reader::*;
2174    use crate::root_as_footer;
2175
2176    use super::*;
2177
2178    fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
2179        let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
2180        writer.write(rb).unwrap();
2181        writer.finish().unwrap();
2182        writer.into_inner().unwrap()
2183    }
2184
2185    fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
2186        let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
2187        reader.next().unwrap().unwrap()
2188    }
2189
2190    fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
2191        // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written,
2192        // without needing to construct a giant array to spill over the 64-byte default alignment
2193        // boundary.
2194        const IPC_ALIGNMENT: usize = 8;
2195
2196        let mut stream_writer = StreamWriter::try_new_with_options(
2197            vec![],
2198            record.schema_ref(),
2199            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2200        )
2201        .unwrap();
2202        stream_writer.write(record).unwrap();
2203        stream_writer.finish().unwrap();
2204        stream_writer.into_inner().unwrap()
2205    }
2206
2207    fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
2208        let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
2209        stream_reader.next().unwrap().unwrap()
2210    }
2211
2212    #[test]
2213    #[cfg(feature = "lz4")]
2214    fn test_write_empty_record_batch_lz4_compression() {
2215        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2216        let values: Vec<Option<i32>> = vec![];
2217        let array = Int32Array::from(values);
2218        let record_batch =
2219            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2220
2221        let mut file = tempfile::tempfile().unwrap();
2222
2223        {
2224            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2225                .unwrap()
2226                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2227                .unwrap();
2228
2229            let mut writer =
2230                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2231            writer.write(&record_batch).unwrap();
2232            writer.finish().unwrap();
2233        }
2234        file.rewind().unwrap();
2235        {
2236            // read file
2237            let reader = FileReader::try_new(file, None).unwrap();
2238            for read_batch in reader {
2239                read_batch
2240                    .unwrap()
2241                    .columns()
2242                    .iter()
2243                    .zip(record_batch.columns())
2244                    .for_each(|(a, b)| {
2245                        assert_eq!(a.data_type(), b.data_type());
2246                        assert_eq!(a.len(), b.len());
2247                        assert_eq!(a.null_count(), b.null_count());
2248                    });
2249            }
2250        }
2251    }
2252
2253    #[test]
2254    #[cfg(feature = "lz4")]
2255    fn test_write_file_with_lz4_compression() {
2256        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2257        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2258        let array = Int32Array::from(values);
2259        let record_batch =
2260            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2261
2262        let mut file = tempfile::tempfile().unwrap();
2263        {
2264            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2265                .unwrap()
2266                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2267                .unwrap();
2268
2269            let mut writer =
2270                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2271            writer.write(&record_batch).unwrap();
2272            writer.finish().unwrap();
2273        }
2274        file.rewind().unwrap();
2275        {
2276            // read file
2277            let reader = FileReader::try_new(file, None).unwrap();
2278            for read_batch in reader {
2279                read_batch
2280                    .unwrap()
2281                    .columns()
2282                    .iter()
2283                    .zip(record_batch.columns())
2284                    .for_each(|(a, b)| {
2285                        assert_eq!(a.data_type(), b.data_type());
2286                        assert_eq!(a.len(), b.len());
2287                        assert_eq!(a.null_count(), b.null_count());
2288                    });
2289            }
2290        }
2291    }
2292
2293    #[test]
2294    #[cfg(feature = "zstd")]
2295    fn test_write_file_with_zstd_compression() {
2296        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2297        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2298        let array = Int32Array::from(values);
2299        let record_batch =
2300            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2301        let mut file = tempfile::tempfile().unwrap();
2302        {
2303            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2304                .unwrap()
2305                .try_with_compression(Some(crate::CompressionType::ZSTD))
2306                .unwrap();
2307
2308            let mut writer =
2309                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2310            writer.write(&record_batch).unwrap();
2311            writer.finish().unwrap();
2312        }
2313        file.rewind().unwrap();
2314        {
2315            // read file
2316            let reader = FileReader::try_new(file, None).unwrap();
2317            for read_batch in reader {
2318                read_batch
2319                    .unwrap()
2320                    .columns()
2321                    .iter()
2322                    .zip(record_batch.columns())
2323                    .for_each(|(a, b)| {
2324                        assert_eq!(a.data_type(), b.data_type());
2325                        assert_eq!(a.len(), b.len());
2326                        assert_eq!(a.null_count(), b.null_count());
2327                    });
2328            }
2329        }
2330    }
2331
2332    #[test]
2333    fn test_write_file() {
2334        let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
2335        let values: Vec<Option<u32>> = vec![
2336            Some(999),
2337            None,
2338            Some(235),
2339            Some(123),
2340            None,
2341            None,
2342            None,
2343            None,
2344            None,
2345        ];
2346        let array1 = UInt32Array::from(values);
2347        let batch =
2348            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
2349                .unwrap();
2350        let mut file = tempfile::tempfile().unwrap();
2351        {
2352            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2353
2354            writer.write(&batch).unwrap();
2355            writer.finish().unwrap();
2356        }
2357        file.rewind().unwrap();
2358
2359        {
2360            let mut reader = FileReader::try_new(file, None).unwrap();
2361            while let Some(Ok(read_batch)) = reader.next() {
2362                read_batch
2363                    .columns()
2364                    .iter()
2365                    .zip(batch.columns())
2366                    .for_each(|(a, b)| {
2367                        assert_eq!(a.data_type(), b.data_type());
2368                        assert_eq!(a.len(), b.len());
2369                        assert_eq!(a.null_count(), b.null_count());
2370                    });
2371            }
2372        }
2373    }
2374
2375    fn write_null_file(options: IpcWriteOptions) {
2376        let schema = Schema::new(vec![
2377            Field::new("nulls", DataType::Null, true),
2378            Field::new("int32s", DataType::Int32, false),
2379            Field::new("nulls2", DataType::Null, true),
2380            Field::new("f64s", DataType::Float64, false),
2381        ]);
2382        let array1 = NullArray::new(32);
2383        let array2 = Int32Array::from(vec![1; 32]);
2384        let array3 = NullArray::new(32);
2385        let array4 = Float64Array::from(vec![f64::NAN; 32]);
2386        let batch = RecordBatch::try_new(
2387            Arc::new(schema.clone()),
2388            vec![
2389                Arc::new(array1) as ArrayRef,
2390                Arc::new(array2) as ArrayRef,
2391                Arc::new(array3) as ArrayRef,
2392                Arc::new(array4) as ArrayRef,
2393            ],
2394        )
2395        .unwrap();
2396        let mut file = tempfile::tempfile().unwrap();
2397        {
2398            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2399
2400            writer.write(&batch).unwrap();
2401            writer.finish().unwrap();
2402        }
2403
2404        file.rewind().unwrap();
2405
2406        {
2407            let reader = FileReader::try_new(file, None).unwrap();
2408            reader.for_each(|maybe_batch| {
2409                maybe_batch
2410                    .unwrap()
2411                    .columns()
2412                    .iter()
2413                    .zip(batch.columns())
2414                    .for_each(|(a, b)| {
2415                        assert_eq!(a.data_type(), b.data_type());
2416                        assert_eq!(a.len(), b.len());
2417                        assert_eq!(a.null_count(), b.null_count());
2418                    });
2419            });
2420        }
2421    }
2422    #[test]
2423    fn test_write_null_file_v4() {
2424        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2425        write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2426        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2427        write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2428    }
2429
2430    #[test]
2431    fn test_write_null_file_v5() {
2432        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2433        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2434    }
2435
2436    #[test]
2437    fn track_union_nested_dict() {
2438        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2439
2440        let array = Arc::new(inner) as ArrayRef;
2441
2442        // Dict field with id 2
2443        #[allow(deprecated)]
2444        let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 0, false);
2445        let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2446
2447        let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2448        let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2449
2450        let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2451
2452        let schema = Arc::new(Schema::new(vec![Field::new(
2453            "union",
2454            union.data_type().clone(),
2455            false,
2456        )]));
2457
2458        let r#gen = IpcDataGenerator::default();
2459        let mut dict_tracker = DictionaryTracker::new(false);
2460        r#gen.schema_to_bytes_with_dictionary_tracker(
2461            &schema,
2462            &mut dict_tracker,
2463            &IpcWriteOptions::default(),
2464        );
2465
2466        let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2467
2468        r#gen
2469            .encode(
2470                &batch,
2471                &mut dict_tracker,
2472                &Default::default(),
2473                &mut Default::default(),
2474            )
2475            .unwrap();
2476
2477        // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema
2478        // so we expect the dict will be keyed to 0
2479        assert!(dict_tracker.written.contains_key(&0));
2480    }
2481
2482    #[test]
2483    fn track_struct_nested_dict() {
2484        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2485
2486        let array = Arc::new(inner) as ArrayRef;
2487
2488        // Dict field with id 2
2489        #[allow(deprecated)]
2490        let dctfield = Arc::new(Field::new_dict(
2491            "dict",
2492            array.data_type().clone(),
2493            false,
2494            2,
2495            false,
2496        ));
2497
2498        let s = StructArray::from(vec![(dctfield, array)]);
2499        let struct_array = Arc::new(s) as ArrayRef;
2500
2501        let schema = Arc::new(Schema::new(vec![Field::new(
2502            "struct",
2503            struct_array.data_type().clone(),
2504            false,
2505        )]));
2506
2507        let r#gen = IpcDataGenerator::default();
2508        let mut dict_tracker = DictionaryTracker::new(false);
2509        r#gen.schema_to_bytes_with_dictionary_tracker(
2510            &schema,
2511            &mut dict_tracker,
2512            &IpcWriteOptions::default(),
2513        );
2514
2515        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2516
2517        r#gen
2518            .encode(
2519                &batch,
2520                &mut dict_tracker,
2521                &Default::default(),
2522                &mut Default::default(),
2523            )
2524            .unwrap();
2525
2526        assert!(dict_tracker.written.contains_key(&0));
2527    }
2528
2529    fn write_union_file(options: IpcWriteOptions) {
2530        let schema = Schema::new(vec![Field::new_union(
2531            "union",
2532            vec![0, 1],
2533            vec![
2534                Field::new("a", DataType::Int32, false),
2535                Field::new("c", DataType::Float64, false),
2536            ],
2537            UnionMode::Sparse,
2538        )]);
2539        let mut builder = UnionBuilder::with_capacity_sparse(5);
2540        builder.append::<Int32Type>("a", 1).unwrap();
2541        builder.append_null::<Int32Type>("a").unwrap();
2542        builder.append::<Float64Type>("c", 3.0).unwrap();
2543        builder.append_null::<Float64Type>("c").unwrap();
2544        builder.append::<Int32Type>("a", 4).unwrap();
2545        let union = builder.build().unwrap();
2546
2547        let batch =
2548            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2549                .unwrap();
2550
2551        let mut file = tempfile::tempfile().unwrap();
2552        {
2553            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2554
2555            writer.write(&batch).unwrap();
2556            writer.finish().unwrap();
2557        }
2558        file.rewind().unwrap();
2559
2560        {
2561            let reader = FileReader::try_new(file, None).unwrap();
2562            reader.for_each(|maybe_batch| {
2563                maybe_batch
2564                    .unwrap()
2565                    .columns()
2566                    .iter()
2567                    .zip(batch.columns())
2568                    .for_each(|(a, b)| {
2569                        assert_eq!(a.data_type(), b.data_type());
2570                        assert_eq!(a.len(), b.len());
2571                        assert_eq!(a.null_count(), b.null_count());
2572                    });
2573            });
2574        }
2575    }
2576
2577    #[test]
2578    fn test_write_union_file_v4_v5() {
2579        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2580        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2581    }
2582
2583    #[test]
2584    fn test_write_view_types() {
2585        const LONG_TEST_STRING: &str =
2586            "This is a long string to make sure binary view array handles it";
2587        let schema = Schema::new(vec![
2588            Field::new("field1", DataType::BinaryView, true),
2589            Field::new("field2", DataType::Utf8View, true),
2590        ]);
2591        let values: Vec<Option<&[u8]>> = vec![
2592            Some(b"foo"),
2593            Some(b"bar"),
2594            Some(LONG_TEST_STRING.as_bytes()),
2595        ];
2596        let binary_array = BinaryViewArray::from_iter(values);
2597        let utf8_array =
2598            StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2599        let record_batch = RecordBatch::try_new(
2600            Arc::new(schema.clone()),
2601            vec![Arc::new(binary_array), Arc::new(utf8_array)],
2602        )
2603        .unwrap();
2604
2605        let mut file = tempfile::tempfile().unwrap();
2606        {
2607            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2608            writer.write(&record_batch).unwrap();
2609            writer.finish().unwrap();
2610        }
2611        file.rewind().unwrap();
2612        {
2613            let mut reader = FileReader::try_new(&file, None).unwrap();
2614            let read_batch = reader.next().unwrap().unwrap();
2615            read_batch
2616                .columns()
2617                .iter()
2618                .zip(record_batch.columns())
2619                .for_each(|(a, b)| {
2620                    assert_eq!(a, b);
2621                });
2622        }
2623        file.rewind().unwrap();
2624        {
2625            let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2626            let read_batch = reader.next().unwrap().unwrap();
2627            assert_eq!(read_batch.num_columns(), 1);
2628            let read_array = read_batch.column(0);
2629            let write_array = record_batch.column(0);
2630            assert_eq!(read_array, write_array);
2631        }
2632    }
2633
2634    #[test]
2635    fn truncate_ipc_record_batch() {
2636        fn create_batch(rows: usize) -> RecordBatch {
2637            let schema = Schema::new(vec![
2638                Field::new("a", DataType::Int32, false),
2639                Field::new("b", DataType::Utf8, false),
2640            ]);
2641
2642            let a = Int32Array::from_iter_values(0..rows as i32);
2643            let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2644
2645            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2646        }
2647
2648        let big_record_batch = create_batch(65536);
2649
2650        let length = 5;
2651        let small_record_batch = create_batch(length);
2652
2653        let offset = 2;
2654        let record_batch_slice = big_record_batch.slice(offset, length);
2655        assert!(
2656            serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2657        );
2658        assert_eq!(
2659            serialize_stream(&small_record_batch).len(),
2660            serialize_stream(&record_batch_slice).len()
2661        );
2662
2663        assert_eq!(
2664            deserialize_stream(serialize_stream(&record_batch_slice)),
2665            record_batch_slice
2666        );
2667    }
2668
2669    #[test]
2670    fn truncate_ipc_record_batch_with_nulls() {
2671        fn create_batch() -> RecordBatch {
2672            let schema = Schema::new(vec![
2673                Field::new("a", DataType::Int32, true),
2674                Field::new("b", DataType::Utf8, true),
2675            ]);
2676
2677            let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2678            let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2679
2680            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2681        }
2682
2683        let record_batch = create_batch();
2684        let record_batch_slice = record_batch.slice(1, 2);
2685        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2686
2687        assert!(
2688            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2689        );
2690
2691        assert!(deserialized_batch.column(0).is_null(0));
2692        assert!(deserialized_batch.column(0).is_valid(1));
2693        assert!(deserialized_batch.column(1).is_valid(0));
2694        assert!(deserialized_batch.column(1).is_valid(1));
2695
2696        assert_eq!(record_batch_slice, deserialized_batch);
2697    }
2698
2699    #[test]
2700    fn truncate_ipc_dictionary_array() {
2701        fn create_batch() -> RecordBatch {
2702            let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2703                .into_iter()
2704                .collect();
2705            let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2706
2707            let array = DictionaryArray::new(keys, Arc::new(values));
2708
2709            let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2710
2711            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2712        }
2713
2714        let record_batch = create_batch();
2715        let record_batch_slice = record_batch.slice(1, 2);
2716        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2717
2718        assert!(
2719            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2720        );
2721
2722        assert!(deserialized_batch.column(0).is_valid(0));
2723        assert!(deserialized_batch.column(0).is_null(1));
2724
2725        assert_eq!(record_batch_slice, deserialized_batch);
2726    }
2727
2728    #[test]
2729    fn truncate_ipc_struct_array() {
2730        fn create_batch() -> RecordBatch {
2731            let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2732                .into_iter()
2733                .collect();
2734            let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2735
2736            let struct_array = StructArray::from(vec![
2737                (
2738                    Arc::new(Field::new("s", DataType::Utf8, true)),
2739                    Arc::new(strings) as ArrayRef,
2740                ),
2741                (
2742                    Arc::new(Field::new("c", DataType::Int32, true)),
2743                    Arc::new(ints) as ArrayRef,
2744                ),
2745            ]);
2746
2747            let schema = Schema::new(vec![Field::new(
2748                "struct_array",
2749                struct_array.data_type().clone(),
2750                true,
2751            )]);
2752
2753            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2754        }
2755
2756        let record_batch = create_batch();
2757        let record_batch_slice = record_batch.slice(1, 2);
2758        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2759
2760        assert!(
2761            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2762        );
2763
2764        let structs = deserialized_batch
2765            .column(0)
2766            .as_any()
2767            .downcast_ref::<StructArray>()
2768            .unwrap();
2769
2770        assert!(structs.column(0).is_null(0));
2771        assert!(structs.column(0).is_valid(1));
2772        assert!(structs.column(1).is_valid(0));
2773        assert!(structs.column(1).is_null(1));
2774        assert_eq!(record_batch_slice, deserialized_batch);
2775    }
2776
2777    #[test]
2778    fn truncate_ipc_string_array_with_all_empty_string() {
2779        fn create_batch() -> RecordBatch {
2780            let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2781            let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2782            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2783        }
2784
2785        let record_batch = create_batch();
2786        let record_batch_slice = record_batch.slice(0, 1);
2787        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2788
2789        assert!(
2790            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2791        );
2792        assert_eq!(record_batch_slice, deserialized_batch);
2793    }
2794
2795    #[test]
2796    fn test_stream_writer_writes_array_slice() {
2797        let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2798        assert_eq!(
2799            vec![Some(1), Some(2), Some(3)],
2800            array.iter().collect::<Vec<_>>()
2801        );
2802
2803        let sliced = array.slice(1, 2);
2804        assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2805
2806        let batch = RecordBatch::try_new(
2807            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2808            vec![Arc::new(sliced)],
2809        )
2810        .expect("new batch");
2811
2812        let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2813        writer.write(&batch).expect("write");
2814        let outbuf = writer.into_inner().expect("inner");
2815
2816        let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2817        let read_batch = reader.next().unwrap().expect("read batch");
2818
2819        let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2820        assert_eq!(
2821            vec![Some(2), Some(3)],
2822            read_array.iter().collect::<Vec<_>>()
2823        );
2824    }
2825
2826    #[test]
2827    fn test_large_slice_uint32() {
2828        ensure_roundtrip(Arc::new(UInt32Array::from_iter(
2829            (0..8000).map(|i| if i % 2 == 0 { Some(i) } else { None }),
2830        )));
2831    }
2832
2833    #[test]
2834    fn test_large_slice_string() {
2835        let strings: Vec<_> = (0..8000)
2836            .map(|i| {
2837                if i % 2 == 0 {
2838                    Some(format!("value{i}"))
2839                } else {
2840                    None
2841                }
2842            })
2843            .collect();
2844
2845        ensure_roundtrip(Arc::new(StringArray::from(strings)));
2846    }
2847
2848    #[test]
2849    fn test_large_slice_string_list() {
2850        let mut ls = ListBuilder::new(StringBuilder::new());
2851
2852        let mut s = String::new();
2853        for row_number in 0..8000 {
2854            if row_number % 2 == 0 {
2855                for list_element in 0..1000 {
2856                    s.clear();
2857                    use std::fmt::Write;
2858                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2859                    ls.values().append_value(&s);
2860                }
2861                ls.append(true)
2862            } else {
2863                ls.append(false); // null
2864            }
2865        }
2866
2867        ensure_roundtrip(Arc::new(ls.finish()));
2868    }
2869
2870    #[test]
2871    fn test_large_slice_string_list_of_lists() {
2872        // The reason for the special test is to verify reencode_offsets which looks both at
2873        // the starting offset and the data offset.  So need a dataset where the starting_offset
2874        // is zero but the data offset is not.
2875        let mut ls = ListBuilder::new(ListBuilder::new(StringBuilder::new()));
2876
2877        for _ in 0..4000 {
2878            ls.values().append(true);
2879            ls.append(true)
2880        }
2881
2882        let mut s = String::new();
2883        for row_number in 0..4000 {
2884            if row_number % 2 == 0 {
2885                for list_element in 0..1000 {
2886                    s.clear();
2887                    use std::fmt::Write;
2888                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2889                    ls.values().values().append_value(&s);
2890                }
2891                ls.values().append(true);
2892                ls.append(true)
2893            } else {
2894                ls.append(false); // null
2895            }
2896        }
2897
2898        ensure_roundtrip(Arc::new(ls.finish()));
2899    }
2900
2901    /// Read/write a record batch to a File and Stream and ensure it is the same at the outout
2902    fn ensure_roundtrip(array: ArrayRef) {
2903        let num_rows = array.len();
2904        let orig_batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap();
2905        // take off the first element
2906        let sliced_batch = orig_batch.slice(1, num_rows - 1);
2907
2908        let schema = orig_batch.schema();
2909        let stream_data = {
2910            let mut writer = StreamWriter::try_new(vec![], &schema).unwrap();
2911            writer.write(&sliced_batch).unwrap();
2912            writer.into_inner().unwrap()
2913        };
2914        let read_batch = {
2915            let projection = None;
2916            let mut reader = StreamReader::try_new(Cursor::new(stream_data), projection).unwrap();
2917            reader
2918                .next()
2919                .expect("expect no errors reading batch")
2920                .expect("expect batch")
2921        };
2922        assert_eq!(sliced_batch, read_batch);
2923
2924        let file_data = {
2925            let mut writer = FileWriter::try_new_buffered(vec![], &schema).unwrap();
2926            writer.write(&sliced_batch).unwrap();
2927            writer.into_inner().unwrap().into_inner().unwrap()
2928        };
2929        let read_batch = {
2930            let projection = None;
2931            let mut reader = FileReader::try_new(Cursor::new(file_data), projection).unwrap();
2932            reader
2933                .next()
2934                .expect("expect no errors reading batch")
2935                .expect("expect batch")
2936        };
2937        assert_eq!(sliced_batch, read_batch);
2938
2939        // TODO test file writer/reader
2940    }
2941
2942    #[test]
2943    fn encode_bools_slice() {
2944        // Test case for https://github.com/apache/arrow-rs/issues/3496
2945        assert_bool_roundtrip([true, false], 1, 1);
2946
2947        // slice somewhere in the middle
2948        assert_bool_roundtrip(
2949            [
2950                true, false, true, true, false, false, true, true, true, false, false, false, true,
2951                true, true, true, false, false, false, false, true, true, true, true, true, false,
2952                false, false, false, false,
2953            ],
2954            13,
2955            17,
2956        );
2957
2958        // start at byte boundary, end in the middle
2959        assert_bool_roundtrip(
2960            [
2961                true, false, true, true, false, false, true, true, true, false, false, false,
2962            ],
2963            8,
2964            2,
2965        );
2966
2967        // start and stop and byte boundary
2968        assert_bool_roundtrip(
2969            [
2970                true, false, true, true, false, false, true, true, true, false, false, false, true,
2971                true, true, true, true, false, false, false, false, false,
2972            ],
2973            8,
2974            8,
2975        );
2976    }
2977
2978    fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2979        let val_bool_field = Field::new("val", DataType::Boolean, false);
2980
2981        let schema = Arc::new(Schema::new(vec![val_bool_field]));
2982
2983        let bools = BooleanArray::from(bools.to_vec());
2984
2985        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2986        let batch = batch.slice(offset, length);
2987
2988        let data = serialize_stream(&batch);
2989        let batch2 = deserialize_stream(data);
2990        assert_eq!(batch, batch2);
2991    }
2992
2993    #[test]
2994    fn test_run_array_unslice() {
2995        let total_len = 80;
2996        let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2997        let repeats: Vec<usize> = vec![3, 4, 1, 2];
2998        let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2999        for ix in 0_usize..32 {
3000            let repeat: usize = repeats[ix % repeats.len()];
3001            let val: Option<i32> = vals[ix % vals.len()];
3002            input_array.resize(input_array.len() + repeat, val);
3003        }
3004
3005        // Encode the input_array to run array
3006        let mut builder =
3007            PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
3008        builder.extend(input_array.iter().copied());
3009        let run_array = builder.finish();
3010
3011        // test for all slice lengths.
3012        for slice_len in 1..=total_len {
3013            // test for offset = 0, slice length = slice_len
3014            let sliced_run_array: RunArray<Int16Type> =
3015                run_array.slice(0, slice_len).into_data().into();
3016
3017            // Create unsliced run array.
3018            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
3019            let typed = unsliced_run_array
3020                .downcast::<PrimitiveArray<Int32Type>>()
3021                .unwrap();
3022            let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
3023            let actual: Vec<Option<i32>> = typed.into_iter().collect();
3024            assert_eq!(expected, actual);
3025
3026            // test for offset = total_len - slice_len, length = slice_len
3027            let sliced_run_array: RunArray<Int16Type> = run_array
3028                .slice(total_len - slice_len, slice_len)
3029                .into_data()
3030                .into();
3031
3032            // Create unsliced run array.
3033            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
3034            let typed = unsliced_run_array
3035                .downcast::<PrimitiveArray<Int32Type>>()
3036                .unwrap();
3037            let expected: Vec<Option<i32>> = input_array
3038                .iter()
3039                .skip(total_len - slice_len)
3040                .copied()
3041                .collect();
3042            let actual: Vec<Option<i32>> = typed.into_iter().collect();
3043            assert_eq!(expected, actual);
3044        }
3045    }
3046
3047    fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3048        let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
3049
3050        for i in 0..100_000 {
3051            for value in [i, i, i] {
3052                ls.values().append_value(value);
3053            }
3054            ls.append(true)
3055        }
3056
3057        ls.finish()
3058    }
3059
3060    fn generate_utf8view_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3061        let mut ls = GenericListBuilder::<O, _>::new(StringViewBuilder::new());
3062
3063        for i in 0..100_000 {
3064            for value in [
3065                format!("value{}", i),
3066                format!("value{}", i),
3067                format!("value{}", i),
3068            ] {
3069                ls.values().append_value(&value);
3070            }
3071            ls.append(true)
3072        }
3073
3074        ls.finish()
3075    }
3076
3077    fn generate_string_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3078        let mut ls = GenericListBuilder::<O, _>::new(StringBuilder::new());
3079
3080        for i in 0..100_000 {
3081            for value in [
3082                format!("value{}", i),
3083                format!("value{}", i),
3084                format!("value{}", i),
3085            ] {
3086                ls.values().append_value(&value);
3087            }
3088            ls.append(true)
3089        }
3090
3091        ls.finish()
3092    }
3093
3094    fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3095        let mut ls =
3096            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
3097
3098        for _i in 0..10_000 {
3099            for j in 0..10 {
3100                for value in [j, j, j, j] {
3101                    ls.values().values().append_value(value);
3102                }
3103                ls.values().append(true)
3104            }
3105            ls.append(true);
3106        }
3107
3108        ls.finish()
3109    }
3110
3111    fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
3112        let mut ls =
3113            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
3114
3115        for _i in 0..999 {
3116            ls.values().append(true);
3117            ls.append(true);
3118        }
3119
3120        for j in 0..10 {
3121            for value in [j, j, j, j] {
3122                ls.values().values().append_value(value);
3123            }
3124            ls.values().append(true)
3125        }
3126        ls.append(true);
3127
3128        for i in 0..9_000 {
3129            for j in 0..10 {
3130                for value in [i + j, i + j, i + j, i + j] {
3131                    ls.values().values().append_value(value);
3132                }
3133                ls.values().append(true)
3134            }
3135            ls.append(true);
3136        }
3137
3138        ls.finish()
3139    }
3140
3141    fn generate_map_array_data() -> MapArray {
3142        let keys_builder = UInt32Builder::new();
3143        let values_builder = UInt32Builder::new();
3144
3145        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
3146
3147        for i in 0..100_000 {
3148            for _j in 0..3 {
3149                builder.keys().append_value(i);
3150                builder.values().append_value(i * 2);
3151            }
3152            builder.append(true).unwrap();
3153        }
3154
3155        builder.finish()
3156    }
3157
3158    #[test]
3159    fn reencode_offsets_when_first_offset_is_not_zero() {
3160        let original_list = generate_list_data::<i32>();
3161        let original_data = original_list.into_data();
3162        let slice_data = original_data.slice(75, 7);
3163        let (new_offsets, original_start, length) =
3164            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
3165        assert_eq!(
3166            vec![0, 3, 6, 9, 12, 15, 18, 21],
3167            new_offsets.typed_data::<i32>()
3168        );
3169        assert_eq!(225, original_start);
3170        assert_eq!(21, length);
3171    }
3172
3173    #[test]
3174    fn reencode_offsets_when_first_offset_is_zero() {
3175        let mut ls = GenericListBuilder::<i32, _>::new(UInt32Builder::new());
3176        // ls = [[], [35, 42]
3177        ls.append(true);
3178        ls.values().append_value(35);
3179        ls.values().append_value(42);
3180        ls.append(true);
3181        let original_list = ls.finish();
3182        let original_data = original_list.into_data();
3183
3184        let slice_data = original_data.slice(1, 1);
3185        let (new_offsets, original_start, length) =
3186            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
3187        assert_eq!(vec![0, 2], new_offsets.typed_data::<i32>());
3188        assert_eq!(0, original_start);
3189        assert_eq!(2, length);
3190    }
3191
3192    /// Ensure when serde full & sliced versions they are equal to original input.
3193    /// Also ensure serialized sliced version is significantly smaller than serialized full.
3194    fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
3195        // test both full and sliced versions
3196        let in_sliced = in_batch.slice(999, 1);
3197
3198        let bytes_batch = serialize_file(&in_batch);
3199        let bytes_sliced = serialize_file(&in_sliced);
3200
3201        // serializing 1 row should be significantly smaller than serializing 100,000
3202        assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
3203
3204        // ensure both are still valid and equal to originals
3205        let out_batch = deserialize_file(bytes_batch);
3206        assert_eq!(in_batch, out_batch);
3207
3208        let out_sliced = deserialize_file(bytes_sliced);
3209        assert_eq!(in_sliced, out_sliced);
3210    }
3211
3212    #[test]
3213    fn encode_lists() {
3214        let val_inner = Field::new_list_field(DataType::UInt32, true);
3215        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3216        let schema = Arc::new(Schema::new(vec![val_list_field]));
3217
3218        let values = Arc::new(generate_list_data::<i32>());
3219
3220        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3221        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3222    }
3223
3224    #[test]
3225    fn encode_empty_list() {
3226        let val_inner = Field::new_list_field(DataType::UInt32, true);
3227        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3228        let schema = Arc::new(Schema::new(vec![val_list_field]));
3229
3230        let values = Arc::new(generate_list_data::<i32>());
3231
3232        let in_batch = RecordBatch::try_new(schema, vec![values])
3233            .unwrap()
3234            .slice(999, 0);
3235        let out_batch = deserialize_file(serialize_file(&in_batch));
3236        assert_eq!(in_batch, out_batch);
3237    }
3238
3239    #[test]
3240    fn encode_large_lists() {
3241        let val_inner = Field::new_list_field(DataType::UInt32, true);
3242        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3243        let schema = Arc::new(Schema::new(vec![val_list_field]));
3244
3245        let values = Arc::new(generate_list_data::<i64>());
3246
3247        // ensure when serde full & sliced versions they are equal to original input
3248        // also ensure serialized sliced version is significantly smaller than serialized full
3249        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3250        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3251    }
3252
3253    #[test]
3254    fn encode_large_lists_non_zero_offset() {
3255        let val_inner = Field::new_list_field(DataType::UInt32, true);
3256        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3257        let schema = Arc::new(Schema::new(vec![val_list_field]));
3258
3259        let values = Arc::new(generate_list_data::<i64>());
3260
3261        check_sliced_list_array(schema, values);
3262    }
3263
3264    #[test]
3265    fn encode_large_lists_string_non_zero_offset() {
3266        let val_inner = Field::new_list_field(DataType::Utf8, true);
3267        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3268        let schema = Arc::new(Schema::new(vec![val_list_field]));
3269
3270        let values = Arc::new(generate_string_list_data::<i64>());
3271
3272        check_sliced_list_array(schema, values);
3273    }
3274
3275    #[test]
3276    fn encode_large_list_string_view_non_zero_offset() {
3277        let val_inner = Field::new_list_field(DataType::Utf8View, true);
3278        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3279        let schema = Arc::new(Schema::new(vec![val_list_field]));
3280
3281        let values = Arc::new(generate_utf8view_list_data::<i64>());
3282
3283        check_sliced_list_array(schema, values);
3284    }
3285
3286    fn check_sliced_list_array(schema: Arc<Schema>, values: Arc<GenericListArray<i64>>) {
3287        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3288            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3289                .unwrap()
3290                .slice(offset, len);
3291            let out_batch = deserialize_file(serialize_file(&in_batch));
3292            assert_eq!(in_batch, out_batch);
3293        }
3294    }
3295
3296    #[test]
3297    fn encode_nested_lists() {
3298        let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
3299        let inner_list_field = Arc::new(Field::new_list_field(DataType::List(inner_int), true));
3300        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3301        let schema = Arc::new(Schema::new(vec![list_field]));
3302
3303        let values = Arc::new(generate_nested_list_data::<i32>());
3304
3305        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3306        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3307    }
3308
3309    #[test]
3310    fn encode_nested_lists_starting_at_zero() {
3311        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
3312        let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
3313        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3314        let schema = Arc::new(Schema::new(vec![list_field]));
3315
3316        let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
3317
3318        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3319        roundtrip_ensure_sliced_smaller(in_batch, 1);
3320    }
3321
3322    #[test]
3323    fn encode_map_array() {
3324        let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
3325        let values = Arc::new(Field::new("values", DataType::UInt32, true));
3326        let map_field = Field::new_map("map", "entries", keys, values, false, true);
3327        let schema = Arc::new(Schema::new(vec![map_field]));
3328
3329        let values = Arc::new(generate_map_array_data());
3330
3331        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3332        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3333    }
3334
3335    fn generate_list_view_data<O: OffsetSizeTrait>() -> GenericListViewArray<O> {
3336        let mut builder = GenericListViewBuilder::<O, _>::new(UInt32Builder::new());
3337
3338        for i in 0u32..100_000 {
3339            if i.is_multiple_of(10_000) {
3340                builder.append(false);
3341                continue;
3342            }
3343            for value in [i, i, i] {
3344                builder.values().append_value(value);
3345            }
3346            builder.append(true);
3347        }
3348
3349        builder.finish()
3350    }
3351
3352    #[test]
3353    fn encode_list_view_arrays() {
3354        let val_inner = Field::new_list_field(DataType::UInt32, true);
3355        let val_field = Field::new("val", DataType::ListView(Arc::new(val_inner)), true);
3356        let schema = Arc::new(Schema::new(vec![val_field]));
3357
3358        let values = Arc::new(generate_list_view_data::<i32>());
3359
3360        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3361        let out_batch = deserialize_file(serialize_file(&in_batch));
3362        assert_eq!(in_batch, out_batch);
3363    }
3364
3365    #[test]
3366    fn encode_large_list_view_arrays() {
3367        let val_inner = Field::new_list_field(DataType::UInt32, true);
3368        let val_field = Field::new("val", DataType::LargeListView(Arc::new(val_inner)), true);
3369        let schema = Arc::new(Schema::new(vec![val_field]));
3370
3371        let values = Arc::new(generate_list_view_data::<i64>());
3372
3373        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3374        let out_batch = deserialize_file(serialize_file(&in_batch));
3375        assert_eq!(in_batch, out_batch);
3376    }
3377
3378    #[test]
3379    fn check_sliced_list_view_array() {
3380        let inner = Field::new_list_field(DataType::UInt32, true);
3381        let field = Field::new("val", DataType::ListView(Arc::new(inner)), true);
3382        let schema = Arc::new(Schema::new(vec![field]));
3383        let values = Arc::new(generate_list_view_data::<i32>());
3384
3385        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3386            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3387                .unwrap()
3388                .slice(offset, len);
3389            let out_batch = deserialize_file(serialize_file(&in_batch));
3390            assert_eq!(in_batch, out_batch);
3391        }
3392    }
3393
3394    #[test]
3395    fn check_sliced_large_list_view_array() {
3396        let inner = Field::new_list_field(DataType::UInt32, true);
3397        let field = Field::new("val", DataType::LargeListView(Arc::new(inner)), true);
3398        let schema = Arc::new(Schema::new(vec![field]));
3399        let values = Arc::new(generate_list_view_data::<i64>());
3400
3401        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3402            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3403                .unwrap()
3404                .slice(offset, len);
3405            let out_batch = deserialize_file(serialize_file(&in_batch));
3406            assert_eq!(in_batch, out_batch);
3407        }
3408    }
3409
3410    fn generate_nested_list_view_data<O: OffsetSizeTrait>() -> GenericListViewArray<O> {
3411        let inner_builder = UInt32Builder::new();
3412        let middle_builder = GenericListViewBuilder::<O, _>::new(inner_builder);
3413        let mut outer_builder = GenericListViewBuilder::<O, _>::new(middle_builder);
3414
3415        for i in 0u32..10_000 {
3416            if i.is_multiple_of(1_000) {
3417                outer_builder.append(false);
3418                continue;
3419            }
3420
3421            for _ in 0..3 {
3422                for value in [i, i + 1, i + 2] {
3423                    outer_builder.values().values().append_value(value);
3424                }
3425                outer_builder.values().append(true);
3426            }
3427            outer_builder.append(true);
3428        }
3429
3430        outer_builder.finish()
3431    }
3432
3433    #[test]
3434    fn encode_nested_list_views() {
3435        let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
3436        let inner_list_field = Arc::new(Field::new_list_field(DataType::ListView(inner_int), true));
3437        let list_field = Field::new("val", DataType::ListView(inner_list_field), true);
3438        let schema = Arc::new(Schema::new(vec![list_field]));
3439
3440        let values = Arc::new(generate_nested_list_view_data::<i32>());
3441
3442        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3443        let out_batch = deserialize_file(serialize_file(&in_batch));
3444        assert_eq!(in_batch, out_batch);
3445    }
3446
3447    fn test_roundtrip_list_view_of_dict_impl<OffsetSize: OffsetSizeTrait, U: ArrowNativeType>(
3448        list_data_type: DataType,
3449        offsets: &[U; 5],
3450        sizes: &[U; 4],
3451    ) {
3452        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3453        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3454        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3455        let dict_data = dict_array.to_data();
3456
3457        let value_offsets = Buffer::from_slice_ref(offsets);
3458        let value_sizes = Buffer::from_slice_ref(sizes);
3459
3460        let list_data = ArrayData::builder(list_data_type)
3461            .len(4)
3462            .add_buffer(value_offsets)
3463            .add_buffer(value_sizes)
3464            .add_child_data(dict_data)
3465            .build()
3466            .unwrap();
3467        let list_view_array = GenericListViewArray::<OffsetSize>::from(list_data);
3468
3469        let schema = Arc::new(Schema::new(vec![Field::new(
3470            "f1",
3471            list_view_array.data_type().clone(),
3472            false,
3473        )]));
3474        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(list_view_array)]).unwrap();
3475
3476        let output_batch = deserialize_file(serialize_file(&input_batch));
3477        assert_eq!(input_batch, output_batch);
3478
3479        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3480        assert_eq!(input_batch, output_batch);
3481    }
3482
3483    #[test]
3484    fn test_roundtrip_list_view_of_dict() {
3485        #[allow(deprecated)]
3486        let list_data_type = DataType::ListView(Arc::new(Field::new_dict(
3487            "item",
3488            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3489            true,
3490            1,
3491            false,
3492        )));
3493        let offsets: &[i32; 5] = &[0, 2, 4, 4, 7];
3494        let sizes: &[i32; 4] = &[2, 2, 0, 3];
3495        test_roundtrip_list_view_of_dict_impl::<i32, i32>(list_data_type, offsets, sizes);
3496    }
3497
3498    #[test]
3499    fn test_roundtrip_large_list_view_of_dict() {
3500        #[allow(deprecated)]
3501        let list_data_type = DataType::LargeListView(Arc::new(Field::new_dict(
3502            "item",
3503            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3504            true,
3505            2,
3506            false,
3507        )));
3508        let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
3509        let sizes: &[i64; 4] = &[2, 2, 0, 3];
3510        test_roundtrip_list_view_of_dict_impl::<i64, i64>(list_data_type, offsets, sizes);
3511    }
3512
3513    #[test]
3514    fn test_roundtrip_sliced_list_view_of_dict() {
3515        #[allow(deprecated)]
3516        let list_data_type = DataType::ListView(Arc::new(Field::new_dict(
3517            "item",
3518            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3519            true,
3520            3,
3521            false,
3522        )));
3523
3524        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3525        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2, 1, 0, 3, 2, 1]);
3526        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3527        let dict_data = dict_array.to_data();
3528
3529        let offsets: &[i32; 7] = &[0, 2, 4, 4, 7, 9, 12];
3530        let sizes: &[i32; 6] = &[2, 2, 0, 3, 2, 3];
3531        let value_offsets = Buffer::from_slice_ref(offsets);
3532        let value_sizes = Buffer::from_slice_ref(sizes);
3533
3534        let list_data = ArrayData::builder(list_data_type)
3535            .len(6)
3536            .add_buffer(value_offsets)
3537            .add_buffer(value_sizes)
3538            .add_child_data(dict_data)
3539            .build()
3540            .unwrap();
3541        let list_view_array = GenericListViewArray::<i32>::from(list_data);
3542
3543        let schema = Arc::new(Schema::new(vec![Field::new(
3544            "f1",
3545            list_view_array.data_type().clone(),
3546            false,
3547        )]));
3548        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(list_view_array)]).unwrap();
3549
3550        let sliced_batch = input_batch.slice(1, 4);
3551
3552        let output_batch = deserialize_file(serialize_file(&sliced_batch));
3553        assert_eq!(sliced_batch, output_batch);
3554
3555        let output_batch = deserialize_stream(serialize_stream(&sliced_batch));
3556        assert_eq!(sliced_batch, output_batch);
3557    }
3558
3559    #[test]
3560    fn test_roundtrip_dense_union_of_dict() {
3561        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3562        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3563        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3564
3565        #[allow(deprecated)]
3566        let dict_field = Arc::new(Field::new_dict(
3567            "dict",
3568            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3569            true,
3570            1,
3571            false,
3572        ));
3573        let int_field = Arc::new(Field::new("int", DataType::Int32, false));
3574        let union_fields = UnionFields::try_new(vec![0, 1], vec![dict_field, int_field]).unwrap();
3575
3576        let types = ScalarBuffer::from(vec![0i8, 0, 1, 0, 1, 0, 0]);
3577        let offsets = ScalarBuffer::from(vec![0i32, 1, 0, 2, 1, 3, 4]);
3578
3579        let int_array = Int32Array::from(vec![100, 200]);
3580
3581        let union = UnionArray::try_new(
3582            union_fields.clone(),
3583            types,
3584            Some(offsets),
3585            vec![Arc::new(dict_array), Arc::new(int_array)],
3586        )
3587        .unwrap();
3588
3589        let schema = Arc::new(Schema::new(vec![Field::new(
3590            "union",
3591            DataType::Union(union_fields, UnionMode::Dense),
3592            false,
3593        )]));
3594        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
3595
3596        let output_batch = deserialize_file(serialize_file(&input_batch));
3597        assert_eq!(input_batch, output_batch);
3598
3599        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3600        assert_eq!(input_batch, output_batch);
3601    }
3602
3603    #[test]
3604    fn test_roundtrip_sparse_union_of_dict() {
3605        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3606        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3607        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3608
3609        #[allow(deprecated)]
3610        let dict_field = Arc::new(Field::new_dict(
3611            "dict",
3612            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3613            true,
3614            2,
3615            false,
3616        ));
3617        let int_field = Arc::new(Field::new("int", DataType::Int32, false));
3618        let union_fields = UnionFields::try_new(vec![0, 1], vec![dict_field, int_field]).unwrap();
3619
3620        let types = ScalarBuffer::from(vec![0i8, 0, 1, 0, 1, 0, 0]);
3621
3622        let int_array = Int32Array::from(vec![0, 0, 100, 0, 200, 0, 0]);
3623
3624        let union = UnionArray::try_new(
3625            union_fields.clone(),
3626            types,
3627            None,
3628            vec![Arc::new(dict_array), Arc::new(int_array)],
3629        )
3630        .unwrap();
3631
3632        let schema = Arc::new(Schema::new(vec![Field::new(
3633            "union",
3634            DataType::Union(union_fields, UnionMode::Sparse),
3635            false,
3636        )]));
3637        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
3638
3639        let output_batch = deserialize_file(serialize_file(&input_batch));
3640        assert_eq!(input_batch, output_batch);
3641
3642        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3643        assert_eq!(input_batch, output_batch);
3644    }
3645
3646    #[test]
3647    fn test_roundtrip_map_with_dict_keys() {
3648        // Building a map array is a bit involved. We first build a struct arary that has a key and
3649        // value field and then use that to build the actual map array.
3650        let key_values = StringArray::from(vec!["key_a", "key_b", "key_c"]);
3651        let keys = Int32Array::from_iter_values([0, 1, 2, 0, 1, 0]);
3652        let dict_keys = DictionaryArray::new(keys, Arc::new(key_values));
3653
3654        let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
3655
3656        #[allow(deprecated)]
3657        let entries_field = Arc::new(Field::new(
3658            "entries",
3659            DataType::Struct(
3660                vec![
3661                    Field::new_dict(
3662                        "key",
3663                        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3664                        false,
3665                        1,
3666                        false,
3667                    ),
3668                    Field::new("value", DataType::Int32, true),
3669                ]
3670                .into(),
3671            ),
3672            false,
3673        ));
3674
3675        let entries = StructArray::from(vec![
3676            (
3677                Arc::new(Field::new(
3678                    "key",
3679                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3680                    false,
3681                )),
3682                Arc::new(dict_keys) as ArrayRef,
3683            ),
3684            (
3685                Arc::new(Field::new("value", DataType::Int32, true)),
3686                Arc::new(values) as ArrayRef,
3687            ),
3688        ]);
3689
3690        let offsets = Buffer::from_slice_ref([0i32, 2, 4, 6]);
3691
3692        let map_data = ArrayData::builder(DataType::Map(entries_field, false))
3693            .len(3)
3694            .add_buffer(offsets)
3695            .add_child_data(entries.into_data())
3696            .build()
3697            .unwrap();
3698        let map_array = MapArray::from(map_data);
3699
3700        let schema = Arc::new(Schema::new(vec![Field::new(
3701            "map",
3702            map_array.data_type().clone(),
3703            false,
3704        )]));
3705        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(map_array)]).unwrap();
3706
3707        let output_batch = deserialize_file(serialize_file(&input_batch));
3708        assert_eq!(input_batch, output_batch);
3709
3710        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3711        assert_eq!(input_batch, output_batch);
3712    }
3713
3714    #[test]
3715    fn test_roundtrip_map_with_dict_values() {
3716        // Building a map array is a bit involved. We first build a struct arary that has a key and
3717        // value field and then use that to build the actual map array.
3718        let keys = StringArray::from(vec!["a", "b", "c", "d", "e", "f"]);
3719
3720        let value_values = StringArray::from(vec!["val_x", "val_y", "val_z"]);
3721        let value_keys = Int32Array::from_iter_values([0, 1, 2, 0, 1, 0]);
3722        let dict_values = DictionaryArray::new(value_keys, Arc::new(value_values));
3723
3724        #[allow(deprecated)]
3725        let entries_field = Arc::new(Field::new(
3726            "entries",
3727            DataType::Struct(
3728                vec![
3729                    Field::new("key", DataType::Utf8, false),
3730                    Field::new_dict(
3731                        "value",
3732                        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3733                        true,
3734                        2,
3735                        false,
3736                    ),
3737                ]
3738                .into(),
3739            ),
3740            false,
3741        ));
3742
3743        let entries = StructArray::from(vec![
3744            (
3745                Arc::new(Field::new("key", DataType::Utf8, false)),
3746                Arc::new(keys) as ArrayRef,
3747            ),
3748            (
3749                Arc::new(Field::new(
3750                    "value",
3751                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3752                    true,
3753                )),
3754                Arc::new(dict_values) as ArrayRef,
3755            ),
3756        ]);
3757
3758        let offsets = Buffer::from_slice_ref([0i32, 2, 4, 6]);
3759
3760        let map_data = ArrayData::builder(DataType::Map(entries_field, false))
3761            .len(3)
3762            .add_buffer(offsets)
3763            .add_child_data(entries.into_data())
3764            .build()
3765            .unwrap();
3766        let map_array = MapArray::from(map_data);
3767
3768        let schema = Arc::new(Schema::new(vec![Field::new(
3769            "map",
3770            map_array.data_type().clone(),
3771            false,
3772        )]));
3773        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(map_array)]).unwrap();
3774
3775        let output_batch = deserialize_file(serialize_file(&input_batch));
3776        assert_eq!(input_batch, output_batch);
3777
3778        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3779        assert_eq!(input_batch, output_batch);
3780    }
3781
3782    #[test]
3783    fn test_decimal128_alignment16_is_sufficient() {
3784        const IPC_ALIGNMENT: usize = 16;
3785
3786        // Test a bunch of different dimensions to ensure alignment is never an issue.
3787        // For example, if we only test `num_cols = 1` then even with alignment 8 this
3788        // test would _happen_ to pass, even though for different dimensions like
3789        // `num_cols = 2` it would fail.
3790        for num_cols in [1, 2, 3, 17, 50, 73, 99] {
3791            let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle
3792
3793            let mut fields = Vec::new();
3794            let mut arrays = Vec::new();
3795            for i in 0..num_cols {
3796                let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3797                let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3798                fields.push(field);
3799                arrays.push(Arc::new(array) as Arc<dyn Array>);
3800            }
3801            let schema = Schema::new(fields);
3802            let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3803
3804            let mut writer = FileWriter::try_new_with_options(
3805                Vec::new(),
3806                batch.schema_ref(),
3807                IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3808            )
3809            .unwrap();
3810            writer.write(&batch).unwrap();
3811            writer.finish().unwrap();
3812
3813            let out: Vec<u8> = writer.into_inner().unwrap();
3814
3815            let buffer = Buffer::from_vec(out);
3816            let trailer_start = buffer.len() - 10;
3817            let footer_len =
3818                read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3819            let footer =
3820                root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3821
3822            let schema = fb_to_schema(footer.schema().unwrap());
3823
3824            // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient
3825            // for `read_record_batch` later on to read the data in a zero-copy manner.
3826            let decoder =
3827                FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3828
3829            let batches = footer.recordBatches().unwrap();
3830
3831            let block = batches.get(0);
3832            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3833            let data = buffer.slice_with_length(block.offset() as _, block_len);
3834
3835            let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
3836
3837            assert_eq!(batch, batch2);
3838        }
3839    }
3840
3841    #[test]
3842    fn test_decimal128_alignment8_is_unaligned() {
3843        const IPC_ALIGNMENT: usize = 8;
3844
3845        let num_cols = 2;
3846        let num_rows = 1;
3847
3848        let mut fields = Vec::new();
3849        let mut arrays = Vec::new();
3850        for i in 0..num_cols {
3851            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3852            let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3853            fields.push(field);
3854            arrays.push(Arc::new(array) as Arc<dyn Array>);
3855        }
3856        let schema = Schema::new(fields);
3857        let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3858
3859        let mut writer = FileWriter::try_new_with_options(
3860            Vec::new(),
3861            batch.schema_ref(),
3862            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3863        )
3864        .unwrap();
3865        writer.write(&batch).unwrap();
3866        writer.finish().unwrap();
3867
3868        let out: Vec<u8> = writer.into_inner().unwrap();
3869
3870        let buffer = Buffer::from_vec(out);
3871        let trailer_start = buffer.len() - 10;
3872        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3873        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3874        let schema = fb_to_schema(footer.schema().unwrap());
3875
3876        // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying
3877        // to an aligned buffer in `ArrayDataBuilder.build_aligned`.
3878        let decoder =
3879            FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3880
3881        let batches = footer.recordBatches().unwrap();
3882
3883        let block = batches.get(0);
3884        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3885        let data = buffer.slice_with_length(block.offset() as _, block_len);
3886
3887        let result = decoder.read_record_batch(block, &data);
3888
3889        let error = result.unwrap_err();
3890        assert_eq!(
3891            error.to_string(),
3892            "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
3893             offset from expected alignment of 16 by 8"
3894        );
3895    }
3896
3897    #[test]
3898    fn test_flush() {
3899        // We write a schema which is small enough to fit into a buffer and not get flushed,
3900        // and then force the write with .flush().
3901        let num_cols = 2;
3902        let mut fields = Vec::new();
3903        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
3904        for i in 0..num_cols {
3905            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3906            fields.push(field);
3907        }
3908        let schema = Schema::new(fields);
3909        let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
3910        let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
3911        let mut stream_writer =
3912            StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
3913                .unwrap();
3914        let mut file_writer =
3915            FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
3916
3917        let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
3918        let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
3919        stream_writer.flush().unwrap();
3920        file_writer.flush().unwrap();
3921        let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
3922        let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
3923        let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
3924        // Finishing a stream writes the continuation bytes in MetadataVersion::V5 (4 bytes)
3925        // and then a length of 0 (4 bytes) for a total of 8 bytes.
3926        // Everything before that should have been flushed in the .flush() call.
3927        let expected_stream_flushed_bytes = stream_out.len() - 8;
3928        // A file write is the same as the stream write except for the leading magic string
3929        // ARROW1 plus padding, which is 8 bytes.
3930        let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
3931
3932        assert!(
3933            stream_bytes_written_on_new < stream_bytes_written_on_flush,
3934            "this test makes no sense if flush is not actually required"
3935        );
3936        assert!(
3937            file_bytes_written_on_new < file_bytes_written_on_flush,
3938            "this test makes no sense if flush is not actually required"
3939        );
3940        assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
3941        assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
3942    }
3943
3944    #[test]
3945    fn test_roundtrip_list_of_fixed_list() -> Result<(), ArrowError> {
3946        let l1_type =
3947            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, false)), 3);
3948        let l2_type = DataType::List(Arc::new(Field::new("item", l1_type.clone(), false)));
3949
3950        let l0_builder = Float32Builder::new();
3951        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3).with_field(Arc::new(Field::new(
3952            "item",
3953            DataType::Float32,
3954            false,
3955        )));
3956        let mut l2_builder =
3957            ListBuilder::new(l1_builder).with_field(Arc::new(Field::new("item", l1_type, false)));
3958
3959        for point in [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] {
3960            l2_builder.values().values().append_value(point[0]);
3961            l2_builder.values().values().append_value(point[1]);
3962            l2_builder.values().values().append_value(point[2]);
3963
3964            l2_builder.values().append(true);
3965        }
3966        l2_builder.append(true);
3967
3968        let point = [10., 11., 12.];
3969        l2_builder.values().values().append_value(point[0]);
3970        l2_builder.values().values().append_value(point[1]);
3971        l2_builder.values().values().append_value(point[2]);
3972
3973        l2_builder.values().append(true);
3974        l2_builder.append(true);
3975
3976        let array = Arc::new(l2_builder.finish()) as ArrayRef;
3977
3978        let schema = Arc::new(Schema::new_with_metadata(
3979            vec![Field::new("points", l2_type, false)],
3980            HashMap::default(),
3981        ));
3982
3983        // Test a variety of combinations that include 0 and non-zero offsets
3984        // and also portions or the rest of the array
3985        test_slices(&array, &schema, 0, 1)?;
3986        test_slices(&array, &schema, 0, 2)?;
3987        test_slices(&array, &schema, 1, 1)?;
3988
3989        Ok(())
3990    }
3991
3992    #[test]
3993    fn test_roundtrip_list_of_fixed_list_w_nulls() -> Result<(), ArrowError> {
3994        let l0_builder = Float32Builder::new();
3995        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3);
3996        let mut l2_builder = ListBuilder::new(l1_builder);
3997
3998        for point in [
3999            [Some(1.0), Some(2.0), None],
4000            [Some(4.0), Some(5.0), Some(6.0)],
4001            [None, Some(8.0), Some(9.0)],
4002        ] {
4003            for p in point {
4004                match p {
4005                    Some(p) => l2_builder.values().values().append_value(p),
4006                    None => l2_builder.values().values().append_null(),
4007                }
4008            }
4009
4010            l2_builder.values().append(true);
4011        }
4012        l2_builder.append(true);
4013
4014        let point = [Some(10.), None, None];
4015        for p in point {
4016            match p {
4017                Some(p) => l2_builder.values().values().append_value(p),
4018                None => l2_builder.values().values().append_null(),
4019            }
4020        }
4021
4022        l2_builder.values().append(true);
4023        l2_builder.append(true);
4024
4025        let array = Arc::new(l2_builder.finish()) as ArrayRef;
4026
4027        let schema = Arc::new(Schema::new_with_metadata(
4028            vec![Field::new(
4029                "points",
4030                DataType::List(Arc::new(Field::new(
4031                    "item",
4032                    DataType::FixedSizeList(
4033                        Arc::new(Field::new("item", DataType::Float32, true)),
4034                        3,
4035                    ),
4036                    true,
4037                ))),
4038                true,
4039            )],
4040            HashMap::default(),
4041        ));
4042
4043        // Test a variety of combinations that include 0 and non-zero offsets
4044        // and also portions or the rest of the array
4045        test_slices(&array, &schema, 0, 1)?;
4046        test_slices(&array, &schema, 0, 2)?;
4047        test_slices(&array, &schema, 1, 1)?;
4048
4049        Ok(())
4050    }
4051
4052    fn test_slices(
4053        parent_array: &ArrayRef,
4054        schema: &SchemaRef,
4055        offset: usize,
4056        length: usize,
4057    ) -> Result<(), ArrowError> {
4058        let subarray = parent_array.slice(offset, length);
4059        let original_batch = RecordBatch::try_new(schema.clone(), vec![subarray])?;
4060
4061        let mut bytes = Vec::new();
4062        let mut writer = StreamWriter::try_new(&mut bytes, schema)?;
4063        writer.write(&original_batch)?;
4064        writer.finish()?;
4065
4066        let mut cursor = std::io::Cursor::new(bytes);
4067        let mut reader = StreamReader::try_new(&mut cursor, None)?;
4068        let returned_batch = reader.next().unwrap()?;
4069
4070        assert_eq!(original_batch, returned_batch);
4071
4072        Ok(())
4073    }
4074
4075    #[test]
4076    fn test_roundtrip_fixed_list() -> Result<(), ArrowError> {
4077        let int_builder = Int64Builder::new();
4078        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3)
4079            .with_field(Arc::new(Field::new("item", DataType::Int64, false)));
4080
4081        for point in [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] {
4082            fixed_list_builder.values().append_value(point[0]);
4083            fixed_list_builder.values().append_value(point[1]);
4084            fixed_list_builder.values().append_value(point[2]);
4085
4086            fixed_list_builder.append(true);
4087        }
4088
4089        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
4090
4091        let schema = Arc::new(Schema::new_with_metadata(
4092            vec![Field::new(
4093                "points",
4094                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, false)), 3),
4095                false,
4096            )],
4097            HashMap::default(),
4098        ));
4099
4100        // Test a variety of combinations that include 0 and non-zero offsets
4101        // and also portions or the rest of the array
4102        test_slices(&array, &schema, 0, 4)?;
4103        test_slices(&array, &schema, 0, 2)?;
4104        test_slices(&array, &schema, 1, 3)?;
4105        test_slices(&array, &schema, 2, 1)?;
4106
4107        Ok(())
4108    }
4109
4110    #[test]
4111    fn test_roundtrip_fixed_list_w_nulls() -> Result<(), ArrowError> {
4112        let int_builder = Int64Builder::new();
4113        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3);
4114
4115        for point in [
4116            [Some(1), Some(2), None],
4117            [Some(4), Some(5), Some(6)],
4118            [None, Some(8), Some(9)],
4119            [Some(10), None, None],
4120        ] {
4121            for p in point {
4122                match p {
4123                    Some(p) => fixed_list_builder.values().append_value(p),
4124                    None => fixed_list_builder.values().append_null(),
4125                }
4126            }
4127
4128            fixed_list_builder.append(true);
4129        }
4130
4131        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
4132
4133        let schema = Arc::new(Schema::new_with_metadata(
4134            vec![Field::new(
4135                "points",
4136                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 3),
4137                true,
4138            )],
4139            HashMap::default(),
4140        ));
4141
4142        // Test a variety of combinations that include 0 and non-zero offsets
4143        // and also portions or the rest of the array
4144        test_slices(&array, &schema, 0, 4)?;
4145        test_slices(&array, &schema, 0, 2)?;
4146        test_slices(&array, &schema, 1, 3)?;
4147        test_slices(&array, &schema, 2, 1)?;
4148
4149        Ok(())
4150    }
4151
4152    #[test]
4153    fn test_metadata_encoding_ordering() {
4154        fn create_hash() -> u64 {
4155            let metadata: HashMap<String, String> = [
4156                ("a", "1"), //
4157                ("b", "2"), //
4158                ("c", "3"), //
4159                ("d", "4"), //
4160                ("e", "5"), //
4161            ]
4162            .into_iter()
4163            .map(|(k, v)| (k.to_owned(), v.to_owned()))
4164            .collect();
4165
4166            // Set metadata on both the schema and a field within it.
4167            let schema = Arc::new(
4168                Schema::new(vec![
4169                    Field::new("a", DataType::Int64, true).with_metadata(metadata.clone()),
4170                ])
4171                .with_metadata(metadata)
4172                .clone(),
4173            );
4174            let batch = RecordBatch::new_empty(schema.clone());
4175
4176            let mut bytes = Vec::new();
4177            let mut w = StreamWriter::try_new(&mut bytes, batch.schema_ref()).unwrap();
4178            w.write(&batch).unwrap();
4179            w.finish().unwrap();
4180
4181            let mut h = std::hash::DefaultHasher::new();
4182            h.write(&bytes);
4183            h.finish()
4184        }
4185
4186        let expected = create_hash();
4187
4188        // Since there is randomness in the HashMap and we cannot specify our
4189        // own Hasher for the implementation used for metadata, run the above
4190        // code 20x and verify it does not change. This is not perfect but it
4191        // should be good enough.
4192        let all_passed = (0..20).all(|_| create_hash() == expected);
4193        assert!(all_passed);
4194    }
4195
4196    #[test]
4197    fn test_dictionary_tracker_reset() {
4198        let data_gen = IpcDataGenerator::default();
4199        let mut dictionary_tracker = DictionaryTracker::new(false);
4200        let writer_options = IpcWriteOptions::default();
4201        let mut compression_ctx = CompressionContext::default();
4202
4203        let schema = Arc::new(Schema::new(vec![Field::new(
4204            "a",
4205            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
4206            false,
4207        )]));
4208
4209        let mut write_single_batch_stream =
4210            |batch: RecordBatch, dict_tracker: &mut DictionaryTracker| -> Vec<u8> {
4211                let mut buffer = Vec::new();
4212
4213                // create a new IPC stream:
4214                let stream_header = data_gen.schema_to_bytes_with_dictionary_tracker(
4215                    &schema,
4216                    dict_tracker,
4217                    &writer_options,
4218                );
4219                _ = write_message(&mut buffer, stream_header, &writer_options).unwrap();
4220
4221                let (encoded_dicts, encoded_batch) = data_gen
4222                    .encode(&batch, dict_tracker, &writer_options, &mut compression_ctx)
4223                    .unwrap();
4224                for encoded_dict in encoded_dicts {
4225                    _ = write_message(&mut buffer, encoded_dict, &writer_options).unwrap();
4226                }
4227                _ = write_message(&mut buffer, encoded_batch, &writer_options).unwrap();
4228
4229                buffer
4230            };
4231
4232        let batch1 = RecordBatch::try_new(
4233            schema.clone(),
4234            vec![Arc::new(DictionaryArray::new(
4235                UInt8Array::from_iter_values([0]),
4236                Arc::new(StringArray::from_iter_values(["a"])),
4237            ))],
4238        )
4239        .unwrap();
4240        let buffer = write_single_batch_stream(batch1.clone(), &mut dictionary_tracker);
4241
4242        // ensure we can read the stream back
4243        let mut reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
4244        let read_batch = reader.next().unwrap().unwrap();
4245        assert_eq!(read_batch, batch1);
4246
4247        // reset the dictionary tracker so it can be used for next stream
4248        dictionary_tracker.clear();
4249
4250        // now write a 2nd stream and ensure we can also read it:
4251        let batch2 = RecordBatch::try_new(
4252            schema.clone(),
4253            vec![Arc::new(DictionaryArray::new(
4254                UInt8Array::from_iter_values([0]),
4255                Arc::new(StringArray::from_iter_values(["a"])),
4256            ))],
4257        )
4258        .unwrap();
4259        let buffer = write_single_batch_stream(batch2.clone(), &mut dictionary_tracker);
4260        let mut reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
4261        let read_batch = reader.next().unwrap().unwrap();
4262        assert_eq!(read_batch, batch2);
4263    }
4264}