1use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal32) => { $crate::Decimal32Array };
111 (@from Decimal64) => { $crate::Decimal64Array };
112 (@from Decimal128) => { $crate::Decimal128Array };
113 (@from Decimal256) => { $crate::Decimal256Array };
114 (@from TimestampSecond) => { $crate::TimestampSecondArray };
115 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119 (@from $ty: ident) => {
120 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121 };
122
123 (Null, $size: expr) => {
124 std::sync::Arc::new($crate::NullArray::new($size))
125 };
126
127 (Binary, [$($values: expr),*]) => {
128 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129 };
130
131 (LargeBinary, [$($values: expr),*]) => {
132 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133 };
134
135 ($ty: tt, [$($values: expr),*]) => {
136 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137 };
138}
139
140#[macro_export]
157macro_rules! record_batch {
158 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159 {
160 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161 $(
162 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163 )*
164 ]));
165
166 let batch = $crate::RecordBatch::try_new(
167 schema,
168 vec![$(
169 $crate::create_array!($type, [$($values),*]),
170 )*]
171 );
172
173 batch
174 }
175 }
176}
177
178#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203 schema: SchemaRef,
204 columns: Vec<Arc<dyn Array>>,
205
206 row_count: usize,
210}
211
212impl RecordBatch {
213 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242 let options = RecordBatchOptions::new();
243 Self::try_new_impl(schema, columns, &options)
244 }
245
246 pub unsafe fn new_unchecked(
262 schema: SchemaRef,
263 columns: Vec<Arc<dyn Array>>,
264 row_count: usize,
265 ) -> Self {
266 Self {
267 schema,
268 columns,
269 row_count,
270 }
271 }
272
273 pub fn try_new_with_options(
278 schema: SchemaRef,
279 columns: Vec<ArrayRef>,
280 options: &RecordBatchOptions,
281 ) -> Result<Self, ArrowError> {
282 Self::try_new_impl(schema, columns, options)
283 }
284
285 pub fn new_empty(schema: SchemaRef) -> Self {
287 let columns = schema
288 .fields()
289 .iter()
290 .map(|field| new_empty_array(field.data_type()))
291 .collect();
292
293 RecordBatch {
294 schema,
295 columns,
296 row_count: 0,
297 }
298 }
299
300 fn try_new_impl(
303 schema: SchemaRef,
304 columns: Vec<ArrayRef>,
305 options: &RecordBatchOptions,
306 ) -> Result<Self, ArrowError> {
307 if schema.fields().len() != columns.len() {
309 return Err(ArrowError::InvalidArgumentError(format!(
310 "number of columns({}) must match number of fields({}) in schema",
311 columns.len(),
312 schema.fields().len(),
313 )));
314 }
315
316 let row_count = options
317 .row_count
318 .or_else(|| columns.first().map(|col| col.len()))
319 .ok_or_else(|| {
320 ArrowError::InvalidArgumentError(
321 "must either specify a row count or at least one column".to_string(),
322 )
323 })?;
324
325 for (c, f) in columns.iter().zip(&schema.fields) {
326 if !f.is_nullable() && c.null_count() > 0 {
327 return Err(ArrowError::InvalidArgumentError(format!(
328 "Column '{}' is declared as non-nullable but contains null values",
329 f.name()
330 )));
331 }
332 }
333
334 if columns.iter().any(|c| c.len() != row_count) {
336 let err = match options.row_count {
337 Some(_) => "all columns in a record batch must have the specified row count",
338 None => "all columns in a record batch must have the same length",
339 };
340 return Err(ArrowError::InvalidArgumentError(err.to_string()));
341 }
342
343 let type_not_match = if options.match_field_names {
346 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347 } else {
348 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349 !col_type.equals_datatype(field_type)
350 }
351 };
352
353 let not_match = columns
355 .iter()
356 .zip(schema.fields().iter())
357 .map(|(col, field)| (col.data_type(), field.data_type()))
358 .enumerate()
359 .find(type_not_match);
360
361 if let Some((i, (col_type, field_type))) = not_match {
362 return Err(ArrowError::InvalidArgumentError(format!(
363 "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
364 }
365
366 Ok(RecordBatch {
367 schema,
368 columns,
369 row_count,
370 })
371 }
372
373 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
375 (self.schema, self.columns, self.row_count)
376 }
377
378 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
385 if !schema.contains(self.schema.as_ref()) {
386 return Err(ArrowError::SchemaError(format!(
387 "target schema is not superset of current schema target={schema} current={}",
388 self.schema
389 )));
390 }
391
392 Ok(Self {
393 schema,
394 columns: self.columns,
395 row_count: self.row_count,
396 })
397 }
398
399 pub fn schema(&self) -> SchemaRef {
401 self.schema.clone()
402 }
403
404 pub fn schema_ref(&self) -> &SchemaRef {
406 &self.schema
407 }
408
409 pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
427 let schema = Arc::make_mut(&mut self.schema);
428 &mut schema.metadata
429 }
430
431 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
433 let projected_schema = self.schema.project(indices)?;
434 let batch_fields = indices
435 .iter()
436 .map(|f| {
437 self.columns.get(*f).cloned().ok_or_else(|| {
438 ArrowError::SchemaError(format!(
439 "project index {} out of bounds, max field {}",
440 f,
441 self.columns.len()
442 ))
443 })
444 })
445 .collect::<Result<Vec<_>, _>>()?;
446
447 RecordBatch::try_new_with_options(
448 SchemaRef::new(projected_schema),
449 batch_fields,
450 &RecordBatchOptions {
451 match_field_names: true,
452 row_count: Some(self.row_count),
453 },
454 )
455 }
456
457 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
517 let max_level = match max_level.unwrap_or(usize::MAX) {
518 0 => usize::MAX,
519 val => val,
520 };
521 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
522 .columns
523 .iter()
524 .zip(self.schema.fields())
525 .rev()
526 .map(|(c, f)| {
527 let name_vec: Vec<&str> = vec![f.name()];
528 (0, c, name_vec, f)
529 })
530 .collect();
531 let mut columns: Vec<ArrayRef> = Vec::new();
532 let mut fields: Vec<FieldRef> = Vec::new();
533
534 while let Some((depth, c, name, field_ref)) = stack.pop() {
535 match field_ref.data_type() {
536 DataType::Struct(ff) if depth < max_level => {
537 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
539 let mut name = name.clone();
540 name.push(separator);
541 name.push(fff.name());
542 stack.push((depth + 1, cff, name, fff))
543 }
544 }
545 _ => {
546 let updated_field = Field::new(
547 name.concat(),
548 field_ref.data_type().clone(),
549 field_ref.is_nullable(),
550 );
551 columns.push(c.clone());
552 fields.push(Arc::new(updated_field));
553 }
554 }
555 }
556 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
557 }
558
559 pub fn num_columns(&self) -> usize {
578 self.columns.len()
579 }
580
581 pub fn num_rows(&self) -> usize {
600 self.row_count
601 }
602
603 pub fn column(&self, index: usize) -> &ArrayRef {
609 &self.columns[index]
610 }
611
612 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
614 self.schema()
615 .column_with_name(name)
616 .map(|(index, _)| &self.columns[index])
617 }
618
619 pub fn columns(&self) -> &[ArrayRef] {
621 &self.columns[..]
622 }
623
624 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
652 let mut builder = SchemaBuilder::from(self.schema.as_ref());
653 builder.remove(index);
654 self.schema = Arc::new(builder.finish());
655 self.columns.remove(index)
656 }
657
658 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
665 assert!((offset + length) <= self.num_rows());
666
667 let columns = self
668 .columns()
669 .iter()
670 .map(|column| column.slice(offset, length))
671 .collect();
672
673 Self {
674 schema: self.schema.clone(),
675 columns,
676 row_count: length,
677 }
678 }
679
680 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
717 where
718 I: IntoIterator<Item = (F, ArrayRef)>,
719 F: AsRef<str>,
720 {
721 let iter = value.into_iter().map(|(field_name, array)| {
725 let nullable = array.null_count() > 0;
726 (field_name, array, nullable)
727 });
728
729 Self::try_from_iter_with_nullable(iter)
730 }
731
732 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
754 where
755 I: IntoIterator<Item = (F, ArrayRef, bool)>,
756 F: AsRef<str>,
757 {
758 let iter = value.into_iter();
759 let capacity = iter.size_hint().0;
760 let mut schema = SchemaBuilder::with_capacity(capacity);
761 let mut columns = Vec::with_capacity(capacity);
762
763 for (field_name, array, nullable) in iter {
764 let field_name = field_name.as_ref();
765 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
766 columns.push(array);
767 }
768
769 let schema = Arc::new(schema.finish());
770 RecordBatch::try_new(schema, columns)
771 }
772
773 pub fn get_array_memory_size(&self) -> usize {
780 self.columns()
781 .iter()
782 .map(|array| array.get_array_memory_size())
783 .sum()
784 }
785}
786
787#[derive(Debug)]
789#[non_exhaustive]
790pub struct RecordBatchOptions {
791 pub match_field_names: bool,
793
794 pub row_count: Option<usize>,
796}
797
798impl RecordBatchOptions {
799 pub fn new() -> Self {
801 Self {
802 match_field_names: true,
803 row_count: None,
804 }
805 }
806 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
808 self.row_count = row_count;
809 self
810 }
811 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
813 self.match_field_names = match_field_names;
814 self
815 }
816}
817impl Default for RecordBatchOptions {
818 fn default() -> Self {
819 Self::new()
820 }
821}
822impl From<StructArray> for RecordBatch {
823 fn from(value: StructArray) -> Self {
824 let row_count = value.len();
825 let (fields, columns, nulls) = value.into_parts();
826 assert_eq!(
827 nulls.map(|n| n.null_count()).unwrap_or_default(),
828 0,
829 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
830 );
831
832 RecordBatch {
833 schema: Arc::new(Schema::new(fields)),
834 row_count,
835 columns,
836 }
837 }
838}
839
840impl From<&StructArray> for RecordBatch {
841 fn from(struct_array: &StructArray) -> Self {
842 struct_array.clone().into()
843 }
844}
845
846impl Index<&str> for RecordBatch {
847 type Output = ArrayRef;
848
849 fn index(&self, name: &str) -> &Self::Output {
855 self.column_by_name(name).unwrap()
856 }
857}
858
859pub struct RecordBatchIterator<I>
885where
886 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
887{
888 inner: I::IntoIter,
889 inner_schema: SchemaRef,
890}
891
892impl<I> RecordBatchIterator<I>
893where
894 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
895{
896 pub fn new(iter: I, schema: SchemaRef) -> Self {
900 Self {
901 inner: iter.into_iter(),
902 inner_schema: schema,
903 }
904 }
905}
906
907impl<I> Iterator for RecordBatchIterator<I>
908where
909 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
910{
911 type Item = I::Item;
912
913 fn next(&mut self) -> Option<Self::Item> {
914 self.inner.next()
915 }
916
917 fn size_hint(&self) -> (usize, Option<usize>) {
918 self.inner.size_hint()
919 }
920}
921
922impl<I> RecordBatchReader for RecordBatchIterator<I>
923where
924 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
925{
926 fn schema(&self) -> SchemaRef {
927 self.inner_schema.clone()
928 }
929}
930
931#[cfg(test)]
932mod tests {
933 use super::*;
934 use crate::{
935 BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
936 };
937 use arrow_buffer::{Buffer, ToByteSlice};
938 use arrow_data::{ArrayData, ArrayDataBuilder};
939 use arrow_schema::Fields;
940 use std::collections::HashMap;
941
942 #[test]
943 fn create_record_batch() {
944 let schema = Schema::new(vec![
945 Field::new("a", DataType::Int32, false),
946 Field::new("b", DataType::Utf8, false),
947 ]);
948
949 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
950 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
951
952 let record_batch =
953 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
954 check_batch(record_batch, 5)
955 }
956
957 #[test]
958 fn create_string_view_record_batch() {
959 let schema = Schema::new(vec![
960 Field::new("a", DataType::Int32, false),
961 Field::new("b", DataType::Utf8View, false),
962 ]);
963
964 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
965 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
966
967 let record_batch =
968 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
969
970 assert_eq!(5, record_batch.num_rows());
971 assert_eq!(2, record_batch.num_columns());
972 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
973 assert_eq!(
974 &DataType::Utf8View,
975 record_batch.schema().field(1).data_type()
976 );
977 assert_eq!(5, record_batch.column(0).len());
978 assert_eq!(5, record_batch.column(1).len());
979 }
980
981 #[test]
982 fn byte_size_should_not_regress() {
983 let schema = Schema::new(vec![
984 Field::new("a", DataType::Int32, false),
985 Field::new("b", DataType::Utf8, false),
986 ]);
987
988 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
989 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
990
991 let record_batch =
992 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
993 assert_eq!(record_batch.get_array_memory_size(), 364);
994 }
995
996 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
997 assert_eq!(num_rows, record_batch.num_rows());
998 assert_eq!(2, record_batch.num_columns());
999 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1000 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1001 assert_eq!(num_rows, record_batch.column(0).len());
1002 assert_eq!(num_rows, record_batch.column(1).len());
1003 }
1004
1005 #[test]
1006 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1007 fn create_record_batch_slice() {
1008 let schema = Schema::new(vec![
1009 Field::new("a", DataType::Int32, false),
1010 Field::new("b", DataType::Utf8, false),
1011 ]);
1012 let expected_schema = schema.clone();
1013
1014 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1015 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1016
1017 let record_batch =
1018 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1019
1020 let offset = 2;
1021 let length = 5;
1022 let record_batch_slice = record_batch.slice(offset, length);
1023
1024 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1025 check_batch(record_batch_slice, 5);
1026
1027 let offset = 2;
1028 let length = 0;
1029 let record_batch_slice = record_batch.slice(offset, length);
1030
1031 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1032 check_batch(record_batch_slice, 0);
1033
1034 let offset = 2;
1035 let length = 10;
1036 let _record_batch_slice = record_batch.slice(offset, length);
1037 }
1038
1039 #[test]
1040 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1041 fn create_record_batch_slice_empty_batch() {
1042 let schema = Schema::empty();
1043
1044 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1045
1046 let offset = 0;
1047 let length = 0;
1048 let record_batch_slice = record_batch.slice(offset, length);
1049 assert_eq!(0, record_batch_slice.schema().fields().len());
1050
1051 let offset = 1;
1052 let length = 2;
1053 let _record_batch_slice = record_batch.slice(offset, length);
1054 }
1055
1056 #[test]
1057 fn create_record_batch_try_from_iter() {
1058 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1059 Some(1),
1060 Some(2),
1061 None,
1062 Some(4),
1063 Some(5),
1064 ]));
1065 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1066
1067 let record_batch =
1068 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1069
1070 let expected_schema = Schema::new(vec![
1071 Field::new("a", DataType::Int32, true),
1072 Field::new("b", DataType::Utf8, false),
1073 ]);
1074 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1075 check_batch(record_batch, 5);
1076 }
1077
1078 #[test]
1079 fn create_record_batch_try_from_iter_with_nullable() {
1080 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1081 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1082
1083 let record_batch =
1085 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1086 .expect("valid conversion");
1087
1088 let expected_schema = Schema::new(vec![
1089 Field::new("a", DataType::Int32, false),
1090 Field::new("b", DataType::Utf8, true),
1091 ]);
1092 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1093 check_batch(record_batch, 5);
1094 }
1095
1096 #[test]
1097 fn create_record_batch_schema_mismatch() {
1098 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1099
1100 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1101
1102 let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1103 assert_eq!(err.to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0");
1104 }
1105
1106 #[test]
1107 fn create_record_batch_field_name_mismatch() {
1108 let fields = vec![
1109 Field::new("a1", DataType::Int32, false),
1110 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1111 ];
1112 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1113
1114 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1115 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1116 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1117 "array",
1118 DataType::Int8,
1119 false,
1120 ))))
1121 .add_child_data(a2_child.into_data())
1122 .len(2)
1123 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1124 .build()
1125 .unwrap();
1126 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1127 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1128 Field::new("aa1", DataType::Int32, false),
1129 Field::new("a2", a2.data_type().clone(), false),
1130 ])))
1131 .add_child_data(a1.into_data())
1132 .add_child_data(a2.into_data())
1133 .len(2)
1134 .build()
1135 .unwrap();
1136 let a: ArrayRef = Arc::new(StructArray::from(a));
1137
1138 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1140 assert!(batch.is_err());
1141
1142 let options = RecordBatchOptions {
1144 match_field_names: false,
1145 row_count: None,
1146 };
1147 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1148 assert!(batch.is_ok());
1149 }
1150
1151 #[test]
1152 fn create_record_batch_record_mismatch() {
1153 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1154
1155 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1156 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1157
1158 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1159 assert!(batch.is_err());
1160 }
1161
1162 #[test]
1163 fn create_record_batch_from_struct_array() {
1164 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1165 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1166 let struct_array = StructArray::from(vec![
1167 (
1168 Arc::new(Field::new("b", DataType::Boolean, false)),
1169 boolean.clone() as ArrayRef,
1170 ),
1171 (
1172 Arc::new(Field::new("c", DataType::Int32, false)),
1173 int.clone() as ArrayRef,
1174 ),
1175 ]);
1176
1177 let batch = RecordBatch::from(&struct_array);
1178 assert_eq!(2, batch.num_columns());
1179 assert_eq!(4, batch.num_rows());
1180 assert_eq!(
1181 struct_array.data_type(),
1182 &DataType::Struct(batch.schema().fields().clone())
1183 );
1184 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1185 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1186 }
1187
1188 #[test]
1189 fn record_batch_equality() {
1190 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1191 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1192 let schema1 = Schema::new(vec![
1193 Field::new("id", DataType::Int32, false),
1194 Field::new("val", DataType::Int32, false),
1195 ]);
1196
1197 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1198 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1199 let schema2 = Schema::new(vec![
1200 Field::new("id", DataType::Int32, false),
1201 Field::new("val", DataType::Int32, false),
1202 ]);
1203
1204 let batch1 = RecordBatch::try_new(
1205 Arc::new(schema1),
1206 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1207 )
1208 .unwrap();
1209
1210 let batch2 = RecordBatch::try_new(
1211 Arc::new(schema2),
1212 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1213 )
1214 .unwrap();
1215
1216 assert_eq!(batch1, batch2);
1217 }
1218
1219 #[test]
1221 fn record_batch_index_access() {
1222 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1223 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1224 let schema1 = Schema::new(vec![
1225 Field::new("id", DataType::Int32, false),
1226 Field::new("val", DataType::Int32, false),
1227 ]);
1228 let record_batch =
1229 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1230
1231 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1232 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1233 }
1234
1235 #[test]
1236 fn record_batch_vals_ne() {
1237 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1238 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1239 let schema1 = Schema::new(vec![
1240 Field::new("id", DataType::Int32, false),
1241 Field::new("val", DataType::Int32, false),
1242 ]);
1243
1244 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1245 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1246 let schema2 = Schema::new(vec![
1247 Field::new("id", DataType::Int32, false),
1248 Field::new("val", DataType::Int32, false),
1249 ]);
1250
1251 let batch1 = RecordBatch::try_new(
1252 Arc::new(schema1),
1253 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1254 )
1255 .unwrap();
1256
1257 let batch2 = RecordBatch::try_new(
1258 Arc::new(schema2),
1259 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1260 )
1261 .unwrap();
1262
1263 assert_ne!(batch1, batch2);
1264 }
1265
1266 #[test]
1267 fn record_batch_column_names_ne() {
1268 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1269 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1270 let schema1 = Schema::new(vec![
1271 Field::new("id", DataType::Int32, false),
1272 Field::new("val", DataType::Int32, false),
1273 ]);
1274
1275 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1276 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1277 let schema2 = Schema::new(vec![
1278 Field::new("id", DataType::Int32, false),
1279 Field::new("num", DataType::Int32, false),
1280 ]);
1281
1282 let batch1 = RecordBatch::try_new(
1283 Arc::new(schema1),
1284 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1285 )
1286 .unwrap();
1287
1288 let batch2 = RecordBatch::try_new(
1289 Arc::new(schema2),
1290 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1291 )
1292 .unwrap();
1293
1294 assert_ne!(batch1, batch2);
1295 }
1296
1297 #[test]
1298 fn record_batch_column_number_ne() {
1299 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1300 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1301 let schema1 = Schema::new(vec![
1302 Field::new("id", DataType::Int32, false),
1303 Field::new("val", DataType::Int32, false),
1304 ]);
1305
1306 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1307 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1308 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1309 let schema2 = Schema::new(vec![
1310 Field::new("id", DataType::Int32, false),
1311 Field::new("val", DataType::Int32, false),
1312 Field::new("num", DataType::Int32, false),
1313 ]);
1314
1315 let batch1 = RecordBatch::try_new(
1316 Arc::new(schema1),
1317 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1318 )
1319 .unwrap();
1320
1321 let batch2 = RecordBatch::try_new(
1322 Arc::new(schema2),
1323 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1324 )
1325 .unwrap();
1326
1327 assert_ne!(batch1, batch2);
1328 }
1329
1330 #[test]
1331 fn record_batch_row_count_ne() {
1332 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1333 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1334 let schema1 = Schema::new(vec![
1335 Field::new("id", DataType::Int32, false),
1336 Field::new("val", DataType::Int32, false),
1337 ]);
1338
1339 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1340 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1341 let schema2 = Schema::new(vec![
1342 Field::new("id", DataType::Int32, false),
1343 Field::new("num", DataType::Int32, false),
1344 ]);
1345
1346 let batch1 = RecordBatch::try_new(
1347 Arc::new(schema1),
1348 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1349 )
1350 .unwrap();
1351
1352 let batch2 = RecordBatch::try_new(
1353 Arc::new(schema2),
1354 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1355 )
1356 .unwrap();
1357
1358 assert_ne!(batch1, batch2);
1359 }
1360
1361 #[test]
1362 fn normalize_simple() {
1363 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1364 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1365 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1366
1367 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1368 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1369 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1370
1371 let a = Arc::new(StructArray::from(vec![
1372 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1373 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1374 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1375 ]));
1376
1377 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1378
1379 let schema = Schema::new(vec![
1380 Field::new(
1381 "a",
1382 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1383 false,
1384 ),
1385 Field::new("month", DataType::Int64, true),
1386 ]);
1387
1388 let normalized =
1389 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1390 .expect("valid conversion")
1391 .normalize(".", Some(0))
1392 .expect("valid normalization");
1393
1394 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1395 ("a.animals", animals.clone(), true),
1396 ("a.n_legs", n_legs.clone(), true),
1397 ("a.year", year.clone(), true),
1398 ("month", month.clone(), true),
1399 ])
1400 .expect("valid conversion");
1401
1402 assert_eq!(expected, normalized);
1403
1404 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1406 .expect("valid conversion")
1407 .normalize(".", None)
1408 .expect("valid normalization");
1409
1410 assert_eq!(expected, normalized);
1411 }
1412
1413 #[test]
1414 fn normalize_nested() {
1415 let a = Arc::new(Field::new("a", DataType::Int64, true));
1417 let b = Arc::new(Field::new("b", DataType::Int64, false));
1418 let c = Arc::new(Field::new("c", DataType::Int64, true));
1419
1420 let one = Arc::new(Field::new(
1421 "1",
1422 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1423 false,
1424 ));
1425 let two = Arc::new(Field::new(
1426 "2",
1427 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1428 true,
1429 ));
1430
1431 let exclamation = Arc::new(Field::new(
1432 "!",
1433 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1434 false,
1435 ));
1436
1437 let schema = Schema::new(vec![exclamation.clone()]);
1438
1439 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1441 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1442 let c_field = Int64Array::from(vec![None, Some(4)]);
1443
1444 let one_field = StructArray::from(vec![
1445 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1446 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1447 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1448 ]);
1449 let two_field = StructArray::from(vec![
1450 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1451 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1452 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1453 ]);
1454
1455 let exclamation_field = Arc::new(StructArray::from(vec![
1456 (one.clone(), Arc::new(one_field) as ArrayRef),
1457 (two.clone(), Arc::new(two_field) as ArrayRef),
1458 ]));
1459
1460 let normalized =
1462 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1463 .expect("valid conversion")
1464 .normalize(".", Some(1))
1465 .expect("valid normalization");
1466
1467 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1468 (
1469 "!.1",
1470 Arc::new(StructArray::from(vec![
1471 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1472 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1473 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1474 ])) as ArrayRef,
1475 false,
1476 ),
1477 (
1478 "!.2",
1479 Arc::new(StructArray::from(vec![
1480 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1481 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1482 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1483 ])) as ArrayRef,
1484 true,
1485 ),
1486 ])
1487 .expect("valid conversion");
1488
1489 assert_eq!(expected, normalized);
1490
1491 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1493 .expect("valid conversion")
1494 .normalize(".", None)
1495 .expect("valid normalization");
1496
1497 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1498 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1499 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1500 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1501 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1502 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1503 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1504 ])
1505 .expect("valid conversion");
1506
1507 assert_eq!(expected, normalized);
1508 }
1509
1510 #[test]
1511 fn normalize_empty() {
1512 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1513 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1514 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1515
1516 let schema = Schema::new(vec![
1517 Field::new(
1518 "a",
1519 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1520 false,
1521 ),
1522 Field::new("month", DataType::Int64, true),
1523 ]);
1524
1525 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1526 .normalize(".", Some(0))
1527 .expect("valid normalization");
1528
1529 let expected = RecordBatch::new_empty(Arc::new(
1530 schema.normalize(".", Some(0)).expect("valid normalization"),
1531 ));
1532
1533 assert_eq!(expected, normalized);
1534 }
1535
1536 #[test]
1537 fn project() {
1538 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1539 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1540 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1541
1542 let record_batch =
1543 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1544 .expect("valid conversion");
1545
1546 let expected =
1547 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1548
1549 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1550 }
1551
1552 #[test]
1553 fn project_empty() {
1554 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1555
1556 let record_batch =
1557 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1558
1559 let expected = RecordBatch::try_new_with_options(
1560 Arc::new(Schema::empty()),
1561 vec![],
1562 &RecordBatchOptions {
1563 match_field_names: true,
1564 row_count: Some(3),
1565 },
1566 )
1567 .expect("valid conversion");
1568
1569 assert_eq!(expected, record_batch.project(&[]).unwrap());
1570 }
1571
1572 #[test]
1573 fn test_no_column_record_batch() {
1574 let schema = Arc::new(Schema::empty());
1575
1576 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1577 assert!(err
1578 .to_string()
1579 .contains("must either specify a row count or at least one column"));
1580
1581 let options = RecordBatchOptions::new().with_row_count(Some(10));
1582
1583 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1584 assert_eq!(ok.num_rows(), 10);
1585
1586 let a = ok.slice(2, 5);
1587 assert_eq!(a.num_rows(), 5);
1588
1589 let b = ok.slice(5, 0);
1590 assert_eq!(b.num_rows(), 0);
1591
1592 assert_ne!(a, b);
1593 assert_eq!(b, RecordBatch::new_empty(schema))
1594 }
1595
1596 #[test]
1597 fn test_nulls_in_non_nullable_field() {
1598 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1599 let maybe_batch = RecordBatch::try_new(
1600 schema,
1601 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1602 );
1603 assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1604 }
1605 #[test]
1606 fn test_record_batch_options() {
1607 let options = RecordBatchOptions::new()
1608 .with_match_field_names(false)
1609 .with_row_count(Some(20));
1610 assert!(!options.match_field_names);
1611 assert_eq!(options.row_count.unwrap(), 20)
1612 }
1613
1614 #[test]
1615 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1616 fn test_from_struct() {
1617 let s = StructArray::from(ArrayData::new_null(
1618 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1620 2,
1621 ));
1622 let _ = RecordBatch::from(s);
1623 }
1624
1625 #[test]
1626 fn test_with_schema() {
1627 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1628 let required_schema = Arc::new(required_schema);
1629 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1630 let nullable_schema = Arc::new(nullable_schema);
1631
1632 let batch = RecordBatch::try_new(
1633 required_schema.clone(),
1634 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1635 )
1636 .unwrap();
1637
1638 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1640
1641 batch.clone().with_schema(required_schema).unwrap_err();
1643
1644 let metadata = vec![("foo".to_string(), "bar".to_string())]
1646 .into_iter()
1647 .collect();
1648 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1649 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1650
1651 batch.with_schema(nullable_schema).unwrap_err();
1653 }
1654
1655 #[test]
1656 fn test_boxed_reader() {
1657 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1660 let schema = Arc::new(schema);
1661
1662 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1663 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1664
1665 fn get_size(reader: impl RecordBatchReader) -> usize {
1666 reader.size_hint().0
1667 }
1668
1669 let size = get_size(reader);
1670 assert_eq!(size, 0);
1671 }
1672
1673 #[test]
1674 fn test_remove_column_maintains_schema_metadata() {
1675 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1676 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1677
1678 let mut metadata = HashMap::new();
1679 metadata.insert("foo".to_string(), "bar".to_string());
1680 let schema = Schema::new(vec![
1681 Field::new("id", DataType::Int32, false),
1682 Field::new("bool", DataType::Boolean, false),
1683 ])
1684 .with_metadata(metadata);
1685
1686 let mut batch = RecordBatch::try_new(
1687 Arc::new(schema),
1688 vec![Arc::new(id_array), Arc::new(bool_array)],
1689 )
1690 .unwrap();
1691
1692 let _removed_column = batch.remove_column(0);
1693 assert_eq!(batch.schema().metadata().len(), 1);
1694 assert_eq!(
1695 batch.schema().metadata().get("foo").unwrap().as_str(),
1696 "bar"
1697 );
1698 }
1699}