1use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal128) => { $crate::Decimal128Array };
111 (@from Decimal256) => { $crate::Decimal256Array };
112 (@from TimestampSecond) => { $crate::TimestampSecondArray };
113 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
114 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
115 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
116
117 (@from $ty: ident) => {
118 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
119 };
120
121 (Null, $size: expr) => {
122 std::sync::Arc::new($crate::NullArray::new($size))
123 };
124
125 (Binary, [$($values: expr),*]) => {
126 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
127 };
128
129 (LargeBinary, [$($values: expr),*]) => {
130 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
131 };
132
133 ($ty: tt, [$($values: expr),*]) => {
134 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
135 };
136}
137
138#[macro_export]
155macro_rules! record_batch {
156 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
157 {
158 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
159 $(
160 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
161 )*
162 ]));
163
164 let batch = $crate::RecordBatch::try_new(
165 schema,
166 vec![$(
167 $crate::create_array!($type, [$($values),*]),
168 )*]
169 );
170
171 batch
172 }
173 }
174}
175
176#[derive(Clone, Debug, PartialEq)]
200pub struct RecordBatch {
201 schema: SchemaRef,
202 columns: Vec<Arc<dyn Array>>,
203
204 row_count: usize,
208}
209
210impl RecordBatch {
211 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
240 let options = RecordBatchOptions::new();
241 Self::try_new_impl(schema, columns, &options)
242 }
243
244 pub unsafe fn new_unchecked(
260 schema: SchemaRef,
261 columns: Vec<Arc<dyn Array>>,
262 row_count: usize,
263 ) -> Self {
264 Self {
265 schema,
266 columns,
267 row_count,
268 }
269 }
270
271 pub fn try_new_with_options(
276 schema: SchemaRef,
277 columns: Vec<ArrayRef>,
278 options: &RecordBatchOptions,
279 ) -> Result<Self, ArrowError> {
280 Self::try_new_impl(schema, columns, options)
281 }
282
283 pub fn new_empty(schema: SchemaRef) -> Self {
285 let columns = schema
286 .fields()
287 .iter()
288 .map(|field| new_empty_array(field.data_type()))
289 .collect();
290
291 RecordBatch {
292 schema,
293 columns,
294 row_count: 0,
295 }
296 }
297
298 fn try_new_impl(
301 schema: SchemaRef,
302 columns: Vec<ArrayRef>,
303 options: &RecordBatchOptions,
304 ) -> Result<Self, ArrowError> {
305 if schema.fields().len() != columns.len() {
307 return Err(ArrowError::InvalidArgumentError(format!(
308 "number of columns({}) must match number of fields({}) in schema",
309 columns.len(),
310 schema.fields().len(),
311 )));
312 }
313
314 let row_count = options
315 .row_count
316 .or_else(|| columns.first().map(|col| col.len()))
317 .ok_or_else(|| {
318 ArrowError::InvalidArgumentError(
319 "must either specify a row count or at least one column".to_string(),
320 )
321 })?;
322
323 for (c, f) in columns.iter().zip(&schema.fields) {
324 if !f.is_nullable() && c.null_count() > 0 {
325 return Err(ArrowError::InvalidArgumentError(format!(
326 "Column '{}' is declared as non-nullable but contains null values",
327 f.name()
328 )));
329 }
330 }
331
332 if columns.iter().any(|c| c.len() != row_count) {
334 let err = match options.row_count {
335 Some(_) => "all columns in a record batch must have the specified row count",
336 None => "all columns in a record batch must have the same length",
337 };
338 return Err(ArrowError::InvalidArgumentError(err.to_string()));
339 }
340
341 let type_not_match = if options.match_field_names {
344 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
345 } else {
346 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
347 !col_type.equals_datatype(field_type)
348 }
349 };
350
351 let not_match = columns
353 .iter()
354 .zip(schema.fields().iter())
355 .map(|(col, field)| (col.data_type(), field.data_type()))
356 .enumerate()
357 .find(type_not_match);
358
359 if let Some((i, (col_type, field_type))) = not_match {
360 return Err(ArrowError::InvalidArgumentError(format!(
361 "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
362 }
363
364 Ok(RecordBatch {
365 schema,
366 columns,
367 row_count,
368 })
369 }
370
371 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
373 (self.schema, self.columns, self.row_count)
374 }
375
376 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
381 if !schema.contains(self.schema.as_ref()) {
382 return Err(ArrowError::SchemaError(format!(
383 "target schema is not superset of current schema target={schema} current={}",
384 self.schema
385 )));
386 }
387
388 Ok(Self {
389 schema,
390 columns: self.columns,
391 row_count: self.row_count,
392 })
393 }
394
395 pub fn schema(&self) -> SchemaRef {
397 self.schema.clone()
398 }
399
400 pub fn schema_ref(&self) -> &SchemaRef {
402 &self.schema
403 }
404
405 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
407 let projected_schema = self.schema.project(indices)?;
408 let batch_fields = indices
409 .iter()
410 .map(|f| {
411 self.columns.get(*f).cloned().ok_or_else(|| {
412 ArrowError::SchemaError(format!(
413 "project index {} out of bounds, max field {}",
414 f,
415 self.columns.len()
416 ))
417 })
418 })
419 .collect::<Result<Vec<_>, _>>()?;
420
421 RecordBatch::try_new_with_options(
422 SchemaRef::new(projected_schema),
423 batch_fields,
424 &RecordBatchOptions {
425 match_field_names: true,
426 row_count: Some(self.row_count),
427 },
428 )
429 }
430
431 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
491 let max_level = match max_level.unwrap_or(usize::MAX) {
492 0 => usize::MAX,
493 val => val,
494 };
495 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
496 .columns
497 .iter()
498 .zip(self.schema.fields())
499 .rev()
500 .map(|(c, f)| {
501 let name_vec: Vec<&str> = vec![f.name()];
502 (0, c, name_vec, f)
503 })
504 .collect();
505 let mut columns: Vec<ArrayRef> = Vec::new();
506 let mut fields: Vec<FieldRef> = Vec::new();
507
508 while let Some((depth, c, name, field_ref)) = stack.pop() {
509 match field_ref.data_type() {
510 DataType::Struct(ff) if depth < max_level => {
511 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
513 let mut name = name.clone();
514 name.push(separator);
515 name.push(fff.name());
516 stack.push((depth + 1, cff, name, fff))
517 }
518 }
519 _ => {
520 let updated_field = Field::new(
521 name.concat(),
522 field_ref.data_type().clone(),
523 field_ref.is_nullable(),
524 );
525 columns.push(c.clone());
526 fields.push(Arc::new(updated_field));
527 }
528 }
529 }
530 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
531 }
532
533 pub fn num_columns(&self) -> usize {
552 self.columns.len()
553 }
554
555 pub fn num_rows(&self) -> usize {
574 self.row_count
575 }
576
577 pub fn column(&self, index: usize) -> &ArrayRef {
583 &self.columns[index]
584 }
585
586 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
588 self.schema()
589 .column_with_name(name)
590 .map(|(index, _)| &self.columns[index])
591 }
592
593 pub fn columns(&self) -> &[ArrayRef] {
595 &self.columns[..]
596 }
597
598 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
626 let mut builder = SchemaBuilder::from(self.schema.as_ref());
627 builder.remove(index);
628 self.schema = Arc::new(builder.finish());
629 self.columns.remove(index)
630 }
631
632 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
639 assert!((offset + length) <= self.num_rows());
640
641 let columns = self
642 .columns()
643 .iter()
644 .map(|column| column.slice(offset, length))
645 .collect();
646
647 Self {
648 schema: self.schema.clone(),
649 columns,
650 row_count: length,
651 }
652 }
653
654 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
691 where
692 I: IntoIterator<Item = (F, ArrayRef)>,
693 F: AsRef<str>,
694 {
695 let iter = value.into_iter().map(|(field_name, array)| {
699 let nullable = array.null_count() > 0;
700 (field_name, array, nullable)
701 });
702
703 Self::try_from_iter_with_nullable(iter)
704 }
705
706 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
728 where
729 I: IntoIterator<Item = (F, ArrayRef, bool)>,
730 F: AsRef<str>,
731 {
732 let iter = value.into_iter();
733 let capacity = iter.size_hint().0;
734 let mut schema = SchemaBuilder::with_capacity(capacity);
735 let mut columns = Vec::with_capacity(capacity);
736
737 for (field_name, array, nullable) in iter {
738 let field_name = field_name.as_ref();
739 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
740 columns.push(array);
741 }
742
743 let schema = Arc::new(schema.finish());
744 RecordBatch::try_new(schema, columns)
745 }
746
747 pub fn get_array_memory_size(&self) -> usize {
754 self.columns()
755 .iter()
756 .map(|array| array.get_array_memory_size())
757 .sum()
758 }
759}
760
761#[derive(Debug)]
763#[non_exhaustive]
764pub struct RecordBatchOptions {
765 pub match_field_names: bool,
767
768 pub row_count: Option<usize>,
770}
771
772impl RecordBatchOptions {
773 pub fn new() -> Self {
775 Self {
776 match_field_names: true,
777 row_count: None,
778 }
779 }
780 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
782 self.row_count = row_count;
783 self
784 }
785 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
787 self.match_field_names = match_field_names;
788 self
789 }
790}
791impl Default for RecordBatchOptions {
792 fn default() -> Self {
793 Self::new()
794 }
795}
796impl From<StructArray> for RecordBatch {
797 fn from(value: StructArray) -> Self {
798 let row_count = value.len();
799 let (fields, columns, nulls) = value.into_parts();
800 assert_eq!(
801 nulls.map(|n| n.null_count()).unwrap_or_default(),
802 0,
803 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
804 );
805
806 RecordBatch {
807 schema: Arc::new(Schema::new(fields)),
808 row_count,
809 columns,
810 }
811 }
812}
813
814impl From<&StructArray> for RecordBatch {
815 fn from(struct_array: &StructArray) -> Self {
816 struct_array.clone().into()
817 }
818}
819
820impl Index<&str> for RecordBatch {
821 type Output = ArrayRef;
822
823 fn index(&self, name: &str) -> &Self::Output {
829 self.column_by_name(name).unwrap()
830 }
831}
832
833pub struct RecordBatchIterator<I>
859where
860 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
861{
862 inner: I::IntoIter,
863 inner_schema: SchemaRef,
864}
865
866impl<I> RecordBatchIterator<I>
867where
868 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
869{
870 pub fn new(iter: I, schema: SchemaRef) -> Self {
874 Self {
875 inner: iter.into_iter(),
876 inner_schema: schema,
877 }
878 }
879}
880
881impl<I> Iterator for RecordBatchIterator<I>
882where
883 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
884{
885 type Item = I::Item;
886
887 fn next(&mut self) -> Option<Self::Item> {
888 self.inner.next()
889 }
890
891 fn size_hint(&self) -> (usize, Option<usize>) {
892 self.inner.size_hint()
893 }
894}
895
896impl<I> RecordBatchReader for RecordBatchIterator<I>
897where
898 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
899{
900 fn schema(&self) -> SchemaRef {
901 self.inner_schema.clone()
902 }
903}
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use crate::{
909 BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
910 };
911 use arrow_buffer::{Buffer, ToByteSlice};
912 use arrow_data::{ArrayData, ArrayDataBuilder};
913 use arrow_schema::Fields;
914 use std::collections::HashMap;
915
916 #[test]
917 fn create_record_batch() {
918 let schema = Schema::new(vec![
919 Field::new("a", DataType::Int32, false),
920 Field::new("b", DataType::Utf8, false),
921 ]);
922
923 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
924 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
925
926 let record_batch =
927 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
928 check_batch(record_batch, 5)
929 }
930
931 #[test]
932 fn create_string_view_record_batch() {
933 let schema = Schema::new(vec![
934 Field::new("a", DataType::Int32, false),
935 Field::new("b", DataType::Utf8View, false),
936 ]);
937
938 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
939 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
940
941 let record_batch =
942 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
943
944 assert_eq!(5, record_batch.num_rows());
945 assert_eq!(2, record_batch.num_columns());
946 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
947 assert_eq!(
948 &DataType::Utf8View,
949 record_batch.schema().field(1).data_type()
950 );
951 assert_eq!(5, record_batch.column(0).len());
952 assert_eq!(5, record_batch.column(1).len());
953 }
954
955 #[test]
956 fn byte_size_should_not_regress() {
957 let schema = Schema::new(vec![
958 Field::new("a", DataType::Int32, false),
959 Field::new("b", DataType::Utf8, false),
960 ]);
961
962 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
963 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
964
965 let record_batch =
966 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
967 assert_eq!(record_batch.get_array_memory_size(), 364);
968 }
969
970 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
971 assert_eq!(num_rows, record_batch.num_rows());
972 assert_eq!(2, record_batch.num_columns());
973 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
974 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
975 assert_eq!(num_rows, record_batch.column(0).len());
976 assert_eq!(num_rows, record_batch.column(1).len());
977 }
978
979 #[test]
980 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
981 fn create_record_batch_slice() {
982 let schema = Schema::new(vec![
983 Field::new("a", DataType::Int32, false),
984 Field::new("b", DataType::Utf8, false),
985 ]);
986 let expected_schema = schema.clone();
987
988 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
989 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
990
991 let record_batch =
992 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
993
994 let offset = 2;
995 let length = 5;
996 let record_batch_slice = record_batch.slice(offset, length);
997
998 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
999 check_batch(record_batch_slice, 5);
1000
1001 let offset = 2;
1002 let length = 0;
1003 let record_batch_slice = record_batch.slice(offset, length);
1004
1005 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1006 check_batch(record_batch_slice, 0);
1007
1008 let offset = 2;
1009 let length = 10;
1010 let _record_batch_slice = record_batch.slice(offset, length);
1011 }
1012
1013 #[test]
1014 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1015 fn create_record_batch_slice_empty_batch() {
1016 let schema = Schema::empty();
1017
1018 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1019
1020 let offset = 0;
1021 let length = 0;
1022 let record_batch_slice = record_batch.slice(offset, length);
1023 assert_eq!(0, record_batch_slice.schema().fields().len());
1024
1025 let offset = 1;
1026 let length = 2;
1027 let _record_batch_slice = record_batch.slice(offset, length);
1028 }
1029
1030 #[test]
1031 fn create_record_batch_try_from_iter() {
1032 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1033 Some(1),
1034 Some(2),
1035 None,
1036 Some(4),
1037 Some(5),
1038 ]));
1039 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1040
1041 let record_batch =
1042 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1043
1044 let expected_schema = Schema::new(vec![
1045 Field::new("a", DataType::Int32, true),
1046 Field::new("b", DataType::Utf8, false),
1047 ]);
1048 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1049 check_batch(record_batch, 5);
1050 }
1051
1052 #[test]
1053 fn create_record_batch_try_from_iter_with_nullable() {
1054 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1055 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1056
1057 let record_batch =
1059 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1060 .expect("valid conversion");
1061
1062 let expected_schema = Schema::new(vec![
1063 Field::new("a", DataType::Int32, false),
1064 Field::new("b", DataType::Utf8, true),
1065 ]);
1066 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1067 check_batch(record_batch, 5);
1068 }
1069
1070 #[test]
1071 fn create_record_batch_schema_mismatch() {
1072 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1073
1074 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1075
1076 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
1077 assert!(batch.is_err());
1078 }
1079
1080 #[test]
1081 fn create_record_batch_field_name_mismatch() {
1082 let fields = vec![
1083 Field::new("a1", DataType::Int32, false),
1084 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1085 ];
1086 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1087
1088 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1089 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1090 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1091 "array",
1092 DataType::Int8,
1093 false,
1094 ))))
1095 .add_child_data(a2_child.into_data())
1096 .len(2)
1097 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1098 .build()
1099 .unwrap();
1100 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1101 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1102 Field::new("aa1", DataType::Int32, false),
1103 Field::new("a2", a2.data_type().clone(), false),
1104 ])))
1105 .add_child_data(a1.into_data())
1106 .add_child_data(a2.into_data())
1107 .len(2)
1108 .build()
1109 .unwrap();
1110 let a: ArrayRef = Arc::new(StructArray::from(a));
1111
1112 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1114 assert!(batch.is_err());
1115
1116 let options = RecordBatchOptions {
1118 match_field_names: false,
1119 row_count: None,
1120 };
1121 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1122 assert!(batch.is_ok());
1123 }
1124
1125 #[test]
1126 fn create_record_batch_record_mismatch() {
1127 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1128
1129 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1130 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1131
1132 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1133 assert!(batch.is_err());
1134 }
1135
1136 #[test]
1137 fn create_record_batch_from_struct_array() {
1138 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1139 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1140 let struct_array = StructArray::from(vec![
1141 (
1142 Arc::new(Field::new("b", DataType::Boolean, false)),
1143 boolean.clone() as ArrayRef,
1144 ),
1145 (
1146 Arc::new(Field::new("c", DataType::Int32, false)),
1147 int.clone() as ArrayRef,
1148 ),
1149 ]);
1150
1151 let batch = RecordBatch::from(&struct_array);
1152 assert_eq!(2, batch.num_columns());
1153 assert_eq!(4, batch.num_rows());
1154 assert_eq!(
1155 struct_array.data_type(),
1156 &DataType::Struct(batch.schema().fields().clone())
1157 );
1158 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1159 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1160 }
1161
1162 #[test]
1163 fn record_batch_equality() {
1164 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1165 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1166 let schema1 = Schema::new(vec![
1167 Field::new("id", DataType::Int32, false),
1168 Field::new("val", DataType::Int32, false),
1169 ]);
1170
1171 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1172 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1173 let schema2 = Schema::new(vec![
1174 Field::new("id", DataType::Int32, false),
1175 Field::new("val", DataType::Int32, false),
1176 ]);
1177
1178 let batch1 = RecordBatch::try_new(
1179 Arc::new(schema1),
1180 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1181 )
1182 .unwrap();
1183
1184 let batch2 = RecordBatch::try_new(
1185 Arc::new(schema2),
1186 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1187 )
1188 .unwrap();
1189
1190 assert_eq!(batch1, batch2);
1191 }
1192
1193 #[test]
1195 fn record_batch_index_access() {
1196 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1197 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1198 let schema1 = Schema::new(vec![
1199 Field::new("id", DataType::Int32, false),
1200 Field::new("val", DataType::Int32, false),
1201 ]);
1202 let record_batch =
1203 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1204
1205 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1206 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1207 }
1208
1209 #[test]
1210 fn record_batch_vals_ne() {
1211 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1212 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1213 let schema1 = Schema::new(vec![
1214 Field::new("id", DataType::Int32, false),
1215 Field::new("val", DataType::Int32, false),
1216 ]);
1217
1218 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1219 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1220 let schema2 = Schema::new(vec![
1221 Field::new("id", DataType::Int32, false),
1222 Field::new("val", DataType::Int32, false),
1223 ]);
1224
1225 let batch1 = RecordBatch::try_new(
1226 Arc::new(schema1),
1227 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1228 )
1229 .unwrap();
1230
1231 let batch2 = RecordBatch::try_new(
1232 Arc::new(schema2),
1233 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1234 )
1235 .unwrap();
1236
1237 assert_ne!(batch1, batch2);
1238 }
1239
1240 #[test]
1241 fn record_batch_column_names_ne() {
1242 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1243 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1244 let schema1 = Schema::new(vec![
1245 Field::new("id", DataType::Int32, false),
1246 Field::new("val", DataType::Int32, false),
1247 ]);
1248
1249 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1250 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1251 let schema2 = Schema::new(vec![
1252 Field::new("id", DataType::Int32, false),
1253 Field::new("num", DataType::Int32, false),
1254 ]);
1255
1256 let batch1 = RecordBatch::try_new(
1257 Arc::new(schema1),
1258 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1259 )
1260 .unwrap();
1261
1262 let batch2 = RecordBatch::try_new(
1263 Arc::new(schema2),
1264 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1265 )
1266 .unwrap();
1267
1268 assert_ne!(batch1, batch2);
1269 }
1270
1271 #[test]
1272 fn record_batch_column_number_ne() {
1273 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1274 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1275 let schema1 = Schema::new(vec![
1276 Field::new("id", DataType::Int32, false),
1277 Field::new("val", DataType::Int32, false),
1278 ]);
1279
1280 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1281 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1282 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283 let schema2 = Schema::new(vec![
1284 Field::new("id", DataType::Int32, false),
1285 Field::new("val", DataType::Int32, false),
1286 Field::new("num", DataType::Int32, false),
1287 ]);
1288
1289 let batch1 = RecordBatch::try_new(
1290 Arc::new(schema1),
1291 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1292 )
1293 .unwrap();
1294
1295 let batch2 = RecordBatch::try_new(
1296 Arc::new(schema2),
1297 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1298 )
1299 .unwrap();
1300
1301 assert_ne!(batch1, batch2);
1302 }
1303
1304 #[test]
1305 fn record_batch_row_count_ne() {
1306 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1307 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1308 let schema1 = Schema::new(vec![
1309 Field::new("id", DataType::Int32, false),
1310 Field::new("val", DataType::Int32, false),
1311 ]);
1312
1313 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1314 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315 let schema2 = Schema::new(vec![
1316 Field::new("id", DataType::Int32, false),
1317 Field::new("num", DataType::Int32, false),
1318 ]);
1319
1320 let batch1 = RecordBatch::try_new(
1321 Arc::new(schema1),
1322 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1323 )
1324 .unwrap();
1325
1326 let batch2 = RecordBatch::try_new(
1327 Arc::new(schema2),
1328 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1329 )
1330 .unwrap();
1331
1332 assert_ne!(batch1, batch2);
1333 }
1334
1335 #[test]
1336 fn normalize_simple() {
1337 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1338 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1339 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1340
1341 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1342 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1343 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1344
1345 let a = Arc::new(StructArray::from(vec![
1346 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1347 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1348 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1349 ]));
1350
1351 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1352
1353 let schema = Schema::new(vec![
1354 Field::new(
1355 "a",
1356 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1357 false,
1358 ),
1359 Field::new("month", DataType::Int64, true),
1360 ]);
1361
1362 let normalized =
1363 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1364 .expect("valid conversion")
1365 .normalize(".", Some(0))
1366 .expect("valid normalization");
1367
1368 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1369 ("a.animals", animals.clone(), true),
1370 ("a.n_legs", n_legs.clone(), true),
1371 ("a.year", year.clone(), true),
1372 ("month", month.clone(), true),
1373 ])
1374 .expect("valid conversion");
1375
1376 assert_eq!(expected, normalized);
1377
1378 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1380 .expect("valid conversion")
1381 .normalize(".", None)
1382 .expect("valid normalization");
1383
1384 assert_eq!(expected, normalized);
1385 }
1386
1387 #[test]
1388 fn normalize_nested() {
1389 let a = Arc::new(Field::new("a", DataType::Int64, true));
1391 let b = Arc::new(Field::new("b", DataType::Int64, false));
1392 let c = Arc::new(Field::new("c", DataType::Int64, true));
1393
1394 let one = Arc::new(Field::new(
1395 "1",
1396 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1397 false,
1398 ));
1399 let two = Arc::new(Field::new(
1400 "2",
1401 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1402 true,
1403 ));
1404
1405 let exclamation = Arc::new(Field::new(
1406 "!",
1407 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1408 false,
1409 ));
1410
1411 let schema = Schema::new(vec![exclamation.clone()]);
1412
1413 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1415 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1416 let c_field = Int64Array::from(vec![None, Some(4)]);
1417
1418 let one_field = StructArray::from(vec![
1419 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1420 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1421 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1422 ]);
1423 let two_field = StructArray::from(vec![
1424 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1425 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1426 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1427 ]);
1428
1429 let exclamation_field = Arc::new(StructArray::from(vec![
1430 (one.clone(), Arc::new(one_field) as ArrayRef),
1431 (two.clone(), Arc::new(two_field) as ArrayRef),
1432 ]));
1433
1434 let normalized =
1436 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1437 .expect("valid conversion")
1438 .normalize(".", Some(1))
1439 .expect("valid normalization");
1440
1441 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1442 (
1443 "!.1",
1444 Arc::new(StructArray::from(vec![
1445 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1446 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1447 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1448 ])) as ArrayRef,
1449 false,
1450 ),
1451 (
1452 "!.2",
1453 Arc::new(StructArray::from(vec![
1454 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1455 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1456 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1457 ])) as ArrayRef,
1458 true,
1459 ),
1460 ])
1461 .expect("valid conversion");
1462
1463 assert_eq!(expected, normalized);
1464
1465 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1467 .expect("valid conversion")
1468 .normalize(".", None)
1469 .expect("valid normalization");
1470
1471 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1472 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1473 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1474 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1475 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1476 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1477 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1478 ])
1479 .expect("valid conversion");
1480
1481 assert_eq!(expected, normalized);
1482 }
1483
1484 #[test]
1485 fn normalize_empty() {
1486 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1487 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1488 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1489
1490 let schema = Schema::new(vec![
1491 Field::new(
1492 "a",
1493 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1494 false,
1495 ),
1496 Field::new("month", DataType::Int64, true),
1497 ]);
1498
1499 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1500 .normalize(".", Some(0))
1501 .expect("valid normalization");
1502
1503 let expected = RecordBatch::new_empty(Arc::new(
1504 schema.normalize(".", Some(0)).expect("valid normalization"),
1505 ));
1506
1507 assert_eq!(expected, normalized);
1508 }
1509
1510 #[test]
1511 fn project() {
1512 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1513 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1514 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1515
1516 let record_batch =
1517 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1518 .expect("valid conversion");
1519
1520 let expected =
1521 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1522
1523 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1524 }
1525
1526 #[test]
1527 fn project_empty() {
1528 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1529
1530 let record_batch =
1531 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1532
1533 let expected = RecordBatch::try_new_with_options(
1534 Arc::new(Schema::empty()),
1535 vec![],
1536 &RecordBatchOptions {
1537 match_field_names: true,
1538 row_count: Some(3),
1539 },
1540 )
1541 .expect("valid conversion");
1542
1543 assert_eq!(expected, record_batch.project(&[]).unwrap());
1544 }
1545
1546 #[test]
1547 fn test_no_column_record_batch() {
1548 let schema = Arc::new(Schema::empty());
1549
1550 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1551 assert!(err
1552 .to_string()
1553 .contains("must either specify a row count or at least one column"));
1554
1555 let options = RecordBatchOptions::new().with_row_count(Some(10));
1556
1557 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1558 assert_eq!(ok.num_rows(), 10);
1559
1560 let a = ok.slice(2, 5);
1561 assert_eq!(a.num_rows(), 5);
1562
1563 let b = ok.slice(5, 0);
1564 assert_eq!(b.num_rows(), 0);
1565
1566 assert_ne!(a, b);
1567 assert_eq!(b, RecordBatch::new_empty(schema))
1568 }
1569
1570 #[test]
1571 fn test_nulls_in_non_nullable_field() {
1572 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1573 let maybe_batch = RecordBatch::try_new(
1574 schema,
1575 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1576 );
1577 assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1578 }
1579 #[test]
1580 fn test_record_batch_options() {
1581 let options = RecordBatchOptions::new()
1582 .with_match_field_names(false)
1583 .with_row_count(Some(20));
1584 assert!(!options.match_field_names);
1585 assert_eq!(options.row_count.unwrap(), 20)
1586 }
1587
1588 #[test]
1589 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1590 fn test_from_struct() {
1591 let s = StructArray::from(ArrayData::new_null(
1592 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1594 2,
1595 ));
1596 let _ = RecordBatch::from(s);
1597 }
1598
1599 #[test]
1600 fn test_with_schema() {
1601 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1602 let required_schema = Arc::new(required_schema);
1603 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1604 let nullable_schema = Arc::new(nullable_schema);
1605
1606 let batch = RecordBatch::try_new(
1607 required_schema.clone(),
1608 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1609 )
1610 .unwrap();
1611
1612 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1614
1615 batch.clone().with_schema(required_schema).unwrap_err();
1617
1618 let metadata = vec![("foo".to_string(), "bar".to_string())]
1620 .into_iter()
1621 .collect();
1622 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1623 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1624
1625 batch.with_schema(nullable_schema).unwrap_err();
1627 }
1628
1629 #[test]
1630 fn test_boxed_reader() {
1631 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1634 let schema = Arc::new(schema);
1635
1636 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1637 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1638
1639 fn get_size(reader: impl RecordBatchReader) -> usize {
1640 reader.size_hint().0
1641 }
1642
1643 let size = get_size(reader);
1644 assert_eq!(size, 0);
1645 }
1646
1647 #[test]
1648 fn test_remove_column_maintains_schema_metadata() {
1649 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1650 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1651
1652 let mut metadata = HashMap::new();
1653 metadata.insert("foo".to_string(), "bar".to_string());
1654 let schema = Schema::new(vec![
1655 Field::new("id", DataType::Int32, false),
1656 Field::new("bool", DataType::Boolean, false),
1657 ])
1658 .with_metadata(metadata);
1659
1660 let mut batch = RecordBatch::try_new(
1661 Arc::new(schema),
1662 vec![Arc::new(id_array), Arc::new(bool_array)],
1663 )
1664 .unwrap();
1665
1666 let _removed_column = batch.remove_column(0);
1667 assert_eq!(batch.schema().metadata().len(), 1);
1668 assert_eq!(
1669 batch.schema().metadata().get("foo").unwrap().as_str(),
1670 "bar"
1671 );
1672 }
1673}