1mod stream;
28
29pub use stream::*;
30
31use flatbuffers::{VectorIter, VerifierOptions};
32use std::collections::{HashMap, VecDeque};
33use std::fmt;
34use std::io::{BufReader, Read, Seek, SeekFrom};
35use std::sync::Arc;
36
37use arrow_array::*;
38use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer};
39use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
40use arrow_schema::*;
41
42use crate::compression::CompressionCodec;
43use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER};
44use DataType::*;
45
46fn read_buffer(
56 buf: &crate::Buffer,
57 a_data: &Buffer,
58 compression_codec: Option<CompressionCodec>,
59) -> Result<Buffer, ArrowError> {
60 let start_offset = buf.offset() as usize;
61 let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
62 match (buf_data.is_empty(), compression_codec) {
64 (true, _) | (_, None) => Ok(buf_data),
65 (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
66 }
67}
68impl RecordBatchDecoder<'_> {
69 fn create_array(
82 &mut self,
83 field: &Field,
84 variadic_counts: &mut VecDeque<i64>,
85 ) -> Result<ArrayRef, ArrowError> {
86 let data_type = field.data_type();
87 match data_type {
88 Utf8 | Binary | LargeBinary | LargeUtf8 => {
89 let field_node = self.next_node(field)?;
90 let buffers = [
91 self.next_buffer()?,
92 self.next_buffer()?,
93 self.next_buffer()?,
94 ];
95 self.create_primitive_array(field_node, data_type, &buffers)
96 }
97 BinaryView | Utf8View => {
98 let count = variadic_counts
99 .pop_front()
100 .ok_or(ArrowError::IpcError(format!(
101 "Missing variadic count for {data_type} column"
102 )))?;
103 let count = count + 2; let buffers = (0..count)
105 .map(|_| self.next_buffer())
106 .collect::<Result<Vec<_>, _>>()?;
107 let field_node = self.next_node(field)?;
108 self.create_primitive_array(field_node, data_type, &buffers)
109 }
110 FixedSizeBinary(_) => {
111 let field_node = self.next_node(field)?;
112 let buffers = [self.next_buffer()?, self.next_buffer()?];
113 self.create_primitive_array(field_node, data_type, &buffers)
114 }
115 List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
116 let list_node = self.next_node(field)?;
117 let list_buffers = [self.next_buffer()?, self.next_buffer()?];
118 let values = self.create_array(list_field, variadic_counts)?;
119 self.create_list_array(list_node, data_type, &list_buffers, values)
120 }
121 FixedSizeList(ref list_field, _) => {
122 let list_node = self.next_node(field)?;
123 let list_buffers = [self.next_buffer()?];
124 let values = self.create_array(list_field, variadic_counts)?;
125 self.create_list_array(list_node, data_type, &list_buffers, values)
126 }
127 Struct(struct_fields) => {
128 let struct_node = self.next_node(field)?;
129 let null_buffer = self.next_buffer()?;
130
131 let mut struct_arrays = vec![];
133 for struct_field in struct_fields {
136 let child = self.create_array(struct_field, variadic_counts)?;
137 struct_arrays.push(child);
138 }
139 self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays)
140 }
141 RunEndEncoded(run_ends_field, values_field) => {
142 let run_node = self.next_node(field)?;
143 let run_ends = self.create_array(run_ends_field, variadic_counts)?;
144 let values = self.create_array(values_field, variadic_counts)?;
145
146 let run_array_length = run_node.length() as usize;
147 let builder = ArrayData::builder(data_type.clone())
148 .len(run_array_length)
149 .offset(0)
150 .add_child_data(run_ends.into_data())
151 .add_child_data(values.into_data());
152 self.create_array_from_builder(builder)
153 }
154 Dictionary(_, _) => {
156 let index_node = self.next_node(field)?;
157 let index_buffers = [self.next_buffer()?, self.next_buffer()?];
158
159 #[allow(deprecated)]
160 let dict_id = field.dict_id().ok_or_else(|| {
161 ArrowError::ParseError(format!("Field {field} does not have dict id"))
162 })?;
163
164 let value_array = self.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
165 ArrowError::ParseError(format!(
166 "Cannot find a dictionary batch with dict id: {dict_id}"
167 ))
168 })?;
169
170 self.create_dictionary_array(
171 index_node,
172 data_type,
173 &index_buffers,
174 value_array.clone(),
175 )
176 }
177 Union(fields, mode) => {
178 let union_node = self.next_node(field)?;
179 let len = union_node.length() as usize;
180
181 if self.version < MetadataVersion::V5 {
184 self.next_buffer()?;
185 }
186
187 let type_ids: ScalarBuffer<i8> =
188 self.next_buffer()?.slice_with_length(0, len).into();
189
190 let value_offsets = match mode {
191 UnionMode::Dense => {
192 let offsets: ScalarBuffer<i32> =
193 self.next_buffer()?.slice_with_length(0, len * 4).into();
194 Some(offsets)
195 }
196 UnionMode::Sparse => None,
197 };
198
199 let mut children = Vec::with_capacity(fields.len());
200
201 for (_id, field) in fields.iter() {
202 let child = self.create_array(field, variadic_counts)?;
203 children.push(child);
204 }
205
206 let array = if self.skip_validation.get() {
207 unsafe {
209 UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children)
210 }
211 } else {
212 UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?
213 };
214 Ok(Arc::new(array))
215 }
216 Null => {
217 let node = self.next_node(field)?;
218 let length = node.length();
219 let null_count = node.null_count();
220
221 if length != null_count {
222 return Err(ArrowError::SchemaError(format!(
223 "Field {field} of NullArray has unequal null_count {null_count} and len {length}"
224 )));
225 }
226
227 let builder = ArrayData::builder(data_type.clone())
228 .len(length as usize)
229 .offset(0);
230 self.create_array_from_builder(builder)
231 }
232 _ => {
233 let field_node = self.next_node(field)?;
234 let buffers = [self.next_buffer()?, self.next_buffer()?];
235 self.create_primitive_array(field_node, data_type, &buffers)
236 }
237 }
238 }
239
240 fn create_primitive_array(
243 &self,
244 field_node: &FieldNode,
245 data_type: &DataType,
246 buffers: &[Buffer],
247 ) -> Result<ArrayRef, ArrowError> {
248 let length = field_node.length() as usize;
249 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
250 let builder = match data_type {
251 Utf8 | Binary | LargeBinary | LargeUtf8 => {
252 ArrayData::builder(data_type.clone())
254 .len(length)
255 .buffers(buffers[1..3].to_vec())
256 .null_bit_buffer(null_buffer)
257 }
258 BinaryView | Utf8View => ArrayData::builder(data_type.clone())
259 .len(length)
260 .buffers(buffers[1..].to_vec())
261 .null_bit_buffer(null_buffer),
262 _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => {
263 ArrayData::builder(data_type.clone())
265 .len(length)
266 .add_buffer(buffers[1].clone())
267 .null_bit_buffer(null_buffer)
268 }
269 t => unreachable!("Data type {:?} either unsupported or not primitive", t),
270 };
271
272 self.create_array_from_builder(builder)
273 }
274
275 fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result<ArrayRef, ArrowError> {
277 let mut builder = builder.align_buffers(!self.require_alignment);
278 if self.skip_validation.get() {
279 unsafe { builder = builder.skip_validation(true) }
281 };
282 Ok(make_array(builder.build()?))
283 }
284
285 fn create_list_array(
288 &self,
289 field_node: &FieldNode,
290 data_type: &DataType,
291 buffers: &[Buffer],
292 child_array: ArrayRef,
293 ) -> Result<ArrayRef, ArrowError> {
294 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
295 let length = field_node.length() as usize;
296 let child_data = child_array.into_data();
297 let builder = match data_type {
298 List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
299 .len(length)
300 .add_buffer(buffers[1].clone())
301 .add_child_data(child_data)
302 .null_bit_buffer(null_buffer),
303
304 FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
305 .len(length)
306 .add_child_data(child_data)
307 .null_bit_buffer(null_buffer),
308
309 _ => unreachable!("Cannot create list or map array from {:?}", data_type),
310 };
311
312 self.create_array_from_builder(builder)
313 }
314
315 fn create_struct_array(
316 &self,
317 struct_node: &FieldNode,
318 null_buffer: Buffer,
319 struct_fields: &Fields,
320 struct_arrays: Vec<ArrayRef>,
321 ) -> Result<ArrayRef, ArrowError> {
322 let null_count = struct_node.null_count() as usize;
323 let len = struct_node.length() as usize;
324
325 let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into());
326 if struct_arrays.is_empty() {
327 return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
330 }
331
332 let struct_array = if self.skip_validation.get() {
333 unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) }
335 } else {
336 StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
337 };
338
339 Ok(Arc::new(struct_array))
340 }
341
342 fn create_dictionary_array(
345 &self,
346 field_node: &FieldNode,
347 data_type: &DataType,
348 buffers: &[Buffer],
349 value_array: ArrayRef,
350 ) -> Result<ArrayRef, ArrowError> {
351 if let Dictionary(_, _) = *data_type {
352 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
353 let builder = ArrayData::builder(data_type.clone())
354 .len(field_node.length() as usize)
355 .add_buffer(buffers[1].clone())
356 .add_child_data(value_array.into_data())
357 .null_bit_buffer(null_buffer);
358 self.create_array_from_builder(builder)
359 } else {
360 unreachable!("Cannot create dictionary array from {:?}", data_type)
361 }
362 }
363}
364
365struct RecordBatchDecoder<'a> {
370 batch: crate::RecordBatch<'a>,
372 schema: SchemaRef,
374 dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
376 compression: Option<CompressionCodec>,
378 version: MetadataVersion,
380 data: &'a Buffer,
382 nodes: VectorIter<'a, FieldNode>,
384 buffers: VectorIter<'a, crate::Buffer>,
386 projection: Option<&'a [usize]>,
389 require_alignment: bool,
392 skip_validation: UnsafeFlag,
396}
397
398impl<'a> RecordBatchDecoder<'a> {
399 fn try_new(
401 buf: &'a Buffer,
402 batch: crate::RecordBatch<'a>,
403 schema: SchemaRef,
404 dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
405 metadata: &'a MetadataVersion,
406 ) -> Result<Self, ArrowError> {
407 let buffers = batch.buffers().ok_or_else(|| {
408 ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string())
409 })?;
410 let field_nodes = batch.nodes().ok_or_else(|| {
411 ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string())
412 })?;
413
414 let batch_compression = batch.compression();
415 let compression = batch_compression
416 .map(|batch_compression| batch_compression.codec().try_into())
417 .transpose()?;
418
419 Ok(Self {
420 batch,
421 schema,
422 dictionaries_by_id,
423 compression,
424 version: *metadata,
425 data: buf,
426 nodes: field_nodes.iter(),
427 buffers: buffers.iter(),
428 projection: None,
429 require_alignment: false,
430 skip_validation: UnsafeFlag::new(),
431 })
432 }
433
434 pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self {
439 self.projection = projection;
440 self
441 }
442
443 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
449 self.require_alignment = require_alignment;
450 self
451 }
452
453 pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
465 self.skip_validation = skip_validation;
466 self
467 }
468
469 fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
471 let mut variadic_counts: VecDeque<i64> = self
472 .batch
473 .variadicBufferCounts()
474 .into_iter()
475 .flatten()
476 .collect();
477
478 let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
479
480 let schema = Arc::clone(&self.schema);
481 if let Some(projection) = self.projection {
482 let mut arrays = vec![];
483 for (idx, field) in schema.fields().iter().enumerate() {
485 if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
487 let child = self.create_array(field, &mut variadic_counts)?;
488 arrays.push((proj_idx, child));
489 } else {
490 self.skip_field(field, &mut variadic_counts)?;
491 }
492 }
493
494 arrays.sort_by_key(|t| t.0);
495
496 let schema = Arc::new(schema.project(projection)?);
497 let columns = arrays.into_iter().map(|t| t.1).collect::<Vec<_>>();
498
499 if self.skip_validation.get() {
500 unsafe {
502 Ok(RecordBatch::new_unchecked(
503 schema,
504 columns,
505 self.batch.length() as usize,
506 ))
507 }
508 } else {
509 assert!(variadic_counts.is_empty());
510 RecordBatch::try_new_with_options(schema, columns, &options)
511 }
512 } else {
513 let mut children = vec![];
514 for field in schema.fields() {
516 let child = self.create_array(field, &mut variadic_counts)?;
517 children.push(child);
518 }
519
520 if self.skip_validation.get() {
521 unsafe {
523 Ok(RecordBatch::new_unchecked(
524 schema,
525 children,
526 self.batch.length() as usize,
527 ))
528 }
529 } else {
530 assert!(variadic_counts.is_empty());
531 RecordBatch::try_new_with_options(schema, children, &options)
532 }
533 }
534 }
535
536 fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
537 read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
538 }
539
540 fn skip_buffer(&mut self) {
541 self.buffers.next().unwrap();
542 }
543
544 fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
545 self.nodes.next().ok_or_else(|| {
546 ArrowError::SchemaError(format!(
547 "Invalid data for schema. {field} refers to node not found in schema",
548 ))
549 })
550 }
551
552 fn skip_field(
553 &mut self,
554 field: &Field,
555 variadic_count: &mut VecDeque<i64>,
556 ) -> Result<(), ArrowError> {
557 self.next_node(field)?;
558
559 match field.data_type() {
560 Utf8 | Binary | LargeBinary | LargeUtf8 => {
561 for _ in 0..3 {
562 self.skip_buffer()
563 }
564 }
565 Utf8View | BinaryView => {
566 let count = variadic_count
567 .pop_front()
568 .ok_or(ArrowError::IpcError(format!(
569 "Missing variadic count for {} column",
570 field.data_type()
571 )))?;
572 let count = count + 2; for _i in 0..count {
574 self.skip_buffer()
575 }
576 }
577 FixedSizeBinary(_) => {
578 self.skip_buffer();
579 self.skip_buffer();
580 }
581 List(list_field) | LargeList(list_field) | Map(list_field, _) => {
582 self.skip_buffer();
583 self.skip_buffer();
584 self.skip_field(list_field, variadic_count)?;
585 }
586 FixedSizeList(list_field, _) => {
587 self.skip_buffer();
588 self.skip_field(list_field, variadic_count)?;
589 }
590 Struct(struct_fields) => {
591 self.skip_buffer();
592
593 for struct_field in struct_fields {
595 self.skip_field(struct_field, variadic_count)?
596 }
597 }
598 RunEndEncoded(run_ends_field, values_field) => {
599 self.skip_field(run_ends_field, variadic_count)?;
600 self.skip_field(values_field, variadic_count)?;
601 }
602 Dictionary(_, _) => {
603 self.skip_buffer(); self.skip_buffer(); }
606 Union(fields, mode) => {
607 self.skip_buffer(); match mode {
610 UnionMode::Dense => self.skip_buffer(),
611 UnionMode::Sparse => {}
612 };
613
614 for (_, field) in fields.iter() {
615 self.skip_field(field, variadic_count)?
616 }
617 }
618 Null => {} _ => {
620 self.skip_buffer();
621 self.skip_buffer();
622 }
623 };
624 Ok(())
625 }
626}
627
628pub fn read_record_batch(
639 buf: &Buffer,
640 batch: crate::RecordBatch,
641 schema: SchemaRef,
642 dictionaries_by_id: &HashMap<i64, ArrayRef>,
643 projection: Option<&[usize]>,
644 metadata: &MetadataVersion,
645) -> Result<RecordBatch, ArrowError> {
646 RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
647 .with_projection(projection)
648 .with_require_alignment(false)
649 .read_record_batch()
650}
651
652pub fn read_dictionary(
655 buf: &Buffer,
656 batch: crate::DictionaryBatch,
657 schema: &Schema,
658 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
659 metadata: &MetadataVersion,
660) -> Result<(), ArrowError> {
661 read_dictionary_impl(
662 buf,
663 batch,
664 schema,
665 dictionaries_by_id,
666 metadata,
667 false,
668 UnsafeFlag::new(),
669 )
670}
671
672fn read_dictionary_impl(
673 buf: &Buffer,
674 batch: crate::DictionaryBatch,
675 schema: &Schema,
676 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
677 metadata: &MetadataVersion,
678 require_alignment: bool,
679 skip_validation: UnsafeFlag,
680) -> Result<(), ArrowError> {
681 if batch.isDelta() {
682 return Err(ArrowError::InvalidArgumentError(
683 "delta dictionary batches not supported".to_string(),
684 ));
685 }
686
687 let id = batch.id();
688 #[allow(deprecated)]
689 let fields_using_this_dictionary = schema.fields_with_dict_id(id);
690 let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
691 ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
692 })?;
693
694 let dictionary_values: ArrayRef = match first_field.data_type() {
698 DataType::Dictionary(_, ref value_type) => {
699 let value = value_type.as_ref().clone();
701 let schema = Schema::new(vec![Field::new("", value, true)]);
702 let record_batch = RecordBatchDecoder::try_new(
704 buf,
705 batch.data().unwrap(),
706 Arc::new(schema),
707 dictionaries_by_id,
708 metadata,
709 )?
710 .with_require_alignment(require_alignment)
711 .with_skip_validation(skip_validation)
712 .read_record_batch()?;
713
714 Some(record_batch.column(0).clone())
715 }
716 _ => None,
717 }
718 .ok_or_else(|| {
719 ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
720 })?;
721
722 dictionaries_by_id.insert(id, dictionary_values.clone());
726
727 Ok(())
728}
729
730fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
732 reader.seek(SeekFrom::Start(block.offset() as u64))?;
733 let body_len = block.bodyLength().to_usize().unwrap();
734 let metadata_len = block.metaDataLength().to_usize().unwrap();
735 let total_len = body_len.checked_add(metadata_len).unwrap();
736
737 let mut buf = MutableBuffer::from_len_zeroed(total_len);
738 reader.read_exact(&mut buf)?;
739 Ok(buf.into())
740}
741
742fn parse_message(buf: &[u8]) -> Result<Message<'_>, ArrowError> {
746 let buf = match buf[..4] == CONTINUATION_MARKER {
747 true => &buf[8..],
748 false => &buf[4..],
749 };
750 crate::root_as_message(buf)
751 .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
752}
753
754pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
758 if buf[4..] != super::ARROW_MAGIC {
759 return Err(ArrowError::ParseError(
760 "Arrow file does not contain correct footer".to_string(),
761 ));
762 }
763
764 let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
766 footer_len
767 .try_into()
768 .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
769}
770
771#[derive(Debug)]
836pub struct FileDecoder {
837 schema: SchemaRef,
838 dictionaries: HashMap<i64, ArrayRef>,
839 version: MetadataVersion,
840 projection: Option<Vec<usize>>,
841 require_alignment: bool,
842 skip_validation: UnsafeFlag,
843}
844
845impl FileDecoder {
846 pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
848 Self {
849 schema,
850 version,
851 dictionaries: Default::default(),
852 projection: None,
853 require_alignment: false,
854 skip_validation: UnsafeFlag::new(),
855 }
856 }
857
858 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
860 self.projection = Some(projection);
861 self
862 }
863
864 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
877 self.require_alignment = require_alignment;
878 self
879 }
880
881 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
892 self.skip_validation.set(skip_validation);
893 self
894 }
895
896 fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
897 let message = parse_message(buf)?;
898
899 if self.version != MetadataVersion::V1 && message.version() != self.version {
901 return Err(ArrowError::IpcError(
902 "Could not read IPC message as metadata versions mismatch".to_string(),
903 ));
904 }
905 Ok(message)
906 }
907
908 pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
910 let message = self.read_message(buf)?;
911 match message.header_type() {
912 crate::MessageHeader::DictionaryBatch => {
913 let batch = message.header_as_dictionary_batch().unwrap();
914 read_dictionary_impl(
915 &buf.slice(block.metaDataLength() as _),
916 batch,
917 &self.schema,
918 &mut self.dictionaries,
919 &message.version(),
920 self.require_alignment,
921 self.skip_validation.clone(),
922 )
923 }
924 t => Err(ArrowError::ParseError(format!(
925 "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
926 ))),
927 }
928 }
929
930 pub fn read_record_batch(
932 &self,
933 block: &Block,
934 buf: &Buffer,
935 ) -> Result<Option<RecordBatch>, ArrowError> {
936 let message = self.read_message(buf)?;
937 match message.header_type() {
938 crate::MessageHeader::Schema => Err(ArrowError::IpcError(
939 "Not expecting a schema when messages are read".to_string(),
940 )),
941 crate::MessageHeader::RecordBatch => {
942 let batch = message.header_as_record_batch().ok_or_else(|| {
943 ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
944 })?;
945 RecordBatchDecoder::try_new(
947 &buf.slice(block.metaDataLength() as _),
948 batch,
949 self.schema.clone(),
950 &self.dictionaries,
951 &message.version(),
952 )?
953 .with_projection(self.projection.as_deref())
954 .with_require_alignment(self.require_alignment)
955 .with_skip_validation(self.skip_validation.clone())
956 .read_record_batch()
957 .map(Some)
958 }
959 crate::MessageHeader::NONE => Ok(None),
960 t => Err(ArrowError::InvalidArgumentError(format!(
961 "Reading types other than record batches not yet supported, unable to read {t:?}"
962 ))),
963 }
964 }
965}
966
967#[derive(Debug)]
969pub struct FileReaderBuilder {
970 projection: Option<Vec<usize>>,
972 max_footer_fb_tables: usize,
974 max_footer_fb_depth: usize,
976}
977
978impl Default for FileReaderBuilder {
979 fn default() -> Self {
980 let verifier_options = VerifierOptions::default();
981 Self {
982 max_footer_fb_tables: verifier_options.max_tables,
983 max_footer_fb_depth: verifier_options.max_depth,
984 projection: None,
985 }
986 }
987}
988
989impl FileReaderBuilder {
990 pub fn new() -> Self {
994 Self::default()
995 }
996
997 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
999 self.projection = Some(projection);
1000 self
1001 }
1002
1003 pub fn with_max_footer_fb_tables(mut self, max_footer_fb_tables: usize) -> Self {
1016 self.max_footer_fb_tables = max_footer_fb_tables;
1017 self
1018 }
1019
1020 pub fn with_max_footer_fb_depth(mut self, max_footer_fb_depth: usize) -> Self {
1033 self.max_footer_fb_depth = max_footer_fb_depth;
1034 self
1035 }
1036
1037 pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<FileReader<R>, ArrowError> {
1039 let mut buffer = [0; 10];
1041 reader.seek(SeekFrom::End(-10))?;
1042 reader.read_exact(&mut buffer)?;
1043
1044 let footer_len = read_footer_length(buffer)?;
1045
1046 let mut footer_data = vec![0; footer_len];
1048 reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
1049 reader.read_exact(&mut footer_data)?;
1050
1051 let verifier_options = VerifierOptions {
1052 max_tables: self.max_footer_fb_tables,
1053 max_depth: self.max_footer_fb_depth,
1054 ..Default::default()
1055 };
1056 let footer = crate::root_as_footer_with_opts(&verifier_options, &footer_data[..]).map_err(
1057 |err| ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")),
1058 )?;
1059
1060 let blocks = footer.recordBatches().ok_or_else(|| {
1061 ArrowError::ParseError("Unable to get record batches from IPC Footer".to_string())
1062 })?;
1063
1064 let total_blocks = blocks.len();
1065
1066 let ipc_schema = footer.schema().unwrap();
1067 if !ipc_schema.endianness().equals_to_target_endianness() {
1068 return Err(ArrowError::IpcError(
1069 "the endianness of the source system does not match the endianness of the target system.".to_owned()
1070 ));
1071 }
1072
1073 let schema = crate::convert::fb_to_schema(ipc_schema);
1074
1075 let mut custom_metadata = HashMap::new();
1076 if let Some(fb_custom_metadata) = footer.custom_metadata() {
1077 for kv in fb_custom_metadata.into_iter() {
1078 custom_metadata.insert(
1079 kv.key().unwrap().to_string(),
1080 kv.value().unwrap().to_string(),
1081 );
1082 }
1083 }
1084
1085 let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
1086 if let Some(projection) = self.projection {
1087 decoder = decoder.with_projection(projection)
1088 }
1089
1090 if let Some(dictionaries) = footer.dictionaries() {
1092 for block in dictionaries {
1093 let buf = read_block(&mut reader, block)?;
1094 decoder.read_dictionary(block, &buf)?;
1095 }
1096 }
1097
1098 Ok(FileReader {
1099 reader,
1100 blocks: blocks.iter().copied().collect(),
1101 current_block: 0,
1102 total_blocks,
1103 decoder,
1104 custom_metadata,
1105 })
1106 }
1107}
1108
1109pub struct FileReader<R> {
1154 reader: R,
1156
1157 decoder: FileDecoder,
1159
1160 blocks: Vec<Block>,
1164
1165 current_block: usize,
1167
1168 total_blocks: usize,
1170
1171 custom_metadata: HashMap<String, String>,
1173}
1174
1175impl<R> fmt::Debug for FileReader<R> {
1176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1177 f.debug_struct("FileReader<R>")
1178 .field("decoder", &self.decoder)
1179 .field("blocks", &self.blocks)
1180 .field("current_block", &self.current_block)
1181 .field("total_blocks", &self.total_blocks)
1182 .finish_non_exhaustive()
1183 }
1184}
1185
1186impl<R: Read + Seek> FileReader<BufReader<R>> {
1187 pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1191 Self::try_new(BufReader::new(reader), projection)
1192 }
1193}
1194
1195impl<R: Read + Seek> FileReader<R> {
1196 pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1207 let builder = FileReaderBuilder {
1208 projection,
1209 ..Default::default()
1210 };
1211 builder.build(reader)
1212 }
1213
1214 pub fn custom_metadata(&self) -> &HashMap<String, String> {
1216 &self.custom_metadata
1217 }
1218
1219 pub fn num_batches(&self) -> usize {
1221 self.total_blocks
1222 }
1223
1224 pub fn schema(&self) -> SchemaRef {
1226 self.decoder.schema.clone()
1227 }
1228
1229 pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
1233 if index >= self.total_blocks {
1234 Err(ArrowError::InvalidArgumentError(format!(
1235 "Cannot set batch to index {} from {} total batches",
1236 index, self.total_blocks
1237 )))
1238 } else {
1239 self.current_block = index;
1240 Ok(())
1241 }
1242 }
1243
1244 fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1245 let block = &self.blocks[self.current_block];
1246 self.current_block += 1;
1247
1248 let buffer = read_block(&mut self.reader, block)?;
1250 self.decoder.read_record_batch(block, &buffer)
1251 }
1252
1253 pub fn get_ref(&self) -> &R {
1257 &self.reader
1258 }
1259
1260 pub fn get_mut(&mut self) -> &mut R {
1264 &mut self.reader
1265 }
1266
1267 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1273 self.decoder = self.decoder.with_skip_validation(skip_validation);
1274 self
1275 }
1276}
1277
1278impl<R: Read + Seek> Iterator for FileReader<R> {
1279 type Item = Result<RecordBatch, ArrowError>;
1280
1281 fn next(&mut self) -> Option<Self::Item> {
1282 if self.current_block < self.total_blocks {
1284 self.maybe_next().transpose()
1285 } else {
1286 None
1287 }
1288 }
1289}
1290
1291impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
1292 fn schema(&self) -> SchemaRef {
1293 self.schema()
1294 }
1295}
1296
1297pub struct StreamReader<R> {
1331 reader: R,
1333
1334 schema: SchemaRef,
1336
1337 dictionaries_by_id: HashMap<i64, ArrayRef>,
1341
1342 finished: bool,
1346
1347 projection: Option<(Vec<usize>, Schema)>,
1349
1350 skip_validation: UnsafeFlag,
1354}
1355
1356impl<R> fmt::Debug for StreamReader<R> {
1357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
1358 f.debug_struct("StreamReader<R>")
1359 .field("reader", &"R")
1360 .field("schema", &self.schema)
1361 .field("dictionaries_by_id", &self.dictionaries_by_id)
1362 .field("finished", &self.finished)
1363 .field("projection", &self.projection)
1364 .finish()
1365 }
1366}
1367
1368impl<R: Read> StreamReader<BufReader<R>> {
1369 pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1373 Self::try_new(BufReader::new(reader), projection)
1374 }
1375}
1376
1377impl<R: Read> StreamReader<R> {
1378 pub fn try_new(
1390 mut reader: R,
1391 projection: Option<Vec<usize>>,
1392 ) -> Result<StreamReader<R>, ArrowError> {
1393 let mut meta_size: [u8; 4] = [0; 4];
1395 reader.read_exact(&mut meta_size)?;
1396 let meta_len = {
1397 if meta_size == CONTINUATION_MARKER {
1400 reader.read_exact(&mut meta_size)?;
1401 }
1402 i32::from_le_bytes(meta_size)
1403 };
1404
1405 let meta_len = usize::try_from(meta_len)
1406 .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
1407 let mut meta_buffer = vec![0; meta_len];
1408 reader.read_exact(&mut meta_buffer)?;
1409
1410 let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| {
1411 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1412 })?;
1413 let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| {
1415 ArrowError::ParseError("Unable to read IPC message as schema".to_string())
1416 })?;
1417 let schema = crate::convert::fb_to_schema(ipc_schema);
1418
1419 let dictionaries_by_id = HashMap::new();
1421
1422 let projection = match projection {
1423 Some(projection_indices) => {
1424 let schema = schema.project(&projection_indices)?;
1425 Some((projection_indices, schema))
1426 }
1427 _ => None,
1428 };
1429 Ok(Self {
1430 reader,
1431 schema: Arc::new(schema),
1432 finished: false,
1433 dictionaries_by_id,
1434 projection,
1435 skip_validation: UnsafeFlag::new(),
1436 })
1437 }
1438
1439 #[deprecated(since = "53.0.0", note = "use `try_new` instead")]
1441 pub fn try_new_unbuffered(
1442 reader: R,
1443 projection: Option<Vec<usize>>,
1444 ) -> Result<Self, ArrowError> {
1445 Self::try_new(reader, projection)
1446 }
1447
1448 pub fn schema(&self) -> SchemaRef {
1450 self.schema.clone()
1451 }
1452
1453 pub fn is_finished(&self) -> bool {
1455 self.finished
1456 }
1457
1458 fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1459 if self.finished {
1460 return Ok(None);
1461 }
1462 let mut meta_size: [u8; 4] = [0; 4];
1464
1465 match self.reader.read_exact(&mut meta_size) {
1466 Ok(()) => (),
1467 Err(e) => {
1468 return if e.kind() == std::io::ErrorKind::UnexpectedEof {
1469 self.finished = true;
1473 Ok(None)
1474 } else {
1475 Err(ArrowError::from(e))
1476 };
1477 }
1478 }
1479
1480 let meta_len = {
1481 if meta_size == CONTINUATION_MARKER {
1484 self.reader.read_exact(&mut meta_size)?;
1485 }
1486 i32::from_le_bytes(meta_size)
1487 };
1488
1489 let meta_len = usize::try_from(meta_len)
1490 .map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
1491
1492 if meta_len == 0 {
1493 self.finished = true;
1495 return Ok(None);
1496 }
1497
1498 let mut meta_buffer = vec![0; meta_len];
1499 self.reader.read_exact(&mut meta_buffer)?;
1500
1501 let vecs = &meta_buffer.to_vec();
1502 let message = crate::root_as_message(vecs).map_err(|err| {
1503 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1504 })?;
1505
1506 match message.header_type() {
1507 crate::MessageHeader::Schema => Err(ArrowError::IpcError(
1508 "Not expecting a schema when messages are read".to_string(),
1509 )),
1510 crate::MessageHeader::RecordBatch => {
1511 let batch = message.header_as_record_batch().ok_or_else(|| {
1512 ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1513 })?;
1514 let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1516 self.reader.read_exact(&mut buf)?;
1517
1518 RecordBatchDecoder::try_new(
1519 &buf.into(),
1520 batch,
1521 self.schema(),
1522 &self.dictionaries_by_id,
1523 &message.version(),
1524 )?
1525 .with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
1526 .with_require_alignment(false)
1527 .with_skip_validation(self.skip_validation.clone())
1528 .read_record_batch()
1529 .map(Some)
1530 }
1531 crate::MessageHeader::DictionaryBatch => {
1532 let batch = message.header_as_dictionary_batch().ok_or_else(|| {
1533 ArrowError::IpcError(
1534 "Unable to read IPC message as dictionary batch".to_string(),
1535 )
1536 })?;
1537 let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1539 self.reader.read_exact(&mut buf)?;
1540
1541 read_dictionary_impl(
1542 &buf.into(),
1543 batch,
1544 &self.schema,
1545 &mut self.dictionaries_by_id,
1546 &message.version(),
1547 false,
1548 self.skip_validation.clone(),
1549 )?;
1550
1551 self.maybe_next()
1553 }
1554 crate::MessageHeader::NONE => Ok(None),
1555 t => Err(ArrowError::InvalidArgumentError(format!(
1556 "Reading types other than record batches not yet supported, unable to read {t:?} "
1557 ))),
1558 }
1559 }
1560
1561 pub fn get_ref(&self) -> &R {
1565 &self.reader
1566 }
1567
1568 pub fn get_mut(&mut self) -> &mut R {
1572 &mut self.reader
1573 }
1574
1575 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1581 self.skip_validation.set(skip_validation);
1582 self
1583 }
1584}
1585
1586impl<R: Read> Iterator for StreamReader<R> {
1587 type Item = Result<RecordBatch, ArrowError>;
1588
1589 fn next(&mut self) -> Option<Self::Item> {
1590 self.maybe_next().transpose()
1591 }
1592}
1593
1594impl<R: Read> RecordBatchReader for StreamReader<R> {
1595 fn schema(&self) -> SchemaRef {
1596 self.schema.clone()
1597 }
1598}
1599
1600#[cfg(test)]
1601mod tests {
1602 use std::io::Cursor;
1603
1604 use crate::convert::fb_to_schema;
1605 use crate::writer::{
1606 unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1607 };
1608
1609 use super::*;
1610
1611 use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
1612 use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
1613 use arrow_array::types::*;
1614 use arrow_buffer::{NullBuffer, OffsetBuffer};
1615 use arrow_data::ArrayDataBuilder;
1616
1617 fn create_test_projection_schema() -> Schema {
1618 let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
1620
1621 let fixed_size_list_data_type =
1622 DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3);
1623
1624 let union_fields = UnionFields::new(
1625 vec![0, 1],
1626 vec![
1627 Field::new("a", DataType::Int32, false),
1628 Field::new("b", DataType::Float64, false),
1629 ],
1630 );
1631
1632 let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
1633
1634 let struct_fields = Fields::from(vec![
1635 Field::new("id", DataType::Int32, false),
1636 Field::new_list("list", Field::new_list_field(DataType::Int8, true), false),
1637 ]);
1638 let struct_data_type = DataType::Struct(struct_fields);
1639
1640 let run_encoded_data_type = DataType::RunEndEncoded(
1641 Arc::new(Field::new("run_ends", DataType::Int16, false)),
1642 Arc::new(Field::new("values", DataType::Int32, true)),
1643 );
1644
1645 Schema::new(vec![
1647 Field::new("f0", DataType::UInt32, false),
1648 Field::new("f1", DataType::Utf8, false),
1649 Field::new("f2", DataType::Boolean, false),
1650 Field::new("f3", union_data_type, true),
1651 Field::new("f4", DataType::Null, true),
1652 Field::new("f5", DataType::Float64, true),
1653 Field::new("f6", list_data_type, false),
1654 Field::new("f7", DataType::FixedSizeBinary(3), true),
1655 Field::new("f8", fixed_size_list_data_type, false),
1656 Field::new("f9", struct_data_type, false),
1657 Field::new("f10", run_encoded_data_type, false),
1658 Field::new("f11", DataType::Boolean, false),
1659 Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
1660 Field::new("f13", DataType::Utf8, false),
1661 ])
1662 }
1663
1664 fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
1665 let array0 = UInt32Array::from(vec![1, 2, 3]);
1667 let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
1668 let array2 = BooleanArray::from(vec![true, false, true]);
1669
1670 let mut union_builder = UnionBuilder::new_dense();
1671 union_builder.append::<Int32Type>("a", 1).unwrap();
1672 union_builder.append::<Float64Type>("b", 10.1).unwrap();
1673 union_builder.append_null::<Float64Type>("b").unwrap();
1674 let array3 = union_builder.build().unwrap();
1675
1676 let array4 = NullArray::new(3);
1677 let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
1678 let array6_values = vec![
1679 Some(vec![Some(10), Some(10), Some(10)]),
1680 Some(vec![Some(20), Some(20), Some(20)]),
1681 Some(vec![Some(30), Some(30)]),
1682 ];
1683 let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
1684 let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
1685 let array7 = FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
1686
1687 let array8_values = ArrayData::builder(DataType::Int32)
1688 .len(9)
1689 .add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
1690 .build()
1691 .unwrap();
1692 let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
1693 .len(3)
1694 .add_child_data(array8_values)
1695 .build()
1696 .unwrap();
1697 let array8 = FixedSizeListArray::from(array8_data);
1698
1699 let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
1700 let array9_list: ArrayRef =
1701 Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
1702 Some(vec![Some(-10)]),
1703 Some(vec![Some(-20), Some(-20), Some(-20)]),
1704 Some(vec![Some(-30)]),
1705 ]));
1706 let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
1707 .add_child_data(array9_id.into_data())
1708 .add_child_data(array9_list.into_data())
1709 .len(3)
1710 .build()
1711 .unwrap();
1712 let array9: ArrayRef = Arc::new(StructArray::from(array9));
1713
1714 let array10_input = vec![Some(1_i32), None, None];
1715 let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
1716 array10_builder.extend(array10_input);
1717 let array10 = array10_builder.finish();
1718
1719 let array11 = BooleanArray::from(vec![false, false, true]);
1720
1721 let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
1722 let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
1723 let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
1724
1725 let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
1726
1727 RecordBatch::try_new(
1729 Arc::new(schema.clone()),
1730 vec![
1731 Arc::new(array0),
1732 Arc::new(array1),
1733 Arc::new(array2),
1734 Arc::new(array3),
1735 Arc::new(array4),
1736 Arc::new(array5),
1737 Arc::new(array6),
1738 Arc::new(array7),
1739 Arc::new(array8),
1740 Arc::new(array9),
1741 Arc::new(array10),
1742 Arc::new(array11),
1743 Arc::new(array12),
1744 Arc::new(array13),
1745 ],
1746 )
1747 .unwrap()
1748 }
1749
1750 #[test]
1751 fn test_negative_meta_len_start_stream() {
1752 let bytes = i32::to_le_bytes(-1);
1753 let mut buf = vec![];
1754 buf.extend(CONTINUATION_MARKER);
1755 buf.extend(bytes);
1756
1757 let reader_err = StreamReader::try_new(Cursor::new(buf), None).err();
1758 assert!(reader_err.is_some());
1759 assert_eq!(
1760 reader_err.unwrap().to_string(),
1761 "Parser error: Invalid metadata length: -1"
1762 );
1763 }
1764
1765 #[test]
1766 fn test_negative_meta_len_mid_stream() {
1767 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1768 let mut buf = Vec::new();
1769 {
1770 let mut writer = crate::writer::StreamWriter::try_new(&mut buf, &schema).unwrap();
1771 let batch =
1772 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int32Array::from(vec![1]))])
1773 .unwrap();
1774 writer.write(&batch).unwrap();
1775 }
1776
1777 let bytes = i32::to_le_bytes(-1);
1778 buf.extend(CONTINUATION_MARKER);
1779 buf.extend(bytes);
1780
1781 let mut reader = StreamReader::try_new(Cursor::new(buf), None).unwrap();
1782 assert!(reader.maybe_next().is_ok());
1784 let batch_err = reader.maybe_next().err();
1786 assert!(batch_err.is_some());
1787 assert_eq!(
1788 batch_err.unwrap().to_string(),
1789 "Parser error: Invalid metadata length: -1"
1790 );
1791 }
1792
1793 #[test]
1794 fn test_projection_array_values() {
1795 let schema = create_test_projection_schema();
1797
1798 let batch = create_test_projection_batch_data(&schema);
1800
1801 let mut buf = Vec::new();
1803 {
1804 let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1805 writer.write(&batch).unwrap();
1806 writer.finish().unwrap();
1807 }
1808
1809 for index in 0..12 {
1811 let projection = vec![index];
1812 let reader = FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
1813 let read_batch = reader.unwrap().next().unwrap().unwrap();
1814 let projected_column = read_batch.column(0);
1815 let expected_column = batch.column(index);
1816
1817 assert_eq!(projected_column.as_ref(), expected_column.as_ref());
1819 }
1820
1821 {
1822 let reader =
1824 FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(vec![3, 2, 1]));
1825 let read_batch = reader.unwrap().next().unwrap().unwrap();
1826 let expected_batch = batch.project(&[3, 2, 1]).unwrap();
1827 assert_eq!(read_batch, expected_batch);
1828 }
1829 }
1830
1831 #[test]
1832 fn test_arrow_single_float_row() {
1833 let schema = Schema::new(vec![
1834 Field::new("a", DataType::Float32, false),
1835 Field::new("b", DataType::Float32, false),
1836 Field::new("c", DataType::Int32, false),
1837 Field::new("d", DataType::Int32, false),
1838 ]);
1839 let arrays = vec![
1840 Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
1841 Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
1842 Arc::new(Int32Array::from(vec![2])) as ArrayRef,
1843 Arc::new(Int32Array::from(vec![1])) as ArrayRef,
1844 ];
1845 let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
1846 let mut file = tempfile::tempfile().unwrap();
1848 let mut stream_writer = crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
1849 stream_writer.write(&batch).unwrap();
1850 stream_writer.finish().unwrap();
1851
1852 drop(stream_writer);
1853
1854 file.rewind().unwrap();
1855
1856 let reader = StreamReader::try_new(&mut file, None).unwrap();
1858
1859 reader.for_each(|batch| {
1860 let batch = batch.unwrap();
1861 assert!(
1862 batch
1863 .column(0)
1864 .as_any()
1865 .downcast_ref::<Float32Array>()
1866 .unwrap()
1867 .value(0)
1868 != 0.0
1869 );
1870 assert!(
1871 batch
1872 .column(1)
1873 .as_any()
1874 .downcast_ref::<Float32Array>()
1875 .unwrap()
1876 .value(0)
1877 != 0.0
1878 );
1879 });
1880
1881 file.rewind().unwrap();
1882
1883 let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
1885
1886 reader.for_each(|batch| {
1887 let batch = batch.unwrap();
1888 assert_eq!(batch.schema().fields().len(), 2);
1889 assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
1890 assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
1891 });
1892 }
1893
1894 fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
1896 let mut buf = Vec::new();
1897 let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1898 writer.write(rb).unwrap();
1899 writer.finish().unwrap();
1900 buf
1901 }
1902
1903 fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1905 let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
1906 reader.next().unwrap()
1907 }
1908
1909 fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1912 let mut reader = unsafe {
1913 FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1914 };
1915 reader.next().unwrap()
1916 }
1917
1918 fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
1919 let buf = write_ipc(rb);
1920 read_ipc(&buf).unwrap()
1921 }
1922
1923 fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1926 read_ipc_with_decoder_inner(buf, false)
1927 }
1928
1929 fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1932 read_ipc_with_decoder_inner(buf, true)
1933 }
1934
1935 fn read_ipc_with_decoder_inner(
1936 buf: Vec<u8>,
1937 skip_validation: bool,
1938 ) -> Result<RecordBatch, ArrowError> {
1939 let buffer = Buffer::from_vec(buf);
1940 let trailer_start = buffer.len() - 10;
1941 let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
1942 let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
1943 .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;
1944
1945 let schema = fb_to_schema(footer.schema().unwrap());
1946
1947 let mut decoder = unsafe {
1948 FileDecoder::new(Arc::new(schema), footer.version())
1949 .with_skip_validation(skip_validation)
1950 };
1951 for block in footer.dictionaries().iter().flatten() {
1953 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1954 let data = buffer.slice_with_length(block.offset() as _, block_len);
1955 decoder.read_dictionary(block, &data)?
1956 }
1957
1958 let batches = footer.recordBatches().unwrap();
1960 assert_eq!(batches.len(), 1); let block = batches.get(0);
1963 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1964 let data = buffer.slice_with_length(block.offset() as _, block_len);
1965 Ok(decoder.read_record_batch(block, &data)?.unwrap())
1966 }
1967
1968 fn write_stream(rb: &RecordBatch) -> Vec<u8> {
1970 let mut buf = Vec::new();
1971 let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1972 writer.write(rb).unwrap();
1973 writer.finish().unwrap();
1974 buf
1975 }
1976
1977 fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1979 let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
1980 reader.next().unwrap()
1981 }
1982
1983 fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1986 let mut reader = unsafe {
1987 StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1988 };
1989 reader.next().unwrap()
1990 }
1991
1992 fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
1993 let buf = write_stream(rb);
1994 read_stream(&buf).unwrap()
1995 }
1996
1997 #[test]
1998 fn test_roundtrip_with_custom_metadata() {
1999 let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
2000 let mut buf = Vec::new();
2001 let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
2002 let mut test_metadata = HashMap::new();
2003 test_metadata.insert("abc".to_string(), "abc".to_string());
2004 test_metadata.insert("def".to_string(), "def".to_string());
2005 for (k, v) in &test_metadata {
2006 writer.write_metadata(k, v);
2007 }
2008 writer.finish().unwrap();
2009 drop(writer);
2010
2011 let reader = crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2012 assert_eq!(reader.custom_metadata(), &test_metadata);
2013 }
2014
2015 #[test]
2016 fn test_roundtrip_nested_dict() {
2017 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2018
2019 let array = Arc::new(inner) as ArrayRef;
2020
2021 let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2022
2023 let s = StructArray::from(vec![(dctfield, array)]);
2024 let struct_array = Arc::new(s) as ArrayRef;
2025
2026 let schema = Arc::new(Schema::new(vec![Field::new(
2027 "struct",
2028 struct_array.data_type().clone(),
2029 false,
2030 )]));
2031
2032 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2033
2034 assert_eq!(batch, roundtrip_ipc(&batch));
2035 }
2036
2037 #[test]
2038 fn test_roundtrip_nested_dict_no_preserve_dict_id() {
2039 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2040
2041 let array = Arc::new(inner) as ArrayRef;
2042
2043 let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
2044
2045 let s = StructArray::from(vec![(dctfield, array)]);
2046 let struct_array = Arc::new(s) as ArrayRef;
2047
2048 let schema = Arc::new(Schema::new(vec![Field::new(
2049 "struct",
2050 struct_array.data_type().clone(),
2051 false,
2052 )]));
2053
2054 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2055
2056 let mut buf = Vec::new();
2057 let mut writer = crate::writer::FileWriter::try_new_with_options(
2058 &mut buf,
2059 batch.schema_ref(),
2060 IpcWriteOptions::default(),
2061 )
2062 .unwrap();
2063 writer.write(&batch).unwrap();
2064 writer.finish().unwrap();
2065 drop(writer);
2066
2067 let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2068
2069 assert_eq!(batch, reader.next().unwrap().unwrap());
2070 }
2071
2072 fn check_union_with_builder(mut builder: UnionBuilder) {
2073 builder.append::<Int32Type>("a", 1).unwrap();
2074 builder.append_null::<Int32Type>("a").unwrap();
2075 builder.append::<Float64Type>("c", 3.0).unwrap();
2076 builder.append::<Int32Type>("a", 4).unwrap();
2077 builder.append::<Int64Type>("d", 11).unwrap();
2078 let union = builder.build().unwrap();
2079
2080 let schema = Arc::new(Schema::new(vec![Field::new(
2081 "union",
2082 union.data_type().clone(),
2083 false,
2084 )]));
2085
2086 let union_array = Arc::new(union) as ArrayRef;
2087
2088 let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
2089 let rb2 = roundtrip_ipc(&rb);
2090 assert_eq!(rb.schema(), rb2.schema());
2093 assert_eq!(rb.num_columns(), rb2.num_columns());
2094 assert_eq!(rb.num_rows(), rb2.num_rows());
2095 let union1 = rb.column(0);
2096 let union2 = rb2.column(0);
2097
2098 assert_eq!(union1, union2);
2099 }
2100
2101 #[test]
2102 fn test_roundtrip_dense_union() {
2103 check_union_with_builder(UnionBuilder::new_dense());
2104 }
2105
2106 #[test]
2107 fn test_roundtrip_sparse_union() {
2108 check_union_with_builder(UnionBuilder::new_sparse());
2109 }
2110
2111 #[test]
2112 fn test_roundtrip_struct_empty_fields() {
2113 let nulls = NullBuffer::from(&[true, true, false]);
2114 let rb = RecordBatch::try_from_iter([(
2115 "",
2116 Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _,
2117 )])
2118 .unwrap();
2119 let rb2 = roundtrip_ipc(&rb);
2120 assert_eq!(rb, rb2);
2121 }
2122
2123 #[test]
2124 fn test_roundtrip_stream_run_array_sliced() {
2125 let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
2126 .into_iter()
2127 .collect();
2128 let run_array_1_sliced = run_array_1.slice(2, 5);
2129
2130 let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
2131 let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
2132 run_array_2_builder.extend(run_array_2_inupt);
2133 let run_array_2 = run_array_2_builder.finish();
2134
2135 let schema = Arc::new(Schema::new(vec![
2136 Field::new(
2137 "run_array_1_sliced",
2138 run_array_1_sliced.data_type().clone(),
2139 false,
2140 ),
2141 Field::new("run_array_2", run_array_2.data_type().clone(), false),
2142 ]));
2143 let input_batch = RecordBatch::try_new(
2144 schema,
2145 vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
2146 )
2147 .unwrap();
2148 let output_batch = roundtrip_ipc_stream(&input_batch);
2149
2150 assert_eq!(input_batch.column(1), output_batch.column(1));
2154
2155 let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap();
2156 assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
2157 }
2158
2159 #[test]
2160 fn test_roundtrip_stream_nested_dict() {
2161 let xs = vec!["AA", "BB", "AA", "CC", "BB"];
2162 let dict = Arc::new(
2163 xs.clone()
2164 .into_iter()
2165 .collect::<DictionaryArray<Int8Type>>(),
2166 );
2167 let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
2168 let struct_array = StructArray::from(vec![
2169 (
2170 Arc::new(Field::new("f2.1", DataType::Utf8, false)),
2171 string_array,
2172 ),
2173 (
2174 Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
2175 dict.clone() as ArrayRef,
2176 ),
2177 ]);
2178 let schema = Arc::new(Schema::new(vec![
2179 Field::new("f1_string", DataType::Utf8, false),
2180 Field::new("f2_struct", struct_array.data_type().clone(), false),
2181 ]));
2182 let input_batch = RecordBatch::try_new(
2183 schema,
2184 vec![
2185 Arc::new(StringArray::from(xs.clone())),
2186 Arc::new(struct_array),
2187 ],
2188 )
2189 .unwrap();
2190 let output_batch = roundtrip_ipc_stream(&input_batch);
2191 assert_eq!(input_batch, output_batch);
2192 }
2193
2194 #[test]
2195 fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
2196 let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
2197 let values = Arc::new(values) as ArrayRef;
2198 let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
2199 let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
2200
2201 let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
2202 let key_dict_array = DictionaryArray::new(key_dict_keys, values);
2203
2204 #[allow(deprecated)]
2205 let keys_field = Arc::new(Field::new_dict(
2206 "keys",
2207 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2208 true, 1,
2210 false,
2211 ));
2212 #[allow(deprecated)]
2213 let values_field = Arc::new(Field::new_dict(
2214 "values",
2215 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2216 true,
2217 2,
2218 false,
2219 ));
2220 let entry_struct = StructArray::from(vec![
2221 (keys_field, make_array(key_dict_array.into_data())),
2222 (values_field, make_array(value_dict_array.into_data())),
2223 ]);
2224 let map_data_type = DataType::Map(
2225 Arc::new(Field::new(
2226 "entries",
2227 entry_struct.data_type().clone(),
2228 false,
2229 )),
2230 false,
2231 );
2232
2233 let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
2234 let map_data = ArrayData::builder(map_data_type)
2235 .len(3)
2236 .add_buffer(entry_offsets)
2237 .add_child_data(entry_struct.into_data())
2238 .build()
2239 .unwrap();
2240 let map_array = MapArray::from(map_data);
2241
2242 let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
2243 let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2244
2245 let schema = Arc::new(Schema::new(vec![Field::new(
2246 "f1",
2247 dict_dict_array.data_type().clone(),
2248 false,
2249 )]));
2250 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2251 let output_batch = roundtrip_ipc_stream(&input_batch);
2252 assert_eq!(input_batch, output_batch);
2253 }
2254
2255 fn test_roundtrip_stream_dict_of_list_of_dict_impl<
2256 OffsetSize: OffsetSizeTrait,
2257 U: ArrowNativeType,
2258 >(
2259 list_data_type: DataType,
2260 offsets: &[U; 5],
2261 ) {
2262 let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2263 let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2264 let dict_array = DictionaryArray::new(keys, Arc::new(values));
2265 let dict_data = dict_array.to_data();
2266
2267 let value_offsets = Buffer::from_slice_ref(offsets);
2268
2269 let list_data = ArrayData::builder(list_data_type)
2270 .len(4)
2271 .add_buffer(value_offsets)
2272 .add_child_data(dict_data)
2273 .build()
2274 .unwrap();
2275 let list_array = GenericListArray::<OffsetSize>::from(list_data);
2276
2277 let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
2278 let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2279
2280 let schema = Arc::new(Schema::new(vec![Field::new(
2281 "f1",
2282 dict_dict_array.data_type().clone(),
2283 false,
2284 )]));
2285 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2286 let output_batch = roundtrip_ipc_stream(&input_batch);
2287 assert_eq!(input_batch, output_batch);
2288 }
2289
2290 #[test]
2291 fn test_roundtrip_stream_dict_of_list_of_dict() {
2292 #[allow(deprecated)]
2294 let list_data_type = DataType::List(Arc::new(Field::new_dict(
2295 "item",
2296 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2297 true,
2298 1,
2299 false,
2300 )));
2301 let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
2302 test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(list_data_type, offsets);
2303
2304 #[allow(deprecated)]
2306 let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
2307 "item",
2308 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2309 true,
2310 1,
2311 false,
2312 )));
2313 let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
2314 test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(list_data_type, offsets);
2315 }
2316
2317 #[test]
2318 fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
2319 let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2320 let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
2321 let dict_array = DictionaryArray::new(keys, Arc::new(values));
2322 let dict_data = dict_array.into_data();
2323
2324 #[allow(deprecated)]
2325 let list_data_type = DataType::FixedSizeList(
2326 Arc::new(Field::new_dict(
2327 "item",
2328 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2329 true,
2330 1,
2331 false,
2332 )),
2333 3,
2334 );
2335 let list_data = ArrayData::builder(list_data_type)
2336 .len(3)
2337 .add_child_data(dict_data)
2338 .build()
2339 .unwrap();
2340 let list_array = FixedSizeListArray::from(list_data);
2341
2342 let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2343 let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2344
2345 let schema = Arc::new(Schema::new(vec![Field::new(
2346 "f1",
2347 dict_dict_array.data_type().clone(),
2348 false,
2349 )]));
2350 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2351 let output_batch = roundtrip_ipc_stream(&input_batch);
2352 assert_eq!(input_batch, output_batch);
2353 }
2354
2355 const LONG_TEST_STRING: &str =
2356 "This is a long string to make sure binary view array handles it";
2357
2358 #[test]
2359 fn test_roundtrip_view_types() {
2360 let schema = Schema::new(vec![
2361 Field::new("field_1", DataType::BinaryView, true),
2362 Field::new("field_2", DataType::Utf8, true),
2363 Field::new("field_3", DataType::Utf8View, true),
2364 ]);
2365 let bin_values: Vec<Option<&[u8]>> = vec![
2366 Some(b"foo"),
2367 None,
2368 Some(b"bar"),
2369 Some(LONG_TEST_STRING.as_bytes()),
2370 ];
2371 let utf8_values: Vec<Option<&str>> =
2372 vec![Some("foo"), None, Some("bar"), Some(LONG_TEST_STRING)];
2373 let bin_view_array = BinaryViewArray::from_iter(bin_values);
2374 let utf8_array = StringArray::from_iter(utf8_values.iter());
2375 let utf8_view_array = StringViewArray::from_iter(utf8_values);
2376 let record_batch = RecordBatch::try_new(
2377 Arc::new(schema.clone()),
2378 vec![
2379 Arc::new(bin_view_array),
2380 Arc::new(utf8_array),
2381 Arc::new(utf8_view_array),
2382 ],
2383 )
2384 .unwrap();
2385
2386 assert_eq!(record_batch, roundtrip_ipc(&record_batch));
2387 assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));
2388
2389 let sliced_batch = record_batch.slice(1, 2);
2390 assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2391 assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2392 }
2393
2394 #[test]
2395 fn test_roundtrip_view_types_nested_dict() {
2396 let bin_values: Vec<Option<&[u8]>> = vec![
2397 Some(b"foo"),
2398 None,
2399 Some(b"bar"),
2400 Some(LONG_TEST_STRING.as_bytes()),
2401 Some(b"field"),
2402 ];
2403 let utf8_values: Vec<Option<&str>> = vec![
2404 Some("foo"),
2405 None,
2406 Some("bar"),
2407 Some(LONG_TEST_STRING),
2408 Some("field"),
2409 ];
2410 let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
2411 let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));
2412
2413 let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2414 let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
2415 #[allow(deprecated)]
2416 let keys_field = Arc::new(Field::new_dict(
2417 "keys",
2418 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
2419 true,
2420 1,
2421 false,
2422 ));
2423
2424 let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
2425 let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
2426 #[allow(deprecated)]
2427 let values_field = Arc::new(Field::new_dict(
2428 "values",
2429 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
2430 true,
2431 2,
2432 false,
2433 ));
2434 let entry_struct = StructArray::from(vec![
2435 (keys_field, make_array(key_dict_array.into_data())),
2436 (values_field, make_array(value_dict_array.into_data())),
2437 ]);
2438
2439 let map_data_type = DataType::Map(
2440 Arc::new(Field::new(
2441 "entries",
2442 entry_struct.data_type().clone(),
2443 false,
2444 )),
2445 false,
2446 );
2447 let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
2448 let map_data = ArrayData::builder(map_data_type)
2449 .len(3)
2450 .add_buffer(entry_offsets)
2451 .add_child_data(entry_struct.into_data())
2452 .build()
2453 .unwrap();
2454 let map_array = MapArray::from(map_data);
2455
2456 let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2457 let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2458 let schema = Arc::new(Schema::new(vec![Field::new(
2459 "f1",
2460 dict_dict_array.data_type().clone(),
2461 false,
2462 )]));
2463 let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2464 assert_eq!(batch, roundtrip_ipc(&batch));
2465 assert_eq!(batch, roundtrip_ipc_stream(&batch));
2466
2467 let sliced_batch = batch.slice(1, 2);
2468 assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2469 assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2470 }
2471
2472 #[test]
2473 fn test_no_columns_batch() {
2474 let schema = Arc::new(Schema::empty());
2475 let options = RecordBatchOptions::new()
2476 .with_match_field_names(true)
2477 .with_row_count(Some(10));
2478 let input_batch = RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
2479 let output_batch = roundtrip_ipc_stream(&input_batch);
2480 assert_eq!(input_batch, output_batch);
2481 }
2482
2483 #[test]
2484 fn test_unaligned() {
2485 let batch = RecordBatch::try_from_iter(vec![(
2486 "i32",
2487 Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2488 )])
2489 .unwrap();
2490
2491 let gen = IpcDataGenerator {};
2492 let mut dict_tracker = DictionaryTracker::new(false);
2493 let (_, encoded) = gen
2494 .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2495 .unwrap();
2496
2497 let message = root_as_message(&encoded.ipc_message).unwrap();
2498
2499 let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2501 buffer.push(0_u8);
2502 buffer.extend_from_slice(&encoded.arrow_data);
2503 let b = Buffer::from(buffer).slice(1);
2504 assert_ne!(b.as_ptr().align_offset(8), 0);
2505
2506 let ipc_batch = message.header_as_record_batch().unwrap();
2507 let roundtrip = RecordBatchDecoder::try_new(
2508 &b,
2509 ipc_batch,
2510 batch.schema(),
2511 &Default::default(),
2512 &message.version(),
2513 )
2514 .unwrap()
2515 .with_require_alignment(false)
2516 .read_record_batch()
2517 .unwrap();
2518 assert_eq!(batch, roundtrip);
2519 }
2520
2521 #[test]
2522 fn test_unaligned_throws_error_with_require_alignment() {
2523 let batch = RecordBatch::try_from_iter(vec![(
2524 "i32",
2525 Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2526 )])
2527 .unwrap();
2528
2529 let gen = IpcDataGenerator {};
2530 let mut dict_tracker = DictionaryTracker::new(false);
2531 let (_, encoded) = gen
2532 .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2533 .unwrap();
2534
2535 let message = root_as_message(&encoded.ipc_message).unwrap();
2536
2537 let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2539 buffer.push(0_u8);
2540 buffer.extend_from_slice(&encoded.arrow_data);
2541 let b = Buffer::from(buffer).slice(1);
2542 assert_ne!(b.as_ptr().align_offset(8), 0);
2543
2544 let ipc_batch = message.header_as_record_batch().unwrap();
2545 let result = RecordBatchDecoder::try_new(
2546 &b,
2547 ipc_batch,
2548 batch.schema(),
2549 &Default::default(),
2550 &message.version(),
2551 )
2552 .unwrap()
2553 .with_require_alignment(true)
2554 .read_record_batch();
2555
2556 let error = result.unwrap_err();
2557 assert_eq!(
2558 error.to_string(),
2559 "Invalid argument error: Misaligned buffers[0] in array of type Int32, \
2560 offset from expected alignment of 4 by 1"
2561 );
2562 }
2563
2564 #[test]
2565 fn test_file_with_massive_column_count() {
2566 let limit = 600_000;
2568
2569 let fields = (0..limit)
2570 .map(|i| Field::new(format!("{i}"), DataType::Boolean, false))
2571 .collect::<Vec<_>>();
2572 let schema = Arc::new(Schema::new(fields));
2573 let batch = RecordBatch::new_empty(schema);
2574
2575 let mut buf = Vec::new();
2576 let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2577 writer.write(&batch).unwrap();
2578 writer.finish().unwrap();
2579 drop(writer);
2580
2581 let mut reader = FileReaderBuilder::new()
2582 .with_max_footer_fb_tables(1_500_000)
2583 .build(std::io::Cursor::new(buf))
2584 .unwrap();
2585 let roundtrip_batch = reader.next().unwrap().unwrap();
2586
2587 assert_eq!(batch, roundtrip_batch);
2588 }
2589
2590 #[test]
2591 fn test_file_with_deeply_nested_columns() {
2592 let limit = 61;
2594
2595 let fields = (0..limit).fold(
2596 vec![Field::new("leaf", DataType::Boolean, false)],
2597 |field, index| vec![Field::new_struct(format!("{index}"), field, false)],
2598 );
2599 let schema = Arc::new(Schema::new(fields));
2600 let batch = RecordBatch::new_empty(schema);
2601
2602 let mut buf = Vec::new();
2603 let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2604 writer.write(&batch).unwrap();
2605 writer.finish().unwrap();
2606 drop(writer);
2607
2608 let mut reader = FileReaderBuilder::new()
2609 .with_max_footer_fb_depth(65)
2610 .build(std::io::Cursor::new(buf))
2611 .unwrap();
2612 let roundtrip_batch = reader.next().unwrap().unwrap();
2613
2614 assert_eq!(batch, roundtrip_batch);
2615 }
2616
2617 #[test]
2618 fn test_invalid_struct_array_ipc_read_errors() {
2619 let a_field = Field::new("a", DataType::Int32, false);
2620 let b_field = Field::new("b", DataType::Int32, false);
2621 let struct_fields = Fields::from(vec![a_field.clone(), b_field.clone()]);
2622
2623 let a_array_data = ArrayData::builder(a_field.data_type().clone())
2624 .len(4)
2625 .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2626 .build()
2627 .unwrap();
2628 let b_array_data = ArrayData::builder(b_field.data_type().clone())
2629 .len(3)
2630 .add_buffer(Buffer::from_slice_ref([5, 6, 7]))
2631 .build()
2632 .unwrap();
2633
2634 let invalid_struct_arr = unsafe {
2635 StructArray::new_unchecked(
2636 struct_fields,
2637 vec![make_array(a_array_data), make_array(b_array_data)],
2638 None,
2639 )
2640 };
2641
2642 expect_ipc_validation_error(
2643 Arc::new(invalid_struct_arr),
2644 "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
2645 );
2646 }
2647
2648 #[test]
2649 fn test_invalid_nested_array_ipc_read_errors() {
2650 let a_field = Field::new("a", DataType::Int32, false);
2652 let b_field = Field::new("b", DataType::Utf8, false);
2653
2654 let schema = Arc::new(Schema::new(vec![Field::new_struct(
2655 "s",
2656 vec![a_field.clone(), b_field.clone()],
2657 false,
2658 )]));
2659
2660 let a_array_data = ArrayData::builder(a_field.data_type().clone())
2661 .len(4)
2662 .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2663 .build()
2664 .unwrap();
2665 let b_array_data = {
2667 let valid: &[u8] = b" ";
2668 let mut invalid = vec![];
2669 invalid.extend_from_slice(b"ValidString");
2670 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2671 let binary_array =
2672 BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2673 let array = unsafe {
2674 StringArray::new_unchecked(
2675 binary_array.offsets().clone(),
2676 binary_array.values().clone(),
2677 binary_array.nulls().cloned(),
2678 )
2679 };
2680 array.into_data()
2681 };
2682 let struct_data_type = schema.field(0).data_type();
2683
2684 let invalid_struct_arr = unsafe {
2685 make_array(
2686 ArrayData::builder(struct_data_type.clone())
2687 .len(4)
2688 .add_child_data(a_array_data)
2689 .add_child_data(b_array_data)
2690 .build_unchecked(),
2691 )
2692 };
2693 expect_ipc_validation_error(
2694 Arc::new(invalid_struct_arr),
2695 "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11",
2696 );
2697 }
2698
2699 #[test]
2700 fn test_same_dict_id_without_preserve() {
2701 let batch = RecordBatch::try_new(
2702 Arc::new(Schema::new(
2703 ["a", "b"]
2704 .iter()
2705 .map(|name| {
2706 #[allow(deprecated)]
2707 Field::new_dict(
2708 name.to_string(),
2709 DataType::Dictionary(
2710 Box::new(DataType::Int32),
2711 Box::new(DataType::Utf8),
2712 ),
2713 true,
2714 0,
2715 false,
2716 )
2717 })
2718 .collect::<Vec<Field>>(),
2719 )),
2720 vec![
2721 Arc::new(
2722 vec![Some("c"), Some("d")]
2723 .into_iter()
2724 .collect::<DictionaryArray<Int32Type>>(),
2725 ) as ArrayRef,
2726 Arc::new(
2727 vec![Some("e"), Some("f")]
2728 .into_iter()
2729 .collect::<DictionaryArray<Int32Type>>(),
2730 ) as ArrayRef,
2731 ],
2732 )
2733 .expect("Failed to create RecordBatch");
2734
2735 let mut buf = vec![];
2737 {
2738 let mut writer = crate::writer::StreamWriter::try_new_with_options(
2739 &mut buf,
2740 batch.schema().as_ref(),
2741 crate::writer::IpcWriteOptions::default(),
2742 )
2743 .expect("Failed to create StreamWriter");
2744 writer.write(&batch).expect("Failed to write RecordBatch");
2745 writer.finish().expect("Failed to finish StreamWriter");
2746 }
2747
2748 StreamReader::try_new(std::io::Cursor::new(buf), None)
2749 .expect("Failed to create StreamReader")
2750 .for_each(|decoded_batch| {
2751 assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
2752 });
2753 }
2754
2755 #[test]
2756 fn test_validation_of_invalid_list_array() {
2757 let array = unsafe {
2759 let values = Int32Array::from(vec![1, 2, 3]);
2760 let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); let offsets = OffsetBuffer::new_unchecked(bad_offsets); let field = Field::new_list_field(DataType::Int32, true);
2763 let nulls = None;
2764 ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
2765 };
2766
2767 expect_ipc_validation_error(
2768 Arc::new(array),
2769 "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
2770 );
2771 }
2772
2773 #[test]
2774 fn test_validation_of_invalid_string_array() {
2775 let valid: &[u8] = b" ";
2776 let mut invalid = vec![];
2777 invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2778 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2779 let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2780 let array = unsafe {
2783 StringArray::new_unchecked(
2784 binary_array.offsets().clone(),
2785 binary_array.values().clone(),
2786 binary_array.nulls().cloned(),
2787 )
2788 };
2789 expect_ipc_validation_error(
2790 Arc::new(array),
2791 "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
2792 );
2793 }
2794
2795 #[test]
2796 fn test_validation_of_invalid_string_view_array() {
2797 let valid: &[u8] = b" ";
2798 let mut invalid = vec![];
2799 invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2800 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2801 let binary_view_array =
2802 BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2803 let array = unsafe {
2806 StringViewArray::new_unchecked(
2807 binary_view_array.views().clone(),
2808 binary_view_array.data_buffers().to_vec(),
2809 binary_view_array.nulls().cloned(),
2810 )
2811 };
2812 expect_ipc_validation_error(
2813 Arc::new(array),
2814 "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
2815 );
2816 }
2817
2818 #[test]
2821 fn test_validation_of_invalid_dictionary_array() {
2822 let array = unsafe {
2823 let values = StringArray::from_iter_values(["a", "b", "c"]);
2824 let keys = Int32Array::from(vec![1, 200]); DictionaryArray::new_unchecked(keys, Arc::new(values))
2826 };
2827
2828 expect_ipc_validation_error(
2829 Arc::new(array),
2830 "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
2831 );
2832 }
2833
2834 #[test]
2835 fn test_validation_of_invalid_union_array() {
2836 let array = unsafe {
2837 let fields = UnionFields::new(
2838 vec![1, 3], vec![
2840 Field::new("a", DataType::Int32, false),
2841 Field::new("b", DataType::Utf8, false),
2842 ],
2843 );
2844 let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); let offsets = None;
2846 let children: Vec<ArrayRef> = vec![
2847 Arc::new(Int32Array::from(vec![10, 20, 30])),
2848 Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
2849 ];
2850
2851 UnionArray::new_unchecked(fields, type_ids, offsets, children)
2852 };
2853
2854 expect_ipc_validation_error(
2855 Arc::new(array),
2856 "Invalid argument error: Type Ids values must match one of the field type ids",
2857 );
2858 }
2859
2860 const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
2863
2864 fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
2866 let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();
2867
2868 let buf = write_stream(&rb); read_stream_skip_validation(&buf).unwrap();
2871 let err = read_stream(&buf).unwrap_err();
2872 assert_eq!(err.to_string(), expected_err);
2873
2874 let buf = write_ipc(&rb); read_ipc_skip_validation(&buf).unwrap();
2877 let err = read_ipc(&buf).unwrap_err();
2878 assert_eq!(err.to_string(), expected_err);
2879
2880 read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
2882 let err = read_ipc_with_decoder(buf).unwrap_err();
2883 assert_eq!(err.to_string(), expected_err);
2884 }
2885
2886 #[test]
2887 fn test_roundtrip_schema() {
2888 let schema = Schema::new(vec![
2889 Field::new(
2890 "a",
2891 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2892 false,
2893 ),
2894 Field::new(
2895 "b",
2896 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2897 false,
2898 ),
2899 ]);
2900
2901 let options = IpcWriteOptions::default();
2902 let data_gen = IpcDataGenerator::default();
2903 let mut dict_tracker = DictionaryTracker::new(false);
2904 let encoded_data =
2905 data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
2906 let mut schema_bytes = vec![];
2907 write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
2908
2909 let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
2910 4
2911 } else {
2912 0
2913 };
2914
2915 size_prefixed_root_as_message(&schema_bytes[begin_offset..])
2916 .expect_err("size_prefixed_root_as_message");
2917
2918 let msg = parse_message(&schema_bytes).expect("parse_message");
2919 let ipc_schema = msg.header_as_schema().expect("header_as_schema");
2920 let new_schema = fb_to_schema(ipc_schema);
2921
2922 assert_eq!(schema, new_schema);
2923 }
2924}