arrow_ipc/
reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Readers
19//!
20//! # Notes
21//!
22//! The [`FileReader`] and [`StreamReader`] have similar interfaces,
23//! however the [`FileReader`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27mod stream;
28pub use stream::*;
29
30use arrow_select::concat;
31
32use flatbuffers::{VectorIter, VerifierOptions};
33use std::collections::{HashMap, VecDeque};
34use std::fmt;
35use std::io::{BufReader, Read, Seek, SeekFrom};
36use std::sync::Arc;
37
38use arrow_array::*;
39use arrow_buffer::{
40    ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
41};
42use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
43use arrow_schema::*;
44
45use crate::compression::CompressionCodec;
46use crate::gen::Message::{self};
47use crate::{Block, FieldNode, MetadataVersion, CONTINUATION_MARKER};
48use DataType::*;
49
50/// Read a buffer based on offset and length
51/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
52/// Each constituent buffer is first compressed with the indicated
53/// compressor, and then written with the uncompressed length in the first 8
54/// bytes as a 64-bit little-endian signed integer followed by the compressed
55/// buffer bytes (and then padding as required by the protocol). The
56/// uncompressed length may be set to -1 to indicate that the data that
57/// follows is not compressed, which can be useful for cases where
58/// compression does not yield appreciable savings.
59fn read_buffer(
60    buf: &crate::Buffer,
61    a_data: &Buffer,
62    compression_codec: Option<CompressionCodec>,
63) -> Result<Buffer, ArrowError> {
64    let start_offset = buf.offset() as usize;
65    let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
66    // corner case: empty buffer
67    match (buf_data.is_empty(), compression_codec) {
68        (true, _) | (_, None) => Ok(buf_data),
69        (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
70    }
71}
72impl RecordBatchDecoder<'_> {
73    /// Coordinates reading arrays based on data types.
74    ///
75    /// `variadic_counts` encodes the number of buffers to read for variadic types (e.g., Utf8View, BinaryView)
76    /// When encounter such types, we pop from the front of the queue to get the number of buffers to read.
77    ///
78    /// Notes:
79    /// * In the IPC format, null buffers are always set, but may be empty. We discard them if an array has 0 nulls
80    /// * Numeric values inside list arrays are often stored as 64-bit values regardless of their data type size.
81    ///   We thus:
82    ///     - check if the bit width of non-64-bit numbers is 64, and
83    ///     - read the buffer as 64-bit (signed integer or float), and
84    ///     - cast the 64-bit array to the appropriate data type
85    fn create_array(
86        &mut self,
87        field: &Field,
88        variadic_counts: &mut VecDeque<i64>,
89    ) -> Result<ArrayRef, ArrowError> {
90        let data_type = field.data_type();
91        match data_type {
92            Utf8 | Binary | LargeBinary | LargeUtf8 => {
93                let field_node = self.next_node(field)?;
94                let buffers = [
95                    self.next_buffer()?,
96                    self.next_buffer()?,
97                    self.next_buffer()?,
98                ];
99                self.create_primitive_array(field_node, data_type, &buffers)
100            }
101            BinaryView | Utf8View => {
102                let count = variadic_counts
103                    .pop_front()
104                    .ok_or(ArrowError::IpcError(format!(
105                        "Missing variadic count for {data_type} column"
106                    )))?;
107                let count = count + 2; // view and null buffer.
108                let buffers = (0..count)
109                    .map(|_| self.next_buffer())
110                    .collect::<Result<Vec<_>, _>>()?;
111                let field_node = self.next_node(field)?;
112                self.create_primitive_array(field_node, data_type, &buffers)
113            }
114            FixedSizeBinary(_) => {
115                let field_node = self.next_node(field)?;
116                let buffers = [self.next_buffer()?, self.next_buffer()?];
117                self.create_primitive_array(field_node, data_type, &buffers)
118            }
119            List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
120                let list_node = self.next_node(field)?;
121                let list_buffers = [self.next_buffer()?, self.next_buffer()?];
122                let values = self.create_array(list_field, variadic_counts)?;
123                self.create_list_array(list_node, data_type, &list_buffers, values)
124            }
125            FixedSizeList(ref list_field, _) => {
126                let list_node = self.next_node(field)?;
127                let list_buffers = [self.next_buffer()?];
128                let values = self.create_array(list_field, variadic_counts)?;
129                self.create_list_array(list_node, data_type, &list_buffers, values)
130            }
131            Struct(struct_fields) => {
132                let struct_node = self.next_node(field)?;
133                let null_buffer = self.next_buffer()?;
134
135                // read the arrays for each field
136                let mut struct_arrays = vec![];
137                // TODO investigate whether just knowing the number of buffers could
138                // still work
139                for struct_field in struct_fields {
140                    let child = self.create_array(struct_field, variadic_counts)?;
141                    struct_arrays.push(child);
142                }
143                self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays)
144            }
145            RunEndEncoded(run_ends_field, values_field) => {
146                let run_node = self.next_node(field)?;
147                let run_ends = self.create_array(run_ends_field, variadic_counts)?;
148                let values = self.create_array(values_field, variadic_counts)?;
149
150                let run_array_length = run_node.length() as usize;
151                let builder = ArrayData::builder(data_type.clone())
152                    .len(run_array_length)
153                    .offset(0)
154                    .add_child_data(run_ends.into_data())
155                    .add_child_data(values.into_data())
156                    .null_count(run_node.null_count() as usize);
157
158                self.create_array_from_builder(builder)
159            }
160            // Create dictionary array from RecordBatch
161            Dictionary(_, _) => {
162                let index_node = self.next_node(field)?;
163                let index_buffers = [self.next_buffer()?, self.next_buffer()?];
164
165                #[allow(deprecated)]
166                let dict_id = field.dict_id().ok_or_else(|| {
167                    ArrowError::ParseError(format!("Field {field} does not have dict id"))
168                })?;
169
170                let value_array = self.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
171                    ArrowError::ParseError(format!(
172                        "Cannot find a dictionary batch with dict id: {dict_id}"
173                    ))
174                })?;
175
176                self.create_dictionary_array(
177                    index_node,
178                    data_type,
179                    &index_buffers,
180                    value_array.clone(),
181                )
182            }
183            Union(fields, mode) => {
184                let union_node = self.next_node(field)?;
185                let len = union_node.length() as usize;
186
187                // In V4, union types has validity bitmap
188                // In V5 and later, union types have no validity bitmap
189                if self.version < MetadataVersion::V5 {
190                    self.next_buffer()?;
191                }
192
193                let type_ids: ScalarBuffer<i8> =
194                    self.next_buffer()?.slice_with_length(0, len).into();
195
196                let value_offsets = match mode {
197                    UnionMode::Dense => {
198                        let offsets: ScalarBuffer<i32> =
199                            self.next_buffer()?.slice_with_length(0, len * 4).into();
200                        Some(offsets)
201                    }
202                    UnionMode::Sparse => None,
203                };
204
205                let mut children = Vec::with_capacity(fields.len());
206
207                for (_id, field) in fields.iter() {
208                    let child = self.create_array(field, variadic_counts)?;
209                    children.push(child);
210                }
211
212                let array = if self.skip_validation.get() {
213                    // safety: flag can only be set via unsafe code
214                    unsafe {
215                        UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children)
216                    }
217                } else {
218                    UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?
219                };
220                Ok(Arc::new(array))
221            }
222            Null => {
223                let node = self.next_node(field)?;
224                let length = node.length();
225                let null_count = node.null_count();
226
227                if length != null_count {
228                    return Err(ArrowError::SchemaError(format!(
229                        "Field {field} of NullArray has unequal null_count {null_count} and len {length}"
230                    )));
231                }
232
233                let builder = ArrayData::builder(data_type.clone())
234                    .len(length as usize)
235                    .offset(0);
236                self.create_array_from_builder(builder)
237            }
238            _ => {
239                let field_node = self.next_node(field)?;
240                let buffers = [self.next_buffer()?, self.next_buffer()?];
241                self.create_primitive_array(field_node, data_type, &buffers)
242            }
243        }
244    }
245
246    /// Reads the correct number of buffers based on data type and null_count, and creates a
247    /// primitive array ref
248    fn create_primitive_array(
249        &self,
250        field_node: &FieldNode,
251        data_type: &DataType,
252        buffers: &[Buffer],
253    ) -> Result<ArrayRef, ArrowError> {
254        let length = field_node.length() as usize;
255        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
256        let mut builder = match data_type {
257            Utf8 | Binary | LargeBinary | LargeUtf8 => {
258                // read 3 buffers: null buffer (optional), offsets buffer and data buffer
259                ArrayData::builder(data_type.clone())
260                    .len(length)
261                    .buffers(buffers[1..3].to_vec())
262                    .null_bit_buffer(null_buffer)
263            }
264            BinaryView | Utf8View => ArrayData::builder(data_type.clone())
265                .len(length)
266                .buffers(buffers[1..].to_vec())
267                .null_bit_buffer(null_buffer),
268            _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => {
269                // read 2 buffers: null buffer (optional) and data buffer
270                ArrayData::builder(data_type.clone())
271                    .len(length)
272                    .add_buffer(buffers[1].clone())
273                    .null_bit_buffer(null_buffer)
274            }
275            t => unreachable!("Data type {:?} either unsupported or not primitive", t),
276        };
277
278        builder = builder.null_count(field_node.null_count() as usize);
279
280        self.create_array_from_builder(builder)
281    }
282
283    /// Update the ArrayDataBuilder based on settings in this decoder
284    fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result<ArrayRef, ArrowError> {
285        let mut builder = builder.align_buffers(!self.require_alignment);
286        if self.skip_validation.get() {
287            // SAFETY: flag can only be set via unsafe code
288            unsafe { builder = builder.skip_validation(true) }
289        };
290        Ok(make_array(builder.build()?))
291    }
292
293    /// Reads the correct number of buffers based on list type and null_count, and creates a
294    /// list array ref
295    fn create_list_array(
296        &self,
297        field_node: &FieldNode,
298        data_type: &DataType,
299        buffers: &[Buffer],
300        child_array: ArrayRef,
301    ) -> Result<ArrayRef, ArrowError> {
302        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
303        let length = field_node.length() as usize;
304        let child_data = child_array.into_data();
305        let mut builder = match data_type {
306            List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
307                .len(length)
308                .add_buffer(buffers[1].clone())
309                .add_child_data(child_data)
310                .null_bit_buffer(null_buffer),
311
312            FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
313                .len(length)
314                .add_child_data(child_data)
315                .null_bit_buffer(null_buffer),
316
317            _ => unreachable!("Cannot create list or map array from {:?}", data_type),
318        };
319
320        builder = builder.null_count(field_node.null_count() as usize);
321
322        self.create_array_from_builder(builder)
323    }
324
325    fn create_struct_array(
326        &self,
327        struct_node: &FieldNode,
328        null_buffer: Buffer,
329        struct_fields: &Fields,
330        struct_arrays: Vec<ArrayRef>,
331    ) -> Result<ArrayRef, ArrowError> {
332        let null_count = struct_node.null_count() as usize;
333        let len = struct_node.length() as usize;
334        let skip_validation = self.skip_validation.get();
335
336        let nulls = if null_count > 0 {
337            let validity_buffer = BooleanBuffer::new(null_buffer, 0, len);
338            let null_buffer = if skip_validation {
339                // safety: flag can only be set via unsafe code
340                unsafe { NullBuffer::new_unchecked(validity_buffer, null_count) }
341            } else {
342                let null_buffer = NullBuffer::new(validity_buffer);
343
344                if null_buffer.null_count() != null_count {
345                    return Err(ArrowError::InvalidArgumentError(format!(
346                        "null_count value ({}) doesn't match actual number of nulls in array ({})",
347                        null_count,
348                        null_buffer.null_count()
349                    )));
350                }
351
352                null_buffer
353            };
354
355            Some(null_buffer)
356        } else {
357            None
358        };
359        if struct_arrays.is_empty() {
360            // `StructArray::from` can't infer the correct row count
361            // if we have zero fields
362            return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
363        }
364
365        let struct_array = if skip_validation {
366            // safety: flag can only be set via unsafe code
367            unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) }
368        } else {
369            StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
370        };
371
372        Ok(Arc::new(struct_array))
373    }
374
375    /// Reads the correct number of buffers based on list type and null_count, and creates a
376    /// list array ref
377    fn create_dictionary_array(
378        &self,
379        field_node: &FieldNode,
380        data_type: &DataType,
381        buffers: &[Buffer],
382        value_array: ArrayRef,
383    ) -> Result<ArrayRef, ArrowError> {
384        if let Dictionary(_, _) = *data_type {
385            let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
386            let builder = ArrayData::builder(data_type.clone())
387                .len(field_node.length() as usize)
388                .add_buffer(buffers[1].clone())
389                .add_child_data(value_array.into_data())
390                .null_bit_buffer(null_buffer)
391                .null_count(field_node.null_count() as usize);
392            self.create_array_from_builder(builder)
393        } else {
394            unreachable!("Cannot create dictionary array from {:?}", data_type)
395        }
396    }
397}
398
399/// State for decoding Arrow arrays from an [IPC RecordBatch] structure to
400/// [`RecordBatch`]
401///
402/// [IPC RecordBatch]: crate::RecordBatch
403///
404pub struct RecordBatchDecoder<'a> {
405    /// The flatbuffers encoded record batch
406    batch: crate::RecordBatch<'a>,
407    /// The output schema
408    schema: SchemaRef,
409    /// Decoded dictionaries indexed by dictionary id
410    dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
411    /// Optional compression codec
412    compression: Option<CompressionCodec>,
413    /// The format version
414    version: MetadataVersion,
415    /// The raw data buffer
416    data: &'a Buffer,
417    /// The fields comprising this array
418    nodes: VectorIter<'a, FieldNode>,
419    /// The buffers comprising this array
420    buffers: VectorIter<'a, crate::Buffer>,
421    /// Projection (subset of columns) to read, if any
422    /// See [`RecordBatchDecoder::with_projection`] for details
423    projection: Option<&'a [usize]>,
424    /// Are buffers required to already be aligned? See
425    /// [`RecordBatchDecoder::with_require_alignment`] for details
426    require_alignment: bool,
427    /// Should validation be skipped when reading data? Defaults to false.
428    ///
429    /// See [`FileDecoder::with_skip_validation`] for details.
430    skip_validation: UnsafeFlag,
431}
432
433impl<'a> RecordBatchDecoder<'a> {
434    /// Create a reader for decoding arrays from an encoded [`RecordBatch`]
435    fn try_new(
436        buf: &'a Buffer,
437        batch: crate::RecordBatch<'a>,
438        schema: SchemaRef,
439        dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
440        metadata: &'a MetadataVersion,
441    ) -> Result<Self, ArrowError> {
442        let buffers = batch.buffers().ok_or_else(|| {
443            ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string())
444        })?;
445        let field_nodes = batch.nodes().ok_or_else(|| {
446            ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string())
447        })?;
448
449        let batch_compression = batch.compression();
450        let compression = batch_compression
451            .map(|batch_compression| batch_compression.codec().try_into())
452            .transpose()?;
453
454        Ok(Self {
455            batch,
456            schema,
457            dictionaries_by_id,
458            compression,
459            version: *metadata,
460            data: buf,
461            nodes: field_nodes.iter(),
462            buffers: buffers.iter(),
463            projection: None,
464            require_alignment: false,
465            skip_validation: UnsafeFlag::new(),
466        })
467    }
468
469    /// Set the projection (default: None)
470    ///
471    /// If set, the projection is the list  of column indices
472    /// that will be read
473    pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self {
474        self.projection = projection;
475        self
476    }
477
478    /// Set require_alignment (default: false)
479    ///
480    /// If true, buffers must be aligned appropriately or error will
481    /// result. If false, buffers will be copied to aligned buffers
482    /// if necessary.
483    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
484        self.require_alignment = require_alignment;
485        self
486    }
487
488    /// Specifies if validation should be skipped when reading data (defaults to `false`)
489    ///
490    /// Note this API is somewhat "funky" as it allows the caller to skip validation
491    /// without having to use `unsafe` code. If this is ever made public
492    /// it should be made clearer that this is a potentially unsafe by
493    /// using an `unsafe` function that takes a boolean flag.
494    ///
495    /// # Safety
496    ///
497    /// Relies on the caller only passing a flag with `true` value if they are
498    /// certain that the data is valid
499    pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
500        self.skip_validation = skip_validation;
501        self
502    }
503
504    /// Read the record batch, consuming the reader
505    fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
506        let mut variadic_counts: VecDeque<i64> = self
507            .batch
508            .variadicBufferCounts()
509            .into_iter()
510            .flatten()
511            .collect();
512
513        let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
514
515        let schema = Arc::clone(&self.schema);
516        if let Some(projection) = self.projection {
517            let mut arrays = vec![];
518            // project fields
519            for (idx, field) in schema.fields().iter().enumerate() {
520                // Create array for projected field
521                if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
522                    let child = self.create_array(field, &mut variadic_counts)?;
523                    arrays.push((proj_idx, child));
524                } else {
525                    self.skip_field(field, &mut variadic_counts)?;
526                }
527            }
528
529            arrays.sort_by_key(|t| t.0);
530
531            let schema = Arc::new(schema.project(projection)?);
532            let columns = arrays.into_iter().map(|t| t.1).collect::<Vec<_>>();
533
534            if self.skip_validation.get() {
535                // Safety: setting `skip_validation` requires `unsafe`, user assures data is valid
536                unsafe {
537                    Ok(RecordBatch::new_unchecked(
538                        schema,
539                        columns,
540                        self.batch.length() as usize,
541                    ))
542                }
543            } else {
544                assert!(variadic_counts.is_empty());
545                RecordBatch::try_new_with_options(schema, columns, &options)
546            }
547        } else {
548            let mut children = vec![];
549            // keep track of index as lists require more than one node
550            for field in schema.fields() {
551                let child = self.create_array(field, &mut variadic_counts)?;
552                children.push(child);
553            }
554
555            if self.skip_validation.get() {
556                // Safety: setting `skip_validation` requires `unsafe`, user assures data is valid
557                unsafe {
558                    Ok(RecordBatch::new_unchecked(
559                        schema,
560                        children,
561                        self.batch.length() as usize,
562                    ))
563                }
564            } else {
565                assert!(variadic_counts.is_empty());
566                RecordBatch::try_new_with_options(schema, children, &options)
567            }
568        }
569    }
570
571    fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
572        read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
573    }
574
575    fn skip_buffer(&mut self) {
576        self.buffers.next().unwrap();
577    }
578
579    fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
580        self.nodes.next().ok_or_else(|| {
581            ArrowError::SchemaError(format!(
582                "Invalid data for schema. {field} refers to node not found in schema",
583            ))
584        })
585    }
586
587    fn skip_field(
588        &mut self,
589        field: &Field,
590        variadic_count: &mut VecDeque<i64>,
591    ) -> Result<(), ArrowError> {
592        self.next_node(field)?;
593
594        match field.data_type() {
595            Utf8 | Binary | LargeBinary | LargeUtf8 => {
596                for _ in 0..3 {
597                    self.skip_buffer()
598                }
599            }
600            Utf8View | BinaryView => {
601                let count = variadic_count
602                    .pop_front()
603                    .ok_or(ArrowError::IpcError(format!(
604                        "Missing variadic count for {} column",
605                        field.data_type()
606                    )))?;
607                let count = count + 2; // view and null buffer.
608                for _i in 0..count {
609                    self.skip_buffer()
610                }
611            }
612            FixedSizeBinary(_) => {
613                self.skip_buffer();
614                self.skip_buffer();
615            }
616            List(list_field) | LargeList(list_field) | Map(list_field, _) => {
617                self.skip_buffer();
618                self.skip_buffer();
619                self.skip_field(list_field, variadic_count)?;
620            }
621            FixedSizeList(list_field, _) => {
622                self.skip_buffer();
623                self.skip_field(list_field, variadic_count)?;
624            }
625            Struct(struct_fields) => {
626                self.skip_buffer();
627
628                // skip for each field
629                for struct_field in struct_fields {
630                    self.skip_field(struct_field, variadic_count)?
631                }
632            }
633            RunEndEncoded(run_ends_field, values_field) => {
634                self.skip_field(run_ends_field, variadic_count)?;
635                self.skip_field(values_field, variadic_count)?;
636            }
637            Dictionary(_, _) => {
638                self.skip_buffer(); // Nulls
639                self.skip_buffer(); // Indices
640            }
641            Union(fields, mode) => {
642                self.skip_buffer(); // Nulls
643
644                match mode {
645                    UnionMode::Dense => self.skip_buffer(),
646                    UnionMode::Sparse => {}
647                };
648
649                for (_, field) in fields.iter() {
650                    self.skip_field(field, variadic_count)?
651                }
652            }
653            Null => {} // No buffer increases
654            _ => {
655                self.skip_buffer();
656                self.skip_buffer();
657            }
658        };
659        Ok(())
660    }
661}
662
663/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`.
664///
665/// If `require_alignment` is true, this function will return an error if any array data in the
666/// input `buf` is not properly aligned.
667/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`].
668///
669/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer
670/// and copy over the data if any array data in the input `buf` is not properly aligned.
671/// (Properly aligned array data will remain zero-copy.)
672/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`].
673pub fn read_record_batch(
674    buf: &Buffer,
675    batch: crate::RecordBatch,
676    schema: SchemaRef,
677    dictionaries_by_id: &HashMap<i64, ArrayRef>,
678    projection: Option<&[usize]>,
679    metadata: &MetadataVersion,
680) -> Result<RecordBatch, ArrowError> {
681    RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
682        .with_projection(projection)
683        .with_require_alignment(false)
684        .read_record_batch()
685}
686
687/// Read the dictionary from the buffer and provided metadata,
688/// updating the `dictionaries_by_id` with the resulting dictionary
689pub fn read_dictionary(
690    buf: &Buffer,
691    batch: crate::DictionaryBatch,
692    schema: &Schema,
693    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
694    metadata: &MetadataVersion,
695) -> Result<(), ArrowError> {
696    read_dictionary_impl(
697        buf,
698        batch,
699        schema,
700        dictionaries_by_id,
701        metadata,
702        false,
703        UnsafeFlag::new(),
704    )
705}
706
707fn read_dictionary_impl(
708    buf: &Buffer,
709    batch: crate::DictionaryBatch,
710    schema: &Schema,
711    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
712    metadata: &MetadataVersion,
713    require_alignment: bool,
714    skip_validation: UnsafeFlag,
715) -> Result<(), ArrowError> {
716    let id = batch.id();
717
718    let dictionary_values = get_dictionary_values(
719        buf,
720        batch,
721        schema,
722        dictionaries_by_id,
723        metadata,
724        require_alignment,
725        skip_validation,
726    )?;
727
728    update_dictionaries(dictionaries_by_id, batch.isDelta(), id, dictionary_values)?;
729
730    Ok(())
731}
732
733/// Updates the `dictionaries_by_id` with the provided dictionary values and id.
734///
735/// # Errors
736/// - If `is_delta` is true and there is no existing dictionary for the given
737///   `dict_id`
738/// - If `is_delta` is true and the concatenation of the existing and new
739///   dictionary fails. This usually signals a type mismatch between the old and
740///   new values.
741fn update_dictionaries(
742    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
743    is_delta: bool,
744    dict_id: i64,
745    dict_values: ArrayRef,
746) -> Result<(), ArrowError> {
747    if !is_delta {
748        // We don't currently record the isOrdered field. This could be general
749        // attributes of arrays.
750        // Add (possibly multiple) array refs to the dictionaries array.
751        dictionaries_by_id.insert(dict_id, dict_values.clone());
752        return Ok(());
753    }
754
755    let existing = dictionaries_by_id.get(&dict_id).ok_or_else(|| {
756        ArrowError::InvalidArgumentError(format!(
757            "No existing dictionary for delta dictionary with id '{dict_id}'"
758        ))
759    })?;
760
761    let combined = concat::concat(&[existing, &dict_values]).map_err(|e| {
762        ArrowError::InvalidArgumentError(format!("Failed to concat delta dictionary: {e}"))
763    })?;
764
765    dictionaries_by_id.insert(dict_id, combined);
766
767    Ok(())
768}
769
770/// Given a dictionary batch IPC message/body along with the full state of a
771/// stream including schema, dictionary cache, metadata, and other flags, this
772/// function will parse the buffer into an array of dictionary values.
773fn get_dictionary_values(
774    buf: &Buffer,
775    batch: crate::DictionaryBatch,
776    schema: &Schema,
777    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
778    metadata: &MetadataVersion,
779    require_alignment: bool,
780    skip_validation: UnsafeFlag,
781) -> Result<ArrayRef, ArrowError> {
782    let id = batch.id();
783    #[allow(deprecated)]
784    let fields_using_this_dictionary = schema.fields_with_dict_id(id);
785    let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
786        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
787    })?;
788
789    // As the dictionary batch does not contain the type of the
790    // values array, we need to retrieve this from the schema.
791    // Get an array representing this dictionary's values.
792    let dictionary_values: ArrayRef = match first_field.data_type() {
793        DataType::Dictionary(_, ref value_type) => {
794            // Make a fake schema for the dictionary batch.
795            let value = value_type.as_ref().clone();
796            let schema = Schema::new(vec![Field::new("", value, true)]);
797            // Read a single column
798            let record_batch = RecordBatchDecoder::try_new(
799                buf,
800                batch.data().unwrap(),
801                Arc::new(schema),
802                dictionaries_by_id,
803                metadata,
804            )?
805            .with_require_alignment(require_alignment)
806            .with_skip_validation(skip_validation)
807            .read_record_batch()?;
808
809            Some(record_batch.column(0).clone())
810        }
811        _ => None,
812    }
813    .ok_or_else(|| {
814        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
815    })?;
816
817    Ok(dictionary_values)
818}
819
820/// Read the data for a given block
821fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
822    reader.seek(SeekFrom::Start(block.offset() as u64))?;
823    let body_len = block.bodyLength().to_usize().unwrap();
824    let metadata_len = block.metaDataLength().to_usize().unwrap();
825    let total_len = body_len.checked_add(metadata_len).unwrap();
826
827    let mut buf = MutableBuffer::from_len_zeroed(total_len);
828    reader.read_exact(&mut buf)?;
829    Ok(buf.into())
830}
831
832/// Parse an encapsulated message
833///
834/// <https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format>
835fn parse_message(buf: &[u8]) -> Result<Message::Message<'_>, ArrowError> {
836    let buf = match buf[..4] == CONTINUATION_MARKER {
837        true => &buf[8..],
838        false => &buf[4..],
839    };
840    crate::root_as_message(buf)
841        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
842}
843
844/// Read the footer length from the last 10 bytes of an Arrow IPC file
845///
846/// Expects a 4 byte footer length followed by `b"ARROW1"`
847pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
848    if buf[4..] != super::ARROW_MAGIC {
849        return Err(ArrowError::ParseError(
850            "Arrow file does not contain correct footer".to_string(),
851        ));
852    }
853
854    // read footer length
855    let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
856    footer_len
857        .try_into()
858        .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
859}
860
861/// A low-level, push-based interface for reading an IPC file
862///
863/// For a higher-level interface see [`FileReader`]
864///
865/// For an example of using this API with `mmap` see the [`zero_copy_ipc`] example.
866///
867/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
868///
869/// ```
870/// # use std::sync::Arc;
871/// # use arrow_array::*;
872/// # use arrow_array::types::Int32Type;
873/// # use arrow_buffer::Buffer;
874/// # use arrow_ipc::convert::fb_to_schema;
875/// # use arrow_ipc::reader::{FileDecoder, read_footer_length};
876/// # use arrow_ipc::root_as_footer;
877/// # use arrow_ipc::writer::FileWriter;
878/// // Write an IPC file
879///
880/// let batch = RecordBatch::try_from_iter([
881///     ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
882///     ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
883///     ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", "hello", "world"])) as _),
884/// ]).unwrap();
885///
886/// let schema = batch.schema();
887///
888/// let mut out = Vec::with_capacity(1024);
889/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap();
890/// writer.write(&batch).unwrap();
891/// writer.finish().unwrap();
892///
893/// drop(writer);
894///
895/// // Read IPC file
896///
897/// let buffer = Buffer::from_vec(out);
898/// let trailer_start = buffer.len() - 10;
899/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
900/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
901///
902/// let back = fb_to_schema(footer.schema().unwrap());
903/// assert_eq!(&back, schema.as_ref());
904///
905/// let mut decoder = FileDecoder::new(schema, footer.version());
906///
907/// // Read dictionaries
908/// for block in footer.dictionaries().iter().flatten() {
909///     let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
910///     let data = buffer.slice_with_length(block.offset() as _, block_len);
911///     decoder.read_dictionary(&block, &data).unwrap();
912/// }
913///
914/// // Read record batch
915/// let batches = footer.recordBatches().unwrap();
916/// assert_eq!(batches.len(), 1); // Only wrote a single batch
917///
918/// let block = batches.get(0);
919/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
920/// let data = buffer.slice_with_length(block.offset() as _, block_len);
921/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap();
922///
923/// assert_eq!(batch, back);
924/// ```
925#[derive(Debug)]
926pub struct FileDecoder {
927    schema: SchemaRef,
928    dictionaries: HashMap<i64, ArrayRef>,
929    version: MetadataVersion,
930    projection: Option<Vec<usize>>,
931    require_alignment: bool,
932    skip_validation: UnsafeFlag,
933}
934
935impl FileDecoder {
936    /// Create a new [`FileDecoder`] with the given schema and version
937    pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
938        Self {
939            schema,
940            version,
941            dictionaries: Default::default(),
942            projection: None,
943            require_alignment: false,
944            skip_validation: UnsafeFlag::new(),
945        }
946    }
947
948    /// Specify a projection
949    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
950        self.projection = Some(projection);
951        self
952    }
953
954    /// Specifies if the array data in input buffers is required to be properly aligned.
955    ///
956    /// If `require_alignment` is true, this decoder will return an error if any array data in the
957    /// input `buf` is not properly aligned.
958    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
959    /// [`arrow_data::ArrayData`].
960    ///
961    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
962    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
963    /// properly aligned. (Properly aligned array data will remain zero-copy.)
964    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
965    /// [`arrow_data::ArrayData`].
966    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
967        self.require_alignment = require_alignment;
968        self
969    }
970
971    /// Specifies if validation should be skipped when reading data (defaults to `false`)
972    ///
973    /// # Safety
974    ///
975    /// This flag must only be set to `true` when you trust the input data and are sure the data you are
976    /// reading is a valid Arrow IPC file, otherwise undefined behavior may
977    /// result.
978    ///
979    /// For example, some programs may wish to trust reading IPC files written
980    /// by the same process that created the files.
981    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
982        self.skip_validation.set(skip_validation);
983        self
984    }
985
986    fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message::Message<'a>, ArrowError> {
987        let message = parse_message(buf)?;
988
989        // some old test data's footer metadata is not set, so we account for that
990        if self.version != MetadataVersion::V1 && message.version() != self.version {
991            return Err(ArrowError::IpcError(
992                "Could not read IPC message as metadata versions mismatch".to_string(),
993            ));
994        }
995        Ok(message)
996    }
997
998    /// Read the dictionary with the given block and data buffer
999    pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
1000        let message = self.read_message(buf)?;
1001        match message.header_type() {
1002            crate::MessageHeader::DictionaryBatch => {
1003                let batch = message.header_as_dictionary_batch().unwrap();
1004                read_dictionary_impl(
1005                    &buf.slice(block.metaDataLength() as _),
1006                    batch,
1007                    &self.schema,
1008                    &mut self.dictionaries,
1009                    &message.version(),
1010                    self.require_alignment,
1011                    self.skip_validation.clone(),
1012                )
1013            }
1014            t => Err(ArrowError::ParseError(format!(
1015                "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
1016            ))),
1017        }
1018    }
1019
1020    /// Read the RecordBatch with the given block and data buffer
1021    pub fn read_record_batch(
1022        &self,
1023        block: &Block,
1024        buf: &Buffer,
1025    ) -> Result<Option<RecordBatch>, ArrowError> {
1026        let message = self.read_message(buf)?;
1027        match message.header_type() {
1028            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
1029                "Not expecting a schema when messages are read".to_string(),
1030            )),
1031            crate::MessageHeader::RecordBatch => {
1032                let batch = message.header_as_record_batch().ok_or_else(|| {
1033                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1034                })?;
1035                // read the block that makes up the record batch into a buffer
1036                RecordBatchDecoder::try_new(
1037                    &buf.slice(block.metaDataLength() as _),
1038                    batch,
1039                    self.schema.clone(),
1040                    &self.dictionaries,
1041                    &message.version(),
1042                )?
1043                .with_projection(self.projection.as_deref())
1044                .with_require_alignment(self.require_alignment)
1045                .with_skip_validation(self.skip_validation.clone())
1046                .read_record_batch()
1047                .map(Some)
1048            }
1049            crate::MessageHeader::NONE => Ok(None),
1050            t => Err(ArrowError::InvalidArgumentError(format!(
1051                "Reading types other than record batches not yet supported, unable to read {t:?}"
1052            ))),
1053        }
1054    }
1055}
1056
1057/// Build an Arrow [`FileReader`] with custom options.
1058#[derive(Debug)]
1059pub struct FileReaderBuilder {
1060    /// Optional projection for which columns to load (zero-based column indices)
1061    projection: Option<Vec<usize>>,
1062    /// Passed through to construct [`VerifierOptions`]
1063    max_footer_fb_tables: usize,
1064    /// Passed through to construct [`VerifierOptions`]
1065    max_footer_fb_depth: usize,
1066}
1067
1068impl Default for FileReaderBuilder {
1069    fn default() -> Self {
1070        let verifier_options = VerifierOptions::default();
1071        Self {
1072            max_footer_fb_tables: verifier_options.max_tables,
1073            max_footer_fb_depth: verifier_options.max_depth,
1074            projection: None,
1075        }
1076    }
1077}
1078
1079impl FileReaderBuilder {
1080    /// Options for creating a new [`FileReader`].
1081    ///
1082    /// To convert a builder into a reader, call [`FileReaderBuilder::build`].
1083    pub fn new() -> Self {
1084        Self::default()
1085    }
1086
1087    /// Optional projection for which columns to load (zero-based column indices).
1088    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1089        self.projection = Some(projection);
1090        self
1091    }
1092
1093    /// Flatbuffers option for parsing the footer. Controls the max number of fields and
1094    /// metadata key-value pairs that can be parsed from the schema of the footer.
1095    ///
1096    /// By default this is set to `1_000_000` which roughly translates to a schema with
1097    /// no metadata key-value pairs but 499,999 fields.
1098    ///
1099    /// This default limit is enforced to protect against malicious files with a massive
1100    /// amount of flatbuffer tables which could cause a denial of service attack.
1101    ///
1102    /// If you need to ingest a trusted file with a massive number of fields and/or
1103    /// metadata key-value pairs and are facing the error `"Unable to get root as
1104    /// footer: TooManyTables"` then increase this parameter as necessary.
1105    pub fn with_max_footer_fb_tables(mut self, max_footer_fb_tables: usize) -> Self {
1106        self.max_footer_fb_tables = max_footer_fb_tables;
1107        self
1108    }
1109
1110    /// Flatbuffers option for parsing the footer. Controls the max depth for schemas with
1111    /// nested fields parsed from the footer.
1112    ///
1113    /// By default this is set to `64` which roughly translates to a schema with
1114    /// a field nested 60 levels down through other struct fields.
1115    ///
1116    /// This default limit is enforced to protect against malicious files with a extremely
1117    /// deep flatbuffer structure which could cause a denial of service attack.
1118    ///
1119    /// If you need to ingest a trusted file with a deeply nested field and are facing the
1120    /// error `"Unable to get root as footer: DepthLimitReached"` then increase this
1121    /// parameter as necessary.
1122    pub fn with_max_footer_fb_depth(mut self, max_footer_fb_depth: usize) -> Self {
1123        self.max_footer_fb_depth = max_footer_fb_depth;
1124        self
1125    }
1126
1127    /// Build [`FileReader`] with given reader.
1128    pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<FileReader<R>, ArrowError> {
1129        // Space for ARROW_MAGIC (6 bytes) and length (4 bytes)
1130        let mut buffer = [0; 10];
1131        reader.seek(SeekFrom::End(-10))?;
1132        reader.read_exact(&mut buffer)?;
1133
1134        let footer_len = read_footer_length(buffer)?;
1135
1136        // read footer
1137        let mut footer_data = vec![0; footer_len];
1138        reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
1139        reader.read_exact(&mut footer_data)?;
1140
1141        let verifier_options = VerifierOptions {
1142            max_tables: self.max_footer_fb_tables,
1143            max_depth: self.max_footer_fb_depth,
1144            ..Default::default()
1145        };
1146        let footer = crate::root_as_footer_with_opts(&verifier_options, &footer_data[..]).map_err(
1147            |err| ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")),
1148        )?;
1149
1150        let blocks = footer.recordBatches().ok_or_else(|| {
1151            ArrowError::ParseError("Unable to get record batches from IPC Footer".to_string())
1152        })?;
1153
1154        let total_blocks = blocks.len();
1155
1156        let ipc_schema = footer.schema().unwrap();
1157        if !ipc_schema.endianness().equals_to_target_endianness() {
1158            return Err(ArrowError::IpcError(
1159                "the endianness of the source system does not match the endianness of the target system.".to_owned()
1160            ));
1161        }
1162
1163        let schema = crate::convert::fb_to_schema(ipc_schema);
1164
1165        let mut custom_metadata = HashMap::new();
1166        if let Some(fb_custom_metadata) = footer.custom_metadata() {
1167            for kv in fb_custom_metadata.into_iter() {
1168                custom_metadata.insert(
1169                    kv.key().unwrap().to_string(),
1170                    kv.value().unwrap().to_string(),
1171                );
1172            }
1173        }
1174
1175        let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
1176        if let Some(projection) = self.projection {
1177            decoder = decoder.with_projection(projection)
1178        }
1179
1180        // Create an array of optional dictionary value arrays, one per field.
1181        if let Some(dictionaries) = footer.dictionaries() {
1182            for block in dictionaries {
1183                let buf = read_block(&mut reader, block)?;
1184                decoder.read_dictionary(block, &buf)?;
1185            }
1186        }
1187
1188        Ok(FileReader {
1189            reader,
1190            blocks: blocks.iter().copied().collect(),
1191            current_block: 0,
1192            total_blocks,
1193            decoder,
1194            custom_metadata,
1195        })
1196    }
1197}
1198
1199/// Arrow File Reader
1200///
1201/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC File Format],
1202/// providing random access to the record batches.
1203///
1204/// # See Also
1205///
1206/// * [`Self::set_index`] for random access
1207/// * [`StreamReader`] for reading streaming data
1208///
1209/// # Example: Reading from a `File`
1210/// ```
1211/// # use std::io::Cursor;
1212/// use arrow_array::record_batch;
1213/// # use arrow_ipc::reader::FileReader;
1214/// # use arrow_ipc::writer::FileWriter;
1215/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1216/// # let mut file = vec![]; // mimic a stream for the example
1217/// # {
1218/// #  let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1219/// #  writer.write(&batch).unwrap();
1220/// #  writer.write(&batch).unwrap();
1221/// #  writer.finish().unwrap();
1222/// # }
1223/// # let mut file = Cursor::new(&file);
1224/// let projection = None; // read all columns
1225/// let mut reader = FileReader::try_new(&mut file, projection).unwrap();
1226/// // Position the reader to the second batch
1227/// reader.set_index(1).unwrap();
1228/// // read batches from the reader using the Iterator trait
1229/// let mut num_rows = 0;
1230/// for batch in reader {
1231///    let batch = batch.unwrap();
1232///    num_rows += batch.num_rows();
1233/// }
1234/// assert_eq!(num_rows, 3);
1235/// ```
1236/// # Example: Reading from `mmap`ed file
1237///
1238/// For an example creating Arrays without copying using  memory mapped (`mmap`)
1239/// files see the [`zero_copy_ipc`] example.
1240///
1241/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1242/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
1243pub struct FileReader<R> {
1244    /// File reader that supports reading and seeking
1245    reader: R,
1246
1247    /// The decoder
1248    decoder: FileDecoder,
1249
1250    /// The blocks in the file
1251    ///
1252    /// A block indicates the regions in the file to read to get data
1253    blocks: Vec<Block>,
1254
1255    /// A counter to keep track of the current block that should be read
1256    current_block: usize,
1257
1258    /// The total number of blocks, which may contain record batches and other types
1259    total_blocks: usize,
1260
1261    /// User defined metadata
1262    custom_metadata: HashMap<String, String>,
1263}
1264
1265impl<R> fmt::Debug for FileReader<R> {
1266    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1267        f.debug_struct("FileReader<R>")
1268            .field("decoder", &self.decoder)
1269            .field("blocks", &self.blocks)
1270            .field("current_block", &self.current_block)
1271            .field("total_blocks", &self.total_blocks)
1272            .finish_non_exhaustive()
1273    }
1274}
1275
1276impl<R: Read + Seek> FileReader<BufReader<R>> {
1277    /// Try to create a new file reader with the reader wrapped in a BufReader.
1278    ///
1279    /// See [`FileReader::try_new`] for an unbuffered version.
1280    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1281        Self::try_new(BufReader::new(reader), projection)
1282    }
1283}
1284
1285impl<R: Read + Seek> FileReader<R> {
1286    /// Try to create a new file reader.
1287    ///
1288    /// There is no internal buffering. If buffered reads are needed you likely want to use
1289    /// [`FileReader::try_new_buffered`] instead.    
1290    ///
1291    /// # Errors
1292    ///
1293    /// An ['Err'](Result::Err) may be returned if:
1294    /// - the file does not meet the Arrow Format footer requirements, or
1295    /// - file endianness does not match the target endianness.
1296    pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1297        let builder = FileReaderBuilder {
1298            projection,
1299            ..Default::default()
1300        };
1301        builder.build(reader)
1302    }
1303
1304    /// Return user defined customized metadata
1305    pub fn custom_metadata(&self) -> &HashMap<String, String> {
1306        &self.custom_metadata
1307    }
1308
1309    /// Return the number of batches in the file
1310    pub fn num_batches(&self) -> usize {
1311        self.total_blocks
1312    }
1313
1314    /// Return the schema of the file
1315    pub fn schema(&self) -> SchemaRef {
1316        self.decoder.schema.clone()
1317    }
1318
1319    /// See to a specific [`RecordBatch`]
1320    ///
1321    /// Sets the current block to the index, allowing random reads
1322    pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
1323        if index >= self.total_blocks {
1324            Err(ArrowError::InvalidArgumentError(format!(
1325                "Cannot set batch to index {} from {} total batches",
1326                index, self.total_blocks
1327            )))
1328        } else {
1329            self.current_block = index;
1330            Ok(())
1331        }
1332    }
1333
1334    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1335        let block = &self.blocks[self.current_block];
1336        self.current_block += 1;
1337
1338        // read length
1339        let buffer = read_block(&mut self.reader, block)?;
1340        self.decoder.read_record_batch(block, &buffer)
1341    }
1342
1343    /// Gets a reference to the underlying reader.
1344    ///
1345    /// It is inadvisable to directly read from the underlying reader.
1346    pub fn get_ref(&self) -> &R {
1347        &self.reader
1348    }
1349
1350    /// Gets a mutable reference to the underlying reader.
1351    ///
1352    /// It is inadvisable to directly read from the underlying reader.
1353    pub fn get_mut(&mut self) -> &mut R {
1354        &mut self.reader
1355    }
1356
1357    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1358    ///
1359    /// # Safety
1360    ///
1361    /// See [`FileDecoder::with_skip_validation`]
1362    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1363        self.decoder = self.decoder.with_skip_validation(skip_validation);
1364        self
1365    }
1366}
1367
1368impl<R: Read + Seek> Iterator for FileReader<R> {
1369    type Item = Result<RecordBatch, ArrowError>;
1370
1371    fn next(&mut self) -> Option<Self::Item> {
1372        // get current block
1373        if self.current_block < self.total_blocks {
1374            self.maybe_next().transpose()
1375        } else {
1376            None
1377        }
1378    }
1379}
1380
1381impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
1382    fn schema(&self) -> SchemaRef {
1383        self.schema()
1384    }
1385}
1386
1387/// Arrow Stream Reader
1388///
1389/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC Streaming Format].
1390///
1391/// # See Also
1392///
1393/// * [`FileReader`] for random access.
1394///
1395/// # Example
1396/// ```
1397/// # use arrow_array::record_batch;
1398/// # use arrow_ipc::reader::StreamReader;
1399/// # use arrow_ipc::writer::StreamWriter;
1400/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1401/// # let mut stream = vec![]; // mimic a stream for the example
1402/// # {
1403/// #  let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1404/// #  writer.write(&batch).unwrap();
1405/// #  writer.finish().unwrap();
1406/// # }
1407/// # let stream = stream.as_slice();
1408/// let projection = None; // read all columns
1409/// let mut reader = StreamReader::try_new(stream, projection).unwrap();
1410/// // read batches from the reader using the Iterator trait
1411/// let mut num_rows = 0;
1412/// for batch in reader {
1413///    let batch = batch.unwrap();
1414///    num_rows += batch.num_rows();
1415/// }
1416/// assert_eq!(num_rows, 3);
1417/// ```
1418///
1419/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1420pub struct StreamReader<R> {
1421    /// Stream reader
1422    reader: MessageReader<R>,
1423
1424    /// The schema that is read from the stream's first message
1425    schema: SchemaRef,
1426
1427    /// Optional dictionaries for each schema field.
1428    ///
1429    /// Dictionaries may be appended to in the streaming format.
1430    dictionaries_by_id: HashMap<i64, ArrayRef>,
1431
1432    /// An indicator of whether the stream is complete.
1433    ///
1434    /// This value is set to `true` the first time the reader's `next()` returns `None`.
1435    finished: bool,
1436
1437    /// Optional projection
1438    projection: Option<(Vec<usize>, Schema)>,
1439
1440    /// Should validation be skipped when reading data? Defaults to false.
1441    ///
1442    /// See [`FileDecoder::with_skip_validation`] for details.
1443    skip_validation: UnsafeFlag,
1444}
1445
1446impl<R> fmt::Debug for StreamReader<R> {
1447    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
1448        f.debug_struct("StreamReader<R>")
1449            .field("reader", &"R")
1450            .field("schema", &self.schema)
1451            .field("dictionaries_by_id", &self.dictionaries_by_id)
1452            .field("finished", &self.finished)
1453            .field("projection", &self.projection)
1454            .finish()
1455    }
1456}
1457
1458impl<R: Read> StreamReader<BufReader<R>> {
1459    /// Try to create a new stream reader with the reader wrapped in a BufReader.
1460    ///
1461    /// See [`StreamReader::try_new`] for an unbuffered version.
1462    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1463        Self::try_new(BufReader::new(reader), projection)
1464    }
1465}
1466
1467impl<R: Read> StreamReader<R> {
1468    /// Try to create a new stream reader.
1469    ///
1470    /// To check if the reader is done, use [`is_finished(self)`](StreamReader::is_finished).
1471    ///
1472    /// There is no internal buffering. If buffered reads are needed you likely want to use
1473    /// [`StreamReader::try_new_buffered`] instead.
1474    ///
1475    /// # Errors
1476    ///
1477    /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema
1478    /// as the first message in the stream.
1479    pub fn try_new(
1480        reader: R,
1481        projection: Option<Vec<usize>>,
1482    ) -> Result<StreamReader<R>, ArrowError> {
1483        let mut msg_reader = MessageReader::new(reader);
1484        let message = msg_reader.maybe_next()?;
1485        let Some((message, _)) = message else {
1486            return Err(ArrowError::IpcError(
1487                "Expected schema message, found empty stream.".to_string(),
1488            ));
1489        };
1490
1491        if message.header_type() != Message::MessageHeader::Schema {
1492            return Err(ArrowError::IpcError(format!(
1493                "Expected a schema as the first message in the stream, got: {:?}",
1494                message.header_type()
1495            )));
1496        }
1497
1498        let schema = message.header_as_schema().ok_or_else(|| {
1499            ArrowError::ParseError("Failed to parse schema from message header".to_string())
1500        })?;
1501        let schema = crate::convert::fb_to_schema(schema);
1502
1503        // Create an array of optional dictionary value arrays, one per field.
1504        let dictionaries_by_id = HashMap::new();
1505
1506        let projection = match projection {
1507            Some(projection_indices) => {
1508                let schema = schema.project(&projection_indices)?;
1509                Some((projection_indices, schema))
1510            }
1511            _ => None,
1512        };
1513
1514        Ok(Self {
1515            reader: msg_reader,
1516            schema: Arc::new(schema),
1517            finished: false,
1518            dictionaries_by_id,
1519            projection,
1520            skip_validation: UnsafeFlag::new(),
1521        })
1522    }
1523
1524    /// Deprecated, use [`StreamReader::try_new`] instead.
1525    #[deprecated(since = "53.0.0", note = "use `try_new` instead")]
1526    pub fn try_new_unbuffered(
1527        reader: R,
1528        projection: Option<Vec<usize>>,
1529    ) -> Result<Self, ArrowError> {
1530        Self::try_new(reader, projection)
1531    }
1532
1533    /// Return the schema of the stream
1534    pub fn schema(&self) -> SchemaRef {
1535        self.schema.clone()
1536    }
1537
1538    /// Check if the stream is finished
1539    pub fn is_finished(&self) -> bool {
1540        self.finished
1541    }
1542
1543    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1544        if self.finished {
1545            return Ok(None);
1546        }
1547
1548        // Read messages until we get a record batch or end of stream
1549        loop {
1550            let message = self.next_ipc_message()?;
1551            let Some(message) = message else {
1552                // If the message is None, we have reached the end of the stream.
1553                self.finished = true;
1554                return Ok(None);
1555            };
1556
1557            match message {
1558                IpcMessage::Schema(_) => {
1559                    return Err(ArrowError::IpcError(
1560                        "Expected a record batch, but found a schema".to_string(),
1561                    ));
1562                }
1563                IpcMessage::RecordBatch(record_batch) => {
1564                    return Ok(Some(record_batch));
1565                }
1566                IpcMessage::DictionaryBatch { .. } => {
1567                    continue;
1568                }
1569            };
1570        }
1571    }
1572
1573    /// Reads and fully parses the next IPC message from the stream. Whereas
1574    /// [`Self::maybe_next`] is a higher level method focused on reading
1575    /// `RecordBatch`es, this method returns the individual fully parsed IPC
1576    /// messages from the underlying stream.
1577    ///
1578    /// This is useful primarily for testing reader/writer behaviors as it
1579    /// allows a full view into the messages that have been written to a stream.
1580    pub(crate) fn next_ipc_message(&mut self) -> Result<Option<IpcMessage>, ArrowError> {
1581        let message = self.reader.maybe_next()?;
1582        let Some((message, body)) = message else {
1583            // If the message is None, we have reached the end of the stream.
1584            return Ok(None);
1585        };
1586
1587        let ipc_message = match message.header_type() {
1588            Message::MessageHeader::Schema => {
1589                let schema = message.header_as_schema().ok_or_else(|| {
1590                    ArrowError::ParseError("Failed to parse schema from message header".to_string())
1591                })?;
1592                let arrow_schema = crate::convert::fb_to_schema(schema);
1593                IpcMessage::Schema(arrow_schema)
1594            }
1595            Message::MessageHeader::RecordBatch => {
1596                let batch = message.header_as_record_batch().ok_or_else(|| {
1597                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1598                })?;
1599
1600                let version = message.version();
1601                let schema = self.schema.clone();
1602                let record_batch = RecordBatchDecoder::try_new(
1603                    &body.into(),
1604                    batch,
1605                    schema,
1606                    &self.dictionaries_by_id,
1607                    &version,
1608                )?
1609                .with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
1610                .with_require_alignment(false)
1611                .with_skip_validation(self.skip_validation.clone())
1612                .read_record_batch()?;
1613                IpcMessage::RecordBatch(record_batch)
1614            }
1615            Message::MessageHeader::DictionaryBatch => {
1616                let dict = message.header_as_dictionary_batch().ok_or_else(|| {
1617                    ArrowError::ParseError(
1618                        "Failed to parse dictionary batch from message header".to_string(),
1619                    )
1620                })?;
1621
1622                let version = message.version();
1623                let dict_values = get_dictionary_values(
1624                    &body.into(),
1625                    dict,
1626                    &self.schema,
1627                    &mut self.dictionaries_by_id,
1628                    &version,
1629                    false,
1630                    self.skip_validation.clone(),
1631                )?;
1632
1633                update_dictionaries(
1634                    &mut self.dictionaries_by_id,
1635                    dict.isDelta(),
1636                    dict.id(),
1637                    dict_values.clone(),
1638                )?;
1639
1640                IpcMessage::DictionaryBatch {
1641                    id: dict.id(),
1642                    is_delta: (dict.isDelta()),
1643                    values: (dict_values),
1644                }
1645            }
1646            x => {
1647                return Err(ArrowError::ParseError(format!(
1648                    "Unsupported message header type in IPC stream: '{x:?}'"
1649                )));
1650            }
1651        };
1652
1653        Ok(Some(ipc_message))
1654    }
1655
1656    /// Gets a reference to the underlying reader.
1657    ///
1658    /// It is inadvisable to directly read from the underlying reader.
1659    pub fn get_ref(&self) -> &R {
1660        self.reader.inner()
1661    }
1662
1663    /// Gets a mutable reference to the underlying reader.
1664    ///
1665    /// It is inadvisable to directly read from the underlying reader.
1666    pub fn get_mut(&mut self) -> &mut R {
1667        self.reader.inner_mut()
1668    }
1669
1670    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1671    ///
1672    /// # Safety
1673    ///
1674    /// See [`FileDecoder::with_skip_validation`]
1675    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1676        self.skip_validation.set(skip_validation);
1677        self
1678    }
1679}
1680
1681impl<R: Read> Iterator for StreamReader<R> {
1682    type Item = Result<RecordBatch, ArrowError>;
1683
1684    fn next(&mut self) -> Option<Self::Item> {
1685        self.maybe_next().transpose()
1686    }
1687}
1688
1689impl<R: Read> RecordBatchReader for StreamReader<R> {
1690    fn schema(&self) -> SchemaRef {
1691        self.schema.clone()
1692    }
1693}
1694
1695/// Representation of a fully parsed IpcMessage from the underlying stream.
1696/// Parsing this kind of message is done by higher level constructs such as
1697/// [`StreamReader`], because fully interpreting the messages into a record
1698/// batch or dictionary batch requires access to stream state such as schema
1699/// and the full dictionary cache.
1700#[derive(Debug)]
1701#[allow(dead_code)]
1702pub(crate) enum IpcMessage {
1703    Schema(arrow_schema::Schema),
1704    RecordBatch(RecordBatch),
1705    DictionaryBatch {
1706        id: i64,
1707        is_delta: bool,
1708        values: ArrayRef,
1709    },
1710}
1711
1712/// A low-level construct that reads [`Message::Message`]s from a reader while
1713/// re-using a buffer for metadata. This is composed into [`StreamReader`].
1714struct MessageReader<R> {
1715    reader: R,
1716    buf: Vec<u8>,
1717}
1718
1719impl<R: Read> MessageReader<R> {
1720    fn new(reader: R) -> Self {
1721        Self {
1722            reader,
1723            buf: Vec::new(),
1724        }
1725    }
1726
1727    /// Reads the entire next message from the underlying reader which includes
1728    /// the metadata length, the metadata, and the body.
1729    ///
1730    /// # Returns
1731    /// - `Ok(None)` if the the reader signals the end of stream with EOF on
1732    ///   the first read
1733    /// - `Err(_)` if the reader returns an error other than on the first
1734    ///   read, or if the metadata length is invalid
1735    /// - `Ok(Some(_))` with the Message and buffer containiner the
1736    ///   body bytes otherwise.
1737    fn maybe_next(&mut self) -> Result<Option<(Message::Message<'_>, MutableBuffer)>, ArrowError> {
1738        let meta_len = self.read_meta_len()?;
1739        let Some(meta_len) = meta_len else {
1740            return Ok(None);
1741        };
1742
1743        self.buf.resize(meta_len, 0);
1744        self.reader.read_exact(&mut self.buf)?;
1745
1746        let message = crate::root_as_message(self.buf.as_slice()).map_err(|err| {
1747            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1748        })?;
1749
1750        let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1751        self.reader.read_exact(&mut buf)?;
1752
1753        Ok(Some((message, buf)))
1754    }
1755
1756    /// Get a mutable reference to the underlying reader.
1757    fn inner_mut(&mut self) -> &mut R {
1758        &mut self.reader
1759    }
1760
1761    /// Get an immutable reference to the underlying reader.
1762    fn inner(&self) -> &R {
1763        &self.reader
1764    }
1765
1766    /// Read the metadata length for the next message from the underlying stream.
1767    ///
1768    /// # Returns
1769    /// - `Ok(None)` if the the reader signals the end of stream with EOF on
1770    ///   the first read
1771    /// - `Err(_)` if the reader returns an error other than on the first
1772    ///   read, or if the metadata length is less than 0.
1773    /// - `Ok(Some(_))` with the length otherwise.
1774    pub fn read_meta_len(&mut self) -> Result<Option<usize>, ArrowError> {
1775        let mut meta_len: [u8; 4] = [0; 4];
1776        match self.reader.read_exact(&mut meta_len) {
1777            Ok(_) => {}
1778            Err(e) => {
1779                return if e.kind() == std::io::ErrorKind::UnexpectedEof {
1780                    // Handle EOF without the "0xFFFFFFFF 0x00000000"
1781                    // valid according to:
1782                    // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1783                    Ok(None)
1784                } else {
1785                    Err(ArrowError::from(e))
1786                };
1787            }
1788        };
1789
1790        let meta_len = {
1791            // If a continuation marker is encountered, skip over it and read
1792            // the size from the next four bytes.
1793            if meta_len == CONTINUATION_MARKER {
1794                self.reader.read_exact(&mut meta_len)?;
1795            }
1796
1797            i32::from_le_bytes(meta_len)
1798        };
1799
1800        if meta_len == 0 {
1801            return Ok(None);
1802        }
1803
1804        let meta_len = usize::try_from(meta_len)
1805            .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
1806
1807        Ok(Some(meta_len))
1808    }
1809}
1810
1811#[cfg(test)]
1812mod tests {
1813    use std::io::Cursor;
1814
1815    use crate::convert::fb_to_schema;
1816    use crate::writer::{
1817        unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1818    };
1819
1820    use super::*;
1821
1822    use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
1823    use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
1824    use arrow_array::types::*;
1825    use arrow_buffer::{NullBuffer, OffsetBuffer};
1826    use arrow_data::ArrayDataBuilder;
1827
1828    fn create_test_projection_schema() -> Schema {
1829        // define field types
1830        let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
1831
1832        let fixed_size_list_data_type =
1833            DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3);
1834
1835        let union_fields = UnionFields::new(
1836            vec![0, 1],
1837            vec![
1838                Field::new("a", DataType::Int32, false),
1839                Field::new("b", DataType::Float64, false),
1840            ],
1841        );
1842
1843        let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
1844
1845        let struct_fields = Fields::from(vec![
1846            Field::new("id", DataType::Int32, false),
1847            Field::new_list("list", Field::new_list_field(DataType::Int8, true), false),
1848        ]);
1849        let struct_data_type = DataType::Struct(struct_fields);
1850
1851        let run_encoded_data_type = DataType::RunEndEncoded(
1852            Arc::new(Field::new("run_ends", DataType::Int16, false)),
1853            Arc::new(Field::new("values", DataType::Int32, true)),
1854        );
1855
1856        // define schema
1857        Schema::new(vec![
1858            Field::new("f0", DataType::UInt32, false),
1859            Field::new("f1", DataType::Utf8, false),
1860            Field::new("f2", DataType::Boolean, false),
1861            Field::new("f3", union_data_type, true),
1862            Field::new("f4", DataType::Null, true),
1863            Field::new("f5", DataType::Float64, true),
1864            Field::new("f6", list_data_type, false),
1865            Field::new("f7", DataType::FixedSizeBinary(3), true),
1866            Field::new("f8", fixed_size_list_data_type, false),
1867            Field::new("f9", struct_data_type, false),
1868            Field::new("f10", run_encoded_data_type, false),
1869            Field::new("f11", DataType::Boolean, false),
1870            Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
1871            Field::new("f13", DataType::Utf8, false),
1872        ])
1873    }
1874
1875    fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
1876        // set test data for each column
1877        let array0 = UInt32Array::from(vec![1, 2, 3]);
1878        let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
1879        let array2 = BooleanArray::from(vec![true, false, true]);
1880
1881        let mut union_builder = UnionBuilder::new_dense();
1882        union_builder.append::<Int32Type>("a", 1).unwrap();
1883        union_builder.append::<Float64Type>("b", 10.1).unwrap();
1884        union_builder.append_null::<Float64Type>("b").unwrap();
1885        let array3 = union_builder.build().unwrap();
1886
1887        let array4 = NullArray::new(3);
1888        let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
1889        let array6_values = vec![
1890            Some(vec![Some(10), Some(10), Some(10)]),
1891            Some(vec![Some(20), Some(20), Some(20)]),
1892            Some(vec![Some(30), Some(30)]),
1893        ];
1894        let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
1895        let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
1896        let array7 = FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
1897
1898        let array8_values = ArrayData::builder(DataType::Int32)
1899            .len(9)
1900            .add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
1901            .build()
1902            .unwrap();
1903        let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
1904            .len(3)
1905            .add_child_data(array8_values)
1906            .build()
1907            .unwrap();
1908        let array8 = FixedSizeListArray::from(array8_data);
1909
1910        let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
1911        let array9_list: ArrayRef =
1912            Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
1913                Some(vec![Some(-10)]),
1914                Some(vec![Some(-20), Some(-20), Some(-20)]),
1915                Some(vec![Some(-30)]),
1916            ]));
1917        let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
1918            .add_child_data(array9_id.into_data())
1919            .add_child_data(array9_list.into_data())
1920            .len(3)
1921            .build()
1922            .unwrap();
1923        let array9: ArrayRef = Arc::new(StructArray::from(array9));
1924
1925        let array10_input = vec![Some(1_i32), None, None];
1926        let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
1927        array10_builder.extend(array10_input);
1928        let array10 = array10_builder.finish();
1929
1930        let array11 = BooleanArray::from(vec![false, false, true]);
1931
1932        let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
1933        let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
1934        let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
1935
1936        let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
1937
1938        // create record batch
1939        RecordBatch::try_new(
1940            Arc::new(schema.clone()),
1941            vec![
1942                Arc::new(array0),
1943                Arc::new(array1),
1944                Arc::new(array2),
1945                Arc::new(array3),
1946                Arc::new(array4),
1947                Arc::new(array5),
1948                Arc::new(array6),
1949                Arc::new(array7),
1950                Arc::new(array8),
1951                Arc::new(array9),
1952                Arc::new(array10),
1953                Arc::new(array11),
1954                Arc::new(array12),
1955                Arc::new(array13),
1956            ],
1957        )
1958        .unwrap()
1959    }
1960
1961    #[test]
1962    fn test_negative_meta_len_start_stream() {
1963        let bytes = i32::to_le_bytes(-1);
1964        let mut buf = vec![];
1965        buf.extend(CONTINUATION_MARKER);
1966        buf.extend(bytes);
1967
1968        let reader_err = StreamReader::try_new(Cursor::new(buf), None).err();
1969        assert!(reader_err.is_some());
1970        assert_eq!(
1971            reader_err.unwrap().to_string(),
1972            "Parser error: Invalid metadata length: -1"
1973        );
1974    }
1975
1976    #[test]
1977    fn test_negative_meta_len_mid_stream() {
1978        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1979        let mut buf = Vec::new();
1980        {
1981            let mut writer = crate::writer::StreamWriter::try_new(&mut buf, &schema).unwrap();
1982            let batch =
1983                RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int32Array::from(vec![1]))])
1984                    .unwrap();
1985            writer.write(&batch).unwrap();
1986        }
1987
1988        let bytes = i32::to_le_bytes(-1);
1989        buf.extend(CONTINUATION_MARKER);
1990        buf.extend(bytes);
1991
1992        let mut reader = StreamReader::try_new(Cursor::new(buf), None).unwrap();
1993        // Read the valid value
1994        assert!(reader.maybe_next().is_ok());
1995        // Read the invalid meta len
1996        let batch_err = reader.maybe_next().err();
1997        assert!(batch_err.is_some());
1998        assert_eq!(
1999            batch_err.unwrap().to_string(),
2000            "Parser error: Invalid metadata length: -1"
2001        );
2002    }
2003
2004    #[test]
2005    fn test_projection_array_values() {
2006        // define schema
2007        let schema = create_test_projection_schema();
2008
2009        // create record batch with test data
2010        let batch = create_test_projection_batch_data(&schema);
2011
2012        // write record batch in IPC format
2013        let mut buf = Vec::new();
2014        {
2015            let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
2016            writer.write(&batch).unwrap();
2017            writer.finish().unwrap();
2018        }
2019
2020        // read record batch with projection
2021        for index in 0..12 {
2022            let projection = vec![index];
2023            let reader = FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
2024            let read_batch = reader.unwrap().next().unwrap().unwrap();
2025            let projected_column = read_batch.column(0);
2026            let expected_column = batch.column(index);
2027
2028            // check the projected column equals the expected column
2029            assert_eq!(projected_column.as_ref(), expected_column.as_ref());
2030        }
2031
2032        {
2033            // read record batch with reversed projection
2034            let reader =
2035                FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(vec![3, 2, 1]));
2036            let read_batch = reader.unwrap().next().unwrap().unwrap();
2037            let expected_batch = batch.project(&[3, 2, 1]).unwrap();
2038            assert_eq!(read_batch, expected_batch);
2039        }
2040    }
2041
2042    #[test]
2043    fn test_arrow_single_float_row() {
2044        let schema = Schema::new(vec![
2045            Field::new("a", DataType::Float32, false),
2046            Field::new("b", DataType::Float32, false),
2047            Field::new("c", DataType::Int32, false),
2048            Field::new("d", DataType::Int32, false),
2049        ]);
2050        let arrays = vec![
2051            Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
2052            Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
2053            Arc::new(Int32Array::from(vec![2])) as ArrayRef,
2054            Arc::new(Int32Array::from(vec![1])) as ArrayRef,
2055        ];
2056        let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
2057        // create stream writer
2058        let mut file = tempfile::tempfile().unwrap();
2059        let mut stream_writer = crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
2060        stream_writer.write(&batch).unwrap();
2061        stream_writer.finish().unwrap();
2062
2063        drop(stream_writer);
2064
2065        file.rewind().unwrap();
2066
2067        // read stream back
2068        let reader = StreamReader::try_new(&mut file, None).unwrap();
2069
2070        reader.for_each(|batch| {
2071            let batch = batch.unwrap();
2072            assert!(
2073                batch
2074                    .column(0)
2075                    .as_any()
2076                    .downcast_ref::<Float32Array>()
2077                    .unwrap()
2078                    .value(0)
2079                    != 0.0
2080            );
2081            assert!(
2082                batch
2083                    .column(1)
2084                    .as_any()
2085                    .downcast_ref::<Float32Array>()
2086                    .unwrap()
2087                    .value(0)
2088                    != 0.0
2089            );
2090        });
2091
2092        file.rewind().unwrap();
2093
2094        // Read with projection
2095        let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
2096
2097        reader.for_each(|batch| {
2098            let batch = batch.unwrap();
2099            assert_eq!(batch.schema().fields().len(), 2);
2100            assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
2101            assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
2102        });
2103    }
2104
2105    /// Write the record batch to an in-memory buffer in IPC File format
2106    fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
2107        let mut buf = Vec::new();
2108        let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
2109        writer.write(rb).unwrap();
2110        writer.finish().unwrap();
2111        buf
2112    }
2113
2114    /// Return the first record batch read from the IPC File buffer
2115    fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
2116        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
2117        reader.next().unwrap()
2118    }
2119
2120    /// Return the first record batch read from the IPC File buffer, disabling
2121    /// validation
2122    fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
2123        let mut reader = unsafe {
2124            FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
2125        };
2126        reader.next().unwrap()
2127    }
2128
2129    fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
2130        let buf = write_ipc(rb);
2131        read_ipc(&buf).unwrap()
2132    }
2133
2134    /// Return the first record batch read from the IPC File buffer
2135    /// using the FileDecoder API
2136    fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
2137        read_ipc_with_decoder_inner(buf, false)
2138    }
2139
2140    /// Return the first record batch read from the IPC File buffer
2141    /// using the FileDecoder API, disabling validation
2142    fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
2143        read_ipc_with_decoder_inner(buf, true)
2144    }
2145
2146    fn read_ipc_with_decoder_inner(
2147        buf: Vec<u8>,
2148        skip_validation: bool,
2149    ) -> Result<RecordBatch, ArrowError> {
2150        let buffer = Buffer::from_vec(buf);
2151        let trailer_start = buffer.len() - 10;
2152        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
2153        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
2154            .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;
2155
2156        let schema = fb_to_schema(footer.schema().unwrap());
2157
2158        let mut decoder = unsafe {
2159            FileDecoder::new(Arc::new(schema), footer.version())
2160                .with_skip_validation(skip_validation)
2161        };
2162        // Read dictionaries
2163        for block in footer.dictionaries().iter().flatten() {
2164            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2165            let data = buffer.slice_with_length(block.offset() as _, block_len);
2166            decoder.read_dictionary(block, &data)?
2167        }
2168
2169        // Read record batch
2170        let batches = footer.recordBatches().unwrap();
2171        assert_eq!(batches.len(), 1); // Only wrote a single batch
2172
2173        let block = batches.get(0);
2174        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2175        let data = buffer.slice_with_length(block.offset() as _, block_len);
2176        Ok(decoder.read_record_batch(block, &data)?.unwrap())
2177    }
2178
2179    /// Write the record batch to an in-memory buffer in IPC Stream format
2180    fn write_stream(rb: &RecordBatch) -> Vec<u8> {
2181        let mut buf = Vec::new();
2182        let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
2183        writer.write(rb).unwrap();
2184        writer.finish().unwrap();
2185        buf
2186    }
2187
2188    /// Return the first record batch read from the IPC Stream buffer
2189    fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
2190        let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
2191        reader.next().unwrap()
2192    }
2193
2194    /// Return the first record batch read from the IPC Stream buffer,
2195    /// disabling validation
2196    fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
2197        let mut reader = unsafe {
2198            StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
2199        };
2200        reader.next().unwrap()
2201    }
2202
2203    fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
2204        let buf = write_stream(rb);
2205        read_stream(&buf).unwrap()
2206    }
2207
2208    #[test]
2209    fn test_roundtrip_with_custom_metadata() {
2210        let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
2211        let mut buf = Vec::new();
2212        let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
2213        let mut test_metadata = HashMap::new();
2214        test_metadata.insert("abc".to_string(), "abc".to_string());
2215        test_metadata.insert("def".to_string(), "def".to_string());
2216        for (k, v) in &test_metadata {
2217            writer.write_metadata(k, v);
2218        }
2219        writer.finish().unwrap();
2220        drop(writer);
2221
2222        let reader = crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2223        assert_eq!(reader.custom_metadata(), &test_metadata);
2224    }
2225
2226    #[test]
2227    fn test_roundtrip_nested_dict() {
2228        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2229
2230        let array = Arc::new(inner) as ArrayRef;
2231
2232        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2233
2234        let s = StructArray::from(vec![(dctfield, array)]);
2235        let struct_array = Arc::new(s) as ArrayRef;
2236
2237        let schema = Arc::new(Schema::new(vec![Field::new(
2238            "struct",
2239            struct_array.data_type().clone(),
2240            false,
2241        )]));
2242
2243        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2244
2245        assert_eq!(batch, roundtrip_ipc(&batch));
2246    }
2247
2248    #[test]
2249    fn test_roundtrip_nested_dict_no_preserve_dict_id() {
2250        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2251
2252        let array = Arc::new(inner) as ArrayRef;
2253
2254        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2255
2256        let s = StructArray::from(vec![(dctfield, array)]);
2257        let struct_array = Arc::new(s) as ArrayRef;
2258
2259        let schema = Arc::new(Schema::new(vec![Field::new(
2260            "struct",
2261            struct_array.data_type().clone(),
2262            false,
2263        )]));
2264
2265        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2266
2267        let mut buf = Vec::new();
2268        let mut writer = crate::writer::FileWriter::try_new_with_options(
2269            &mut buf,
2270            batch.schema_ref(),
2271            IpcWriteOptions::default(),
2272        )
2273        .unwrap();
2274        writer.write(&batch).unwrap();
2275        writer.finish().unwrap();
2276        drop(writer);
2277
2278        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2279
2280        assert_eq!(batch, reader.next().unwrap().unwrap());
2281    }
2282
2283    fn check_union_with_builder(mut builder: UnionBuilder) {
2284        builder.append::<Int32Type>("a", 1).unwrap();
2285        builder.append_null::<Int32Type>("a").unwrap();
2286        builder.append::<Float64Type>("c", 3.0).unwrap();
2287        builder.append::<Int32Type>("a", 4).unwrap();
2288        builder.append::<Int64Type>("d", 11).unwrap();
2289        let union = builder.build().unwrap();
2290
2291        let schema = Arc::new(Schema::new(vec![Field::new(
2292            "union",
2293            union.data_type().clone(),
2294            false,
2295        )]));
2296
2297        let union_array = Arc::new(union) as ArrayRef;
2298
2299        let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
2300        let rb2 = roundtrip_ipc(&rb);
2301        // TODO: equality not yet implemented for union, so we check that the length of the array is
2302        // the same and that all of the buffers are the same instead.
2303        assert_eq!(rb.schema(), rb2.schema());
2304        assert_eq!(rb.num_columns(), rb2.num_columns());
2305        assert_eq!(rb.num_rows(), rb2.num_rows());
2306        let union1 = rb.column(0);
2307        let union2 = rb2.column(0);
2308
2309        assert_eq!(union1, union2);
2310    }
2311
2312    #[test]
2313    fn test_roundtrip_dense_union() {
2314        check_union_with_builder(UnionBuilder::new_dense());
2315    }
2316
2317    #[test]
2318    fn test_roundtrip_sparse_union() {
2319        check_union_with_builder(UnionBuilder::new_sparse());
2320    }
2321
2322    #[test]
2323    fn test_roundtrip_struct_empty_fields() {
2324        let nulls = NullBuffer::from(&[true, true, false]);
2325        let rb = RecordBatch::try_from_iter([(
2326            "",
2327            Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _,
2328        )])
2329        .unwrap();
2330        let rb2 = roundtrip_ipc(&rb);
2331        assert_eq!(rb, rb2);
2332    }
2333
2334    #[test]
2335    fn test_roundtrip_stream_run_array_sliced() {
2336        let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
2337            .into_iter()
2338            .collect();
2339        let run_array_1_sliced = run_array_1.slice(2, 5);
2340
2341        let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
2342        let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
2343        run_array_2_builder.extend(run_array_2_inupt);
2344        let run_array_2 = run_array_2_builder.finish();
2345
2346        let schema = Arc::new(Schema::new(vec![
2347            Field::new(
2348                "run_array_1_sliced",
2349                run_array_1_sliced.data_type().clone(),
2350                false,
2351            ),
2352            Field::new("run_array_2", run_array_2.data_type().clone(), false),
2353        ]));
2354        let input_batch = RecordBatch::try_new(
2355            schema,
2356            vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
2357        )
2358        .unwrap();
2359        let output_batch = roundtrip_ipc_stream(&input_batch);
2360
2361        // As partial comparison not yet supported for run arrays, the sliced run array
2362        // has to be unsliced before comparing with the output. the second run array
2363        // can be compared as such.
2364        assert_eq!(input_batch.column(1), output_batch.column(1));
2365
2366        let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap();
2367        assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
2368    }
2369
2370    #[test]
2371    fn test_roundtrip_stream_nested_dict() {
2372        let xs = vec!["AA", "BB", "AA", "CC", "BB"];
2373        let dict = Arc::new(
2374            xs.clone()
2375                .into_iter()
2376                .collect::<DictionaryArray<Int8Type>>(),
2377        );
2378        let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
2379        let struct_array = StructArray::from(vec![
2380            (
2381                Arc::new(Field::new("f2.1", DataType::Utf8, false)),
2382                string_array,
2383            ),
2384            (
2385                Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
2386                dict.clone() as ArrayRef,
2387            ),
2388        ]);
2389        let schema = Arc::new(Schema::new(vec![
2390            Field::new("f1_string", DataType::Utf8, false),
2391            Field::new("f2_struct", struct_array.data_type().clone(), false),
2392        ]));
2393        let input_batch = RecordBatch::try_new(
2394            schema,
2395            vec![
2396                Arc::new(StringArray::from(xs.clone())),
2397                Arc::new(struct_array),
2398            ],
2399        )
2400        .unwrap();
2401        let output_batch = roundtrip_ipc_stream(&input_batch);
2402        assert_eq!(input_batch, output_batch);
2403    }
2404
2405    #[test]
2406    fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
2407        let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
2408        let values = Arc::new(values) as ArrayRef;
2409        let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
2410        let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
2411
2412        let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
2413        let key_dict_array = DictionaryArray::new(key_dict_keys, values);
2414
2415        #[allow(deprecated)]
2416        let keys_field = Arc::new(Field::new_dict(
2417            "keys",
2418            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2419            true, // It is technically not legal for this field to be null.
2420            1,
2421            false,
2422        ));
2423        #[allow(deprecated)]
2424        let values_field = Arc::new(Field::new_dict(
2425            "values",
2426            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2427            true,
2428            2,
2429            false,
2430        ));
2431        let entry_struct = StructArray::from(vec![
2432            (keys_field, make_array(key_dict_array.into_data())),
2433            (values_field, make_array(value_dict_array.into_data())),
2434        ]);
2435        let map_data_type = DataType::Map(
2436            Arc::new(Field::new(
2437                "entries",
2438                entry_struct.data_type().clone(),
2439                false,
2440            )),
2441            false,
2442        );
2443
2444        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
2445        let map_data = ArrayData::builder(map_data_type)
2446            .len(3)
2447            .add_buffer(entry_offsets)
2448            .add_child_data(entry_struct.into_data())
2449            .build()
2450            .unwrap();
2451        let map_array = MapArray::from(map_data);
2452
2453        let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
2454        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2455
2456        let schema = Arc::new(Schema::new(vec![Field::new(
2457            "f1",
2458            dict_dict_array.data_type().clone(),
2459            false,
2460        )]));
2461        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2462        let output_batch = roundtrip_ipc_stream(&input_batch);
2463        assert_eq!(input_batch, output_batch);
2464    }
2465
2466    fn test_roundtrip_stream_dict_of_list_of_dict_impl<
2467        OffsetSize: OffsetSizeTrait,
2468        U: ArrowNativeType,
2469    >(
2470        list_data_type: DataType,
2471        offsets: &[U; 5],
2472    ) {
2473        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2474        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2475        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2476        let dict_data = dict_array.to_data();
2477
2478        let value_offsets = Buffer::from_slice_ref(offsets);
2479
2480        let list_data = ArrayData::builder(list_data_type)
2481            .len(4)
2482            .add_buffer(value_offsets)
2483            .add_child_data(dict_data)
2484            .build()
2485            .unwrap();
2486        let list_array = GenericListArray::<OffsetSize>::from(list_data);
2487
2488        let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
2489        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2490
2491        let schema = Arc::new(Schema::new(vec![Field::new(
2492            "f1",
2493            dict_dict_array.data_type().clone(),
2494            false,
2495        )]));
2496        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2497        let output_batch = roundtrip_ipc_stream(&input_batch);
2498        assert_eq!(input_batch, output_batch);
2499    }
2500
2501    #[test]
2502    fn test_roundtrip_stream_dict_of_list_of_dict() {
2503        // list
2504        #[allow(deprecated)]
2505        let list_data_type = DataType::List(Arc::new(Field::new_dict(
2506            "item",
2507            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2508            true,
2509            1,
2510            false,
2511        )));
2512        let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
2513        test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(list_data_type, offsets);
2514
2515        // large list
2516        #[allow(deprecated)]
2517        let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
2518            "item",
2519            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2520            true,
2521            1,
2522            false,
2523        )));
2524        let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
2525        test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(list_data_type, offsets);
2526    }
2527
2528    #[test]
2529    fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
2530        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2531        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
2532        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2533        let dict_data = dict_array.into_data();
2534
2535        #[allow(deprecated)]
2536        let list_data_type = DataType::FixedSizeList(
2537            Arc::new(Field::new_dict(
2538                "item",
2539                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2540                true,
2541                1,
2542                false,
2543            )),
2544            3,
2545        );
2546        let list_data = ArrayData::builder(list_data_type)
2547            .len(3)
2548            .add_child_data(dict_data)
2549            .build()
2550            .unwrap();
2551        let list_array = FixedSizeListArray::from(list_data);
2552
2553        let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2554        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2555
2556        let schema = Arc::new(Schema::new(vec![Field::new(
2557            "f1",
2558            dict_dict_array.data_type().clone(),
2559            false,
2560        )]));
2561        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2562        let output_batch = roundtrip_ipc_stream(&input_batch);
2563        assert_eq!(input_batch, output_batch);
2564    }
2565
2566    const LONG_TEST_STRING: &str =
2567        "This is a long string to make sure binary view array handles it";
2568
2569    #[test]
2570    fn test_roundtrip_view_types() {
2571        let schema = Schema::new(vec![
2572            Field::new("field_1", DataType::BinaryView, true),
2573            Field::new("field_2", DataType::Utf8, true),
2574            Field::new("field_3", DataType::Utf8View, true),
2575        ]);
2576        let bin_values: Vec<Option<&[u8]>> = vec![
2577            Some(b"foo"),
2578            None,
2579            Some(b"bar"),
2580            Some(LONG_TEST_STRING.as_bytes()),
2581        ];
2582        let utf8_values: Vec<Option<&str>> =
2583            vec![Some("foo"), None, Some("bar"), Some(LONG_TEST_STRING)];
2584        let bin_view_array = BinaryViewArray::from_iter(bin_values);
2585        let utf8_array = StringArray::from_iter(utf8_values.iter());
2586        let utf8_view_array = StringViewArray::from_iter(utf8_values);
2587        let record_batch = RecordBatch::try_new(
2588            Arc::new(schema.clone()),
2589            vec![
2590                Arc::new(bin_view_array),
2591                Arc::new(utf8_array),
2592                Arc::new(utf8_view_array),
2593            ],
2594        )
2595        .unwrap();
2596
2597        assert_eq!(record_batch, roundtrip_ipc(&record_batch));
2598        assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));
2599
2600        let sliced_batch = record_batch.slice(1, 2);
2601        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2602        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2603    }
2604
2605    #[test]
2606    fn test_roundtrip_view_types_nested_dict() {
2607        let bin_values: Vec<Option<&[u8]>> = vec![
2608            Some(b"foo"),
2609            None,
2610            Some(b"bar"),
2611            Some(LONG_TEST_STRING.as_bytes()),
2612            Some(b"field"),
2613        ];
2614        let utf8_values: Vec<Option<&str>> = vec![
2615            Some("foo"),
2616            None,
2617            Some("bar"),
2618            Some(LONG_TEST_STRING),
2619            Some("field"),
2620        ];
2621        let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
2622        let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));
2623
2624        let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2625        let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
2626        #[allow(deprecated)]
2627        let keys_field = Arc::new(Field::new_dict(
2628            "keys",
2629            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
2630            true,
2631            1,
2632            false,
2633        ));
2634
2635        let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
2636        let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
2637        #[allow(deprecated)]
2638        let values_field = Arc::new(Field::new_dict(
2639            "values",
2640            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
2641            true,
2642            2,
2643            false,
2644        ));
2645        let entry_struct = StructArray::from(vec![
2646            (keys_field, make_array(key_dict_array.into_data())),
2647            (values_field, make_array(value_dict_array.into_data())),
2648        ]);
2649
2650        let map_data_type = DataType::Map(
2651            Arc::new(Field::new(
2652                "entries",
2653                entry_struct.data_type().clone(),
2654                false,
2655            )),
2656            false,
2657        );
2658        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
2659        let map_data = ArrayData::builder(map_data_type)
2660            .len(3)
2661            .add_buffer(entry_offsets)
2662            .add_child_data(entry_struct.into_data())
2663            .build()
2664            .unwrap();
2665        let map_array = MapArray::from(map_data);
2666
2667        let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2668        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2669        let schema = Arc::new(Schema::new(vec![Field::new(
2670            "f1",
2671            dict_dict_array.data_type().clone(),
2672            false,
2673        )]));
2674        let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2675        assert_eq!(batch, roundtrip_ipc(&batch));
2676        assert_eq!(batch, roundtrip_ipc_stream(&batch));
2677
2678        let sliced_batch = batch.slice(1, 2);
2679        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2680        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2681    }
2682
2683    #[test]
2684    fn test_no_columns_batch() {
2685        let schema = Arc::new(Schema::empty());
2686        let options = RecordBatchOptions::new()
2687            .with_match_field_names(true)
2688            .with_row_count(Some(10));
2689        let input_batch = RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
2690        let output_batch = roundtrip_ipc_stream(&input_batch);
2691        assert_eq!(input_batch, output_batch);
2692    }
2693
2694    #[test]
2695    fn test_unaligned() {
2696        let batch = RecordBatch::try_from_iter(vec![(
2697            "i32",
2698            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2699        )])
2700        .unwrap();
2701
2702        let gen = IpcDataGenerator {};
2703        let mut dict_tracker = DictionaryTracker::new(false);
2704        let (_, encoded) = gen
2705            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2706            .unwrap();
2707
2708        let message = root_as_message(&encoded.ipc_message).unwrap();
2709
2710        // Construct an unaligned buffer
2711        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2712        buffer.push(0_u8);
2713        buffer.extend_from_slice(&encoded.arrow_data);
2714        let b = Buffer::from(buffer).slice(1);
2715        assert_ne!(b.as_ptr().align_offset(8), 0);
2716
2717        let ipc_batch = message.header_as_record_batch().unwrap();
2718        let roundtrip = RecordBatchDecoder::try_new(
2719            &b,
2720            ipc_batch,
2721            batch.schema(),
2722            &Default::default(),
2723            &message.version(),
2724        )
2725        .unwrap()
2726        .with_require_alignment(false)
2727        .read_record_batch()
2728        .unwrap();
2729        assert_eq!(batch, roundtrip);
2730    }
2731
2732    #[test]
2733    fn test_unaligned_throws_error_with_require_alignment() {
2734        let batch = RecordBatch::try_from_iter(vec![(
2735            "i32",
2736            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2737        )])
2738        .unwrap();
2739
2740        let gen = IpcDataGenerator {};
2741        let mut dict_tracker = DictionaryTracker::new(false);
2742        let (_, encoded) = gen
2743            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2744            .unwrap();
2745
2746        let message = root_as_message(&encoded.ipc_message).unwrap();
2747
2748        // Construct an unaligned buffer
2749        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2750        buffer.push(0_u8);
2751        buffer.extend_from_slice(&encoded.arrow_data);
2752        let b = Buffer::from(buffer).slice(1);
2753        assert_ne!(b.as_ptr().align_offset(8), 0);
2754
2755        let ipc_batch = message.header_as_record_batch().unwrap();
2756        let result = RecordBatchDecoder::try_new(
2757            &b,
2758            ipc_batch,
2759            batch.schema(),
2760            &Default::default(),
2761            &message.version(),
2762        )
2763        .unwrap()
2764        .with_require_alignment(true)
2765        .read_record_batch();
2766
2767        let error = result.unwrap_err();
2768        assert_eq!(
2769            error.to_string(),
2770            "Invalid argument error: Misaligned buffers[0] in array of type Int32, \
2771             offset from expected alignment of 4 by 1"
2772        );
2773    }
2774
2775    #[test]
2776    fn test_file_with_massive_column_count() {
2777        // 499_999 is upper limit for default settings (1_000_000)
2778        let limit = 600_000;
2779
2780        let fields = (0..limit)
2781            .map(|i| Field::new(format!("{i}"), DataType::Boolean, false))
2782            .collect::<Vec<_>>();
2783        let schema = Arc::new(Schema::new(fields));
2784        let batch = RecordBatch::new_empty(schema);
2785
2786        let mut buf = Vec::new();
2787        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2788        writer.write(&batch).unwrap();
2789        writer.finish().unwrap();
2790        drop(writer);
2791
2792        let mut reader = FileReaderBuilder::new()
2793            .with_max_footer_fb_tables(1_500_000)
2794            .build(std::io::Cursor::new(buf))
2795            .unwrap();
2796        let roundtrip_batch = reader.next().unwrap().unwrap();
2797
2798        assert_eq!(batch, roundtrip_batch);
2799    }
2800
2801    #[test]
2802    fn test_file_with_deeply_nested_columns() {
2803        // 60 is upper limit for default settings (64)
2804        let limit = 61;
2805
2806        let fields = (0..limit).fold(
2807            vec![Field::new("leaf", DataType::Boolean, false)],
2808            |field, index| vec![Field::new_struct(format!("{index}"), field, false)],
2809        );
2810        let schema = Arc::new(Schema::new(fields));
2811        let batch = RecordBatch::new_empty(schema);
2812
2813        let mut buf = Vec::new();
2814        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2815        writer.write(&batch).unwrap();
2816        writer.finish().unwrap();
2817        drop(writer);
2818
2819        let mut reader = FileReaderBuilder::new()
2820            .with_max_footer_fb_depth(65)
2821            .build(std::io::Cursor::new(buf))
2822            .unwrap();
2823        let roundtrip_batch = reader.next().unwrap().unwrap();
2824
2825        assert_eq!(batch, roundtrip_batch);
2826    }
2827
2828    #[test]
2829    fn test_invalid_struct_array_ipc_read_errors() {
2830        let a_field = Field::new("a", DataType::Int32, false);
2831        let b_field = Field::new("b", DataType::Int32, false);
2832        let struct_fields = Fields::from(vec![a_field.clone(), b_field.clone()]);
2833
2834        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2835            .len(4)
2836            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2837            .build()
2838            .unwrap();
2839        let b_array_data = ArrayData::builder(b_field.data_type().clone())
2840            .len(3)
2841            .add_buffer(Buffer::from_slice_ref([5, 6, 7]))
2842            .build()
2843            .unwrap();
2844
2845        let invalid_struct_arr = unsafe {
2846            StructArray::new_unchecked(
2847                struct_fields,
2848                vec![make_array(a_array_data), make_array(b_array_data)],
2849                None,
2850            )
2851        };
2852
2853        expect_ipc_validation_error(
2854            Arc::new(invalid_struct_arr),
2855            "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
2856        );
2857    }
2858
2859    #[test]
2860    fn test_invalid_nested_array_ipc_read_errors() {
2861        // one of the nested arrays has invalid data
2862        let a_field = Field::new("a", DataType::Int32, false);
2863        let b_field = Field::new("b", DataType::Utf8, false);
2864
2865        let schema = Arc::new(Schema::new(vec![Field::new_struct(
2866            "s",
2867            vec![a_field.clone(), b_field.clone()],
2868            false,
2869        )]));
2870
2871        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2872            .len(4)
2873            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2874            .build()
2875            .unwrap();
2876        // invalid nested child array -- length is correct, but has invalid utf8 data
2877        let b_array_data = {
2878            let valid: &[u8] = b"   ";
2879            let mut invalid = vec![];
2880            invalid.extend_from_slice(b"ValidString");
2881            invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2882            let binary_array =
2883                BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2884            let array = unsafe {
2885                StringArray::new_unchecked(
2886                    binary_array.offsets().clone(),
2887                    binary_array.values().clone(),
2888                    binary_array.nulls().cloned(),
2889                )
2890            };
2891            array.into_data()
2892        };
2893        let struct_data_type = schema.field(0).data_type();
2894
2895        let invalid_struct_arr = unsafe {
2896            make_array(
2897                ArrayData::builder(struct_data_type.clone())
2898                    .len(4)
2899                    .add_child_data(a_array_data)
2900                    .add_child_data(b_array_data)
2901                    .build_unchecked(),
2902            )
2903        };
2904        expect_ipc_validation_error(
2905            Arc::new(invalid_struct_arr),
2906            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11",
2907        );
2908    }
2909
2910    #[test]
2911    fn test_same_dict_id_without_preserve() {
2912        let batch = RecordBatch::try_new(
2913            Arc::new(Schema::new(
2914                ["a", "b"]
2915                    .iter()
2916                    .map(|name| {
2917                        #[allow(deprecated)]
2918                        Field::new_dict(
2919                            name.to_string(),
2920                            DataType::Dictionary(
2921                                Box::new(DataType::Int32),
2922                                Box::new(DataType::Utf8),
2923                            ),
2924                            true,
2925                            0,
2926                            false,
2927                        )
2928                    })
2929                    .collect::<Vec<Field>>(),
2930            )),
2931            vec![
2932                Arc::new(
2933                    vec![Some("c"), Some("d")]
2934                        .into_iter()
2935                        .collect::<DictionaryArray<Int32Type>>(),
2936                ) as ArrayRef,
2937                Arc::new(
2938                    vec![Some("e"), Some("f")]
2939                        .into_iter()
2940                        .collect::<DictionaryArray<Int32Type>>(),
2941                ) as ArrayRef,
2942            ],
2943        )
2944        .expect("Failed to create RecordBatch");
2945
2946        // serialize the record batch as an IPC stream
2947        let mut buf = vec![];
2948        {
2949            let mut writer = crate::writer::StreamWriter::try_new_with_options(
2950                &mut buf,
2951                batch.schema().as_ref(),
2952                crate::writer::IpcWriteOptions::default(),
2953            )
2954            .expect("Failed to create StreamWriter");
2955            writer.write(&batch).expect("Failed to write RecordBatch");
2956            writer.finish().expect("Failed to finish StreamWriter");
2957        }
2958
2959        StreamReader::try_new(std::io::Cursor::new(buf), None)
2960            .expect("Failed to create StreamReader")
2961            .for_each(|decoded_batch| {
2962                assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
2963            });
2964    }
2965
2966    #[test]
2967    fn test_validation_of_invalid_list_array() {
2968        // ListArray with invalid offsets
2969        let array = unsafe {
2970            let values = Int32Array::from(vec![1, 2, 3]);
2971            let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); // offsets can't go backwards
2972            let offsets = OffsetBuffer::new_unchecked(bad_offsets); // INVALID array created
2973            let field = Field::new_list_field(DataType::Int32, true);
2974            let nulls = None;
2975            ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
2976        };
2977
2978        expect_ipc_validation_error(
2979            Arc::new(array),
2980            "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
2981        );
2982    }
2983
2984    #[test]
2985    fn test_validation_of_invalid_string_array() {
2986        let valid: &[u8] = b"   ";
2987        let mut invalid = vec![];
2988        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2989        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2990        let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2991        // data is not valid utf8 we can not construct a correct StringArray
2992        // safely, so purposely create an invalid StringArray
2993        let array = unsafe {
2994            StringArray::new_unchecked(
2995                binary_array.offsets().clone(),
2996                binary_array.values().clone(),
2997                binary_array.nulls().cloned(),
2998            )
2999        };
3000        expect_ipc_validation_error(
3001            Arc::new(array),
3002            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
3003        );
3004    }
3005
3006    #[test]
3007    fn test_validation_of_invalid_string_view_array() {
3008        let valid: &[u8] = b"   ";
3009        let mut invalid = vec![];
3010        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
3011        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
3012        let binary_view_array =
3013            BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
3014        // data is not valid utf8 we can not construct a correct StringArray
3015        // safely, so purposely create an invalid StringArray
3016        let array = unsafe {
3017            StringViewArray::new_unchecked(
3018                binary_view_array.views().clone(),
3019                binary_view_array.data_buffers().to_vec(),
3020                binary_view_array.nulls().cloned(),
3021            )
3022        };
3023        expect_ipc_validation_error(
3024            Arc::new(array),
3025            "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
3026        );
3027    }
3028
3029    /// return an invalid dictionary array (key is larger than values)
3030    /// ListArray with invalid offsets
3031    #[test]
3032    fn test_validation_of_invalid_dictionary_array() {
3033        let array = unsafe {
3034            let values = StringArray::from_iter_values(["a", "b", "c"]);
3035            let keys = Int32Array::from(vec![1, 200]); // keys are not valid for values
3036            DictionaryArray::new_unchecked(keys, Arc::new(values))
3037        };
3038
3039        expect_ipc_validation_error(
3040            Arc::new(array),
3041            "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
3042        );
3043    }
3044
3045    #[test]
3046    fn test_validation_of_invalid_union_array() {
3047        let array = unsafe {
3048            let fields = UnionFields::new(
3049                vec![1, 3], // typeids : type id 2 is not valid
3050                vec![
3051                    Field::new("a", DataType::Int32, false),
3052                    Field::new("b", DataType::Utf8, false),
3053                ],
3054            );
3055            let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
3056            let offsets = None;
3057            let children: Vec<ArrayRef> = vec![
3058                Arc::new(Int32Array::from(vec![10, 20, 30])),
3059                Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
3060            ];
3061
3062            UnionArray::new_unchecked(fields, type_ids, offsets, children)
3063        };
3064
3065        expect_ipc_validation_error(
3066            Arc::new(array),
3067            "Invalid argument error: Type Ids values must match one of the field type ids",
3068        );
3069    }
3070
3071    /// Invalid Utf-8 sequence in the first character
3072    /// <https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
3073    const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
3074
3075    /// Expect an error when reading the record batch using IPC or IPC Streams
3076    fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
3077        let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();
3078
3079        // IPC Stream format
3080        let buf = write_stream(&rb); // write is ok
3081        read_stream_skip_validation(&buf).unwrap();
3082        let err = read_stream(&buf).unwrap_err();
3083        assert_eq!(err.to_string(), expected_err);
3084
3085        // IPC File format
3086        let buf = write_ipc(&rb); // write is ok
3087        read_ipc_skip_validation(&buf).unwrap();
3088        let err = read_ipc(&buf).unwrap_err();
3089        assert_eq!(err.to_string(), expected_err);
3090
3091        // IPC Format with FileDecoder
3092        read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
3093        let err = read_ipc_with_decoder(buf).unwrap_err();
3094        assert_eq!(err.to_string(), expected_err);
3095    }
3096
3097    #[test]
3098    fn test_roundtrip_schema() {
3099        let schema = Schema::new(vec![
3100            Field::new(
3101                "a",
3102                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
3103                false,
3104            ),
3105            Field::new(
3106                "b",
3107                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
3108                false,
3109            ),
3110        ]);
3111
3112        let options = IpcWriteOptions::default();
3113        let data_gen = IpcDataGenerator::default();
3114        let mut dict_tracker = DictionaryTracker::new(false);
3115        let encoded_data =
3116            data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
3117        let mut schema_bytes = vec![];
3118        write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
3119
3120        let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
3121            4
3122        } else {
3123            0
3124        };
3125
3126        size_prefixed_root_as_message(&schema_bytes[begin_offset..])
3127            .expect_err("size_prefixed_root_as_message");
3128
3129        let msg = parse_message(&schema_bytes).expect("parse_message");
3130        let ipc_schema = msg.header_as_schema().expect("header_as_schema");
3131        let new_schema = fb_to_schema(ipc_schema);
3132
3133        assert_eq!(schema, new_schema);
3134    }
3135
3136    #[test]
3137    fn test_negative_meta_len() {
3138        let bytes = i32::to_le_bytes(-1);
3139        let mut buf = vec![];
3140        buf.extend(CONTINUATION_MARKER);
3141        buf.extend(bytes);
3142
3143        let reader = StreamReader::try_new(Cursor::new(buf), None);
3144        assert!(reader.is_err());
3145    }
3146}