Skip to main content

arrow_ipc/
writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Writers
19//!
20//! # Notes
21//!
22//! [`FileWriter`] and [`StreamWriter`] have similar interfaces,
23//! however the [`FileWriter`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27use std::cmp::min;
28use std::collections::HashMap;
29use std::io::{BufWriter, Write};
30use std::mem::size_of;
31use std::sync::Arc;
32
33use flatbuffers::FlatBufferBuilder;
34
35use arrow_array::builder::BufferBuilder;
36use arrow_array::cast::*;
37use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
38use arrow_array::*;
39use arrow_buffer::bit_util;
40use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
41use arrow_data::{ArrayData, ArrayDataBuilder, BufferSpec, layout};
42use arrow_schema::*;
43
44use crate::CONTINUATION_MARKER;
45use crate::compression::CompressionCodec;
46pub use crate::compression::CompressionContext;
47use crate::convert::IpcSchemaEncoder;
48
49/// IPC write options used to control the behaviour of the [`IpcDataGenerator`]
50#[derive(Debug, Clone)]
51pub struct IpcWriteOptions {
52    /// Write padding after memory buffers to this multiple of bytes.
53    /// Must be 8, 16, 32, or 64 - defaults to 64.
54    alignment: u8,
55    /// The legacy format is for releases before 0.15.0, and uses metadata V4
56    write_legacy_ipc_format: bool,
57    /// The metadata version to write. The Rust IPC writer supports V4+
58    ///
59    /// *Default versions per crate*
60    ///
61    /// When creating the default IpcWriteOptions, the following metadata versions are used:
62    ///
63    /// version 2.0.0: V4, with legacy format enabled
64    /// version 4.0.0: V5
65    metadata_version: crate::MetadataVersion,
66    /// Compression, if desired. Will result in a runtime error
67    /// if the corresponding feature is not enabled
68    batch_compression_type: Option<crate::CompressionType>,
69    /// How to handle updating dictionaries in IPC messages
70    dictionary_handling: DictionaryHandling,
71}
72
73impl IpcWriteOptions {
74    /// Configures compression when writing IPC files.
75    ///
76    /// Will result in a runtime error if the corresponding feature
77    /// is not enabled
78    pub fn try_with_compression(
79        mut self,
80        batch_compression_type: Option<crate::CompressionType>,
81    ) -> Result<Self, ArrowError> {
82        self.batch_compression_type = batch_compression_type;
83
84        if self.batch_compression_type.is_some()
85            && self.metadata_version < crate::MetadataVersion::V5
86        {
87            return Err(ArrowError::InvalidArgumentError(
88                "Compression only supported in metadata v5 and above".to_string(),
89            ));
90        }
91        Ok(self)
92    }
93    /// Try to create IpcWriteOptions, checking for incompatible settings
94    pub fn try_new(
95        alignment: usize,
96        write_legacy_ipc_format: bool,
97        metadata_version: crate::MetadataVersion,
98    ) -> Result<Self, ArrowError> {
99        let is_alignment_valid =
100            alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
101        if !is_alignment_valid {
102            return Err(ArrowError::InvalidArgumentError(
103                "Alignment should be 8, 16, 32, or 64.".to_string(),
104            ));
105        }
106        let alignment: u8 = u8::try_from(alignment).expect("range already checked");
107        match metadata_version {
108            crate::MetadataVersion::V1
109            | crate::MetadataVersion::V2
110            | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
111                "Writing IPC metadata version 3 and lower not supported".to_string(),
112            )),
113            #[allow(deprecated)]
114            crate::MetadataVersion::V4 => Ok(Self {
115                alignment,
116                write_legacy_ipc_format,
117                metadata_version,
118                batch_compression_type: None,
119                dictionary_handling: DictionaryHandling::default(),
120            }),
121            crate::MetadataVersion::V5 => {
122                if write_legacy_ipc_format {
123                    Err(ArrowError::InvalidArgumentError(
124                        "Legacy IPC format only supported on metadata version 4".to_string(),
125                    ))
126                } else {
127                    Ok(Self {
128                        alignment,
129                        write_legacy_ipc_format,
130                        metadata_version,
131                        batch_compression_type: None,
132                        dictionary_handling: DictionaryHandling::default(),
133                    })
134                }
135            }
136            z => Err(ArrowError::InvalidArgumentError(format!(
137                "Unsupported crate::MetadataVersion {z:?}"
138            ))),
139        }
140    }
141
142    /// Configure how dictionaries are handled in IPC messages
143    pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
144        self.dictionary_handling = dictionary_handling;
145        self
146    }
147}
148
149impl Default for IpcWriteOptions {
150    fn default() -> Self {
151        Self {
152            alignment: 64,
153            write_legacy_ipc_format: false,
154            metadata_version: crate::MetadataVersion::V5,
155            batch_compression_type: None,
156            dictionary_handling: DictionaryHandling::default(),
157        }
158    }
159}
160
161#[derive(Debug, Default)]
162/// Handles low level details of encoding [`Array`] and [`Schema`] into the
163/// [Arrow IPC Format].
164///
165/// # Example
166/// ```
167/// # fn run() {
168/// # use std::sync::Arc;
169/// # use arrow_array::UInt64Array;
170/// # use arrow_array::RecordBatch;
171/// # use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
172///
173/// // Create a record batch
174/// let batch = RecordBatch::try_from_iter(vec![
175///  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
176/// ]).unwrap();
177///
178/// // Error of dictionary ids are replaced.
179/// let error_on_replacement = true;
180/// let options = IpcWriteOptions::default();
181/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement);
182///
183/// let mut compression_context = CompressionContext::default();
184///
185/// // encode the batch into zero or more encoded dictionaries
186/// // and the data for the actual array.
187/// let data_gen = IpcDataGenerator::default();
188/// let (encoded_dictionaries, encoded_message) = data_gen
189///   .encode(&batch, &mut dictionary_tracker, &options, &mut compression_context)
190///   .unwrap();
191/// # }
192/// ```
193///
194/// [Arrow IPC Format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
195pub struct IpcDataGenerator {}
196
197impl IpcDataGenerator {
198    /// Converts a schema to an IPC message along with `dictionary_tracker`
199    /// and returns it encoded inside [EncodedData] as a flatbuffer.
200    pub fn schema_to_bytes_with_dictionary_tracker(
201        &self,
202        schema: &Schema,
203        dictionary_tracker: &mut DictionaryTracker,
204        write_options: &IpcWriteOptions,
205    ) -> EncodedData {
206        let mut fbb = FlatBufferBuilder::new();
207        let schema = {
208            let fb = IpcSchemaEncoder::new()
209                .with_dictionary_tracker(dictionary_tracker)
210                .schema_to_fb_offset(&mut fbb, schema);
211            fb.as_union_value()
212        };
213
214        let mut message = crate::MessageBuilder::new(&mut fbb);
215        message.add_version(write_options.metadata_version);
216        message.add_header_type(crate::MessageHeader::Schema);
217        message.add_bodyLength(0);
218        message.add_header(schema);
219        // TODO: custom metadata
220        let data = message.finish();
221        fbb.finish(data, None);
222
223        let data = fbb.finished_data();
224        EncodedData {
225            ipc_message: data.to_vec(),
226            arrow_data: vec![],
227        }
228    }
229
230    fn _encode_dictionaries<I: Iterator<Item = i64>>(
231        &self,
232        column: &ArrayRef,
233        encoded_dictionaries: &mut Vec<EncodedData>,
234        dictionary_tracker: &mut DictionaryTracker,
235        write_options: &IpcWriteOptions,
236        dict_id: &mut I,
237        compression_context: &mut CompressionContext,
238    ) -> Result<(), ArrowError> {
239        match column.data_type() {
240            DataType::Struct(fields) => {
241                let s = as_struct_array(column);
242                for (field, column) in fields.iter().zip(s.columns()) {
243                    self.encode_dictionaries(
244                        field,
245                        column,
246                        encoded_dictionaries,
247                        dictionary_tracker,
248                        write_options,
249                        dict_id,
250                        compression_context,
251                    )?;
252                }
253            }
254            DataType::RunEndEncoded(_, values) => {
255                let data = column.to_data();
256                if data.child_data().len() != 2 {
257                    return Err(ArrowError::InvalidArgumentError(format!(
258                        "The run encoded array should have exactly two child arrays. Found {}",
259                        data.child_data().len()
260                    )));
261                }
262                // The run_ends array is not expected to be dictionary encoded. Hence encode dictionaries
263                // only for values array.
264                let values_array = make_array(data.child_data()[1].clone());
265                self.encode_dictionaries(
266                    values,
267                    &values_array,
268                    encoded_dictionaries,
269                    dictionary_tracker,
270                    write_options,
271                    dict_id,
272                    compression_context,
273                )?;
274            }
275            DataType::List(field) => {
276                let list = as_list_array(column);
277                self.encode_dictionaries(
278                    field,
279                    list.values(),
280                    encoded_dictionaries,
281                    dictionary_tracker,
282                    write_options,
283                    dict_id,
284                    compression_context,
285                )?;
286            }
287            DataType::LargeList(field) => {
288                let list = as_large_list_array(column);
289                self.encode_dictionaries(
290                    field,
291                    list.values(),
292                    encoded_dictionaries,
293                    dictionary_tracker,
294                    write_options,
295                    dict_id,
296                    compression_context,
297                )?;
298            }
299            DataType::ListView(field) => {
300                let list = column.as_list_view::<i32>();
301                self.encode_dictionaries(
302                    field,
303                    list.values(),
304                    encoded_dictionaries,
305                    dictionary_tracker,
306                    write_options,
307                    dict_id,
308                    compression_context,
309                )?;
310            }
311            DataType::LargeListView(field) => {
312                let list = column.as_list_view::<i64>();
313                self.encode_dictionaries(
314                    field,
315                    list.values(),
316                    encoded_dictionaries,
317                    dictionary_tracker,
318                    write_options,
319                    dict_id,
320                    compression_context,
321                )?;
322            }
323            DataType::FixedSizeList(field, _) => {
324                let list = column
325                    .as_any()
326                    .downcast_ref::<FixedSizeListArray>()
327                    .expect("Unable to downcast to fixed size list array");
328                self.encode_dictionaries(
329                    field,
330                    list.values(),
331                    encoded_dictionaries,
332                    dictionary_tracker,
333                    write_options,
334                    dict_id,
335                    compression_context,
336                )?;
337            }
338            DataType::Map(field, _) => {
339                let map_array = as_map_array(column);
340
341                let (keys, values) = match field.data_type() {
342                    DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
343                    _ => panic!("Incorrect field data type {:?}", field.data_type()),
344                };
345
346                // keys
347                self.encode_dictionaries(
348                    keys,
349                    map_array.keys(),
350                    encoded_dictionaries,
351                    dictionary_tracker,
352                    write_options,
353                    dict_id,
354                    compression_context,
355                )?;
356
357                // values
358                self.encode_dictionaries(
359                    values,
360                    map_array.values(),
361                    encoded_dictionaries,
362                    dictionary_tracker,
363                    write_options,
364                    dict_id,
365                    compression_context,
366                )?;
367            }
368            DataType::Union(fields, _) => {
369                let union = as_union_array(column);
370                for (type_id, field) in fields.iter() {
371                    let column = union.child(type_id);
372                    self.encode_dictionaries(
373                        field,
374                        column,
375                        encoded_dictionaries,
376                        dictionary_tracker,
377                        write_options,
378                        dict_id,
379                        compression_context,
380                    )?;
381                }
382            }
383            _ => (),
384        }
385
386        Ok(())
387    }
388
389    #[allow(clippy::too_many_arguments)]
390    fn encode_dictionaries<I: Iterator<Item = i64>>(
391        &self,
392        field: &Field,
393        column: &ArrayRef,
394        encoded_dictionaries: &mut Vec<EncodedData>,
395        dictionary_tracker: &mut DictionaryTracker,
396        write_options: &IpcWriteOptions,
397        dict_id_seq: &mut I,
398        compression_context: &mut CompressionContext,
399    ) -> Result<(), ArrowError> {
400        match column.data_type() {
401            DataType::Dictionary(_key_type, _value_type) => {
402                let dict_data = column.to_data();
403                let dict_values = &dict_data.child_data()[0];
404
405                let values = make_array(dict_data.child_data()[0].clone());
406
407                self._encode_dictionaries(
408                    &values,
409                    encoded_dictionaries,
410                    dictionary_tracker,
411                    write_options,
412                    dict_id_seq,
413                    compression_context,
414                )?;
415
416                // It's important to only take the dict_id at this point, because the dict ID
417                // sequence is assigned depth-first, so we need to first encode children and have
418                // them take their assigned dict IDs before we take the dict ID for this field.
419                let dict_id = dict_id_seq.next().ok_or_else(|| {
420                    ArrowError::IpcError(format!(
421                        "no dict id for field {:?}: field.data_type={:?}, column.data_type={:?}",
422                        field.name(),
423                        field.data_type(),
424                        column.data_type()
425                    ))
426                })?;
427
428                match dictionary_tracker.insert_column(
429                    dict_id,
430                    column,
431                    write_options.dictionary_handling,
432                )? {
433                    DictionaryUpdate::None => {}
434                    DictionaryUpdate::New | DictionaryUpdate::Replaced => {
435                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
436                            dict_id,
437                            dict_values,
438                            write_options,
439                            false,
440                            compression_context,
441                        )?);
442                    }
443                    DictionaryUpdate::Delta(data) => {
444                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
445                            dict_id,
446                            &data,
447                            write_options,
448                            true,
449                            compression_context,
450                        )?);
451                    }
452                }
453            }
454            _ => self._encode_dictionaries(
455                column,
456                encoded_dictionaries,
457                dictionary_tracker,
458                write_options,
459                dict_id_seq,
460                compression_context,
461            )?,
462        }
463
464        Ok(())
465    }
466
467    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
468    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
469    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
470    pub fn encode(
471        &self,
472        batch: &RecordBatch,
473        dictionary_tracker: &mut DictionaryTracker,
474        write_options: &IpcWriteOptions,
475        compression_context: &mut CompressionContext,
476    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
477        let schema = batch.schema();
478        let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
479
480        let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
481
482        for (i, field) in schema.fields().iter().enumerate() {
483            let column = batch.column(i);
484            self.encode_dictionaries(
485                field,
486                column,
487                &mut encoded_dictionaries,
488                dictionary_tracker,
489                write_options,
490                &mut dict_id,
491                compression_context,
492            )?;
493        }
494
495        let encoded_message =
496            self.record_batch_to_bytes(batch, write_options, compression_context)?;
497        Ok((encoded_dictionaries, encoded_message))
498    }
499
500    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
501    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
502    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
503    #[deprecated(since = "57.0.0", note = "Use `encode` instead")]
504    pub fn encoded_batch(
505        &self,
506        batch: &RecordBatch,
507        dictionary_tracker: &mut DictionaryTracker,
508        write_options: &IpcWriteOptions,
509    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
510        self.encode(
511            batch,
512            dictionary_tracker,
513            write_options,
514            &mut Default::default(),
515        )
516    }
517
518    /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the
519    /// other for the batch's data
520    fn record_batch_to_bytes(
521        &self,
522        batch: &RecordBatch,
523        write_options: &IpcWriteOptions,
524        compression_context: &mut CompressionContext,
525    ) -> Result<EncodedData, ArrowError> {
526        let mut fbb = FlatBufferBuilder::new();
527
528        let mut nodes: Vec<crate::FieldNode> = vec![];
529        let mut buffers: Vec<crate::Buffer> = vec![];
530        let mut arrow_data: Vec<u8> = vec![];
531        let mut offset = 0;
532
533        // get the type of compression
534        let batch_compression_type = write_options.batch_compression_type;
535
536        let compression = batch_compression_type.map(|batch_compression_type| {
537            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
538            c.add_method(crate::BodyCompressionMethod::BUFFER);
539            c.add_codec(batch_compression_type);
540            c.finish()
541        });
542
543        let compression_codec: Option<CompressionCodec> =
544            batch_compression_type.map(TryInto::try_into).transpose()?;
545
546        let mut variadic_buffer_counts = vec![];
547
548        for array in batch.columns() {
549            let array_data = array.to_data();
550            offset = write_array_data(
551                &array_data,
552                &mut buffers,
553                &mut arrow_data,
554                &mut nodes,
555                offset,
556                array.len(),
557                array.null_count(),
558                compression_codec,
559                compression_context,
560                write_options,
561            )?;
562
563            append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
564        }
565        // pad the tail of body data
566        let len = arrow_data.len();
567        let pad_len = pad_to_alignment(write_options.alignment, len);
568        arrow_data.extend_from_slice(&PADDING[..pad_len]);
569
570        // write data
571        let buffers = fbb.create_vector(&buffers);
572        let nodes = fbb.create_vector(&nodes);
573        let variadic_buffer = if variadic_buffer_counts.is_empty() {
574            None
575        } else {
576            Some(fbb.create_vector(&variadic_buffer_counts))
577        };
578
579        let root = {
580            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
581            batch_builder.add_length(batch.num_rows() as i64);
582            batch_builder.add_nodes(nodes);
583            batch_builder.add_buffers(buffers);
584            if let Some(c) = compression {
585                batch_builder.add_compression(c);
586            }
587
588            if let Some(v) = variadic_buffer {
589                batch_builder.add_variadicBufferCounts(v);
590            }
591            let b = batch_builder.finish();
592            b.as_union_value()
593        };
594        // create an crate::Message
595        let mut message = crate::MessageBuilder::new(&mut fbb);
596        message.add_version(write_options.metadata_version);
597        message.add_header_type(crate::MessageHeader::RecordBatch);
598        message.add_bodyLength(arrow_data.len() as i64);
599        message.add_header(root);
600        let root = message.finish();
601        fbb.finish(root, None);
602        let finished_data = fbb.finished_data();
603
604        Ok(EncodedData {
605            ipc_message: finished_data.to_vec(),
606            arrow_data,
607        })
608    }
609
610    /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the
611    /// other for the data
612    fn dictionary_batch_to_bytes(
613        &self,
614        dict_id: i64,
615        array_data: &ArrayData,
616        write_options: &IpcWriteOptions,
617        is_delta: bool,
618        compression_context: &mut CompressionContext,
619    ) -> Result<EncodedData, ArrowError> {
620        let mut fbb = FlatBufferBuilder::new();
621
622        let mut nodes: Vec<crate::FieldNode> = vec![];
623        let mut buffers: Vec<crate::Buffer> = vec![];
624        let mut arrow_data: Vec<u8> = vec![];
625
626        // get the type of compression
627        let batch_compression_type = write_options.batch_compression_type;
628
629        let compression = batch_compression_type.map(|batch_compression_type| {
630            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
631            c.add_method(crate::BodyCompressionMethod::BUFFER);
632            c.add_codec(batch_compression_type);
633            c.finish()
634        });
635
636        let compression_codec: Option<CompressionCodec> = batch_compression_type
637            .map(|batch_compression_type| batch_compression_type.try_into())
638            .transpose()?;
639
640        write_array_data(
641            array_data,
642            &mut buffers,
643            &mut arrow_data,
644            &mut nodes,
645            0,
646            array_data.len(),
647            array_data.null_count(),
648            compression_codec,
649            compression_context,
650            write_options,
651        )?;
652
653        let mut variadic_buffer_counts = vec![];
654        append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
655
656        // pad the tail of body data
657        let len = arrow_data.len();
658        let pad_len = pad_to_alignment(write_options.alignment, len);
659        arrow_data.extend_from_slice(&PADDING[..pad_len]);
660
661        // write data
662        let buffers = fbb.create_vector(&buffers);
663        let nodes = fbb.create_vector(&nodes);
664        let variadic_buffer = if variadic_buffer_counts.is_empty() {
665            None
666        } else {
667            Some(fbb.create_vector(&variadic_buffer_counts))
668        };
669
670        let root = {
671            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
672            batch_builder.add_length(array_data.len() as i64);
673            batch_builder.add_nodes(nodes);
674            batch_builder.add_buffers(buffers);
675            if let Some(c) = compression {
676                batch_builder.add_compression(c);
677            }
678            if let Some(v) = variadic_buffer {
679                batch_builder.add_variadicBufferCounts(v);
680            }
681            batch_builder.finish()
682        };
683
684        let root = {
685            let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
686            batch_builder.add_id(dict_id);
687            batch_builder.add_data(root);
688            batch_builder.add_isDelta(is_delta);
689            batch_builder.finish().as_union_value()
690        };
691
692        let root = {
693            let mut message_builder = crate::MessageBuilder::new(&mut fbb);
694            message_builder.add_version(write_options.metadata_version);
695            message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
696            message_builder.add_bodyLength(arrow_data.len() as i64);
697            message_builder.add_header(root);
698            message_builder.finish()
699        };
700
701        fbb.finish(root, None);
702        let finished_data = fbb.finished_data();
703
704        Ok(EncodedData {
705            ipc_message: finished_data.to_vec(),
706            arrow_data,
707        })
708    }
709}
710
711fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
712    match array.data_type() {
713        DataType::BinaryView | DataType::Utf8View => {
714            // The spec documents the counts only includes the variadic buffers, not the view/null buffers.
715            // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers
716            counts.push(array.buffers().len() as i64 - 1);
717        }
718        DataType::Dictionary(_, _) => {
719            // Do nothing
720            // Dictionary types are handled in `encode_dictionaries`.
721        }
722        _ => {
723            for child in array.child_data() {
724                append_variadic_buffer_counts(counts, child)
725            }
726        }
727    }
728}
729
730pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
731    match arr.data_type() {
732        DataType::RunEndEncoded(k, _) => match k.data_type() {
733            DataType::Int16 => {
734                Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
735            }
736            DataType::Int32 => {
737                Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
738            }
739            DataType::Int64 => {
740                Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
741            }
742            d => unreachable!("Unexpected data type {d}"),
743        },
744        d => Err(ArrowError::InvalidArgumentError(format!(
745            "The given array is not a run array. Data type of given array: {d}"
746        ))),
747    }
748}
749
750// Returns a `RunArray` with zero offset and length matching the last value
751// in run_ends array.
752fn into_zero_offset_run_array<R: RunEndIndexType>(
753    run_array: RunArray<R>,
754) -> Result<RunArray<R>, ArrowError> {
755    let run_ends = run_array.run_ends();
756    if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
757        return Ok(run_array);
758    }
759
760    // The physical index of original run_ends array from which the `ArrayData`is sliced.
761    let start_physical_index = run_ends.get_start_physical_index();
762
763    // The physical index of original run_ends array until which the `ArrayData`is sliced.
764    let end_physical_index = run_ends.get_end_physical_index();
765
766    let physical_length = end_physical_index - start_physical_index + 1;
767
768    // build new run_ends array by subtracting offset from run ends.
769    let offset = R::Native::usize_as(run_ends.offset());
770    let mut builder = BufferBuilder::<R::Native>::new(physical_length);
771    for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
772        builder.append(run_end_value.sub_wrapping(offset));
773    }
774    builder.append(R::Native::from_usize(run_array.len()).unwrap());
775    let new_run_ends = unsafe {
776        // Safety:
777        // The function builds a valid run_ends array and hence need not be validated.
778        ArrayDataBuilder::new(R::DATA_TYPE)
779            .len(physical_length)
780            .add_buffer(builder.finish())
781            .build_unchecked()
782    };
783
784    // build new values by slicing physical indices.
785    let new_values = run_array
786        .values()
787        .slice(start_physical_index, physical_length)
788        .into_data();
789
790    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
791        .len(run_array.len())
792        .add_child_data(new_run_ends)
793        .add_child_data(new_values);
794    let array_data = unsafe {
795        // Safety:
796        //  This function builds a valid run array and hence can skip validation.
797        builder.build_unchecked()
798    };
799    Ok(array_data.into())
800}
801
802/// Controls how dictionaries are handled in Arrow IPC messages
803#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
804pub enum DictionaryHandling {
805    /// Send the entire dictionary every time it is encountered (default)
806    #[default]
807    Resend,
808    /// Send only new dictionary values since the last batch (delta encoding)
809    ///
810    /// When a dictionary is first encountered, the entire dictionary is sent.
811    /// For subsequent batches, only values that are new (not previously sent)
812    /// are transmitted with the `isDelta` flag set to true.
813    Delta,
814}
815
816/// Describes what kind of update took place after a call to [`DictionaryTracker::insert`].
817#[derive(Debug, Clone)]
818pub enum DictionaryUpdate {
819    /// No dictionary was written, the dictionary was identical to what was already
820    /// in the tracker.
821    None,
822    /// No dictionary was present in the tracker
823    New,
824    /// Dictionary was replaced with the new data
825    Replaced,
826    /// Dictionary was updated, ArrayData is the delta between old and new
827    Delta(ArrayData),
828}
829
830/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
831/// multiple times.
832///
833/// Can optionally error if an update to an existing dictionary is attempted, which
834/// isn't allowed in the `FileWriter`.
835#[derive(Debug)]
836pub struct DictionaryTracker {
837    // NOTE: When adding fields, update the clear() method accordingly.
838    written: HashMap<i64, ArrayData>,
839    dict_ids: Vec<i64>,
840    error_on_replacement: bool,
841}
842
843impl DictionaryTracker {
844    /// Create a new [`DictionaryTracker`].
845    ///
846    /// If `error_on_replacement`
847    /// is true, an error will be generated if an update to an
848    /// existing dictionary is attempted.
849    pub fn new(error_on_replacement: bool) -> Self {
850        #[allow(deprecated)]
851        Self {
852            written: HashMap::new(),
853            dict_ids: Vec::new(),
854            error_on_replacement,
855        }
856    }
857
858    /// Record and return the next dictionary ID.
859    pub fn next_dict_id(&mut self) -> i64 {
860        let next = self
861            .dict_ids
862            .last()
863            .copied()
864            .map(|i| i + 1)
865            .unwrap_or_default();
866
867        self.dict_ids.push(next);
868        next
869    }
870
871    /// Return the sequence of dictionary IDs in the order they should be observed while
872    /// traversing the schema
873    pub fn dict_id(&mut self) -> &[i64] {
874        &self.dict_ids
875    }
876
877    /// Keep track of the dictionary with the given ID and values. Behavior:
878    ///
879    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
880    ///   that the dictionary was not actually inserted (because it's already been seen).
881    /// * If this ID has been written already but with different data, and this tracker is
882    ///   configured to return an error, return an error.
883    /// * If the tracker has not been configured to error on replacement or this dictionary
884    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
885    ///   inserted.
886    #[deprecated(since = "56.1.0", note = "Use `insert_column` instead")]
887    pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
888        let dict_data = column.to_data();
889        let dict_values = &dict_data.child_data()[0];
890
891        // If a dictionary with this id was already emitted, check if it was the same.
892        if let Some(last) = self.written.get(&dict_id) {
893            if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
894                // Same dictionary values => no need to emit it again
895                return Ok(false);
896            }
897            if self.error_on_replacement {
898                // If error on replacement perform a logical comparison
899                if last.child_data()[0] == *dict_values {
900                    // Same dictionary values => no need to emit it again
901                    return Ok(false);
902                }
903                return Err(ArrowError::InvalidArgumentError(
904                    "Dictionary replacement detected when writing IPC file format. \
905                     Arrow IPC files only support a single dictionary for a given field \
906                     across all batches."
907                        .to_string(),
908                ));
909            }
910        }
911
912        self.written.insert(dict_id, dict_data);
913        Ok(true)
914    }
915
916    /// Keep track of the dictionary with the given ID and values. The return
917    /// value indicates what, if any, update to the internal map took place
918    /// and how it should be interpreted based on the `dict_handling` parameter.
919    ///
920    /// # Returns
921    ///
922    /// * `Ok(Dictionary::New)` - If the dictionary was not previously written
923    /// * `Ok(Dictionary::Replaced)` - If the dictionary was previously written
924    ///   with completely different data, or if the data is a delta of the existing,
925    ///   but with `dict_handling` set to `DictionaryHandling::Resend`
926    /// * `Ok(Dictionary::Delta)` - If the dictionary was previously written, but
927    ///   the new data is a delta of the old and the `dict_handling` is set to
928    ///   `DictionaryHandling::Delta`
929    /// * `Err(e)` - If the dictionary was previously written with different data,
930    ///   and `error_on_replacement` is set to `true`.
931    pub fn insert_column(
932        &mut self,
933        dict_id: i64,
934        column: &ArrayRef,
935        dict_handling: DictionaryHandling,
936    ) -> Result<DictionaryUpdate, ArrowError> {
937        let new_data = column.to_data();
938        let new_values = &new_data.child_data()[0];
939
940        // If there is no existing dictionary with this ID, we always insert
941        let Some(old) = self.written.get(&dict_id) else {
942            self.written.insert(dict_id, new_data);
943            return Ok(DictionaryUpdate::New);
944        };
945
946        // Fast path - If the array data points to the same buffer as the
947        // existing then they're the same.
948        let old_values = &old.child_data()[0];
949        if ArrayData::ptr_eq(old_values, new_values) {
950            return Ok(DictionaryUpdate::None);
951        }
952
953        // Slow path - Compare the dictionaries value by value
954        let comparison = compare_dictionaries(old_values, new_values);
955        if matches!(comparison, DictionaryComparison::Equal) {
956            return Ok(DictionaryUpdate::None);
957        }
958
959        const REPLACEMENT_ERROR: &str = "Dictionary replacement detected when writing IPC file format. \
960                 Arrow IPC files only support a single dictionary for a given field \
961                 across all batches.";
962
963        match comparison {
964            DictionaryComparison::NotEqual => {
965                if self.error_on_replacement {
966                    return Err(ArrowError::InvalidArgumentError(
967                        REPLACEMENT_ERROR.to_string(),
968                    ));
969                }
970
971                self.written.insert(dict_id, new_data);
972                Ok(DictionaryUpdate::Replaced)
973            }
974            DictionaryComparison::Delta => match dict_handling {
975                DictionaryHandling::Resend => {
976                    if self.error_on_replacement {
977                        return Err(ArrowError::InvalidArgumentError(
978                            REPLACEMENT_ERROR.to_string(),
979                        ));
980                    }
981
982                    self.written.insert(dict_id, new_data);
983                    Ok(DictionaryUpdate::Replaced)
984                }
985                DictionaryHandling::Delta => {
986                    let delta =
987                        new_values.slice(old_values.len(), new_values.len() - old_values.len());
988                    self.written.insert(dict_id, new_data);
989                    Ok(DictionaryUpdate::Delta(delta))
990                }
991            },
992            DictionaryComparison::Equal => unreachable!("Already checked equal case"),
993        }
994    }
995
996    /// Clears the state of the dictionary tracker.
997    ///
998    /// This allows the dictionary tracker to be reused for a new IPC stream while avoiding the
999    /// allocation cost of creating a new instance. This method should not be called if
1000    /// the dictionary tracker will be used to continue writing to an existing IPC stream.
1001    pub fn clear(&mut self) {
1002        self.dict_ids.clear();
1003        self.written.clear();
1004    }
1005}
1006
1007/// Describes how two dictionary arrays compare to each other.
1008#[derive(Debug, Clone)]
1009enum DictionaryComparison {
1010    /// Neither a delta, nor an exact match
1011    NotEqual,
1012    /// Exact element-wise match
1013    Equal,
1014    /// The two arrays are dictionary deltas of each other, meaning the first
1015    /// is a prefix of the second.
1016    Delta,
1017}
1018
1019// Compares two dictionaries and returns a [`DictionaryComparison`].
1020fn compare_dictionaries(old: &ArrayData, new: &ArrayData) -> DictionaryComparison {
1021    // Check for exact match
1022    let existing_len = old.len();
1023    let new_len = new.len();
1024    if existing_len == new_len {
1025        if *old == *new {
1026            return DictionaryComparison::Equal;
1027        } else {
1028            return DictionaryComparison::NotEqual;
1029        }
1030    }
1031
1032    // Can't be a delta if the new is shorter than the existing
1033    if new_len < existing_len {
1034        return DictionaryComparison::NotEqual;
1035    }
1036
1037    // Check for delta
1038    if new.slice(0, existing_len) == *old {
1039        return DictionaryComparison::Delta;
1040    }
1041
1042    DictionaryComparison::NotEqual
1043}
1044
1045/// Arrow File Writer
1046///
1047/// Writes Arrow [`RecordBatch`]es in the [IPC File Format].
1048///
1049/// # See Also
1050///
1051/// * [`StreamWriter`] for writing IPC Streams
1052///
1053/// # Example
1054/// ```
1055/// # use arrow_array::record_batch;
1056/// # use arrow_ipc::writer::FileWriter;
1057/// # let mut file = vec![]; // mimic a file for the example
1058/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1059/// // create a new writer, the schema must be known in advance
1060/// let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1061/// // write each batch to the underlying writer
1062/// writer.write(&batch).unwrap();
1063/// // When all batches are written, call finish to flush all buffers
1064/// writer.finish().unwrap();
1065/// ```
1066/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1067pub struct FileWriter<W> {
1068    /// The object to write to
1069    writer: W,
1070    /// IPC write options
1071    write_options: IpcWriteOptions,
1072    /// A reference to the schema, used in validating record batches
1073    schema: SchemaRef,
1074    /// The number of bytes between each block of bytes, as an offset for random access
1075    block_offsets: usize,
1076    /// Dictionary blocks that will be written as part of the IPC footer
1077    dictionary_blocks: Vec<crate::Block>,
1078    /// Record blocks that will be written as part of the IPC footer
1079    record_blocks: Vec<crate::Block>,
1080    /// Whether the writer footer has been written, and the writer is finished
1081    finished: bool,
1082    /// Keeps track of dictionaries that have been written
1083    dictionary_tracker: DictionaryTracker,
1084    /// User level customized metadata
1085    custom_metadata: HashMap<String, String>,
1086
1087    data_gen: IpcDataGenerator,
1088
1089    compression_context: CompressionContext,
1090}
1091
1092impl<W: Write> FileWriter<BufWriter<W>> {
1093    /// Try to create a new file writer with the writer wrapped in a BufWriter.
1094    ///
1095    /// See [`FileWriter::try_new`] for an unbuffered version.
1096    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1097        Self::try_new(BufWriter::new(writer), schema)
1098    }
1099}
1100
1101impl<W: Write> FileWriter<W> {
1102    /// Try to create a new writer, with the schema written as part of the header
1103    ///
1104    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1105    ///
1106    /// # Errors
1107    ///
1108    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1109    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1110        let write_options = IpcWriteOptions::default();
1111        Self::try_new_with_options(writer, schema, write_options)
1112    }
1113
1114    /// Try to create a new writer with IpcWriteOptions
1115    ///
1116    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1117    ///
1118    /// # Errors
1119    ///
1120    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1121    pub fn try_new_with_options(
1122        mut writer: W,
1123        schema: &Schema,
1124        write_options: IpcWriteOptions,
1125    ) -> Result<Self, ArrowError> {
1126        let data_gen = IpcDataGenerator::default();
1127        // write magic to header aligned on alignment boundary
1128        let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
1129        let header_size = super::ARROW_MAGIC.len() + pad_len;
1130        writer.write_all(&super::ARROW_MAGIC)?;
1131        writer.write_all(&PADDING[..pad_len])?;
1132        // write the schema, set the written bytes to the schema + header
1133        let mut dictionary_tracker = DictionaryTracker::new(true);
1134        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1135            schema,
1136            &mut dictionary_tracker,
1137            &write_options,
1138        );
1139        let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
1140        Ok(Self {
1141            writer,
1142            write_options,
1143            schema: Arc::new(schema.clone()),
1144            block_offsets: meta + data + header_size,
1145            dictionary_blocks: vec![],
1146            record_blocks: vec![],
1147            finished: false,
1148            dictionary_tracker,
1149            custom_metadata: HashMap::new(),
1150            data_gen,
1151            compression_context: CompressionContext::default(),
1152        })
1153    }
1154
1155    /// Adds a key-value pair to the [FileWriter]'s custom metadata
1156    pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
1157        self.custom_metadata.insert(key.into(), value.into());
1158    }
1159
1160    /// Write a record batch to the file
1161    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1162        if self.finished {
1163            return Err(ArrowError::IpcError(
1164                "Cannot write record batch to file writer as it is closed".to_string(),
1165            ));
1166        }
1167
1168        let (encoded_dictionaries, encoded_message) = self.data_gen.encode(
1169            batch,
1170            &mut self.dictionary_tracker,
1171            &self.write_options,
1172            &mut self.compression_context,
1173        )?;
1174
1175        for encoded_dictionary in encoded_dictionaries {
1176            let (meta, data) =
1177                write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1178
1179            let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
1180            self.dictionary_blocks.push(block);
1181            self.block_offsets += meta + data;
1182        }
1183
1184        let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
1185
1186        // add a record block for the footer
1187        let block = crate::Block::new(
1188            self.block_offsets as i64,
1189            meta as i32, // TODO: is this still applicable?
1190            data as i64,
1191        );
1192        self.record_blocks.push(block);
1193        self.block_offsets += meta + data;
1194        Ok(())
1195    }
1196
1197    /// Write footer and closing tag, then mark the writer as done
1198    pub fn finish(&mut self) -> Result<(), ArrowError> {
1199        if self.finished {
1200            return Err(ArrowError::IpcError(
1201                "Cannot write footer to file writer as it is closed".to_string(),
1202            ));
1203        }
1204
1205        // write EOS
1206        write_continuation(&mut self.writer, &self.write_options, 0)?;
1207
1208        let mut fbb = FlatBufferBuilder::new();
1209        let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1210        let record_batches = fbb.create_vector(&self.record_blocks);
1211
1212        // dictionaries are already written, so we can reset dictionary tracker to reuse for schema
1213        self.dictionary_tracker.clear();
1214        let schema = IpcSchemaEncoder::new()
1215            .with_dictionary_tracker(&mut self.dictionary_tracker)
1216            .schema_to_fb_offset(&mut fbb, &self.schema);
1217        let fb_custom_metadata = (!self.custom_metadata.is_empty())
1218            .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1219
1220        let root = {
1221            let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1222            footer_builder.add_version(self.write_options.metadata_version);
1223            footer_builder.add_schema(schema);
1224            footer_builder.add_dictionaries(dictionaries);
1225            footer_builder.add_recordBatches(record_batches);
1226            if let Some(fb_custom_metadata) = fb_custom_metadata {
1227                footer_builder.add_custom_metadata(fb_custom_metadata);
1228            }
1229            footer_builder.finish()
1230        };
1231        fbb.finish(root, None);
1232        let footer_data = fbb.finished_data();
1233        self.writer.write_all(footer_data)?;
1234        self.writer
1235            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1236        self.writer.write_all(&super::ARROW_MAGIC)?;
1237        self.writer.flush()?;
1238        self.finished = true;
1239
1240        Ok(())
1241    }
1242
1243    /// Returns the arrow [`SchemaRef`] for this arrow file.
1244    pub fn schema(&self) -> &SchemaRef {
1245        &self.schema
1246    }
1247
1248    /// Gets a reference to the underlying writer.
1249    pub fn get_ref(&self) -> &W {
1250        &self.writer
1251    }
1252
1253    /// Gets a mutable reference to the underlying writer.
1254    ///
1255    /// It is inadvisable to directly write to the underlying writer.
1256    pub fn get_mut(&mut self) -> &mut W {
1257        &mut self.writer
1258    }
1259
1260    /// Flush the underlying writer.
1261    ///
1262    /// Both the BufWriter and the underlying writer are flushed.
1263    pub fn flush(&mut self) -> Result<(), ArrowError> {
1264        self.writer.flush()?;
1265        Ok(())
1266    }
1267
1268    /// Unwraps the underlying writer.
1269    ///
1270    /// The writer is flushed and the FileWriter is finished before returning.
1271    ///
1272    /// # Errors
1273    ///
1274    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1275    /// or while flushing the writer.
1276    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1277        if !self.finished {
1278            // `finish` flushes the writer.
1279            self.finish()?;
1280        }
1281        Ok(self.writer)
1282    }
1283}
1284
1285impl<W: Write> RecordBatchWriter for FileWriter<W> {
1286    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1287        self.write(batch)
1288    }
1289
1290    fn close(mut self) -> Result<(), ArrowError> {
1291        self.finish()
1292    }
1293}
1294
1295/// Arrow Stream Writer
1296///
1297/// Writes Arrow [`RecordBatch`]es to bytes using the [IPC Streaming Format].
1298///
1299/// # See Also
1300///
1301/// * [`FileWriter`] for writing IPC Files
1302///
1303/// # Example - Basic usage
1304/// ```
1305/// # use arrow_array::record_batch;
1306/// # use arrow_ipc::writer::StreamWriter;
1307/// # let mut stream = vec![]; // mimic a stream for the example
1308/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1309/// // create a new writer, the schema must be known in advance
1310/// let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1311/// // write each batch to the underlying stream
1312/// writer.write(&batch).unwrap();
1313/// // When all batches are written, call finish to flush all buffers
1314/// writer.finish().unwrap();
1315/// ```
1316/// # Example - Efficient delta dictionaries
1317/// ```
1318/// # use arrow_array::record_batch;
1319/// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1320/// # use arrow_ipc::writer::DictionaryHandling;
1321/// # use arrow_schema::{DataType, Field, Schema, SchemaRef};
1322/// # use arrow_array::{
1323/// #    builder::StringDictionaryBuilder, types::Int32Type, Array, ArrayRef, DictionaryArray,
1324/// #    RecordBatch, StringArray,
1325/// # };
1326/// # use std::sync::Arc;
1327///
1328/// let schema = Arc::new(Schema::new(vec![Field::new(
1329///    "col1",
1330///    DataType::Dictionary(Box::from(DataType::Int32), Box::from(DataType::Utf8)),
1331///    true,
1332/// )]));
1333///
1334/// let mut builder = StringDictionaryBuilder::<arrow_array::types::Int32Type>::new();
1335///
1336/// // `finish_preserve_values` will keep the dictionary values along with their
1337/// // key assignments so that they can be re-used in the next batch.
1338/// builder.append("a").unwrap();
1339/// builder.append("b").unwrap();
1340/// let array1 = builder.finish_preserve_values();
1341/// let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array1) as ArrayRef]).unwrap();
1342///
1343/// // In this batch, 'a' will have the same dictionary key as 'a' in the previous batch,
1344/// // and 'd' will take the next available key.
1345/// builder.append("a").unwrap();
1346/// builder.append("d").unwrap();
1347/// let array2 = builder.finish_preserve_values();
1348/// let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array2) as ArrayRef]).unwrap();
1349///
1350/// let mut stream = vec![];
1351/// // You must set `.with_dictionary_handling(DictionaryHandling::Delta)` to
1352/// // enable delta dictionaries in the writer
1353/// let options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
1354/// let mut writer = StreamWriter::try_new(&mut stream, &schema).unwrap();
1355///
1356/// // When writing the first batch, a dictionary message with 'a' and 'b' will be written
1357/// // prior to the record batch.
1358/// writer.write(&batch1).unwrap();
1359/// // With the second batch only a delta dictionary with 'd' will be written
1360/// // prior to the record batch. This is only possible with `finish_preserve_values`.
1361/// // Without it, 'a' and 'd' in this batch would have different keys than the
1362/// // first batch and so we'd have to send a replacement dictionary with new keys
1363/// // for both.
1364/// writer.write(&batch2).unwrap();
1365/// writer.finish().unwrap();
1366/// ```
1367/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1368pub struct StreamWriter<W> {
1369    /// The object to write to
1370    writer: W,
1371    /// IPC write options
1372    write_options: IpcWriteOptions,
1373    /// Whether the writer footer has been written, and the writer is finished
1374    finished: bool,
1375    /// Keeps track of dictionaries that have been written
1376    dictionary_tracker: DictionaryTracker,
1377
1378    data_gen: IpcDataGenerator,
1379
1380    compression_context: CompressionContext,
1381}
1382
1383impl<W: Write> StreamWriter<BufWriter<W>> {
1384    /// Try to create a new stream writer with the writer wrapped in a BufWriter.
1385    ///
1386    /// See [`StreamWriter::try_new`] for an unbuffered version.
1387    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1388        Self::try_new(BufWriter::new(writer), schema)
1389    }
1390}
1391
1392impl<W: Write> StreamWriter<W> {
1393    /// Try to create a new writer, with the schema written as part of the header.
1394    ///
1395    /// Note that there is no internal buffering. See also [`StreamWriter::try_new_buffered`].
1396    ///
1397    /// # Errors
1398    ///
1399    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1400    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1401        let write_options = IpcWriteOptions::default();
1402        Self::try_new_with_options(writer, schema, write_options)
1403    }
1404
1405    /// Try to create a new writer with [`IpcWriteOptions`].
1406    ///
1407    /// # Errors
1408    ///
1409    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1410    pub fn try_new_with_options(
1411        mut writer: W,
1412        schema: &Schema,
1413        write_options: IpcWriteOptions,
1414    ) -> Result<Self, ArrowError> {
1415        let data_gen = IpcDataGenerator::default();
1416        let mut dictionary_tracker = DictionaryTracker::new(false);
1417
1418        // write the schema, set the written bytes to the schema
1419        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1420            schema,
1421            &mut dictionary_tracker,
1422            &write_options,
1423        );
1424        write_message(&mut writer, encoded_message, &write_options)?;
1425        Ok(Self {
1426            writer,
1427            write_options,
1428            finished: false,
1429            dictionary_tracker,
1430            data_gen,
1431            compression_context: CompressionContext::default(),
1432        })
1433    }
1434
1435    /// Write a record batch to the stream
1436    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1437        if self.finished {
1438            return Err(ArrowError::IpcError(
1439                "Cannot write record batch to stream writer as it is closed".to_string(),
1440            ));
1441        }
1442
1443        let (encoded_dictionaries, encoded_message) = self
1444            .data_gen
1445            .encode(
1446                batch,
1447                &mut self.dictionary_tracker,
1448                &self.write_options,
1449                &mut self.compression_context,
1450            )
1451            .expect("StreamWriter is configured to not error on dictionary replacement");
1452
1453        for encoded_dictionary in encoded_dictionaries {
1454            write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1455        }
1456
1457        write_message(&mut self.writer, encoded_message, &self.write_options)?;
1458        Ok(())
1459    }
1460
1461    /// Write continuation bytes, and mark the stream as done
1462    pub fn finish(&mut self) -> Result<(), ArrowError> {
1463        if self.finished {
1464            return Err(ArrowError::IpcError(
1465                "Cannot write footer to stream writer as it is closed".to_string(),
1466            ));
1467        }
1468
1469        write_continuation(&mut self.writer, &self.write_options, 0)?;
1470        self.writer.flush()?;
1471
1472        self.finished = true;
1473
1474        Ok(())
1475    }
1476
1477    /// Gets a reference to the underlying writer.
1478    pub fn get_ref(&self) -> &W {
1479        &self.writer
1480    }
1481
1482    /// Gets a mutable reference to the underlying writer.
1483    ///
1484    /// It is inadvisable to directly write to the underlying writer.
1485    pub fn get_mut(&mut self) -> &mut W {
1486        &mut self.writer
1487    }
1488
1489    /// Flush the underlying writer.
1490    ///
1491    /// Both the BufWriter and the underlying writer are flushed.
1492    pub fn flush(&mut self) -> Result<(), ArrowError> {
1493        self.writer.flush()?;
1494        Ok(())
1495    }
1496
1497    /// Unwraps the the underlying writer.
1498    ///
1499    /// The writer is flushed and the StreamWriter is finished before returning.
1500    ///
1501    /// # Errors
1502    ///
1503    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1504    /// or while flushing the writer.
1505    ///
1506    /// # Example
1507    ///
1508    /// ```
1509    /// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1510    /// # use arrow_ipc::MetadataVersion;
1511    /// # use arrow_schema::{ArrowError, Schema};
1512    /// # fn main() -> Result<(), ArrowError> {
1513    /// // The result we expect from an empty schema
1514    /// let expected = vec![
1515    ///     255, 255, 255, 255,  48,   0,   0,   0,
1516    ///      16,   0,   0,   0,   0,   0,  10,   0,
1517    ///      12,   0,  10,   0,   9,   0,   4,   0,
1518    ///      10,   0,   0,   0,  16,   0,   0,   0,
1519    ///       0,   1,   4,   0,   8,   0,   8,   0,
1520    ///       0,   0,   4,   0,   8,   0,   0,   0,
1521    ///       4,   0,   0,   0,   0,   0,   0,   0,
1522    ///     255, 255, 255, 255,   0,   0,   0,   0
1523    /// ];
1524    ///
1525    /// let schema = Schema::empty();
1526    /// let buffer: Vec<u8> = Vec::new();
1527    /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?;
1528    /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?;
1529    ///
1530    /// assert_eq!(stream_writer.into_inner()?, expected);
1531    /// # Ok(())
1532    /// # }
1533    /// ```
1534    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1535        if !self.finished {
1536            // `finish` flushes.
1537            self.finish()?;
1538        }
1539        Ok(self.writer)
1540    }
1541}
1542
1543impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1544    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1545        self.write(batch)
1546    }
1547
1548    fn close(mut self) -> Result<(), ArrowError> {
1549        self.finish()
1550    }
1551}
1552
1553/// Stores the encoded data, which is an crate::Message, and optional Arrow data
1554pub struct EncodedData {
1555    /// An encoded crate::Message
1556    pub ipc_message: Vec<u8>,
1557    /// Arrow buffers to be written, should be an empty vec for schema messages
1558    pub arrow_data: Vec<u8>,
1559}
1560/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
1561pub fn write_message<W: Write>(
1562    mut writer: W,
1563    encoded: EncodedData,
1564    write_options: &IpcWriteOptions,
1565) -> Result<(usize, usize), ArrowError> {
1566    let arrow_data_len = encoded.arrow_data.len();
1567    if arrow_data_len % usize::from(write_options.alignment) != 0 {
1568        return Err(ArrowError::MemoryError(
1569            "Arrow data not aligned".to_string(),
1570        ));
1571    }
1572
1573    let a = usize::from(write_options.alignment - 1);
1574    let buffer = encoded.ipc_message;
1575    let flatbuf_size = buffer.len();
1576    let prefix_size = if write_options.write_legacy_ipc_format {
1577        4
1578    } else {
1579        8
1580    };
1581    let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1582    let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1583
1584    write_continuation(
1585        &mut writer,
1586        write_options,
1587        (aligned_size - prefix_size) as i32,
1588    )?;
1589
1590    // write the flatbuf
1591    if flatbuf_size > 0 {
1592        writer.write_all(&buffer)?;
1593    }
1594    // write padding
1595    writer.write_all(&PADDING[..padding_bytes])?;
1596
1597    // write arrow data
1598    let body_len = if arrow_data_len > 0 {
1599        write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1600    } else {
1601        0
1602    };
1603
1604    Ok((aligned_size, body_len))
1605}
1606
1607fn write_body_buffers<W: Write>(
1608    mut writer: W,
1609    data: &[u8],
1610    alignment: u8,
1611) -> Result<usize, ArrowError> {
1612    let len = data.len();
1613    let pad_len = pad_to_alignment(alignment, len);
1614    let total_len = len + pad_len;
1615
1616    // write body buffer
1617    writer.write_all(data)?;
1618    if pad_len > 0 {
1619        writer.write_all(&PADDING[..pad_len])?;
1620    }
1621
1622    Ok(total_len)
1623}
1624
1625/// Write a record batch to the writer, writing the message size before the message
1626/// if the record batch is being written to a stream
1627fn write_continuation<W: Write>(
1628    mut writer: W,
1629    write_options: &IpcWriteOptions,
1630    total_len: i32,
1631) -> Result<usize, ArrowError> {
1632    let mut written = 8;
1633
1634    // the version of the writer determines whether continuation markers should be added
1635    match write_options.metadata_version {
1636        crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1637            unreachable!("Options with the metadata version cannot be created")
1638        }
1639        crate::MetadataVersion::V4 => {
1640            if !write_options.write_legacy_ipc_format {
1641                // v0.15.0 format
1642                writer.write_all(&CONTINUATION_MARKER)?;
1643                written = 4;
1644            }
1645            writer.write_all(&total_len.to_le_bytes()[..])?;
1646        }
1647        crate::MetadataVersion::V5 => {
1648            // write continuation marker and message length
1649            writer.write_all(&CONTINUATION_MARKER)?;
1650            writer.write_all(&total_len.to_le_bytes()[..])?;
1651        }
1652        z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1653    };
1654
1655    Ok(written)
1656}
1657
1658/// In V4, null types have no validity bitmap
1659/// In V5 and later, null and union types have no validity bitmap
1660/// Run end encoded type has no validity bitmap.
1661fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1662    if write_options.metadata_version < crate::MetadataVersion::V5 {
1663        !matches!(data_type, DataType::Null)
1664    } else {
1665        !matches!(
1666            data_type,
1667            DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1668        )
1669    }
1670}
1671
1672/// Whether to truncate the buffer
1673#[inline]
1674fn buffer_need_truncate(
1675    array_offset: usize,
1676    buffer: &Buffer,
1677    spec: &BufferSpec,
1678    min_length: usize,
1679) -> bool {
1680    spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1681}
1682
1683/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
1684#[inline]
1685fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1686    match spec {
1687        BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1688        _ => 0,
1689    }
1690}
1691
1692/// Common functionality for re-encoding offsets. Returns the new offsets as well as
1693/// original start offset and length for use in slicing child data.
1694fn reencode_offsets<O: OffsetSizeTrait>(
1695    offsets: &Buffer,
1696    data: &ArrayData,
1697) -> (Buffer, usize, usize) {
1698    let offsets_slice: &[O] = offsets.typed_data::<O>();
1699    let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1700
1701    let start_offset = offset_slice.first().unwrap();
1702    let end_offset = offset_slice.last().unwrap();
1703
1704    let offsets = match start_offset.as_usize() {
1705        0 => {
1706            let size = size_of::<O>();
1707            offsets.slice_with_length(data.offset() * size, (data.len() + 1) * size)
1708        }
1709        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1710    };
1711
1712    let start_offset = start_offset.as_usize();
1713    let end_offset = end_offset.as_usize();
1714
1715    (offsets, start_offset, end_offset - start_offset)
1716}
1717
1718/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O`
1719///
1720/// In particular, this handles re-encoding the offsets if they don't start at `0`,
1721/// slicing the values buffer as appropriate. This helps reduce the encoded
1722/// size of sliced arrays, as values that have been sliced away are not encoded
1723fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1724    if data.is_empty() {
1725        return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1726    }
1727
1728    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1729    let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1730    (offsets, values)
1731}
1732
1733/// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead
1734/// of a values buffer.
1735fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1736    if data.is_empty() {
1737        return (
1738            MutableBuffer::new(0).into(),
1739            data.child_data()[0].slice(0, 0),
1740        );
1741    }
1742
1743    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1744    let child_data = data.child_data()[0].slice(original_start_offset, len);
1745    (offsets, child_data)
1746}
1747
1748/// Returns the offsets, sizes, and child data buffers for a ListView array.
1749///
1750/// Unlike List arrays, ListView arrays store both offsets and sizes explicitly,
1751/// and offsets can be non-monotonic. When slicing, we simply pass through the
1752/// offsets and sizes without re-encoding, and do not slice the child data.
1753fn get_list_view_array_buffers<O: OffsetSizeTrait>(
1754    data: &ArrayData,
1755) -> (Buffer, Buffer, ArrayData) {
1756    if data.is_empty() {
1757        return (
1758            MutableBuffer::new(0).into(),
1759            MutableBuffer::new(0).into(),
1760            data.child_data()[0].slice(0, 0),
1761        );
1762    }
1763
1764    let offsets = &data.buffers()[0];
1765    let sizes = &data.buffers()[1];
1766
1767    let element_size = std::mem::size_of::<O>();
1768    let offsets_slice =
1769        offsets.slice_with_length(data.offset() * element_size, data.len() * element_size);
1770    let sizes_slice =
1771        sizes.slice_with_length(data.offset() * element_size, data.len() * element_size);
1772
1773    let child_data = data.child_data()[0].clone();
1774
1775    (offsets_slice, sizes_slice, child_data)
1776}
1777
1778/// Returns the sliced views [`Buffer`] for a BinaryView/Utf8View array.
1779///
1780/// The views buffer is sliced to only include views in the valid range based on
1781/// the array's offset and length. This helps reduce the encoded size of sliced
1782/// arrays
1783///
1784fn get_or_truncate_buffer(array_data: &ArrayData) -> &[u8] {
1785    let buffer = &array_data.buffers()[0];
1786    let layout = layout(array_data.data_type());
1787    let spec = &layout.buffers[0];
1788
1789    let byte_width = get_buffer_element_width(spec);
1790    let min_length = array_data.len() * byte_width;
1791    if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1792        let byte_offset = array_data.offset() * byte_width;
1793        let buffer_length = min(min_length, buffer.len() - byte_offset);
1794        &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1795    } else {
1796        buffer.as_slice()
1797    }
1798}
1799
1800/// Write array data to a vector of bytes
1801#[allow(clippy::too_many_arguments)]
1802fn write_array_data(
1803    array_data: &ArrayData,
1804    buffers: &mut Vec<crate::Buffer>,
1805    arrow_data: &mut Vec<u8>,
1806    nodes: &mut Vec<crate::FieldNode>,
1807    offset: i64,
1808    num_rows: usize,
1809    null_count: usize,
1810    compression_codec: Option<CompressionCodec>,
1811    compression_context: &mut CompressionContext,
1812    write_options: &IpcWriteOptions,
1813) -> Result<i64, ArrowError> {
1814    let mut offset = offset;
1815    if !matches!(array_data.data_type(), DataType::Null) {
1816        nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1817    } else {
1818        // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData
1819        // where null_count is always 0.
1820        nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1821    }
1822    if has_validity_bitmap(array_data.data_type(), write_options) {
1823        // write null buffer if exists
1824        let null_buffer = match array_data.nulls() {
1825            None => {
1826                // create a buffer and fill it with valid bits
1827                let num_bytes = bit_util::ceil(num_rows, 8);
1828                let buffer = MutableBuffer::new(num_bytes);
1829                let buffer = buffer.with_bitset(num_bytes, true);
1830                buffer.into()
1831            }
1832            Some(buffer) => buffer.inner().sliced(),
1833        };
1834
1835        offset = write_buffer(
1836            null_buffer.as_slice(),
1837            buffers,
1838            arrow_data,
1839            offset,
1840            compression_codec,
1841            compression_context,
1842            write_options.alignment,
1843        )?;
1844    }
1845
1846    let data_type = array_data.data_type();
1847    if matches!(data_type, DataType::Binary | DataType::Utf8) {
1848        let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1849        for buffer in [offsets, values] {
1850            offset = write_buffer(
1851                buffer.as_slice(),
1852                buffers,
1853                arrow_data,
1854                offset,
1855                compression_codec,
1856                compression_context,
1857                write_options.alignment,
1858            )?;
1859        }
1860    } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1861        // Slicing the views buffer is safe and easy,
1862        // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers
1863        //
1864        // Current implementation just serialize the raw arrays as given and not try to optimize anything.
1865        // If users wants to "compact" the arrays prior to sending them over IPC,
1866        // they should consider the gc API suggested in #5513
1867        let views = get_or_truncate_buffer(array_data);
1868        offset = write_buffer(
1869            views,
1870            buffers,
1871            arrow_data,
1872            offset,
1873            compression_codec,
1874            compression_context,
1875            write_options.alignment,
1876        )?;
1877
1878        for buffer in array_data.buffers().iter().skip(1) {
1879            offset = write_buffer(
1880                buffer.as_slice(),
1881                buffers,
1882                arrow_data,
1883                offset,
1884                compression_codec,
1885                compression_context,
1886                write_options.alignment,
1887            )?;
1888        }
1889    } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1890        let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1891        for buffer in [offsets, values] {
1892            offset = write_buffer(
1893                buffer.as_slice(),
1894                buffers,
1895                arrow_data,
1896                offset,
1897                compression_codec,
1898                compression_context,
1899                write_options.alignment,
1900            )?;
1901        }
1902    } else if DataType::is_numeric(data_type)
1903        || DataType::is_temporal(data_type)
1904        || matches!(
1905            array_data.data_type(),
1906            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1907        )
1908    {
1909        // Truncate values
1910        assert_eq!(array_data.buffers().len(), 1);
1911
1912        let buffer = get_or_truncate_buffer(array_data);
1913        offset = write_buffer(
1914            buffer,
1915            buffers,
1916            arrow_data,
1917            offset,
1918            compression_codec,
1919            compression_context,
1920            write_options.alignment,
1921        )?;
1922    } else if matches!(data_type, DataType::Boolean) {
1923        // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
1924        // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
1925        assert_eq!(array_data.buffers().len(), 1);
1926
1927        let buffer = &array_data.buffers()[0];
1928        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1929        offset = write_buffer(
1930            &buffer,
1931            buffers,
1932            arrow_data,
1933            offset,
1934            compression_codec,
1935            compression_context,
1936            write_options.alignment,
1937        )?;
1938    } else if matches!(
1939        data_type,
1940        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1941    ) {
1942        assert_eq!(array_data.buffers().len(), 1);
1943        assert_eq!(array_data.child_data().len(), 1);
1944
1945        // Truncate offsets and the child data to avoid writing unnecessary data
1946        let (offsets, sliced_child_data) = match data_type {
1947            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1948            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1949            DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1950            _ => unreachable!(),
1951        };
1952        offset = write_buffer(
1953            offsets.as_slice(),
1954            buffers,
1955            arrow_data,
1956            offset,
1957            compression_codec,
1958            compression_context,
1959            write_options.alignment,
1960        )?;
1961        offset = write_array_data(
1962            &sliced_child_data,
1963            buffers,
1964            arrow_data,
1965            nodes,
1966            offset,
1967            sliced_child_data.len(),
1968            sliced_child_data.null_count(),
1969            compression_codec,
1970            compression_context,
1971            write_options,
1972        )?;
1973        return Ok(offset);
1974    } else if matches!(
1975        data_type,
1976        DataType::ListView(_) | DataType::LargeListView(_)
1977    ) {
1978        assert_eq!(array_data.buffers().len(), 2); // offsets + sizes
1979        assert_eq!(array_data.child_data().len(), 1);
1980
1981        let (offsets, sizes, child_data) = match data_type {
1982            DataType::ListView(_) => get_list_view_array_buffers::<i32>(array_data),
1983            DataType::LargeListView(_) => get_list_view_array_buffers::<i64>(array_data),
1984            _ => unreachable!(),
1985        };
1986
1987        offset = write_buffer(
1988            offsets.as_slice(),
1989            buffers,
1990            arrow_data,
1991            offset,
1992            compression_codec,
1993            compression_context,
1994            write_options.alignment,
1995        )?;
1996
1997        offset = write_buffer(
1998            sizes.as_slice(),
1999            buffers,
2000            arrow_data,
2001            offset,
2002            compression_codec,
2003            compression_context,
2004            write_options.alignment,
2005        )?;
2006
2007        offset = write_array_data(
2008            &child_data,
2009            buffers,
2010            arrow_data,
2011            nodes,
2012            offset,
2013            child_data.len(),
2014            child_data.null_count(),
2015            compression_codec,
2016            compression_context,
2017            write_options,
2018        )?;
2019        return Ok(offset);
2020    } else if let DataType::FixedSizeList(_, fixed_size) = data_type {
2021        assert_eq!(array_data.child_data().len(), 1);
2022        let fixed_size = *fixed_size as usize;
2023
2024        let child_offset = array_data.offset() * fixed_size;
2025        let child_length = array_data.len() * fixed_size;
2026        let child_data = array_data.child_data()[0].slice(child_offset, child_length);
2027
2028        offset = write_array_data(
2029            &child_data,
2030            buffers,
2031            arrow_data,
2032            nodes,
2033            offset,
2034            child_data.len(),
2035            child_data.null_count(),
2036            compression_codec,
2037            compression_context,
2038            write_options,
2039        )?;
2040        return Ok(offset);
2041    } else {
2042        for buffer in array_data.buffers() {
2043            offset = write_buffer(
2044                buffer,
2045                buffers,
2046                arrow_data,
2047                offset,
2048                compression_codec,
2049                compression_context,
2050                write_options.alignment,
2051            )?;
2052        }
2053    }
2054
2055    match array_data.data_type() {
2056        DataType::Dictionary(_, _) => {}
2057        DataType::RunEndEncoded(_, _) => {
2058            // unslice the run encoded array.
2059            let arr = unslice_run_array(array_data.clone())?;
2060            // recursively write out nested structures
2061            for data_ref in arr.child_data() {
2062                // write the nested data (e.g list data)
2063                offset = write_array_data(
2064                    data_ref,
2065                    buffers,
2066                    arrow_data,
2067                    nodes,
2068                    offset,
2069                    data_ref.len(),
2070                    data_ref.null_count(),
2071                    compression_codec,
2072                    compression_context,
2073                    write_options,
2074                )?;
2075            }
2076        }
2077        _ => {
2078            // recursively write out nested structures
2079            for data_ref in array_data.child_data() {
2080                // write the nested data (e.g list data)
2081                offset = write_array_data(
2082                    data_ref,
2083                    buffers,
2084                    arrow_data,
2085                    nodes,
2086                    offset,
2087                    data_ref.len(),
2088                    data_ref.null_count(),
2089                    compression_codec,
2090                    compression_context,
2091                    write_options,
2092                )?;
2093            }
2094        }
2095    }
2096    Ok(offset)
2097}
2098
2099/// Write a buffer into `arrow_data`, a vector of bytes, and adds its
2100/// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data`
2101///
2102///
2103/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
2104/// Each constituent buffer is first compressed with the indicated
2105/// compressor, and then written with the uncompressed length in the first 8
2106/// bytes as a 64-bit little-endian signed integer followed by the compressed
2107/// buffer bytes (and then padding as required by the protocol). The
2108/// uncompressed length may be set to -1 to indicate that the data that
2109/// follows is not compressed, which can be useful for cases where
2110/// compression does not yield appreciable savings.
2111fn write_buffer(
2112    buffer: &[u8],                    // input
2113    buffers: &mut Vec<crate::Buffer>, // output buffer descriptors
2114    arrow_data: &mut Vec<u8>,         // output stream
2115    offset: i64,                      // current output stream offset
2116    compression_codec: Option<CompressionCodec>,
2117    compression_context: &mut CompressionContext,
2118    alignment: u8,
2119) -> Result<i64, ArrowError> {
2120    let len: i64 = match compression_codec {
2121        Some(compressor) => compressor.compress_to_vec(buffer, arrow_data, compression_context)?,
2122        None => {
2123            arrow_data.extend_from_slice(buffer);
2124            buffer.len()
2125        }
2126    }
2127    .try_into()
2128    .map_err(|e| {
2129        ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
2130    })?;
2131
2132    // make new index entry
2133    buffers.push(crate::Buffer::new(offset, len));
2134    // padding and make offset aligned
2135    let pad_len = pad_to_alignment(alignment, len as usize);
2136    arrow_data.extend_from_slice(&PADDING[..pad_len]);
2137
2138    Ok(offset + len + (pad_len as i64))
2139}
2140
2141const PADDING: [u8; 64] = [0; 64];
2142
2143/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary
2144#[inline]
2145fn pad_to_alignment(alignment: u8, len: usize) -> usize {
2146    let a = usize::from(alignment - 1);
2147    ((len + a) & !a) - len
2148}
2149
2150#[cfg(test)]
2151mod tests {
2152    use std::hash::Hasher;
2153    use std::io::Cursor;
2154    use std::io::Seek;
2155
2156    use arrow_array::builder::FixedSizeListBuilder;
2157    use arrow_array::builder::Float32Builder;
2158    use arrow_array::builder::Int64Builder;
2159    use arrow_array::builder::MapBuilder;
2160    use arrow_array::builder::StringViewBuilder;
2161    use arrow_array::builder::UnionBuilder;
2162    use arrow_array::builder::{
2163        GenericListBuilder, GenericListViewBuilder, ListBuilder, StringBuilder,
2164    };
2165    use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
2166    use arrow_array::types::*;
2167    use arrow_buffer::ScalarBuffer;
2168
2169    use crate::MetadataVersion;
2170    use crate::convert::fb_to_schema;
2171    use crate::reader::*;
2172    use crate::root_as_footer;
2173
2174    use super::*;
2175
2176    fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
2177        let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
2178        writer.write(rb).unwrap();
2179        writer.finish().unwrap();
2180        writer.into_inner().unwrap()
2181    }
2182
2183    fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
2184        let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
2185        reader.next().unwrap().unwrap()
2186    }
2187
2188    fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
2189        // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written,
2190        // without needing to construct a giant array to spill over the 64-byte default alignment
2191        // boundary.
2192        const IPC_ALIGNMENT: usize = 8;
2193
2194        let mut stream_writer = StreamWriter::try_new_with_options(
2195            vec![],
2196            record.schema_ref(),
2197            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2198        )
2199        .unwrap();
2200        stream_writer.write(record).unwrap();
2201        stream_writer.finish().unwrap();
2202        stream_writer.into_inner().unwrap()
2203    }
2204
2205    fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
2206        let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
2207        stream_reader.next().unwrap().unwrap()
2208    }
2209
2210    #[test]
2211    #[cfg(feature = "lz4")]
2212    fn test_write_empty_record_batch_lz4_compression() {
2213        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2214        let values: Vec<Option<i32>> = vec![];
2215        let array = Int32Array::from(values);
2216        let record_batch =
2217            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2218
2219        let mut file = tempfile::tempfile().unwrap();
2220
2221        {
2222            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2223                .unwrap()
2224                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2225                .unwrap();
2226
2227            let mut writer =
2228                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2229            writer.write(&record_batch).unwrap();
2230            writer.finish().unwrap();
2231        }
2232        file.rewind().unwrap();
2233        {
2234            // read file
2235            let reader = FileReader::try_new(file, None).unwrap();
2236            for read_batch in reader {
2237                read_batch
2238                    .unwrap()
2239                    .columns()
2240                    .iter()
2241                    .zip(record_batch.columns())
2242                    .for_each(|(a, b)| {
2243                        assert_eq!(a.data_type(), b.data_type());
2244                        assert_eq!(a.len(), b.len());
2245                        assert_eq!(a.null_count(), b.null_count());
2246                    });
2247            }
2248        }
2249    }
2250
2251    #[test]
2252    #[cfg(feature = "lz4")]
2253    fn test_write_file_with_lz4_compression() {
2254        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2255        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2256        let array = Int32Array::from(values);
2257        let record_batch =
2258            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2259
2260        let mut file = tempfile::tempfile().unwrap();
2261        {
2262            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2263                .unwrap()
2264                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2265                .unwrap();
2266
2267            let mut writer =
2268                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2269            writer.write(&record_batch).unwrap();
2270            writer.finish().unwrap();
2271        }
2272        file.rewind().unwrap();
2273        {
2274            // read file
2275            let reader = FileReader::try_new(file, None).unwrap();
2276            for read_batch in reader {
2277                read_batch
2278                    .unwrap()
2279                    .columns()
2280                    .iter()
2281                    .zip(record_batch.columns())
2282                    .for_each(|(a, b)| {
2283                        assert_eq!(a.data_type(), b.data_type());
2284                        assert_eq!(a.len(), b.len());
2285                        assert_eq!(a.null_count(), b.null_count());
2286                    });
2287            }
2288        }
2289    }
2290
2291    #[test]
2292    #[cfg(feature = "zstd")]
2293    fn test_write_file_with_zstd_compression() {
2294        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2295        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2296        let array = Int32Array::from(values);
2297        let record_batch =
2298            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2299        let mut file = tempfile::tempfile().unwrap();
2300        {
2301            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2302                .unwrap()
2303                .try_with_compression(Some(crate::CompressionType::ZSTD))
2304                .unwrap();
2305
2306            let mut writer =
2307                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2308            writer.write(&record_batch).unwrap();
2309            writer.finish().unwrap();
2310        }
2311        file.rewind().unwrap();
2312        {
2313            // read file
2314            let reader = FileReader::try_new(file, None).unwrap();
2315            for read_batch in reader {
2316                read_batch
2317                    .unwrap()
2318                    .columns()
2319                    .iter()
2320                    .zip(record_batch.columns())
2321                    .for_each(|(a, b)| {
2322                        assert_eq!(a.data_type(), b.data_type());
2323                        assert_eq!(a.len(), b.len());
2324                        assert_eq!(a.null_count(), b.null_count());
2325                    });
2326            }
2327        }
2328    }
2329
2330    #[test]
2331    fn test_write_file() {
2332        let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
2333        let values: Vec<Option<u32>> = vec![
2334            Some(999),
2335            None,
2336            Some(235),
2337            Some(123),
2338            None,
2339            None,
2340            None,
2341            None,
2342            None,
2343        ];
2344        let array1 = UInt32Array::from(values);
2345        let batch =
2346            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
2347                .unwrap();
2348        let mut file = tempfile::tempfile().unwrap();
2349        {
2350            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2351
2352            writer.write(&batch).unwrap();
2353            writer.finish().unwrap();
2354        }
2355        file.rewind().unwrap();
2356
2357        {
2358            let mut reader = FileReader::try_new(file, None).unwrap();
2359            while let Some(Ok(read_batch)) = reader.next() {
2360                read_batch
2361                    .columns()
2362                    .iter()
2363                    .zip(batch.columns())
2364                    .for_each(|(a, b)| {
2365                        assert_eq!(a.data_type(), b.data_type());
2366                        assert_eq!(a.len(), b.len());
2367                        assert_eq!(a.null_count(), b.null_count());
2368                    });
2369            }
2370        }
2371    }
2372
2373    fn write_null_file(options: IpcWriteOptions) {
2374        let schema = Schema::new(vec![
2375            Field::new("nulls", DataType::Null, true),
2376            Field::new("int32s", DataType::Int32, false),
2377            Field::new("nulls2", DataType::Null, true),
2378            Field::new("f64s", DataType::Float64, false),
2379        ]);
2380        let array1 = NullArray::new(32);
2381        let array2 = Int32Array::from(vec![1; 32]);
2382        let array3 = NullArray::new(32);
2383        let array4 = Float64Array::from(vec![f64::NAN; 32]);
2384        let batch = RecordBatch::try_new(
2385            Arc::new(schema.clone()),
2386            vec![
2387                Arc::new(array1) as ArrayRef,
2388                Arc::new(array2) as ArrayRef,
2389                Arc::new(array3) as ArrayRef,
2390                Arc::new(array4) as ArrayRef,
2391            ],
2392        )
2393        .unwrap();
2394        let mut file = tempfile::tempfile().unwrap();
2395        {
2396            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2397
2398            writer.write(&batch).unwrap();
2399            writer.finish().unwrap();
2400        }
2401
2402        file.rewind().unwrap();
2403
2404        {
2405            let reader = FileReader::try_new(file, None).unwrap();
2406            reader.for_each(|maybe_batch| {
2407                maybe_batch
2408                    .unwrap()
2409                    .columns()
2410                    .iter()
2411                    .zip(batch.columns())
2412                    .for_each(|(a, b)| {
2413                        assert_eq!(a.data_type(), b.data_type());
2414                        assert_eq!(a.len(), b.len());
2415                        assert_eq!(a.null_count(), b.null_count());
2416                    });
2417            });
2418        }
2419    }
2420    #[test]
2421    fn test_write_null_file_v4() {
2422        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2423        write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2424        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2425        write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2426    }
2427
2428    #[test]
2429    fn test_write_null_file_v5() {
2430        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2431        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2432    }
2433
2434    #[test]
2435    fn track_union_nested_dict() {
2436        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2437
2438        let array = Arc::new(inner) as ArrayRef;
2439
2440        // Dict field with id 2
2441        #[allow(deprecated)]
2442        let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 0, false);
2443        let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2444
2445        let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2446        let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2447
2448        let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2449
2450        let schema = Arc::new(Schema::new(vec![Field::new(
2451            "union",
2452            union.data_type().clone(),
2453            false,
2454        )]));
2455
2456        let r#gen = IpcDataGenerator::default();
2457        let mut dict_tracker = DictionaryTracker::new(false);
2458        r#gen.schema_to_bytes_with_dictionary_tracker(
2459            &schema,
2460            &mut dict_tracker,
2461            &IpcWriteOptions::default(),
2462        );
2463
2464        let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2465
2466        r#gen
2467            .encode(
2468                &batch,
2469                &mut dict_tracker,
2470                &Default::default(),
2471                &mut Default::default(),
2472            )
2473            .unwrap();
2474
2475        // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema
2476        // so we expect the dict will be keyed to 0
2477        assert!(dict_tracker.written.contains_key(&0));
2478    }
2479
2480    #[test]
2481    fn track_struct_nested_dict() {
2482        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2483
2484        let array = Arc::new(inner) as ArrayRef;
2485
2486        // Dict field with id 2
2487        #[allow(deprecated)]
2488        let dctfield = Arc::new(Field::new_dict(
2489            "dict",
2490            array.data_type().clone(),
2491            false,
2492            2,
2493            false,
2494        ));
2495
2496        let s = StructArray::from(vec![(dctfield, array)]);
2497        let struct_array = Arc::new(s) as ArrayRef;
2498
2499        let schema = Arc::new(Schema::new(vec![Field::new(
2500            "struct",
2501            struct_array.data_type().clone(),
2502            false,
2503        )]));
2504
2505        let r#gen = IpcDataGenerator::default();
2506        let mut dict_tracker = DictionaryTracker::new(false);
2507        r#gen.schema_to_bytes_with_dictionary_tracker(
2508            &schema,
2509            &mut dict_tracker,
2510            &IpcWriteOptions::default(),
2511        );
2512
2513        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2514
2515        r#gen
2516            .encode(
2517                &batch,
2518                &mut dict_tracker,
2519                &Default::default(),
2520                &mut Default::default(),
2521            )
2522            .unwrap();
2523
2524        assert!(dict_tracker.written.contains_key(&0));
2525    }
2526
2527    fn write_union_file(options: IpcWriteOptions) {
2528        let schema = Schema::new(vec![Field::new_union(
2529            "union",
2530            vec![0, 1],
2531            vec![
2532                Field::new("a", DataType::Int32, false),
2533                Field::new("c", DataType::Float64, false),
2534            ],
2535            UnionMode::Sparse,
2536        )]);
2537        let mut builder = UnionBuilder::with_capacity_sparse(5);
2538        builder.append::<Int32Type>("a", 1).unwrap();
2539        builder.append_null::<Int32Type>("a").unwrap();
2540        builder.append::<Float64Type>("c", 3.0).unwrap();
2541        builder.append_null::<Float64Type>("c").unwrap();
2542        builder.append::<Int32Type>("a", 4).unwrap();
2543        let union = builder.build().unwrap();
2544
2545        let batch =
2546            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2547                .unwrap();
2548
2549        let mut file = tempfile::tempfile().unwrap();
2550        {
2551            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2552
2553            writer.write(&batch).unwrap();
2554            writer.finish().unwrap();
2555        }
2556        file.rewind().unwrap();
2557
2558        {
2559            let reader = FileReader::try_new(file, None).unwrap();
2560            reader.for_each(|maybe_batch| {
2561                maybe_batch
2562                    .unwrap()
2563                    .columns()
2564                    .iter()
2565                    .zip(batch.columns())
2566                    .for_each(|(a, b)| {
2567                        assert_eq!(a.data_type(), b.data_type());
2568                        assert_eq!(a.len(), b.len());
2569                        assert_eq!(a.null_count(), b.null_count());
2570                    });
2571            });
2572        }
2573    }
2574
2575    #[test]
2576    fn test_write_union_file_v4_v5() {
2577        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2578        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2579    }
2580
2581    #[test]
2582    fn test_write_view_types() {
2583        const LONG_TEST_STRING: &str =
2584            "This is a long string to make sure binary view array handles it";
2585        let schema = Schema::new(vec![
2586            Field::new("field1", DataType::BinaryView, true),
2587            Field::new("field2", DataType::Utf8View, true),
2588        ]);
2589        let values: Vec<Option<&[u8]>> = vec![
2590            Some(b"foo"),
2591            Some(b"bar"),
2592            Some(LONG_TEST_STRING.as_bytes()),
2593        ];
2594        let binary_array = BinaryViewArray::from_iter(values);
2595        let utf8_array =
2596            StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2597        let record_batch = RecordBatch::try_new(
2598            Arc::new(schema.clone()),
2599            vec![Arc::new(binary_array), Arc::new(utf8_array)],
2600        )
2601        .unwrap();
2602
2603        let mut file = tempfile::tempfile().unwrap();
2604        {
2605            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2606            writer.write(&record_batch).unwrap();
2607            writer.finish().unwrap();
2608        }
2609        file.rewind().unwrap();
2610        {
2611            let mut reader = FileReader::try_new(&file, None).unwrap();
2612            let read_batch = reader.next().unwrap().unwrap();
2613            read_batch
2614                .columns()
2615                .iter()
2616                .zip(record_batch.columns())
2617                .for_each(|(a, b)| {
2618                    assert_eq!(a, b);
2619                });
2620        }
2621        file.rewind().unwrap();
2622        {
2623            let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2624            let read_batch = reader.next().unwrap().unwrap();
2625            assert_eq!(read_batch.num_columns(), 1);
2626            let read_array = read_batch.column(0);
2627            let write_array = record_batch.column(0);
2628            assert_eq!(read_array, write_array);
2629        }
2630    }
2631
2632    #[test]
2633    fn truncate_ipc_record_batch() {
2634        fn create_batch(rows: usize) -> RecordBatch {
2635            let schema = Schema::new(vec![
2636                Field::new("a", DataType::Int32, false),
2637                Field::new("b", DataType::Utf8, false),
2638            ]);
2639
2640            let a = Int32Array::from_iter_values(0..rows as i32);
2641            let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2642
2643            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2644        }
2645
2646        let big_record_batch = create_batch(65536);
2647
2648        let length = 5;
2649        let small_record_batch = create_batch(length);
2650
2651        let offset = 2;
2652        let record_batch_slice = big_record_batch.slice(offset, length);
2653        assert!(
2654            serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2655        );
2656        assert_eq!(
2657            serialize_stream(&small_record_batch).len(),
2658            serialize_stream(&record_batch_slice).len()
2659        );
2660
2661        assert_eq!(
2662            deserialize_stream(serialize_stream(&record_batch_slice)),
2663            record_batch_slice
2664        );
2665    }
2666
2667    #[test]
2668    fn truncate_ipc_record_batch_with_nulls() {
2669        fn create_batch() -> RecordBatch {
2670            let schema = Schema::new(vec![
2671                Field::new("a", DataType::Int32, true),
2672                Field::new("b", DataType::Utf8, true),
2673            ]);
2674
2675            let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2676            let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2677
2678            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2679        }
2680
2681        let record_batch = create_batch();
2682        let record_batch_slice = record_batch.slice(1, 2);
2683        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2684
2685        assert!(
2686            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2687        );
2688
2689        assert!(deserialized_batch.column(0).is_null(0));
2690        assert!(deserialized_batch.column(0).is_valid(1));
2691        assert!(deserialized_batch.column(1).is_valid(0));
2692        assert!(deserialized_batch.column(1).is_valid(1));
2693
2694        assert_eq!(record_batch_slice, deserialized_batch);
2695    }
2696
2697    #[test]
2698    fn truncate_ipc_dictionary_array() {
2699        fn create_batch() -> RecordBatch {
2700            let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2701                .into_iter()
2702                .collect();
2703            let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2704
2705            let array = DictionaryArray::new(keys, Arc::new(values));
2706
2707            let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2708
2709            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2710        }
2711
2712        let record_batch = create_batch();
2713        let record_batch_slice = record_batch.slice(1, 2);
2714        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2715
2716        assert!(
2717            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2718        );
2719
2720        assert!(deserialized_batch.column(0).is_valid(0));
2721        assert!(deserialized_batch.column(0).is_null(1));
2722
2723        assert_eq!(record_batch_slice, deserialized_batch);
2724    }
2725
2726    #[test]
2727    fn truncate_ipc_struct_array() {
2728        fn create_batch() -> RecordBatch {
2729            let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2730                .into_iter()
2731                .collect();
2732            let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2733
2734            let struct_array = StructArray::from(vec![
2735                (
2736                    Arc::new(Field::new("s", DataType::Utf8, true)),
2737                    Arc::new(strings) as ArrayRef,
2738                ),
2739                (
2740                    Arc::new(Field::new("c", DataType::Int32, true)),
2741                    Arc::new(ints) as ArrayRef,
2742                ),
2743            ]);
2744
2745            let schema = Schema::new(vec![Field::new(
2746                "struct_array",
2747                struct_array.data_type().clone(),
2748                true,
2749            )]);
2750
2751            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2752        }
2753
2754        let record_batch = create_batch();
2755        let record_batch_slice = record_batch.slice(1, 2);
2756        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2757
2758        assert!(
2759            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2760        );
2761
2762        let structs = deserialized_batch
2763            .column(0)
2764            .as_any()
2765            .downcast_ref::<StructArray>()
2766            .unwrap();
2767
2768        assert!(structs.column(0).is_null(0));
2769        assert!(structs.column(0).is_valid(1));
2770        assert!(structs.column(1).is_valid(0));
2771        assert!(structs.column(1).is_null(1));
2772        assert_eq!(record_batch_slice, deserialized_batch);
2773    }
2774
2775    #[test]
2776    fn truncate_ipc_string_array_with_all_empty_string() {
2777        fn create_batch() -> RecordBatch {
2778            let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2779            let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2780            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2781        }
2782
2783        let record_batch = create_batch();
2784        let record_batch_slice = record_batch.slice(0, 1);
2785        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2786
2787        assert!(
2788            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2789        );
2790        assert_eq!(record_batch_slice, deserialized_batch);
2791    }
2792
2793    #[test]
2794    fn test_stream_writer_writes_array_slice() {
2795        let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2796        assert_eq!(
2797            vec![Some(1), Some(2), Some(3)],
2798            array.iter().collect::<Vec<_>>()
2799        );
2800
2801        let sliced = array.slice(1, 2);
2802        assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2803
2804        let batch = RecordBatch::try_new(
2805            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2806            vec![Arc::new(sliced)],
2807        )
2808        .expect("new batch");
2809
2810        let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2811        writer.write(&batch).expect("write");
2812        let outbuf = writer.into_inner().expect("inner");
2813
2814        let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2815        let read_batch = reader.next().unwrap().expect("read batch");
2816
2817        let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2818        assert_eq!(
2819            vec![Some(2), Some(3)],
2820            read_array.iter().collect::<Vec<_>>()
2821        );
2822    }
2823
2824    #[test]
2825    fn test_large_slice_uint32() {
2826        ensure_roundtrip(Arc::new(UInt32Array::from_iter(
2827            (0..8000).map(|i| if i % 2 == 0 { Some(i) } else { None }),
2828        )));
2829    }
2830
2831    #[test]
2832    fn test_large_slice_string() {
2833        let strings: Vec<_> = (0..8000)
2834            .map(|i| {
2835                if i % 2 == 0 {
2836                    Some(format!("value{i}"))
2837                } else {
2838                    None
2839                }
2840            })
2841            .collect();
2842
2843        ensure_roundtrip(Arc::new(StringArray::from(strings)));
2844    }
2845
2846    #[test]
2847    fn test_large_slice_string_list() {
2848        let mut ls = ListBuilder::new(StringBuilder::new());
2849
2850        let mut s = String::new();
2851        for row_number in 0..8000 {
2852            if row_number % 2 == 0 {
2853                for list_element in 0..1000 {
2854                    s.clear();
2855                    use std::fmt::Write;
2856                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2857                    ls.values().append_value(&s);
2858                }
2859                ls.append(true)
2860            } else {
2861                ls.append(false); // null
2862            }
2863        }
2864
2865        ensure_roundtrip(Arc::new(ls.finish()));
2866    }
2867
2868    #[test]
2869    fn test_large_slice_string_list_of_lists() {
2870        // The reason for the special test is to verify reencode_offsets which looks both at
2871        // the starting offset and the data offset.  So need a dataset where the starting_offset
2872        // is zero but the data offset is not.
2873        let mut ls = ListBuilder::new(ListBuilder::new(StringBuilder::new()));
2874
2875        for _ in 0..4000 {
2876            ls.values().append(true);
2877            ls.append(true)
2878        }
2879
2880        let mut s = String::new();
2881        for row_number in 0..4000 {
2882            if row_number % 2 == 0 {
2883                for list_element in 0..1000 {
2884                    s.clear();
2885                    use std::fmt::Write;
2886                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2887                    ls.values().values().append_value(&s);
2888                }
2889                ls.values().append(true);
2890                ls.append(true)
2891            } else {
2892                ls.append(false); // null
2893            }
2894        }
2895
2896        ensure_roundtrip(Arc::new(ls.finish()));
2897    }
2898
2899    /// Read/write a record batch to a File and Stream and ensure it is the same at the outout
2900    fn ensure_roundtrip(array: ArrayRef) {
2901        let num_rows = array.len();
2902        let orig_batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap();
2903        // take off the first element
2904        let sliced_batch = orig_batch.slice(1, num_rows - 1);
2905
2906        let schema = orig_batch.schema();
2907        let stream_data = {
2908            let mut writer = StreamWriter::try_new(vec![], &schema).unwrap();
2909            writer.write(&sliced_batch).unwrap();
2910            writer.into_inner().unwrap()
2911        };
2912        let read_batch = {
2913            let projection = None;
2914            let mut reader = StreamReader::try_new(Cursor::new(stream_data), projection).unwrap();
2915            reader
2916                .next()
2917                .expect("expect no errors reading batch")
2918                .expect("expect batch")
2919        };
2920        assert_eq!(sliced_batch, read_batch);
2921
2922        let file_data = {
2923            let mut writer = FileWriter::try_new_buffered(vec![], &schema).unwrap();
2924            writer.write(&sliced_batch).unwrap();
2925            writer.into_inner().unwrap().into_inner().unwrap()
2926        };
2927        let read_batch = {
2928            let projection = None;
2929            let mut reader = FileReader::try_new(Cursor::new(file_data), projection).unwrap();
2930            reader
2931                .next()
2932                .expect("expect no errors reading batch")
2933                .expect("expect batch")
2934        };
2935        assert_eq!(sliced_batch, read_batch);
2936
2937        // TODO test file writer/reader
2938    }
2939
2940    #[test]
2941    fn encode_bools_slice() {
2942        // Test case for https://github.com/apache/arrow-rs/issues/3496
2943        assert_bool_roundtrip([true, false], 1, 1);
2944
2945        // slice somewhere in the middle
2946        assert_bool_roundtrip(
2947            [
2948                true, false, true, true, false, false, true, true, true, false, false, false, true,
2949                true, true, true, false, false, false, false, true, true, true, true, true, false,
2950                false, false, false, false,
2951            ],
2952            13,
2953            17,
2954        );
2955
2956        // start at byte boundary, end in the middle
2957        assert_bool_roundtrip(
2958            [
2959                true, false, true, true, false, false, true, true, true, false, false, false,
2960            ],
2961            8,
2962            2,
2963        );
2964
2965        // start and stop and byte boundary
2966        assert_bool_roundtrip(
2967            [
2968                true, false, true, true, false, false, true, true, true, false, false, false, true,
2969                true, true, true, true, false, false, false, false, false,
2970            ],
2971            8,
2972            8,
2973        );
2974    }
2975
2976    fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2977        let val_bool_field = Field::new("val", DataType::Boolean, false);
2978
2979        let schema = Arc::new(Schema::new(vec![val_bool_field]));
2980
2981        let bools = BooleanArray::from(bools.to_vec());
2982
2983        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2984        let batch = batch.slice(offset, length);
2985
2986        let data = serialize_stream(&batch);
2987        let batch2 = deserialize_stream(data);
2988        assert_eq!(batch, batch2);
2989    }
2990
2991    #[test]
2992    fn test_run_array_unslice() {
2993        let total_len = 80;
2994        let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2995        let repeats: Vec<usize> = vec![3, 4, 1, 2];
2996        let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2997        for ix in 0_usize..32 {
2998            let repeat: usize = repeats[ix % repeats.len()];
2999            let val: Option<i32> = vals[ix % vals.len()];
3000            input_array.resize(input_array.len() + repeat, val);
3001        }
3002
3003        // Encode the input_array to run array
3004        let mut builder =
3005            PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
3006        builder.extend(input_array.iter().copied());
3007        let run_array = builder.finish();
3008
3009        // test for all slice lengths.
3010        for slice_len in 1..=total_len {
3011            // test for offset = 0, slice length = slice_len
3012            let sliced_run_array: RunArray<Int16Type> =
3013                run_array.slice(0, slice_len).into_data().into();
3014
3015            // Create unsliced run array.
3016            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
3017            let typed = unsliced_run_array
3018                .downcast::<PrimitiveArray<Int32Type>>()
3019                .unwrap();
3020            let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
3021            let actual: Vec<Option<i32>> = typed.into_iter().collect();
3022            assert_eq!(expected, actual);
3023
3024            // test for offset = total_len - slice_len, length = slice_len
3025            let sliced_run_array: RunArray<Int16Type> = run_array
3026                .slice(total_len - slice_len, slice_len)
3027                .into_data()
3028                .into();
3029
3030            // Create unsliced run array.
3031            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
3032            let typed = unsliced_run_array
3033                .downcast::<PrimitiveArray<Int32Type>>()
3034                .unwrap();
3035            let expected: Vec<Option<i32>> = input_array
3036                .iter()
3037                .skip(total_len - slice_len)
3038                .copied()
3039                .collect();
3040            let actual: Vec<Option<i32>> = typed.into_iter().collect();
3041            assert_eq!(expected, actual);
3042        }
3043    }
3044
3045    fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3046        let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
3047
3048        for i in 0..100_000 {
3049            for value in [i, i, i] {
3050                ls.values().append_value(value);
3051            }
3052            ls.append(true)
3053        }
3054
3055        ls.finish()
3056    }
3057
3058    fn generate_utf8view_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3059        let mut ls = GenericListBuilder::<O, _>::new(StringViewBuilder::new());
3060
3061        for i in 0..100_000 {
3062            for value in [
3063                format!("value{}", i),
3064                format!("value{}", i),
3065                format!("value{}", i),
3066            ] {
3067                ls.values().append_value(&value);
3068            }
3069            ls.append(true)
3070        }
3071
3072        ls.finish()
3073    }
3074
3075    fn generate_string_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3076        let mut ls = GenericListBuilder::<O, _>::new(StringBuilder::new());
3077
3078        for i in 0..100_000 {
3079            for value in [
3080                format!("value{}", i),
3081                format!("value{}", i),
3082                format!("value{}", i),
3083            ] {
3084                ls.values().append_value(&value);
3085            }
3086            ls.append(true)
3087        }
3088
3089        ls.finish()
3090    }
3091
3092    fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
3093        let mut ls =
3094            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
3095
3096        for _i in 0..10_000 {
3097            for j in 0..10 {
3098                for value in [j, j, j, j] {
3099                    ls.values().values().append_value(value);
3100                }
3101                ls.values().append(true)
3102            }
3103            ls.append(true);
3104        }
3105
3106        ls.finish()
3107    }
3108
3109    fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
3110        let mut ls =
3111            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
3112
3113        for _i in 0..999 {
3114            ls.values().append(true);
3115            ls.append(true);
3116        }
3117
3118        for j in 0..10 {
3119            for value in [j, j, j, j] {
3120                ls.values().values().append_value(value);
3121            }
3122            ls.values().append(true)
3123        }
3124        ls.append(true);
3125
3126        for i in 0..9_000 {
3127            for j in 0..10 {
3128                for value in [i + j, i + j, i + j, i + j] {
3129                    ls.values().values().append_value(value);
3130                }
3131                ls.values().append(true)
3132            }
3133            ls.append(true);
3134        }
3135
3136        ls.finish()
3137    }
3138
3139    fn generate_map_array_data() -> MapArray {
3140        let keys_builder = UInt32Builder::new();
3141        let values_builder = UInt32Builder::new();
3142
3143        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
3144
3145        for i in 0..100_000 {
3146            for _j in 0..3 {
3147                builder.keys().append_value(i);
3148                builder.values().append_value(i * 2);
3149            }
3150            builder.append(true).unwrap();
3151        }
3152
3153        builder.finish()
3154    }
3155
3156    #[test]
3157    fn reencode_offsets_when_first_offset_is_not_zero() {
3158        let original_list = generate_list_data::<i32>();
3159        let original_data = original_list.into_data();
3160        let slice_data = original_data.slice(75, 7);
3161        let (new_offsets, original_start, length) =
3162            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
3163        assert_eq!(
3164            vec![0, 3, 6, 9, 12, 15, 18, 21],
3165            new_offsets.typed_data::<i32>()
3166        );
3167        assert_eq!(225, original_start);
3168        assert_eq!(21, length);
3169    }
3170
3171    #[test]
3172    fn reencode_offsets_when_first_offset_is_zero() {
3173        let mut ls = GenericListBuilder::<i32, _>::new(UInt32Builder::new());
3174        // ls = [[], [35, 42]
3175        ls.append(true);
3176        ls.values().append_value(35);
3177        ls.values().append_value(42);
3178        ls.append(true);
3179        let original_list = ls.finish();
3180        let original_data = original_list.into_data();
3181
3182        let slice_data = original_data.slice(1, 1);
3183        let (new_offsets, original_start, length) =
3184            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
3185        assert_eq!(vec![0, 2], new_offsets.typed_data::<i32>());
3186        assert_eq!(0, original_start);
3187        assert_eq!(2, length);
3188    }
3189
3190    /// Ensure when serde full & sliced versions they are equal to original input.
3191    /// Also ensure serialized sliced version is significantly smaller than serialized full.
3192    fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
3193        // test both full and sliced versions
3194        let in_sliced = in_batch.slice(999, 1);
3195
3196        let bytes_batch = serialize_file(&in_batch);
3197        let bytes_sliced = serialize_file(&in_sliced);
3198
3199        // serializing 1 row should be significantly smaller than serializing 100,000
3200        assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
3201
3202        // ensure both are still valid and equal to originals
3203        let out_batch = deserialize_file(bytes_batch);
3204        assert_eq!(in_batch, out_batch);
3205
3206        let out_sliced = deserialize_file(bytes_sliced);
3207        assert_eq!(in_sliced, out_sliced);
3208    }
3209
3210    #[test]
3211    fn encode_lists() {
3212        let val_inner = Field::new_list_field(DataType::UInt32, true);
3213        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3214        let schema = Arc::new(Schema::new(vec![val_list_field]));
3215
3216        let values = Arc::new(generate_list_data::<i32>());
3217
3218        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3219        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3220    }
3221
3222    #[test]
3223    fn encode_empty_list() {
3224        let val_inner = Field::new_list_field(DataType::UInt32, true);
3225        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3226        let schema = Arc::new(Schema::new(vec![val_list_field]));
3227
3228        let values = Arc::new(generate_list_data::<i32>());
3229
3230        let in_batch = RecordBatch::try_new(schema, vec![values])
3231            .unwrap()
3232            .slice(999, 0);
3233        let out_batch = deserialize_file(serialize_file(&in_batch));
3234        assert_eq!(in_batch, out_batch);
3235    }
3236
3237    #[test]
3238    fn encode_large_lists() {
3239        let val_inner = Field::new_list_field(DataType::UInt32, true);
3240        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3241        let schema = Arc::new(Schema::new(vec![val_list_field]));
3242
3243        let values = Arc::new(generate_list_data::<i64>());
3244
3245        // ensure when serde full & sliced versions they are equal to original input
3246        // also ensure serialized sliced version is significantly smaller than serialized full
3247        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3248        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3249    }
3250
3251    #[test]
3252    fn encode_large_lists_non_zero_offset() {
3253        let val_inner = Field::new_list_field(DataType::UInt32, true);
3254        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3255        let schema = Arc::new(Schema::new(vec![val_list_field]));
3256
3257        let values = Arc::new(generate_list_data::<i64>());
3258
3259        check_sliced_list_array(schema, values);
3260    }
3261
3262    #[test]
3263    fn encode_large_lists_string_non_zero_offset() {
3264        let val_inner = Field::new_list_field(DataType::Utf8, true);
3265        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3266        let schema = Arc::new(Schema::new(vec![val_list_field]));
3267
3268        let values = Arc::new(generate_string_list_data::<i64>());
3269
3270        check_sliced_list_array(schema, values);
3271    }
3272
3273    #[test]
3274    fn encode_large_list_string_view_non_zero_offset() {
3275        let val_inner = Field::new_list_field(DataType::Utf8View, true);
3276        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3277        let schema = Arc::new(Schema::new(vec![val_list_field]));
3278
3279        let values = Arc::new(generate_utf8view_list_data::<i64>());
3280
3281        check_sliced_list_array(schema, values);
3282    }
3283
3284    fn check_sliced_list_array(schema: Arc<Schema>, values: Arc<GenericListArray<i64>>) {
3285        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3286            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3287                .unwrap()
3288                .slice(offset, len);
3289            let out_batch = deserialize_file(serialize_file(&in_batch));
3290            assert_eq!(in_batch, out_batch);
3291        }
3292    }
3293
3294    #[test]
3295    fn encode_nested_lists() {
3296        let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
3297        let inner_list_field = Arc::new(Field::new_list_field(DataType::List(inner_int), true));
3298        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3299        let schema = Arc::new(Schema::new(vec![list_field]));
3300
3301        let values = Arc::new(generate_nested_list_data::<i32>());
3302
3303        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3304        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3305    }
3306
3307    #[test]
3308    fn encode_nested_lists_starting_at_zero() {
3309        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
3310        let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
3311        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3312        let schema = Arc::new(Schema::new(vec![list_field]));
3313
3314        let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
3315
3316        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3317        roundtrip_ensure_sliced_smaller(in_batch, 1);
3318    }
3319
3320    #[test]
3321    fn encode_map_array() {
3322        let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
3323        let values = Arc::new(Field::new("values", DataType::UInt32, true));
3324        let map_field = Field::new_map("map", "entries", keys, values, false, true);
3325        let schema = Arc::new(Schema::new(vec![map_field]));
3326
3327        let values = Arc::new(generate_map_array_data());
3328
3329        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3330        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3331    }
3332
3333    fn generate_list_view_data<O: OffsetSizeTrait>() -> GenericListViewArray<O> {
3334        let mut builder = GenericListViewBuilder::<O, _>::new(UInt32Builder::new());
3335
3336        for i in 0u32..100_000 {
3337            if i.is_multiple_of(10_000) {
3338                builder.append(false);
3339                continue;
3340            }
3341            for value in [i, i, i] {
3342                builder.values().append_value(value);
3343            }
3344            builder.append(true);
3345        }
3346
3347        builder.finish()
3348    }
3349
3350    #[test]
3351    fn encode_list_view_arrays() {
3352        let val_inner = Field::new_list_field(DataType::UInt32, true);
3353        let val_field = Field::new("val", DataType::ListView(Arc::new(val_inner)), true);
3354        let schema = Arc::new(Schema::new(vec![val_field]));
3355
3356        let values = Arc::new(generate_list_view_data::<i32>());
3357
3358        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3359        let out_batch = deserialize_file(serialize_file(&in_batch));
3360        assert_eq!(in_batch, out_batch);
3361    }
3362
3363    #[test]
3364    fn encode_large_list_view_arrays() {
3365        let val_inner = Field::new_list_field(DataType::UInt32, true);
3366        let val_field = Field::new("val", DataType::LargeListView(Arc::new(val_inner)), true);
3367        let schema = Arc::new(Schema::new(vec![val_field]));
3368
3369        let values = Arc::new(generate_list_view_data::<i64>());
3370
3371        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3372        let out_batch = deserialize_file(serialize_file(&in_batch));
3373        assert_eq!(in_batch, out_batch);
3374    }
3375
3376    #[test]
3377    fn check_sliced_list_view_array() {
3378        let inner = Field::new_list_field(DataType::UInt32, true);
3379        let field = Field::new("val", DataType::ListView(Arc::new(inner)), true);
3380        let schema = Arc::new(Schema::new(vec![field]));
3381        let values = Arc::new(generate_list_view_data::<i32>());
3382
3383        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3384            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3385                .unwrap()
3386                .slice(offset, len);
3387            let out_batch = deserialize_file(serialize_file(&in_batch));
3388            assert_eq!(in_batch, out_batch);
3389        }
3390    }
3391
3392    #[test]
3393    fn check_sliced_large_list_view_array() {
3394        let inner = Field::new_list_field(DataType::UInt32, true);
3395        let field = Field::new("val", DataType::LargeListView(Arc::new(inner)), true);
3396        let schema = Arc::new(Schema::new(vec![field]));
3397        let values = Arc::new(generate_list_view_data::<i64>());
3398
3399        for (offset, len) in [(999, 1), (0, 13), (47, 12), (values.len() - 13, 13)] {
3400            let in_batch = RecordBatch::try_new(schema.clone(), vec![values.clone()])
3401                .unwrap()
3402                .slice(offset, len);
3403            let out_batch = deserialize_file(serialize_file(&in_batch));
3404            assert_eq!(in_batch, out_batch);
3405        }
3406    }
3407
3408    fn generate_nested_list_view_data<O: OffsetSizeTrait>() -> GenericListViewArray<O> {
3409        let inner_builder = UInt32Builder::new();
3410        let middle_builder = GenericListViewBuilder::<O, _>::new(inner_builder);
3411        let mut outer_builder = GenericListViewBuilder::<O, _>::new(middle_builder);
3412
3413        for i in 0u32..10_000 {
3414            if i.is_multiple_of(1_000) {
3415                outer_builder.append(false);
3416                continue;
3417            }
3418
3419            for _ in 0..3 {
3420                for value in [i, i + 1, i + 2] {
3421                    outer_builder.values().values().append_value(value);
3422                }
3423                outer_builder.values().append(true);
3424            }
3425            outer_builder.append(true);
3426        }
3427
3428        outer_builder.finish()
3429    }
3430
3431    #[test]
3432    fn encode_nested_list_views() {
3433        let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
3434        let inner_list_field = Arc::new(Field::new_list_field(DataType::ListView(inner_int), true));
3435        let list_field = Field::new("val", DataType::ListView(inner_list_field), true);
3436        let schema = Arc::new(Schema::new(vec![list_field]));
3437
3438        let values = Arc::new(generate_nested_list_view_data::<i32>());
3439
3440        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3441        let out_batch = deserialize_file(serialize_file(&in_batch));
3442        assert_eq!(in_batch, out_batch);
3443    }
3444
3445    fn test_roundtrip_list_view_of_dict_impl<OffsetSize: OffsetSizeTrait, U: ArrowNativeType>(
3446        list_data_type: DataType,
3447        offsets: &[U; 5],
3448        sizes: &[U; 4],
3449    ) {
3450        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3451        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3452        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3453        let dict_data = dict_array.to_data();
3454
3455        let value_offsets = Buffer::from_slice_ref(offsets);
3456        let value_sizes = Buffer::from_slice_ref(sizes);
3457
3458        let list_data = ArrayData::builder(list_data_type)
3459            .len(4)
3460            .add_buffer(value_offsets)
3461            .add_buffer(value_sizes)
3462            .add_child_data(dict_data)
3463            .build()
3464            .unwrap();
3465        let list_view_array = GenericListViewArray::<OffsetSize>::from(list_data);
3466
3467        let schema = Arc::new(Schema::new(vec![Field::new(
3468            "f1",
3469            list_view_array.data_type().clone(),
3470            false,
3471        )]));
3472        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(list_view_array)]).unwrap();
3473
3474        let output_batch = deserialize_file(serialize_file(&input_batch));
3475        assert_eq!(input_batch, output_batch);
3476
3477        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3478        assert_eq!(input_batch, output_batch);
3479    }
3480
3481    #[test]
3482    fn test_roundtrip_list_view_of_dict() {
3483        #[allow(deprecated)]
3484        let list_data_type = DataType::ListView(Arc::new(Field::new_dict(
3485            "item",
3486            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3487            true,
3488            1,
3489            false,
3490        )));
3491        let offsets: &[i32; 5] = &[0, 2, 4, 4, 7];
3492        let sizes: &[i32; 4] = &[2, 2, 0, 3];
3493        test_roundtrip_list_view_of_dict_impl::<i32, i32>(list_data_type, offsets, sizes);
3494    }
3495
3496    #[test]
3497    fn test_roundtrip_large_list_view_of_dict() {
3498        #[allow(deprecated)]
3499        let list_data_type = DataType::LargeListView(Arc::new(Field::new_dict(
3500            "item",
3501            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3502            true,
3503            2,
3504            false,
3505        )));
3506        let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
3507        let sizes: &[i64; 4] = &[2, 2, 0, 3];
3508        test_roundtrip_list_view_of_dict_impl::<i64, i64>(list_data_type, offsets, sizes);
3509    }
3510
3511    #[test]
3512    fn test_roundtrip_sliced_list_view_of_dict() {
3513        #[allow(deprecated)]
3514        let list_data_type = DataType::ListView(Arc::new(Field::new_dict(
3515            "item",
3516            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3517            true,
3518            3,
3519            false,
3520        )));
3521
3522        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3523        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2, 1, 0, 3, 2, 1]);
3524        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3525        let dict_data = dict_array.to_data();
3526
3527        let offsets: &[i32; 7] = &[0, 2, 4, 4, 7, 9, 12];
3528        let sizes: &[i32; 6] = &[2, 2, 0, 3, 2, 3];
3529        let value_offsets = Buffer::from_slice_ref(offsets);
3530        let value_sizes = Buffer::from_slice_ref(sizes);
3531
3532        let list_data = ArrayData::builder(list_data_type)
3533            .len(6)
3534            .add_buffer(value_offsets)
3535            .add_buffer(value_sizes)
3536            .add_child_data(dict_data)
3537            .build()
3538            .unwrap();
3539        let list_view_array = GenericListViewArray::<i32>::from(list_data);
3540
3541        let schema = Arc::new(Schema::new(vec![Field::new(
3542            "f1",
3543            list_view_array.data_type().clone(),
3544            false,
3545        )]));
3546        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(list_view_array)]).unwrap();
3547
3548        let sliced_batch = input_batch.slice(1, 4);
3549
3550        let output_batch = deserialize_file(serialize_file(&sliced_batch));
3551        assert_eq!(sliced_batch, output_batch);
3552
3553        let output_batch = deserialize_stream(serialize_stream(&sliced_batch));
3554        assert_eq!(sliced_batch, output_batch);
3555    }
3556
3557    #[test]
3558    fn test_roundtrip_dense_union_of_dict() {
3559        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3560        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3561        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3562
3563        #[allow(deprecated)]
3564        let dict_field = Arc::new(Field::new_dict(
3565            "dict",
3566            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3567            true,
3568            1,
3569            false,
3570        ));
3571        let int_field = Arc::new(Field::new("int", DataType::Int32, false));
3572        let union_fields = UnionFields::try_new(vec![0, 1], vec![dict_field, int_field]).unwrap();
3573
3574        let types = ScalarBuffer::from(vec![0i8, 0, 1, 0, 1, 0, 0]);
3575        let offsets = ScalarBuffer::from(vec![0i32, 1, 0, 2, 1, 3, 4]);
3576
3577        let int_array = Int32Array::from(vec![100, 200]);
3578
3579        let union = UnionArray::try_new(
3580            union_fields.clone(),
3581            types,
3582            Some(offsets),
3583            vec![Arc::new(dict_array), Arc::new(int_array)],
3584        )
3585        .unwrap();
3586
3587        let schema = Arc::new(Schema::new(vec![Field::new(
3588            "union",
3589            DataType::Union(union_fields, UnionMode::Dense),
3590            false,
3591        )]));
3592        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
3593
3594        let output_batch = deserialize_file(serialize_file(&input_batch));
3595        assert_eq!(input_batch, output_batch);
3596
3597        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3598        assert_eq!(input_batch, output_batch);
3599    }
3600
3601    #[test]
3602    fn test_roundtrip_sparse_union_of_dict() {
3603        let values = StringArray::from(vec![Some("alpha"), None, Some("beta"), Some("gamma")]);
3604        let keys = Int32Array::from_iter_values([0, 0, 1, 2, 3, 0, 2]);
3605        let dict_array = DictionaryArray::new(keys, Arc::new(values));
3606
3607        #[allow(deprecated)]
3608        let dict_field = Arc::new(Field::new_dict(
3609            "dict",
3610            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3611            true,
3612            2,
3613            false,
3614        ));
3615        let int_field = Arc::new(Field::new("int", DataType::Int32, false));
3616        let union_fields = UnionFields::try_new(vec![0, 1], vec![dict_field, int_field]).unwrap();
3617
3618        let types = ScalarBuffer::from(vec![0i8, 0, 1, 0, 1, 0, 0]);
3619
3620        let int_array = Int32Array::from(vec![0, 0, 100, 0, 200, 0, 0]);
3621
3622        let union = UnionArray::try_new(
3623            union_fields.clone(),
3624            types,
3625            None,
3626            vec![Arc::new(dict_array), Arc::new(int_array)],
3627        )
3628        .unwrap();
3629
3630        let schema = Arc::new(Schema::new(vec![Field::new(
3631            "union",
3632            DataType::Union(union_fields, UnionMode::Sparse),
3633            false,
3634        )]));
3635        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
3636
3637        let output_batch = deserialize_file(serialize_file(&input_batch));
3638        assert_eq!(input_batch, output_batch);
3639
3640        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3641        assert_eq!(input_batch, output_batch);
3642    }
3643
3644    #[test]
3645    fn test_roundtrip_map_with_dict_keys() {
3646        // Building a map array is a bit involved. We first build a struct arary that has a key and
3647        // value field and then use that to build the actual map array.
3648        let key_values = StringArray::from(vec!["key_a", "key_b", "key_c"]);
3649        let keys = Int32Array::from_iter_values([0, 1, 2, 0, 1, 0]);
3650        let dict_keys = DictionaryArray::new(keys, Arc::new(key_values));
3651
3652        let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
3653
3654        #[allow(deprecated)]
3655        let entries_field = Arc::new(Field::new(
3656            "entries",
3657            DataType::Struct(
3658                vec![
3659                    Field::new_dict(
3660                        "key",
3661                        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3662                        false,
3663                        1,
3664                        false,
3665                    ),
3666                    Field::new("value", DataType::Int32, true),
3667                ]
3668                .into(),
3669            ),
3670            false,
3671        ));
3672
3673        let entries = StructArray::from(vec![
3674            (
3675                Arc::new(Field::new(
3676                    "key",
3677                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3678                    false,
3679                )),
3680                Arc::new(dict_keys) as ArrayRef,
3681            ),
3682            (
3683                Arc::new(Field::new("value", DataType::Int32, true)),
3684                Arc::new(values) as ArrayRef,
3685            ),
3686        ]);
3687
3688        let offsets = Buffer::from_slice_ref([0i32, 2, 4, 6]);
3689
3690        let map_data = ArrayData::builder(DataType::Map(entries_field, false))
3691            .len(3)
3692            .add_buffer(offsets)
3693            .add_child_data(entries.into_data())
3694            .build()
3695            .unwrap();
3696        let map_array = MapArray::from(map_data);
3697
3698        let schema = Arc::new(Schema::new(vec![Field::new(
3699            "map",
3700            map_array.data_type().clone(),
3701            false,
3702        )]));
3703        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(map_array)]).unwrap();
3704
3705        let output_batch = deserialize_file(serialize_file(&input_batch));
3706        assert_eq!(input_batch, output_batch);
3707
3708        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3709        assert_eq!(input_batch, output_batch);
3710    }
3711
3712    #[test]
3713    fn test_roundtrip_map_with_dict_values() {
3714        // Building a map array is a bit involved. We first build a struct arary that has a key and
3715        // value field and then use that to build the actual map array.
3716        let keys = StringArray::from(vec!["a", "b", "c", "d", "e", "f"]);
3717
3718        let value_values = StringArray::from(vec!["val_x", "val_y", "val_z"]);
3719        let value_keys = Int32Array::from_iter_values([0, 1, 2, 0, 1, 0]);
3720        let dict_values = DictionaryArray::new(value_keys, Arc::new(value_values));
3721
3722        #[allow(deprecated)]
3723        let entries_field = Arc::new(Field::new(
3724            "entries",
3725            DataType::Struct(
3726                vec![
3727                    Field::new("key", DataType::Utf8, false),
3728                    Field::new_dict(
3729                        "value",
3730                        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3731                        true,
3732                        2,
3733                        false,
3734                    ),
3735                ]
3736                .into(),
3737            ),
3738            false,
3739        ));
3740
3741        let entries = StructArray::from(vec![
3742            (
3743                Arc::new(Field::new("key", DataType::Utf8, false)),
3744                Arc::new(keys) as ArrayRef,
3745            ),
3746            (
3747                Arc::new(Field::new(
3748                    "value",
3749                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3750                    true,
3751                )),
3752                Arc::new(dict_values) as ArrayRef,
3753            ),
3754        ]);
3755
3756        let offsets = Buffer::from_slice_ref([0i32, 2, 4, 6]);
3757
3758        let map_data = ArrayData::builder(DataType::Map(entries_field, false))
3759            .len(3)
3760            .add_buffer(offsets)
3761            .add_child_data(entries.into_data())
3762            .build()
3763            .unwrap();
3764        let map_array = MapArray::from(map_data);
3765
3766        let schema = Arc::new(Schema::new(vec![Field::new(
3767            "map",
3768            map_array.data_type().clone(),
3769            false,
3770        )]));
3771        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(map_array)]).unwrap();
3772
3773        let output_batch = deserialize_file(serialize_file(&input_batch));
3774        assert_eq!(input_batch, output_batch);
3775
3776        let output_batch = deserialize_stream(serialize_stream(&input_batch));
3777        assert_eq!(input_batch, output_batch);
3778    }
3779
3780    #[test]
3781    fn test_decimal128_alignment16_is_sufficient() {
3782        const IPC_ALIGNMENT: usize = 16;
3783
3784        // Test a bunch of different dimensions to ensure alignment is never an issue.
3785        // For example, if we only test `num_cols = 1` then even with alignment 8 this
3786        // test would _happen_ to pass, even though for different dimensions like
3787        // `num_cols = 2` it would fail.
3788        for num_cols in [1, 2, 3, 17, 50, 73, 99] {
3789            let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle
3790
3791            let mut fields = Vec::new();
3792            let mut arrays = Vec::new();
3793            for i in 0..num_cols {
3794                let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3795                let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3796                fields.push(field);
3797                arrays.push(Arc::new(array) as Arc<dyn Array>);
3798            }
3799            let schema = Schema::new(fields);
3800            let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3801
3802            let mut writer = FileWriter::try_new_with_options(
3803                Vec::new(),
3804                batch.schema_ref(),
3805                IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3806            )
3807            .unwrap();
3808            writer.write(&batch).unwrap();
3809            writer.finish().unwrap();
3810
3811            let out: Vec<u8> = writer.into_inner().unwrap();
3812
3813            let buffer = Buffer::from_vec(out);
3814            let trailer_start = buffer.len() - 10;
3815            let footer_len =
3816                read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3817            let footer =
3818                root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3819
3820            let schema = fb_to_schema(footer.schema().unwrap());
3821
3822            // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient
3823            // for `read_record_batch` later on to read the data in a zero-copy manner.
3824            let decoder =
3825                FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3826
3827            let batches = footer.recordBatches().unwrap();
3828
3829            let block = batches.get(0);
3830            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3831            let data = buffer.slice_with_length(block.offset() as _, block_len);
3832
3833            let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
3834
3835            assert_eq!(batch, batch2);
3836        }
3837    }
3838
3839    #[test]
3840    fn test_decimal128_alignment8_is_unaligned() {
3841        const IPC_ALIGNMENT: usize = 8;
3842
3843        let num_cols = 2;
3844        let num_rows = 1;
3845
3846        let mut fields = Vec::new();
3847        let mut arrays = Vec::new();
3848        for i in 0..num_cols {
3849            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3850            let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3851            fields.push(field);
3852            arrays.push(Arc::new(array) as Arc<dyn Array>);
3853        }
3854        let schema = Schema::new(fields);
3855        let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3856
3857        let mut writer = FileWriter::try_new_with_options(
3858            Vec::new(),
3859            batch.schema_ref(),
3860            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3861        )
3862        .unwrap();
3863        writer.write(&batch).unwrap();
3864        writer.finish().unwrap();
3865
3866        let out: Vec<u8> = writer.into_inner().unwrap();
3867
3868        let buffer = Buffer::from_vec(out);
3869        let trailer_start = buffer.len() - 10;
3870        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3871        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3872        let schema = fb_to_schema(footer.schema().unwrap());
3873
3874        // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying
3875        // to an aligned buffer in `ArrayDataBuilder.build_aligned`.
3876        let decoder =
3877            FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3878
3879        let batches = footer.recordBatches().unwrap();
3880
3881        let block = batches.get(0);
3882        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3883        let data = buffer.slice_with_length(block.offset() as _, block_len);
3884
3885        let result = decoder.read_record_batch(block, &data);
3886
3887        let error = result.unwrap_err();
3888        assert_eq!(
3889            error.to_string(),
3890            "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
3891             offset from expected alignment of 16 by 8"
3892        );
3893    }
3894
3895    #[test]
3896    fn test_flush() {
3897        // We write a schema which is small enough to fit into a buffer and not get flushed,
3898        // and then force the write with .flush().
3899        let num_cols = 2;
3900        let mut fields = Vec::new();
3901        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
3902        for i in 0..num_cols {
3903            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3904            fields.push(field);
3905        }
3906        let schema = Schema::new(fields);
3907        let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
3908        let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
3909        let mut stream_writer =
3910            StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
3911                .unwrap();
3912        let mut file_writer =
3913            FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
3914
3915        let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
3916        let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
3917        stream_writer.flush().unwrap();
3918        file_writer.flush().unwrap();
3919        let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
3920        let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
3921        let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
3922        // Finishing a stream writes the continuation bytes in MetadataVersion::V5 (4 bytes)
3923        // and then a length of 0 (4 bytes) for a total of 8 bytes.
3924        // Everything before that should have been flushed in the .flush() call.
3925        let expected_stream_flushed_bytes = stream_out.len() - 8;
3926        // A file write is the same as the stream write except for the leading magic string
3927        // ARROW1 plus padding, which is 8 bytes.
3928        let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
3929
3930        assert!(
3931            stream_bytes_written_on_new < stream_bytes_written_on_flush,
3932            "this test makes no sense if flush is not actually required"
3933        );
3934        assert!(
3935            file_bytes_written_on_new < file_bytes_written_on_flush,
3936            "this test makes no sense if flush is not actually required"
3937        );
3938        assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
3939        assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
3940    }
3941
3942    #[test]
3943    fn test_roundtrip_list_of_fixed_list() -> Result<(), ArrowError> {
3944        let l1_type =
3945            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, false)), 3);
3946        let l2_type = DataType::List(Arc::new(Field::new("item", l1_type.clone(), false)));
3947
3948        let l0_builder = Float32Builder::new();
3949        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3).with_field(Arc::new(Field::new(
3950            "item",
3951            DataType::Float32,
3952            false,
3953        )));
3954        let mut l2_builder =
3955            ListBuilder::new(l1_builder).with_field(Arc::new(Field::new("item", l1_type, false)));
3956
3957        for point in [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] {
3958            l2_builder.values().values().append_value(point[0]);
3959            l2_builder.values().values().append_value(point[1]);
3960            l2_builder.values().values().append_value(point[2]);
3961
3962            l2_builder.values().append(true);
3963        }
3964        l2_builder.append(true);
3965
3966        let point = [10., 11., 12.];
3967        l2_builder.values().values().append_value(point[0]);
3968        l2_builder.values().values().append_value(point[1]);
3969        l2_builder.values().values().append_value(point[2]);
3970
3971        l2_builder.values().append(true);
3972        l2_builder.append(true);
3973
3974        let array = Arc::new(l2_builder.finish()) as ArrayRef;
3975
3976        let schema = Arc::new(Schema::new_with_metadata(
3977            vec![Field::new("points", l2_type, false)],
3978            HashMap::default(),
3979        ));
3980
3981        // Test a variety of combinations that include 0 and non-zero offsets
3982        // and also portions or the rest of the array
3983        test_slices(&array, &schema, 0, 1)?;
3984        test_slices(&array, &schema, 0, 2)?;
3985        test_slices(&array, &schema, 1, 1)?;
3986
3987        Ok(())
3988    }
3989
3990    #[test]
3991    fn test_roundtrip_list_of_fixed_list_w_nulls() -> Result<(), ArrowError> {
3992        let l0_builder = Float32Builder::new();
3993        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3);
3994        let mut l2_builder = ListBuilder::new(l1_builder);
3995
3996        for point in [
3997            [Some(1.0), Some(2.0), None],
3998            [Some(4.0), Some(5.0), Some(6.0)],
3999            [None, Some(8.0), Some(9.0)],
4000        ] {
4001            for p in point {
4002                match p {
4003                    Some(p) => l2_builder.values().values().append_value(p),
4004                    None => l2_builder.values().values().append_null(),
4005                }
4006            }
4007
4008            l2_builder.values().append(true);
4009        }
4010        l2_builder.append(true);
4011
4012        let point = [Some(10.), None, None];
4013        for p in point {
4014            match p {
4015                Some(p) => l2_builder.values().values().append_value(p),
4016                None => l2_builder.values().values().append_null(),
4017            }
4018        }
4019
4020        l2_builder.values().append(true);
4021        l2_builder.append(true);
4022
4023        let array = Arc::new(l2_builder.finish()) as ArrayRef;
4024
4025        let schema = Arc::new(Schema::new_with_metadata(
4026            vec![Field::new(
4027                "points",
4028                DataType::List(Arc::new(Field::new(
4029                    "item",
4030                    DataType::FixedSizeList(
4031                        Arc::new(Field::new("item", DataType::Float32, true)),
4032                        3,
4033                    ),
4034                    true,
4035                ))),
4036                true,
4037            )],
4038            HashMap::default(),
4039        ));
4040
4041        // Test a variety of combinations that include 0 and non-zero offsets
4042        // and also portions or the rest of the array
4043        test_slices(&array, &schema, 0, 1)?;
4044        test_slices(&array, &schema, 0, 2)?;
4045        test_slices(&array, &schema, 1, 1)?;
4046
4047        Ok(())
4048    }
4049
4050    fn test_slices(
4051        parent_array: &ArrayRef,
4052        schema: &SchemaRef,
4053        offset: usize,
4054        length: usize,
4055    ) -> Result<(), ArrowError> {
4056        let subarray = parent_array.slice(offset, length);
4057        let original_batch = RecordBatch::try_new(schema.clone(), vec![subarray])?;
4058
4059        let mut bytes = Vec::new();
4060        let mut writer = StreamWriter::try_new(&mut bytes, schema)?;
4061        writer.write(&original_batch)?;
4062        writer.finish()?;
4063
4064        let mut cursor = std::io::Cursor::new(bytes);
4065        let mut reader = StreamReader::try_new(&mut cursor, None)?;
4066        let returned_batch = reader.next().unwrap()?;
4067
4068        assert_eq!(original_batch, returned_batch);
4069
4070        Ok(())
4071    }
4072
4073    #[test]
4074    fn test_roundtrip_fixed_list() -> Result<(), ArrowError> {
4075        let int_builder = Int64Builder::new();
4076        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3)
4077            .with_field(Arc::new(Field::new("item", DataType::Int64, false)));
4078
4079        for point in [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] {
4080            fixed_list_builder.values().append_value(point[0]);
4081            fixed_list_builder.values().append_value(point[1]);
4082            fixed_list_builder.values().append_value(point[2]);
4083
4084            fixed_list_builder.append(true);
4085        }
4086
4087        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
4088
4089        let schema = Arc::new(Schema::new_with_metadata(
4090            vec![Field::new(
4091                "points",
4092                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, false)), 3),
4093                false,
4094            )],
4095            HashMap::default(),
4096        ));
4097
4098        // Test a variety of combinations that include 0 and non-zero offsets
4099        // and also portions or the rest of the array
4100        test_slices(&array, &schema, 0, 4)?;
4101        test_slices(&array, &schema, 0, 2)?;
4102        test_slices(&array, &schema, 1, 3)?;
4103        test_slices(&array, &schema, 2, 1)?;
4104
4105        Ok(())
4106    }
4107
4108    #[test]
4109    fn test_roundtrip_fixed_list_w_nulls() -> Result<(), ArrowError> {
4110        let int_builder = Int64Builder::new();
4111        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3);
4112
4113        for point in [
4114            [Some(1), Some(2), None],
4115            [Some(4), Some(5), Some(6)],
4116            [None, Some(8), Some(9)],
4117            [Some(10), None, None],
4118        ] {
4119            for p in point {
4120                match p {
4121                    Some(p) => fixed_list_builder.values().append_value(p),
4122                    None => fixed_list_builder.values().append_null(),
4123                }
4124            }
4125
4126            fixed_list_builder.append(true);
4127        }
4128
4129        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
4130
4131        let schema = Arc::new(Schema::new_with_metadata(
4132            vec![Field::new(
4133                "points",
4134                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 3),
4135                true,
4136            )],
4137            HashMap::default(),
4138        ));
4139
4140        // Test a variety of combinations that include 0 and non-zero offsets
4141        // and also portions or the rest of the array
4142        test_slices(&array, &schema, 0, 4)?;
4143        test_slices(&array, &schema, 0, 2)?;
4144        test_slices(&array, &schema, 1, 3)?;
4145        test_slices(&array, &schema, 2, 1)?;
4146
4147        Ok(())
4148    }
4149
4150    #[test]
4151    fn test_metadata_encoding_ordering() {
4152        fn create_hash() -> u64 {
4153            let metadata: HashMap<String, String> = [
4154                ("a", "1"), //
4155                ("b", "2"), //
4156                ("c", "3"), //
4157                ("d", "4"), //
4158                ("e", "5"), //
4159            ]
4160            .into_iter()
4161            .map(|(k, v)| (k.to_owned(), v.to_owned()))
4162            .collect();
4163
4164            // Set metadata on both the schema and a field within it.
4165            let schema = Arc::new(
4166                Schema::new(vec![
4167                    Field::new("a", DataType::Int64, true).with_metadata(metadata.clone()),
4168                ])
4169                .with_metadata(metadata)
4170                .clone(),
4171            );
4172            let batch = RecordBatch::new_empty(schema.clone());
4173
4174            let mut bytes = Vec::new();
4175            let mut w = StreamWriter::try_new(&mut bytes, batch.schema_ref()).unwrap();
4176            w.write(&batch).unwrap();
4177            w.finish().unwrap();
4178
4179            let mut h = std::hash::DefaultHasher::new();
4180            h.write(&bytes);
4181            h.finish()
4182        }
4183
4184        let expected = create_hash();
4185
4186        // Since there is randomness in the HashMap and we cannot specify our
4187        // own Hasher for the implementation used for metadata, run the above
4188        // code 20x and verify it does not change. This is not perfect but it
4189        // should be good enough.
4190        let all_passed = (0..20).all(|_| create_hash() == expected);
4191        assert!(all_passed);
4192    }
4193
4194    #[test]
4195    fn test_dictionary_tracker_reset() {
4196        let data_gen = IpcDataGenerator::default();
4197        let mut dictionary_tracker = DictionaryTracker::new(false);
4198        let writer_options = IpcWriteOptions::default();
4199        let mut compression_ctx = CompressionContext::default();
4200
4201        let schema = Arc::new(Schema::new(vec![Field::new(
4202            "a",
4203            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
4204            false,
4205        )]));
4206
4207        let mut write_single_batch_stream =
4208            |batch: RecordBatch, dict_tracker: &mut DictionaryTracker| -> Vec<u8> {
4209                let mut buffer = Vec::new();
4210
4211                // create a new IPC stream:
4212                let stream_header = data_gen.schema_to_bytes_with_dictionary_tracker(
4213                    &schema,
4214                    dict_tracker,
4215                    &writer_options,
4216                );
4217                _ = write_message(&mut buffer, stream_header, &writer_options).unwrap();
4218
4219                let (encoded_dicts, encoded_batch) = data_gen
4220                    .encode(&batch, dict_tracker, &writer_options, &mut compression_ctx)
4221                    .unwrap();
4222                for encoded_dict in encoded_dicts {
4223                    _ = write_message(&mut buffer, encoded_dict, &writer_options).unwrap();
4224                }
4225                _ = write_message(&mut buffer, encoded_batch, &writer_options).unwrap();
4226
4227                buffer
4228            };
4229
4230        let batch1 = RecordBatch::try_new(
4231            schema.clone(),
4232            vec![Arc::new(DictionaryArray::new(
4233                UInt8Array::from_iter_values([0]),
4234                Arc::new(StringArray::from_iter_values(["a"])),
4235            ))],
4236        )
4237        .unwrap();
4238        let buffer = write_single_batch_stream(batch1.clone(), &mut dictionary_tracker);
4239
4240        // ensure we can read the stream back
4241        let mut reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
4242        let read_batch = reader.next().unwrap().unwrap();
4243        assert_eq!(read_batch, batch1);
4244
4245        // reset the dictionary tracker so it can be used for next stream
4246        dictionary_tracker.clear();
4247
4248        // now write a 2nd stream and ensure we can also read it:
4249        let batch2 = RecordBatch::try_new(
4250            schema.clone(),
4251            vec![Arc::new(DictionaryArray::new(
4252                UInt8Array::from_iter_values([0]),
4253                Arc::new(StringArray::from_iter_values(["a"])),
4254            ))],
4255        )
4256        .unwrap();
4257        let buffer = write_single_batch_stream(batch2.clone(), &mut dictionary_tracker);
4258        let mut reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
4259        let read_batch = reader.next().unwrap().unwrap();
4260        assert_eq!(read_batch, batch2);
4261    }
4262}