1use crate::cast::AsArray;
22use crate::{Array, ArrayRef, StructArray, new_empty_array};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal32) => { $crate::Decimal32Array };
111 (@from Decimal64) => { $crate::Decimal64Array };
112 (@from Decimal128) => { $crate::Decimal128Array };
113 (@from Decimal256) => { $crate::Decimal256Array };
114 (@from TimestampSecond) => { $crate::TimestampSecondArray };
115 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119 (@from $ty: ident) => {
120 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121 };
122
123 (Null, $size: expr) => {
124 std::sync::Arc::new($crate::NullArray::new($size))
125 };
126
127 (Binary, [$($values: expr),*]) => {
128 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129 };
130
131 (LargeBinary, [$($values: expr),*]) => {
132 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133 };
134
135 ($ty: tt, [$($values: expr),*]) => {
136 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137 };
138}
139
140#[macro_export]
157macro_rules! record_batch {
158 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159 {
160 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161 $(
162 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163 )*
164 ]));
165
166 let batch = $crate::RecordBatch::try_new(
167 schema,
168 vec![$(
169 $crate::create_array!($type, [$($values),*]),
170 )*]
171 );
172
173 batch
174 }
175 }
176}
177
178#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203 schema: SchemaRef,
204 columns: Vec<Arc<dyn Array>>,
205
206 row_count: usize,
210}
211
212impl RecordBatch {
213 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242 let options = RecordBatchOptions::new();
243 Self::try_new_impl(schema, columns, &options)
244 }
245
246 pub unsafe fn new_unchecked(
262 schema: SchemaRef,
263 columns: Vec<Arc<dyn Array>>,
264 row_count: usize,
265 ) -> Self {
266 Self {
267 schema,
268 columns,
269 row_count,
270 }
271 }
272
273 pub fn try_new_with_options(
278 schema: SchemaRef,
279 columns: Vec<ArrayRef>,
280 options: &RecordBatchOptions,
281 ) -> Result<Self, ArrowError> {
282 Self::try_new_impl(schema, columns, options)
283 }
284
285 pub fn new_empty(schema: SchemaRef) -> Self {
287 let columns = schema
288 .fields()
289 .iter()
290 .map(|field| new_empty_array(field.data_type()))
291 .collect();
292
293 RecordBatch {
294 schema,
295 columns,
296 row_count: 0,
297 }
298 }
299
300 fn try_new_impl(
303 schema: SchemaRef,
304 columns: Vec<ArrayRef>,
305 options: &RecordBatchOptions,
306 ) -> Result<Self, ArrowError> {
307 if schema.fields().len() != columns.len() {
309 return Err(ArrowError::InvalidArgumentError(format!(
310 "number of columns({}) must match number of fields({}) in schema",
311 columns.len(),
312 schema.fields().len(),
313 )));
314 }
315
316 let row_count = options
317 .row_count
318 .or_else(|| columns.first().map(|col| col.len()))
319 .ok_or_else(|| {
320 ArrowError::InvalidArgumentError(
321 "must either specify a row count or at least one column".to_string(),
322 )
323 })?;
324
325 for (c, f) in columns.iter().zip(&schema.fields) {
326 if !f.is_nullable() && c.null_count() > 0 {
327 return Err(ArrowError::InvalidArgumentError(format!(
328 "Column '{}' is declared as non-nullable but contains null values",
329 f.name()
330 )));
331 }
332 }
333
334 if columns.iter().any(|c| c.len() != row_count) {
336 let err = match options.row_count {
337 Some(_) => "all columns in a record batch must have the specified row count",
338 None => "all columns in a record batch must have the same length",
339 };
340 return Err(ArrowError::InvalidArgumentError(err.to_string()));
341 }
342
343 let type_not_match = if options.match_field_names {
346 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347 } else {
348 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349 !col_type.equals_datatype(field_type)
350 }
351 };
352
353 let not_match = columns
355 .iter()
356 .zip(schema.fields().iter())
357 .map(|(col, field)| (col.data_type(), field.data_type()))
358 .enumerate()
359 .find(type_not_match);
360
361 if let Some((i, (col_type, field_type))) = not_match {
362 return Err(ArrowError::InvalidArgumentError(format!(
363 "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
364 )));
365 }
366
367 Ok(RecordBatch {
368 schema,
369 columns,
370 row_count,
371 })
372 }
373
374 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376 (self.schema, self.columns, self.row_count)
377 }
378
379 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
386 if !schema.contains(self.schema.as_ref()) {
387 return Err(ArrowError::SchemaError(format!(
388 "target schema is not superset of current schema target={schema} current={}",
389 self.schema
390 )));
391 }
392
393 Ok(Self {
394 schema,
395 columns: self.columns,
396 row_count: self.row_count,
397 })
398 }
399
400 pub fn schema(&self) -> SchemaRef {
402 self.schema.clone()
403 }
404
405 pub fn schema_ref(&self) -> &SchemaRef {
407 &self.schema
408 }
409
410 pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
428 let schema = Arc::make_mut(&mut self.schema);
429 &mut schema.metadata
430 }
431
432 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434 let projected_schema = self.schema.project(indices)?;
435 let batch_fields = indices
436 .iter()
437 .map(|f| {
438 self.columns.get(*f).cloned().ok_or_else(|| {
439 ArrowError::SchemaError(format!(
440 "project index {} out of bounds, max field {}",
441 f,
442 self.columns.len()
443 ))
444 })
445 })
446 .collect::<Result<Vec<_>, _>>()?;
447
448 unsafe {
449 Ok(RecordBatch::new_unchecked(
453 SchemaRef::new(projected_schema),
454 batch_fields,
455 self.row_count,
456 ))
457 }
458 }
459
460 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
520 let max_level = match max_level.unwrap_or(usize::MAX) {
521 0 => usize::MAX,
522 val => val,
523 };
524 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
525 .columns
526 .iter()
527 .zip(self.schema.fields())
528 .rev()
529 .map(|(c, f)| {
530 let name_vec: Vec<&str> = vec![f.name()];
531 (0, c, name_vec, f)
532 })
533 .collect();
534 let mut columns: Vec<ArrayRef> = Vec::new();
535 let mut fields: Vec<FieldRef> = Vec::new();
536
537 while let Some((depth, c, name, field_ref)) = stack.pop() {
538 match field_ref.data_type() {
539 DataType::Struct(ff) if depth < max_level => {
540 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
542 let mut name = name.clone();
543 name.push(separator);
544 name.push(fff.name());
545 stack.push((depth + 1, cff, name, fff))
546 }
547 }
548 _ => {
549 let updated_field = Field::new(
550 name.concat(),
551 field_ref.data_type().clone(),
552 field_ref.is_nullable(),
553 );
554 columns.push(c.clone());
555 fields.push(Arc::new(updated_field));
556 }
557 }
558 }
559 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
560 }
561
562 pub fn num_columns(&self) -> usize {
581 self.columns.len()
582 }
583
584 pub fn num_rows(&self) -> usize {
603 self.row_count
604 }
605
606 pub fn column(&self, index: usize) -> &ArrayRef {
612 &self.columns[index]
613 }
614
615 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
617 self.schema()
618 .column_with_name(name)
619 .map(|(index, _)| &self.columns[index])
620 }
621
622 pub fn columns(&self) -> &[ArrayRef] {
624 &self.columns[..]
625 }
626
627 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
655 let mut builder = SchemaBuilder::from(self.schema.as_ref());
656 builder.remove(index);
657 self.schema = Arc::new(builder.finish());
658 self.columns.remove(index)
659 }
660
661 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
668 assert!((offset + length) <= self.num_rows());
669
670 let columns = self
671 .columns()
672 .iter()
673 .map(|column| column.slice(offset, length))
674 .collect();
675
676 Self {
677 schema: self.schema.clone(),
678 columns,
679 row_count: length,
680 }
681 }
682
683 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
720 where
721 I: IntoIterator<Item = (F, ArrayRef)>,
722 F: AsRef<str>,
723 {
724 let iter = value.into_iter().map(|(field_name, array)| {
728 let nullable = array.null_count() > 0;
729 (field_name, array, nullable)
730 });
731
732 Self::try_from_iter_with_nullable(iter)
733 }
734
735 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
757 where
758 I: IntoIterator<Item = (F, ArrayRef, bool)>,
759 F: AsRef<str>,
760 {
761 let iter = value.into_iter();
762 let capacity = iter.size_hint().0;
763 let mut schema = SchemaBuilder::with_capacity(capacity);
764 let mut columns = Vec::with_capacity(capacity);
765
766 for (field_name, array, nullable) in iter {
767 let field_name = field_name.as_ref();
768 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
769 columns.push(array);
770 }
771
772 let schema = Arc::new(schema.finish());
773 RecordBatch::try_new(schema, columns)
774 }
775
776 pub fn get_array_memory_size(&self) -> usize {
783 self.columns()
784 .iter()
785 .map(|array| array.get_array_memory_size())
786 .sum()
787 }
788}
789
790#[derive(Debug)]
792#[non_exhaustive]
793pub struct RecordBatchOptions {
794 pub match_field_names: bool,
796
797 pub row_count: Option<usize>,
799}
800
801impl RecordBatchOptions {
802 pub fn new() -> Self {
804 Self {
805 match_field_names: true,
806 row_count: None,
807 }
808 }
809 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
811 self.row_count = row_count;
812 self
813 }
814 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
816 self.match_field_names = match_field_names;
817 self
818 }
819}
820impl Default for RecordBatchOptions {
821 fn default() -> Self {
822 Self::new()
823 }
824}
825impl From<StructArray> for RecordBatch {
826 fn from(value: StructArray) -> Self {
827 let row_count = value.len();
828 let (fields, columns, nulls) = value.into_parts();
829 assert_eq!(
830 nulls.map(|n| n.null_count()).unwrap_or_default(),
831 0,
832 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
833 );
834
835 RecordBatch {
836 schema: Arc::new(Schema::new(fields)),
837 row_count,
838 columns,
839 }
840 }
841}
842
843impl From<&StructArray> for RecordBatch {
844 fn from(struct_array: &StructArray) -> Self {
845 struct_array.clone().into()
846 }
847}
848
849impl Index<&str> for RecordBatch {
850 type Output = ArrayRef;
851
852 fn index(&self, name: &str) -> &Self::Output {
858 self.column_by_name(name).unwrap()
859 }
860}
861
862pub struct RecordBatchIterator<I>
888where
889 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
890{
891 inner: I::IntoIter,
892 inner_schema: SchemaRef,
893}
894
895impl<I> RecordBatchIterator<I>
896where
897 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
898{
899 pub fn new(iter: I, schema: SchemaRef) -> Self {
903 Self {
904 inner: iter.into_iter(),
905 inner_schema: schema,
906 }
907 }
908}
909
910impl<I> Iterator for RecordBatchIterator<I>
911where
912 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
913{
914 type Item = I::Item;
915
916 fn next(&mut self) -> Option<Self::Item> {
917 self.inner.next()
918 }
919
920 fn size_hint(&self) -> (usize, Option<usize>) {
921 self.inner.size_hint()
922 }
923}
924
925impl<I> RecordBatchReader for RecordBatchIterator<I>
926where
927 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
928{
929 fn schema(&self) -> SchemaRef {
930 self.inner_schema.clone()
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937 use crate::{
938 BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
939 };
940 use arrow_buffer::{Buffer, ToByteSlice};
941 use arrow_data::{ArrayData, ArrayDataBuilder};
942 use arrow_schema::Fields;
943 use std::collections::HashMap;
944
945 #[test]
946 fn create_record_batch() {
947 let schema = Schema::new(vec![
948 Field::new("a", DataType::Int32, false),
949 Field::new("b", DataType::Utf8, false),
950 ]);
951
952 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
953 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
954
955 let record_batch =
956 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
957 check_batch(record_batch, 5)
958 }
959
960 #[test]
961 fn create_string_view_record_batch() {
962 let schema = Schema::new(vec![
963 Field::new("a", DataType::Int32, false),
964 Field::new("b", DataType::Utf8View, false),
965 ]);
966
967 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
968 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
969
970 let record_batch =
971 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
972
973 assert_eq!(5, record_batch.num_rows());
974 assert_eq!(2, record_batch.num_columns());
975 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
976 assert_eq!(
977 &DataType::Utf8View,
978 record_batch.schema().field(1).data_type()
979 );
980 assert_eq!(5, record_batch.column(0).len());
981 assert_eq!(5, record_batch.column(1).len());
982 }
983
984 #[test]
985 fn byte_size_should_not_regress() {
986 let schema = Schema::new(vec![
987 Field::new("a", DataType::Int32, false),
988 Field::new("b", DataType::Utf8, false),
989 ]);
990
991 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
992 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
993
994 let record_batch =
995 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
996 assert_eq!(record_batch.get_array_memory_size(), 364);
997 }
998
999 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
1000 assert_eq!(num_rows, record_batch.num_rows());
1001 assert_eq!(2, record_batch.num_columns());
1002 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1003 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1004 assert_eq!(num_rows, record_batch.column(0).len());
1005 assert_eq!(num_rows, record_batch.column(1).len());
1006 }
1007
1008 #[test]
1009 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1010 fn create_record_batch_slice() {
1011 let schema = Schema::new(vec![
1012 Field::new("a", DataType::Int32, false),
1013 Field::new("b", DataType::Utf8, false),
1014 ]);
1015 let expected_schema = schema.clone();
1016
1017 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1018 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1019
1020 let record_batch =
1021 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1022
1023 let offset = 2;
1024 let length = 5;
1025 let record_batch_slice = record_batch.slice(offset, length);
1026
1027 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1028 check_batch(record_batch_slice, 5);
1029
1030 let offset = 2;
1031 let length = 0;
1032 let record_batch_slice = record_batch.slice(offset, length);
1033
1034 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1035 check_batch(record_batch_slice, 0);
1036
1037 let offset = 2;
1038 let length = 10;
1039 let _record_batch_slice = record_batch.slice(offset, length);
1040 }
1041
1042 #[test]
1043 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1044 fn create_record_batch_slice_empty_batch() {
1045 let schema = Schema::empty();
1046
1047 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1048
1049 let offset = 0;
1050 let length = 0;
1051 let record_batch_slice = record_batch.slice(offset, length);
1052 assert_eq!(0, record_batch_slice.schema().fields().len());
1053
1054 let offset = 1;
1055 let length = 2;
1056 let _record_batch_slice = record_batch.slice(offset, length);
1057 }
1058
1059 #[test]
1060 fn create_record_batch_try_from_iter() {
1061 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1062 Some(1),
1063 Some(2),
1064 None,
1065 Some(4),
1066 Some(5),
1067 ]));
1068 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1069
1070 let record_batch =
1071 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1072
1073 let expected_schema = Schema::new(vec![
1074 Field::new("a", DataType::Int32, true),
1075 Field::new("b", DataType::Utf8, false),
1076 ]);
1077 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1078 check_batch(record_batch, 5);
1079 }
1080
1081 #[test]
1082 fn create_record_batch_try_from_iter_with_nullable() {
1083 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1084 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1085
1086 let record_batch =
1088 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1089 .expect("valid conversion");
1090
1091 let expected_schema = Schema::new(vec![
1092 Field::new("a", DataType::Int32, false),
1093 Field::new("b", DataType::Utf8, true),
1094 ]);
1095 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1096 check_batch(record_batch, 5);
1097 }
1098
1099 #[test]
1100 fn create_record_batch_schema_mismatch() {
1101 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1102
1103 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1104
1105 let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1106 assert_eq!(
1107 err.to_string(),
1108 "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1109 );
1110 }
1111
1112 #[test]
1113 fn create_record_batch_field_name_mismatch() {
1114 let fields = vec![
1115 Field::new("a1", DataType::Int32, false),
1116 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1117 ];
1118 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1119
1120 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1121 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1122 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1123 "array",
1124 DataType::Int8,
1125 false,
1126 ))))
1127 .add_child_data(a2_child.into_data())
1128 .len(2)
1129 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1130 .build()
1131 .unwrap();
1132 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1133 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1134 Field::new("aa1", DataType::Int32, false),
1135 Field::new("a2", a2.data_type().clone(), false),
1136 ])))
1137 .add_child_data(a1.into_data())
1138 .add_child_data(a2.into_data())
1139 .len(2)
1140 .build()
1141 .unwrap();
1142 let a: ArrayRef = Arc::new(StructArray::from(a));
1143
1144 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1146 assert!(batch.is_err());
1147
1148 let options = RecordBatchOptions {
1150 match_field_names: false,
1151 row_count: None,
1152 };
1153 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1154 assert!(batch.is_ok());
1155 }
1156
1157 #[test]
1158 fn create_record_batch_record_mismatch() {
1159 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1160
1161 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1162 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1163
1164 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1165 assert!(batch.is_err());
1166 }
1167
1168 #[test]
1169 fn create_record_batch_from_struct_array() {
1170 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1171 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1172 let struct_array = StructArray::from(vec![
1173 (
1174 Arc::new(Field::new("b", DataType::Boolean, false)),
1175 boolean.clone() as ArrayRef,
1176 ),
1177 (
1178 Arc::new(Field::new("c", DataType::Int32, false)),
1179 int.clone() as ArrayRef,
1180 ),
1181 ]);
1182
1183 let batch = RecordBatch::from(&struct_array);
1184 assert_eq!(2, batch.num_columns());
1185 assert_eq!(4, batch.num_rows());
1186 assert_eq!(
1187 struct_array.data_type(),
1188 &DataType::Struct(batch.schema().fields().clone())
1189 );
1190 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1191 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1192 }
1193
1194 #[test]
1195 fn record_batch_equality() {
1196 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1197 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1198 let schema1 = Schema::new(vec![
1199 Field::new("id", DataType::Int32, false),
1200 Field::new("val", DataType::Int32, false),
1201 ]);
1202
1203 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1204 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1205 let schema2 = Schema::new(vec![
1206 Field::new("id", DataType::Int32, false),
1207 Field::new("val", DataType::Int32, false),
1208 ]);
1209
1210 let batch1 = RecordBatch::try_new(
1211 Arc::new(schema1),
1212 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1213 )
1214 .unwrap();
1215
1216 let batch2 = RecordBatch::try_new(
1217 Arc::new(schema2),
1218 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1219 )
1220 .unwrap();
1221
1222 assert_eq!(batch1, batch2);
1223 }
1224
1225 #[test]
1227 fn record_batch_index_access() {
1228 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1229 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1230 let schema1 = Schema::new(vec![
1231 Field::new("id", DataType::Int32, false),
1232 Field::new("val", DataType::Int32, false),
1233 ]);
1234 let record_batch =
1235 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1236
1237 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1238 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1239 }
1240
1241 #[test]
1242 fn record_batch_vals_ne() {
1243 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1244 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1245 let schema1 = Schema::new(vec![
1246 Field::new("id", DataType::Int32, false),
1247 Field::new("val", DataType::Int32, false),
1248 ]);
1249
1250 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1251 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1252 let schema2 = Schema::new(vec![
1253 Field::new("id", DataType::Int32, false),
1254 Field::new("val", DataType::Int32, false),
1255 ]);
1256
1257 let batch1 = RecordBatch::try_new(
1258 Arc::new(schema1),
1259 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1260 )
1261 .unwrap();
1262
1263 let batch2 = RecordBatch::try_new(
1264 Arc::new(schema2),
1265 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1266 )
1267 .unwrap();
1268
1269 assert_ne!(batch1, batch2);
1270 }
1271
1272 #[test]
1273 fn record_batch_column_names_ne() {
1274 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1275 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1276 let schema1 = Schema::new(vec![
1277 Field::new("id", DataType::Int32, false),
1278 Field::new("val", DataType::Int32, false),
1279 ]);
1280
1281 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1282 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283 let schema2 = Schema::new(vec![
1284 Field::new("id", DataType::Int32, false),
1285 Field::new("num", DataType::Int32, false),
1286 ]);
1287
1288 let batch1 = RecordBatch::try_new(
1289 Arc::new(schema1),
1290 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1291 )
1292 .unwrap();
1293
1294 let batch2 = RecordBatch::try_new(
1295 Arc::new(schema2),
1296 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1297 )
1298 .unwrap();
1299
1300 assert_ne!(batch1, batch2);
1301 }
1302
1303 #[test]
1304 fn record_batch_column_number_ne() {
1305 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1306 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1307 let schema1 = Schema::new(vec![
1308 Field::new("id", DataType::Int32, false),
1309 Field::new("val", DataType::Int32, false),
1310 ]);
1311
1312 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1313 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1314 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315 let schema2 = Schema::new(vec![
1316 Field::new("id", DataType::Int32, false),
1317 Field::new("val", DataType::Int32, false),
1318 Field::new("num", DataType::Int32, false),
1319 ]);
1320
1321 let batch1 = RecordBatch::try_new(
1322 Arc::new(schema1),
1323 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1324 )
1325 .unwrap();
1326
1327 let batch2 = RecordBatch::try_new(
1328 Arc::new(schema2),
1329 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1330 )
1331 .unwrap();
1332
1333 assert_ne!(batch1, batch2);
1334 }
1335
1336 #[test]
1337 fn record_batch_row_count_ne() {
1338 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1339 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1340 let schema1 = Schema::new(vec![
1341 Field::new("id", DataType::Int32, false),
1342 Field::new("val", DataType::Int32, false),
1343 ]);
1344
1345 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1346 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1347 let schema2 = Schema::new(vec![
1348 Field::new("id", DataType::Int32, false),
1349 Field::new("num", DataType::Int32, false),
1350 ]);
1351
1352 let batch1 = RecordBatch::try_new(
1353 Arc::new(schema1),
1354 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1355 )
1356 .unwrap();
1357
1358 let batch2 = RecordBatch::try_new(
1359 Arc::new(schema2),
1360 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1361 )
1362 .unwrap();
1363
1364 assert_ne!(batch1, batch2);
1365 }
1366
1367 #[test]
1368 fn normalize_simple() {
1369 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1370 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1371 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1372
1373 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1374 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1375 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1376
1377 let a = Arc::new(StructArray::from(vec![
1378 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1379 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1380 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1381 ]));
1382
1383 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1384
1385 let schema = Schema::new(vec![
1386 Field::new(
1387 "a",
1388 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1389 false,
1390 ),
1391 Field::new("month", DataType::Int64, true),
1392 ]);
1393
1394 let normalized =
1395 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1396 .expect("valid conversion")
1397 .normalize(".", Some(0))
1398 .expect("valid normalization");
1399
1400 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1401 ("a.animals", animals.clone(), true),
1402 ("a.n_legs", n_legs.clone(), true),
1403 ("a.year", year.clone(), true),
1404 ("month", month.clone(), true),
1405 ])
1406 .expect("valid conversion");
1407
1408 assert_eq!(expected, normalized);
1409
1410 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1412 .expect("valid conversion")
1413 .normalize(".", None)
1414 .expect("valid normalization");
1415
1416 assert_eq!(expected, normalized);
1417 }
1418
1419 #[test]
1420 fn normalize_nested() {
1421 let a = Arc::new(Field::new("a", DataType::Int64, true));
1423 let b = Arc::new(Field::new("b", DataType::Int64, false));
1424 let c = Arc::new(Field::new("c", DataType::Int64, true));
1425
1426 let one = Arc::new(Field::new(
1427 "1",
1428 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1429 false,
1430 ));
1431 let two = Arc::new(Field::new(
1432 "2",
1433 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1434 true,
1435 ));
1436
1437 let exclamation = Arc::new(Field::new(
1438 "!",
1439 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1440 false,
1441 ));
1442
1443 let schema = Schema::new(vec![exclamation.clone()]);
1444
1445 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1447 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1448 let c_field = Int64Array::from(vec![None, Some(4)]);
1449
1450 let one_field = StructArray::from(vec![
1451 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1452 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1453 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1454 ]);
1455 let two_field = StructArray::from(vec![
1456 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1457 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1458 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1459 ]);
1460
1461 let exclamation_field = Arc::new(StructArray::from(vec![
1462 (one.clone(), Arc::new(one_field) as ArrayRef),
1463 (two.clone(), Arc::new(two_field) as ArrayRef),
1464 ]));
1465
1466 let normalized =
1468 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1469 .expect("valid conversion")
1470 .normalize(".", Some(1))
1471 .expect("valid normalization");
1472
1473 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1474 (
1475 "!.1",
1476 Arc::new(StructArray::from(vec![
1477 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1478 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1479 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1480 ])) as ArrayRef,
1481 false,
1482 ),
1483 (
1484 "!.2",
1485 Arc::new(StructArray::from(vec![
1486 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1487 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1488 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1489 ])) as ArrayRef,
1490 true,
1491 ),
1492 ])
1493 .expect("valid conversion");
1494
1495 assert_eq!(expected, normalized);
1496
1497 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1499 .expect("valid conversion")
1500 .normalize(".", None)
1501 .expect("valid normalization");
1502
1503 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1504 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1505 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1506 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1507 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1508 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1509 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1510 ])
1511 .expect("valid conversion");
1512
1513 assert_eq!(expected, normalized);
1514 }
1515
1516 #[test]
1517 fn normalize_empty() {
1518 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1519 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1520 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1521
1522 let schema = Schema::new(vec![
1523 Field::new(
1524 "a",
1525 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1526 false,
1527 ),
1528 Field::new("month", DataType::Int64, true),
1529 ]);
1530
1531 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1532 .normalize(".", Some(0))
1533 .expect("valid normalization");
1534
1535 let expected = RecordBatch::new_empty(Arc::new(
1536 schema.normalize(".", Some(0)).expect("valid normalization"),
1537 ));
1538
1539 assert_eq!(expected, normalized);
1540 }
1541
1542 #[test]
1543 fn project() {
1544 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1545 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1546 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1547
1548 let record_batch =
1549 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1550 .expect("valid conversion");
1551
1552 let expected =
1553 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1554
1555 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1556 }
1557
1558 #[test]
1559 fn project_empty() {
1560 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1561
1562 let record_batch =
1563 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1564
1565 let expected = RecordBatch::try_new_with_options(
1566 Arc::new(Schema::empty()),
1567 vec![],
1568 &RecordBatchOptions {
1569 match_field_names: true,
1570 row_count: Some(3),
1571 },
1572 )
1573 .expect("valid conversion");
1574
1575 assert_eq!(expected, record_batch.project(&[]).unwrap());
1576 }
1577
1578 #[test]
1579 fn test_no_column_record_batch() {
1580 let schema = Arc::new(Schema::empty());
1581
1582 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1583 assert!(
1584 err.to_string()
1585 .contains("must either specify a row count or at least one column")
1586 );
1587
1588 let options = RecordBatchOptions::new().with_row_count(Some(10));
1589
1590 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1591 assert_eq!(ok.num_rows(), 10);
1592
1593 let a = ok.slice(2, 5);
1594 assert_eq!(a.num_rows(), 5);
1595
1596 let b = ok.slice(5, 0);
1597 assert_eq!(b.num_rows(), 0);
1598
1599 assert_ne!(a, b);
1600 assert_eq!(b, RecordBatch::new_empty(schema))
1601 }
1602
1603 #[test]
1604 fn test_nulls_in_non_nullable_field() {
1605 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1606 let maybe_batch = RecordBatch::try_new(
1607 schema,
1608 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1609 );
1610 assert_eq!(
1611 "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1612 format!("{}", maybe_batch.err().unwrap())
1613 );
1614 }
1615 #[test]
1616 fn test_record_batch_options() {
1617 let options = RecordBatchOptions::new()
1618 .with_match_field_names(false)
1619 .with_row_count(Some(20));
1620 assert!(!options.match_field_names);
1621 assert_eq!(options.row_count.unwrap(), 20)
1622 }
1623
1624 #[test]
1625 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1626 fn test_from_struct() {
1627 let s = StructArray::from(ArrayData::new_null(
1628 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1630 2,
1631 ));
1632 let _ = RecordBatch::from(s);
1633 }
1634
1635 #[test]
1636 fn test_with_schema() {
1637 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1638 let required_schema = Arc::new(required_schema);
1639 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1640 let nullable_schema = Arc::new(nullable_schema);
1641
1642 let batch = RecordBatch::try_new(
1643 required_schema.clone(),
1644 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1645 )
1646 .unwrap();
1647
1648 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1650
1651 batch.clone().with_schema(required_schema).unwrap_err();
1653
1654 let metadata = vec![("foo".to_string(), "bar".to_string())]
1656 .into_iter()
1657 .collect();
1658 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1659 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1660
1661 batch.with_schema(nullable_schema).unwrap_err();
1663 }
1664
1665 #[test]
1666 fn test_boxed_reader() {
1667 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1670 let schema = Arc::new(schema);
1671
1672 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1673 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1674
1675 fn get_size(reader: impl RecordBatchReader) -> usize {
1676 reader.size_hint().0
1677 }
1678
1679 let size = get_size(reader);
1680 assert_eq!(size, 0);
1681 }
1682
1683 #[test]
1684 fn test_remove_column_maintains_schema_metadata() {
1685 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1686 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1687
1688 let mut metadata = HashMap::new();
1689 metadata.insert("foo".to_string(), "bar".to_string());
1690 let schema = Schema::new(vec![
1691 Field::new("id", DataType::Int32, false),
1692 Field::new("bool", DataType::Boolean, false),
1693 ])
1694 .with_metadata(metadata);
1695
1696 let mut batch = RecordBatch::try_new(
1697 Arc::new(schema),
1698 vec![Arc::new(id_array), Arc::new(bool_array)],
1699 )
1700 .unwrap();
1701
1702 let _removed_column = batch.remove_column(0);
1703 assert_eq!(batch.schema().metadata().len(), 1);
1704 assert_eq!(
1705 batch.schema().metadata().get("foo").unwrap().as_str(),
1706 "bar"
1707 );
1708 }
1709}