arrow_ipc/
writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Writers
19//!
20//! # Notes
21//!
22//! [`FileWriter`] and [`StreamWriter`] have similar interfaces,
23//! however the [`FileWriter`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27use std::cmp::min;
28use std::collections::HashMap;
29use std::io::{BufWriter, Write};
30use std::mem::size_of;
31use std::sync::Arc;
32
33use flatbuffers::FlatBufferBuilder;
34
35use arrow_array::builder::BufferBuilder;
36use arrow_array::cast::*;
37use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
38use arrow_array::*;
39use arrow_buffer::bit_util;
40use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
41use arrow_data::{ArrayData, ArrayDataBuilder, BufferSpec, layout};
42use arrow_schema::*;
43
44use crate::CONTINUATION_MARKER;
45use crate::compression::CompressionCodec;
46pub use crate::compression::CompressionContext;
47use crate::convert::IpcSchemaEncoder;
48
49/// IPC write options used to control the behaviour of the [`IpcDataGenerator`]
50#[derive(Debug, Clone)]
51pub struct IpcWriteOptions {
52    /// Write padding after memory buffers to this multiple of bytes.
53    /// Must be 8, 16, 32, or 64 - defaults to 64.
54    alignment: u8,
55    /// The legacy format is for releases before 0.15.0, and uses metadata V4
56    write_legacy_ipc_format: bool,
57    /// The metadata version to write. The Rust IPC writer supports V4+
58    ///
59    /// *Default versions per crate*
60    ///
61    /// When creating the default IpcWriteOptions, the following metadata versions are used:
62    ///
63    /// version 2.0.0: V4, with legacy format enabled
64    /// version 4.0.0: V5
65    metadata_version: crate::MetadataVersion,
66    /// Compression, if desired. Will result in a runtime error
67    /// if the corresponding feature is not enabled
68    batch_compression_type: Option<crate::CompressionType>,
69    /// How to handle updating dictionaries in IPC messages
70    dictionary_handling: DictionaryHandling,
71}
72
73impl IpcWriteOptions {
74    /// Configures compression when writing IPC files.
75    ///
76    /// Will result in a runtime error if the corresponding feature
77    /// is not enabled
78    pub fn try_with_compression(
79        mut self,
80        batch_compression_type: Option<crate::CompressionType>,
81    ) -> Result<Self, ArrowError> {
82        self.batch_compression_type = batch_compression_type;
83
84        if self.batch_compression_type.is_some()
85            && self.metadata_version < crate::MetadataVersion::V5
86        {
87            return Err(ArrowError::InvalidArgumentError(
88                "Compression only supported in metadata v5 and above".to_string(),
89            ));
90        }
91        Ok(self)
92    }
93    /// Try to create IpcWriteOptions, checking for incompatible settings
94    pub fn try_new(
95        alignment: usize,
96        write_legacy_ipc_format: bool,
97        metadata_version: crate::MetadataVersion,
98    ) -> Result<Self, ArrowError> {
99        let is_alignment_valid =
100            alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
101        if !is_alignment_valid {
102            return Err(ArrowError::InvalidArgumentError(
103                "Alignment should be 8, 16, 32, or 64.".to_string(),
104            ));
105        }
106        let alignment: u8 = u8::try_from(alignment).expect("range already checked");
107        match metadata_version {
108            crate::MetadataVersion::V1
109            | crate::MetadataVersion::V2
110            | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
111                "Writing IPC metadata version 3 and lower not supported".to_string(),
112            )),
113            #[allow(deprecated)]
114            crate::MetadataVersion::V4 => Ok(Self {
115                alignment,
116                write_legacy_ipc_format,
117                metadata_version,
118                batch_compression_type: None,
119                dictionary_handling: DictionaryHandling::default(),
120            }),
121            crate::MetadataVersion::V5 => {
122                if write_legacy_ipc_format {
123                    Err(ArrowError::InvalidArgumentError(
124                        "Legacy IPC format only supported on metadata version 4".to_string(),
125                    ))
126                } else {
127                    Ok(Self {
128                        alignment,
129                        write_legacy_ipc_format,
130                        metadata_version,
131                        batch_compression_type: None,
132                        dictionary_handling: DictionaryHandling::default(),
133                    })
134                }
135            }
136            z => Err(ArrowError::InvalidArgumentError(format!(
137                "Unsupported crate::MetadataVersion {z:?}"
138            ))),
139        }
140    }
141
142    /// Configure how dictionaries are handled in IPC messages
143    pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
144        self.dictionary_handling = dictionary_handling;
145        self
146    }
147}
148
149impl Default for IpcWriteOptions {
150    fn default() -> Self {
151        Self {
152            alignment: 64,
153            write_legacy_ipc_format: false,
154            metadata_version: crate::MetadataVersion::V5,
155            batch_compression_type: None,
156            dictionary_handling: DictionaryHandling::default(),
157        }
158    }
159}
160
161#[derive(Debug, Default)]
162/// Handles low level details of encoding [`Array`] and [`Schema`] into the
163/// [Arrow IPC Format].
164///
165/// # Example
166/// ```
167/// # fn run() {
168/// # use std::sync::Arc;
169/// # use arrow_array::UInt64Array;
170/// # use arrow_array::RecordBatch;
171/// # use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
172///
173/// // Create a record batch
174/// let batch = RecordBatch::try_from_iter(vec![
175///  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
176/// ]).unwrap();
177///
178/// // Error of dictionary ids are replaced.
179/// let error_on_replacement = true;
180/// let options = IpcWriteOptions::default();
181/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement);
182///
183/// let mut compression_context = CompressionContext::default();
184///
185/// // encode the batch into zero or more encoded dictionaries
186/// // and the data for the actual array.
187/// let data_gen = IpcDataGenerator::default();
188/// let (encoded_dictionaries, encoded_message) = data_gen
189///   .encode(&batch, &mut dictionary_tracker, &options, &mut compression_context)
190///   .unwrap();
191/// # }
192/// ```
193///
194/// [Arrow IPC Format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
195pub struct IpcDataGenerator {}
196
197impl IpcDataGenerator {
198    /// Converts a schema to an IPC message along with `dictionary_tracker`
199    /// and returns it encoded inside [EncodedData] as a flatbuffer.
200    pub fn schema_to_bytes_with_dictionary_tracker(
201        &self,
202        schema: &Schema,
203        dictionary_tracker: &mut DictionaryTracker,
204        write_options: &IpcWriteOptions,
205    ) -> EncodedData {
206        let mut fbb = FlatBufferBuilder::new();
207        let schema = {
208            let fb = IpcSchemaEncoder::new()
209                .with_dictionary_tracker(dictionary_tracker)
210                .schema_to_fb_offset(&mut fbb, schema);
211            fb.as_union_value()
212        };
213
214        let mut message = crate::MessageBuilder::new(&mut fbb);
215        message.add_version(write_options.metadata_version);
216        message.add_header_type(crate::MessageHeader::Schema);
217        message.add_bodyLength(0);
218        message.add_header(schema);
219        // TODO: custom metadata
220        let data = message.finish();
221        fbb.finish(data, None);
222
223        let data = fbb.finished_data();
224        EncodedData {
225            ipc_message: data.to_vec(),
226            arrow_data: vec![],
227        }
228    }
229
230    fn _encode_dictionaries<I: Iterator<Item = i64>>(
231        &self,
232        column: &ArrayRef,
233        encoded_dictionaries: &mut Vec<EncodedData>,
234        dictionary_tracker: &mut DictionaryTracker,
235        write_options: &IpcWriteOptions,
236        dict_id: &mut I,
237        compression_context: &mut CompressionContext,
238    ) -> Result<(), ArrowError> {
239        match column.data_type() {
240            DataType::Struct(fields) => {
241                let s = as_struct_array(column);
242                for (field, column) in fields.iter().zip(s.columns()) {
243                    self.encode_dictionaries(
244                        field,
245                        column,
246                        encoded_dictionaries,
247                        dictionary_tracker,
248                        write_options,
249                        dict_id,
250                        compression_context,
251                    )?;
252                }
253            }
254            DataType::RunEndEncoded(_, values) => {
255                let data = column.to_data();
256                if data.child_data().len() != 2 {
257                    return Err(ArrowError::InvalidArgumentError(format!(
258                        "The run encoded array should have exactly two child arrays. Found {}",
259                        data.child_data().len()
260                    )));
261                }
262                // The run_ends array is not expected to be dictionary encoded. Hence encode dictionaries
263                // only for values array.
264                let values_array = make_array(data.child_data()[1].clone());
265                self.encode_dictionaries(
266                    values,
267                    &values_array,
268                    encoded_dictionaries,
269                    dictionary_tracker,
270                    write_options,
271                    dict_id,
272                    compression_context,
273                )?;
274            }
275            DataType::List(field) => {
276                let list = as_list_array(column);
277                self.encode_dictionaries(
278                    field,
279                    list.values(),
280                    encoded_dictionaries,
281                    dictionary_tracker,
282                    write_options,
283                    dict_id,
284                    compression_context,
285                )?;
286            }
287            DataType::LargeList(field) => {
288                let list = as_large_list_array(column);
289                self.encode_dictionaries(
290                    field,
291                    list.values(),
292                    encoded_dictionaries,
293                    dictionary_tracker,
294                    write_options,
295                    dict_id,
296                    compression_context,
297                )?;
298            }
299            DataType::FixedSizeList(field, _) => {
300                let list = column
301                    .as_any()
302                    .downcast_ref::<FixedSizeListArray>()
303                    .expect("Unable to downcast to fixed size list array");
304                self.encode_dictionaries(
305                    field,
306                    list.values(),
307                    encoded_dictionaries,
308                    dictionary_tracker,
309                    write_options,
310                    dict_id,
311                    compression_context,
312                )?;
313            }
314            DataType::Map(field, _) => {
315                let map_array = as_map_array(column);
316
317                let (keys, values) = match field.data_type() {
318                    DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
319                    _ => panic!("Incorrect field data type {:?}", field.data_type()),
320                };
321
322                // keys
323                self.encode_dictionaries(
324                    keys,
325                    map_array.keys(),
326                    encoded_dictionaries,
327                    dictionary_tracker,
328                    write_options,
329                    dict_id,
330                    compression_context,
331                )?;
332
333                // values
334                self.encode_dictionaries(
335                    values,
336                    map_array.values(),
337                    encoded_dictionaries,
338                    dictionary_tracker,
339                    write_options,
340                    dict_id,
341                    compression_context,
342                )?;
343            }
344            DataType::Union(fields, _) => {
345                let union = as_union_array(column);
346                for (type_id, field) in fields.iter() {
347                    let column = union.child(type_id);
348                    self.encode_dictionaries(
349                        field,
350                        column,
351                        encoded_dictionaries,
352                        dictionary_tracker,
353                        write_options,
354                        dict_id,
355                        compression_context,
356                    )?;
357                }
358            }
359            _ => (),
360        }
361
362        Ok(())
363    }
364
365    #[allow(clippy::too_many_arguments)]
366    fn encode_dictionaries<I: Iterator<Item = i64>>(
367        &self,
368        field: &Field,
369        column: &ArrayRef,
370        encoded_dictionaries: &mut Vec<EncodedData>,
371        dictionary_tracker: &mut DictionaryTracker,
372        write_options: &IpcWriteOptions,
373        dict_id_seq: &mut I,
374        compression_context: &mut CompressionContext,
375    ) -> Result<(), ArrowError> {
376        match column.data_type() {
377            DataType::Dictionary(_key_type, _value_type) => {
378                let dict_data = column.to_data();
379                let dict_values = &dict_data.child_data()[0];
380
381                let values = make_array(dict_data.child_data()[0].clone());
382
383                self._encode_dictionaries(
384                    &values,
385                    encoded_dictionaries,
386                    dictionary_tracker,
387                    write_options,
388                    dict_id_seq,
389                    compression_context,
390                )?;
391
392                // It's important to only take the dict_id at this point, because the dict ID
393                // sequence is assigned depth-first, so we need to first encode children and have
394                // them take their assigned dict IDs before we take the dict ID for this field.
395                let dict_id = dict_id_seq.next().ok_or_else(|| {
396                    ArrowError::IpcError(format!("no dict id for field {}", field.name()))
397                })?;
398
399                match dictionary_tracker.insert_column(
400                    dict_id,
401                    column,
402                    write_options.dictionary_handling,
403                )? {
404                    DictionaryUpdate::None => {}
405                    DictionaryUpdate::New | DictionaryUpdate::Replaced => {
406                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
407                            dict_id,
408                            dict_values,
409                            write_options,
410                            false,
411                            compression_context,
412                        )?);
413                    }
414                    DictionaryUpdate::Delta(data) => {
415                        encoded_dictionaries.push(self.dictionary_batch_to_bytes(
416                            dict_id,
417                            &data,
418                            write_options,
419                            true,
420                            compression_context,
421                        )?);
422                    }
423                }
424            }
425            _ => self._encode_dictionaries(
426                column,
427                encoded_dictionaries,
428                dictionary_tracker,
429                write_options,
430                dict_id_seq,
431                compression_context,
432            )?,
433        }
434
435        Ok(())
436    }
437
438    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
439    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
440    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
441    pub fn encode(
442        &self,
443        batch: &RecordBatch,
444        dictionary_tracker: &mut DictionaryTracker,
445        write_options: &IpcWriteOptions,
446        compression_context: &mut CompressionContext,
447    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
448        let schema = batch.schema();
449        let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
450
451        let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
452
453        for (i, field) in schema.fields().iter().enumerate() {
454            let column = batch.column(i);
455            self.encode_dictionaries(
456                field,
457                column,
458                &mut encoded_dictionaries,
459                dictionary_tracker,
460                write_options,
461                &mut dict_id,
462                compression_context,
463            )?;
464        }
465
466        let encoded_message =
467            self.record_batch_to_bytes(batch, write_options, compression_context)?;
468        Ok((encoded_dictionaries, encoded_message))
469    }
470
471    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
472    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
473    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
474    #[deprecated(since = "57.0.0", note = "Use `encode` instead")]
475    pub fn encoded_batch(
476        &self,
477        batch: &RecordBatch,
478        dictionary_tracker: &mut DictionaryTracker,
479        write_options: &IpcWriteOptions,
480    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
481        self.encode(
482            batch,
483            dictionary_tracker,
484            write_options,
485            &mut Default::default(),
486        )
487    }
488
489    /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the
490    /// other for the batch's data
491    fn record_batch_to_bytes(
492        &self,
493        batch: &RecordBatch,
494        write_options: &IpcWriteOptions,
495        compression_context: &mut CompressionContext,
496    ) -> Result<EncodedData, ArrowError> {
497        let mut fbb = FlatBufferBuilder::new();
498
499        let mut nodes: Vec<crate::FieldNode> = vec![];
500        let mut buffers: Vec<crate::Buffer> = vec![];
501        let mut arrow_data: Vec<u8> = vec![];
502        let mut offset = 0;
503
504        // get the type of compression
505        let batch_compression_type = write_options.batch_compression_type;
506
507        let compression = batch_compression_type.map(|batch_compression_type| {
508            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
509            c.add_method(crate::BodyCompressionMethod::BUFFER);
510            c.add_codec(batch_compression_type);
511            c.finish()
512        });
513
514        let compression_codec: Option<CompressionCodec> =
515            batch_compression_type.map(TryInto::try_into).transpose()?;
516
517        let mut variadic_buffer_counts = vec![];
518
519        for array in batch.columns() {
520            let array_data = array.to_data();
521            offset = write_array_data(
522                &array_data,
523                &mut buffers,
524                &mut arrow_data,
525                &mut nodes,
526                offset,
527                array.len(),
528                array.null_count(),
529                compression_codec,
530                compression_context,
531                write_options,
532            )?;
533
534            append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
535        }
536        // pad the tail of body data
537        let len = arrow_data.len();
538        let pad_len = pad_to_alignment(write_options.alignment, len);
539        arrow_data.extend_from_slice(&PADDING[..pad_len]);
540
541        // write data
542        let buffers = fbb.create_vector(&buffers);
543        let nodes = fbb.create_vector(&nodes);
544        let variadic_buffer = if variadic_buffer_counts.is_empty() {
545            None
546        } else {
547            Some(fbb.create_vector(&variadic_buffer_counts))
548        };
549
550        let root = {
551            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
552            batch_builder.add_length(batch.num_rows() as i64);
553            batch_builder.add_nodes(nodes);
554            batch_builder.add_buffers(buffers);
555            if let Some(c) = compression {
556                batch_builder.add_compression(c);
557            }
558
559            if let Some(v) = variadic_buffer {
560                batch_builder.add_variadicBufferCounts(v);
561            }
562            let b = batch_builder.finish();
563            b.as_union_value()
564        };
565        // create an crate::Message
566        let mut message = crate::MessageBuilder::new(&mut fbb);
567        message.add_version(write_options.metadata_version);
568        message.add_header_type(crate::MessageHeader::RecordBatch);
569        message.add_bodyLength(arrow_data.len() as i64);
570        message.add_header(root);
571        let root = message.finish();
572        fbb.finish(root, None);
573        let finished_data = fbb.finished_data();
574
575        Ok(EncodedData {
576            ipc_message: finished_data.to_vec(),
577            arrow_data,
578        })
579    }
580
581    /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the
582    /// other for the data
583    fn dictionary_batch_to_bytes(
584        &self,
585        dict_id: i64,
586        array_data: &ArrayData,
587        write_options: &IpcWriteOptions,
588        is_delta: bool,
589        compression_context: &mut CompressionContext,
590    ) -> Result<EncodedData, ArrowError> {
591        let mut fbb = FlatBufferBuilder::new();
592
593        let mut nodes: Vec<crate::FieldNode> = vec![];
594        let mut buffers: Vec<crate::Buffer> = vec![];
595        let mut arrow_data: Vec<u8> = vec![];
596
597        // get the type of compression
598        let batch_compression_type = write_options.batch_compression_type;
599
600        let compression = batch_compression_type.map(|batch_compression_type| {
601            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
602            c.add_method(crate::BodyCompressionMethod::BUFFER);
603            c.add_codec(batch_compression_type);
604            c.finish()
605        });
606
607        let compression_codec: Option<CompressionCodec> = batch_compression_type
608            .map(|batch_compression_type| batch_compression_type.try_into())
609            .transpose()?;
610
611        write_array_data(
612            array_data,
613            &mut buffers,
614            &mut arrow_data,
615            &mut nodes,
616            0,
617            array_data.len(),
618            array_data.null_count(),
619            compression_codec,
620            compression_context,
621            write_options,
622        )?;
623
624        let mut variadic_buffer_counts = vec![];
625        append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
626
627        // pad the tail of body data
628        let len = arrow_data.len();
629        let pad_len = pad_to_alignment(write_options.alignment, len);
630        arrow_data.extend_from_slice(&PADDING[..pad_len]);
631
632        // write data
633        let buffers = fbb.create_vector(&buffers);
634        let nodes = fbb.create_vector(&nodes);
635        let variadic_buffer = if variadic_buffer_counts.is_empty() {
636            None
637        } else {
638            Some(fbb.create_vector(&variadic_buffer_counts))
639        };
640
641        let root = {
642            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
643            batch_builder.add_length(array_data.len() as i64);
644            batch_builder.add_nodes(nodes);
645            batch_builder.add_buffers(buffers);
646            if let Some(c) = compression {
647                batch_builder.add_compression(c);
648            }
649            if let Some(v) = variadic_buffer {
650                batch_builder.add_variadicBufferCounts(v);
651            }
652            batch_builder.finish()
653        };
654
655        let root = {
656            let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
657            batch_builder.add_id(dict_id);
658            batch_builder.add_data(root);
659            batch_builder.add_isDelta(is_delta);
660            batch_builder.finish().as_union_value()
661        };
662
663        let root = {
664            let mut message_builder = crate::MessageBuilder::new(&mut fbb);
665            message_builder.add_version(write_options.metadata_version);
666            message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
667            message_builder.add_bodyLength(arrow_data.len() as i64);
668            message_builder.add_header(root);
669            message_builder.finish()
670        };
671
672        fbb.finish(root, None);
673        let finished_data = fbb.finished_data();
674
675        Ok(EncodedData {
676            ipc_message: finished_data.to_vec(),
677            arrow_data,
678        })
679    }
680}
681
682fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
683    match array.data_type() {
684        DataType::BinaryView | DataType::Utf8View => {
685            // The spec documents the counts only includes the variadic buffers, not the view/null buffers.
686            // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers
687            counts.push(array.buffers().len() as i64 - 1);
688        }
689        DataType::Dictionary(_, _) => {
690            // Do nothing
691            // Dictionary types are handled in `encode_dictionaries`.
692        }
693        _ => {
694            for child in array.child_data() {
695                append_variadic_buffer_counts(counts, child)
696            }
697        }
698    }
699}
700
701pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
702    match arr.data_type() {
703        DataType::RunEndEncoded(k, _) => match k.data_type() {
704            DataType::Int16 => {
705                Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
706            }
707            DataType::Int32 => {
708                Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
709            }
710            DataType::Int64 => {
711                Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
712            }
713            d => unreachable!("Unexpected data type {d}"),
714        },
715        d => Err(ArrowError::InvalidArgumentError(format!(
716            "The given array is not a run array. Data type of given array: {d}"
717        ))),
718    }
719}
720
721// Returns a `RunArray` with zero offset and length matching the last value
722// in run_ends array.
723fn into_zero_offset_run_array<R: RunEndIndexType>(
724    run_array: RunArray<R>,
725) -> Result<RunArray<R>, ArrowError> {
726    let run_ends = run_array.run_ends();
727    if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
728        return Ok(run_array);
729    }
730
731    // The physical index of original run_ends array from which the `ArrayData`is sliced.
732    let start_physical_index = run_ends.get_start_physical_index();
733
734    // The physical index of original run_ends array until which the `ArrayData`is sliced.
735    let end_physical_index = run_ends.get_end_physical_index();
736
737    let physical_length = end_physical_index - start_physical_index + 1;
738
739    // build new run_ends array by subtracting offset from run ends.
740    let offset = R::Native::usize_as(run_ends.offset());
741    let mut builder = BufferBuilder::<R::Native>::new(physical_length);
742    for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
743        builder.append(run_end_value.sub_wrapping(offset));
744    }
745    builder.append(R::Native::from_usize(run_array.len()).unwrap());
746    let new_run_ends = unsafe {
747        // Safety:
748        // The function builds a valid run_ends array and hence need not be validated.
749        ArrayDataBuilder::new(R::DATA_TYPE)
750            .len(physical_length)
751            .add_buffer(builder.finish())
752            .build_unchecked()
753    };
754
755    // build new values by slicing physical indices.
756    let new_values = run_array
757        .values()
758        .slice(start_physical_index, physical_length)
759        .into_data();
760
761    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
762        .len(run_array.len())
763        .add_child_data(new_run_ends)
764        .add_child_data(new_values);
765    let array_data = unsafe {
766        // Safety:
767        //  This function builds a valid run array and hence can skip validation.
768        builder.build_unchecked()
769    };
770    Ok(array_data.into())
771}
772
773/// Controls how dictionaries are handled in Arrow IPC messages
774#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
775pub enum DictionaryHandling {
776    /// Send the entire dictionary every time it is encountered (default)
777    #[default]
778    Resend,
779    /// Send only new dictionary values since the last batch (delta encoding)
780    ///
781    /// When a dictionary is first encountered, the entire dictionary is sent.
782    /// For subsequent batches, only values that are new (not previously sent)
783    /// are transmitted with the `isDelta` flag set to true.
784    Delta,
785}
786
787/// Describes what kind of update took place after a call to [`DictionaryTracker::insert`].
788#[derive(Debug, Clone)]
789pub enum DictionaryUpdate {
790    /// No dictionary was written, the dictionary was identical to what was already
791    /// in the tracker.
792    None,
793    /// No dictionary was present in the tracker
794    New,
795    /// Dictionary was replaced with the new data
796    Replaced,
797    /// Dictionary was updated, ArrayData is the delta between old and new
798    Delta(ArrayData),
799}
800
801/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
802/// multiple times.
803///
804/// Can optionally error if an update to an existing dictionary is attempted, which
805/// isn't allowed in the `FileWriter`.
806#[derive(Debug)]
807pub struct DictionaryTracker {
808    written: HashMap<i64, ArrayData>,
809    dict_ids: Vec<i64>,
810    error_on_replacement: bool,
811}
812
813impl DictionaryTracker {
814    /// Create a new [`DictionaryTracker`].
815    ///
816    /// If `error_on_replacement`
817    /// is true, an error will be generated if an update to an
818    /// existing dictionary is attempted.
819    pub fn new(error_on_replacement: bool) -> Self {
820        #[allow(deprecated)]
821        Self {
822            written: HashMap::new(),
823            dict_ids: Vec::new(),
824            error_on_replacement,
825        }
826    }
827
828    /// Record and return the next dictionary ID.
829    pub fn next_dict_id(&mut self) -> i64 {
830        let next = self
831            .dict_ids
832            .last()
833            .copied()
834            .map(|i| i + 1)
835            .unwrap_or_default();
836
837        self.dict_ids.push(next);
838        next
839    }
840
841    /// Return the sequence of dictionary IDs in the order they should be observed while
842    /// traversing the schema
843    pub fn dict_id(&mut self) -> &[i64] {
844        &self.dict_ids
845    }
846
847    /// Keep track of the dictionary with the given ID and values. Behavior:
848    ///
849    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
850    ///   that the dictionary was not actually inserted (because it's already been seen).
851    /// * If this ID has been written already but with different data, and this tracker is
852    ///   configured to return an error, return an error.
853    /// * If the tracker has not been configured to error on replacement or this dictionary
854    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
855    ///   inserted.
856    #[deprecated(since = "56.1.0", note = "Use `insert_column` instead")]
857    pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
858        let dict_data = column.to_data();
859        let dict_values = &dict_data.child_data()[0];
860
861        // If a dictionary with this id was already emitted, check if it was the same.
862        if let Some(last) = self.written.get(&dict_id) {
863            if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
864                // Same dictionary values => no need to emit it again
865                return Ok(false);
866            }
867            if self.error_on_replacement {
868                // If error on replacement perform a logical comparison
869                if last.child_data()[0] == *dict_values {
870                    // Same dictionary values => no need to emit it again
871                    return Ok(false);
872                }
873                return Err(ArrowError::InvalidArgumentError(
874                    "Dictionary replacement detected when writing IPC file format. \
875                     Arrow IPC files only support a single dictionary for a given field \
876                     across all batches."
877                        .to_string(),
878                ));
879            }
880        }
881
882        self.written.insert(dict_id, dict_data);
883        Ok(true)
884    }
885
886    /// Keep track of the dictionary with the given ID and values. The return
887    /// value indicates what, if any, update to the internal map took place
888    /// and how it should be interpreted based on the `dict_handling` parameter.
889    ///
890    /// # Returns
891    ///
892    /// * `Ok(Dictionary::New)` - If the dictionary was not previously written
893    /// * `Ok(Dictionary::Replaced)` - If the dictionary was previously written
894    ///   with completely different data, or if the data is a delta of the existing,
895    ///   but with `dict_handling` set to `DictionaryHandling::Resend`
896    /// * `Ok(Dictionary::Delta)` - If the dictionary was previously written, but
897    ///   the new data is a delta of the old and the `dict_handling` is set to
898    ///   `DictionaryHandling::Delta`
899    /// * `Err(e)` - If the dictionary was previously written with different data,
900    ///   and `error_on_replacement` is set to `true`.
901    pub fn insert_column(
902        &mut self,
903        dict_id: i64,
904        column: &ArrayRef,
905        dict_handling: DictionaryHandling,
906    ) -> Result<DictionaryUpdate, ArrowError> {
907        let new_data = column.to_data();
908        let new_values = &new_data.child_data()[0];
909
910        // If there is no existing dictionary with this ID, we always insert
911        let Some(old) = self.written.get(&dict_id) else {
912            self.written.insert(dict_id, new_data);
913            return Ok(DictionaryUpdate::New);
914        };
915
916        // Fast path - If the array data points to the same buffer as the
917        // existing then they're the same.
918        let old_values = &old.child_data()[0];
919        if ArrayData::ptr_eq(old_values, new_values) {
920            return Ok(DictionaryUpdate::None);
921        }
922
923        // Slow path - Compare the dictionaries value by value
924        let comparison = compare_dictionaries(old_values, new_values);
925        if matches!(comparison, DictionaryComparison::Equal) {
926            return Ok(DictionaryUpdate::None);
927        }
928
929        const REPLACEMENT_ERROR: &str = "Dictionary replacement detected when writing IPC file format. \
930                 Arrow IPC files only support a single dictionary for a given field \
931                 across all batches.";
932
933        match comparison {
934            DictionaryComparison::NotEqual => {
935                if self.error_on_replacement {
936                    return Err(ArrowError::InvalidArgumentError(
937                        REPLACEMENT_ERROR.to_string(),
938                    ));
939                }
940
941                self.written.insert(dict_id, new_data);
942                Ok(DictionaryUpdate::Replaced)
943            }
944            DictionaryComparison::Delta => match dict_handling {
945                DictionaryHandling::Resend => {
946                    if self.error_on_replacement {
947                        return Err(ArrowError::InvalidArgumentError(
948                            REPLACEMENT_ERROR.to_string(),
949                        ));
950                    }
951
952                    self.written.insert(dict_id, new_data);
953                    Ok(DictionaryUpdate::Replaced)
954                }
955                DictionaryHandling::Delta => {
956                    let delta =
957                        new_values.slice(old_values.len(), new_values.len() - old_values.len());
958                    self.written.insert(dict_id, new_data);
959                    Ok(DictionaryUpdate::Delta(delta))
960                }
961            },
962            DictionaryComparison::Equal => unreachable!("Already checked equal case"),
963        }
964    }
965}
966
967/// Describes how two dictionary arrays compare to each other.
968#[derive(Debug, Clone)]
969enum DictionaryComparison {
970    /// Neither a delta, nor an exact match
971    NotEqual,
972    /// Exact element-wise match
973    Equal,
974    /// The two arrays are dictionary deltas of each other, meaning the first
975    /// is a prefix of the second.
976    Delta,
977}
978
979// Compares two dictionaries and returns a [`DictionaryComparison`].
980fn compare_dictionaries(old: &ArrayData, new: &ArrayData) -> DictionaryComparison {
981    // Check for exact match
982    let existing_len = old.len();
983    let new_len = new.len();
984    if existing_len == new_len {
985        if *old == *new {
986            return DictionaryComparison::Equal;
987        } else {
988            return DictionaryComparison::NotEqual;
989        }
990    }
991
992    // Can't be a delta if the new is shorter than the existing
993    if new_len < existing_len {
994        return DictionaryComparison::NotEqual;
995    }
996
997    // Check for delta
998    if new.slice(0, existing_len) == *old {
999        return DictionaryComparison::Delta;
1000    }
1001
1002    DictionaryComparison::NotEqual
1003}
1004
1005/// Arrow File Writer
1006///
1007/// Writes Arrow [`RecordBatch`]es in the [IPC File Format].
1008///
1009/// # See Also
1010///
1011/// * [`StreamWriter`] for writing IPC Streams
1012///
1013/// # Example
1014/// ```
1015/// # use arrow_array::record_batch;
1016/// # use arrow_ipc::writer::FileWriter;
1017/// # let mut file = vec![]; // mimic a file for the example
1018/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1019/// // create a new writer, the schema must be known in advance
1020/// let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1021/// // write each batch to the underlying writer
1022/// writer.write(&batch).unwrap();
1023/// // When all batches are written, call finish to flush all buffers
1024/// writer.finish().unwrap();
1025/// ```
1026/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1027pub struct FileWriter<W> {
1028    /// The object to write to
1029    writer: W,
1030    /// IPC write options
1031    write_options: IpcWriteOptions,
1032    /// A reference to the schema, used in validating record batches
1033    schema: SchemaRef,
1034    /// The number of bytes between each block of bytes, as an offset for random access
1035    block_offsets: usize,
1036    /// Dictionary blocks that will be written as part of the IPC footer
1037    dictionary_blocks: Vec<crate::Block>,
1038    /// Record blocks that will be written as part of the IPC footer
1039    record_blocks: Vec<crate::Block>,
1040    /// Whether the writer footer has been written, and the writer is finished
1041    finished: bool,
1042    /// Keeps track of dictionaries that have been written
1043    dictionary_tracker: DictionaryTracker,
1044    /// User level customized metadata
1045    custom_metadata: HashMap<String, String>,
1046
1047    data_gen: IpcDataGenerator,
1048
1049    compression_context: CompressionContext,
1050}
1051
1052impl<W: Write> FileWriter<BufWriter<W>> {
1053    /// Try to create a new file writer with the writer wrapped in a BufWriter.
1054    ///
1055    /// See [`FileWriter::try_new`] for an unbuffered version.
1056    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1057        Self::try_new(BufWriter::new(writer), schema)
1058    }
1059}
1060
1061impl<W: Write> FileWriter<W> {
1062    /// Try to create a new writer, with the schema written as part of the header
1063    ///
1064    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1065    ///
1066    /// # Errors
1067    ///
1068    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1069    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1070        let write_options = IpcWriteOptions::default();
1071        Self::try_new_with_options(writer, schema, write_options)
1072    }
1073
1074    /// Try to create a new writer with IpcWriteOptions
1075    ///
1076    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
1077    ///
1078    /// # Errors
1079    ///
1080    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1081    pub fn try_new_with_options(
1082        mut writer: W,
1083        schema: &Schema,
1084        write_options: IpcWriteOptions,
1085    ) -> Result<Self, ArrowError> {
1086        let data_gen = IpcDataGenerator::default();
1087        // write magic to header aligned on alignment boundary
1088        let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
1089        let header_size = super::ARROW_MAGIC.len() + pad_len;
1090        writer.write_all(&super::ARROW_MAGIC)?;
1091        writer.write_all(&PADDING[..pad_len])?;
1092        // write the schema, set the written bytes to the schema + header
1093        let mut dictionary_tracker = DictionaryTracker::new(true);
1094        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1095            schema,
1096            &mut dictionary_tracker,
1097            &write_options,
1098        );
1099        let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
1100        Ok(Self {
1101            writer,
1102            write_options,
1103            schema: Arc::new(schema.clone()),
1104            block_offsets: meta + data + header_size,
1105            dictionary_blocks: vec![],
1106            record_blocks: vec![],
1107            finished: false,
1108            dictionary_tracker,
1109            custom_metadata: HashMap::new(),
1110            data_gen,
1111            compression_context: CompressionContext::default(),
1112        })
1113    }
1114
1115    /// Adds a key-value pair to the [FileWriter]'s custom metadata
1116    pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
1117        self.custom_metadata.insert(key.into(), value.into());
1118    }
1119
1120    /// Write a record batch to the file
1121    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1122        if self.finished {
1123            return Err(ArrowError::IpcError(
1124                "Cannot write record batch to file writer as it is closed".to_string(),
1125            ));
1126        }
1127
1128        let (encoded_dictionaries, encoded_message) = self.data_gen.encode(
1129            batch,
1130            &mut self.dictionary_tracker,
1131            &self.write_options,
1132            &mut self.compression_context,
1133        )?;
1134
1135        for encoded_dictionary in encoded_dictionaries {
1136            let (meta, data) =
1137                write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1138
1139            let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
1140            self.dictionary_blocks.push(block);
1141            self.block_offsets += meta + data;
1142        }
1143
1144        let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
1145
1146        // add a record block for the footer
1147        let block = crate::Block::new(
1148            self.block_offsets as i64,
1149            meta as i32, // TODO: is this still applicable?
1150            data as i64,
1151        );
1152        self.record_blocks.push(block);
1153        self.block_offsets += meta + data;
1154        Ok(())
1155    }
1156
1157    /// Write footer and closing tag, then mark the writer as done
1158    pub fn finish(&mut self) -> Result<(), ArrowError> {
1159        if self.finished {
1160            return Err(ArrowError::IpcError(
1161                "Cannot write footer to file writer as it is closed".to_string(),
1162            ));
1163        }
1164
1165        // write EOS
1166        write_continuation(&mut self.writer, &self.write_options, 0)?;
1167
1168        let mut fbb = FlatBufferBuilder::new();
1169        let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1170        let record_batches = fbb.create_vector(&self.record_blocks);
1171        let mut dictionary_tracker = DictionaryTracker::new(true);
1172        let schema = IpcSchemaEncoder::new()
1173            .with_dictionary_tracker(&mut dictionary_tracker)
1174            .schema_to_fb_offset(&mut fbb, &self.schema);
1175        let fb_custom_metadata = (!self.custom_metadata.is_empty())
1176            .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1177
1178        let root = {
1179            let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1180            footer_builder.add_version(self.write_options.metadata_version);
1181            footer_builder.add_schema(schema);
1182            footer_builder.add_dictionaries(dictionaries);
1183            footer_builder.add_recordBatches(record_batches);
1184            if let Some(fb_custom_metadata) = fb_custom_metadata {
1185                footer_builder.add_custom_metadata(fb_custom_metadata);
1186            }
1187            footer_builder.finish()
1188        };
1189        fbb.finish(root, None);
1190        let footer_data = fbb.finished_data();
1191        self.writer.write_all(footer_data)?;
1192        self.writer
1193            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1194        self.writer.write_all(&super::ARROW_MAGIC)?;
1195        self.writer.flush()?;
1196        self.finished = true;
1197
1198        Ok(())
1199    }
1200
1201    /// Returns the arrow [`SchemaRef`] for this arrow file.
1202    pub fn schema(&self) -> &SchemaRef {
1203        &self.schema
1204    }
1205
1206    /// Gets a reference to the underlying writer.
1207    pub fn get_ref(&self) -> &W {
1208        &self.writer
1209    }
1210
1211    /// Gets a mutable reference to the underlying writer.
1212    ///
1213    /// It is inadvisable to directly write to the underlying writer.
1214    pub fn get_mut(&mut self) -> &mut W {
1215        &mut self.writer
1216    }
1217
1218    /// Flush the underlying writer.
1219    ///
1220    /// Both the BufWriter and the underlying writer are flushed.
1221    pub fn flush(&mut self) -> Result<(), ArrowError> {
1222        self.writer.flush()?;
1223        Ok(())
1224    }
1225
1226    /// Unwraps the underlying writer.
1227    ///
1228    /// The writer is flushed and the FileWriter is finished before returning.
1229    ///
1230    /// # Errors
1231    ///
1232    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1233    /// or while flushing the writer.
1234    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1235        if !self.finished {
1236            // `finish` flushes the writer.
1237            self.finish()?;
1238        }
1239        Ok(self.writer)
1240    }
1241}
1242
1243impl<W: Write> RecordBatchWriter for FileWriter<W> {
1244    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1245        self.write(batch)
1246    }
1247
1248    fn close(mut self) -> Result<(), ArrowError> {
1249        self.finish()
1250    }
1251}
1252
1253/// Arrow Stream Writer
1254///
1255/// Writes Arrow [`RecordBatch`]es to bytes using the [IPC Streaming Format].
1256///
1257/// # See Also
1258///
1259/// * [`FileWriter`] for writing IPC Files
1260///
1261/// # Example - Basic usage
1262/// ```
1263/// # use arrow_array::record_batch;
1264/// # use arrow_ipc::writer::StreamWriter;
1265/// # let mut stream = vec![]; // mimic a stream for the example
1266/// let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1267/// // create a new writer, the schema must be known in advance
1268/// let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1269/// // write each batch to the underlying stream
1270/// writer.write(&batch).unwrap();
1271/// // When all batches are written, call finish to flush all buffers
1272/// writer.finish().unwrap();
1273/// ```
1274/// # Example - Efficient delta dictionaries
1275/// ```
1276/// # use arrow_array::record_batch;
1277/// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1278/// # use arrow_ipc::writer::DictionaryHandling;
1279/// # use arrow_schema::{DataType, Field, Schema, SchemaRef};
1280/// # use arrow_array::{
1281/// #    builder::StringDictionaryBuilder, types::Int32Type, Array, ArrayRef, DictionaryArray,
1282/// #    RecordBatch, StringArray,
1283/// # };
1284/// # use std::sync::Arc;
1285///
1286/// let schema = Arc::new(Schema::new(vec![Field::new(
1287///    "col1",
1288///    DataType::Dictionary(Box::from(DataType::Int32), Box::from(DataType::Utf8)),
1289///    true,
1290/// )]));
1291///
1292/// let mut builder = StringDictionaryBuilder::<arrow_array::types::Int32Type>::new();
1293///
1294/// // `finish_preserve_values` will keep the dictionary values along with their
1295/// // key assignments so that they can be re-used in the next batch.
1296/// builder.append("a").unwrap();
1297/// builder.append("b").unwrap();
1298/// let array1 = builder.finish_preserve_values();
1299/// let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array1) as ArrayRef]).unwrap();
1300///
1301/// // In this batch, 'a' will have the same dictionary key as 'a' in the previous batch,
1302/// // and 'd' will take the next available key.
1303/// builder.append("a").unwrap();
1304/// builder.append("d").unwrap();
1305/// let array2 = builder.finish_preserve_values();
1306/// let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(array2) as ArrayRef]).unwrap();
1307///
1308/// let mut stream = vec![];
1309/// // You must set `.with_dictionary_handling(DictionaryHandling::Delta)` to
1310/// // enable delta dictionaries in the writer
1311/// let options = IpcWriteOptions::default().with_dictionary_handling(DictionaryHandling::Delta);
1312/// let mut writer = StreamWriter::try_new(&mut stream, &schema).unwrap();
1313///
1314/// // When writing the first batch, a dictionary message with 'a' and 'b' will be written
1315/// // prior to the record batch.
1316/// writer.write(&batch1).unwrap();
1317/// // With the second batch only a delta dictionary with 'd' will be written
1318/// // prior to the record batch. This is only possible with `finish_preserve_values`.
1319/// // Without it, 'a' and 'd' in this batch would have different keys than the
1320/// // first batch and so we'd have to send a replacement dictionary with new keys
1321/// // for both.
1322/// writer.write(&batch2).unwrap();
1323/// writer.finish().unwrap();
1324/// ```
1325/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1326pub struct StreamWriter<W> {
1327    /// The object to write to
1328    writer: W,
1329    /// IPC write options
1330    write_options: IpcWriteOptions,
1331    /// Whether the writer footer has been written, and the writer is finished
1332    finished: bool,
1333    /// Keeps track of dictionaries that have been written
1334    dictionary_tracker: DictionaryTracker,
1335
1336    data_gen: IpcDataGenerator,
1337
1338    compression_context: CompressionContext,
1339}
1340
1341impl<W: Write> StreamWriter<BufWriter<W>> {
1342    /// Try to create a new stream writer with the writer wrapped in a BufWriter.
1343    ///
1344    /// See [`StreamWriter::try_new`] for an unbuffered version.
1345    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1346        Self::try_new(BufWriter::new(writer), schema)
1347    }
1348}
1349
1350impl<W: Write> StreamWriter<W> {
1351    /// Try to create a new writer, with the schema written as part of the header.
1352    ///
1353    /// Note that there is no internal buffering. See also [`StreamWriter::try_new_buffered`].
1354    ///
1355    /// # Errors
1356    ///
1357    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1358    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1359        let write_options = IpcWriteOptions::default();
1360        Self::try_new_with_options(writer, schema, write_options)
1361    }
1362
1363    /// Try to create a new writer with [`IpcWriteOptions`].
1364    ///
1365    /// # Errors
1366    ///
1367    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1368    pub fn try_new_with_options(
1369        mut writer: W,
1370        schema: &Schema,
1371        write_options: IpcWriteOptions,
1372    ) -> Result<Self, ArrowError> {
1373        let data_gen = IpcDataGenerator::default();
1374        let mut dictionary_tracker = DictionaryTracker::new(false);
1375
1376        // write the schema, set the written bytes to the schema
1377        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1378            schema,
1379            &mut dictionary_tracker,
1380            &write_options,
1381        );
1382        write_message(&mut writer, encoded_message, &write_options)?;
1383        Ok(Self {
1384            writer,
1385            write_options,
1386            finished: false,
1387            dictionary_tracker,
1388            data_gen,
1389            compression_context: CompressionContext::default(),
1390        })
1391    }
1392
1393    /// Write a record batch to the stream
1394    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1395        if self.finished {
1396            return Err(ArrowError::IpcError(
1397                "Cannot write record batch to stream writer as it is closed".to_string(),
1398            ));
1399        }
1400
1401        let (encoded_dictionaries, encoded_message) = self
1402            .data_gen
1403            .encode(
1404                batch,
1405                &mut self.dictionary_tracker,
1406                &self.write_options,
1407                &mut self.compression_context,
1408            )
1409            .expect("StreamWriter is configured to not error on dictionary replacement");
1410
1411        for encoded_dictionary in encoded_dictionaries {
1412            write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1413        }
1414
1415        write_message(&mut self.writer, encoded_message, &self.write_options)?;
1416        Ok(())
1417    }
1418
1419    /// Write continuation bytes, and mark the stream as done
1420    pub fn finish(&mut self) -> Result<(), ArrowError> {
1421        if self.finished {
1422            return Err(ArrowError::IpcError(
1423                "Cannot write footer to stream writer as it is closed".to_string(),
1424            ));
1425        }
1426
1427        write_continuation(&mut self.writer, &self.write_options, 0)?;
1428
1429        self.finished = true;
1430
1431        Ok(())
1432    }
1433
1434    /// Gets a reference to the underlying writer.
1435    pub fn get_ref(&self) -> &W {
1436        &self.writer
1437    }
1438
1439    /// Gets a mutable reference to the underlying writer.
1440    ///
1441    /// It is inadvisable to directly write to the underlying writer.
1442    pub fn get_mut(&mut self) -> &mut W {
1443        &mut self.writer
1444    }
1445
1446    /// Flush the underlying writer.
1447    ///
1448    /// Both the BufWriter and the underlying writer are flushed.
1449    pub fn flush(&mut self) -> Result<(), ArrowError> {
1450        self.writer.flush()?;
1451        Ok(())
1452    }
1453
1454    /// Unwraps the the underlying writer.
1455    ///
1456    /// The writer is flushed and the StreamWriter is finished before returning.
1457    ///
1458    /// # Errors
1459    ///
1460    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1461    /// or while flushing the writer.
1462    ///
1463    /// # Example
1464    ///
1465    /// ```
1466    /// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1467    /// # use arrow_ipc::MetadataVersion;
1468    /// # use arrow_schema::{ArrowError, Schema};
1469    /// # fn main() -> Result<(), ArrowError> {
1470    /// // The result we expect from an empty schema
1471    /// let expected = vec![
1472    ///     255, 255, 255, 255,  48,   0,   0,   0,
1473    ///      16,   0,   0,   0,   0,   0,  10,   0,
1474    ///      12,   0,  10,   0,   9,   0,   4,   0,
1475    ///      10,   0,   0,   0,  16,   0,   0,   0,
1476    ///       0,   1,   4,   0,   8,   0,   8,   0,
1477    ///       0,   0,   4,   0,   8,   0,   0,   0,
1478    ///       4,   0,   0,   0,   0,   0,   0,   0,
1479    ///     255, 255, 255, 255,   0,   0,   0,   0
1480    /// ];
1481    ///
1482    /// let schema = Schema::empty();
1483    /// let buffer: Vec<u8> = Vec::new();
1484    /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?;
1485    /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?;
1486    ///
1487    /// assert_eq!(stream_writer.into_inner()?, expected);
1488    /// # Ok(())
1489    /// # }
1490    /// ```
1491    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1492        if !self.finished {
1493            // `finish` flushes.
1494            self.finish()?;
1495        }
1496        Ok(self.writer)
1497    }
1498}
1499
1500impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1501    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1502        self.write(batch)
1503    }
1504
1505    fn close(mut self) -> Result<(), ArrowError> {
1506        self.finish()
1507    }
1508}
1509
1510/// Stores the encoded data, which is an crate::Message, and optional Arrow data
1511pub struct EncodedData {
1512    /// An encoded crate::Message
1513    pub ipc_message: Vec<u8>,
1514    /// Arrow buffers to be written, should be an empty vec for schema messages
1515    pub arrow_data: Vec<u8>,
1516}
1517/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
1518pub fn write_message<W: Write>(
1519    mut writer: W,
1520    encoded: EncodedData,
1521    write_options: &IpcWriteOptions,
1522) -> Result<(usize, usize), ArrowError> {
1523    let arrow_data_len = encoded.arrow_data.len();
1524    if arrow_data_len % usize::from(write_options.alignment) != 0 {
1525        return Err(ArrowError::MemoryError(
1526            "Arrow data not aligned".to_string(),
1527        ));
1528    }
1529
1530    let a = usize::from(write_options.alignment - 1);
1531    let buffer = encoded.ipc_message;
1532    let flatbuf_size = buffer.len();
1533    let prefix_size = if write_options.write_legacy_ipc_format {
1534        4
1535    } else {
1536        8
1537    };
1538    let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1539    let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1540
1541    write_continuation(
1542        &mut writer,
1543        write_options,
1544        (aligned_size - prefix_size) as i32,
1545    )?;
1546
1547    // write the flatbuf
1548    if flatbuf_size > 0 {
1549        writer.write_all(&buffer)?;
1550    }
1551    // write padding
1552    writer.write_all(&PADDING[..padding_bytes])?;
1553
1554    // write arrow data
1555    let body_len = if arrow_data_len > 0 {
1556        write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1557    } else {
1558        0
1559    };
1560
1561    Ok((aligned_size, body_len))
1562}
1563
1564fn write_body_buffers<W: Write>(
1565    mut writer: W,
1566    data: &[u8],
1567    alignment: u8,
1568) -> Result<usize, ArrowError> {
1569    let len = data.len();
1570    let pad_len = pad_to_alignment(alignment, len);
1571    let total_len = len + pad_len;
1572
1573    // write body buffer
1574    writer.write_all(data)?;
1575    if pad_len > 0 {
1576        writer.write_all(&PADDING[..pad_len])?;
1577    }
1578
1579    writer.flush()?;
1580    Ok(total_len)
1581}
1582
1583/// Write a record batch to the writer, writing the message size before the message
1584/// if the record batch is being written to a stream
1585fn write_continuation<W: Write>(
1586    mut writer: W,
1587    write_options: &IpcWriteOptions,
1588    total_len: i32,
1589) -> Result<usize, ArrowError> {
1590    let mut written = 8;
1591
1592    // the version of the writer determines whether continuation markers should be added
1593    match write_options.metadata_version {
1594        crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1595            unreachable!("Options with the metadata version cannot be created")
1596        }
1597        crate::MetadataVersion::V4 => {
1598            if !write_options.write_legacy_ipc_format {
1599                // v0.15.0 format
1600                writer.write_all(&CONTINUATION_MARKER)?;
1601                written = 4;
1602            }
1603            writer.write_all(&total_len.to_le_bytes()[..])?;
1604        }
1605        crate::MetadataVersion::V5 => {
1606            // write continuation marker and message length
1607            writer.write_all(&CONTINUATION_MARKER)?;
1608            writer.write_all(&total_len.to_le_bytes()[..])?;
1609        }
1610        z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1611    };
1612
1613    writer.flush()?;
1614
1615    Ok(written)
1616}
1617
1618/// In V4, null types have no validity bitmap
1619/// In V5 and later, null and union types have no validity bitmap
1620/// Run end encoded type has no validity bitmap.
1621fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1622    if write_options.metadata_version < crate::MetadataVersion::V5 {
1623        !matches!(data_type, DataType::Null)
1624    } else {
1625        !matches!(
1626            data_type,
1627            DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1628        )
1629    }
1630}
1631
1632/// Whether to truncate the buffer
1633#[inline]
1634fn buffer_need_truncate(
1635    array_offset: usize,
1636    buffer: &Buffer,
1637    spec: &BufferSpec,
1638    min_length: usize,
1639) -> bool {
1640    spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1641}
1642
1643/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
1644#[inline]
1645fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1646    match spec {
1647        BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1648        _ => 0,
1649    }
1650}
1651
1652/// Common functionality for re-encoding offsets. Returns the new offsets as well as
1653/// original start offset and length for use in slicing child data.
1654fn reencode_offsets<O: OffsetSizeTrait>(
1655    offsets: &Buffer,
1656    data: &ArrayData,
1657) -> (Buffer, usize, usize) {
1658    let offsets_slice: &[O] = offsets.typed_data::<O>();
1659    let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1660
1661    let start_offset = offset_slice.first().unwrap();
1662    let end_offset = offset_slice.last().unwrap();
1663
1664    let offsets = match start_offset.as_usize() {
1665        0 => {
1666            let size = size_of::<O>();
1667            offsets.slice_with_length(data.offset() * size, (data.len() + 1) * size)
1668        }
1669        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1670    };
1671
1672    let start_offset = start_offset.as_usize();
1673    let end_offset = end_offset.as_usize();
1674
1675    (offsets, start_offset, end_offset - start_offset)
1676}
1677
1678/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O`
1679///
1680/// In particular, this handles re-encoding the offsets if they don't start at `0`,
1681/// slicing the values buffer as appropriate. This helps reduce the encoded
1682/// size of sliced arrays, as values that have been sliced away are not encoded
1683fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1684    if data.is_empty() {
1685        return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1686    }
1687
1688    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1689    let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1690    (offsets, values)
1691}
1692
1693/// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead
1694/// of a values buffer.
1695fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1696    if data.is_empty() {
1697        return (
1698            MutableBuffer::new(0).into(),
1699            data.child_data()[0].slice(0, 0),
1700        );
1701    }
1702
1703    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1704    let child_data = data.child_data()[0].slice(original_start_offset, len);
1705    (offsets, child_data)
1706}
1707
1708/// Write array data to a vector of bytes
1709#[allow(clippy::too_many_arguments)]
1710fn write_array_data(
1711    array_data: &ArrayData,
1712    buffers: &mut Vec<crate::Buffer>,
1713    arrow_data: &mut Vec<u8>,
1714    nodes: &mut Vec<crate::FieldNode>,
1715    offset: i64,
1716    num_rows: usize,
1717    null_count: usize,
1718    compression_codec: Option<CompressionCodec>,
1719    compression_context: &mut CompressionContext,
1720    write_options: &IpcWriteOptions,
1721) -> Result<i64, ArrowError> {
1722    let mut offset = offset;
1723    if !matches!(array_data.data_type(), DataType::Null) {
1724        nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1725    } else {
1726        // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData
1727        // where null_count is always 0.
1728        nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1729    }
1730    if has_validity_bitmap(array_data.data_type(), write_options) {
1731        // write null buffer if exists
1732        let null_buffer = match array_data.nulls() {
1733            None => {
1734                // create a buffer and fill it with valid bits
1735                let num_bytes = bit_util::ceil(num_rows, 8);
1736                let buffer = MutableBuffer::new(num_bytes);
1737                let buffer = buffer.with_bitset(num_bytes, true);
1738                buffer.into()
1739            }
1740            Some(buffer) => buffer.inner().sliced(),
1741        };
1742
1743        offset = write_buffer(
1744            null_buffer.as_slice(),
1745            buffers,
1746            arrow_data,
1747            offset,
1748            compression_codec,
1749            compression_context,
1750            write_options.alignment,
1751        )?;
1752    }
1753
1754    let data_type = array_data.data_type();
1755    if matches!(data_type, DataType::Binary | DataType::Utf8) {
1756        let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1757        for buffer in [offsets, values] {
1758            offset = write_buffer(
1759                buffer.as_slice(),
1760                buffers,
1761                arrow_data,
1762                offset,
1763                compression_codec,
1764                compression_context,
1765                write_options.alignment,
1766            )?;
1767        }
1768    } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1769        // Slicing the views buffer is safe and easy,
1770        // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers
1771        //
1772        // Current implementation just serialize the raw arrays as given and not try to optimize anything.
1773        // If users wants to "compact" the arrays prior to sending them over IPC,
1774        // they should consider the gc API suggested in #5513
1775        for buffer in array_data.buffers() {
1776            offset = write_buffer(
1777                buffer.as_slice(),
1778                buffers,
1779                arrow_data,
1780                offset,
1781                compression_codec,
1782                compression_context,
1783                write_options.alignment,
1784            )?;
1785        }
1786    } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1787        let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1788        for buffer in [offsets, values] {
1789            offset = write_buffer(
1790                buffer.as_slice(),
1791                buffers,
1792                arrow_data,
1793                offset,
1794                compression_codec,
1795                compression_context,
1796                write_options.alignment,
1797            )?;
1798        }
1799    } else if DataType::is_numeric(data_type)
1800        || DataType::is_temporal(data_type)
1801        || matches!(
1802            array_data.data_type(),
1803            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1804        )
1805    {
1806        // Truncate values
1807        assert_eq!(array_data.buffers().len(), 1);
1808
1809        let buffer = &array_data.buffers()[0];
1810        let layout = layout(data_type);
1811        let spec = &layout.buffers[0];
1812
1813        let byte_width = get_buffer_element_width(spec);
1814        let min_length = array_data.len() * byte_width;
1815        let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1816            let byte_offset = array_data.offset() * byte_width;
1817            let buffer_length = min(min_length, buffer.len() - byte_offset);
1818            &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1819        } else {
1820            buffer.as_slice()
1821        };
1822        offset = write_buffer(
1823            buffer_slice,
1824            buffers,
1825            arrow_data,
1826            offset,
1827            compression_codec,
1828            compression_context,
1829            write_options.alignment,
1830        )?;
1831    } else if matches!(data_type, DataType::Boolean) {
1832        // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
1833        // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
1834        assert_eq!(array_data.buffers().len(), 1);
1835
1836        let buffer = &array_data.buffers()[0];
1837        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1838        offset = write_buffer(
1839            &buffer,
1840            buffers,
1841            arrow_data,
1842            offset,
1843            compression_codec,
1844            compression_context,
1845            write_options.alignment,
1846        )?;
1847    } else if matches!(
1848        data_type,
1849        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1850    ) {
1851        assert_eq!(array_data.buffers().len(), 1);
1852        assert_eq!(array_data.child_data().len(), 1);
1853
1854        // Truncate offsets and the child data to avoid writing unnecessary data
1855        let (offsets, sliced_child_data) = match data_type {
1856            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1857            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1858            DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1859            _ => unreachable!(),
1860        };
1861        offset = write_buffer(
1862            offsets.as_slice(),
1863            buffers,
1864            arrow_data,
1865            offset,
1866            compression_codec,
1867            compression_context,
1868            write_options.alignment,
1869        )?;
1870        offset = write_array_data(
1871            &sliced_child_data,
1872            buffers,
1873            arrow_data,
1874            nodes,
1875            offset,
1876            sliced_child_data.len(),
1877            sliced_child_data.null_count(),
1878            compression_codec,
1879            compression_context,
1880            write_options,
1881        )?;
1882        return Ok(offset);
1883    } else if let DataType::FixedSizeList(_, fixed_size) = data_type {
1884        assert_eq!(array_data.child_data().len(), 1);
1885        let fixed_size = *fixed_size as usize;
1886
1887        let child_offset = array_data.offset() * fixed_size;
1888        let child_length = array_data.len() * fixed_size;
1889        let child_data = array_data.child_data()[0].slice(child_offset, child_length);
1890
1891        offset = write_array_data(
1892            &child_data,
1893            buffers,
1894            arrow_data,
1895            nodes,
1896            offset,
1897            child_data.len(),
1898            child_data.null_count(),
1899            compression_codec,
1900            compression_context,
1901            write_options,
1902        )?;
1903        return Ok(offset);
1904    } else {
1905        for buffer in array_data.buffers() {
1906            offset = write_buffer(
1907                buffer,
1908                buffers,
1909                arrow_data,
1910                offset,
1911                compression_codec,
1912                compression_context,
1913                write_options.alignment,
1914            )?;
1915        }
1916    }
1917
1918    match array_data.data_type() {
1919        DataType::Dictionary(_, _) => {}
1920        DataType::RunEndEncoded(_, _) => {
1921            // unslice the run encoded array.
1922            let arr = unslice_run_array(array_data.clone())?;
1923            // recursively write out nested structures
1924            for data_ref in arr.child_data() {
1925                // write the nested data (e.g list data)
1926                offset = write_array_data(
1927                    data_ref,
1928                    buffers,
1929                    arrow_data,
1930                    nodes,
1931                    offset,
1932                    data_ref.len(),
1933                    data_ref.null_count(),
1934                    compression_codec,
1935                    compression_context,
1936                    write_options,
1937                )?;
1938            }
1939        }
1940        _ => {
1941            // recursively write out nested structures
1942            for data_ref in array_data.child_data() {
1943                // write the nested data (e.g list data)
1944                offset = write_array_data(
1945                    data_ref,
1946                    buffers,
1947                    arrow_data,
1948                    nodes,
1949                    offset,
1950                    data_ref.len(),
1951                    data_ref.null_count(),
1952                    compression_codec,
1953                    compression_context,
1954                    write_options,
1955                )?;
1956            }
1957        }
1958    }
1959    Ok(offset)
1960}
1961
1962/// Write a buffer into `arrow_data`, a vector of bytes, and adds its
1963/// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data`
1964///
1965///
1966/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
1967/// Each constituent buffer is first compressed with the indicated
1968/// compressor, and then written with the uncompressed length in the first 8
1969/// bytes as a 64-bit little-endian signed integer followed by the compressed
1970/// buffer bytes (and then padding as required by the protocol). The
1971/// uncompressed length may be set to -1 to indicate that the data that
1972/// follows is not compressed, which can be useful for cases where
1973/// compression does not yield appreciable savings.
1974fn write_buffer(
1975    buffer: &[u8],                    // input
1976    buffers: &mut Vec<crate::Buffer>, // output buffer descriptors
1977    arrow_data: &mut Vec<u8>,         // output stream
1978    offset: i64,                      // current output stream offset
1979    compression_codec: Option<CompressionCodec>,
1980    compression_context: &mut CompressionContext,
1981    alignment: u8,
1982) -> Result<i64, ArrowError> {
1983    let len: i64 = match compression_codec {
1984        Some(compressor) => compressor.compress_to_vec(buffer, arrow_data, compression_context)?,
1985        None => {
1986            arrow_data.extend_from_slice(buffer);
1987            buffer.len()
1988        }
1989    }
1990    .try_into()
1991    .map_err(|e| {
1992        ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
1993    })?;
1994
1995    // make new index entry
1996    buffers.push(crate::Buffer::new(offset, len));
1997    // padding and make offset aligned
1998    let pad_len = pad_to_alignment(alignment, len as usize);
1999    arrow_data.extend_from_slice(&PADDING[..pad_len]);
2000
2001    Ok(offset + len + (pad_len as i64))
2002}
2003
2004const PADDING: [u8; 64] = [0; 64];
2005
2006/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary
2007#[inline]
2008fn pad_to_alignment(alignment: u8, len: usize) -> usize {
2009    let a = usize::from(alignment - 1);
2010    ((len + a) & !a) - len
2011}
2012
2013#[cfg(test)]
2014mod tests {
2015    use std::hash::Hasher;
2016    use std::io::Cursor;
2017    use std::io::Seek;
2018
2019    use arrow_array::builder::FixedSizeListBuilder;
2020    use arrow_array::builder::Float32Builder;
2021    use arrow_array::builder::Int64Builder;
2022    use arrow_array::builder::MapBuilder;
2023    use arrow_array::builder::UnionBuilder;
2024    use arrow_array::builder::{GenericListBuilder, ListBuilder, StringBuilder};
2025    use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
2026    use arrow_array::types::*;
2027    use arrow_buffer::ScalarBuffer;
2028
2029    use crate::MetadataVersion;
2030    use crate::convert::fb_to_schema;
2031    use crate::reader::*;
2032    use crate::root_as_footer;
2033
2034    use super::*;
2035
2036    fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
2037        let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
2038        writer.write(rb).unwrap();
2039        writer.finish().unwrap();
2040        writer.into_inner().unwrap()
2041    }
2042
2043    fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
2044        let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
2045        reader.next().unwrap().unwrap()
2046    }
2047
2048    fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
2049        // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written,
2050        // without needing to construct a giant array to spill over the 64-byte default alignment
2051        // boundary.
2052        const IPC_ALIGNMENT: usize = 8;
2053
2054        let mut stream_writer = StreamWriter::try_new_with_options(
2055            vec![],
2056            record.schema_ref(),
2057            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2058        )
2059        .unwrap();
2060        stream_writer.write(record).unwrap();
2061        stream_writer.finish().unwrap();
2062        stream_writer.into_inner().unwrap()
2063    }
2064
2065    fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
2066        let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
2067        stream_reader.next().unwrap().unwrap()
2068    }
2069
2070    #[test]
2071    #[cfg(feature = "lz4")]
2072    fn test_write_empty_record_batch_lz4_compression() {
2073        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2074        let values: Vec<Option<i32>> = vec![];
2075        let array = Int32Array::from(values);
2076        let record_batch =
2077            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2078
2079        let mut file = tempfile::tempfile().unwrap();
2080
2081        {
2082            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2083                .unwrap()
2084                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2085                .unwrap();
2086
2087            let mut writer =
2088                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2089            writer.write(&record_batch).unwrap();
2090            writer.finish().unwrap();
2091        }
2092        file.rewind().unwrap();
2093        {
2094            // read file
2095            let reader = FileReader::try_new(file, None).unwrap();
2096            for read_batch in reader {
2097                read_batch
2098                    .unwrap()
2099                    .columns()
2100                    .iter()
2101                    .zip(record_batch.columns())
2102                    .for_each(|(a, b)| {
2103                        assert_eq!(a.data_type(), b.data_type());
2104                        assert_eq!(a.len(), b.len());
2105                        assert_eq!(a.null_count(), b.null_count());
2106                    });
2107            }
2108        }
2109    }
2110
2111    #[test]
2112    #[cfg(feature = "lz4")]
2113    fn test_write_file_with_lz4_compression() {
2114        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2115        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2116        let array = Int32Array::from(values);
2117        let record_batch =
2118            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2119
2120        let mut file = tempfile::tempfile().unwrap();
2121        {
2122            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2123                .unwrap()
2124                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
2125                .unwrap();
2126
2127            let mut writer =
2128                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2129            writer.write(&record_batch).unwrap();
2130            writer.finish().unwrap();
2131        }
2132        file.rewind().unwrap();
2133        {
2134            // read file
2135            let reader = FileReader::try_new(file, None).unwrap();
2136            for read_batch in reader {
2137                read_batch
2138                    .unwrap()
2139                    .columns()
2140                    .iter()
2141                    .zip(record_batch.columns())
2142                    .for_each(|(a, b)| {
2143                        assert_eq!(a.data_type(), b.data_type());
2144                        assert_eq!(a.len(), b.len());
2145                        assert_eq!(a.null_count(), b.null_count());
2146                    });
2147            }
2148        }
2149    }
2150
2151    #[test]
2152    #[cfg(feature = "zstd")]
2153    fn test_write_file_with_zstd_compression() {
2154        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
2155        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
2156        let array = Int32Array::from(values);
2157        let record_batch =
2158            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
2159        let mut file = tempfile::tempfile().unwrap();
2160        {
2161            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
2162                .unwrap()
2163                .try_with_compression(Some(crate::CompressionType::ZSTD))
2164                .unwrap();
2165
2166            let mut writer =
2167                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
2168            writer.write(&record_batch).unwrap();
2169            writer.finish().unwrap();
2170        }
2171        file.rewind().unwrap();
2172        {
2173            // read file
2174            let reader = FileReader::try_new(file, None).unwrap();
2175            for read_batch in reader {
2176                read_batch
2177                    .unwrap()
2178                    .columns()
2179                    .iter()
2180                    .zip(record_batch.columns())
2181                    .for_each(|(a, b)| {
2182                        assert_eq!(a.data_type(), b.data_type());
2183                        assert_eq!(a.len(), b.len());
2184                        assert_eq!(a.null_count(), b.null_count());
2185                    });
2186            }
2187        }
2188    }
2189
2190    #[test]
2191    fn test_write_file() {
2192        let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
2193        let values: Vec<Option<u32>> = vec![
2194            Some(999),
2195            None,
2196            Some(235),
2197            Some(123),
2198            None,
2199            None,
2200            None,
2201            None,
2202            None,
2203        ];
2204        let array1 = UInt32Array::from(values);
2205        let batch =
2206            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
2207                .unwrap();
2208        let mut file = tempfile::tempfile().unwrap();
2209        {
2210            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2211
2212            writer.write(&batch).unwrap();
2213            writer.finish().unwrap();
2214        }
2215        file.rewind().unwrap();
2216
2217        {
2218            let mut reader = FileReader::try_new(file, None).unwrap();
2219            while let Some(Ok(read_batch)) = reader.next() {
2220                read_batch
2221                    .columns()
2222                    .iter()
2223                    .zip(batch.columns())
2224                    .for_each(|(a, b)| {
2225                        assert_eq!(a.data_type(), b.data_type());
2226                        assert_eq!(a.len(), b.len());
2227                        assert_eq!(a.null_count(), b.null_count());
2228                    });
2229            }
2230        }
2231    }
2232
2233    fn write_null_file(options: IpcWriteOptions) {
2234        let schema = Schema::new(vec![
2235            Field::new("nulls", DataType::Null, true),
2236            Field::new("int32s", DataType::Int32, false),
2237            Field::new("nulls2", DataType::Null, true),
2238            Field::new("f64s", DataType::Float64, false),
2239        ]);
2240        let array1 = NullArray::new(32);
2241        let array2 = Int32Array::from(vec![1; 32]);
2242        let array3 = NullArray::new(32);
2243        let array4 = Float64Array::from(vec![f64::NAN; 32]);
2244        let batch = RecordBatch::try_new(
2245            Arc::new(schema.clone()),
2246            vec![
2247                Arc::new(array1) as ArrayRef,
2248                Arc::new(array2) as ArrayRef,
2249                Arc::new(array3) as ArrayRef,
2250                Arc::new(array4) as ArrayRef,
2251            ],
2252        )
2253        .unwrap();
2254        let mut file = tempfile::tempfile().unwrap();
2255        {
2256            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2257
2258            writer.write(&batch).unwrap();
2259            writer.finish().unwrap();
2260        }
2261
2262        file.rewind().unwrap();
2263
2264        {
2265            let reader = FileReader::try_new(file, None).unwrap();
2266            reader.for_each(|maybe_batch| {
2267                maybe_batch
2268                    .unwrap()
2269                    .columns()
2270                    .iter()
2271                    .zip(batch.columns())
2272                    .for_each(|(a, b)| {
2273                        assert_eq!(a.data_type(), b.data_type());
2274                        assert_eq!(a.len(), b.len());
2275                        assert_eq!(a.null_count(), b.null_count());
2276                    });
2277            });
2278        }
2279    }
2280    #[test]
2281    fn test_write_null_file_v4() {
2282        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2283        write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2284        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2285        write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2286    }
2287
2288    #[test]
2289    fn test_write_null_file_v5() {
2290        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2291        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2292    }
2293
2294    #[test]
2295    fn track_union_nested_dict() {
2296        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2297
2298        let array = Arc::new(inner) as ArrayRef;
2299
2300        // Dict field with id 2
2301        #[allow(deprecated)]
2302        let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 0, false);
2303        let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2304
2305        let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2306        let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2307
2308        let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2309
2310        let schema = Arc::new(Schema::new(vec![Field::new(
2311            "union",
2312            union.data_type().clone(),
2313            false,
2314        )]));
2315
2316        let r#gen = IpcDataGenerator::default();
2317        let mut dict_tracker = DictionaryTracker::new(false);
2318        r#gen.schema_to_bytes_with_dictionary_tracker(
2319            &schema,
2320            &mut dict_tracker,
2321            &IpcWriteOptions::default(),
2322        );
2323
2324        let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2325
2326        r#gen
2327            .encode(
2328                &batch,
2329                &mut dict_tracker,
2330                &Default::default(),
2331                &mut Default::default(),
2332            )
2333            .unwrap();
2334
2335        // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema
2336        // so we expect the dict will be keyed to 0
2337        assert!(dict_tracker.written.contains_key(&0));
2338    }
2339
2340    #[test]
2341    fn track_struct_nested_dict() {
2342        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2343
2344        let array = Arc::new(inner) as ArrayRef;
2345
2346        // Dict field with id 2
2347        #[allow(deprecated)]
2348        let dctfield = Arc::new(Field::new_dict(
2349            "dict",
2350            array.data_type().clone(),
2351            false,
2352            2,
2353            false,
2354        ));
2355
2356        let s = StructArray::from(vec![(dctfield, array)]);
2357        let struct_array = Arc::new(s) as ArrayRef;
2358
2359        let schema = Arc::new(Schema::new(vec![Field::new(
2360            "struct",
2361            struct_array.data_type().clone(),
2362            false,
2363        )]));
2364
2365        let r#gen = IpcDataGenerator::default();
2366        let mut dict_tracker = DictionaryTracker::new(false);
2367        r#gen.schema_to_bytes_with_dictionary_tracker(
2368            &schema,
2369            &mut dict_tracker,
2370            &IpcWriteOptions::default(),
2371        );
2372
2373        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2374
2375        r#gen
2376            .encode(
2377                &batch,
2378                &mut dict_tracker,
2379                &Default::default(),
2380                &mut Default::default(),
2381            )
2382            .unwrap();
2383
2384        assert!(dict_tracker.written.contains_key(&0));
2385    }
2386
2387    fn write_union_file(options: IpcWriteOptions) {
2388        let schema = Schema::new(vec![Field::new_union(
2389            "union",
2390            vec![0, 1],
2391            vec![
2392                Field::new("a", DataType::Int32, false),
2393                Field::new("c", DataType::Float64, false),
2394            ],
2395            UnionMode::Sparse,
2396        )]);
2397        let mut builder = UnionBuilder::with_capacity_sparse(5);
2398        builder.append::<Int32Type>("a", 1).unwrap();
2399        builder.append_null::<Int32Type>("a").unwrap();
2400        builder.append::<Float64Type>("c", 3.0).unwrap();
2401        builder.append_null::<Float64Type>("c").unwrap();
2402        builder.append::<Int32Type>("a", 4).unwrap();
2403        let union = builder.build().unwrap();
2404
2405        let batch =
2406            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2407                .unwrap();
2408
2409        let mut file = tempfile::tempfile().unwrap();
2410        {
2411            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2412
2413            writer.write(&batch).unwrap();
2414            writer.finish().unwrap();
2415        }
2416        file.rewind().unwrap();
2417
2418        {
2419            let reader = FileReader::try_new(file, None).unwrap();
2420            reader.for_each(|maybe_batch| {
2421                maybe_batch
2422                    .unwrap()
2423                    .columns()
2424                    .iter()
2425                    .zip(batch.columns())
2426                    .for_each(|(a, b)| {
2427                        assert_eq!(a.data_type(), b.data_type());
2428                        assert_eq!(a.len(), b.len());
2429                        assert_eq!(a.null_count(), b.null_count());
2430                    });
2431            });
2432        }
2433    }
2434
2435    #[test]
2436    fn test_write_union_file_v4_v5() {
2437        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2438        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2439    }
2440
2441    #[test]
2442    fn test_write_view_types() {
2443        const LONG_TEST_STRING: &str =
2444            "This is a long string to make sure binary view array handles it";
2445        let schema = Schema::new(vec![
2446            Field::new("field1", DataType::BinaryView, true),
2447            Field::new("field2", DataType::Utf8View, true),
2448        ]);
2449        let values: Vec<Option<&[u8]>> = vec![
2450            Some(b"foo"),
2451            Some(b"bar"),
2452            Some(LONG_TEST_STRING.as_bytes()),
2453        ];
2454        let binary_array = BinaryViewArray::from_iter(values);
2455        let utf8_array =
2456            StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2457        let record_batch = RecordBatch::try_new(
2458            Arc::new(schema.clone()),
2459            vec![Arc::new(binary_array), Arc::new(utf8_array)],
2460        )
2461        .unwrap();
2462
2463        let mut file = tempfile::tempfile().unwrap();
2464        {
2465            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2466            writer.write(&record_batch).unwrap();
2467            writer.finish().unwrap();
2468        }
2469        file.rewind().unwrap();
2470        {
2471            let mut reader = FileReader::try_new(&file, None).unwrap();
2472            let read_batch = reader.next().unwrap().unwrap();
2473            read_batch
2474                .columns()
2475                .iter()
2476                .zip(record_batch.columns())
2477                .for_each(|(a, b)| {
2478                    assert_eq!(a, b);
2479                });
2480        }
2481        file.rewind().unwrap();
2482        {
2483            let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2484            let read_batch = reader.next().unwrap().unwrap();
2485            assert_eq!(read_batch.num_columns(), 1);
2486            let read_array = read_batch.column(0);
2487            let write_array = record_batch.column(0);
2488            assert_eq!(read_array, write_array);
2489        }
2490    }
2491
2492    #[test]
2493    fn truncate_ipc_record_batch() {
2494        fn create_batch(rows: usize) -> RecordBatch {
2495            let schema = Schema::new(vec![
2496                Field::new("a", DataType::Int32, false),
2497                Field::new("b", DataType::Utf8, false),
2498            ]);
2499
2500            let a = Int32Array::from_iter_values(0..rows as i32);
2501            let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2502
2503            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2504        }
2505
2506        let big_record_batch = create_batch(65536);
2507
2508        let length = 5;
2509        let small_record_batch = create_batch(length);
2510
2511        let offset = 2;
2512        let record_batch_slice = big_record_batch.slice(offset, length);
2513        assert!(
2514            serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2515        );
2516        assert_eq!(
2517            serialize_stream(&small_record_batch).len(),
2518            serialize_stream(&record_batch_slice).len()
2519        );
2520
2521        assert_eq!(
2522            deserialize_stream(serialize_stream(&record_batch_slice)),
2523            record_batch_slice
2524        );
2525    }
2526
2527    #[test]
2528    fn truncate_ipc_record_batch_with_nulls() {
2529        fn create_batch() -> RecordBatch {
2530            let schema = Schema::new(vec![
2531                Field::new("a", DataType::Int32, true),
2532                Field::new("b", DataType::Utf8, true),
2533            ]);
2534
2535            let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2536            let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2537
2538            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2539        }
2540
2541        let record_batch = create_batch();
2542        let record_batch_slice = record_batch.slice(1, 2);
2543        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2544
2545        assert!(
2546            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2547        );
2548
2549        assert!(deserialized_batch.column(0).is_null(0));
2550        assert!(deserialized_batch.column(0).is_valid(1));
2551        assert!(deserialized_batch.column(1).is_valid(0));
2552        assert!(deserialized_batch.column(1).is_valid(1));
2553
2554        assert_eq!(record_batch_slice, deserialized_batch);
2555    }
2556
2557    #[test]
2558    fn truncate_ipc_dictionary_array() {
2559        fn create_batch() -> RecordBatch {
2560            let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2561                .into_iter()
2562                .collect();
2563            let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2564
2565            let array = DictionaryArray::new(keys, Arc::new(values));
2566
2567            let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2568
2569            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2570        }
2571
2572        let record_batch = create_batch();
2573        let record_batch_slice = record_batch.slice(1, 2);
2574        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2575
2576        assert!(
2577            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2578        );
2579
2580        assert!(deserialized_batch.column(0).is_valid(0));
2581        assert!(deserialized_batch.column(0).is_null(1));
2582
2583        assert_eq!(record_batch_slice, deserialized_batch);
2584    }
2585
2586    #[test]
2587    fn truncate_ipc_struct_array() {
2588        fn create_batch() -> RecordBatch {
2589            let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2590                .into_iter()
2591                .collect();
2592            let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2593
2594            let struct_array = StructArray::from(vec![
2595                (
2596                    Arc::new(Field::new("s", DataType::Utf8, true)),
2597                    Arc::new(strings) as ArrayRef,
2598                ),
2599                (
2600                    Arc::new(Field::new("c", DataType::Int32, true)),
2601                    Arc::new(ints) as ArrayRef,
2602                ),
2603            ]);
2604
2605            let schema = Schema::new(vec![Field::new(
2606                "struct_array",
2607                struct_array.data_type().clone(),
2608                true,
2609            )]);
2610
2611            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2612        }
2613
2614        let record_batch = create_batch();
2615        let record_batch_slice = record_batch.slice(1, 2);
2616        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2617
2618        assert!(
2619            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2620        );
2621
2622        let structs = deserialized_batch
2623            .column(0)
2624            .as_any()
2625            .downcast_ref::<StructArray>()
2626            .unwrap();
2627
2628        assert!(structs.column(0).is_null(0));
2629        assert!(structs.column(0).is_valid(1));
2630        assert!(structs.column(1).is_valid(0));
2631        assert!(structs.column(1).is_null(1));
2632        assert_eq!(record_batch_slice, deserialized_batch);
2633    }
2634
2635    #[test]
2636    fn truncate_ipc_string_array_with_all_empty_string() {
2637        fn create_batch() -> RecordBatch {
2638            let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2639            let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2640            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2641        }
2642
2643        let record_batch = create_batch();
2644        let record_batch_slice = record_batch.slice(0, 1);
2645        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2646
2647        assert!(
2648            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2649        );
2650        assert_eq!(record_batch_slice, deserialized_batch);
2651    }
2652
2653    #[test]
2654    fn test_stream_writer_writes_array_slice() {
2655        let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2656        assert_eq!(
2657            vec![Some(1), Some(2), Some(3)],
2658            array.iter().collect::<Vec<_>>()
2659        );
2660
2661        let sliced = array.slice(1, 2);
2662        assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2663
2664        let batch = RecordBatch::try_new(
2665            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2666            vec![Arc::new(sliced)],
2667        )
2668        .expect("new batch");
2669
2670        let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2671        writer.write(&batch).expect("write");
2672        let outbuf = writer.into_inner().expect("inner");
2673
2674        let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2675        let read_batch = reader.next().unwrap().expect("read batch");
2676
2677        let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2678        assert_eq!(
2679            vec![Some(2), Some(3)],
2680            read_array.iter().collect::<Vec<_>>()
2681        );
2682    }
2683
2684    #[test]
2685    fn test_large_slice_uint32() {
2686        ensure_roundtrip(Arc::new(UInt32Array::from_iter(
2687            (0..8000).map(|i| if i % 2 == 0 { Some(i) } else { None }),
2688        )));
2689    }
2690
2691    #[test]
2692    fn test_large_slice_string() {
2693        let strings: Vec<_> = (0..8000)
2694            .map(|i| {
2695                if i % 2 == 0 {
2696                    Some(format!("value{i}"))
2697                } else {
2698                    None
2699                }
2700            })
2701            .collect();
2702
2703        ensure_roundtrip(Arc::new(StringArray::from(strings)));
2704    }
2705
2706    #[test]
2707    fn test_large_slice_string_list() {
2708        let mut ls = ListBuilder::new(StringBuilder::new());
2709
2710        let mut s = String::new();
2711        for row_number in 0..8000 {
2712            if row_number % 2 == 0 {
2713                for list_element in 0..1000 {
2714                    s.clear();
2715                    use std::fmt::Write;
2716                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2717                    ls.values().append_value(&s);
2718                }
2719                ls.append(true)
2720            } else {
2721                ls.append(false); // null
2722            }
2723        }
2724
2725        ensure_roundtrip(Arc::new(ls.finish()));
2726    }
2727
2728    #[test]
2729    fn test_large_slice_string_list_of_lists() {
2730        // The reason for the special test is to verify reencode_offsets which looks both at
2731        // the starting offset and the data offset.  So need a dataset where the starting_offset
2732        // is zero but the data offset is not.
2733        let mut ls = ListBuilder::new(ListBuilder::new(StringBuilder::new()));
2734
2735        for _ in 0..4000 {
2736            ls.values().append(true);
2737            ls.append(true)
2738        }
2739
2740        let mut s = String::new();
2741        for row_number in 0..4000 {
2742            if row_number % 2 == 0 {
2743                for list_element in 0..1000 {
2744                    s.clear();
2745                    use std::fmt::Write;
2746                    write!(&mut s, "value{row_number}-{list_element}").unwrap();
2747                    ls.values().values().append_value(&s);
2748                }
2749                ls.values().append(true);
2750                ls.append(true)
2751            } else {
2752                ls.append(false); // null
2753            }
2754        }
2755
2756        ensure_roundtrip(Arc::new(ls.finish()));
2757    }
2758
2759    /// Read/write a record batch to a File and Stream and ensure it is the same at the outout
2760    fn ensure_roundtrip(array: ArrayRef) {
2761        let num_rows = array.len();
2762        let orig_batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap();
2763        // take off the first element
2764        let sliced_batch = orig_batch.slice(1, num_rows - 1);
2765
2766        let schema = orig_batch.schema();
2767        let stream_data = {
2768            let mut writer = StreamWriter::try_new(vec![], &schema).unwrap();
2769            writer.write(&sliced_batch).unwrap();
2770            writer.into_inner().unwrap()
2771        };
2772        let read_batch = {
2773            let projection = None;
2774            let mut reader = StreamReader::try_new(Cursor::new(stream_data), projection).unwrap();
2775            reader
2776                .next()
2777                .expect("expect no errors reading batch")
2778                .expect("expect batch")
2779        };
2780        assert_eq!(sliced_batch, read_batch);
2781
2782        let file_data = {
2783            let mut writer = FileWriter::try_new_buffered(vec![], &schema).unwrap();
2784            writer.write(&sliced_batch).unwrap();
2785            writer.into_inner().unwrap().into_inner().unwrap()
2786        };
2787        let read_batch = {
2788            let projection = None;
2789            let mut reader = FileReader::try_new(Cursor::new(file_data), projection).unwrap();
2790            reader
2791                .next()
2792                .expect("expect no errors reading batch")
2793                .expect("expect batch")
2794        };
2795        assert_eq!(sliced_batch, read_batch);
2796
2797        // TODO test file writer/reader
2798    }
2799
2800    #[test]
2801    fn encode_bools_slice() {
2802        // Test case for https://github.com/apache/arrow-rs/issues/3496
2803        assert_bool_roundtrip([true, false], 1, 1);
2804
2805        // slice somewhere in the middle
2806        assert_bool_roundtrip(
2807            [
2808                true, false, true, true, false, false, true, true, true, false, false, false, true,
2809                true, true, true, false, false, false, false, true, true, true, true, true, false,
2810                false, false, false, false,
2811            ],
2812            13,
2813            17,
2814        );
2815
2816        // start at byte boundary, end in the middle
2817        assert_bool_roundtrip(
2818            [
2819                true, false, true, true, false, false, true, true, true, false, false, false,
2820            ],
2821            8,
2822            2,
2823        );
2824
2825        // start and stop and byte boundary
2826        assert_bool_roundtrip(
2827            [
2828                true, false, true, true, false, false, true, true, true, false, false, false, true,
2829                true, true, true, true, false, false, false, false, false,
2830            ],
2831            8,
2832            8,
2833        );
2834    }
2835
2836    fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2837        let val_bool_field = Field::new("val", DataType::Boolean, false);
2838
2839        let schema = Arc::new(Schema::new(vec![val_bool_field]));
2840
2841        let bools = BooleanArray::from(bools.to_vec());
2842
2843        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2844        let batch = batch.slice(offset, length);
2845
2846        let data = serialize_stream(&batch);
2847        let batch2 = deserialize_stream(data);
2848        assert_eq!(batch, batch2);
2849    }
2850
2851    #[test]
2852    fn test_run_array_unslice() {
2853        let total_len = 80;
2854        let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2855        let repeats: Vec<usize> = vec![3, 4, 1, 2];
2856        let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2857        for ix in 0_usize..32 {
2858            let repeat: usize = repeats[ix % repeats.len()];
2859            let val: Option<i32> = vals[ix % vals.len()];
2860            input_array.resize(input_array.len() + repeat, val);
2861        }
2862
2863        // Encode the input_array to run array
2864        let mut builder =
2865            PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
2866        builder.extend(input_array.iter().copied());
2867        let run_array = builder.finish();
2868
2869        // test for all slice lengths.
2870        for slice_len in 1..=total_len {
2871            // test for offset = 0, slice length = slice_len
2872            let sliced_run_array: RunArray<Int16Type> =
2873                run_array.slice(0, slice_len).into_data().into();
2874
2875            // Create unsliced run array.
2876            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2877            let typed = unsliced_run_array
2878                .downcast::<PrimitiveArray<Int32Type>>()
2879                .unwrap();
2880            let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
2881            let actual: Vec<Option<i32>> = typed.into_iter().collect();
2882            assert_eq!(expected, actual);
2883
2884            // test for offset = total_len - slice_len, length = slice_len
2885            let sliced_run_array: RunArray<Int16Type> = run_array
2886                .slice(total_len - slice_len, slice_len)
2887                .into_data()
2888                .into();
2889
2890            // Create unsliced run array.
2891            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2892            let typed = unsliced_run_array
2893                .downcast::<PrimitiveArray<Int32Type>>()
2894                .unwrap();
2895            let expected: Vec<Option<i32>> = input_array
2896                .iter()
2897                .skip(total_len - slice_len)
2898                .copied()
2899                .collect();
2900            let actual: Vec<Option<i32>> = typed.into_iter().collect();
2901            assert_eq!(expected, actual);
2902        }
2903    }
2904
2905    fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2906        let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
2907
2908        for i in 0..100_000 {
2909            for value in [i, i, i] {
2910                ls.values().append_value(value);
2911            }
2912            ls.append(true)
2913        }
2914
2915        ls.finish()
2916    }
2917
2918    fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2919        let mut ls =
2920            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2921
2922        for _i in 0..10_000 {
2923            for j in 0..10 {
2924                for value in [j, j, j, j] {
2925                    ls.values().values().append_value(value);
2926                }
2927                ls.values().append(true)
2928            }
2929            ls.append(true);
2930        }
2931
2932        ls.finish()
2933    }
2934
2935    fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
2936        let mut ls =
2937            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2938
2939        for _i in 0..999 {
2940            ls.values().append(true);
2941            ls.append(true);
2942        }
2943
2944        for j in 0..10 {
2945            for value in [j, j, j, j] {
2946                ls.values().values().append_value(value);
2947            }
2948            ls.values().append(true)
2949        }
2950        ls.append(true);
2951
2952        for i in 0..9_000 {
2953            for j in 0..10 {
2954                for value in [i + j, i + j, i + j, i + j] {
2955                    ls.values().values().append_value(value);
2956                }
2957                ls.values().append(true)
2958            }
2959            ls.append(true);
2960        }
2961
2962        ls.finish()
2963    }
2964
2965    fn generate_map_array_data() -> MapArray {
2966        let keys_builder = UInt32Builder::new();
2967        let values_builder = UInt32Builder::new();
2968
2969        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
2970
2971        for i in 0..100_000 {
2972            for _j in 0..3 {
2973                builder.keys().append_value(i);
2974                builder.values().append_value(i * 2);
2975            }
2976            builder.append(true).unwrap();
2977        }
2978
2979        builder.finish()
2980    }
2981
2982    #[test]
2983    fn reencode_offsets_when_first_offset_is_not_zero() {
2984        let original_list = generate_list_data::<i32>();
2985        let original_data = original_list.into_data();
2986        let slice_data = original_data.slice(75, 7);
2987        let (new_offsets, original_start, length) =
2988            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
2989        assert_eq!(
2990            vec![0, 3, 6, 9, 12, 15, 18, 21],
2991            new_offsets.typed_data::<i32>()
2992        );
2993        assert_eq!(225, original_start);
2994        assert_eq!(21, length);
2995    }
2996
2997    #[test]
2998    fn reencode_offsets_when_first_offset_is_zero() {
2999        let mut ls = GenericListBuilder::<i32, _>::new(UInt32Builder::new());
3000        // ls = [[], [35, 42]
3001        ls.append(true);
3002        ls.values().append_value(35);
3003        ls.values().append_value(42);
3004        ls.append(true);
3005        let original_list = ls.finish();
3006        let original_data = original_list.into_data();
3007
3008        let slice_data = original_data.slice(1, 1);
3009        let (new_offsets, original_start, length) =
3010            reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
3011        assert_eq!(vec![0, 2], new_offsets.typed_data::<i32>());
3012        assert_eq!(0, original_start);
3013        assert_eq!(2, length);
3014    }
3015
3016    /// Ensure when serde full & sliced versions they are equal to original input.
3017    /// Also ensure serialized sliced version is significantly smaller than serialized full.
3018    fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
3019        // test both full and sliced versions
3020        let in_sliced = in_batch.slice(999, 1);
3021
3022        let bytes_batch = serialize_file(&in_batch);
3023        let bytes_sliced = serialize_file(&in_sliced);
3024
3025        // serializing 1 row should be significantly smaller than serializing 100,000
3026        assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
3027
3028        // ensure both are still valid and equal to originals
3029        let out_batch = deserialize_file(bytes_batch);
3030        assert_eq!(in_batch, out_batch);
3031
3032        let out_sliced = deserialize_file(bytes_sliced);
3033        assert_eq!(in_sliced, out_sliced);
3034    }
3035
3036    #[test]
3037    fn encode_lists() {
3038        let val_inner = Field::new_list_field(DataType::UInt32, true);
3039        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3040        let schema = Arc::new(Schema::new(vec![val_list_field]));
3041
3042        let values = Arc::new(generate_list_data::<i32>());
3043
3044        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3045        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3046    }
3047
3048    #[test]
3049    fn encode_empty_list() {
3050        let val_inner = Field::new_list_field(DataType::UInt32, true);
3051        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
3052        let schema = Arc::new(Schema::new(vec![val_list_field]));
3053
3054        let values = Arc::new(generate_list_data::<i32>());
3055
3056        let in_batch = RecordBatch::try_new(schema, vec![values])
3057            .unwrap()
3058            .slice(999, 0);
3059        let out_batch = deserialize_file(serialize_file(&in_batch));
3060        assert_eq!(in_batch, out_batch);
3061    }
3062
3063    #[test]
3064    fn encode_large_lists() {
3065        let val_inner = Field::new_list_field(DataType::UInt32, true);
3066        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
3067        let schema = Arc::new(Schema::new(vec![val_list_field]));
3068
3069        let values = Arc::new(generate_list_data::<i64>());
3070
3071        // ensure when serde full & sliced versions they are equal to original input
3072        // also ensure serialized sliced version is significantly smaller than serialized full
3073        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3074        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3075    }
3076
3077    #[test]
3078    fn encode_nested_lists() {
3079        let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
3080        let inner_list_field = Arc::new(Field::new_list_field(DataType::List(inner_int), true));
3081        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3082        let schema = Arc::new(Schema::new(vec![list_field]));
3083
3084        let values = Arc::new(generate_nested_list_data::<i32>());
3085
3086        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3087        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3088    }
3089
3090    #[test]
3091    fn encode_nested_lists_starting_at_zero() {
3092        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
3093        let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
3094        let list_field = Field::new("val", DataType::List(inner_list_field), true);
3095        let schema = Arc::new(Schema::new(vec![list_field]));
3096
3097        let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
3098
3099        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3100        roundtrip_ensure_sliced_smaller(in_batch, 1);
3101    }
3102
3103    #[test]
3104    fn encode_map_array() {
3105        let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
3106        let values = Arc::new(Field::new("values", DataType::UInt32, true));
3107        let map_field = Field::new_map("map", "entries", keys, values, false, true);
3108        let schema = Arc::new(Schema::new(vec![map_field]));
3109
3110        let values = Arc::new(generate_map_array_data());
3111
3112        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
3113        roundtrip_ensure_sliced_smaller(in_batch, 1000);
3114    }
3115
3116    #[test]
3117    fn test_decimal128_alignment16_is_sufficient() {
3118        const IPC_ALIGNMENT: usize = 16;
3119
3120        // Test a bunch of different dimensions to ensure alignment is never an issue.
3121        // For example, if we only test `num_cols = 1` then even with alignment 8 this
3122        // test would _happen_ to pass, even though for different dimensions like
3123        // `num_cols = 2` it would fail.
3124        for num_cols in [1, 2, 3, 17, 50, 73, 99] {
3125            let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle
3126
3127            let mut fields = Vec::new();
3128            let mut arrays = Vec::new();
3129            for i in 0..num_cols {
3130                let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3131                let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3132                fields.push(field);
3133                arrays.push(Arc::new(array) as Arc<dyn Array>);
3134            }
3135            let schema = Schema::new(fields);
3136            let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3137
3138            let mut writer = FileWriter::try_new_with_options(
3139                Vec::new(),
3140                batch.schema_ref(),
3141                IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3142            )
3143            .unwrap();
3144            writer.write(&batch).unwrap();
3145            writer.finish().unwrap();
3146
3147            let out: Vec<u8> = writer.into_inner().unwrap();
3148
3149            let buffer = Buffer::from_vec(out);
3150            let trailer_start = buffer.len() - 10;
3151            let footer_len =
3152                read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3153            let footer =
3154                root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3155
3156            let schema = fb_to_schema(footer.schema().unwrap());
3157
3158            // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient
3159            // for `read_record_batch` later on to read the data in a zero-copy manner.
3160            let decoder =
3161                FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3162
3163            let batches = footer.recordBatches().unwrap();
3164
3165            let block = batches.get(0);
3166            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3167            let data = buffer.slice_with_length(block.offset() as _, block_len);
3168
3169            let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
3170
3171            assert_eq!(batch, batch2);
3172        }
3173    }
3174
3175    #[test]
3176    fn test_decimal128_alignment8_is_unaligned() {
3177        const IPC_ALIGNMENT: usize = 8;
3178
3179        let num_cols = 2;
3180        let num_rows = 1;
3181
3182        let mut fields = Vec::new();
3183        let mut arrays = Vec::new();
3184        for i in 0..num_cols {
3185            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3186            let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
3187            fields.push(field);
3188            arrays.push(Arc::new(array) as Arc<dyn Array>);
3189        }
3190        let schema = Schema::new(fields);
3191        let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
3192
3193        let mut writer = FileWriter::try_new_with_options(
3194            Vec::new(),
3195            batch.schema_ref(),
3196            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
3197        )
3198        .unwrap();
3199        writer.write(&batch).unwrap();
3200        writer.finish().unwrap();
3201
3202        let out: Vec<u8> = writer.into_inner().unwrap();
3203
3204        let buffer = Buffer::from_vec(out);
3205        let trailer_start = buffer.len() - 10;
3206        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3207        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3208        let schema = fb_to_schema(footer.schema().unwrap());
3209
3210        // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying
3211        // to an aligned buffer in `ArrayDataBuilder.build_aligned`.
3212        let decoder =
3213            FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3214
3215        let batches = footer.recordBatches().unwrap();
3216
3217        let block = batches.get(0);
3218        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3219        let data = buffer.slice_with_length(block.offset() as _, block_len);
3220
3221        let result = decoder.read_record_batch(block, &data);
3222
3223        let error = result.unwrap_err();
3224        assert_eq!(
3225            error.to_string(),
3226            "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
3227             offset from expected alignment of 16 by 8"
3228        );
3229    }
3230
3231    #[test]
3232    fn test_flush() {
3233        // We write a schema which is small enough to fit into a buffer and not get flushed,
3234        // and then force the write with .flush().
3235        let num_cols = 2;
3236        let mut fields = Vec::new();
3237        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
3238        for i in 0..num_cols {
3239            let field = Field::new(format!("col_{i}"), DataType::Decimal128(38, 10), true);
3240            fields.push(field);
3241        }
3242        let schema = Schema::new(fields);
3243        let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
3244        let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
3245        let mut stream_writer =
3246            StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
3247                .unwrap();
3248        let mut file_writer =
3249            FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
3250
3251        let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
3252        let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
3253        stream_writer.flush().unwrap();
3254        file_writer.flush().unwrap();
3255        let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
3256        let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
3257        let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
3258        // Finishing a stream writes the continuation bytes in MetadataVersion::V5 (4 bytes)
3259        // and then a length of 0 (4 bytes) for a total of 8 bytes.
3260        // Everything before that should have been flushed in the .flush() call.
3261        let expected_stream_flushed_bytes = stream_out.len() - 8;
3262        // A file write is the same as the stream write except for the leading magic string
3263        // ARROW1 plus padding, which is 8 bytes.
3264        let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
3265
3266        assert!(
3267            stream_bytes_written_on_new < stream_bytes_written_on_flush,
3268            "this test makes no sense if flush is not actually required"
3269        );
3270        assert!(
3271            file_bytes_written_on_new < file_bytes_written_on_flush,
3272            "this test makes no sense if flush is not actually required"
3273        );
3274        assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
3275        assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
3276    }
3277
3278    #[test]
3279    fn test_roundtrip_list_of_fixed_list() -> Result<(), ArrowError> {
3280        let l1_type =
3281            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, false)), 3);
3282        let l2_type = DataType::List(Arc::new(Field::new("item", l1_type.clone(), false)));
3283
3284        let l0_builder = Float32Builder::new();
3285        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3).with_field(Arc::new(Field::new(
3286            "item",
3287            DataType::Float32,
3288            false,
3289        )));
3290        let mut l2_builder =
3291            ListBuilder::new(l1_builder).with_field(Arc::new(Field::new("item", l1_type, false)));
3292
3293        for point in [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] {
3294            l2_builder.values().values().append_value(point[0]);
3295            l2_builder.values().values().append_value(point[1]);
3296            l2_builder.values().values().append_value(point[2]);
3297
3298            l2_builder.values().append(true);
3299        }
3300        l2_builder.append(true);
3301
3302        let point = [10., 11., 12.];
3303        l2_builder.values().values().append_value(point[0]);
3304        l2_builder.values().values().append_value(point[1]);
3305        l2_builder.values().values().append_value(point[2]);
3306
3307        l2_builder.values().append(true);
3308        l2_builder.append(true);
3309
3310        let array = Arc::new(l2_builder.finish()) as ArrayRef;
3311
3312        let schema = Arc::new(Schema::new_with_metadata(
3313            vec![Field::new("points", l2_type, false)],
3314            HashMap::default(),
3315        ));
3316
3317        // Test a variety of combinations that include 0 and non-zero offsets
3318        // and also portions or the rest of the array
3319        test_slices(&array, &schema, 0, 1)?;
3320        test_slices(&array, &schema, 0, 2)?;
3321        test_slices(&array, &schema, 1, 1)?;
3322
3323        Ok(())
3324    }
3325
3326    #[test]
3327    fn test_roundtrip_list_of_fixed_list_w_nulls() -> Result<(), ArrowError> {
3328        let l0_builder = Float32Builder::new();
3329        let l1_builder = FixedSizeListBuilder::new(l0_builder, 3);
3330        let mut l2_builder = ListBuilder::new(l1_builder);
3331
3332        for point in [
3333            [Some(1.0), Some(2.0), None],
3334            [Some(4.0), Some(5.0), Some(6.0)],
3335            [None, Some(8.0), Some(9.0)],
3336        ] {
3337            for p in point {
3338                match p {
3339                    Some(p) => l2_builder.values().values().append_value(p),
3340                    None => l2_builder.values().values().append_null(),
3341                }
3342            }
3343
3344            l2_builder.values().append(true);
3345        }
3346        l2_builder.append(true);
3347
3348        let point = [Some(10.), None, None];
3349        for p in point {
3350            match p {
3351                Some(p) => l2_builder.values().values().append_value(p),
3352                None => l2_builder.values().values().append_null(),
3353            }
3354        }
3355
3356        l2_builder.values().append(true);
3357        l2_builder.append(true);
3358
3359        let array = Arc::new(l2_builder.finish()) as ArrayRef;
3360
3361        let schema = Arc::new(Schema::new_with_metadata(
3362            vec![Field::new(
3363                "points",
3364                DataType::List(Arc::new(Field::new(
3365                    "item",
3366                    DataType::FixedSizeList(
3367                        Arc::new(Field::new("item", DataType::Float32, true)),
3368                        3,
3369                    ),
3370                    true,
3371                ))),
3372                true,
3373            )],
3374            HashMap::default(),
3375        ));
3376
3377        // Test a variety of combinations that include 0 and non-zero offsets
3378        // and also portions or the rest of the array
3379        test_slices(&array, &schema, 0, 1)?;
3380        test_slices(&array, &schema, 0, 2)?;
3381        test_slices(&array, &schema, 1, 1)?;
3382
3383        Ok(())
3384    }
3385
3386    fn test_slices(
3387        parent_array: &ArrayRef,
3388        schema: &SchemaRef,
3389        offset: usize,
3390        length: usize,
3391    ) -> Result<(), ArrowError> {
3392        let subarray = parent_array.slice(offset, length);
3393        let original_batch = RecordBatch::try_new(schema.clone(), vec![subarray])?;
3394
3395        let mut bytes = Vec::new();
3396        let mut writer = StreamWriter::try_new(&mut bytes, schema)?;
3397        writer.write(&original_batch)?;
3398        writer.finish()?;
3399
3400        let mut cursor = std::io::Cursor::new(bytes);
3401        let mut reader = StreamReader::try_new(&mut cursor, None)?;
3402        let returned_batch = reader.next().unwrap()?;
3403
3404        assert_eq!(original_batch, returned_batch);
3405
3406        Ok(())
3407    }
3408
3409    #[test]
3410    fn test_roundtrip_fixed_list() -> Result<(), ArrowError> {
3411        let int_builder = Int64Builder::new();
3412        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3)
3413            .with_field(Arc::new(Field::new("item", DataType::Int64, false)));
3414
3415        for point in [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] {
3416            fixed_list_builder.values().append_value(point[0]);
3417            fixed_list_builder.values().append_value(point[1]);
3418            fixed_list_builder.values().append_value(point[2]);
3419
3420            fixed_list_builder.append(true);
3421        }
3422
3423        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
3424
3425        let schema = Arc::new(Schema::new_with_metadata(
3426            vec![Field::new(
3427                "points",
3428                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, false)), 3),
3429                false,
3430            )],
3431            HashMap::default(),
3432        ));
3433
3434        // Test a variety of combinations that include 0 and non-zero offsets
3435        // and also portions or the rest of the array
3436        test_slices(&array, &schema, 0, 4)?;
3437        test_slices(&array, &schema, 0, 2)?;
3438        test_slices(&array, &schema, 1, 3)?;
3439        test_slices(&array, &schema, 2, 1)?;
3440
3441        Ok(())
3442    }
3443
3444    #[test]
3445    fn test_roundtrip_fixed_list_w_nulls() -> Result<(), ArrowError> {
3446        let int_builder = Int64Builder::new();
3447        let mut fixed_list_builder = FixedSizeListBuilder::new(int_builder, 3);
3448
3449        for point in [
3450            [Some(1), Some(2), None],
3451            [Some(4), Some(5), Some(6)],
3452            [None, Some(8), Some(9)],
3453            [Some(10), None, None],
3454        ] {
3455            for p in point {
3456                match p {
3457                    Some(p) => fixed_list_builder.values().append_value(p),
3458                    None => fixed_list_builder.values().append_null(),
3459                }
3460            }
3461
3462            fixed_list_builder.append(true);
3463        }
3464
3465        let array = Arc::new(fixed_list_builder.finish()) as ArrayRef;
3466
3467        let schema = Arc::new(Schema::new_with_metadata(
3468            vec![Field::new(
3469                "points",
3470                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 3),
3471                true,
3472            )],
3473            HashMap::default(),
3474        ));
3475
3476        // Test a variety of combinations that include 0 and non-zero offsets
3477        // and also portions or the rest of the array
3478        test_slices(&array, &schema, 0, 4)?;
3479        test_slices(&array, &schema, 0, 2)?;
3480        test_slices(&array, &schema, 1, 3)?;
3481        test_slices(&array, &schema, 2, 1)?;
3482
3483        Ok(())
3484    }
3485
3486    #[test]
3487    fn test_metadata_encoding_ordering() {
3488        fn create_hash() -> u64 {
3489            let metadata: HashMap<String, String> = [
3490                ("a", "1"), //
3491                ("b", "2"), //
3492                ("c", "3"), //
3493                ("d", "4"), //
3494                ("e", "5"), //
3495            ]
3496            .into_iter()
3497            .map(|(k, v)| (k.to_owned(), v.to_owned()))
3498            .collect();
3499
3500            // Set metadata on both the schema and a field within it.
3501            let schema = Arc::new(
3502                Schema::new(vec![
3503                    Field::new("a", DataType::Int64, true).with_metadata(metadata.clone()),
3504                ])
3505                .with_metadata(metadata)
3506                .clone(),
3507            );
3508            let batch = RecordBatch::new_empty(schema.clone());
3509
3510            let mut bytes = Vec::new();
3511            let mut w = StreamWriter::try_new(&mut bytes, batch.schema_ref()).unwrap();
3512            w.write(&batch).unwrap();
3513            w.finish().unwrap();
3514
3515            let mut h = std::hash::DefaultHasher::new();
3516            h.write(&bytes);
3517            h.finish()
3518        }
3519
3520        let expected = create_hash();
3521
3522        // Since there is randomness in the HashMap and we cannot specify our
3523        // own Hasher for the implementation used for metadata, run the above
3524        // code 20x and verify it does not change. This is not perfect but it
3525        // should be good enough.
3526        let all_passed = (0..20).all(|_| create_hash() == expected);
3527        assert!(all_passed);
3528    }
3529}