1mod stream;
28
29pub use stream::*;
30
31use flatbuffers::{VectorIter, VerifierOptions};
32use std::collections::{HashMap, VecDeque};
33use std::fmt;
34use std::io::{BufReader, Read, Seek, SeekFrom};
35use std::sync::Arc;
36
37use arrow_array::*;
38use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer};
39use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
40use arrow_schema::*;
41
42use crate::compression::CompressionCodec;
43use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER};
44use DataType::*;
45
46fn read_buffer(
56 buf: &crate::Buffer,
57 a_data: &Buffer,
58 compression_codec: Option<CompressionCodec>,
59) -> Result<Buffer, ArrowError> {
60 let start_offset = buf.offset() as usize;
61 let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
62 match (buf_data.is_empty(), compression_codec) {
64 (true, _) | (_, None) => Ok(buf_data),
65 (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
66 }
67}
68impl RecordBatchDecoder<'_> {
69 fn create_array(
82 &mut self,
83 field: &Field,
84 variadic_counts: &mut VecDeque<i64>,
85 ) -> Result<ArrayRef, ArrowError> {
86 let data_type = field.data_type();
87 match data_type {
88 Utf8 | Binary | LargeBinary | LargeUtf8 => {
89 let field_node = self.next_node(field)?;
90 let buffers = [
91 self.next_buffer()?,
92 self.next_buffer()?,
93 self.next_buffer()?,
94 ];
95 self.create_primitive_array(field_node, data_type, &buffers)
96 }
97 BinaryView | Utf8View => {
98 let count = variadic_counts
99 .pop_front()
100 .ok_or(ArrowError::IpcError(format!(
101 "Missing variadic count for {data_type} column"
102 )))?;
103 let count = count + 2; let buffers = (0..count)
105 .map(|_| self.next_buffer())
106 .collect::<Result<Vec<_>, _>>()?;
107 let field_node = self.next_node(field)?;
108 self.create_primitive_array(field_node, data_type, &buffers)
109 }
110 FixedSizeBinary(_) => {
111 let field_node = self.next_node(field)?;
112 let buffers = [self.next_buffer()?, self.next_buffer()?];
113 self.create_primitive_array(field_node, data_type, &buffers)
114 }
115 List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
116 let list_node = self.next_node(field)?;
117 let list_buffers = [self.next_buffer()?, self.next_buffer()?];
118 let values = self.create_array(list_field, variadic_counts)?;
119 self.create_list_array(list_node, data_type, &list_buffers, values)
120 }
121 FixedSizeList(ref list_field, _) => {
122 let list_node = self.next_node(field)?;
123 let list_buffers = [self.next_buffer()?];
124 let values = self.create_array(list_field, variadic_counts)?;
125 self.create_list_array(list_node, data_type, &list_buffers, values)
126 }
127 Struct(struct_fields) => {
128 let struct_node = self.next_node(field)?;
129 let null_buffer = self.next_buffer()?;
130
131 let mut struct_arrays = vec![];
133 for struct_field in struct_fields {
136 let child = self.create_array(struct_field, variadic_counts)?;
137 struct_arrays.push(child);
138 }
139 self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays)
140 }
141 RunEndEncoded(run_ends_field, values_field) => {
142 let run_node = self.next_node(field)?;
143 let run_ends = self.create_array(run_ends_field, variadic_counts)?;
144 let values = self.create_array(values_field, variadic_counts)?;
145
146 let run_array_length = run_node.length() as usize;
147 let builder = ArrayData::builder(data_type.clone())
148 .len(run_array_length)
149 .offset(0)
150 .add_child_data(run_ends.into_data())
151 .add_child_data(values.into_data());
152 self.create_array_from_builder(builder)
153 }
154 Dictionary(_, _) => {
156 let index_node = self.next_node(field)?;
157 let index_buffers = [self.next_buffer()?, self.next_buffer()?];
158
159 #[allow(deprecated)]
160 let dict_id = field.dict_id().ok_or_else(|| {
161 ArrowError::ParseError(format!("Field {field} does not have dict id"))
162 })?;
163
164 let value_array = self.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
165 ArrowError::ParseError(format!(
166 "Cannot find a dictionary batch with dict id: {dict_id}"
167 ))
168 })?;
169
170 self.create_dictionary_array(
171 index_node,
172 data_type,
173 &index_buffers,
174 value_array.clone(),
175 )
176 }
177 Union(fields, mode) => {
178 let union_node = self.next_node(field)?;
179 let len = union_node.length() as usize;
180
181 if self.version < MetadataVersion::V5 {
184 self.next_buffer()?;
185 }
186
187 let type_ids: ScalarBuffer<i8> =
188 self.next_buffer()?.slice_with_length(0, len).into();
189
190 let value_offsets = match mode {
191 UnionMode::Dense => {
192 let offsets: ScalarBuffer<i32> =
193 self.next_buffer()?.slice_with_length(0, len * 4).into();
194 Some(offsets)
195 }
196 UnionMode::Sparse => None,
197 };
198
199 let mut children = Vec::with_capacity(fields.len());
200
201 for (_id, field) in fields.iter() {
202 let child = self.create_array(field, variadic_counts)?;
203 children.push(child);
204 }
205
206 let array = if self.skip_validation.get() {
207 unsafe {
209 UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children)
210 }
211 } else {
212 UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?
213 };
214 Ok(Arc::new(array))
215 }
216 Null => {
217 let node = self.next_node(field)?;
218 let length = node.length();
219 let null_count = node.null_count();
220
221 if length != null_count {
222 return Err(ArrowError::SchemaError(format!(
223 "Field {field} of NullArray has unequal null_count {null_count} and len {length}"
224 )));
225 }
226
227 let builder = ArrayData::builder(data_type.clone())
228 .len(length as usize)
229 .offset(0);
230 self.create_array_from_builder(builder)
231 }
232 _ => {
233 let field_node = self.next_node(field)?;
234 let buffers = [self.next_buffer()?, self.next_buffer()?];
235 self.create_primitive_array(field_node, data_type, &buffers)
236 }
237 }
238 }
239
240 fn create_primitive_array(
243 &self,
244 field_node: &FieldNode,
245 data_type: &DataType,
246 buffers: &[Buffer],
247 ) -> Result<ArrayRef, ArrowError> {
248 let length = field_node.length() as usize;
249 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
250 let builder = match data_type {
251 Utf8 | Binary | LargeBinary | LargeUtf8 => {
252 ArrayData::builder(data_type.clone())
254 .len(length)
255 .buffers(buffers[1..3].to_vec())
256 .null_bit_buffer(null_buffer)
257 }
258 BinaryView | Utf8View => ArrayData::builder(data_type.clone())
259 .len(length)
260 .buffers(buffers[1..].to_vec())
261 .null_bit_buffer(null_buffer),
262 _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => {
263 ArrayData::builder(data_type.clone())
265 .len(length)
266 .add_buffer(buffers[1].clone())
267 .null_bit_buffer(null_buffer)
268 }
269 t => unreachable!("Data type {:?} either unsupported or not primitive", t),
270 };
271
272 self.create_array_from_builder(builder)
273 }
274
275 fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result<ArrayRef, ArrowError> {
277 let mut builder = builder.align_buffers(!self.require_alignment);
278 if self.skip_validation.get() {
279 unsafe { builder = builder.skip_validation(true) }
281 };
282 Ok(make_array(builder.build()?))
283 }
284
285 fn create_list_array(
288 &self,
289 field_node: &FieldNode,
290 data_type: &DataType,
291 buffers: &[Buffer],
292 child_array: ArrayRef,
293 ) -> Result<ArrayRef, ArrowError> {
294 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
295 let length = field_node.length() as usize;
296 let child_data = child_array.into_data();
297 let builder = match data_type {
298 List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
299 .len(length)
300 .add_buffer(buffers[1].clone())
301 .add_child_data(child_data)
302 .null_bit_buffer(null_buffer),
303
304 FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
305 .len(length)
306 .add_child_data(child_data)
307 .null_bit_buffer(null_buffer),
308
309 _ => unreachable!("Cannot create list or map array from {:?}", data_type),
310 };
311
312 self.create_array_from_builder(builder)
313 }
314
315 fn create_struct_array(
316 &self,
317 struct_node: &FieldNode,
318 null_buffer: Buffer,
319 struct_fields: &Fields,
320 struct_arrays: Vec<ArrayRef>,
321 ) -> Result<ArrayRef, ArrowError> {
322 let null_count = struct_node.null_count() as usize;
323 let len = struct_node.length() as usize;
324
325 let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into());
326 if struct_arrays.is_empty() {
327 return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
330 }
331
332 let struct_array = if self.skip_validation.get() {
333 unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) }
335 } else {
336 StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
337 };
338
339 Ok(Arc::new(struct_array))
340 }
341
342 fn create_dictionary_array(
345 &self,
346 field_node: &FieldNode,
347 data_type: &DataType,
348 buffers: &[Buffer],
349 value_array: ArrayRef,
350 ) -> Result<ArrayRef, ArrowError> {
351 if let Dictionary(_, _) = *data_type {
352 let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
353 let builder = ArrayData::builder(data_type.clone())
354 .len(field_node.length() as usize)
355 .add_buffer(buffers[1].clone())
356 .add_child_data(value_array.into_data())
357 .null_bit_buffer(null_buffer);
358 self.create_array_from_builder(builder)
359 } else {
360 unreachable!("Cannot create dictionary array from {:?}", data_type)
361 }
362 }
363}
364
365struct RecordBatchDecoder<'a> {
370 batch: crate::RecordBatch<'a>,
372 schema: SchemaRef,
374 dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
376 compression: Option<CompressionCodec>,
378 version: MetadataVersion,
380 data: &'a Buffer,
382 nodes: VectorIter<'a, FieldNode>,
384 buffers: VectorIter<'a, crate::Buffer>,
386 projection: Option<&'a [usize]>,
389 require_alignment: bool,
392 skip_validation: UnsafeFlag,
396}
397
398impl<'a> RecordBatchDecoder<'a> {
399 fn try_new(
401 buf: &'a Buffer,
402 batch: crate::RecordBatch<'a>,
403 schema: SchemaRef,
404 dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
405 metadata: &'a MetadataVersion,
406 ) -> Result<Self, ArrowError> {
407 let buffers = batch.buffers().ok_or_else(|| {
408 ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string())
409 })?;
410 let field_nodes = batch.nodes().ok_or_else(|| {
411 ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string())
412 })?;
413
414 let batch_compression = batch.compression();
415 let compression = batch_compression
416 .map(|batch_compression| batch_compression.codec().try_into())
417 .transpose()?;
418
419 Ok(Self {
420 batch,
421 schema,
422 dictionaries_by_id,
423 compression,
424 version: *metadata,
425 data: buf,
426 nodes: field_nodes.iter(),
427 buffers: buffers.iter(),
428 projection: None,
429 require_alignment: false,
430 skip_validation: UnsafeFlag::new(),
431 })
432 }
433
434 pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self {
439 self.projection = projection;
440 self
441 }
442
443 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
449 self.require_alignment = require_alignment;
450 self
451 }
452
453 pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
465 self.skip_validation = skip_validation;
466 self
467 }
468
469 fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
471 let mut variadic_counts: VecDeque<i64> = self
472 .batch
473 .variadicBufferCounts()
474 .into_iter()
475 .flatten()
476 .collect();
477
478 let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
479
480 let schema = Arc::clone(&self.schema);
481 if let Some(projection) = self.projection {
482 let mut arrays = vec![];
483 for (idx, field) in schema.fields().iter().enumerate() {
485 if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
487 let child = self.create_array(field, &mut variadic_counts)?;
488 arrays.push((proj_idx, child));
489 } else {
490 self.skip_field(field, &mut variadic_counts)?;
491 }
492 }
493
494 arrays.sort_by_key(|t| t.0);
495
496 let schema = Arc::new(schema.project(projection)?);
497 let columns = arrays.into_iter().map(|t| t.1).collect::<Vec<_>>();
498
499 if self.skip_validation.get() {
500 unsafe {
502 Ok(RecordBatch::new_unchecked(
503 schema,
504 columns,
505 self.batch.length() as usize,
506 ))
507 }
508 } else {
509 assert!(variadic_counts.is_empty());
510 RecordBatch::try_new_with_options(schema, columns, &options)
511 }
512 } else {
513 let mut children = vec![];
514 for field in schema.fields() {
516 let child = self.create_array(field, &mut variadic_counts)?;
517 children.push(child);
518 }
519
520 if self.skip_validation.get() {
521 unsafe {
523 Ok(RecordBatch::new_unchecked(
524 schema,
525 children,
526 self.batch.length() as usize,
527 ))
528 }
529 } else {
530 assert!(variadic_counts.is_empty());
531 RecordBatch::try_new_with_options(schema, children, &options)
532 }
533 }
534 }
535
536 fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
537 read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
538 }
539
540 fn skip_buffer(&mut self) {
541 self.buffers.next().unwrap();
542 }
543
544 fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
545 self.nodes.next().ok_or_else(|| {
546 ArrowError::SchemaError(format!(
547 "Invalid data for schema. {} refers to node not found in schema",
548 field
549 ))
550 })
551 }
552
553 fn skip_field(
554 &mut self,
555 field: &Field,
556 variadic_count: &mut VecDeque<i64>,
557 ) -> Result<(), ArrowError> {
558 self.next_node(field)?;
559
560 match field.data_type() {
561 Utf8 | Binary | LargeBinary | LargeUtf8 => {
562 for _ in 0..3 {
563 self.skip_buffer()
564 }
565 }
566 Utf8View | BinaryView => {
567 let count = variadic_count
568 .pop_front()
569 .ok_or(ArrowError::IpcError(format!(
570 "Missing variadic count for {} column",
571 field.data_type()
572 )))?;
573 let count = count + 2; for _i in 0..count {
575 self.skip_buffer()
576 }
577 }
578 FixedSizeBinary(_) => {
579 self.skip_buffer();
580 self.skip_buffer();
581 }
582 List(list_field) | LargeList(list_field) | Map(list_field, _) => {
583 self.skip_buffer();
584 self.skip_buffer();
585 self.skip_field(list_field, variadic_count)?;
586 }
587 FixedSizeList(list_field, _) => {
588 self.skip_buffer();
589 self.skip_field(list_field, variadic_count)?;
590 }
591 Struct(struct_fields) => {
592 self.skip_buffer();
593
594 for struct_field in struct_fields {
596 self.skip_field(struct_field, variadic_count)?
597 }
598 }
599 RunEndEncoded(run_ends_field, values_field) => {
600 self.skip_field(run_ends_field, variadic_count)?;
601 self.skip_field(values_field, variadic_count)?;
602 }
603 Dictionary(_, _) => {
604 self.skip_buffer(); self.skip_buffer(); }
607 Union(fields, mode) => {
608 self.skip_buffer(); match mode {
611 UnionMode::Dense => self.skip_buffer(),
612 UnionMode::Sparse => {}
613 };
614
615 for (_, field) in fields.iter() {
616 self.skip_field(field, variadic_count)?
617 }
618 }
619 Null => {} _ => {
621 self.skip_buffer();
622 self.skip_buffer();
623 }
624 };
625 Ok(())
626 }
627}
628
629pub fn read_record_batch(
640 buf: &Buffer,
641 batch: crate::RecordBatch,
642 schema: SchemaRef,
643 dictionaries_by_id: &HashMap<i64, ArrayRef>,
644 projection: Option<&[usize]>,
645 metadata: &MetadataVersion,
646) -> Result<RecordBatch, ArrowError> {
647 RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
648 .with_projection(projection)
649 .with_require_alignment(false)
650 .read_record_batch()
651}
652
653pub fn read_dictionary(
656 buf: &Buffer,
657 batch: crate::DictionaryBatch,
658 schema: &Schema,
659 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
660 metadata: &MetadataVersion,
661) -> Result<(), ArrowError> {
662 read_dictionary_impl(
663 buf,
664 batch,
665 schema,
666 dictionaries_by_id,
667 metadata,
668 false,
669 UnsafeFlag::new(),
670 )
671}
672
673fn read_dictionary_impl(
674 buf: &Buffer,
675 batch: crate::DictionaryBatch,
676 schema: &Schema,
677 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
678 metadata: &MetadataVersion,
679 require_alignment: bool,
680 skip_validation: UnsafeFlag,
681) -> Result<(), ArrowError> {
682 if batch.isDelta() {
683 return Err(ArrowError::InvalidArgumentError(
684 "delta dictionary batches not supported".to_string(),
685 ));
686 }
687
688 let id = batch.id();
689 #[allow(deprecated)]
690 let fields_using_this_dictionary = schema.fields_with_dict_id(id);
691 let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
692 ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
693 })?;
694
695 let dictionary_values: ArrayRef = match first_field.data_type() {
699 DataType::Dictionary(_, ref value_type) => {
700 let value = value_type.as_ref().clone();
702 let schema = Schema::new(vec![Field::new("", value, true)]);
703 let record_batch = RecordBatchDecoder::try_new(
705 buf,
706 batch.data().unwrap(),
707 Arc::new(schema),
708 dictionaries_by_id,
709 metadata,
710 )?
711 .with_require_alignment(require_alignment)
712 .with_skip_validation(skip_validation)
713 .read_record_batch()?;
714
715 Some(record_batch.column(0).clone())
716 }
717 _ => None,
718 }
719 .ok_or_else(|| {
720 ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
721 })?;
722
723 dictionaries_by_id.insert(id, dictionary_values.clone());
727
728 Ok(())
729}
730
731fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
733 reader.seek(SeekFrom::Start(block.offset() as u64))?;
734 let body_len = block.bodyLength().to_usize().unwrap();
735 let metadata_len = block.metaDataLength().to_usize().unwrap();
736 let total_len = body_len.checked_add(metadata_len).unwrap();
737
738 let mut buf = MutableBuffer::from_len_zeroed(total_len);
739 reader.read_exact(&mut buf)?;
740 Ok(buf.into())
741}
742
743fn parse_message(buf: &[u8]) -> Result<Message, ArrowError> {
747 let buf = match buf[..4] == CONTINUATION_MARKER {
748 true => &buf[8..],
749 false => &buf[4..],
750 };
751 crate::root_as_message(buf)
752 .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
753}
754
755pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
759 if buf[4..] != super::ARROW_MAGIC {
760 return Err(ArrowError::ParseError(
761 "Arrow file does not contain correct footer".to_string(),
762 ));
763 }
764
765 let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
767 footer_len
768 .try_into()
769 .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
770}
771
772#[derive(Debug)]
837pub struct FileDecoder {
838 schema: SchemaRef,
839 dictionaries: HashMap<i64, ArrayRef>,
840 version: MetadataVersion,
841 projection: Option<Vec<usize>>,
842 require_alignment: bool,
843 skip_validation: UnsafeFlag,
844}
845
846impl FileDecoder {
847 pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
849 Self {
850 schema,
851 version,
852 dictionaries: Default::default(),
853 projection: None,
854 require_alignment: false,
855 skip_validation: UnsafeFlag::new(),
856 }
857 }
858
859 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
861 self.projection = Some(projection);
862 self
863 }
864
865 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
878 self.require_alignment = require_alignment;
879 self
880 }
881
882 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
893 self.skip_validation.set(skip_validation);
894 self
895 }
896
897 fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
898 let message = parse_message(buf)?;
899
900 if self.version != MetadataVersion::V1 && message.version() != self.version {
902 return Err(ArrowError::IpcError(
903 "Could not read IPC message as metadata versions mismatch".to_string(),
904 ));
905 }
906 Ok(message)
907 }
908
909 pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
911 let message = self.read_message(buf)?;
912 match message.header_type() {
913 crate::MessageHeader::DictionaryBatch => {
914 let batch = message.header_as_dictionary_batch().unwrap();
915 read_dictionary_impl(
916 &buf.slice(block.metaDataLength() as _),
917 batch,
918 &self.schema,
919 &mut self.dictionaries,
920 &message.version(),
921 self.require_alignment,
922 self.skip_validation.clone(),
923 )
924 }
925 t => Err(ArrowError::ParseError(format!(
926 "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
927 ))),
928 }
929 }
930
931 pub fn read_record_batch(
933 &self,
934 block: &Block,
935 buf: &Buffer,
936 ) -> Result<Option<RecordBatch>, ArrowError> {
937 let message = self.read_message(buf)?;
938 match message.header_type() {
939 crate::MessageHeader::Schema => Err(ArrowError::IpcError(
940 "Not expecting a schema when messages are read".to_string(),
941 )),
942 crate::MessageHeader::RecordBatch => {
943 let batch = message.header_as_record_batch().ok_or_else(|| {
944 ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
945 })?;
946 RecordBatchDecoder::try_new(
948 &buf.slice(block.metaDataLength() as _),
949 batch,
950 self.schema.clone(),
951 &self.dictionaries,
952 &message.version(),
953 )?
954 .with_projection(self.projection.as_deref())
955 .with_require_alignment(self.require_alignment)
956 .with_skip_validation(self.skip_validation.clone())
957 .read_record_batch()
958 .map(Some)
959 }
960 crate::MessageHeader::NONE => Ok(None),
961 t => Err(ArrowError::InvalidArgumentError(format!(
962 "Reading types other than record batches not yet supported, unable to read {t:?}"
963 ))),
964 }
965 }
966}
967
968#[derive(Debug)]
970pub struct FileReaderBuilder {
971 projection: Option<Vec<usize>>,
973 max_footer_fb_tables: usize,
975 max_footer_fb_depth: usize,
977}
978
979impl Default for FileReaderBuilder {
980 fn default() -> Self {
981 let verifier_options = VerifierOptions::default();
982 Self {
983 max_footer_fb_tables: verifier_options.max_tables,
984 max_footer_fb_depth: verifier_options.max_depth,
985 projection: None,
986 }
987 }
988}
989
990impl FileReaderBuilder {
991 pub fn new() -> Self {
995 Self::default()
996 }
997
998 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1000 self.projection = Some(projection);
1001 self
1002 }
1003
1004 pub fn with_max_footer_fb_tables(mut self, max_footer_fb_tables: usize) -> Self {
1017 self.max_footer_fb_tables = max_footer_fb_tables;
1018 self
1019 }
1020
1021 pub fn with_max_footer_fb_depth(mut self, max_footer_fb_depth: usize) -> Self {
1034 self.max_footer_fb_depth = max_footer_fb_depth;
1035 self
1036 }
1037
1038 pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<FileReader<R>, ArrowError> {
1040 let mut buffer = [0; 10];
1042 reader.seek(SeekFrom::End(-10))?;
1043 reader.read_exact(&mut buffer)?;
1044
1045 let footer_len = read_footer_length(buffer)?;
1046
1047 let mut footer_data = vec![0; footer_len];
1049 reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
1050 reader.read_exact(&mut footer_data)?;
1051
1052 let verifier_options = VerifierOptions {
1053 max_tables: self.max_footer_fb_tables,
1054 max_depth: self.max_footer_fb_depth,
1055 ..Default::default()
1056 };
1057 let footer = crate::root_as_footer_with_opts(&verifier_options, &footer_data[..]).map_err(
1058 |err| ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")),
1059 )?;
1060
1061 let blocks = footer.recordBatches().ok_or_else(|| {
1062 ArrowError::ParseError("Unable to get record batches from IPC Footer".to_string())
1063 })?;
1064
1065 let total_blocks = blocks.len();
1066
1067 let ipc_schema = footer.schema().unwrap();
1068 if !ipc_schema.endianness().equals_to_target_endianness() {
1069 return Err(ArrowError::IpcError(
1070 "the endianness of the source system does not match the endianness of the target system.".to_owned()
1071 ));
1072 }
1073
1074 let schema = crate::convert::fb_to_schema(ipc_schema);
1075
1076 let mut custom_metadata = HashMap::new();
1077 if let Some(fb_custom_metadata) = footer.custom_metadata() {
1078 for kv in fb_custom_metadata.into_iter() {
1079 custom_metadata.insert(
1080 kv.key().unwrap().to_string(),
1081 kv.value().unwrap().to_string(),
1082 );
1083 }
1084 }
1085
1086 let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
1087 if let Some(projection) = self.projection {
1088 decoder = decoder.with_projection(projection)
1089 }
1090
1091 if let Some(dictionaries) = footer.dictionaries() {
1093 for block in dictionaries {
1094 let buf = read_block(&mut reader, block)?;
1095 decoder.read_dictionary(block, &buf)?;
1096 }
1097 }
1098
1099 Ok(FileReader {
1100 reader,
1101 blocks: blocks.iter().copied().collect(),
1102 current_block: 0,
1103 total_blocks,
1104 decoder,
1105 custom_metadata,
1106 })
1107 }
1108}
1109
1110pub struct FileReader<R> {
1155 reader: R,
1157
1158 decoder: FileDecoder,
1160
1161 blocks: Vec<Block>,
1165
1166 current_block: usize,
1168
1169 total_blocks: usize,
1171
1172 custom_metadata: HashMap<String, String>,
1174}
1175
1176impl<R> fmt::Debug for FileReader<R> {
1177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1178 f.debug_struct("FileReader<R>")
1179 .field("decoder", &self.decoder)
1180 .field("blocks", &self.blocks)
1181 .field("current_block", &self.current_block)
1182 .field("total_blocks", &self.total_blocks)
1183 .finish_non_exhaustive()
1184 }
1185}
1186
1187impl<R: Read + Seek> FileReader<BufReader<R>> {
1188 pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1192 Self::try_new(BufReader::new(reader), projection)
1193 }
1194}
1195
1196impl<R: Read + Seek> FileReader<R> {
1197 pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1208 let builder = FileReaderBuilder {
1209 projection,
1210 ..Default::default()
1211 };
1212 builder.build(reader)
1213 }
1214
1215 pub fn custom_metadata(&self) -> &HashMap<String, String> {
1217 &self.custom_metadata
1218 }
1219
1220 pub fn num_batches(&self) -> usize {
1222 self.total_blocks
1223 }
1224
1225 pub fn schema(&self) -> SchemaRef {
1227 self.decoder.schema.clone()
1228 }
1229
1230 pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
1234 if index >= self.total_blocks {
1235 Err(ArrowError::InvalidArgumentError(format!(
1236 "Cannot set batch to index {} from {} total batches",
1237 index, self.total_blocks
1238 )))
1239 } else {
1240 self.current_block = index;
1241 Ok(())
1242 }
1243 }
1244
1245 fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1246 let block = &self.blocks[self.current_block];
1247 self.current_block += 1;
1248
1249 let buffer = read_block(&mut self.reader, block)?;
1251 self.decoder.read_record_batch(block, &buffer)
1252 }
1253
1254 pub fn get_ref(&self) -> &R {
1258 &self.reader
1259 }
1260
1261 pub fn get_mut(&mut self) -> &mut R {
1265 &mut self.reader
1266 }
1267
1268 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1274 self.decoder = self.decoder.with_skip_validation(skip_validation);
1275 self
1276 }
1277}
1278
1279impl<R: Read + Seek> Iterator for FileReader<R> {
1280 type Item = Result<RecordBatch, ArrowError>;
1281
1282 fn next(&mut self) -> Option<Self::Item> {
1283 if self.current_block < self.total_blocks {
1285 self.maybe_next().transpose()
1286 } else {
1287 None
1288 }
1289 }
1290}
1291
1292impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
1293 fn schema(&self) -> SchemaRef {
1294 self.schema()
1295 }
1296}
1297
1298pub struct StreamReader<R> {
1332 reader: R,
1334
1335 schema: SchemaRef,
1337
1338 dictionaries_by_id: HashMap<i64, ArrayRef>,
1342
1343 finished: bool,
1347
1348 projection: Option<(Vec<usize>, Schema)>,
1350
1351 skip_validation: UnsafeFlag,
1355}
1356
1357impl<R> fmt::Debug for StreamReader<R> {
1358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
1359 f.debug_struct("StreamReader<R>")
1360 .field("reader", &"R")
1361 .field("schema", &self.schema)
1362 .field("dictionaries_by_id", &self.dictionaries_by_id)
1363 .field("finished", &self.finished)
1364 .field("projection", &self.projection)
1365 .finish()
1366 }
1367}
1368
1369impl<R: Read> StreamReader<BufReader<R>> {
1370 pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1374 Self::try_new(BufReader::new(reader), projection)
1375 }
1376}
1377
1378impl<R: Read> StreamReader<R> {
1379 pub fn try_new(
1391 mut reader: R,
1392 projection: Option<Vec<usize>>,
1393 ) -> Result<StreamReader<R>, ArrowError> {
1394 let mut meta_size: [u8; 4] = [0; 4];
1396 reader.read_exact(&mut meta_size)?;
1397 let meta_len = {
1398 if meta_size == CONTINUATION_MARKER {
1401 reader.read_exact(&mut meta_size)?;
1402 }
1403 i32::from_le_bytes(meta_size)
1404 };
1405
1406 let mut meta_buffer = vec![0; meta_len as usize];
1407 reader.read_exact(&mut meta_buffer)?;
1408
1409 let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| {
1410 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1411 })?;
1412 let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| {
1414 ArrowError::ParseError("Unable to read IPC message as schema".to_string())
1415 })?;
1416 let schema = crate::convert::fb_to_schema(ipc_schema);
1417
1418 let dictionaries_by_id = HashMap::new();
1420
1421 let projection = match projection {
1422 Some(projection_indices) => {
1423 let schema = schema.project(&projection_indices)?;
1424 Some((projection_indices, schema))
1425 }
1426 _ => None,
1427 };
1428 Ok(Self {
1429 reader,
1430 schema: Arc::new(schema),
1431 finished: false,
1432 dictionaries_by_id,
1433 projection,
1434 skip_validation: UnsafeFlag::new(),
1435 })
1436 }
1437
1438 #[deprecated(since = "53.0.0", note = "use `try_new` instead")]
1440 pub fn try_new_unbuffered(
1441 reader: R,
1442 projection: Option<Vec<usize>>,
1443 ) -> Result<Self, ArrowError> {
1444 Self::try_new(reader, projection)
1445 }
1446
1447 pub fn schema(&self) -> SchemaRef {
1449 self.schema.clone()
1450 }
1451
1452 pub fn is_finished(&self) -> bool {
1454 self.finished
1455 }
1456
1457 fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1458 if self.finished {
1459 return Ok(None);
1460 }
1461 let mut meta_size: [u8; 4] = [0; 4];
1463
1464 match self.reader.read_exact(&mut meta_size) {
1465 Ok(()) => (),
1466 Err(e) => {
1467 return if e.kind() == std::io::ErrorKind::UnexpectedEof {
1468 self.finished = true;
1472 Ok(None)
1473 } else {
1474 Err(ArrowError::from(e))
1475 };
1476 }
1477 }
1478
1479 let meta_len = {
1480 if meta_size == CONTINUATION_MARKER {
1483 self.reader.read_exact(&mut meta_size)?;
1484 }
1485 i32::from_le_bytes(meta_size)
1486 };
1487
1488 if meta_len == 0 {
1489 self.finished = true;
1491 return Ok(None);
1492 }
1493
1494 let mut meta_buffer = vec![0; meta_len as usize];
1495 self.reader.read_exact(&mut meta_buffer)?;
1496
1497 let vecs = &meta_buffer.to_vec();
1498 let message = crate::root_as_message(vecs).map_err(|err| {
1499 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1500 })?;
1501
1502 match message.header_type() {
1503 crate::MessageHeader::Schema => Err(ArrowError::IpcError(
1504 "Not expecting a schema when messages are read".to_string(),
1505 )),
1506 crate::MessageHeader::RecordBatch => {
1507 let batch = message.header_as_record_batch().ok_or_else(|| {
1508 ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1509 })?;
1510 let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1512 self.reader.read_exact(&mut buf)?;
1513
1514 RecordBatchDecoder::try_new(
1515 &buf.into(),
1516 batch,
1517 self.schema(),
1518 &self.dictionaries_by_id,
1519 &message.version(),
1520 )?
1521 .with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
1522 .with_require_alignment(false)
1523 .with_skip_validation(self.skip_validation.clone())
1524 .read_record_batch()
1525 .map(Some)
1526 }
1527 crate::MessageHeader::DictionaryBatch => {
1528 let batch = message.header_as_dictionary_batch().ok_or_else(|| {
1529 ArrowError::IpcError(
1530 "Unable to read IPC message as dictionary batch".to_string(),
1531 )
1532 })?;
1533 let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1535 self.reader.read_exact(&mut buf)?;
1536
1537 read_dictionary_impl(
1538 &buf.into(),
1539 batch,
1540 &self.schema,
1541 &mut self.dictionaries_by_id,
1542 &message.version(),
1543 false,
1544 self.skip_validation.clone(),
1545 )?;
1546
1547 self.maybe_next()
1549 }
1550 crate::MessageHeader::NONE => Ok(None),
1551 t => Err(ArrowError::InvalidArgumentError(format!(
1552 "Reading types other than record batches not yet supported, unable to read {t:?} "
1553 ))),
1554 }
1555 }
1556
1557 pub fn get_ref(&self) -> &R {
1561 &self.reader
1562 }
1563
1564 pub fn get_mut(&mut self) -> &mut R {
1568 &mut self.reader
1569 }
1570
1571 pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1577 self.skip_validation.set(skip_validation);
1578 self
1579 }
1580}
1581
1582impl<R: Read> Iterator for StreamReader<R> {
1583 type Item = Result<RecordBatch, ArrowError>;
1584
1585 fn next(&mut self) -> Option<Self::Item> {
1586 self.maybe_next().transpose()
1587 }
1588}
1589
1590impl<R: Read> RecordBatchReader for StreamReader<R> {
1591 fn schema(&self) -> SchemaRef {
1592 self.schema.clone()
1593 }
1594}
1595
1596#[cfg(test)]
1597mod tests {
1598 use crate::convert::fb_to_schema;
1599 use crate::writer::{
1600 unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1601 };
1602
1603 use super::*;
1604
1605 use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
1606 use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
1607 use arrow_array::types::*;
1608 use arrow_buffer::{NullBuffer, OffsetBuffer};
1609 use arrow_data::ArrayDataBuilder;
1610
1611 fn create_test_projection_schema() -> Schema {
1612 let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
1614
1615 let fixed_size_list_data_type =
1616 DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3);
1617
1618 let union_fields = UnionFields::new(
1619 vec![0, 1],
1620 vec![
1621 Field::new("a", DataType::Int32, false),
1622 Field::new("b", DataType::Float64, false),
1623 ],
1624 );
1625
1626 let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
1627
1628 let struct_fields = Fields::from(vec![
1629 Field::new("id", DataType::Int32, false),
1630 Field::new_list("list", Field::new_list_field(DataType::Int8, true), false),
1631 ]);
1632 let struct_data_type = DataType::Struct(struct_fields);
1633
1634 let run_encoded_data_type = DataType::RunEndEncoded(
1635 Arc::new(Field::new("run_ends", DataType::Int16, false)),
1636 Arc::new(Field::new("values", DataType::Int32, true)),
1637 );
1638
1639 Schema::new(vec![
1641 Field::new("f0", DataType::UInt32, false),
1642 Field::new("f1", DataType::Utf8, false),
1643 Field::new("f2", DataType::Boolean, false),
1644 Field::new("f3", union_data_type, true),
1645 Field::new("f4", DataType::Null, true),
1646 Field::new("f5", DataType::Float64, true),
1647 Field::new("f6", list_data_type, false),
1648 Field::new("f7", DataType::FixedSizeBinary(3), true),
1649 Field::new("f8", fixed_size_list_data_type, false),
1650 Field::new("f9", struct_data_type, false),
1651 Field::new("f10", run_encoded_data_type, false),
1652 Field::new("f11", DataType::Boolean, false),
1653 Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
1654 Field::new("f13", DataType::Utf8, false),
1655 ])
1656 }
1657
1658 fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
1659 let array0 = UInt32Array::from(vec![1, 2, 3]);
1661 let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
1662 let array2 = BooleanArray::from(vec![true, false, true]);
1663
1664 let mut union_builder = UnionBuilder::new_dense();
1665 union_builder.append::<Int32Type>("a", 1).unwrap();
1666 union_builder.append::<Float64Type>("b", 10.1).unwrap();
1667 union_builder.append_null::<Float64Type>("b").unwrap();
1668 let array3 = union_builder.build().unwrap();
1669
1670 let array4 = NullArray::new(3);
1671 let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
1672 let array6_values = vec![
1673 Some(vec![Some(10), Some(10), Some(10)]),
1674 Some(vec![Some(20), Some(20), Some(20)]),
1675 Some(vec![Some(30), Some(30)]),
1676 ];
1677 let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
1678 let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
1679 let array7 = FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
1680
1681 let array8_values = ArrayData::builder(DataType::Int32)
1682 .len(9)
1683 .add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
1684 .build()
1685 .unwrap();
1686 let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
1687 .len(3)
1688 .add_child_data(array8_values)
1689 .build()
1690 .unwrap();
1691 let array8 = FixedSizeListArray::from(array8_data);
1692
1693 let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
1694 let array9_list: ArrayRef =
1695 Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
1696 Some(vec![Some(-10)]),
1697 Some(vec![Some(-20), Some(-20), Some(-20)]),
1698 Some(vec![Some(-30)]),
1699 ]));
1700 let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
1701 .add_child_data(array9_id.into_data())
1702 .add_child_data(array9_list.into_data())
1703 .len(3)
1704 .build()
1705 .unwrap();
1706 let array9: ArrayRef = Arc::new(StructArray::from(array9));
1707
1708 let array10_input = vec![Some(1_i32), None, None];
1709 let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
1710 array10_builder.extend(array10_input);
1711 let array10 = array10_builder.finish();
1712
1713 let array11 = BooleanArray::from(vec![false, false, true]);
1714
1715 let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
1716 let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
1717 let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
1718
1719 let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
1720
1721 RecordBatch::try_new(
1723 Arc::new(schema.clone()),
1724 vec![
1725 Arc::new(array0),
1726 Arc::new(array1),
1727 Arc::new(array2),
1728 Arc::new(array3),
1729 Arc::new(array4),
1730 Arc::new(array5),
1731 Arc::new(array6),
1732 Arc::new(array7),
1733 Arc::new(array8),
1734 Arc::new(array9),
1735 Arc::new(array10),
1736 Arc::new(array11),
1737 Arc::new(array12),
1738 Arc::new(array13),
1739 ],
1740 )
1741 .unwrap()
1742 }
1743
1744 #[test]
1745 fn test_projection_array_values() {
1746 let schema = create_test_projection_schema();
1748
1749 let batch = create_test_projection_batch_data(&schema);
1751
1752 let mut buf = Vec::new();
1754 {
1755 let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1756 writer.write(&batch).unwrap();
1757 writer.finish().unwrap();
1758 }
1759
1760 for index in 0..12 {
1762 let projection = vec![index];
1763 let reader = FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
1764 let read_batch = reader.unwrap().next().unwrap().unwrap();
1765 let projected_column = read_batch.column(0);
1766 let expected_column = batch.column(index);
1767
1768 assert_eq!(projected_column.as_ref(), expected_column.as_ref());
1770 }
1771
1772 {
1773 let reader =
1775 FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(vec![3, 2, 1]));
1776 let read_batch = reader.unwrap().next().unwrap().unwrap();
1777 let expected_batch = batch.project(&[3, 2, 1]).unwrap();
1778 assert_eq!(read_batch, expected_batch);
1779 }
1780 }
1781
1782 #[test]
1783 fn test_arrow_single_float_row() {
1784 let schema = Schema::new(vec![
1785 Field::new("a", DataType::Float32, false),
1786 Field::new("b", DataType::Float32, false),
1787 Field::new("c", DataType::Int32, false),
1788 Field::new("d", DataType::Int32, false),
1789 ]);
1790 let arrays = vec![
1791 Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
1792 Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
1793 Arc::new(Int32Array::from(vec![2])) as ArrayRef,
1794 Arc::new(Int32Array::from(vec![1])) as ArrayRef,
1795 ];
1796 let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
1797 let mut file = tempfile::tempfile().unwrap();
1799 let mut stream_writer = crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
1800 stream_writer.write(&batch).unwrap();
1801 stream_writer.finish().unwrap();
1802
1803 drop(stream_writer);
1804
1805 file.rewind().unwrap();
1806
1807 let reader = StreamReader::try_new(&mut file, None).unwrap();
1809
1810 reader.for_each(|batch| {
1811 let batch = batch.unwrap();
1812 assert!(
1813 batch
1814 .column(0)
1815 .as_any()
1816 .downcast_ref::<Float32Array>()
1817 .unwrap()
1818 .value(0)
1819 != 0.0
1820 );
1821 assert!(
1822 batch
1823 .column(1)
1824 .as_any()
1825 .downcast_ref::<Float32Array>()
1826 .unwrap()
1827 .value(0)
1828 != 0.0
1829 );
1830 });
1831
1832 file.rewind().unwrap();
1833
1834 let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
1836
1837 reader.for_each(|batch| {
1838 let batch = batch.unwrap();
1839 assert_eq!(batch.schema().fields().len(), 2);
1840 assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
1841 assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
1842 });
1843 }
1844
1845 fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
1847 let mut buf = Vec::new();
1848 let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1849 writer.write(rb).unwrap();
1850 writer.finish().unwrap();
1851 buf
1852 }
1853
1854 fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1856 let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
1857 reader.next().unwrap()
1858 }
1859
1860 fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1863 let mut reader = unsafe {
1864 FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1865 };
1866 reader.next().unwrap()
1867 }
1868
1869 fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
1870 let buf = write_ipc(rb);
1871 read_ipc(&buf).unwrap()
1872 }
1873
1874 fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1877 read_ipc_with_decoder_inner(buf, false)
1878 }
1879
1880 fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1883 read_ipc_with_decoder_inner(buf, true)
1884 }
1885
1886 fn read_ipc_with_decoder_inner(
1887 buf: Vec<u8>,
1888 skip_validation: bool,
1889 ) -> Result<RecordBatch, ArrowError> {
1890 let buffer = Buffer::from_vec(buf);
1891 let trailer_start = buffer.len() - 10;
1892 let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
1893 let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
1894 .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;
1895
1896 let schema = fb_to_schema(footer.schema().unwrap());
1897
1898 let mut decoder = unsafe {
1899 FileDecoder::new(Arc::new(schema), footer.version())
1900 .with_skip_validation(skip_validation)
1901 };
1902 for block in footer.dictionaries().iter().flatten() {
1904 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1905 let data = buffer.slice_with_length(block.offset() as _, block_len);
1906 decoder.read_dictionary(block, &data)?
1907 }
1908
1909 let batches = footer.recordBatches().unwrap();
1911 assert_eq!(batches.len(), 1); let block = batches.get(0);
1914 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1915 let data = buffer.slice_with_length(block.offset() as _, block_len);
1916 Ok(decoder.read_record_batch(block, &data)?.unwrap())
1917 }
1918
1919 fn write_stream(rb: &RecordBatch) -> Vec<u8> {
1921 let mut buf = Vec::new();
1922 let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1923 writer.write(rb).unwrap();
1924 writer.finish().unwrap();
1925 buf
1926 }
1927
1928 fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1930 let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
1931 reader.next().unwrap()
1932 }
1933
1934 fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1937 let mut reader = unsafe {
1938 StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1939 };
1940 reader.next().unwrap()
1941 }
1942
1943 fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
1944 let buf = write_stream(rb);
1945 read_stream(&buf).unwrap()
1946 }
1947
1948 #[test]
1949 fn test_roundtrip_with_custom_metadata() {
1950 let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
1951 let mut buf = Vec::new();
1952 let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1953 let mut test_metadata = HashMap::new();
1954 test_metadata.insert("abc".to_string(), "abc".to_string());
1955 test_metadata.insert("def".to_string(), "def".to_string());
1956 for (k, v) in &test_metadata {
1957 writer.write_metadata(k, v);
1958 }
1959 writer.finish().unwrap();
1960 drop(writer);
1961
1962 let reader = crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
1963 assert_eq!(reader.custom_metadata(), &test_metadata);
1964 }
1965
1966 #[test]
1967 fn test_roundtrip_nested_dict() {
1968 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
1969
1970 let array = Arc::new(inner) as ArrayRef;
1971
1972 let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
1973
1974 let s = StructArray::from(vec![(dctfield, array)]);
1975 let struct_array = Arc::new(s) as ArrayRef;
1976
1977 let schema = Arc::new(Schema::new(vec![Field::new(
1978 "struct",
1979 struct_array.data_type().clone(),
1980 false,
1981 )]));
1982
1983 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
1984
1985 assert_eq!(batch, roundtrip_ipc(&batch));
1986 }
1987
1988 #[test]
1989 fn test_roundtrip_nested_dict_no_preserve_dict_id() {
1990 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
1991
1992 let array = Arc::new(inner) as ArrayRef;
1993
1994 let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
1995
1996 let s = StructArray::from(vec![(dctfield, array)]);
1997 let struct_array = Arc::new(s) as ArrayRef;
1998
1999 let schema = Arc::new(Schema::new(vec![Field::new(
2000 "struct",
2001 struct_array.data_type().clone(),
2002 false,
2003 )]));
2004
2005 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2006
2007 let mut buf = Vec::new();
2008 let mut writer = crate::writer::FileWriter::try_new_with_options(
2009 &mut buf,
2010 batch.schema_ref(),
2011 #[allow(deprecated)]
2012 IpcWriteOptions::default().with_preserve_dict_id(false),
2013 )
2014 .unwrap();
2015 writer.write(&batch).unwrap();
2016 writer.finish().unwrap();
2017 drop(writer);
2018
2019 let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
2020
2021 assert_eq!(batch, reader.next().unwrap().unwrap());
2022 }
2023
2024 fn check_union_with_builder(mut builder: UnionBuilder) {
2025 builder.append::<Int32Type>("a", 1).unwrap();
2026 builder.append_null::<Int32Type>("a").unwrap();
2027 builder.append::<Float64Type>("c", 3.0).unwrap();
2028 builder.append::<Int32Type>("a", 4).unwrap();
2029 builder.append::<Int64Type>("d", 11).unwrap();
2030 let union = builder.build().unwrap();
2031
2032 let schema = Arc::new(Schema::new(vec![Field::new(
2033 "union",
2034 union.data_type().clone(),
2035 false,
2036 )]));
2037
2038 let union_array = Arc::new(union) as ArrayRef;
2039
2040 let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
2041 let rb2 = roundtrip_ipc(&rb);
2042 assert_eq!(rb.schema(), rb2.schema());
2045 assert_eq!(rb.num_columns(), rb2.num_columns());
2046 assert_eq!(rb.num_rows(), rb2.num_rows());
2047 let union1 = rb.column(0);
2048 let union2 = rb2.column(0);
2049
2050 assert_eq!(union1, union2);
2051 }
2052
2053 #[test]
2054 fn test_roundtrip_dense_union() {
2055 check_union_with_builder(UnionBuilder::new_dense());
2056 }
2057
2058 #[test]
2059 fn test_roundtrip_sparse_union() {
2060 check_union_with_builder(UnionBuilder::new_sparse());
2061 }
2062
2063 #[test]
2064 fn test_roundtrip_struct_empty_fields() {
2065 let nulls = NullBuffer::from(&[true, true, false]);
2066 let rb = RecordBatch::try_from_iter([(
2067 "",
2068 Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _,
2069 )])
2070 .unwrap();
2071 let rb2 = roundtrip_ipc(&rb);
2072 assert_eq!(rb, rb2);
2073 }
2074
2075 #[test]
2076 fn test_roundtrip_stream_run_array_sliced() {
2077 let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
2078 .into_iter()
2079 .collect();
2080 let run_array_1_sliced = run_array_1.slice(2, 5);
2081
2082 let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
2083 let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
2084 run_array_2_builder.extend(run_array_2_inupt);
2085 let run_array_2 = run_array_2_builder.finish();
2086
2087 let schema = Arc::new(Schema::new(vec![
2088 Field::new(
2089 "run_array_1_sliced",
2090 run_array_1_sliced.data_type().clone(),
2091 false,
2092 ),
2093 Field::new("run_array_2", run_array_2.data_type().clone(), false),
2094 ]));
2095 let input_batch = RecordBatch::try_new(
2096 schema,
2097 vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
2098 )
2099 .unwrap();
2100 let output_batch = roundtrip_ipc_stream(&input_batch);
2101
2102 assert_eq!(input_batch.column(1), output_batch.column(1));
2106
2107 let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap();
2108 assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
2109 }
2110
2111 #[test]
2112 fn test_roundtrip_stream_nested_dict() {
2113 let xs = vec!["AA", "BB", "AA", "CC", "BB"];
2114 let dict = Arc::new(
2115 xs.clone()
2116 .into_iter()
2117 .collect::<DictionaryArray<Int8Type>>(),
2118 );
2119 let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
2120 let struct_array = StructArray::from(vec![
2121 (
2122 Arc::new(Field::new("f2.1", DataType::Utf8, false)),
2123 string_array,
2124 ),
2125 (
2126 Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
2127 dict.clone() as ArrayRef,
2128 ),
2129 ]);
2130 let schema = Arc::new(Schema::new(vec![
2131 Field::new("f1_string", DataType::Utf8, false),
2132 Field::new("f2_struct", struct_array.data_type().clone(), false),
2133 ]));
2134 let input_batch = RecordBatch::try_new(
2135 schema,
2136 vec![
2137 Arc::new(StringArray::from(xs.clone())),
2138 Arc::new(struct_array),
2139 ],
2140 )
2141 .unwrap();
2142 let output_batch = roundtrip_ipc_stream(&input_batch);
2143 assert_eq!(input_batch, output_batch);
2144 }
2145
2146 #[test]
2147 fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
2148 let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
2149 let values = Arc::new(values) as ArrayRef;
2150 let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
2151 let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
2152
2153 let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
2154 let key_dict_array = DictionaryArray::new(key_dict_keys, values);
2155
2156 #[allow(deprecated)]
2157 let keys_field = Arc::new(Field::new_dict(
2158 "keys",
2159 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2160 true, 1,
2162 false,
2163 ));
2164 #[allow(deprecated)]
2165 let values_field = Arc::new(Field::new_dict(
2166 "values",
2167 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2168 true,
2169 2,
2170 false,
2171 ));
2172 let entry_struct = StructArray::from(vec![
2173 (keys_field, make_array(key_dict_array.into_data())),
2174 (values_field, make_array(value_dict_array.into_data())),
2175 ]);
2176 let map_data_type = DataType::Map(
2177 Arc::new(Field::new(
2178 "entries",
2179 entry_struct.data_type().clone(),
2180 false,
2181 )),
2182 false,
2183 );
2184
2185 let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
2186 let map_data = ArrayData::builder(map_data_type)
2187 .len(3)
2188 .add_buffer(entry_offsets)
2189 .add_child_data(entry_struct.into_data())
2190 .build()
2191 .unwrap();
2192 let map_array = MapArray::from(map_data);
2193
2194 let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
2195 let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2196
2197 let schema = Arc::new(Schema::new(vec![Field::new(
2198 "f1",
2199 dict_dict_array.data_type().clone(),
2200 false,
2201 )]));
2202 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2203 let output_batch = roundtrip_ipc_stream(&input_batch);
2204 assert_eq!(input_batch, output_batch);
2205 }
2206
2207 fn test_roundtrip_stream_dict_of_list_of_dict_impl<
2208 OffsetSize: OffsetSizeTrait,
2209 U: ArrowNativeType,
2210 >(
2211 list_data_type: DataType,
2212 offsets: &[U; 5],
2213 ) {
2214 let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2215 let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2216 let dict_array = DictionaryArray::new(keys, Arc::new(values));
2217 let dict_data = dict_array.to_data();
2218
2219 let value_offsets = Buffer::from_slice_ref(offsets);
2220
2221 let list_data = ArrayData::builder(list_data_type)
2222 .len(4)
2223 .add_buffer(value_offsets)
2224 .add_child_data(dict_data)
2225 .build()
2226 .unwrap();
2227 let list_array = GenericListArray::<OffsetSize>::from(list_data);
2228
2229 let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
2230 let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2231
2232 let schema = Arc::new(Schema::new(vec![Field::new(
2233 "f1",
2234 dict_dict_array.data_type().clone(),
2235 false,
2236 )]));
2237 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2238 let output_batch = roundtrip_ipc_stream(&input_batch);
2239 assert_eq!(input_batch, output_batch);
2240 }
2241
2242 #[test]
2243 fn test_roundtrip_stream_dict_of_list_of_dict() {
2244 #[allow(deprecated)]
2246 let list_data_type = DataType::List(Arc::new(Field::new_dict(
2247 "item",
2248 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2249 true,
2250 1,
2251 false,
2252 )));
2253 let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
2254 test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(list_data_type, offsets);
2255
2256 #[allow(deprecated)]
2258 let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
2259 "item",
2260 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2261 true,
2262 1,
2263 false,
2264 )));
2265 let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
2266 test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(list_data_type, offsets);
2267 }
2268
2269 #[test]
2270 fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
2271 let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2272 let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
2273 let dict_array = DictionaryArray::new(keys, Arc::new(values));
2274 let dict_data = dict_array.into_data();
2275
2276 #[allow(deprecated)]
2277 let list_data_type = DataType::FixedSizeList(
2278 Arc::new(Field::new_dict(
2279 "item",
2280 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2281 true,
2282 1,
2283 false,
2284 )),
2285 3,
2286 );
2287 let list_data = ArrayData::builder(list_data_type)
2288 .len(3)
2289 .add_child_data(dict_data)
2290 .build()
2291 .unwrap();
2292 let list_array = FixedSizeListArray::from(list_data);
2293
2294 let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2295 let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2296
2297 let schema = Arc::new(Schema::new(vec![Field::new(
2298 "f1",
2299 dict_dict_array.data_type().clone(),
2300 false,
2301 )]));
2302 let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2303 let output_batch = roundtrip_ipc_stream(&input_batch);
2304 assert_eq!(input_batch, output_batch);
2305 }
2306
2307 const LONG_TEST_STRING: &str =
2308 "This is a long string to make sure binary view array handles it";
2309
2310 #[test]
2311 fn test_roundtrip_view_types() {
2312 let schema = Schema::new(vec![
2313 Field::new("field_1", DataType::BinaryView, true),
2314 Field::new("field_2", DataType::Utf8, true),
2315 Field::new("field_3", DataType::Utf8View, true),
2316 ]);
2317 let bin_values: Vec<Option<&[u8]>> = vec![
2318 Some(b"foo"),
2319 None,
2320 Some(b"bar"),
2321 Some(LONG_TEST_STRING.as_bytes()),
2322 ];
2323 let utf8_values: Vec<Option<&str>> =
2324 vec![Some("foo"), None, Some("bar"), Some(LONG_TEST_STRING)];
2325 let bin_view_array = BinaryViewArray::from_iter(bin_values);
2326 let utf8_array = StringArray::from_iter(utf8_values.iter());
2327 let utf8_view_array = StringViewArray::from_iter(utf8_values);
2328 let record_batch = RecordBatch::try_new(
2329 Arc::new(schema.clone()),
2330 vec![
2331 Arc::new(bin_view_array),
2332 Arc::new(utf8_array),
2333 Arc::new(utf8_view_array),
2334 ],
2335 )
2336 .unwrap();
2337
2338 assert_eq!(record_batch, roundtrip_ipc(&record_batch));
2339 assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));
2340
2341 let sliced_batch = record_batch.slice(1, 2);
2342 assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2343 assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2344 }
2345
2346 #[test]
2347 fn test_roundtrip_view_types_nested_dict() {
2348 let bin_values: Vec<Option<&[u8]>> = vec![
2349 Some(b"foo"),
2350 None,
2351 Some(b"bar"),
2352 Some(LONG_TEST_STRING.as_bytes()),
2353 Some(b"field"),
2354 ];
2355 let utf8_values: Vec<Option<&str>> = vec![
2356 Some("foo"),
2357 None,
2358 Some("bar"),
2359 Some(LONG_TEST_STRING),
2360 Some("field"),
2361 ];
2362 let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
2363 let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));
2364
2365 let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2366 let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
2367 #[allow(deprecated)]
2368 let keys_field = Arc::new(Field::new_dict(
2369 "keys",
2370 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
2371 true,
2372 1,
2373 false,
2374 ));
2375
2376 let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
2377 let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
2378 #[allow(deprecated)]
2379 let values_field = Arc::new(Field::new_dict(
2380 "values",
2381 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
2382 true,
2383 2,
2384 false,
2385 ));
2386 let entry_struct = StructArray::from(vec![
2387 (keys_field, make_array(key_dict_array.into_data())),
2388 (values_field, make_array(value_dict_array.into_data())),
2389 ]);
2390
2391 let map_data_type = DataType::Map(
2392 Arc::new(Field::new(
2393 "entries",
2394 entry_struct.data_type().clone(),
2395 false,
2396 )),
2397 false,
2398 );
2399 let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
2400 let map_data = ArrayData::builder(map_data_type)
2401 .len(3)
2402 .add_buffer(entry_offsets)
2403 .add_child_data(entry_struct.into_data())
2404 .build()
2405 .unwrap();
2406 let map_array = MapArray::from(map_data);
2407
2408 let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2409 let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2410 let schema = Arc::new(Schema::new(vec![Field::new(
2411 "f1",
2412 dict_dict_array.data_type().clone(),
2413 false,
2414 )]));
2415 let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2416 assert_eq!(batch, roundtrip_ipc(&batch));
2417 assert_eq!(batch, roundtrip_ipc_stream(&batch));
2418
2419 let sliced_batch = batch.slice(1, 2);
2420 assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2421 assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2422 }
2423
2424 #[test]
2425 fn test_no_columns_batch() {
2426 let schema = Arc::new(Schema::empty());
2427 let options = RecordBatchOptions::new()
2428 .with_match_field_names(true)
2429 .with_row_count(Some(10));
2430 let input_batch = RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
2431 let output_batch = roundtrip_ipc_stream(&input_batch);
2432 assert_eq!(input_batch, output_batch);
2433 }
2434
2435 #[test]
2436 fn test_unaligned() {
2437 let batch = RecordBatch::try_from_iter(vec![(
2438 "i32",
2439 Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2440 )])
2441 .unwrap();
2442
2443 let gen = IpcDataGenerator {};
2444 #[allow(deprecated)]
2445 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2446 let (_, encoded) = gen
2447 .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2448 .unwrap();
2449
2450 let message = root_as_message(&encoded.ipc_message).unwrap();
2451
2452 let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2454 buffer.push(0_u8);
2455 buffer.extend_from_slice(&encoded.arrow_data);
2456 let b = Buffer::from(buffer).slice(1);
2457 assert_ne!(b.as_ptr().align_offset(8), 0);
2458
2459 let ipc_batch = message.header_as_record_batch().unwrap();
2460 let roundtrip = RecordBatchDecoder::try_new(
2461 &b,
2462 ipc_batch,
2463 batch.schema(),
2464 &Default::default(),
2465 &message.version(),
2466 )
2467 .unwrap()
2468 .with_require_alignment(false)
2469 .read_record_batch()
2470 .unwrap();
2471 assert_eq!(batch, roundtrip);
2472 }
2473
2474 #[test]
2475 fn test_unaligned_throws_error_with_require_alignment() {
2476 let batch = RecordBatch::try_from_iter(vec![(
2477 "i32",
2478 Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2479 )])
2480 .unwrap();
2481
2482 let gen = IpcDataGenerator {};
2483 #[allow(deprecated)]
2484 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2485 let (_, encoded) = gen
2486 .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2487 .unwrap();
2488
2489 let message = root_as_message(&encoded.ipc_message).unwrap();
2490
2491 let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2493 buffer.push(0_u8);
2494 buffer.extend_from_slice(&encoded.arrow_data);
2495 let b = Buffer::from(buffer).slice(1);
2496 assert_ne!(b.as_ptr().align_offset(8), 0);
2497
2498 let ipc_batch = message.header_as_record_batch().unwrap();
2499 let result = RecordBatchDecoder::try_new(
2500 &b,
2501 ipc_batch,
2502 batch.schema(),
2503 &Default::default(),
2504 &message.version(),
2505 )
2506 .unwrap()
2507 .with_require_alignment(true)
2508 .read_record_batch();
2509
2510 let error = result.unwrap_err();
2511 assert_eq!(
2512 error.to_string(),
2513 "Invalid argument error: Misaligned buffers[0] in array of type Int32, \
2514 offset from expected alignment of 4 by 1"
2515 );
2516 }
2517
2518 #[test]
2519 fn test_file_with_massive_column_count() {
2520 let limit = 600_000;
2522
2523 let fields = (0..limit)
2524 .map(|i| Field::new(format!("{i}"), DataType::Boolean, false))
2525 .collect::<Vec<_>>();
2526 let schema = Arc::new(Schema::new(fields));
2527 let batch = RecordBatch::new_empty(schema);
2528
2529 let mut buf = Vec::new();
2530 let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2531 writer.write(&batch).unwrap();
2532 writer.finish().unwrap();
2533 drop(writer);
2534
2535 let mut reader = FileReaderBuilder::new()
2536 .with_max_footer_fb_tables(1_500_000)
2537 .build(std::io::Cursor::new(buf))
2538 .unwrap();
2539 let roundtrip_batch = reader.next().unwrap().unwrap();
2540
2541 assert_eq!(batch, roundtrip_batch);
2542 }
2543
2544 #[test]
2545 fn test_file_with_deeply_nested_columns() {
2546 let limit = 61;
2548
2549 let fields = (0..limit).fold(
2550 vec![Field::new("leaf", DataType::Boolean, false)],
2551 |field, index| vec![Field::new_struct(format!("{index}"), field, false)],
2552 );
2553 let schema = Arc::new(Schema::new(fields));
2554 let batch = RecordBatch::new_empty(schema);
2555
2556 let mut buf = Vec::new();
2557 let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2558 writer.write(&batch).unwrap();
2559 writer.finish().unwrap();
2560 drop(writer);
2561
2562 let mut reader = FileReaderBuilder::new()
2563 .with_max_footer_fb_depth(65)
2564 .build(std::io::Cursor::new(buf))
2565 .unwrap();
2566 let roundtrip_batch = reader.next().unwrap().unwrap();
2567
2568 assert_eq!(batch, roundtrip_batch);
2569 }
2570
2571 #[test]
2572 fn test_invalid_struct_array_ipc_read_errors() {
2573 let a_field = Field::new("a", DataType::Int32, false);
2574 let b_field = Field::new("b", DataType::Int32, false);
2575 let struct_fields = Fields::from(vec![a_field.clone(), b_field.clone()]);
2576
2577 let a_array_data = ArrayData::builder(a_field.data_type().clone())
2578 .len(4)
2579 .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2580 .build()
2581 .unwrap();
2582 let b_array_data = ArrayData::builder(b_field.data_type().clone())
2583 .len(3)
2584 .add_buffer(Buffer::from_slice_ref([5, 6, 7]))
2585 .build()
2586 .unwrap();
2587
2588 let invalid_struct_arr = unsafe {
2589 StructArray::new_unchecked(
2590 struct_fields,
2591 vec![make_array(a_array_data), make_array(b_array_data)],
2592 None,
2593 )
2594 };
2595
2596 expect_ipc_validation_error(
2597 Arc::new(invalid_struct_arr),
2598 "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
2599 );
2600 }
2601
2602 #[test]
2603 fn test_invalid_nested_array_ipc_read_errors() {
2604 let a_field = Field::new("a", DataType::Int32, false);
2606 let b_field = Field::new("b", DataType::Utf8, false);
2607
2608 let schema = Arc::new(Schema::new(vec![Field::new_struct(
2609 "s",
2610 vec![a_field.clone(), b_field.clone()],
2611 false,
2612 )]));
2613
2614 let a_array_data = ArrayData::builder(a_field.data_type().clone())
2615 .len(4)
2616 .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2617 .build()
2618 .unwrap();
2619 let b_array_data = {
2621 let valid: &[u8] = b" ";
2622 let mut invalid = vec![];
2623 invalid.extend_from_slice(b"ValidString");
2624 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2625 let binary_array =
2626 BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2627 let array = unsafe {
2628 StringArray::new_unchecked(
2629 binary_array.offsets().clone(),
2630 binary_array.values().clone(),
2631 binary_array.nulls().cloned(),
2632 )
2633 };
2634 array.into_data()
2635 };
2636 let struct_data_type = schema.field(0).data_type();
2637
2638 let invalid_struct_arr = unsafe {
2639 make_array(
2640 ArrayData::builder(struct_data_type.clone())
2641 .len(4)
2642 .add_child_data(a_array_data)
2643 .add_child_data(b_array_data)
2644 .build_unchecked(),
2645 )
2646 };
2647 expect_ipc_validation_error(
2648 Arc::new(invalid_struct_arr),
2649 "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11",
2650 );
2651 }
2652
2653 #[test]
2654 fn test_same_dict_id_without_preserve() {
2655 let batch = RecordBatch::try_new(
2656 Arc::new(Schema::new(
2657 ["a", "b"]
2658 .iter()
2659 .map(|name| {
2660 #[allow(deprecated)]
2661 Field::new_dict(
2662 name.to_string(),
2663 DataType::Dictionary(
2664 Box::new(DataType::Int32),
2665 Box::new(DataType::Utf8),
2666 ),
2667 true,
2668 0,
2669 false,
2670 )
2671 })
2672 .collect::<Vec<Field>>(),
2673 )),
2674 vec![
2675 Arc::new(
2676 vec![Some("c"), Some("d")]
2677 .into_iter()
2678 .collect::<DictionaryArray<Int32Type>>(),
2679 ) as ArrayRef,
2680 Arc::new(
2681 vec![Some("e"), Some("f")]
2682 .into_iter()
2683 .collect::<DictionaryArray<Int32Type>>(),
2684 ) as ArrayRef,
2685 ],
2686 )
2687 .expect("Failed to create RecordBatch");
2688
2689 let mut buf = vec![];
2691 {
2692 let mut writer = crate::writer::StreamWriter::try_new_with_options(
2693 &mut buf,
2694 batch.schema().as_ref(),
2695 #[allow(deprecated)]
2696 crate::writer::IpcWriteOptions::default().with_preserve_dict_id(false),
2697 )
2698 .expect("Failed to create StreamWriter");
2699 writer.write(&batch).expect("Failed to write RecordBatch");
2700 writer.finish().expect("Failed to finish StreamWriter");
2701 }
2702
2703 StreamReader::try_new(std::io::Cursor::new(buf), None)
2704 .expect("Failed to create StreamReader")
2705 .for_each(|decoded_batch| {
2706 assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
2707 });
2708 }
2709
2710 #[test]
2711 fn test_validation_of_invalid_list_array() {
2712 let array = unsafe {
2714 let values = Int32Array::from(vec![1, 2, 3]);
2715 let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); let offsets = OffsetBuffer::new_unchecked(bad_offsets); let field = Field::new_list_field(DataType::Int32, true);
2718 let nulls = None;
2719 ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
2720 };
2721
2722 expect_ipc_validation_error(
2723 Arc::new(array),
2724 "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
2725 );
2726 }
2727
2728 #[test]
2729 fn test_validation_of_invalid_string_array() {
2730 let valid: &[u8] = b" ";
2731 let mut invalid = vec![];
2732 invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2733 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2734 let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2735 let array = unsafe {
2738 StringArray::new_unchecked(
2739 binary_array.offsets().clone(),
2740 binary_array.values().clone(),
2741 binary_array.nulls().cloned(),
2742 )
2743 };
2744 expect_ipc_validation_error(
2745 Arc::new(array),
2746 "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
2747 );
2748 }
2749
2750 #[test]
2751 fn test_validation_of_invalid_string_view_array() {
2752 let valid: &[u8] = b" ";
2753 let mut invalid = vec![];
2754 invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2755 invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2756 let binary_view_array =
2757 BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2758 let array = unsafe {
2761 StringViewArray::new_unchecked(
2762 binary_view_array.views().clone(),
2763 binary_view_array.data_buffers().to_vec(),
2764 binary_view_array.nulls().cloned(),
2765 )
2766 };
2767 expect_ipc_validation_error(
2768 Arc::new(array),
2769 "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
2770 );
2771 }
2772
2773 #[test]
2776 fn test_validation_of_invalid_dictionary_array() {
2777 let array = unsafe {
2778 let values = StringArray::from_iter_values(["a", "b", "c"]);
2779 let keys = Int32Array::from(vec![1, 200]); DictionaryArray::new_unchecked(keys, Arc::new(values))
2781 };
2782
2783 expect_ipc_validation_error(
2784 Arc::new(array),
2785 "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
2786 );
2787 }
2788
2789 #[test]
2790 fn test_validation_of_invalid_union_array() {
2791 let array = unsafe {
2792 let fields = UnionFields::new(
2793 vec![1, 3], vec![
2795 Field::new("a", DataType::Int32, false),
2796 Field::new("b", DataType::Utf8, false),
2797 ],
2798 );
2799 let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); let offsets = None;
2801 let children: Vec<ArrayRef> = vec![
2802 Arc::new(Int32Array::from(vec![10, 20, 30])),
2803 Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
2804 ];
2805
2806 UnionArray::new_unchecked(fields, type_ids, offsets, children)
2807 };
2808
2809 expect_ipc_validation_error(
2810 Arc::new(array),
2811 "Invalid argument error: Type Ids values must match one of the field type ids",
2812 );
2813 }
2814
2815 const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
2818
2819 fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
2821 let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();
2822
2823 let buf = write_stream(&rb); read_stream_skip_validation(&buf).unwrap();
2826 let err = read_stream(&buf).unwrap_err();
2827 assert_eq!(err.to_string(), expected_err);
2828
2829 let buf = write_ipc(&rb); read_ipc_skip_validation(&buf).unwrap();
2832 let err = read_ipc(&buf).unwrap_err();
2833 assert_eq!(err.to_string(), expected_err);
2834
2835 read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
2837 let err = read_ipc_with_decoder(buf).unwrap_err();
2838 assert_eq!(err.to_string(), expected_err);
2839 }
2840
2841 #[test]
2842 fn test_roundtrip_schema() {
2843 let schema = Schema::new(vec![
2844 Field::new(
2845 "a",
2846 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2847 false,
2848 ),
2849 Field::new(
2850 "b",
2851 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2852 false,
2853 ),
2854 ]);
2855
2856 let options = IpcWriteOptions::default();
2857 let data_gen = IpcDataGenerator::default();
2858 let mut dict_tracker = DictionaryTracker::new(false);
2859 let encoded_data =
2860 data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
2861 let mut schema_bytes = vec![];
2862 write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
2863
2864 let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
2865 4
2866 } else {
2867 0
2868 };
2869
2870 size_prefixed_root_as_message(&schema_bytes[begin_offset..])
2871 .expect_err("size_prefixed_root_as_message");
2872
2873 let msg = parse_message(&schema_bytes).expect("parse_message");
2874 let ipc_schema = msg.header_as_schema().expect("header_as_schema");
2875 let new_schema = fb_to_schema(ipc_schema);
2876
2877 assert_eq!(schema, new_schema);
2878 }
2879}