1use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal128) => { $crate::Decimal128Array };
111 (@from Decimal256) => { $crate::Decimal256Array };
112 (@from TimestampSecond) => { $crate::TimestampSecondArray };
113 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
114 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
115 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
116
117 (@from $ty: ident) => {
118 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
119 };
120
121 (Null, $size: expr) => {
122 std::sync::Arc::new($crate::NullArray::new($size))
123 };
124
125 (Binary, [$($values: expr),*]) => {
126 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
127 };
128
129 (LargeBinary, [$($values: expr),*]) => {
130 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
131 };
132
133 ($ty: tt, [$($values: expr),*]) => {
134 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
135 };
136}
137
138#[macro_export]
155macro_rules! record_batch {
156 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
157 {
158 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
159 $(
160 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
161 )*
162 ]));
163
164 let batch = $crate::RecordBatch::try_new(
165 schema,
166 vec![$(
167 $crate::create_array!($type, [$($values),*]),
168 )*]
169 );
170
171 batch
172 }
173 }
174}
175
176#[derive(Clone, Debug, PartialEq)]
200pub struct RecordBatch {
201 schema: SchemaRef,
202 columns: Vec<Arc<dyn Array>>,
203
204 row_count: usize,
208}
209
210impl RecordBatch {
211 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
240 let options = RecordBatchOptions::new();
241 Self::try_new_impl(schema, columns, &options)
242 }
243
244 pub unsafe fn new_unchecked(
260 schema: SchemaRef,
261 columns: Vec<Arc<dyn Array>>,
262 row_count: usize,
263 ) -> Self {
264 Self {
265 schema,
266 columns,
267 row_count,
268 }
269 }
270
271 pub fn try_new_with_options(
276 schema: SchemaRef,
277 columns: Vec<ArrayRef>,
278 options: &RecordBatchOptions,
279 ) -> Result<Self, ArrowError> {
280 Self::try_new_impl(schema, columns, options)
281 }
282
283 pub fn new_empty(schema: SchemaRef) -> Self {
285 let columns = schema
286 .fields()
287 .iter()
288 .map(|field| new_empty_array(field.data_type()))
289 .collect();
290
291 RecordBatch {
292 schema,
293 columns,
294 row_count: 0,
295 }
296 }
297
298 fn try_new_impl(
301 schema: SchemaRef,
302 columns: Vec<ArrayRef>,
303 options: &RecordBatchOptions,
304 ) -> Result<Self, ArrowError> {
305 if schema.fields().len() != columns.len() {
307 return Err(ArrowError::InvalidArgumentError(format!(
308 "number of columns({}) must match number of fields({}) in schema",
309 columns.len(),
310 schema.fields().len(),
311 )));
312 }
313
314 let row_count = options
315 .row_count
316 .or_else(|| columns.first().map(|col| col.len()))
317 .ok_or_else(|| {
318 ArrowError::InvalidArgumentError(
319 "must either specify a row count or at least one column".to_string(),
320 )
321 })?;
322
323 for (c, f) in columns.iter().zip(&schema.fields) {
324 if !f.is_nullable() && c.null_count() > 0 {
325 return Err(ArrowError::InvalidArgumentError(format!(
326 "Column '{}' is declared as non-nullable but contains null values",
327 f.name()
328 )));
329 }
330 }
331
332 if columns.iter().any(|c| c.len() != row_count) {
334 let err = match options.row_count {
335 Some(_) => "all columns in a record batch must have the specified row count",
336 None => "all columns in a record batch must have the same length",
337 };
338 return Err(ArrowError::InvalidArgumentError(err.to_string()));
339 }
340
341 let type_not_match = if options.match_field_names {
344 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
345 } else {
346 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
347 !col_type.equals_datatype(field_type)
348 }
349 };
350
351 let not_match = columns
353 .iter()
354 .zip(schema.fields().iter())
355 .map(|(col, field)| (col.data_type(), field.data_type()))
356 .enumerate()
357 .find(type_not_match);
358
359 if let Some((i, (col_type, field_type))) = not_match {
360 return Err(ArrowError::InvalidArgumentError(format!(
361 "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
362 }
363
364 Ok(RecordBatch {
365 schema,
366 columns,
367 row_count,
368 })
369 }
370
371 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
373 (self.schema, self.columns, self.row_count)
374 }
375
376 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
383 if !schema.contains(self.schema.as_ref()) {
384 return Err(ArrowError::SchemaError(format!(
385 "target schema is not superset of current schema target={schema} current={}",
386 self.schema
387 )));
388 }
389
390 Ok(Self {
391 schema,
392 columns: self.columns,
393 row_count: self.row_count,
394 })
395 }
396
397 pub fn schema(&self) -> SchemaRef {
399 self.schema.clone()
400 }
401
402 pub fn schema_ref(&self) -> &SchemaRef {
404 &self.schema
405 }
406
407 pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
425 let schema = Arc::make_mut(&mut self.schema);
426 &mut schema.metadata
427 }
428
429 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
431 let projected_schema = self.schema.project(indices)?;
432 let batch_fields = indices
433 .iter()
434 .map(|f| {
435 self.columns.get(*f).cloned().ok_or_else(|| {
436 ArrowError::SchemaError(format!(
437 "project index {} out of bounds, max field {}",
438 f,
439 self.columns.len()
440 ))
441 })
442 })
443 .collect::<Result<Vec<_>, _>>()?;
444
445 RecordBatch::try_new_with_options(
446 SchemaRef::new(projected_schema),
447 batch_fields,
448 &RecordBatchOptions {
449 match_field_names: true,
450 row_count: Some(self.row_count),
451 },
452 )
453 }
454
455 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
515 let max_level = match max_level.unwrap_or(usize::MAX) {
516 0 => usize::MAX,
517 val => val,
518 };
519 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
520 .columns
521 .iter()
522 .zip(self.schema.fields())
523 .rev()
524 .map(|(c, f)| {
525 let name_vec: Vec<&str> = vec![f.name()];
526 (0, c, name_vec, f)
527 })
528 .collect();
529 let mut columns: Vec<ArrayRef> = Vec::new();
530 let mut fields: Vec<FieldRef> = Vec::new();
531
532 while let Some((depth, c, name, field_ref)) = stack.pop() {
533 match field_ref.data_type() {
534 DataType::Struct(ff) if depth < max_level => {
535 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
537 let mut name = name.clone();
538 name.push(separator);
539 name.push(fff.name());
540 stack.push((depth + 1, cff, name, fff))
541 }
542 }
543 _ => {
544 let updated_field = Field::new(
545 name.concat(),
546 field_ref.data_type().clone(),
547 field_ref.is_nullable(),
548 );
549 columns.push(c.clone());
550 fields.push(Arc::new(updated_field));
551 }
552 }
553 }
554 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
555 }
556
557 pub fn num_columns(&self) -> usize {
576 self.columns.len()
577 }
578
579 pub fn num_rows(&self) -> usize {
598 self.row_count
599 }
600
601 pub fn column(&self, index: usize) -> &ArrayRef {
607 &self.columns[index]
608 }
609
610 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
612 self.schema()
613 .column_with_name(name)
614 .map(|(index, _)| &self.columns[index])
615 }
616
617 pub fn columns(&self) -> &[ArrayRef] {
619 &self.columns[..]
620 }
621
622 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
650 let mut builder = SchemaBuilder::from(self.schema.as_ref());
651 builder.remove(index);
652 self.schema = Arc::new(builder.finish());
653 self.columns.remove(index)
654 }
655
656 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
663 assert!((offset + length) <= self.num_rows());
664
665 let columns = self
666 .columns()
667 .iter()
668 .map(|column| column.slice(offset, length))
669 .collect();
670
671 Self {
672 schema: self.schema.clone(),
673 columns,
674 row_count: length,
675 }
676 }
677
678 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
715 where
716 I: IntoIterator<Item = (F, ArrayRef)>,
717 F: AsRef<str>,
718 {
719 let iter = value.into_iter().map(|(field_name, array)| {
723 let nullable = array.null_count() > 0;
724 (field_name, array, nullable)
725 });
726
727 Self::try_from_iter_with_nullable(iter)
728 }
729
730 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
752 where
753 I: IntoIterator<Item = (F, ArrayRef, bool)>,
754 F: AsRef<str>,
755 {
756 let iter = value.into_iter();
757 let capacity = iter.size_hint().0;
758 let mut schema = SchemaBuilder::with_capacity(capacity);
759 let mut columns = Vec::with_capacity(capacity);
760
761 for (field_name, array, nullable) in iter {
762 let field_name = field_name.as_ref();
763 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
764 columns.push(array);
765 }
766
767 let schema = Arc::new(schema.finish());
768 RecordBatch::try_new(schema, columns)
769 }
770
771 pub fn get_array_memory_size(&self) -> usize {
778 self.columns()
779 .iter()
780 .map(|array| array.get_array_memory_size())
781 .sum()
782 }
783}
784
785#[derive(Debug)]
787#[non_exhaustive]
788pub struct RecordBatchOptions {
789 pub match_field_names: bool,
791
792 pub row_count: Option<usize>,
794}
795
796impl RecordBatchOptions {
797 pub fn new() -> Self {
799 Self {
800 match_field_names: true,
801 row_count: None,
802 }
803 }
804 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
806 self.row_count = row_count;
807 self
808 }
809 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
811 self.match_field_names = match_field_names;
812 self
813 }
814}
815impl Default for RecordBatchOptions {
816 fn default() -> Self {
817 Self::new()
818 }
819}
820impl From<StructArray> for RecordBatch {
821 fn from(value: StructArray) -> Self {
822 let row_count = value.len();
823 let (fields, columns, nulls) = value.into_parts();
824 assert_eq!(
825 nulls.map(|n| n.null_count()).unwrap_or_default(),
826 0,
827 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
828 );
829
830 RecordBatch {
831 schema: Arc::new(Schema::new(fields)),
832 row_count,
833 columns,
834 }
835 }
836}
837
838impl From<&StructArray> for RecordBatch {
839 fn from(struct_array: &StructArray) -> Self {
840 struct_array.clone().into()
841 }
842}
843
844impl Index<&str> for RecordBatch {
845 type Output = ArrayRef;
846
847 fn index(&self, name: &str) -> &Self::Output {
853 self.column_by_name(name).unwrap()
854 }
855}
856
857pub struct RecordBatchIterator<I>
883where
884 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
885{
886 inner: I::IntoIter,
887 inner_schema: SchemaRef,
888}
889
890impl<I> RecordBatchIterator<I>
891where
892 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
893{
894 pub fn new(iter: I, schema: SchemaRef) -> Self {
898 Self {
899 inner: iter.into_iter(),
900 inner_schema: schema,
901 }
902 }
903}
904
905impl<I> Iterator for RecordBatchIterator<I>
906where
907 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
908{
909 type Item = I::Item;
910
911 fn next(&mut self) -> Option<Self::Item> {
912 self.inner.next()
913 }
914
915 fn size_hint(&self) -> (usize, Option<usize>) {
916 self.inner.size_hint()
917 }
918}
919
920impl<I> RecordBatchReader for RecordBatchIterator<I>
921where
922 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
923{
924 fn schema(&self) -> SchemaRef {
925 self.inner_schema.clone()
926 }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932 use crate::{
933 BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
934 };
935 use arrow_buffer::{Buffer, ToByteSlice};
936 use arrow_data::{ArrayData, ArrayDataBuilder};
937 use arrow_schema::Fields;
938 use std::collections::HashMap;
939
940 #[test]
941 fn create_record_batch() {
942 let schema = Schema::new(vec![
943 Field::new("a", DataType::Int32, false),
944 Field::new("b", DataType::Utf8, false),
945 ]);
946
947 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
948 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
949
950 let record_batch =
951 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
952 check_batch(record_batch, 5)
953 }
954
955 #[test]
956 fn create_string_view_record_batch() {
957 let schema = Schema::new(vec![
958 Field::new("a", DataType::Int32, false),
959 Field::new("b", DataType::Utf8View, false),
960 ]);
961
962 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
963 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
964
965 let record_batch =
966 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
967
968 assert_eq!(5, record_batch.num_rows());
969 assert_eq!(2, record_batch.num_columns());
970 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
971 assert_eq!(
972 &DataType::Utf8View,
973 record_batch.schema().field(1).data_type()
974 );
975 assert_eq!(5, record_batch.column(0).len());
976 assert_eq!(5, record_batch.column(1).len());
977 }
978
979 #[test]
980 fn byte_size_should_not_regress() {
981 let schema = Schema::new(vec![
982 Field::new("a", DataType::Int32, false),
983 Field::new("b", DataType::Utf8, false),
984 ]);
985
986 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
987 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
988
989 let record_batch =
990 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
991 assert_eq!(record_batch.get_array_memory_size(), 364);
992 }
993
994 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
995 assert_eq!(num_rows, record_batch.num_rows());
996 assert_eq!(2, record_batch.num_columns());
997 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
998 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
999 assert_eq!(num_rows, record_batch.column(0).len());
1000 assert_eq!(num_rows, record_batch.column(1).len());
1001 }
1002
1003 #[test]
1004 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1005 fn create_record_batch_slice() {
1006 let schema = Schema::new(vec![
1007 Field::new("a", DataType::Int32, false),
1008 Field::new("b", DataType::Utf8, false),
1009 ]);
1010 let expected_schema = schema.clone();
1011
1012 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1013 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1014
1015 let record_batch =
1016 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1017
1018 let offset = 2;
1019 let length = 5;
1020 let record_batch_slice = record_batch.slice(offset, length);
1021
1022 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1023 check_batch(record_batch_slice, 5);
1024
1025 let offset = 2;
1026 let length = 0;
1027 let record_batch_slice = record_batch.slice(offset, length);
1028
1029 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1030 check_batch(record_batch_slice, 0);
1031
1032 let offset = 2;
1033 let length = 10;
1034 let _record_batch_slice = record_batch.slice(offset, length);
1035 }
1036
1037 #[test]
1038 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1039 fn create_record_batch_slice_empty_batch() {
1040 let schema = Schema::empty();
1041
1042 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1043
1044 let offset = 0;
1045 let length = 0;
1046 let record_batch_slice = record_batch.slice(offset, length);
1047 assert_eq!(0, record_batch_slice.schema().fields().len());
1048
1049 let offset = 1;
1050 let length = 2;
1051 let _record_batch_slice = record_batch.slice(offset, length);
1052 }
1053
1054 #[test]
1055 fn create_record_batch_try_from_iter() {
1056 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1057 Some(1),
1058 Some(2),
1059 None,
1060 Some(4),
1061 Some(5),
1062 ]));
1063 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1064
1065 let record_batch =
1066 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1067
1068 let expected_schema = Schema::new(vec![
1069 Field::new("a", DataType::Int32, true),
1070 Field::new("b", DataType::Utf8, false),
1071 ]);
1072 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1073 check_batch(record_batch, 5);
1074 }
1075
1076 #[test]
1077 fn create_record_batch_try_from_iter_with_nullable() {
1078 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1079 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1080
1081 let record_batch =
1083 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1084 .expect("valid conversion");
1085
1086 let expected_schema = Schema::new(vec![
1087 Field::new("a", DataType::Int32, false),
1088 Field::new("b", DataType::Utf8, true),
1089 ]);
1090 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1091 check_batch(record_batch, 5);
1092 }
1093
1094 #[test]
1095 fn create_record_batch_schema_mismatch() {
1096 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1097
1098 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1099
1100 let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1101 assert_eq!(err.to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0");
1102 }
1103
1104 #[test]
1105 fn create_record_batch_field_name_mismatch() {
1106 let fields = vec![
1107 Field::new("a1", DataType::Int32, false),
1108 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1109 ];
1110 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1111
1112 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1113 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1114 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1115 "array",
1116 DataType::Int8,
1117 false,
1118 ))))
1119 .add_child_data(a2_child.into_data())
1120 .len(2)
1121 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1122 .build()
1123 .unwrap();
1124 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1125 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1126 Field::new("aa1", DataType::Int32, false),
1127 Field::new("a2", a2.data_type().clone(), false),
1128 ])))
1129 .add_child_data(a1.into_data())
1130 .add_child_data(a2.into_data())
1131 .len(2)
1132 .build()
1133 .unwrap();
1134 let a: ArrayRef = Arc::new(StructArray::from(a));
1135
1136 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1138 assert!(batch.is_err());
1139
1140 let options = RecordBatchOptions {
1142 match_field_names: false,
1143 row_count: None,
1144 };
1145 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1146 assert!(batch.is_ok());
1147 }
1148
1149 #[test]
1150 fn create_record_batch_record_mismatch() {
1151 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1152
1153 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1154 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1155
1156 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1157 assert!(batch.is_err());
1158 }
1159
1160 #[test]
1161 fn create_record_batch_from_struct_array() {
1162 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1163 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1164 let struct_array = StructArray::from(vec![
1165 (
1166 Arc::new(Field::new("b", DataType::Boolean, false)),
1167 boolean.clone() as ArrayRef,
1168 ),
1169 (
1170 Arc::new(Field::new("c", DataType::Int32, false)),
1171 int.clone() as ArrayRef,
1172 ),
1173 ]);
1174
1175 let batch = RecordBatch::from(&struct_array);
1176 assert_eq!(2, batch.num_columns());
1177 assert_eq!(4, batch.num_rows());
1178 assert_eq!(
1179 struct_array.data_type(),
1180 &DataType::Struct(batch.schema().fields().clone())
1181 );
1182 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1183 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1184 }
1185
1186 #[test]
1187 fn record_batch_equality() {
1188 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1189 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1190 let schema1 = Schema::new(vec![
1191 Field::new("id", DataType::Int32, false),
1192 Field::new("val", DataType::Int32, false),
1193 ]);
1194
1195 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1196 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1197 let schema2 = Schema::new(vec![
1198 Field::new("id", DataType::Int32, false),
1199 Field::new("val", DataType::Int32, false),
1200 ]);
1201
1202 let batch1 = RecordBatch::try_new(
1203 Arc::new(schema1),
1204 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1205 )
1206 .unwrap();
1207
1208 let batch2 = RecordBatch::try_new(
1209 Arc::new(schema2),
1210 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1211 )
1212 .unwrap();
1213
1214 assert_eq!(batch1, batch2);
1215 }
1216
1217 #[test]
1219 fn record_batch_index_access() {
1220 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1221 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1222 let schema1 = Schema::new(vec![
1223 Field::new("id", DataType::Int32, false),
1224 Field::new("val", DataType::Int32, false),
1225 ]);
1226 let record_batch =
1227 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1228
1229 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1230 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1231 }
1232
1233 #[test]
1234 fn record_batch_vals_ne() {
1235 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1236 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1237 let schema1 = Schema::new(vec![
1238 Field::new("id", DataType::Int32, false),
1239 Field::new("val", DataType::Int32, false),
1240 ]);
1241
1242 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1243 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1244 let schema2 = Schema::new(vec![
1245 Field::new("id", DataType::Int32, false),
1246 Field::new("val", DataType::Int32, false),
1247 ]);
1248
1249 let batch1 = RecordBatch::try_new(
1250 Arc::new(schema1),
1251 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1252 )
1253 .unwrap();
1254
1255 let batch2 = RecordBatch::try_new(
1256 Arc::new(schema2),
1257 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1258 )
1259 .unwrap();
1260
1261 assert_ne!(batch1, batch2);
1262 }
1263
1264 #[test]
1265 fn record_batch_column_names_ne() {
1266 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1267 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1268 let schema1 = Schema::new(vec![
1269 Field::new("id", DataType::Int32, false),
1270 Field::new("val", DataType::Int32, false),
1271 ]);
1272
1273 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1274 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1275 let schema2 = Schema::new(vec![
1276 Field::new("id", DataType::Int32, false),
1277 Field::new("num", DataType::Int32, false),
1278 ]);
1279
1280 let batch1 = RecordBatch::try_new(
1281 Arc::new(schema1),
1282 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1283 )
1284 .unwrap();
1285
1286 let batch2 = RecordBatch::try_new(
1287 Arc::new(schema2),
1288 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1289 )
1290 .unwrap();
1291
1292 assert_ne!(batch1, batch2);
1293 }
1294
1295 #[test]
1296 fn record_batch_column_number_ne() {
1297 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1298 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1299 let schema1 = Schema::new(vec![
1300 Field::new("id", DataType::Int32, false),
1301 Field::new("val", DataType::Int32, false),
1302 ]);
1303
1304 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1305 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1306 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1307 let schema2 = Schema::new(vec![
1308 Field::new("id", DataType::Int32, false),
1309 Field::new("val", DataType::Int32, false),
1310 Field::new("num", DataType::Int32, false),
1311 ]);
1312
1313 let batch1 = RecordBatch::try_new(
1314 Arc::new(schema1),
1315 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1316 )
1317 .unwrap();
1318
1319 let batch2 = RecordBatch::try_new(
1320 Arc::new(schema2),
1321 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1322 )
1323 .unwrap();
1324
1325 assert_ne!(batch1, batch2);
1326 }
1327
1328 #[test]
1329 fn record_batch_row_count_ne() {
1330 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1331 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1332 let schema1 = Schema::new(vec![
1333 Field::new("id", DataType::Int32, false),
1334 Field::new("val", DataType::Int32, false),
1335 ]);
1336
1337 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1338 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1339 let schema2 = Schema::new(vec![
1340 Field::new("id", DataType::Int32, false),
1341 Field::new("num", DataType::Int32, false),
1342 ]);
1343
1344 let batch1 = RecordBatch::try_new(
1345 Arc::new(schema1),
1346 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1347 )
1348 .unwrap();
1349
1350 let batch2 = RecordBatch::try_new(
1351 Arc::new(schema2),
1352 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1353 )
1354 .unwrap();
1355
1356 assert_ne!(batch1, batch2);
1357 }
1358
1359 #[test]
1360 fn normalize_simple() {
1361 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1362 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1363 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1364
1365 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1366 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1367 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1368
1369 let a = Arc::new(StructArray::from(vec![
1370 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1371 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1372 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1373 ]));
1374
1375 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1376
1377 let schema = Schema::new(vec![
1378 Field::new(
1379 "a",
1380 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1381 false,
1382 ),
1383 Field::new("month", DataType::Int64, true),
1384 ]);
1385
1386 let normalized =
1387 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1388 .expect("valid conversion")
1389 .normalize(".", Some(0))
1390 .expect("valid normalization");
1391
1392 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1393 ("a.animals", animals.clone(), true),
1394 ("a.n_legs", n_legs.clone(), true),
1395 ("a.year", year.clone(), true),
1396 ("month", month.clone(), true),
1397 ])
1398 .expect("valid conversion");
1399
1400 assert_eq!(expected, normalized);
1401
1402 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1404 .expect("valid conversion")
1405 .normalize(".", None)
1406 .expect("valid normalization");
1407
1408 assert_eq!(expected, normalized);
1409 }
1410
1411 #[test]
1412 fn normalize_nested() {
1413 let a = Arc::new(Field::new("a", DataType::Int64, true));
1415 let b = Arc::new(Field::new("b", DataType::Int64, false));
1416 let c = Arc::new(Field::new("c", DataType::Int64, true));
1417
1418 let one = Arc::new(Field::new(
1419 "1",
1420 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1421 false,
1422 ));
1423 let two = Arc::new(Field::new(
1424 "2",
1425 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1426 true,
1427 ));
1428
1429 let exclamation = Arc::new(Field::new(
1430 "!",
1431 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1432 false,
1433 ));
1434
1435 let schema = Schema::new(vec![exclamation.clone()]);
1436
1437 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1439 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1440 let c_field = Int64Array::from(vec![None, Some(4)]);
1441
1442 let one_field = StructArray::from(vec![
1443 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1444 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1445 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1446 ]);
1447 let two_field = StructArray::from(vec![
1448 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1449 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1450 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1451 ]);
1452
1453 let exclamation_field = Arc::new(StructArray::from(vec![
1454 (one.clone(), Arc::new(one_field) as ArrayRef),
1455 (two.clone(), Arc::new(two_field) as ArrayRef),
1456 ]));
1457
1458 let normalized =
1460 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1461 .expect("valid conversion")
1462 .normalize(".", Some(1))
1463 .expect("valid normalization");
1464
1465 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1466 (
1467 "!.1",
1468 Arc::new(StructArray::from(vec![
1469 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1470 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1471 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1472 ])) as ArrayRef,
1473 false,
1474 ),
1475 (
1476 "!.2",
1477 Arc::new(StructArray::from(vec![
1478 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1479 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1480 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1481 ])) as ArrayRef,
1482 true,
1483 ),
1484 ])
1485 .expect("valid conversion");
1486
1487 assert_eq!(expected, normalized);
1488
1489 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1491 .expect("valid conversion")
1492 .normalize(".", None)
1493 .expect("valid normalization");
1494
1495 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1496 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1497 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1498 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1499 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1500 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1501 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1502 ])
1503 .expect("valid conversion");
1504
1505 assert_eq!(expected, normalized);
1506 }
1507
1508 #[test]
1509 fn normalize_empty() {
1510 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1511 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1512 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1513
1514 let schema = Schema::new(vec![
1515 Field::new(
1516 "a",
1517 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1518 false,
1519 ),
1520 Field::new("month", DataType::Int64, true),
1521 ]);
1522
1523 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1524 .normalize(".", Some(0))
1525 .expect("valid normalization");
1526
1527 let expected = RecordBatch::new_empty(Arc::new(
1528 schema.normalize(".", Some(0)).expect("valid normalization"),
1529 ));
1530
1531 assert_eq!(expected, normalized);
1532 }
1533
1534 #[test]
1535 fn project() {
1536 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1537 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1538 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1539
1540 let record_batch =
1541 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1542 .expect("valid conversion");
1543
1544 let expected =
1545 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1546
1547 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1548 }
1549
1550 #[test]
1551 fn project_empty() {
1552 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1553
1554 let record_batch =
1555 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1556
1557 let expected = RecordBatch::try_new_with_options(
1558 Arc::new(Schema::empty()),
1559 vec![],
1560 &RecordBatchOptions {
1561 match_field_names: true,
1562 row_count: Some(3),
1563 },
1564 )
1565 .expect("valid conversion");
1566
1567 assert_eq!(expected, record_batch.project(&[]).unwrap());
1568 }
1569
1570 #[test]
1571 fn test_no_column_record_batch() {
1572 let schema = Arc::new(Schema::empty());
1573
1574 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1575 assert!(err
1576 .to_string()
1577 .contains("must either specify a row count or at least one column"));
1578
1579 let options = RecordBatchOptions::new().with_row_count(Some(10));
1580
1581 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1582 assert_eq!(ok.num_rows(), 10);
1583
1584 let a = ok.slice(2, 5);
1585 assert_eq!(a.num_rows(), 5);
1586
1587 let b = ok.slice(5, 0);
1588 assert_eq!(b.num_rows(), 0);
1589
1590 assert_ne!(a, b);
1591 assert_eq!(b, RecordBatch::new_empty(schema))
1592 }
1593
1594 #[test]
1595 fn test_nulls_in_non_nullable_field() {
1596 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1597 let maybe_batch = RecordBatch::try_new(
1598 schema,
1599 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1600 );
1601 assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1602 }
1603 #[test]
1604 fn test_record_batch_options() {
1605 let options = RecordBatchOptions::new()
1606 .with_match_field_names(false)
1607 .with_row_count(Some(20));
1608 assert!(!options.match_field_names);
1609 assert_eq!(options.row_count.unwrap(), 20)
1610 }
1611
1612 #[test]
1613 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1614 fn test_from_struct() {
1615 let s = StructArray::from(ArrayData::new_null(
1616 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1618 2,
1619 ));
1620 let _ = RecordBatch::from(s);
1621 }
1622
1623 #[test]
1624 fn test_with_schema() {
1625 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1626 let required_schema = Arc::new(required_schema);
1627 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1628 let nullable_schema = Arc::new(nullable_schema);
1629
1630 let batch = RecordBatch::try_new(
1631 required_schema.clone(),
1632 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1633 )
1634 .unwrap();
1635
1636 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1638
1639 batch.clone().with_schema(required_schema).unwrap_err();
1641
1642 let metadata = vec![("foo".to_string(), "bar".to_string())]
1644 .into_iter()
1645 .collect();
1646 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1647 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1648
1649 batch.with_schema(nullable_schema).unwrap_err();
1651 }
1652
1653 #[test]
1654 fn test_boxed_reader() {
1655 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1658 let schema = Arc::new(schema);
1659
1660 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1661 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1662
1663 fn get_size(reader: impl RecordBatchReader) -> usize {
1664 reader.size_hint().0
1665 }
1666
1667 let size = get_size(reader);
1668 assert_eq!(size, 0);
1669 }
1670
1671 #[test]
1672 fn test_remove_column_maintains_schema_metadata() {
1673 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1674 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1675
1676 let mut metadata = HashMap::new();
1677 metadata.insert("foo".to_string(), "bar".to_string());
1678 let schema = Schema::new(vec![
1679 Field::new("id", DataType::Int32, false),
1680 Field::new("bool", DataType::Boolean, false),
1681 ])
1682 .with_metadata(metadata);
1683
1684 let mut batch = RecordBatch::try_new(
1685 Arc::new(schema),
1686 vec![Arc::new(id_array), Arc::new(bool_array)],
1687 )
1688 .unwrap();
1689
1690 let _removed_column = batch.remove_column(0);
1691 assert_eq!(batch.schema().metadata().len(), 1);
1692 assert_eq!(
1693 batch.schema().metadata().get("foo").unwrap().as_str(),
1694 "bar"
1695 );
1696 }
1697}