1use crate::cast::AsArray;
22use crate::{Array, ArrayRef, StructArray, new_empty_array};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal32) => { $crate::Decimal32Array };
111 (@from Decimal64) => { $crate::Decimal64Array };
112 (@from Decimal128) => { $crate::Decimal128Array };
113 (@from Decimal256) => { $crate::Decimal256Array };
114 (@from TimestampSecond) => { $crate::TimestampSecondArray };
115 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119 (@from $ty: ident) => {
120 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121 };
122
123 (Null, $size: expr) => {
124 std::sync::Arc::new($crate::NullArray::new($size))
125 };
126
127 (Binary, [$($values: expr),*]) => {
128 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129 };
130
131 (LargeBinary, [$($values: expr),*]) => {
132 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133 };
134
135 ($ty: tt, [$($values: expr),*]) => {
136 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137 };
138
139 (Binary, $values: expr) => {
140 std::sync::Arc::new($crate::BinaryArray::from_vec($values))
141 };
142
143 (LargeBinary, $values: expr) => {
144 std::sync::Arc::new($crate::LargeBinaryArray::from_vec($values))
145 };
146
147 ($ty: tt, $values: expr) => {
148 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from($values))
149 };
150}
151
152#[macro_export]
181macro_rules! record_batch {
182 ($(($name: expr, $type: ident, $($values: tt)+)),*) => {
183 {
184 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
185 $(
186 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
187 )*
188 ]));
189
190 $crate::RecordBatch::try_new(
191 schema,
192 vec![$(
193 $crate::create_array!($type, $($values)+),
194 )*]
195 )
196 }
197 };
198}
199
200#[derive(Clone, Debug, PartialEq)]
224pub struct RecordBatch {
225 schema: SchemaRef,
226 columns: Vec<Arc<dyn Array>>,
227
228 row_count: usize,
232}
233
234impl RecordBatch {
235 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
264 let options = RecordBatchOptions::new();
265 Self::try_new_impl(schema, columns, &options)
266 }
267
268 pub unsafe fn new_unchecked(
284 schema: SchemaRef,
285 columns: Vec<Arc<dyn Array>>,
286 row_count: usize,
287 ) -> Self {
288 Self {
289 schema,
290 columns,
291 row_count,
292 }
293 }
294
295 pub fn try_new_with_options(
300 schema: SchemaRef,
301 columns: Vec<ArrayRef>,
302 options: &RecordBatchOptions,
303 ) -> Result<Self, ArrowError> {
304 Self::try_new_impl(schema, columns, options)
305 }
306
307 pub fn new_empty(schema: SchemaRef) -> Self {
309 let columns = schema
310 .fields()
311 .iter()
312 .map(|field| new_empty_array(field.data_type()))
313 .collect();
314
315 RecordBatch {
316 schema,
317 columns,
318 row_count: 0,
319 }
320 }
321
322 fn try_new_impl(
325 schema: SchemaRef,
326 columns: Vec<ArrayRef>,
327 options: &RecordBatchOptions,
328 ) -> Result<Self, ArrowError> {
329 if schema.fields().len() != columns.len() {
331 return Err(ArrowError::InvalidArgumentError(format!(
332 "number of columns({}) must match number of fields({}) in schema",
333 columns.len(),
334 schema.fields().len(),
335 )));
336 }
337
338 let row_count = options
339 .row_count
340 .or_else(|| columns.first().map(|col| col.len()))
341 .ok_or_else(|| {
342 ArrowError::InvalidArgumentError(
343 "must either specify a row count or at least one column".to_string(),
344 )
345 })?;
346
347 for (c, f) in columns.iter().zip(&schema.fields) {
348 if !f.is_nullable() && c.null_count() > 0 {
349 return Err(ArrowError::InvalidArgumentError(format!(
350 "Column '{}' is declared as non-nullable but contains null values",
351 f.name()
352 )));
353 }
354 }
355
356 if columns.iter().any(|c| c.len() != row_count) {
358 let err = match options.row_count {
359 Some(_) => "all columns in a record batch must have the specified row count",
360 None => "all columns in a record batch must have the same length",
361 };
362 return Err(ArrowError::InvalidArgumentError(err.to_string()));
363 }
364
365 let type_not_match = if options.match_field_names {
368 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
369 } else {
370 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
371 !col_type.equals_datatype(field_type)
372 }
373 };
374
375 let not_match = columns
377 .iter()
378 .zip(schema.fields().iter())
379 .map(|(col, field)| (col.data_type(), field.data_type()))
380 .enumerate()
381 .find(type_not_match);
382
383 if let Some((i, (col_type, field_type))) = not_match {
384 return Err(ArrowError::InvalidArgumentError(format!(
385 "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
386 )));
387 }
388
389 Ok(RecordBatch {
390 schema,
391 columns,
392 row_count,
393 })
394 }
395
396 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
398 (self.schema, self.columns, self.row_count)
399 }
400
401 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
408 if !schema.contains(self.schema.as_ref()) {
409 return Err(ArrowError::SchemaError(format!(
410 "target schema is not superset of current schema target={schema} current={}",
411 self.schema
412 )));
413 }
414
415 Ok(Self {
416 schema,
417 columns: self.columns,
418 row_count: self.row_count,
419 })
420 }
421
422 pub fn schema(&self) -> SchemaRef {
424 self.schema.clone()
425 }
426
427 pub fn schema_ref(&self) -> &SchemaRef {
429 &self.schema
430 }
431
432 pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
450 let schema = Arc::make_mut(&mut self.schema);
451 &mut schema.metadata
452 }
453
454 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
456 let projected_schema = self.schema.project(indices)?;
457 let batch_fields = indices
458 .iter()
459 .map(|f| {
460 self.columns.get(*f).cloned().ok_or_else(|| {
461 ArrowError::SchemaError(format!(
462 "project index {} out of bounds, max field {}",
463 f,
464 self.columns.len()
465 ))
466 })
467 })
468 .collect::<Result<Vec<_>, _>>()?;
469
470 unsafe {
471 Ok(RecordBatch::new_unchecked(
475 SchemaRef::new(projected_schema),
476 batch_fields,
477 self.row_count,
478 ))
479 }
480 }
481
482 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
542 let max_level = match max_level.unwrap_or(usize::MAX) {
543 0 => usize::MAX,
544 val => val,
545 };
546 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
547 .columns
548 .iter()
549 .zip(self.schema.fields())
550 .rev()
551 .map(|(c, f)| {
552 let name_vec: Vec<&str> = vec![f.name()];
553 (0, c, name_vec, f)
554 })
555 .collect();
556 let mut columns: Vec<ArrayRef> = Vec::new();
557 let mut fields: Vec<FieldRef> = Vec::new();
558
559 while let Some((depth, c, name, field_ref)) = stack.pop() {
560 match field_ref.data_type() {
561 DataType::Struct(ff) if depth < max_level => {
562 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
564 let mut name = name.clone();
565 name.push(separator);
566 name.push(fff.name());
567 stack.push((depth + 1, cff, name, fff))
568 }
569 }
570 _ => {
571 let updated_field = Field::new(
572 name.concat(),
573 field_ref.data_type().clone(),
574 field_ref.is_nullable(),
575 );
576 columns.push(c.clone());
577 fields.push(Arc::new(updated_field));
578 }
579 }
580 }
581 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
582 }
583
584 pub fn num_columns(&self) -> usize {
603 self.columns.len()
604 }
605
606 pub fn num_rows(&self) -> usize {
625 self.row_count
626 }
627
628 pub fn column(&self, index: usize) -> &ArrayRef {
634 &self.columns[index]
635 }
636
637 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
639 self.schema()
640 .column_with_name(name)
641 .map(|(index, _)| &self.columns[index])
642 }
643
644 pub fn columns(&self) -> &[ArrayRef] {
646 &self.columns[..]
647 }
648
649 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
677 let mut builder = SchemaBuilder::from(self.schema.as_ref());
678 builder.remove(index);
679 self.schema = Arc::new(builder.finish());
680 self.columns.remove(index)
681 }
682
683 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
690 assert!((offset + length) <= self.num_rows());
691
692 let columns = self
693 .columns()
694 .iter()
695 .map(|column| column.slice(offset, length))
696 .collect();
697
698 Self {
699 schema: self.schema.clone(),
700 columns,
701 row_count: length,
702 }
703 }
704
705 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
742 where
743 I: IntoIterator<Item = (F, ArrayRef)>,
744 F: AsRef<str>,
745 {
746 let iter = value.into_iter().map(|(field_name, array)| {
750 let nullable = array.null_count() > 0;
751 (field_name, array, nullable)
752 });
753
754 Self::try_from_iter_with_nullable(iter)
755 }
756
757 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
779 where
780 I: IntoIterator<Item = (F, ArrayRef, bool)>,
781 F: AsRef<str>,
782 {
783 let iter = value.into_iter();
784 let capacity = iter.size_hint().0;
785 let mut schema = SchemaBuilder::with_capacity(capacity);
786 let mut columns = Vec::with_capacity(capacity);
787
788 for (field_name, array, nullable) in iter {
789 let field_name = field_name.as_ref();
790 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
791 columns.push(array);
792 }
793
794 let schema = Arc::new(schema.finish());
795 RecordBatch::try_new(schema, columns)
796 }
797
798 #[cfg(feature = "pool")]
806 pub fn claim(&self, pool: &dyn arrow_buffer::MemoryPool) {
807 for column in self.columns() {
808 column.claim(pool);
809 }
810 }
811
812 pub fn get_array_memory_size(&self) -> usize {
819 self.columns()
820 .iter()
821 .map(|array| array.get_array_memory_size())
822 .sum()
823 }
824}
825
826#[derive(Debug)]
828#[non_exhaustive]
829pub struct RecordBatchOptions {
830 pub match_field_names: bool,
832
833 pub row_count: Option<usize>,
835}
836
837impl RecordBatchOptions {
838 pub fn new() -> Self {
840 Self {
841 match_field_names: true,
842 row_count: None,
843 }
844 }
845 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
847 self.row_count = row_count;
848 self
849 }
850 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
852 self.match_field_names = match_field_names;
853 self
854 }
855}
856impl Default for RecordBatchOptions {
857 fn default() -> Self {
858 Self::new()
859 }
860}
861impl From<StructArray> for RecordBatch {
862 fn from(value: StructArray) -> Self {
863 let row_count = value.len();
864 let (fields, columns, nulls) = value.into_parts();
865 assert_eq!(
866 nulls.map(|n| n.null_count()).unwrap_or_default(),
867 0,
868 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
869 );
870
871 RecordBatch {
872 schema: Arc::new(Schema::new(fields)),
873 row_count,
874 columns,
875 }
876 }
877}
878
879impl From<&StructArray> for RecordBatch {
880 fn from(struct_array: &StructArray) -> Self {
881 struct_array.clone().into()
882 }
883}
884
885impl Index<&str> for RecordBatch {
886 type Output = ArrayRef;
887
888 fn index(&self, name: &str) -> &Self::Output {
894 self.column_by_name(name).unwrap()
895 }
896}
897
898pub struct RecordBatchIterator<I>
924where
925 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
926{
927 inner: I::IntoIter,
928 inner_schema: SchemaRef,
929}
930
931impl<I> RecordBatchIterator<I>
932where
933 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
934{
935 pub fn new(iter: I, schema: SchemaRef) -> Self {
939 Self {
940 inner: iter.into_iter(),
941 inner_schema: schema,
942 }
943 }
944}
945
946impl<I> Iterator for RecordBatchIterator<I>
947where
948 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
949{
950 type Item = I::Item;
951
952 fn next(&mut self) -> Option<Self::Item> {
953 self.inner.next()
954 }
955
956 fn size_hint(&self) -> (usize, Option<usize>) {
957 self.inner.size_hint()
958 }
959}
960
961impl<I> RecordBatchReader for RecordBatchIterator<I>
962where
963 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
964{
965 fn schema(&self) -> SchemaRef {
966 self.inner_schema.clone()
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973 use crate::{
974 BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
975 };
976 use arrow_buffer::{Buffer, ToByteSlice};
977 use arrow_data::{ArrayData, ArrayDataBuilder};
978 use arrow_schema::Fields;
979 use std::collections::HashMap;
980
981 #[test]
982 fn create_record_batch() {
983 let schema = Schema::new(vec![
984 Field::new("a", DataType::Int32, false),
985 Field::new("b", DataType::Utf8, false),
986 ]);
987
988 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
989 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
990
991 let record_batch =
992 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
993 check_batch(record_batch, 5)
994 }
995
996 #[test]
997 fn create_string_view_record_batch() {
998 let schema = Schema::new(vec![
999 Field::new("a", DataType::Int32, false),
1000 Field::new("b", DataType::Utf8View, false),
1001 ]);
1002
1003 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1004 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
1005
1006 let record_batch =
1007 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1008
1009 assert_eq!(5, record_batch.num_rows());
1010 assert_eq!(2, record_batch.num_columns());
1011 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1012 assert_eq!(
1013 &DataType::Utf8View,
1014 record_batch.schema().field(1).data_type()
1015 );
1016 assert_eq!(5, record_batch.column(0).len());
1017 assert_eq!(5, record_batch.column(1).len());
1018 }
1019
1020 #[test]
1021 fn create_binary_record_batch_from_variables() {
1022 let binary_values = vec![b"a".as_slice()];
1023 let large_binary_values = vec![b"xxx".as_slice()];
1024
1025 let record_batch = record_batch!(
1026 ("a", Binary, binary_values),
1027 ("b", LargeBinary, large_binary_values)
1028 )
1029 .unwrap();
1030
1031 assert_eq!(1, record_batch.num_rows());
1032 assert_eq!(2, record_batch.num_columns());
1033 assert_eq!(
1034 &DataType::Binary,
1035 record_batch.schema().field(0).data_type()
1036 );
1037 assert_eq!(
1038 &DataType::LargeBinary,
1039 record_batch.schema().field(1).data_type()
1040 );
1041
1042 let binary = record_batch.column(0).as_binary::<i32>();
1043 assert_eq!(b"a", binary.value(0));
1044
1045 let large_binary = record_batch.column(1).as_binary::<i64>();
1046 assert_eq!(b"xxx", large_binary.value(0));
1047 }
1048
1049 #[test]
1050 fn byte_size_should_not_regress() {
1051 let schema = Schema::new(vec![
1052 Field::new("a", DataType::Int32, false),
1053 Field::new("b", DataType::Utf8, false),
1054 ]);
1055
1056 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1057 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1058
1059 let record_batch =
1060 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1061 assert_eq!(record_batch.get_array_memory_size(), 364);
1062 }
1063
1064 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
1065 assert_eq!(num_rows, record_batch.num_rows());
1066 assert_eq!(2, record_batch.num_columns());
1067 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1068 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1069 assert_eq!(num_rows, record_batch.column(0).len());
1070 assert_eq!(num_rows, record_batch.column(1).len());
1071 }
1072
1073 #[test]
1074 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1075 fn create_record_batch_slice() {
1076 let schema = Schema::new(vec![
1077 Field::new("a", DataType::Int32, false),
1078 Field::new("b", DataType::Utf8, false),
1079 ]);
1080 let expected_schema = schema.clone();
1081
1082 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1083 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1084
1085 let record_batch =
1086 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1087
1088 let offset = 2;
1089 let length = 5;
1090 let record_batch_slice = record_batch.slice(offset, length);
1091
1092 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1093 check_batch(record_batch_slice, 5);
1094
1095 let offset = 2;
1096 let length = 0;
1097 let record_batch_slice = record_batch.slice(offset, length);
1098
1099 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1100 check_batch(record_batch_slice, 0);
1101
1102 let offset = 2;
1103 let length = 10;
1104 let _record_batch_slice = record_batch.slice(offset, length);
1105 }
1106
1107 #[test]
1108 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1109 fn create_record_batch_slice_empty_batch() {
1110 let schema = Schema::empty();
1111
1112 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1113
1114 let offset = 0;
1115 let length = 0;
1116 let record_batch_slice = record_batch.slice(offset, length);
1117 assert_eq!(0, record_batch_slice.schema().fields().len());
1118
1119 let offset = 1;
1120 let length = 2;
1121 let _record_batch_slice = record_batch.slice(offset, length);
1122 }
1123
1124 #[test]
1125 fn create_record_batch_try_from_iter() {
1126 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1127 Some(1),
1128 Some(2),
1129 None,
1130 Some(4),
1131 Some(5),
1132 ]));
1133 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1134
1135 let record_batch =
1136 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1137
1138 let expected_schema = Schema::new(vec![
1139 Field::new("a", DataType::Int32, true),
1140 Field::new("b", DataType::Utf8, false),
1141 ]);
1142 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1143 check_batch(record_batch, 5);
1144 }
1145
1146 #[test]
1147 fn create_record_batch_try_from_iter_with_nullable() {
1148 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1149 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1150
1151 let record_batch =
1153 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1154 .expect("valid conversion");
1155
1156 let expected_schema = Schema::new(vec![
1157 Field::new("a", DataType::Int32, false),
1158 Field::new("b", DataType::Utf8, true),
1159 ]);
1160 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1161 check_batch(record_batch, 5);
1162 }
1163
1164 #[test]
1165 fn create_record_batch_schema_mismatch() {
1166 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1167
1168 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1169
1170 let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1171 assert_eq!(
1172 err.to_string(),
1173 "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1174 );
1175 }
1176
1177 #[test]
1178 fn create_record_batch_field_name_mismatch() {
1179 let fields = vec![
1180 Field::new("a1", DataType::Int32, false),
1181 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1182 ];
1183 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1184
1185 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1186 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1187 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1188 "array",
1189 DataType::Int8,
1190 false,
1191 ))))
1192 .add_child_data(a2_child.into_data())
1193 .len(2)
1194 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1195 .build()
1196 .unwrap();
1197 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1198 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1199 Field::new("aa1", DataType::Int32, false),
1200 Field::new("a2", a2.data_type().clone(), false),
1201 ])))
1202 .add_child_data(a1.into_data())
1203 .add_child_data(a2.into_data())
1204 .len(2)
1205 .build()
1206 .unwrap();
1207 let a: ArrayRef = Arc::new(StructArray::from(a));
1208
1209 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1211 assert!(batch.is_err());
1212
1213 let options = RecordBatchOptions {
1215 match_field_names: false,
1216 row_count: None,
1217 };
1218 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1219 assert!(batch.is_ok());
1220 }
1221
1222 #[test]
1223 fn create_record_batch_record_mismatch() {
1224 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1225
1226 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1227 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1228
1229 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1230 assert!(batch.is_err());
1231 }
1232
1233 #[test]
1234 fn create_record_batch_from_struct_array() {
1235 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1236 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1237 let struct_array = StructArray::from(vec![
1238 (
1239 Arc::new(Field::new("b", DataType::Boolean, false)),
1240 boolean.clone() as ArrayRef,
1241 ),
1242 (
1243 Arc::new(Field::new("c", DataType::Int32, false)),
1244 int.clone() as ArrayRef,
1245 ),
1246 ]);
1247
1248 let batch = RecordBatch::from(&struct_array);
1249 assert_eq!(2, batch.num_columns());
1250 assert_eq!(4, batch.num_rows());
1251 assert_eq!(
1252 struct_array.data_type(),
1253 &DataType::Struct(batch.schema().fields().clone())
1254 );
1255 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1256 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1257 }
1258
1259 #[test]
1260 fn record_batch_equality() {
1261 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1262 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1263 let schema1 = Schema::new(vec![
1264 Field::new("id", DataType::Int32, false),
1265 Field::new("val", DataType::Int32, false),
1266 ]);
1267
1268 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1269 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1270 let schema2 = Schema::new(vec![
1271 Field::new("id", DataType::Int32, false),
1272 Field::new("val", DataType::Int32, false),
1273 ]);
1274
1275 let batch1 = RecordBatch::try_new(
1276 Arc::new(schema1),
1277 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1278 )
1279 .unwrap();
1280
1281 let batch2 = RecordBatch::try_new(
1282 Arc::new(schema2),
1283 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1284 )
1285 .unwrap();
1286
1287 assert_eq!(batch1, batch2);
1288 }
1289
1290 #[test]
1292 fn record_batch_index_access() {
1293 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1294 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1295 let schema1 = Schema::new(vec![
1296 Field::new("id", DataType::Int32, false),
1297 Field::new("val", DataType::Int32, false),
1298 ]);
1299 let record_batch =
1300 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1301
1302 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1303 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1304 }
1305
1306 #[test]
1307 fn record_batch_vals_ne() {
1308 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1309 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1310 let schema1 = Schema::new(vec![
1311 Field::new("id", DataType::Int32, false),
1312 Field::new("val", DataType::Int32, false),
1313 ]);
1314
1315 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1316 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1317 let schema2 = Schema::new(vec![
1318 Field::new("id", DataType::Int32, false),
1319 Field::new("val", DataType::Int32, false),
1320 ]);
1321
1322 let batch1 = RecordBatch::try_new(
1323 Arc::new(schema1),
1324 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1325 )
1326 .unwrap();
1327
1328 let batch2 = RecordBatch::try_new(
1329 Arc::new(schema2),
1330 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1331 )
1332 .unwrap();
1333
1334 assert_ne!(batch1, batch2);
1335 }
1336
1337 #[test]
1338 fn record_batch_column_names_ne() {
1339 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1340 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1341 let schema1 = Schema::new(vec![
1342 Field::new("id", DataType::Int32, false),
1343 Field::new("val", DataType::Int32, false),
1344 ]);
1345
1346 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1347 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1348 let schema2 = Schema::new(vec![
1349 Field::new("id", DataType::Int32, false),
1350 Field::new("num", DataType::Int32, false),
1351 ]);
1352
1353 let batch1 = RecordBatch::try_new(
1354 Arc::new(schema1),
1355 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1356 )
1357 .unwrap();
1358
1359 let batch2 = RecordBatch::try_new(
1360 Arc::new(schema2),
1361 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1362 )
1363 .unwrap();
1364
1365 assert_ne!(batch1, batch2);
1366 }
1367
1368 #[test]
1369 fn record_batch_column_number_ne() {
1370 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1371 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1372 let schema1 = Schema::new(vec![
1373 Field::new("id", DataType::Int32, false),
1374 Field::new("val", DataType::Int32, false),
1375 ]);
1376
1377 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1378 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1379 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1380 let schema2 = Schema::new(vec![
1381 Field::new("id", DataType::Int32, false),
1382 Field::new("val", DataType::Int32, false),
1383 Field::new("num", DataType::Int32, false),
1384 ]);
1385
1386 let batch1 = RecordBatch::try_new(
1387 Arc::new(schema1),
1388 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1389 )
1390 .unwrap();
1391
1392 let batch2 = RecordBatch::try_new(
1393 Arc::new(schema2),
1394 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1395 )
1396 .unwrap();
1397
1398 assert_ne!(batch1, batch2);
1399 }
1400
1401 #[test]
1402 fn record_batch_row_count_ne() {
1403 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1404 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1405 let schema1 = Schema::new(vec![
1406 Field::new("id", DataType::Int32, false),
1407 Field::new("val", DataType::Int32, false),
1408 ]);
1409
1410 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1411 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1412 let schema2 = Schema::new(vec![
1413 Field::new("id", DataType::Int32, false),
1414 Field::new("num", DataType::Int32, false),
1415 ]);
1416
1417 let batch1 = RecordBatch::try_new(
1418 Arc::new(schema1),
1419 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1420 )
1421 .unwrap();
1422
1423 let batch2 = RecordBatch::try_new(
1424 Arc::new(schema2),
1425 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1426 )
1427 .unwrap();
1428
1429 assert_ne!(batch1, batch2);
1430 }
1431
1432 #[test]
1433 fn normalize_simple() {
1434 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1435 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1436 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1437
1438 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1439 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1440 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1441
1442 let a = Arc::new(StructArray::from(vec![
1443 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1444 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1445 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1446 ]));
1447
1448 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1449
1450 let schema = Schema::new(vec![
1451 Field::new(
1452 "a",
1453 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1454 false,
1455 ),
1456 Field::new("month", DataType::Int64, true),
1457 ]);
1458
1459 let normalized =
1460 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1461 .expect("valid conversion")
1462 .normalize(".", Some(0))
1463 .expect("valid normalization");
1464
1465 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1466 ("a.animals", animals.clone(), true),
1467 ("a.n_legs", n_legs.clone(), true),
1468 ("a.year", year.clone(), true),
1469 ("month", month.clone(), true),
1470 ])
1471 .expect("valid conversion");
1472
1473 assert_eq!(expected, normalized);
1474
1475 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1477 .expect("valid conversion")
1478 .normalize(".", None)
1479 .expect("valid normalization");
1480
1481 assert_eq!(expected, normalized);
1482 }
1483
1484 #[test]
1485 fn normalize_nested() {
1486 let a = Arc::new(Field::new("a", DataType::Int64, true));
1488 let b = Arc::new(Field::new("b", DataType::Int64, false));
1489 let c = Arc::new(Field::new("c", DataType::Int64, true));
1490
1491 let one = Arc::new(Field::new(
1492 "1",
1493 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1494 false,
1495 ));
1496 let two = Arc::new(Field::new(
1497 "2",
1498 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1499 true,
1500 ));
1501
1502 let exclamation = Arc::new(Field::new(
1503 "!",
1504 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1505 false,
1506 ));
1507
1508 let schema = Schema::new(vec![exclamation.clone()]);
1509
1510 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1512 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1513 let c_field = Int64Array::from(vec![None, Some(4)]);
1514
1515 let one_field = StructArray::from(vec![
1516 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1517 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1518 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1519 ]);
1520 let two_field = StructArray::from(vec![
1521 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1522 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1523 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1524 ]);
1525
1526 let exclamation_field = Arc::new(StructArray::from(vec![
1527 (one.clone(), Arc::new(one_field) as ArrayRef),
1528 (two.clone(), Arc::new(two_field) as ArrayRef),
1529 ]));
1530
1531 let normalized =
1533 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1534 .expect("valid conversion")
1535 .normalize(".", Some(1))
1536 .expect("valid normalization");
1537
1538 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1539 (
1540 "!.1",
1541 Arc::new(StructArray::from(vec![
1542 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1543 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1544 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1545 ])) as ArrayRef,
1546 false,
1547 ),
1548 (
1549 "!.2",
1550 Arc::new(StructArray::from(vec![
1551 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1552 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1553 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1554 ])) as ArrayRef,
1555 true,
1556 ),
1557 ])
1558 .expect("valid conversion");
1559
1560 assert_eq!(expected, normalized);
1561
1562 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1564 .expect("valid conversion")
1565 .normalize(".", None)
1566 .expect("valid normalization");
1567
1568 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1569 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1570 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1571 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1572 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1573 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1574 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1575 ])
1576 .expect("valid conversion");
1577
1578 assert_eq!(expected, normalized);
1579 }
1580
1581 #[test]
1582 fn normalize_empty() {
1583 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1584 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1585 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1586
1587 let schema = Schema::new(vec![
1588 Field::new(
1589 "a",
1590 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1591 false,
1592 ),
1593 Field::new("month", DataType::Int64, true),
1594 ]);
1595
1596 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1597 .normalize(".", Some(0))
1598 .expect("valid normalization");
1599
1600 let expected = RecordBatch::new_empty(Arc::new(
1601 schema.normalize(".", Some(0)).expect("valid normalization"),
1602 ));
1603
1604 assert_eq!(expected, normalized);
1605 }
1606
1607 #[test]
1608 fn project() {
1609 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1610 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1611 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1612
1613 let record_batch =
1614 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1615 .expect("valid conversion");
1616
1617 let expected =
1618 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1619
1620 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1621 }
1622
1623 #[test]
1624 fn project_empty() {
1625 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1626
1627 let record_batch =
1628 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1629
1630 let expected = RecordBatch::try_new_with_options(
1631 Arc::new(Schema::empty()),
1632 vec![],
1633 &RecordBatchOptions {
1634 match_field_names: true,
1635 row_count: Some(3),
1636 },
1637 )
1638 .expect("valid conversion");
1639
1640 assert_eq!(expected, record_batch.project(&[]).unwrap());
1641 }
1642
1643 #[test]
1644 fn test_no_column_record_batch() {
1645 let schema = Arc::new(Schema::empty());
1646
1647 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1648 assert!(
1649 err.to_string()
1650 .contains("must either specify a row count or at least one column")
1651 );
1652
1653 let options = RecordBatchOptions::new().with_row_count(Some(10));
1654
1655 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1656 assert_eq!(ok.num_rows(), 10);
1657
1658 let a = ok.slice(2, 5);
1659 assert_eq!(a.num_rows(), 5);
1660
1661 let b = ok.slice(5, 0);
1662 assert_eq!(b.num_rows(), 0);
1663
1664 assert_ne!(a, b);
1665 assert_eq!(b, RecordBatch::new_empty(schema))
1666 }
1667
1668 #[test]
1669 fn test_nulls_in_non_nullable_field() {
1670 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1671 let maybe_batch = RecordBatch::try_new(
1672 schema,
1673 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1674 );
1675 assert_eq!(
1676 "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1677 format!("{}", maybe_batch.err().unwrap())
1678 );
1679 }
1680 #[test]
1681 fn test_record_batch_options() {
1682 let options = RecordBatchOptions::new()
1683 .with_match_field_names(false)
1684 .with_row_count(Some(20));
1685 assert!(!options.match_field_names);
1686 assert_eq!(options.row_count.unwrap(), 20)
1687 }
1688
1689 #[test]
1690 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1691 fn test_from_struct() {
1692 let s = StructArray::from(ArrayData::new_null(
1693 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1695 2,
1696 ));
1697 let _ = RecordBatch::from(s);
1698 }
1699
1700 #[test]
1701 fn test_with_schema() {
1702 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1703 let required_schema = Arc::new(required_schema);
1704 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1705 let nullable_schema = Arc::new(nullable_schema);
1706
1707 let batch = RecordBatch::try_new(
1708 required_schema.clone(),
1709 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1710 )
1711 .unwrap();
1712
1713 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1715
1716 batch.clone().with_schema(required_schema).unwrap_err();
1718
1719 let metadata = vec![("foo".to_string(), "bar".to_string())]
1721 .into_iter()
1722 .collect();
1723 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1724 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1725
1726 batch.with_schema(nullable_schema).unwrap_err();
1728 }
1729
1730 #[test]
1731 fn test_boxed_reader() {
1732 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1735 let schema = Arc::new(schema);
1736
1737 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1738 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1739
1740 fn get_size(reader: impl RecordBatchReader) -> usize {
1741 reader.size_hint().0
1742 }
1743
1744 let size = get_size(reader);
1745 assert_eq!(size, 0);
1746 }
1747
1748 #[test]
1749 fn test_remove_column_maintains_schema_metadata() {
1750 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1751 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1752
1753 let mut metadata = HashMap::new();
1754 metadata.insert("foo".to_string(), "bar".to_string());
1755 let schema = Schema::new(vec![
1756 Field::new("id", DataType::Int32, false),
1757 Field::new("bool", DataType::Boolean, false),
1758 ])
1759 .with_metadata(metadata);
1760
1761 let mut batch = RecordBatch::try_new(
1762 Arc::new(schema),
1763 vec![Arc::new(id_array), Arc::new(bool_array)],
1764 )
1765 .unwrap();
1766
1767 let _removed_column = batch.remove_column(0);
1768 assert_eq!(batch.schema().metadata().len(), 1);
1769 assert_eq!(
1770 batch.schema().metadata().get("foo").unwrap().as_str(),
1771 "bar"
1772 );
1773 }
1774}