arrow_ipc/
reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Readers
19//!
20//! # Notes
21//!
22//! The [`FileReader`] and [`StreamReader`] have similar interfaces,
23//! however the [`FileReader`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27mod stream;
28
29pub use stream::*;
30
31use flatbuffers::{VectorIter, VerifierOptions};
32use std::collections::{HashMap, VecDeque};
33use std::fmt;
34use std::io::{BufReader, Read, Seek, SeekFrom};
35use std::sync::Arc;
36
37use arrow_array::*;
38use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer};
39use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
40use arrow_schema::*;
41
42use crate::compression::CompressionCodec;
43use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER};
44use DataType::*;
45
46/// Read a buffer based on offset and length
47/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
48/// Each constituent buffer is first compressed with the indicated
49/// compressor, and then written with the uncompressed length in the first 8
50/// bytes as a 64-bit little-endian signed integer followed by the compressed
51/// buffer bytes (and then padding as required by the protocol). The
52/// uncompressed length may be set to -1 to indicate that the data that
53/// follows is not compressed, which can be useful for cases where
54/// compression does not yield appreciable savings.
55fn read_buffer(
56    buf: &crate::Buffer,
57    a_data: &Buffer,
58    compression_codec: Option<CompressionCodec>,
59) -> Result<Buffer, ArrowError> {
60    let start_offset = buf.offset() as usize;
61    let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
62    // corner case: empty buffer
63    match (buf_data.is_empty(), compression_codec) {
64        (true, _) | (_, None) => Ok(buf_data),
65        (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
66    }
67}
68impl RecordBatchDecoder<'_> {
69    /// Coordinates reading arrays based on data types.
70    ///
71    /// `variadic_counts` encodes the number of buffers to read for variadic types (e.g., Utf8View, BinaryView)
72    /// When encounter such types, we pop from the front of the queue to get the number of buffers to read.
73    ///
74    /// Notes:
75    /// * In the IPC format, null buffers are always set, but may be empty. We discard them if an array has 0 nulls
76    /// * Numeric values inside list arrays are often stored as 64-bit values regardless of their data type size.
77    ///   We thus:
78    ///     - check if the bit width of non-64-bit numbers is 64, and
79    ///     - read the buffer as 64-bit (signed integer or float), and
80    ///     - cast the 64-bit array to the appropriate data type
81    fn create_array(
82        &mut self,
83        field: &Field,
84        variadic_counts: &mut VecDeque<i64>,
85    ) -> Result<ArrayRef, ArrowError> {
86        let data_type = field.data_type();
87        match data_type {
88            Utf8 | Binary | LargeBinary | LargeUtf8 => {
89                let field_node = self.next_node(field)?;
90                let buffers = [
91                    self.next_buffer()?,
92                    self.next_buffer()?,
93                    self.next_buffer()?,
94                ];
95                self.create_primitive_array(field_node, data_type, &buffers)
96            }
97            BinaryView | Utf8View => {
98                let count = variadic_counts
99                    .pop_front()
100                    .ok_or(ArrowError::IpcError(format!(
101                        "Missing variadic count for {data_type} column"
102                    )))?;
103                let count = count + 2; // view and null buffer.
104                let buffers = (0..count)
105                    .map(|_| self.next_buffer())
106                    .collect::<Result<Vec<_>, _>>()?;
107                let field_node = self.next_node(field)?;
108                self.create_primitive_array(field_node, data_type, &buffers)
109            }
110            FixedSizeBinary(_) => {
111                let field_node = self.next_node(field)?;
112                let buffers = [self.next_buffer()?, self.next_buffer()?];
113                self.create_primitive_array(field_node, data_type, &buffers)
114            }
115            List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
116                let list_node = self.next_node(field)?;
117                let list_buffers = [self.next_buffer()?, self.next_buffer()?];
118                let values = self.create_array(list_field, variadic_counts)?;
119                self.create_list_array(list_node, data_type, &list_buffers, values)
120            }
121            FixedSizeList(ref list_field, _) => {
122                let list_node = self.next_node(field)?;
123                let list_buffers = [self.next_buffer()?];
124                let values = self.create_array(list_field, variadic_counts)?;
125                self.create_list_array(list_node, data_type, &list_buffers, values)
126            }
127            Struct(struct_fields) => {
128                let struct_node = self.next_node(field)?;
129                let null_buffer = self.next_buffer()?;
130
131                // read the arrays for each field
132                let mut struct_arrays = vec![];
133                // TODO investigate whether just knowing the number of buffers could
134                // still work
135                for struct_field in struct_fields {
136                    let child = self.create_array(struct_field, variadic_counts)?;
137                    struct_arrays.push(child);
138                }
139                self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays)
140            }
141            RunEndEncoded(run_ends_field, values_field) => {
142                let run_node = self.next_node(field)?;
143                let run_ends = self.create_array(run_ends_field, variadic_counts)?;
144                let values = self.create_array(values_field, variadic_counts)?;
145
146                let run_array_length = run_node.length() as usize;
147                let builder = ArrayData::builder(data_type.clone())
148                    .len(run_array_length)
149                    .offset(0)
150                    .add_child_data(run_ends.into_data())
151                    .add_child_data(values.into_data());
152                self.create_array_from_builder(builder)
153            }
154            // Create dictionary array from RecordBatch
155            Dictionary(_, _) => {
156                let index_node = self.next_node(field)?;
157                let index_buffers = [self.next_buffer()?, self.next_buffer()?];
158
159                #[allow(deprecated)]
160                let dict_id = field.dict_id().ok_or_else(|| {
161                    ArrowError::ParseError(format!("Field {field} does not have dict id"))
162                })?;
163
164                let value_array = self.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
165                    ArrowError::ParseError(format!(
166                        "Cannot find a dictionary batch with dict id: {dict_id}"
167                    ))
168                })?;
169
170                self.create_dictionary_array(
171                    index_node,
172                    data_type,
173                    &index_buffers,
174                    value_array.clone(),
175                )
176            }
177            Union(fields, mode) => {
178                let union_node = self.next_node(field)?;
179                let len = union_node.length() as usize;
180
181                // In V4, union types has validity bitmap
182                // In V5 and later, union types have no validity bitmap
183                if self.version < MetadataVersion::V5 {
184                    self.next_buffer()?;
185                }
186
187                let type_ids: ScalarBuffer<i8> =
188                    self.next_buffer()?.slice_with_length(0, len).into();
189
190                let value_offsets = match mode {
191                    UnionMode::Dense => {
192                        let offsets: ScalarBuffer<i32> =
193                            self.next_buffer()?.slice_with_length(0, len * 4).into();
194                        Some(offsets)
195                    }
196                    UnionMode::Sparse => None,
197                };
198
199                let mut children = Vec::with_capacity(fields.len());
200
201                for (_id, field) in fields.iter() {
202                    let child = self.create_array(field, variadic_counts)?;
203                    children.push(child);
204                }
205
206                let array = if self.skip_validation.get() {
207                    // safety: flag can only be set via unsafe code
208                    unsafe {
209                        UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children)
210                    }
211                } else {
212                    UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?
213                };
214                Ok(Arc::new(array))
215            }
216            Null => {
217                let node = self.next_node(field)?;
218                let length = node.length();
219                let null_count = node.null_count();
220
221                if length != null_count {
222                    return Err(ArrowError::SchemaError(format!(
223                        "Field {field} of NullArray has unequal null_count {null_count} and len {length}"
224                    )));
225                }
226
227                let builder = ArrayData::builder(data_type.clone())
228                    .len(length as usize)
229                    .offset(0);
230                self.create_array_from_builder(builder)
231            }
232            _ => {
233                let field_node = self.next_node(field)?;
234                let buffers = [self.next_buffer()?, self.next_buffer()?];
235                self.create_primitive_array(field_node, data_type, &buffers)
236            }
237        }
238    }
239
240    /// Reads the correct number of buffers based on data type and null_count, and creates a
241    /// primitive array ref
242    fn create_primitive_array(
243        &self,
244        field_node: &FieldNode,
245        data_type: &DataType,
246        buffers: &[Buffer],
247    ) -> Result<ArrayRef, ArrowError> {
248        let length = field_node.length() as usize;
249        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
250        let builder = match data_type {
251            Utf8 | Binary | LargeBinary | LargeUtf8 => {
252                // read 3 buffers: null buffer (optional), offsets buffer and data buffer
253                ArrayData::builder(data_type.clone())
254                    .len(length)
255                    .buffers(buffers[1..3].to_vec())
256                    .null_bit_buffer(null_buffer)
257            }
258            BinaryView | Utf8View => ArrayData::builder(data_type.clone())
259                .len(length)
260                .buffers(buffers[1..].to_vec())
261                .null_bit_buffer(null_buffer),
262            _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => {
263                // read 2 buffers: null buffer (optional) and data buffer
264                ArrayData::builder(data_type.clone())
265                    .len(length)
266                    .add_buffer(buffers[1].clone())
267                    .null_bit_buffer(null_buffer)
268            }
269            t => unreachable!("Data type {:?} either unsupported or not primitive", t),
270        };
271
272        self.create_array_from_builder(builder)
273    }
274
275    /// Update the ArrayDataBuilder based on settings in this decoder
276    fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result<ArrayRef, ArrowError> {
277        let mut builder = builder.align_buffers(!self.require_alignment);
278        if self.skip_validation.get() {
279            // SAFETY: flag can only be set via unsafe code
280            unsafe { builder = builder.skip_validation(true) }
281        };
282        Ok(make_array(builder.build()?))
283    }
284
285    /// Reads the correct number of buffers based on list type and null_count, and creates a
286    /// list array ref
287    fn create_list_array(
288        &self,
289        field_node: &FieldNode,
290        data_type: &DataType,
291        buffers: &[Buffer],
292        child_array: ArrayRef,
293    ) -> Result<ArrayRef, ArrowError> {
294        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
295        let length = field_node.length() as usize;
296        let child_data = child_array.into_data();
297        let builder = match data_type {
298            List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
299                .len(length)
300                .add_buffer(buffers[1].clone())
301                .add_child_data(child_data)
302                .null_bit_buffer(null_buffer),
303
304            FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
305                .len(length)
306                .add_child_data(child_data)
307                .null_bit_buffer(null_buffer),
308
309            _ => unreachable!("Cannot create list or map array from {:?}", data_type),
310        };
311
312        self.create_array_from_builder(builder)
313    }
314
315    fn create_struct_array(
316        &self,
317        struct_node: &FieldNode,
318        null_buffer: Buffer,
319        struct_fields: &Fields,
320        struct_arrays: Vec<ArrayRef>,
321    ) -> Result<ArrayRef, ArrowError> {
322        let null_count = struct_node.null_count() as usize;
323        let len = struct_node.length() as usize;
324
325        let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into());
326        if struct_arrays.is_empty() {
327            // `StructArray::from` can't infer the correct row count
328            // if we have zero fields
329            return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
330        }
331
332        let struct_array = if self.skip_validation.get() {
333            // safety: flag can only be set via unsafe code
334            unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) }
335        } else {
336            StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
337        };
338
339        Ok(Arc::new(struct_array))
340    }
341
342    /// Reads the correct number of buffers based on list type and null_count, and creates a
343    /// list array ref
344    fn create_dictionary_array(
345        &self,
346        field_node: &FieldNode,
347        data_type: &DataType,
348        buffers: &[Buffer],
349        value_array: ArrayRef,
350    ) -> Result<ArrayRef, ArrowError> {
351        if let Dictionary(_, _) = *data_type {
352            let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
353            let builder = ArrayData::builder(data_type.clone())
354                .len(field_node.length() as usize)
355                .add_buffer(buffers[1].clone())
356                .add_child_data(value_array.into_data())
357                .null_bit_buffer(null_buffer);
358            self.create_array_from_builder(builder)
359        } else {
360            unreachable!("Cannot create dictionary array from {:?}", data_type)
361        }
362    }
363}
364
365/// State for decoding Arrow arrays from an [IPC RecordBatch] structure to
366/// [`RecordBatch`]
367///
368/// [IPC RecordBatch]: crate::RecordBatch
369struct RecordBatchDecoder<'a> {
370    /// The flatbuffers encoded record batch
371    batch: crate::RecordBatch<'a>,
372    /// The output schema
373    schema: SchemaRef,
374    /// Decoded dictionaries indexed by dictionary id
375    dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
376    /// Optional compression codec
377    compression: Option<CompressionCodec>,
378    /// The format version
379    version: MetadataVersion,
380    /// The raw data buffer
381    data: &'a Buffer,
382    /// The fields comprising this array
383    nodes: VectorIter<'a, FieldNode>,
384    /// The buffers comprising this array
385    buffers: VectorIter<'a, crate::Buffer>,
386    /// Projection (subset of columns) to read, if any
387    /// See [`RecordBatchDecoder::with_projection`] for details
388    projection: Option<&'a [usize]>,
389    /// Are buffers required to already be aligned? See
390    /// [`RecordBatchDecoder::with_require_alignment`] for details
391    require_alignment: bool,
392    /// Should validation be skipped when reading data? Defaults to false.
393    ///
394    /// See [`FileDecoder::with_skip_validation`] for details.
395    skip_validation: UnsafeFlag,
396}
397
398impl<'a> RecordBatchDecoder<'a> {
399    /// Create a reader for decoding arrays from an encoded [`RecordBatch`]
400    fn try_new(
401        buf: &'a Buffer,
402        batch: crate::RecordBatch<'a>,
403        schema: SchemaRef,
404        dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
405        metadata: &'a MetadataVersion,
406    ) -> Result<Self, ArrowError> {
407        let buffers = batch.buffers().ok_or_else(|| {
408            ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string())
409        })?;
410        let field_nodes = batch.nodes().ok_or_else(|| {
411            ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string())
412        })?;
413
414        let batch_compression = batch.compression();
415        let compression = batch_compression
416            .map(|batch_compression| batch_compression.codec().try_into())
417            .transpose()?;
418
419        Ok(Self {
420            batch,
421            schema,
422            dictionaries_by_id,
423            compression,
424            version: *metadata,
425            data: buf,
426            nodes: field_nodes.iter(),
427            buffers: buffers.iter(),
428            projection: None,
429            require_alignment: false,
430            skip_validation: UnsafeFlag::new(),
431        })
432    }
433
434    /// Set the projection (default: None)
435    ///
436    /// If set, the projection is the list  of column indices
437    /// that will be read
438    pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self {
439        self.projection = projection;
440        self
441    }
442
443    /// Set require_alignment (default: false)
444    ///
445    /// If true, buffers must be aligned appropriately or error will
446    /// result. If false, buffers will be copied to aligned buffers
447    /// if necessary.
448    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
449        self.require_alignment = require_alignment;
450        self
451    }
452
453    /// Specifies if validation should be skipped when reading data (defaults to `false`)
454    ///
455    /// Note this API is somewhat "funky" as it allows the caller to skip validation
456    /// without having to use `unsafe` code. If this is ever made public
457    /// it should be made clearer that this is a potentially unsafe by
458    /// using an `unsafe` function that takes a boolean flag.
459    ///
460    /// # Safety
461    ///
462    /// Relies on the caller only passing a flag with `true` value if they are
463    /// certain that the data is valid
464    pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
465        self.skip_validation = skip_validation;
466        self
467    }
468
469    /// Read the record batch, consuming the reader
470    fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
471        let mut variadic_counts: VecDeque<i64> = self
472            .batch
473            .variadicBufferCounts()
474            .into_iter()
475            .flatten()
476            .collect();
477
478        let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
479
480        let schema = Arc::clone(&self.schema);
481        if let Some(projection) = self.projection {
482            let mut arrays = vec![];
483            // project fields
484            for (idx, field) in schema.fields().iter().enumerate() {
485                // Create array for projected field
486                if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
487                    let child = self.create_array(field, &mut variadic_counts)?;
488                    arrays.push((proj_idx, child));
489                } else {
490                    self.skip_field(field, &mut variadic_counts)?;
491                }
492            }
493
494            arrays.sort_by_key(|t| t.0);
495
496            let schema = Arc::new(schema.project(projection)?);
497            let columns = arrays.into_iter().map(|t| t.1).collect::<Vec<_>>();
498
499            if self.skip_validation.get() {
500                // Safety: setting `skip_validation` requires `unsafe`, user assures data is valid
501                unsafe {
502                    Ok(RecordBatch::new_unchecked(
503                        schema,
504                        columns,
505                        self.batch.length() as usize,
506                    ))
507                }
508            } else {
509                assert!(variadic_counts.is_empty());
510                RecordBatch::try_new_with_options(schema, columns, &options)
511            }
512        } else {
513            let mut children = vec![];
514            // keep track of index as lists require more than one node
515            for field in schema.fields() {
516                let child = self.create_array(field, &mut variadic_counts)?;
517                children.push(child);
518            }
519
520            if self.skip_validation.get() {
521                // Safety: setting `skip_validation` requires `unsafe`, user assures data is valid
522                unsafe {
523                    Ok(RecordBatch::new_unchecked(
524                        schema,
525                        children,
526                        self.batch.length() as usize,
527                    ))
528                }
529            } else {
530                assert!(variadic_counts.is_empty());
531                RecordBatch::try_new_with_options(schema, children, &options)
532            }
533        }
534    }
535
536    fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
537        read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
538    }
539
540    fn skip_buffer(&mut self) {
541        self.buffers.next().unwrap();
542    }
543
544    fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
545        self.nodes.next().ok_or_else(|| {
546            ArrowError::SchemaError(format!(
547                "Invalid data for schema. {field} refers to node not found in schema",
548            ))
549        })
550    }
551
552    fn skip_field(
553        &mut self,
554        field: &Field,
555        variadic_count: &mut VecDeque<i64>,
556    ) -> Result<(), ArrowError> {
557        self.next_node(field)?;
558
559        match field.data_type() {
560            Utf8 | Binary | LargeBinary | LargeUtf8 => {
561                for _ in 0..3 {
562                    self.skip_buffer()
563                }
564            }
565            Utf8View | BinaryView => {
566                let count = variadic_count
567                    .pop_front()
568                    .ok_or(ArrowError::IpcError(format!(
569                        "Missing variadic count for {} column",
570                        field.data_type()
571                    )))?;
572                let count = count + 2; // view and null buffer.
573                for _i in 0..count {
574                    self.skip_buffer()
575                }
576            }
577            FixedSizeBinary(_) => {
578                self.skip_buffer();
579                self.skip_buffer();
580            }
581            List(list_field) | LargeList(list_field) | Map(list_field, _) => {
582                self.skip_buffer();
583                self.skip_buffer();
584                self.skip_field(list_field, variadic_count)?;
585            }
586            FixedSizeList(list_field, _) => {
587                self.skip_buffer();
588                self.skip_field(list_field, variadic_count)?;
589            }
590            Struct(struct_fields) => {
591                self.skip_buffer();
592
593                // skip for each field
594                for struct_field in struct_fields {
595                    self.skip_field(struct_field, variadic_count)?
596                }
597            }
598            RunEndEncoded(run_ends_field, values_field) => {
599                self.skip_field(run_ends_field, variadic_count)?;
600                self.skip_field(values_field, variadic_count)?;
601            }
602            Dictionary(_, _) => {
603                self.skip_buffer(); // Nulls
604                self.skip_buffer(); // Indices
605            }
606            Union(fields, mode) => {
607                self.skip_buffer(); // Nulls
608
609                match mode {
610                    UnionMode::Dense => self.skip_buffer(),
611                    UnionMode::Sparse => {}
612                };
613
614                for (_, field) in fields.iter() {
615                    self.skip_field(field, variadic_count)?
616                }
617            }
618            Null => {} // No buffer increases
619            _ => {
620                self.skip_buffer();
621                self.skip_buffer();
622            }
623        };
624        Ok(())
625    }
626}
627
628/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`.
629///
630/// If `require_alignment` is true, this function will return an error if any array data in the
631/// input `buf` is not properly aligned.
632/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`].
633///
634/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer
635/// and copy over the data if any array data in the input `buf` is not properly aligned.
636/// (Properly aligned array data will remain zero-copy.)
637/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`].
638pub fn read_record_batch(
639    buf: &Buffer,
640    batch: crate::RecordBatch,
641    schema: SchemaRef,
642    dictionaries_by_id: &HashMap<i64, ArrayRef>,
643    projection: Option<&[usize]>,
644    metadata: &MetadataVersion,
645) -> Result<RecordBatch, ArrowError> {
646    RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
647        .with_projection(projection)
648        .with_require_alignment(false)
649        .read_record_batch()
650}
651
652/// Read the dictionary from the buffer and provided metadata,
653/// updating the `dictionaries_by_id` with the resulting dictionary
654pub fn read_dictionary(
655    buf: &Buffer,
656    batch: crate::DictionaryBatch,
657    schema: &Schema,
658    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
659    metadata: &MetadataVersion,
660) -> Result<(), ArrowError> {
661    read_dictionary_impl(
662        buf,
663        batch,
664        schema,
665        dictionaries_by_id,
666        metadata,
667        false,
668        UnsafeFlag::new(),
669    )
670}
671
672fn read_dictionary_impl(
673    buf: &Buffer,
674    batch: crate::DictionaryBatch,
675    schema: &Schema,
676    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
677    metadata: &MetadataVersion,
678    require_alignment: bool,
679    skip_validation: UnsafeFlag,
680) -> Result<(), ArrowError> {
681    if batch.isDelta() {
682        return Err(ArrowError::InvalidArgumentError(
683            "delta dictionary batches not supported".to_string(),
684        ));
685    }
686
687    let id = batch.id();
688    #[allow(deprecated)]
689    let fields_using_this_dictionary = schema.fields_with_dict_id(id);
690    let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
691        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
692    })?;
693
694    // As the dictionary batch does not contain the type of the
695    // values array, we need to retrieve this from the schema.
696    // Get an array representing this dictionary's values.
697    let dictionary_values: ArrayRef = match first_field.data_type() {
698        DataType::Dictionary(_, ref value_type) => {
699            // Make a fake schema for the dictionary batch.
700            let value = value_type.as_ref().clone();
701            let schema = Schema::new(vec![Field::new("", value, true)]);
702            // Read a single column
703            let record_batch = RecordBatchDecoder::try_new(
704                buf,
705                batch.data().unwrap(),
706                Arc::new(schema),
707                dictionaries_by_id,
708                metadata,
709            )?
710            .with_require_alignment(require_alignment)
711            .with_skip_validation(skip_validation)
712            .read_record_batch()?;
713
714            Some(record_batch.column(0).clone())
715        }
716        _ => None,
717    }
718    .ok_or_else(|| {
719        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
720    })?;
721
722    // We don't currently record the isOrdered field. This could be general
723    // attributes of arrays.
724    // Add (possibly multiple) array refs to the dictionaries array.
725    dictionaries_by_id.insert(id, dictionary_values.clone());
726
727    Ok(())
728}
729
730/// Read the data for a given block
731fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
732    reader.seek(SeekFrom::Start(block.offset() as u64))?;
733    let body_len = block.bodyLength().to_usize().unwrap();
734    let metadata_len = block.metaDataLength().to_usize().unwrap();
735    let total_len = body_len.checked_add(metadata_len).unwrap();
736
737    let mut buf = MutableBuffer::from_len_zeroed(total_len);
738    reader.read_exact(&mut buf)?;
739    Ok(buf.into())
740}
741
742/// Parse an encapsulated message
743///
744/// <https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format>
745fn parse_message(buf: &[u8]) -> Result<Message<'_>, ArrowError> {
746    let buf = match buf[..4] == CONTINUATION_MARKER {
747        true => &buf[8..],
748        false => &buf[4..],
749    };
750    crate::root_as_message(buf)
751        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
752}
753
754/// Read the footer length from the last 10 bytes of an Arrow IPC file
755///
756/// Expects a 4 byte footer length followed by `b"ARROW1"`
757pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
758    if buf[4..] != super::ARROW_MAGIC {
759        return Err(ArrowError::ParseError(
760            "Arrow file does not contain correct footer".to_string(),
761        ));
762    }
763
764    // read footer length
765    let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
766    footer_len
767        .try_into()
768        .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
769}
770
771/// A low-level, push-based interface for reading an IPC file
772///
773/// For a higher-level interface see [`FileReader`]
774///
775/// For an example of using this API with `mmap` see the [`zero_copy_ipc`] example.
776///
777/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
778///
779/// ```
780/// # use std::sync::Arc;
781/// # use arrow_array::*;
782/// # use arrow_array::types::Int32Type;
783/// # use arrow_buffer::Buffer;
784/// # use arrow_ipc::convert::fb_to_schema;
785/// # use arrow_ipc::reader::{FileDecoder, read_footer_length};
786/// # use arrow_ipc::root_as_footer;
787/// # use arrow_ipc::writer::FileWriter;
788/// // Write an IPC file
789///
790/// let batch = RecordBatch::try_from_iter([
791///     ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
792///     ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
793///     ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", "hello", "world"])) as _),
794/// ]).unwrap();
795///
796/// let schema = batch.schema();
797///
798/// let mut out = Vec::with_capacity(1024);
799/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap();
800/// writer.write(&batch).unwrap();
801/// writer.finish().unwrap();
802///
803/// drop(writer);
804///
805/// // Read IPC file
806///
807/// let buffer = Buffer::from_vec(out);
808/// let trailer_start = buffer.len() - 10;
809/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
810/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
811///
812/// let back = fb_to_schema(footer.schema().unwrap());
813/// assert_eq!(&back, schema.as_ref());
814///
815/// let mut decoder = FileDecoder::new(schema, footer.version());
816///
817/// // Read dictionaries
818/// for block in footer.dictionaries().iter().flatten() {
819///     let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
820///     let data = buffer.slice_with_length(block.offset() as _, block_len);
821///     decoder.read_dictionary(&block, &data).unwrap();
822/// }
823///
824/// // Read record batch
825/// let batches = footer.recordBatches().unwrap();
826/// assert_eq!(batches.len(), 1); // Only wrote a single batch
827///
828/// let block = batches.get(0);
829/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
830/// let data = buffer.slice_with_length(block.offset() as _, block_len);
831/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap();
832///
833/// assert_eq!(batch, back);
834/// ```
835#[derive(Debug)]
836pub struct FileDecoder {
837    schema: SchemaRef,
838    dictionaries: HashMap<i64, ArrayRef>,
839    version: MetadataVersion,
840    projection: Option<Vec<usize>>,
841    require_alignment: bool,
842    skip_validation: UnsafeFlag,
843}
844
845impl FileDecoder {
846    /// Create a new [`FileDecoder`] with the given schema and version
847    pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
848        Self {
849            schema,
850            version,
851            dictionaries: Default::default(),
852            projection: None,
853            require_alignment: false,
854            skip_validation: UnsafeFlag::new(),
855        }
856    }
857
858    /// Specify a projection
859    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
860        self.projection = Some(projection);
861        self
862    }
863
864    /// Specifies if the array data in input buffers is required to be properly aligned.
865    ///
866    /// If `require_alignment` is true, this decoder will return an error if any array data in the
867    /// input `buf` is not properly aligned.
868    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
869    /// [`arrow_data::ArrayData`].
870    ///
871    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
872    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
873    /// properly aligned. (Properly aligned array data will remain zero-copy.)
874    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
875    /// [`arrow_data::ArrayData`].
876    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
877        self.require_alignment = require_alignment;
878        self
879    }
880
881    /// Specifies if validation should be skipped when reading data (defaults to `false`)
882    ///
883    /// # Safety
884    ///
885    /// This flag must only be set to `true` when you trust the input data and are sure the data you are
886    /// reading is a valid Arrow IPC file, otherwise undefined behavior may
887    /// result.
888    ///
889    /// For example, some programs may wish to trust reading IPC files written
890    /// by the same process that created the files.
891    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
892        self.skip_validation.set(skip_validation);
893        self
894    }
895
896    fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
897        let message = parse_message(buf)?;
898
899        // some old test data's footer metadata is not set, so we account for that
900        if self.version != MetadataVersion::V1 && message.version() != self.version {
901            return Err(ArrowError::IpcError(
902                "Could not read IPC message as metadata versions mismatch".to_string(),
903            ));
904        }
905        Ok(message)
906    }
907
908    /// Read the dictionary with the given block and data buffer
909    pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
910        let message = self.read_message(buf)?;
911        match message.header_type() {
912            crate::MessageHeader::DictionaryBatch => {
913                let batch = message.header_as_dictionary_batch().unwrap();
914                read_dictionary_impl(
915                    &buf.slice(block.metaDataLength() as _),
916                    batch,
917                    &self.schema,
918                    &mut self.dictionaries,
919                    &message.version(),
920                    self.require_alignment,
921                    self.skip_validation.clone(),
922                )
923            }
924            t => Err(ArrowError::ParseError(format!(
925                "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
926            ))),
927        }
928    }
929
930    /// Read the RecordBatch with the given block and data buffer
931    pub fn read_record_batch(
932        &self,
933        block: &Block,
934        buf: &Buffer,
935    ) -> Result<Option<RecordBatch>, ArrowError> {
936        let message = self.read_message(buf)?;
937        match message.header_type() {
938            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
939                "Not expecting a schema when messages are read".to_string(),
940            )),
941            crate::MessageHeader::RecordBatch => {
942                let batch = message.header_as_record_batch().ok_or_else(|| {
943                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
944                })?;
945                // read the block that makes up the record batch into a buffer
946                RecordBatchDecoder::try_new(
947                    &buf.slice(block.metaDataLength() as _),
948                    batch,
949                    self.schema.clone(),
950                    &self.dictionaries,
951                    &message.version(),
952                )?
953                .with_projection(self.projection.as_deref())
954                .with_require_alignment(self.require_alignment)
955                .with_skip_validation(self.skip_validation.clone())
956                .read_record_batch()
957                .map(Some)
958            }
959            crate::MessageHeader::NONE => Ok(None),
960            t => Err(ArrowError::InvalidArgumentError(format!(
961                "Reading types other than record batches not yet supported, unable to read {t:?}"
962            ))),
963        }
964    }
965}
966
967/// Build an Arrow [`FileReader`] with custom options.
968#[derive(Debug)]
969pub struct FileReaderBuilder {
970    /// Optional projection for which columns to load (zero-based column indices)
971    projection: Option<Vec<usize>>,
972    /// Passed through to construct [`VerifierOptions`]
973    max_footer_fb_tables: usize,
974    /// Passed through to construct [`VerifierOptions`]
975    max_footer_fb_depth: usize,
976}
977
978impl Default for FileReaderBuilder {
979    fn default() -> Self {
980        let verifier_options = VerifierOptions::default();
981        Self {
982            max_footer_fb_tables: verifier_options.max_tables,
983            max_footer_fb_depth: verifier_options.max_depth,
984            projection: None,
985        }
986    }
987}
988
989impl FileReaderBuilder {
990    /// Options for creating a new [`FileReader`].
991    ///
992    /// To convert a builder into a reader, call [`FileReaderBuilder::build`].
993    pub fn new() -> Self {
994        Self::default()
995    }
996
997    /// Optional projection for which columns to load (zero-based column indices).
998    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
999        self.projection = Some(projection);
1000        self
1001    }
1002
1003    /// Flatbuffers option for parsing the footer. Controls the max number of fields and
1004    /// metadata key-value pairs that can be parsed from the schema of the footer.
1005    ///
1006    /// By default this is set to `1_000_000` which roughly translates to a schema with
1007    /// no metadata key-value pairs but 499,999 fields.
1008    ///
1009    /// This default limit is enforced to protect against malicious files with a massive
1010    /// amount of flatbuffer tables which could cause a denial of service attack.
1011    ///
1012    /// If you need to ingest a trusted file with a massive number of fields and/or
1013    /// metadata key-value pairs and are facing the error `"Unable to get root as
1014    /// footer: TooManyTables"` then increase this parameter as necessary.
1015    pub fn with_max_footer_fb_tables(mut self, max_footer_fb_tables: usize) -> Self {
1016        self.max_footer_fb_tables = max_footer_fb_tables;
1017        self
1018    }
1019
1020    /// Flatbuffers option for parsing the footer. Controls the max depth for schemas with
1021    /// nested fields parsed from the footer.
1022    ///
1023    /// By default this is set to `64` which roughly translates to a schema with
1024    /// a field nested 60 levels down through other struct fields.
1025    ///
1026    /// This default limit is enforced to protect against malicious files with a extremely
1027    /// deep flatbuffer structure which could cause a denial of service attack.
1028    ///
1029    /// If you need to ingest a trusted file with a deeply nested field and are facing the
1030    /// error `"Unable to get root as footer: DepthLimitReached"` then increase this
1031    /// parameter as necessary.
1032    pub fn with_max_footer_fb_depth(mut self, max_footer_fb_depth: usize) -> Self {
1033        self.max_footer_fb_depth = max_footer_fb_depth;
1034        self
1035    }
1036
1037    /// Build [`FileReader`] with given reader.
1038    pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<FileReader<R>, ArrowError> {
1039        // Space for ARROW_MAGIC (6 bytes) and length (4 bytes)
1040        let mut buffer = [0; 10];
1041        reader.seek(SeekFrom::End(-10))?;
1042        reader.read_exact(&mut buffer)?;
1043
1044        let footer_len = read_footer_length(buffer)?;
1045
1046        // read footer
1047        let mut footer_data = vec![0; footer_len];
1048        reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
1049        reader.read_exact(&mut footer_data)?;
1050
1051        let verifier_options = VerifierOptions {
1052            max_tables: self.max_footer_fb_tables,
1053            max_depth: self.max_footer_fb_depth,
1054            ..Default::default()
1055        };
1056        let footer = crate::root_as_footer_with_opts(&verifier_options, &footer_data[..]).map_err(
1057            |err| ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")),
1058        )?;
1059
1060        let blocks = footer.recordBatches().ok_or_else(|| {
1061            ArrowError::ParseError("Unable to get record batches from IPC Footer".to_string())
1062        })?;
1063
1064        let total_blocks = blocks.len();
1065
1066        let ipc_schema = footer.schema().unwrap();
1067        if !ipc_schema.endianness().equals_to_target_endianness() {
1068            return Err(ArrowError::IpcError(
1069                "the endianness of the source system does not match the endianness of the target system.".to_owned()
1070            ));
1071        }
1072
1073        let schema = crate::convert::fb_to_schema(ipc_schema);
1074
1075        let mut custom_metadata = HashMap::new();
1076        if let Some(fb_custom_metadata) = footer.custom_metadata() {
1077            for kv in fb_custom_metadata.into_iter() {
1078                custom_metadata.insert(
1079                    kv.key().unwrap().to_string(),
1080                    kv.value().unwrap().to_string(),
1081                );
1082            }
1083        }
1084
1085        let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
1086        if let Some(projection) = self.projection {
1087            decoder = decoder.with_projection(projection)
1088        }
1089
1090        // Create an array of optional dictionary value arrays, one per field.
1091        if let Some(dictionaries) = footer.dictionaries() {
1092            for block in dictionaries {
1093                let buf = read_block(&mut reader, block)?;
1094                decoder.read_dictionary(block, &buf)?;
1095            }
1096        }
1097
1098        Ok(FileReader {
1099            reader,
1100            blocks: blocks.iter().copied().collect(),
1101            current_block: 0,
1102            total_blocks,
1103            decoder,
1104            custom_metadata,
1105        })
1106    }
1107}
1108
1109/// Arrow File Reader
1110///
1111/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC File Format],
1112/// providing random access to the record batches.
1113///
1114/// # See Also
1115///
1116/// * [`Self::set_index`] for random access
1117/// * [`StreamReader`] for reading streaming data
1118///
1119/// # Example: Reading from a `File`
1120/// ```
1121/// # use std::io::Cursor;
1122/// use arrow_array::record_batch;
1123/// # use arrow_ipc::reader::FileReader;
1124/// # use arrow_ipc::writer::FileWriter;
1125/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1126/// # let mut file = vec![]; // mimic a stream for the example
1127/// # {
1128/// #  let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1129/// #  writer.write(&batch).unwrap();
1130/// #  writer.write(&batch).unwrap();
1131/// #  writer.finish().unwrap();
1132/// # }
1133/// # let mut file = Cursor::new(&file);
1134/// let projection = None; // read all columns
1135/// let mut reader = FileReader::try_new(&mut file, projection).unwrap();
1136/// // Position the reader to the second batch
1137/// reader.set_index(1).unwrap();
1138/// // read batches from the reader using the Iterator trait
1139/// let mut num_rows = 0;
1140/// for batch in reader {
1141///    let batch = batch.unwrap();
1142///    num_rows += batch.num_rows();
1143/// }
1144/// assert_eq!(num_rows, 3);
1145/// ```
1146/// # Example: Reading from `mmap`ed file
1147///
1148/// For an example creating Arrays without copying using  memory mapped (`mmap`)
1149/// files see the [`zero_copy_ipc`] example.
1150///
1151/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1152/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
1153pub struct FileReader<R> {
1154    /// File reader that supports reading and seeking
1155    reader: R,
1156
1157    /// The decoder
1158    decoder: FileDecoder,
1159
1160    /// The blocks in the file
1161    ///
1162    /// A block indicates the regions in the file to read to get data
1163    blocks: Vec<Block>,
1164
1165    /// A counter to keep track of the current block that should be read
1166    current_block: usize,
1167
1168    /// The total number of blocks, which may contain record batches and other types
1169    total_blocks: usize,
1170
1171    /// User defined metadata
1172    custom_metadata: HashMap<String, String>,
1173}
1174
1175impl<R> fmt::Debug for FileReader<R> {
1176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1177        f.debug_struct("FileReader<R>")
1178            .field("decoder", &self.decoder)
1179            .field("blocks", &self.blocks)
1180            .field("current_block", &self.current_block)
1181            .field("total_blocks", &self.total_blocks)
1182            .finish_non_exhaustive()
1183    }
1184}
1185
1186impl<R: Read + Seek> FileReader<BufReader<R>> {
1187    /// Try to create a new file reader with the reader wrapped in a BufReader.
1188    ///
1189    /// See [`FileReader::try_new`] for an unbuffered version.
1190    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1191        Self::try_new(BufReader::new(reader), projection)
1192    }
1193}
1194
1195impl<R: Read + Seek> FileReader<R> {
1196    /// Try to create a new file reader.
1197    ///
1198    /// There is no internal buffering. If buffered reads are needed you likely want to use
1199    /// [`FileReader::try_new_buffered`] instead.    
1200    ///
1201    /// # Errors
1202    ///
1203    /// An ['Err'](Result::Err) may be returned if:
1204    /// - the file does not meet the Arrow Format footer requirements, or
1205    /// - file endianness does not match the target endianness.
1206    pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1207        let builder = FileReaderBuilder {
1208            projection,
1209            ..Default::default()
1210        };
1211        builder.build(reader)
1212    }
1213
1214    /// Return user defined customized metadata
1215    pub fn custom_metadata(&self) -> &HashMap<String, String> {
1216        &self.custom_metadata
1217    }
1218
1219    /// Return the number of batches in the file
1220    pub fn num_batches(&self) -> usize {
1221        self.total_blocks
1222    }
1223
1224    /// Return the schema of the file
1225    pub fn schema(&self) -> SchemaRef {
1226        self.decoder.schema.clone()
1227    }
1228
1229    /// See to a specific [`RecordBatch`]
1230    ///
1231    /// Sets the current block to the index, allowing random reads
1232    pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
1233        if index >= self.total_blocks {
1234            Err(ArrowError::InvalidArgumentError(format!(
1235                "Cannot set batch to index {} from {} total batches",
1236                index, self.total_blocks
1237            )))
1238        } else {
1239            self.current_block = index;
1240            Ok(())
1241        }
1242    }
1243
1244    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1245        let block = &self.blocks[self.current_block];
1246        self.current_block += 1;
1247
1248        // read length
1249        let buffer = read_block(&mut self.reader, block)?;
1250        self.decoder.read_record_batch(block, &buffer)
1251    }
1252
1253    /// Gets a reference to the underlying reader.
1254    ///
1255    /// It is inadvisable to directly read from the underlying reader.
1256    pub fn get_ref(&self) -> &R {
1257        &self.reader
1258    }
1259
1260    /// Gets a mutable reference to the underlying reader.
1261    ///
1262    /// It is inadvisable to directly read from the underlying reader.
1263    pub fn get_mut(&mut self) -> &mut R {
1264        &mut self.reader
1265    }
1266
1267    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1268    ///
1269    /// # Safety
1270    ///
1271    /// See [`FileDecoder::with_skip_validation`]
1272    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1273        self.decoder = self.decoder.with_skip_validation(skip_validation);
1274        self
1275    }
1276}
1277
1278impl<R: Read + Seek> Iterator for FileReader<R> {
1279    type Item = Result<RecordBatch, ArrowError>;
1280
1281    fn next(&mut self) -> Option<Self::Item> {
1282        // get current block
1283        if self.current_block < self.total_blocks {
1284            self.maybe_next().transpose()
1285        } else {
1286            None
1287        }
1288    }
1289}
1290
1291impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
1292    fn schema(&self) -> SchemaRef {
1293        self.schema()
1294    }
1295}
1296
1297/// Arrow Stream Reader
1298///
1299/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC Streaming Format].
1300///
1301/// # See Also
1302///
1303/// * [`FileReader`] for random access.
1304///
1305/// # Example
1306/// ```
1307/// # use arrow_array::record_batch;
1308/// # use arrow_ipc::reader::StreamReader;
1309/// # use arrow_ipc::writer::StreamWriter;
1310/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1311/// # let mut stream = vec![]; // mimic a stream for the example
1312/// # {
1313/// #  let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1314/// #  writer.write(&batch).unwrap();
1315/// #  writer.finish().unwrap();
1316/// # }
1317/// # let stream = stream.as_slice();
1318/// let projection = None; // read all columns
1319/// let mut reader = StreamReader::try_new(stream, projection).unwrap();
1320/// // read batches from the reader using the Iterator trait
1321/// let mut num_rows = 0;
1322/// for batch in reader {
1323///    let batch = batch.unwrap();
1324///    num_rows += batch.num_rows();
1325/// }
1326/// assert_eq!(num_rows, 3);
1327/// ```
1328///
1329/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1330pub struct StreamReader<R> {
1331    /// Stream reader
1332    reader: R,
1333
1334    /// The schema that is read from the stream's first message
1335    schema: SchemaRef,
1336
1337    /// Optional dictionaries for each schema field.
1338    ///
1339    /// Dictionaries may be appended to in the streaming format.
1340    dictionaries_by_id: HashMap<i64, ArrayRef>,
1341
1342    /// An indicator of whether the stream is complete.
1343    ///
1344    /// This value is set to `true` the first time the reader's `next()` returns `None`.
1345    finished: bool,
1346
1347    /// Optional projection
1348    projection: Option<(Vec<usize>, Schema)>,
1349
1350    /// Should validation be skipped when reading data? Defaults to false.
1351    ///
1352    /// See [`FileDecoder::with_skip_validation`] for details.
1353    skip_validation: UnsafeFlag,
1354}
1355
1356impl<R> fmt::Debug for StreamReader<R> {
1357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
1358        f.debug_struct("StreamReader<R>")
1359            .field("reader", &"R")
1360            .field("schema", &self.schema)
1361            .field("dictionaries_by_id", &self.dictionaries_by_id)
1362            .field("finished", &self.finished)
1363            .field("projection", &self.projection)
1364            .finish()
1365    }
1366}
1367
1368impl<R: Read> StreamReader<BufReader<R>> {
1369    /// Try to create a new stream reader with the reader wrapped in a BufReader.
1370    ///
1371    /// See [`StreamReader::try_new`] for an unbuffered version.
1372    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1373        Self::try_new(BufReader::new(reader), projection)
1374    }
1375}
1376
1377impl<R: Read> StreamReader<R> {
1378    /// Try to create a new stream reader.
1379    ///
1380    /// To check if the reader is done, use [`is_finished(self)`](StreamReader::is_finished).
1381    ///
1382    /// There is no internal buffering. If buffered reads are needed you likely want to use
1383    /// [`StreamReader::try_new_buffered`] instead.
1384    ///
1385    /// # Errors
1386    ///
1387    /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema
1388    /// as the first message in the stream.
1389    pub fn try_new(
1390        mut reader: R,
1391        projection: Option<Vec<usize>>,
1392    ) -> Result<StreamReader<R>, ArrowError> {
1393        // determine metadata length
1394        let mut meta_size: [u8; 4] = [0; 4];
1395        reader.read_exact(&mut meta_size)?;
1396        let meta_len = {
1397            // If a continuation marker is encountered, skip over it and read
1398            // the size from the next four bytes.
1399            if meta_size == CONTINUATION_MARKER {
1400                reader.read_exact(&mut meta_size)?;
1401            }
1402            i32::from_le_bytes(meta_size)
1403        };
1404
1405        let meta_len = usize::try_from(meta_len)
1406            .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
1407        let mut meta_buffer = vec![0; meta_len];
1408        reader.read_exact(&mut meta_buffer)?;
1409
1410        let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| {
1411            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1412        })?;
1413        // message header is a Schema, so read it
1414        let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| {
1415            ArrowError::ParseError("Unable to read IPC message as schema".to_string())
1416        })?;
1417        let schema = crate::convert::fb_to_schema(ipc_schema);
1418
1419        // Create an array of optional dictionary value arrays, one per field.
1420        let dictionaries_by_id = HashMap::new();
1421
1422        let projection = match projection {
1423            Some(projection_indices) => {
1424                let schema = schema.project(&projection_indices)?;
1425                Some((projection_indices, schema))
1426            }
1427            _ => None,
1428        };
1429        Ok(Self {
1430            reader,
1431            schema: Arc::new(schema),
1432            finished: false,
1433            dictionaries_by_id,
1434            projection,
1435            skip_validation: UnsafeFlag::new(),
1436        })
1437    }
1438
1439    /// Deprecated, use [`StreamReader::try_new`] instead.
1440    #[deprecated(since = "53.0.0", note = "use `try_new` instead")]
1441    pub fn try_new_unbuffered(
1442        reader: R,
1443        projection: Option<Vec<usize>>,
1444    ) -> Result<Self, ArrowError> {
1445        Self::try_new(reader, projection)
1446    }
1447
1448    /// Return the schema of the stream
1449    pub fn schema(&self) -> SchemaRef {
1450        self.schema.clone()
1451    }
1452
1453    /// Check if the stream is finished
1454    pub fn is_finished(&self) -> bool {
1455        self.finished
1456    }
1457
1458    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1459        if self.finished {
1460            return Ok(None);
1461        }
1462        // determine metadata length
1463        let mut meta_size: [u8; 4] = [0; 4];
1464
1465        match self.reader.read_exact(&mut meta_size) {
1466            Ok(()) => (),
1467            Err(e) => {
1468                return if e.kind() == std::io::ErrorKind::UnexpectedEof {
1469                    // Handle EOF without the "0xFFFFFFFF 0x00000000"
1470                    // valid according to:
1471                    // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1472                    self.finished = true;
1473                    Ok(None)
1474                } else {
1475                    Err(ArrowError::from(e))
1476                };
1477            }
1478        }
1479
1480        let meta_len = {
1481            // If a continuation marker is encountered, skip over it and read
1482            // the size from the next four bytes.
1483            if meta_size == CONTINUATION_MARKER {
1484                self.reader.read_exact(&mut meta_size)?;
1485            }
1486            i32::from_le_bytes(meta_size)
1487        };
1488
1489        let meta_len = usize::try_from(meta_len)
1490            .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
1491
1492        if meta_len == 0 {
1493            // the stream has ended, mark the reader as finished
1494            self.finished = true;
1495            return Ok(None);
1496        }
1497
1498        let mut meta_buffer = vec![0; meta_len];
1499        self.reader.read_exact(&mut meta_buffer)?;
1500
1501        let vecs = &meta_buffer.to_vec();
1502        let message = crate::root_as_message(vecs).map_err(|err| {
1503            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1504        })?;
1505
1506        match message.header_type() {
1507            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
1508                "Not expecting a schema when messages are read".to_string(),
1509            )),
1510            crate::MessageHeader::RecordBatch => {
1511                let batch = message.header_as_record_batch().ok_or_else(|| {
1512                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1513                })?;
1514                // read the block that makes up the record batch into a buffer
1515                let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1516                self.reader.read_exact(&mut buf)?;
1517
1518                RecordBatchDecoder::try_new(
1519                    &buf.into(),
1520                    batch,
1521                    self.schema(),
1522                    &self.dictionaries_by_id,
1523                    &message.version(),
1524                )?
1525                .with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
1526                .with_require_alignment(false)
1527                .with_skip_validation(self.skip_validation.clone())
1528                .read_record_batch()
1529                .map(Some)
1530            }
1531            crate::MessageHeader::DictionaryBatch => {
1532                let batch = message.header_as_dictionary_batch().ok_or_else(|| {
1533                    ArrowError::IpcError(
1534                        "Unable to read IPC message as dictionary batch".to_string(),
1535                    )
1536                })?;
1537                // read the block that makes up the dictionary batch into a buffer
1538                let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1539                self.reader.read_exact(&mut buf)?;
1540
1541                read_dictionary_impl(
1542                    &buf.into(),
1543                    batch,
1544                    &self.schema,
1545                    &mut self.dictionaries_by_id,
1546                    &message.version(),
1547                    false,
1548                    self.skip_validation.clone(),
1549                )?;
1550
1551                // read the next message until we encounter a RecordBatch
1552                self.maybe_next()
1553            }
1554            crate::MessageHeader::NONE => Ok(None),
1555            t => Err(ArrowError::InvalidArgumentError(format!(
1556                "Reading types other than record batches not yet supported, unable to read {t:?} "
1557            ))),
1558        }
1559    }
1560
1561    /// Gets a reference to the underlying reader.
1562    ///
1563    /// It is inadvisable to directly read from the underlying reader.
1564    pub fn get_ref(&self) -> &R {
1565        &self.reader
1566    }
1567
1568    /// Gets a mutable reference to the underlying reader.
1569    ///
1570    /// It is inadvisable to directly read from the underlying reader.
1571    pub fn get_mut(&mut self) -> &mut R {
1572        &mut self.reader
1573    }
1574
1575    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1576    ///
1577    /// # Safety
1578    ///
1579    /// See [`FileDecoder::with_skip_validation`]
1580    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1581        self.skip_validation.set(skip_validation);
1582        self
1583    }
1584}
1585
1586impl<R: Read> Iterator for StreamReader<R> {
1587    type Item = Result<RecordBatch, ArrowError>;
1588
1589    fn next(&mut self) -> Option<Self::Item> {
1590        self.maybe_next().transpose()
1591    }
1592}
1593
1594impl<R: Read> RecordBatchReader for StreamReader<R> {
1595    fn schema(&self) -> SchemaRef {
1596        self.schema.clone()
1597    }
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602    use std::io::Cursor;
1603
1604    use crate::convert::fb_to_schema;
1605    use crate::writer::{
1606        unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1607    };
1608
1609    use super::*;
1610
1611    use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
1612    use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
1613    use arrow_array::types::*;
1614    use arrow_buffer::{NullBuffer, OffsetBuffer};
1615    use arrow_data::ArrayDataBuilder;
1616
1617    fn create_test_projection_schema() -> Schema {
1618        // define field types
1619        let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
1620
1621        let fixed_size_list_data_type =
1622            DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3);
1623
1624        let union_fields = UnionFields::new(
1625            vec![0, 1],
1626            vec![
1627                Field::new("a", DataType::Int32, false),
1628                Field::new("b", DataType::Float64, false),
1629            ],
1630        );
1631
1632        let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
1633
1634        let struct_fields = Fields::from(vec![
1635            Field::new("id", DataType::Int32, false),
1636            Field::new_list("list", Field::new_list_field(DataType::Int8, true), false),
1637        ]);
1638        let struct_data_type = DataType::Struct(struct_fields);
1639
1640        let run_encoded_data_type = DataType::RunEndEncoded(
1641            Arc::new(Field::new("run_ends", DataType::Int16, false)),
1642            Arc::new(Field::new("values", DataType::Int32, true)),
1643        );
1644
1645        // define schema
1646        Schema::new(vec![
1647            Field::new("f0", DataType::UInt32, false),
1648            Field::new("f1", DataType::Utf8, false),
1649            Field::new("f2", DataType::Boolean, false),
1650            Field::new("f3", union_data_type, true),
1651            Field::new("f4", DataType::Null, true),
1652            Field::new("f5", DataType::Float64, true),
1653            Field::new("f6", list_data_type, false),
1654            Field::new("f7", DataType::FixedSizeBinary(3), true),
1655            Field::new("f8", fixed_size_list_data_type, false),
1656            Field::new("f9", struct_data_type, false),
1657            Field::new("f10", run_encoded_data_type, false),
1658            Field::new("f11", DataType::Boolean, false),
1659            Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
1660            Field::new("f13", DataType::Utf8, false),
1661        ])
1662    }
1663
1664    fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
1665        // set test data for each column
1666        let array0 = UInt32Array::from(vec![1, 2, 3]);
1667        let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
1668        let array2 = BooleanArray::from(vec![true, false, true]);
1669
1670        let mut union_builder = UnionBuilder::new_dense();
1671        union_builder.append::<Int32Type>("a", 1).unwrap();
1672        union_builder.append::<Float64Type>("b", 10.1).unwrap();
1673        union_builder.append_null::<Float64Type>("b").unwrap();
1674        let array3 = union_builder.build().unwrap();
1675
1676        let array4 = NullArray::new(3);
1677        let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
1678        let array6_values = vec![
1679            Some(vec![Some(10), Some(10), Some(10)]),
1680            Some(vec![Some(20), Some(20), Some(20)]),
1681            Some(vec![Some(30), Some(30)]),
1682        ];
1683        let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
1684        let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
1685        let array7 = FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
1686
1687        let array8_values = ArrayData::builder(DataType::Int32)
1688            .len(9)
1689            .add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
1690            .build()
1691            .unwrap();
1692        let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
1693            .len(3)
1694            .add_child_data(array8_values)
1695            .build()
1696            .unwrap();
1697        let array8 = FixedSizeListArray::from(array8_data);
1698
1699        let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
1700        let array9_list: ArrayRef =
1701            Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
1702                Some(vec![Some(-10)]),
1703                Some(vec![Some(-20), Some(-20), Some(-20)]),
1704                Some(vec![Some(-30)]),
1705            ]));
1706        let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
1707            .add_child_data(array9_id.into_data())
1708            .add_child_data(array9_list.into_data())
1709            .len(3)
1710            .build()
1711            .unwrap();
1712        let array9: ArrayRef = Arc::new(StructArray::from(array9));
1713
1714        let array10_input = vec![Some(1_i32), None, None];
1715        let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
1716        array10_builder.extend(array10_input);
1717        let array10 = array10_builder.finish();
1718
1719        let array11 = BooleanArray::from(vec![false, false, true]);
1720
1721        let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
1722        let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
1723        let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
1724
1725        let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
1726
1727        // create record batch
1728        RecordBatch::try_new(
1729            Arc::new(schema.clone()),
1730            vec![
1731                Arc::new(array0),
1732                Arc::new(array1),
1733                Arc::new(array2),
1734                Arc::new(array3),
1735                Arc::new(array4),
1736                Arc::new(array5),
1737                Arc::new(array6),
1738                Arc::new(array7),
1739                Arc::new(array8),
1740                Arc::new(array9),
1741                Arc::new(array10),
1742                Arc::new(array11),
1743                Arc::new(array12),
1744                Arc::new(array13),
1745            ],
1746        )
1747        .unwrap()
1748    }
1749
1750    #[test]
1751    fn test_negative_meta_len_start_stream() {
1752        let bytes = i32::to_le_bytes(-1);
1753        let mut buf = vec![];
1754        buf.extend(CONTINUATION_MARKER);
1755        buf.extend(bytes);
1756
1757        let reader_err = StreamReader::try_new(Cursor::new(buf), None).err();
1758        assert!(reader_err.is_some());
1759        assert_eq!(
1760            reader_err.unwrap().to_string(),
1761            "Parser error: Invalid metadata length: -1"
1762        );
1763    }
1764
1765    #[test]
1766    fn test_negative_meta_len_mid_stream() {
1767        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1768        let mut buf = Vec::new();
1769        {
1770            let mut writer = crate::writer::StreamWriter::try_new(&mut buf, &schema).unwrap();
1771            let batch =
1772                RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int32Array::from(vec![1]))])
1773                    .unwrap();
1774            writer.write(&batch).unwrap();
1775        }
1776
1777        let bytes = i32::to_le_bytes(-1);
1778        buf.extend(CONTINUATION_MARKER);
1779        buf.extend(bytes);
1780
1781        let mut reader = StreamReader::try_new(Cursor::new(buf), None).unwrap();
1782        // Read the valid value
1783        assert!(reader.maybe_next().is_ok());
1784        // Read the invalid meta len
1785        let batch_err = reader.maybe_next().err();
1786        assert!(batch_err.is_some());
1787        assert_eq!(
1788            batch_err.unwrap().to_string(),
1789            "Parser error: Invalid metadata length: -1"
1790        );
1791    }
1792
1793    #[test]
1794    fn test_projection_array_values() {
1795        // define schema
1796        let schema = create_test_projection_schema();
1797
1798        // create record batch with test data
1799        let batch = create_test_projection_batch_data(&schema);
1800
1801        // write record batch in IPC format
1802        let mut buf = Vec::new();
1803        {
1804            let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1805            writer.write(&batch).unwrap();
1806            writer.finish().unwrap();
1807        }
1808
1809        // read record batch with projection
1810        for index in 0..12 {
1811            let projection = vec![index];
1812            let reader = FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
1813            let read_batch = reader.unwrap().next().unwrap().unwrap();
1814            let projected_column = read_batch.column(0);
1815            let expected_column = batch.column(index);
1816
1817            // check the projected column equals the expected column
1818            assert_eq!(projected_column.as_ref(), expected_column.as_ref());
1819        }
1820
1821        {
1822            // read record batch with reversed projection
1823            let reader =
1824                FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(vec![3, 2, 1]));
1825            let read_batch = reader.unwrap().next().unwrap().unwrap();
1826            let expected_batch = batch.project(&[3, 2, 1]).unwrap();
1827            assert_eq!(read_batch, expected_batch);
1828        }
1829    }
1830
1831    #[test]
1832    fn test_arrow_single_float_row() {
1833        let schema = Schema::new(vec![
1834            Field::new("a", DataType::Float32, false),
1835            Field::new("b", DataType::Float32, false),
1836            Field::new("c", DataType::Int32, false),
1837            Field::new("d", DataType::Int32, false),
1838        ]);
1839        let arrays = vec![
1840            Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
1841            Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
1842            Arc::new(Int32Array::from(vec![2])) as ArrayRef,
1843            Arc::new(Int32Array::from(vec![1])) as ArrayRef,
1844        ];
1845        let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
1846        // create stream writer
1847        let mut file = tempfile::tempfile().unwrap();
1848        let mut stream_writer = crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
1849        stream_writer.write(&batch).unwrap();
1850        stream_writer.finish().unwrap();
1851
1852        drop(stream_writer);
1853
1854        file.rewind().unwrap();
1855
1856        // read stream back
1857        let reader = StreamReader::try_new(&mut file, None).unwrap();
1858
1859        reader.for_each(|batch| {
1860            let batch = batch.unwrap();
1861            assert!(
1862                batch
1863                    .column(0)
1864                    .as_any()
1865                    .downcast_ref::<Float32Array>()
1866                    .unwrap()
1867                    .value(0)
1868                    != 0.0
1869            );
1870            assert!(
1871                batch
1872                    .column(1)
1873                    .as_any()
1874                    .downcast_ref::<Float32Array>()
1875                    .unwrap()
1876                    .value(0)
1877                    != 0.0
1878            );
1879        });
1880
1881        file.rewind().unwrap();
1882
1883        // Read with projection
1884        let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
1885
1886        reader.for_each(|batch| {
1887            let batch = batch.unwrap();
1888            assert_eq!(batch.schema().fields().len(), 2);
1889            assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
1890            assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
1891        });
1892    }
1893
1894    /// Write the record batch to an in-memory buffer in IPC File format
1895    fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
1896        let mut buf = Vec::new();
1897        let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1898        writer.write(rb).unwrap();
1899        writer.finish().unwrap();
1900        buf
1901    }
1902
1903    /// Return the first record batch read from the IPC File buffer
1904    fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1905        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
1906        reader.next().unwrap()
1907    }
1908
1909    /// Return the first record batch read from the IPC File buffer, disabling
1910    /// validation
1911    fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1912        let mut reader = unsafe {
1913            FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1914        };
1915        reader.next().unwrap()
1916    }
1917
1918    fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
1919        let buf = write_ipc(rb);
1920        read_ipc(&buf).unwrap()
1921    }
1922
1923    /// Return the first record batch read from the IPC File buffer
1924    /// using the FileDecoder API
1925    fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1926        read_ipc_with_decoder_inner(buf, false)
1927    }
1928
1929    /// Return the first record batch read from the IPC File buffer
1930    /// using the FileDecoder API, disabling validation
1931    fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1932        read_ipc_with_decoder_inner(buf, true)
1933    }
1934
1935    fn read_ipc_with_decoder_inner(
1936        buf: Vec<u8>,
1937        skip_validation: bool,
1938    ) -> Result<RecordBatch, ArrowError> {
1939        let buffer = Buffer::from_vec(buf);
1940        let trailer_start = buffer.len() - 10;
1941        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
1942        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
1943            .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;
1944
1945        let schema = fb_to_schema(footer.schema().unwrap());
1946
1947        let mut decoder = unsafe {
1948            FileDecoder::new(Arc::new(schema), footer.version())
1949                .with_skip_validation(skip_validation)
1950        };
1951        // Read dictionaries
1952        for block in footer.dictionaries().iter().flatten() {
1953            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1954            let data = buffer.slice_with_length(block.offset() as _, block_len);
1955            decoder.read_dictionary(block, &data)?
1956        }
1957
1958        // Read record batch
1959        let batches = footer.recordBatches().unwrap();
1960        assert_eq!(batches.len(), 1); // Only wrote a single batch
1961
1962        let block = batches.get(0);
1963        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1964        let data = buffer.slice_with_length(block.offset() as _, block_len);
1965        Ok(decoder.read_record_batch(block, &data)?.unwrap())
1966    }
1967
1968    /// Write the record batch to an in-memory buffer in IPC Stream format
1969    fn write_stream(rb: &RecordBatch) -> Vec<u8> {
1970        let mut buf = Vec::new();
1971        let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1972        writer.write(rb).unwrap();
1973        writer.finish().unwrap();
1974        buf
1975    }
1976
1977    /// Return the first record batch read from the IPC Stream buffer
1978    fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1979        let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
1980        reader.next().unwrap()
1981    }
1982
1983    /// Return the first record batch read from the IPC Stream buffer,
1984    /// disabling validation
1985    fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1986        let mut reader = unsafe {
1987            StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1988        };
1989        reader.next().unwrap()
1990    }
1991
1992    fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
1993        let buf = write_stream(rb);
1994        read_stream(&buf).unwrap()
1995    }
1996
1997    #[test]
1998    fn test_roundtrip_with_custom_metadata() {
1999        let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
2000        let mut buf = Vec::new();
2001        let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
2002        let mut test_metadata = HashMap::new();
2003        test_metadata.insert("abc".to_string(), "abc".to_string());
2004        test_metadata.insert("def".to_string(), "def".to_string());
2005        for (k, v) in &test_metadata {
2006            writer.write_metadata(k, v);
2007        }
2008        writer.finish().unwrap();
2009        drop(writer);
2010
2011        let reader = crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2012        assert_eq!(reader.custom_metadata(), &test_metadata);
2013    }
2014
2015    #[test]
2016    fn test_roundtrip_nested_dict() {
2017        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2018
2019        let array = Arc::new(inner) as ArrayRef;
2020
2021        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2022
2023        let s = StructArray::from(vec![(dctfield, array)]);
2024        let struct_array = Arc::new(s) as ArrayRef;
2025
2026        let schema = Arc::new(Schema::new(vec![Field::new(
2027            "struct",
2028            struct_array.data_type().clone(),
2029            false,
2030        )]));
2031
2032        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2033
2034        assert_eq!(batch, roundtrip_ipc(&batch));
2035    }
2036
2037    #[test]
2038    fn test_roundtrip_nested_dict_no_preserve_dict_id() {
2039        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2040
2041        let array = Arc::new(inner) as ArrayRef;
2042
2043        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2044
2045        let s = StructArray::from(vec![(dctfield, array)]);
2046        let struct_array = Arc::new(s) as ArrayRef;
2047
2048        let schema = Arc::new(Schema::new(vec![Field::new(
2049            "struct",
2050            struct_array.data_type().clone(),
2051            false,
2052        )]));
2053
2054        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2055
2056        let mut buf = Vec::new();
2057        let mut writer = crate::writer::FileWriter::try_new_with_options(
2058            &mut buf,
2059            batch.schema_ref(),
2060            IpcWriteOptions::default(),
2061        )
2062        .unwrap();
2063        writer.write(&batch).unwrap();
2064        writer.finish().unwrap();
2065        drop(writer);
2066
2067        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2068
2069        assert_eq!(batch, reader.next().unwrap().unwrap());
2070    }
2071
2072    fn check_union_with_builder(mut builder: UnionBuilder) {
2073        builder.append::<Int32Type>("a", 1).unwrap();
2074        builder.append_null::<Int32Type>("a").unwrap();
2075        builder.append::<Float64Type>("c", 3.0).unwrap();
2076        builder.append::<Int32Type>("a", 4).unwrap();
2077        builder.append::<Int64Type>("d", 11).unwrap();
2078        let union = builder.build().unwrap();
2079
2080        let schema = Arc::new(Schema::new(vec![Field::new(
2081            "union",
2082            union.data_type().clone(),
2083            false,
2084        )]));
2085
2086        let union_array = Arc::new(union) as ArrayRef;
2087
2088        let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
2089        let rb2 = roundtrip_ipc(&rb);
2090        // TODO: equality not yet implemented for union, so we check that the length of the array is
2091        // the same and that all of the buffers are the same instead.
2092        assert_eq!(rb.schema(), rb2.schema());
2093        assert_eq!(rb.num_columns(), rb2.num_columns());
2094        assert_eq!(rb.num_rows(), rb2.num_rows());
2095        let union1 = rb.column(0);
2096        let union2 = rb2.column(0);
2097
2098        assert_eq!(union1, union2);
2099    }
2100
2101    #[test]
2102    fn test_roundtrip_dense_union() {
2103        check_union_with_builder(UnionBuilder::new_dense());
2104    }
2105
2106    #[test]
2107    fn test_roundtrip_sparse_union() {
2108        check_union_with_builder(UnionBuilder::new_sparse());
2109    }
2110
2111    #[test]
2112    fn test_roundtrip_struct_empty_fields() {
2113        let nulls = NullBuffer::from(&[true, true, false]);
2114        let rb = RecordBatch::try_from_iter([(
2115            "",
2116            Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _,
2117        )])
2118        .unwrap();
2119        let rb2 = roundtrip_ipc(&rb);
2120        assert_eq!(rb, rb2);
2121    }
2122
2123    #[test]
2124    fn test_roundtrip_stream_run_array_sliced() {
2125        let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
2126            .into_iter()
2127            .collect();
2128        let run_array_1_sliced = run_array_1.slice(2, 5);
2129
2130        let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
2131        let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
2132        run_array_2_builder.extend(run_array_2_inupt);
2133        let run_array_2 = run_array_2_builder.finish();
2134
2135        let schema = Arc::new(Schema::new(vec![
2136            Field::new(
2137                "run_array_1_sliced",
2138                run_array_1_sliced.data_type().clone(),
2139                false,
2140            ),
2141            Field::new("run_array_2", run_array_2.data_type().clone(), false),
2142        ]));
2143        let input_batch = RecordBatch::try_new(
2144            schema,
2145            vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
2146        )
2147        .unwrap();
2148        let output_batch = roundtrip_ipc_stream(&input_batch);
2149
2150        // As partial comparison not yet supported for run arrays, the sliced run array
2151        // has to be unsliced before comparing with the output. the second run array
2152        // can be compared as such.
2153        assert_eq!(input_batch.column(1), output_batch.column(1));
2154
2155        let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap();
2156        assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
2157    }
2158
2159    #[test]
2160    fn test_roundtrip_stream_nested_dict() {
2161        let xs = vec!["AA", "BB", "AA", "CC", "BB"];
2162        let dict = Arc::new(
2163            xs.clone()
2164                .into_iter()
2165                .collect::<DictionaryArray<Int8Type>>(),
2166        );
2167        let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
2168        let struct_array = StructArray::from(vec![
2169            (
2170                Arc::new(Field::new("f2.1", DataType::Utf8, false)),
2171                string_array,
2172            ),
2173            (
2174                Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
2175                dict.clone() as ArrayRef,
2176            ),
2177        ]);
2178        let schema = Arc::new(Schema::new(vec![
2179            Field::new("f1_string", DataType::Utf8, false),
2180            Field::new("f2_struct", struct_array.data_type().clone(), false),
2181        ]));
2182        let input_batch = RecordBatch::try_new(
2183            schema,
2184            vec![
2185                Arc::new(StringArray::from(xs.clone())),
2186                Arc::new(struct_array),
2187            ],
2188        )
2189        .unwrap();
2190        let output_batch = roundtrip_ipc_stream(&input_batch);
2191        assert_eq!(input_batch, output_batch);
2192    }
2193
2194    #[test]
2195    fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
2196        let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
2197        let values = Arc::new(values) as ArrayRef;
2198        let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
2199        let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
2200
2201        let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
2202        let key_dict_array = DictionaryArray::new(key_dict_keys, values);
2203
2204        #[allow(deprecated)]
2205        let keys_field = Arc::new(Field::new_dict(
2206            "keys",
2207            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2208            true, // It is technically not legal for this field to be null.
2209            1,
2210            false,
2211        ));
2212        #[allow(deprecated)]
2213        let values_field = Arc::new(Field::new_dict(
2214            "values",
2215            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2216            true,
2217            2,
2218            false,
2219        ));
2220        let entry_struct = StructArray::from(vec![
2221            (keys_field, make_array(key_dict_array.into_data())),
2222            (values_field, make_array(value_dict_array.into_data())),
2223        ]);
2224        let map_data_type = DataType::Map(
2225            Arc::new(Field::new(
2226                "entries",
2227                entry_struct.data_type().clone(),
2228                false,
2229            )),
2230            false,
2231        );
2232
2233        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
2234        let map_data = ArrayData::builder(map_data_type)
2235            .len(3)
2236            .add_buffer(entry_offsets)
2237            .add_child_data(entry_struct.into_data())
2238            .build()
2239            .unwrap();
2240        let map_array = MapArray::from(map_data);
2241
2242        let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
2243        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2244
2245        let schema = Arc::new(Schema::new(vec![Field::new(
2246            "f1",
2247            dict_dict_array.data_type().clone(),
2248            false,
2249        )]));
2250        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2251        let output_batch = roundtrip_ipc_stream(&input_batch);
2252        assert_eq!(input_batch, output_batch);
2253    }
2254
2255    fn test_roundtrip_stream_dict_of_list_of_dict_impl<
2256        OffsetSize: OffsetSizeTrait,
2257        U: ArrowNativeType,
2258    >(
2259        list_data_type: DataType,
2260        offsets: &[U; 5],
2261    ) {
2262        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2263        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2264        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2265        let dict_data = dict_array.to_data();
2266
2267        let value_offsets = Buffer::from_slice_ref(offsets);
2268
2269        let list_data = ArrayData::builder(list_data_type)
2270            .len(4)
2271            .add_buffer(value_offsets)
2272            .add_child_data(dict_data)
2273            .build()
2274            .unwrap();
2275        let list_array = GenericListArray::<OffsetSize>::from(list_data);
2276
2277        let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
2278        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2279
2280        let schema = Arc::new(Schema::new(vec![Field::new(
2281            "f1",
2282            dict_dict_array.data_type().clone(),
2283            false,
2284        )]));
2285        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2286        let output_batch = roundtrip_ipc_stream(&input_batch);
2287        assert_eq!(input_batch, output_batch);
2288    }
2289
2290    #[test]
2291    fn test_roundtrip_stream_dict_of_list_of_dict() {
2292        // list
2293        #[allow(deprecated)]
2294        let list_data_type = DataType::List(Arc::new(Field::new_dict(
2295            "item",
2296            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2297            true,
2298            1,
2299            false,
2300        )));
2301        let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
2302        test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(list_data_type, offsets);
2303
2304        // large list
2305        #[allow(deprecated)]
2306        let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
2307            "item",
2308            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2309            true,
2310            1,
2311            false,
2312        )));
2313        let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
2314        test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(list_data_type, offsets);
2315    }
2316
2317    #[test]
2318    fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
2319        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2320        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
2321        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2322        let dict_data = dict_array.into_data();
2323
2324        #[allow(deprecated)]
2325        let list_data_type = DataType::FixedSizeList(
2326            Arc::new(Field::new_dict(
2327                "item",
2328                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2329                true,
2330                1,
2331                false,
2332            )),
2333            3,
2334        );
2335        let list_data = ArrayData::builder(list_data_type)
2336            .len(3)
2337            .add_child_data(dict_data)
2338            .build()
2339            .unwrap();
2340        let list_array = FixedSizeListArray::from(list_data);
2341
2342        let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2343        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2344
2345        let schema = Arc::new(Schema::new(vec![Field::new(
2346            "f1",
2347            dict_dict_array.data_type().clone(),
2348            false,
2349        )]));
2350        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2351        let output_batch = roundtrip_ipc_stream(&input_batch);
2352        assert_eq!(input_batch, output_batch);
2353    }
2354
2355    const LONG_TEST_STRING: &str =
2356        "This is a long string to make sure binary view array handles it";
2357
2358    #[test]
2359    fn test_roundtrip_view_types() {
2360        let schema = Schema::new(vec![
2361            Field::new("field_1", DataType::BinaryView, true),
2362            Field::new("field_2", DataType::Utf8, true),
2363            Field::new("field_3", DataType::Utf8View, true),
2364        ]);
2365        let bin_values: Vec<Option<&[u8]>> = vec![
2366            Some(b"foo"),
2367            None,
2368            Some(b"bar"),
2369            Some(LONG_TEST_STRING.as_bytes()),
2370        ];
2371        let utf8_values: Vec<Option<&str>> =
2372            vec![Some("foo"), None, Some("bar"), Some(LONG_TEST_STRING)];
2373        let bin_view_array = BinaryViewArray::from_iter(bin_values);
2374        let utf8_array = StringArray::from_iter(utf8_values.iter());
2375        let utf8_view_array = StringViewArray::from_iter(utf8_values);
2376        let record_batch = RecordBatch::try_new(
2377            Arc::new(schema.clone()),
2378            vec![
2379                Arc::new(bin_view_array),
2380                Arc::new(utf8_array),
2381                Arc::new(utf8_view_array),
2382            ],
2383        )
2384        .unwrap();
2385
2386        assert_eq!(record_batch, roundtrip_ipc(&record_batch));
2387        assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));
2388
2389        let sliced_batch = record_batch.slice(1, 2);
2390        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2391        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2392    }
2393
2394    #[test]
2395    fn test_roundtrip_view_types_nested_dict() {
2396        let bin_values: Vec<Option<&[u8]>> = vec![
2397            Some(b"foo"),
2398            None,
2399            Some(b"bar"),
2400            Some(LONG_TEST_STRING.as_bytes()),
2401            Some(b"field"),
2402        ];
2403        let utf8_values: Vec<Option<&str>> = vec![
2404            Some("foo"),
2405            None,
2406            Some("bar"),
2407            Some(LONG_TEST_STRING),
2408            Some("field"),
2409        ];
2410        let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
2411        let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));
2412
2413        let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2414        let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
2415        #[allow(deprecated)]
2416        let keys_field = Arc::new(Field::new_dict(
2417            "keys",
2418            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
2419            true,
2420            1,
2421            false,
2422        ));
2423
2424        let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
2425        let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
2426        #[allow(deprecated)]
2427        let values_field = Arc::new(Field::new_dict(
2428            "values",
2429            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
2430            true,
2431            2,
2432            false,
2433        ));
2434        let entry_struct = StructArray::from(vec![
2435            (keys_field, make_array(key_dict_array.into_data())),
2436            (values_field, make_array(value_dict_array.into_data())),
2437        ]);
2438
2439        let map_data_type = DataType::Map(
2440            Arc::new(Field::new(
2441                "entries",
2442                entry_struct.data_type().clone(),
2443                false,
2444            )),
2445            false,
2446        );
2447        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
2448        let map_data = ArrayData::builder(map_data_type)
2449            .len(3)
2450            .add_buffer(entry_offsets)
2451            .add_child_data(entry_struct.into_data())
2452            .build()
2453            .unwrap();
2454        let map_array = MapArray::from(map_data);
2455
2456        let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2457        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2458        let schema = Arc::new(Schema::new(vec![Field::new(
2459            "f1",
2460            dict_dict_array.data_type().clone(),
2461            false,
2462        )]));
2463        let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2464        assert_eq!(batch, roundtrip_ipc(&batch));
2465        assert_eq!(batch, roundtrip_ipc_stream(&batch));
2466
2467        let sliced_batch = batch.slice(1, 2);
2468        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2469        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2470    }
2471
2472    #[test]
2473    fn test_no_columns_batch() {
2474        let schema = Arc::new(Schema::empty());
2475        let options = RecordBatchOptions::new()
2476            .with_match_field_names(true)
2477            .with_row_count(Some(10));
2478        let input_batch = RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
2479        let output_batch = roundtrip_ipc_stream(&input_batch);
2480        assert_eq!(input_batch, output_batch);
2481    }
2482
2483    #[test]
2484    fn test_unaligned() {
2485        let batch = RecordBatch::try_from_iter(vec![(
2486            "i32",
2487            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2488        )])
2489        .unwrap();
2490
2491        let gen = IpcDataGenerator {};
2492        let mut dict_tracker = DictionaryTracker::new(false);
2493        let (_, encoded) = gen
2494            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2495            .unwrap();
2496
2497        let message = root_as_message(&encoded.ipc_message).unwrap();
2498
2499        // Construct an unaligned buffer
2500        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2501        buffer.push(0_u8);
2502        buffer.extend_from_slice(&encoded.arrow_data);
2503        let b = Buffer::from(buffer).slice(1);
2504        assert_ne!(b.as_ptr().align_offset(8), 0);
2505
2506        let ipc_batch = message.header_as_record_batch().unwrap();
2507        let roundtrip = RecordBatchDecoder::try_new(
2508            &b,
2509            ipc_batch,
2510            batch.schema(),
2511            &Default::default(),
2512            &message.version(),
2513        )
2514        .unwrap()
2515        .with_require_alignment(false)
2516        .read_record_batch()
2517        .unwrap();
2518        assert_eq!(batch, roundtrip);
2519    }
2520
2521    #[test]
2522    fn test_unaligned_throws_error_with_require_alignment() {
2523        let batch = RecordBatch::try_from_iter(vec![(
2524            "i32",
2525            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2526        )])
2527        .unwrap();
2528
2529        let gen = IpcDataGenerator {};
2530        let mut dict_tracker = DictionaryTracker::new(false);
2531        let (_, encoded) = gen
2532            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2533            .unwrap();
2534
2535        let message = root_as_message(&encoded.ipc_message).unwrap();
2536
2537        // Construct an unaligned buffer
2538        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2539        buffer.push(0_u8);
2540        buffer.extend_from_slice(&encoded.arrow_data);
2541        let b = Buffer::from(buffer).slice(1);
2542        assert_ne!(b.as_ptr().align_offset(8), 0);
2543
2544        let ipc_batch = message.header_as_record_batch().unwrap();
2545        let result = RecordBatchDecoder::try_new(
2546            &b,
2547            ipc_batch,
2548            batch.schema(),
2549            &Default::default(),
2550            &message.version(),
2551        )
2552        .unwrap()
2553        .with_require_alignment(true)
2554        .read_record_batch();
2555
2556        let error = result.unwrap_err();
2557        assert_eq!(
2558            error.to_string(),
2559            "Invalid argument error: Misaligned buffers[0] in array of type Int32, \
2560             offset from expected alignment of 4 by 1"
2561        );
2562    }
2563
2564    #[test]
2565    fn test_file_with_massive_column_count() {
2566        // 499_999 is upper limit for default settings (1_000_000)
2567        let limit = 600_000;
2568
2569        let fields = (0..limit)
2570            .map(|i| Field::new(format!("{i}"), DataType::Boolean, false))
2571            .collect::<Vec<_>>();
2572        let schema = Arc::new(Schema::new(fields));
2573        let batch = RecordBatch::new_empty(schema);
2574
2575        let mut buf = Vec::new();
2576        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2577        writer.write(&batch).unwrap();
2578        writer.finish().unwrap();
2579        drop(writer);
2580
2581        let mut reader = FileReaderBuilder::new()
2582            .with_max_footer_fb_tables(1_500_000)
2583            .build(std::io::Cursor::new(buf))
2584            .unwrap();
2585        let roundtrip_batch = reader.next().unwrap().unwrap();
2586
2587        assert_eq!(batch, roundtrip_batch);
2588    }
2589
2590    #[test]
2591    fn test_file_with_deeply_nested_columns() {
2592        // 60 is upper limit for default settings (64)
2593        let limit = 61;
2594
2595        let fields = (0..limit).fold(
2596            vec![Field::new("leaf", DataType::Boolean, false)],
2597            |field, index| vec![Field::new_struct(format!("{index}"), field, false)],
2598        );
2599        let schema = Arc::new(Schema::new(fields));
2600        let batch = RecordBatch::new_empty(schema);
2601
2602        let mut buf = Vec::new();
2603        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2604        writer.write(&batch).unwrap();
2605        writer.finish().unwrap();
2606        drop(writer);
2607
2608        let mut reader = FileReaderBuilder::new()
2609            .with_max_footer_fb_depth(65)
2610            .build(std::io::Cursor::new(buf))
2611            .unwrap();
2612        let roundtrip_batch = reader.next().unwrap().unwrap();
2613
2614        assert_eq!(batch, roundtrip_batch);
2615    }
2616
2617    #[test]
2618    fn test_invalid_struct_array_ipc_read_errors() {
2619        let a_field = Field::new("a", DataType::Int32, false);
2620        let b_field = Field::new("b", DataType::Int32, false);
2621        let struct_fields = Fields::from(vec![a_field.clone(), b_field.clone()]);
2622
2623        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2624            .len(4)
2625            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2626            .build()
2627            .unwrap();
2628        let b_array_data = ArrayData::builder(b_field.data_type().clone())
2629            .len(3)
2630            .add_buffer(Buffer::from_slice_ref([5, 6, 7]))
2631            .build()
2632            .unwrap();
2633
2634        let invalid_struct_arr = unsafe {
2635            StructArray::new_unchecked(
2636                struct_fields,
2637                vec![make_array(a_array_data), make_array(b_array_data)],
2638                None,
2639            )
2640        };
2641
2642        expect_ipc_validation_error(
2643            Arc::new(invalid_struct_arr),
2644            "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
2645        );
2646    }
2647
2648    #[test]
2649    fn test_invalid_nested_array_ipc_read_errors() {
2650        // one of the nested arrays has invalid data
2651        let a_field = Field::new("a", DataType::Int32, false);
2652        let b_field = Field::new("b", DataType::Utf8, false);
2653
2654        let schema = Arc::new(Schema::new(vec![Field::new_struct(
2655            "s",
2656            vec![a_field.clone(), b_field.clone()],
2657            false,
2658        )]));
2659
2660        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2661            .len(4)
2662            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2663            .build()
2664            .unwrap();
2665        // invalid nested child array -- length is correct, but has invalid utf8 data
2666        let b_array_data = {
2667            let valid: &[u8] = b"   ";
2668            let mut invalid = vec![];
2669            invalid.extend_from_slice(b"ValidString");
2670            invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2671            let binary_array =
2672                BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2673            let array = unsafe {
2674                StringArray::new_unchecked(
2675                    binary_array.offsets().clone(),
2676                    binary_array.values().clone(),
2677                    binary_array.nulls().cloned(),
2678                )
2679            };
2680            array.into_data()
2681        };
2682        let struct_data_type = schema.field(0).data_type();
2683
2684        let invalid_struct_arr = unsafe {
2685            make_array(
2686                ArrayData::builder(struct_data_type.clone())
2687                    .len(4)
2688                    .add_child_data(a_array_data)
2689                    .add_child_data(b_array_data)
2690                    .build_unchecked(),
2691            )
2692        };
2693        expect_ipc_validation_error(
2694            Arc::new(invalid_struct_arr),
2695            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11",
2696        );
2697    }
2698
2699    #[test]
2700    fn test_same_dict_id_without_preserve() {
2701        let batch = RecordBatch::try_new(
2702            Arc::new(Schema::new(
2703                ["a", "b"]
2704                    .iter()
2705                    .map(|name| {
2706                        #[allow(deprecated)]
2707                        Field::new_dict(
2708                            name.to_string(),
2709                            DataType::Dictionary(
2710                                Box::new(DataType::Int32),
2711                                Box::new(DataType::Utf8),
2712                            ),
2713                            true,
2714                            0,
2715                            false,
2716                        )
2717                    })
2718                    .collect::<Vec<Field>>(),
2719            )),
2720            vec![
2721                Arc::new(
2722                    vec![Some("c"), Some("d")]
2723                        .into_iter()
2724                        .collect::<DictionaryArray<Int32Type>>(),
2725                ) as ArrayRef,
2726                Arc::new(
2727                    vec![Some("e"), Some("f")]
2728                        .into_iter()
2729                        .collect::<DictionaryArray<Int32Type>>(),
2730                ) as ArrayRef,
2731            ],
2732        )
2733        .expect("Failed to create RecordBatch");
2734
2735        // serialize the record batch as an IPC stream
2736        let mut buf = vec![];
2737        {
2738            let mut writer = crate::writer::StreamWriter::try_new_with_options(
2739                &mut buf,
2740                batch.schema().as_ref(),
2741                crate::writer::IpcWriteOptions::default(),
2742            )
2743            .expect("Failed to create StreamWriter");
2744            writer.write(&batch).expect("Failed to write RecordBatch");
2745            writer.finish().expect("Failed to finish StreamWriter");
2746        }
2747
2748        StreamReader::try_new(std::io::Cursor::new(buf), None)
2749            .expect("Failed to create StreamReader")
2750            .for_each(|decoded_batch| {
2751                assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
2752            });
2753    }
2754
2755    #[test]
2756    fn test_validation_of_invalid_list_array() {
2757        // ListArray with invalid offsets
2758        let array = unsafe {
2759            let values = Int32Array::from(vec![1, 2, 3]);
2760            let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); // offsets can't go backwards
2761            let offsets = OffsetBuffer::new_unchecked(bad_offsets); // INVALID array created
2762            let field = Field::new_list_field(DataType::Int32, true);
2763            let nulls = None;
2764            ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
2765        };
2766
2767        expect_ipc_validation_error(
2768            Arc::new(array),
2769            "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
2770        );
2771    }
2772
2773    #[test]
2774    fn test_validation_of_invalid_string_array() {
2775        let valid: &[u8] = b"   ";
2776        let mut invalid = vec![];
2777        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2778        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2779        let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2780        // data is not valid utf8 we can not construct a correct StringArray
2781        // safely, so purposely create an invalid StringArray
2782        let array = unsafe {
2783            StringArray::new_unchecked(
2784                binary_array.offsets().clone(),
2785                binary_array.values().clone(),
2786                binary_array.nulls().cloned(),
2787            )
2788        };
2789        expect_ipc_validation_error(
2790            Arc::new(array),
2791            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
2792        );
2793    }
2794
2795    #[test]
2796    fn test_validation_of_invalid_string_view_array() {
2797        let valid: &[u8] = b"   ";
2798        let mut invalid = vec![];
2799        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2800        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2801        let binary_view_array =
2802            BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2803        // data is not valid utf8 we can not construct a correct StringArray
2804        // safely, so purposely create an invalid StringArray
2805        let array = unsafe {
2806            StringViewArray::new_unchecked(
2807                binary_view_array.views().clone(),
2808                binary_view_array.data_buffers().to_vec(),
2809                binary_view_array.nulls().cloned(),
2810            )
2811        };
2812        expect_ipc_validation_error(
2813            Arc::new(array),
2814            "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
2815        );
2816    }
2817
2818    /// return an invalid dictionary array (key is larger than values)
2819    /// ListArray with invalid offsets
2820    #[test]
2821    fn test_validation_of_invalid_dictionary_array() {
2822        let array = unsafe {
2823            let values = StringArray::from_iter_values(["a", "b", "c"]);
2824            let keys = Int32Array::from(vec![1, 200]); // keys are not valid for values
2825            DictionaryArray::new_unchecked(keys, Arc::new(values))
2826        };
2827
2828        expect_ipc_validation_error(
2829            Arc::new(array),
2830            "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
2831        );
2832    }
2833
2834    #[test]
2835    fn test_validation_of_invalid_union_array() {
2836        let array = unsafe {
2837            let fields = UnionFields::new(
2838                vec![1, 3], // typeids : type id 2 is not valid
2839                vec![
2840                    Field::new("a", DataType::Int32, false),
2841                    Field::new("b", DataType::Utf8, false),
2842                ],
2843            );
2844            let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
2845            let offsets = None;
2846            let children: Vec<ArrayRef> = vec![
2847                Arc::new(Int32Array::from(vec![10, 20, 30])),
2848                Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
2849            ];
2850
2851            UnionArray::new_unchecked(fields, type_ids, offsets, children)
2852        };
2853
2854        expect_ipc_validation_error(
2855            Arc::new(array),
2856            "Invalid argument error: Type Ids values must match one of the field type ids",
2857        );
2858    }
2859
2860    /// Invalid Utf-8 sequence in the first character
2861    /// <https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
2862    const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
2863
2864    /// Expect an error when reading the record batch using IPC or IPC Streams
2865    fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
2866        let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();
2867
2868        // IPC Stream format
2869        let buf = write_stream(&rb); // write is ok
2870        read_stream_skip_validation(&buf).unwrap();
2871        let err = read_stream(&buf).unwrap_err();
2872        assert_eq!(err.to_string(), expected_err);
2873
2874        // IPC File format
2875        let buf = write_ipc(&rb); // write is ok
2876        read_ipc_skip_validation(&buf).unwrap();
2877        let err = read_ipc(&buf).unwrap_err();
2878        assert_eq!(err.to_string(), expected_err);
2879
2880        // IPC Format with FileDecoder
2881        read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
2882        let err = read_ipc_with_decoder(buf).unwrap_err();
2883        assert_eq!(err.to_string(), expected_err);
2884    }
2885
2886    #[test]
2887    fn test_roundtrip_schema() {
2888        let schema = Schema::new(vec![
2889            Field::new(
2890                "a",
2891                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2892                false,
2893            ),
2894            Field::new(
2895                "b",
2896                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2897                false,
2898            ),
2899        ]);
2900
2901        let options = IpcWriteOptions::default();
2902        let data_gen = IpcDataGenerator::default();
2903        let mut dict_tracker = DictionaryTracker::new(false);
2904        let encoded_data =
2905            data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
2906        let mut schema_bytes = vec![];
2907        write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
2908
2909        let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
2910            4
2911        } else {
2912            0
2913        };
2914
2915        size_prefixed_root_as_message(&schema_bytes[begin_offset..])
2916            .expect_err("size_prefixed_root_as_message");
2917
2918        let msg = parse_message(&schema_bytes).expect("parse_message");
2919        let ipc_schema = msg.header_as_schema().expect("header_as_schema");
2920        let new_schema = fb_to_schema(ipc_schema);
2921
2922        assert_eq!(schema, new_schema);
2923    }
2924}