1use crate::cast::AsArray;
22use crate::{Array, ArrayRef, StructArray, new_empty_array};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal32) => { $crate::Decimal32Array };
111 (@from Decimal64) => { $crate::Decimal64Array };
112 (@from Decimal128) => { $crate::Decimal128Array };
113 (@from Decimal256) => { $crate::Decimal256Array };
114 (@from TimestampSecond) => { $crate::TimestampSecondArray };
115 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119 (@from $ty: ident) => {
120 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121 };
122
123 (Null, $size: expr) => {
124 std::sync::Arc::new($crate::NullArray::new($size))
125 };
126
127 (Binary, [$($values: expr),*]) => {
128 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129 };
130
131 (LargeBinary, [$($values: expr),*]) => {
132 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133 };
134
135 ($ty: tt, [$($values: expr),*]) => {
136 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137 };
138}
139
140#[macro_export]
157macro_rules! record_batch {
158 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159 {
160 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161 $(
162 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163 )*
164 ]));
165
166 let batch = $crate::RecordBatch::try_new(
167 schema,
168 vec![$(
169 $crate::create_array!($type, [$($values),*]),
170 )*]
171 );
172
173 batch
174 }
175 }
176}
177
178#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203 schema: SchemaRef,
204 columns: Vec<Arc<dyn Array>>,
205
206 row_count: usize,
210}
211
212impl RecordBatch {
213 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242 let options = RecordBatchOptions::new();
243 Self::try_new_impl(schema, columns, &options)
244 }
245
246 pub unsafe fn new_unchecked(
262 schema: SchemaRef,
263 columns: Vec<Arc<dyn Array>>,
264 row_count: usize,
265 ) -> Self {
266 Self {
267 schema,
268 columns,
269 row_count,
270 }
271 }
272
273 pub fn try_new_with_options(
278 schema: SchemaRef,
279 columns: Vec<ArrayRef>,
280 options: &RecordBatchOptions,
281 ) -> Result<Self, ArrowError> {
282 Self::try_new_impl(schema, columns, options)
283 }
284
285 pub fn new_empty(schema: SchemaRef) -> Self {
287 let columns = schema
288 .fields()
289 .iter()
290 .map(|field| new_empty_array(field.data_type()))
291 .collect();
292
293 RecordBatch {
294 schema,
295 columns,
296 row_count: 0,
297 }
298 }
299
300 fn try_new_impl(
303 schema: SchemaRef,
304 columns: Vec<ArrayRef>,
305 options: &RecordBatchOptions,
306 ) -> Result<Self, ArrowError> {
307 if schema.fields().len() != columns.len() {
309 return Err(ArrowError::InvalidArgumentError(format!(
310 "number of columns({}) must match number of fields({}) in schema",
311 columns.len(),
312 schema.fields().len(),
313 )));
314 }
315
316 let row_count = options
317 .row_count
318 .or_else(|| columns.first().map(|col| col.len()))
319 .ok_or_else(|| {
320 ArrowError::InvalidArgumentError(
321 "must either specify a row count or at least one column".to_string(),
322 )
323 })?;
324
325 for (c, f) in columns.iter().zip(&schema.fields) {
326 if !f.is_nullable() && c.null_count() > 0 {
327 return Err(ArrowError::InvalidArgumentError(format!(
328 "Column '{}' is declared as non-nullable but contains null values",
329 f.name()
330 )));
331 }
332 }
333
334 if columns.iter().any(|c| c.len() != row_count) {
336 let err = match options.row_count {
337 Some(_) => "all columns in a record batch must have the specified row count",
338 None => "all columns in a record batch must have the same length",
339 };
340 return Err(ArrowError::InvalidArgumentError(err.to_string()));
341 }
342
343 let type_not_match = if options.match_field_names {
346 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347 } else {
348 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349 !col_type.equals_datatype(field_type)
350 }
351 };
352
353 let not_match = columns
355 .iter()
356 .zip(schema.fields().iter())
357 .map(|(col, field)| (col.data_type(), field.data_type()))
358 .enumerate()
359 .find(type_not_match);
360
361 if let Some((i, (col_type, field_type))) = not_match {
362 return Err(ArrowError::InvalidArgumentError(format!(
363 "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
364 )));
365 }
366
367 Ok(RecordBatch {
368 schema,
369 columns,
370 row_count,
371 })
372 }
373
374 pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376 (self.schema, self.columns, self.row_count)
377 }
378
379 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
386 if !schema.contains(self.schema.as_ref()) {
387 return Err(ArrowError::SchemaError(format!(
388 "target schema is not superset of current schema target={schema} current={}",
389 self.schema
390 )));
391 }
392
393 Ok(Self {
394 schema,
395 columns: self.columns,
396 row_count: self.row_count,
397 })
398 }
399
400 pub fn schema(&self) -> SchemaRef {
402 self.schema.clone()
403 }
404
405 pub fn schema_ref(&self) -> &SchemaRef {
407 &self.schema
408 }
409
410 pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
428 let schema = Arc::make_mut(&mut self.schema);
429 &mut schema.metadata
430 }
431
432 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434 let projected_schema = self.schema.project(indices)?;
435 let batch_fields = indices
436 .iter()
437 .map(|f| {
438 self.columns.get(*f).cloned().ok_or_else(|| {
439 ArrowError::SchemaError(format!(
440 "project index {} out of bounds, max field {}",
441 f,
442 self.columns.len()
443 ))
444 })
445 })
446 .collect::<Result<Vec<_>, _>>()?;
447
448 RecordBatch::try_new_with_options(
449 SchemaRef::new(projected_schema),
450 batch_fields,
451 &RecordBatchOptions {
452 match_field_names: true,
453 row_count: Some(self.row_count),
454 },
455 )
456 }
457
458 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
518 let max_level = match max_level.unwrap_or(usize::MAX) {
519 0 => usize::MAX,
520 val => val,
521 };
522 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
523 .columns
524 .iter()
525 .zip(self.schema.fields())
526 .rev()
527 .map(|(c, f)| {
528 let name_vec: Vec<&str> = vec![f.name()];
529 (0, c, name_vec, f)
530 })
531 .collect();
532 let mut columns: Vec<ArrayRef> = Vec::new();
533 let mut fields: Vec<FieldRef> = Vec::new();
534
535 while let Some((depth, c, name, field_ref)) = stack.pop() {
536 match field_ref.data_type() {
537 DataType::Struct(ff) if depth < max_level => {
538 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
540 let mut name = name.clone();
541 name.push(separator);
542 name.push(fff.name());
543 stack.push((depth + 1, cff, name, fff))
544 }
545 }
546 _ => {
547 let updated_field = Field::new(
548 name.concat(),
549 field_ref.data_type().clone(),
550 field_ref.is_nullable(),
551 );
552 columns.push(c.clone());
553 fields.push(Arc::new(updated_field));
554 }
555 }
556 }
557 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
558 }
559
560 pub fn num_columns(&self) -> usize {
579 self.columns.len()
580 }
581
582 pub fn num_rows(&self) -> usize {
601 self.row_count
602 }
603
604 pub fn column(&self, index: usize) -> &ArrayRef {
610 &self.columns[index]
611 }
612
613 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
615 self.schema()
616 .column_with_name(name)
617 .map(|(index, _)| &self.columns[index])
618 }
619
620 pub fn columns(&self) -> &[ArrayRef] {
622 &self.columns[..]
623 }
624
625 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
653 let mut builder = SchemaBuilder::from(self.schema.as_ref());
654 builder.remove(index);
655 self.schema = Arc::new(builder.finish());
656 self.columns.remove(index)
657 }
658
659 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
666 assert!((offset + length) <= self.num_rows());
667
668 let columns = self
669 .columns()
670 .iter()
671 .map(|column| column.slice(offset, length))
672 .collect();
673
674 Self {
675 schema: self.schema.clone(),
676 columns,
677 row_count: length,
678 }
679 }
680
681 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
718 where
719 I: IntoIterator<Item = (F, ArrayRef)>,
720 F: AsRef<str>,
721 {
722 let iter = value.into_iter().map(|(field_name, array)| {
726 let nullable = array.null_count() > 0;
727 (field_name, array, nullable)
728 });
729
730 Self::try_from_iter_with_nullable(iter)
731 }
732
733 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
755 where
756 I: IntoIterator<Item = (F, ArrayRef, bool)>,
757 F: AsRef<str>,
758 {
759 let iter = value.into_iter();
760 let capacity = iter.size_hint().0;
761 let mut schema = SchemaBuilder::with_capacity(capacity);
762 let mut columns = Vec::with_capacity(capacity);
763
764 for (field_name, array, nullable) in iter {
765 let field_name = field_name.as_ref();
766 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
767 columns.push(array);
768 }
769
770 let schema = Arc::new(schema.finish());
771 RecordBatch::try_new(schema, columns)
772 }
773
774 pub fn get_array_memory_size(&self) -> usize {
781 self.columns()
782 .iter()
783 .map(|array| array.get_array_memory_size())
784 .sum()
785 }
786}
787
788#[derive(Debug)]
790#[non_exhaustive]
791pub struct RecordBatchOptions {
792 pub match_field_names: bool,
794
795 pub row_count: Option<usize>,
797}
798
799impl RecordBatchOptions {
800 pub fn new() -> Self {
802 Self {
803 match_field_names: true,
804 row_count: None,
805 }
806 }
807 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
809 self.row_count = row_count;
810 self
811 }
812 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
814 self.match_field_names = match_field_names;
815 self
816 }
817}
818impl Default for RecordBatchOptions {
819 fn default() -> Self {
820 Self::new()
821 }
822}
823impl From<StructArray> for RecordBatch {
824 fn from(value: StructArray) -> Self {
825 let row_count = value.len();
826 let (fields, columns, nulls) = value.into_parts();
827 assert_eq!(
828 nulls.map(|n| n.null_count()).unwrap_or_default(),
829 0,
830 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
831 );
832
833 RecordBatch {
834 schema: Arc::new(Schema::new(fields)),
835 row_count,
836 columns,
837 }
838 }
839}
840
841impl From<&StructArray> for RecordBatch {
842 fn from(struct_array: &StructArray) -> Self {
843 struct_array.clone().into()
844 }
845}
846
847impl Index<&str> for RecordBatch {
848 type Output = ArrayRef;
849
850 fn index(&self, name: &str) -> &Self::Output {
856 self.column_by_name(name).unwrap()
857 }
858}
859
860pub struct RecordBatchIterator<I>
886where
887 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
888{
889 inner: I::IntoIter,
890 inner_schema: SchemaRef,
891}
892
893impl<I> RecordBatchIterator<I>
894where
895 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
896{
897 pub fn new(iter: I, schema: SchemaRef) -> Self {
901 Self {
902 inner: iter.into_iter(),
903 inner_schema: schema,
904 }
905 }
906}
907
908impl<I> Iterator for RecordBatchIterator<I>
909where
910 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
911{
912 type Item = I::Item;
913
914 fn next(&mut self) -> Option<Self::Item> {
915 self.inner.next()
916 }
917
918 fn size_hint(&self) -> (usize, Option<usize>) {
919 self.inner.size_hint()
920 }
921}
922
923impl<I> RecordBatchReader for RecordBatchIterator<I>
924where
925 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
926{
927 fn schema(&self) -> SchemaRef {
928 self.inner_schema.clone()
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935 use crate::{
936 BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
937 };
938 use arrow_buffer::{Buffer, ToByteSlice};
939 use arrow_data::{ArrayData, ArrayDataBuilder};
940 use arrow_schema::Fields;
941 use std::collections::HashMap;
942
943 #[test]
944 fn create_record_batch() {
945 let schema = Schema::new(vec![
946 Field::new("a", DataType::Int32, false),
947 Field::new("b", DataType::Utf8, false),
948 ]);
949
950 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
951 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
952
953 let record_batch =
954 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
955 check_batch(record_batch, 5)
956 }
957
958 #[test]
959 fn create_string_view_record_batch() {
960 let schema = Schema::new(vec![
961 Field::new("a", DataType::Int32, false),
962 Field::new("b", DataType::Utf8View, false),
963 ]);
964
965 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
966 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
967
968 let record_batch =
969 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
970
971 assert_eq!(5, record_batch.num_rows());
972 assert_eq!(2, record_batch.num_columns());
973 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
974 assert_eq!(
975 &DataType::Utf8View,
976 record_batch.schema().field(1).data_type()
977 );
978 assert_eq!(5, record_batch.column(0).len());
979 assert_eq!(5, record_batch.column(1).len());
980 }
981
982 #[test]
983 fn byte_size_should_not_regress() {
984 let schema = Schema::new(vec![
985 Field::new("a", DataType::Int32, false),
986 Field::new("b", DataType::Utf8, false),
987 ]);
988
989 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
990 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
991
992 let record_batch =
993 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
994 assert_eq!(record_batch.get_array_memory_size(), 364);
995 }
996
997 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
998 assert_eq!(num_rows, record_batch.num_rows());
999 assert_eq!(2, record_batch.num_columns());
1000 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1001 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1002 assert_eq!(num_rows, record_batch.column(0).len());
1003 assert_eq!(num_rows, record_batch.column(1).len());
1004 }
1005
1006 #[test]
1007 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1008 fn create_record_batch_slice() {
1009 let schema = Schema::new(vec![
1010 Field::new("a", DataType::Int32, false),
1011 Field::new("b", DataType::Utf8, false),
1012 ]);
1013 let expected_schema = schema.clone();
1014
1015 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1016 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1017
1018 let record_batch =
1019 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1020
1021 let offset = 2;
1022 let length = 5;
1023 let record_batch_slice = record_batch.slice(offset, length);
1024
1025 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1026 check_batch(record_batch_slice, 5);
1027
1028 let offset = 2;
1029 let length = 0;
1030 let record_batch_slice = record_batch.slice(offset, length);
1031
1032 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1033 check_batch(record_batch_slice, 0);
1034
1035 let offset = 2;
1036 let length = 10;
1037 let _record_batch_slice = record_batch.slice(offset, length);
1038 }
1039
1040 #[test]
1041 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1042 fn create_record_batch_slice_empty_batch() {
1043 let schema = Schema::empty();
1044
1045 let record_batch = RecordBatch::new_empty(Arc::new(schema));
1046
1047 let offset = 0;
1048 let length = 0;
1049 let record_batch_slice = record_batch.slice(offset, length);
1050 assert_eq!(0, record_batch_slice.schema().fields().len());
1051
1052 let offset = 1;
1053 let length = 2;
1054 let _record_batch_slice = record_batch.slice(offset, length);
1055 }
1056
1057 #[test]
1058 fn create_record_batch_try_from_iter() {
1059 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1060 Some(1),
1061 Some(2),
1062 None,
1063 Some(4),
1064 Some(5),
1065 ]));
1066 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1067
1068 let record_batch =
1069 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1070
1071 let expected_schema = Schema::new(vec![
1072 Field::new("a", DataType::Int32, true),
1073 Field::new("b", DataType::Utf8, false),
1074 ]);
1075 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1076 check_batch(record_batch, 5);
1077 }
1078
1079 #[test]
1080 fn create_record_batch_try_from_iter_with_nullable() {
1081 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1082 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1083
1084 let record_batch =
1086 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1087 .expect("valid conversion");
1088
1089 let expected_schema = Schema::new(vec![
1090 Field::new("a", DataType::Int32, false),
1091 Field::new("b", DataType::Utf8, true),
1092 ]);
1093 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1094 check_batch(record_batch, 5);
1095 }
1096
1097 #[test]
1098 fn create_record_batch_schema_mismatch() {
1099 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1100
1101 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1102
1103 let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1104 assert_eq!(
1105 err.to_string(),
1106 "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1107 );
1108 }
1109
1110 #[test]
1111 fn create_record_batch_field_name_mismatch() {
1112 let fields = vec![
1113 Field::new("a1", DataType::Int32, false),
1114 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1115 ];
1116 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1117
1118 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1119 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1120 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1121 "array",
1122 DataType::Int8,
1123 false,
1124 ))))
1125 .add_child_data(a2_child.into_data())
1126 .len(2)
1127 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1128 .build()
1129 .unwrap();
1130 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1131 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1132 Field::new("aa1", DataType::Int32, false),
1133 Field::new("a2", a2.data_type().clone(), false),
1134 ])))
1135 .add_child_data(a1.into_data())
1136 .add_child_data(a2.into_data())
1137 .len(2)
1138 .build()
1139 .unwrap();
1140 let a: ArrayRef = Arc::new(StructArray::from(a));
1141
1142 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1144 assert!(batch.is_err());
1145
1146 let options = RecordBatchOptions {
1148 match_field_names: false,
1149 row_count: None,
1150 };
1151 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1152 assert!(batch.is_ok());
1153 }
1154
1155 #[test]
1156 fn create_record_batch_record_mismatch() {
1157 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1158
1159 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1160 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1161
1162 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1163 assert!(batch.is_err());
1164 }
1165
1166 #[test]
1167 fn create_record_batch_from_struct_array() {
1168 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1169 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1170 let struct_array = StructArray::from(vec![
1171 (
1172 Arc::new(Field::new("b", DataType::Boolean, false)),
1173 boolean.clone() as ArrayRef,
1174 ),
1175 (
1176 Arc::new(Field::new("c", DataType::Int32, false)),
1177 int.clone() as ArrayRef,
1178 ),
1179 ]);
1180
1181 let batch = RecordBatch::from(&struct_array);
1182 assert_eq!(2, batch.num_columns());
1183 assert_eq!(4, batch.num_rows());
1184 assert_eq!(
1185 struct_array.data_type(),
1186 &DataType::Struct(batch.schema().fields().clone())
1187 );
1188 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1189 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1190 }
1191
1192 #[test]
1193 fn record_batch_equality() {
1194 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1195 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1196 let schema1 = Schema::new(vec![
1197 Field::new("id", DataType::Int32, false),
1198 Field::new("val", DataType::Int32, false),
1199 ]);
1200
1201 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1202 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1203 let schema2 = Schema::new(vec![
1204 Field::new("id", DataType::Int32, false),
1205 Field::new("val", DataType::Int32, false),
1206 ]);
1207
1208 let batch1 = RecordBatch::try_new(
1209 Arc::new(schema1),
1210 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1211 )
1212 .unwrap();
1213
1214 let batch2 = RecordBatch::try_new(
1215 Arc::new(schema2),
1216 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1217 )
1218 .unwrap();
1219
1220 assert_eq!(batch1, batch2);
1221 }
1222
1223 #[test]
1225 fn record_batch_index_access() {
1226 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1227 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1228 let schema1 = Schema::new(vec![
1229 Field::new("id", DataType::Int32, false),
1230 Field::new("val", DataType::Int32, false),
1231 ]);
1232 let record_batch =
1233 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1234
1235 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1236 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1237 }
1238
1239 #[test]
1240 fn record_batch_vals_ne() {
1241 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1242 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1243 let schema1 = Schema::new(vec![
1244 Field::new("id", DataType::Int32, false),
1245 Field::new("val", DataType::Int32, false),
1246 ]);
1247
1248 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1249 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1250 let schema2 = Schema::new(vec![
1251 Field::new("id", DataType::Int32, false),
1252 Field::new("val", DataType::Int32, false),
1253 ]);
1254
1255 let batch1 = RecordBatch::try_new(
1256 Arc::new(schema1),
1257 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1258 )
1259 .unwrap();
1260
1261 let batch2 = RecordBatch::try_new(
1262 Arc::new(schema2),
1263 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1264 )
1265 .unwrap();
1266
1267 assert_ne!(batch1, batch2);
1268 }
1269
1270 #[test]
1271 fn record_batch_column_names_ne() {
1272 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1273 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1274 let schema1 = Schema::new(vec![
1275 Field::new("id", DataType::Int32, false),
1276 Field::new("val", DataType::Int32, false),
1277 ]);
1278
1279 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1280 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1281 let schema2 = Schema::new(vec![
1282 Field::new("id", DataType::Int32, false),
1283 Field::new("num", DataType::Int32, false),
1284 ]);
1285
1286 let batch1 = RecordBatch::try_new(
1287 Arc::new(schema1),
1288 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1289 )
1290 .unwrap();
1291
1292 let batch2 = RecordBatch::try_new(
1293 Arc::new(schema2),
1294 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1295 )
1296 .unwrap();
1297
1298 assert_ne!(batch1, batch2);
1299 }
1300
1301 #[test]
1302 fn record_batch_column_number_ne() {
1303 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1304 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1305 let schema1 = Schema::new(vec![
1306 Field::new("id", DataType::Int32, false),
1307 Field::new("val", DataType::Int32, false),
1308 ]);
1309
1310 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1311 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1312 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1313 let schema2 = Schema::new(vec![
1314 Field::new("id", DataType::Int32, false),
1315 Field::new("val", DataType::Int32, false),
1316 Field::new("num", DataType::Int32, false),
1317 ]);
1318
1319 let batch1 = RecordBatch::try_new(
1320 Arc::new(schema1),
1321 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1322 )
1323 .unwrap();
1324
1325 let batch2 = RecordBatch::try_new(
1326 Arc::new(schema2),
1327 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1328 )
1329 .unwrap();
1330
1331 assert_ne!(batch1, batch2);
1332 }
1333
1334 #[test]
1335 fn record_batch_row_count_ne() {
1336 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1337 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1338 let schema1 = Schema::new(vec![
1339 Field::new("id", DataType::Int32, false),
1340 Field::new("val", DataType::Int32, false),
1341 ]);
1342
1343 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1344 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1345 let schema2 = Schema::new(vec![
1346 Field::new("id", DataType::Int32, false),
1347 Field::new("num", DataType::Int32, false),
1348 ]);
1349
1350 let batch1 = RecordBatch::try_new(
1351 Arc::new(schema1),
1352 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1353 )
1354 .unwrap();
1355
1356 let batch2 = RecordBatch::try_new(
1357 Arc::new(schema2),
1358 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1359 )
1360 .unwrap();
1361
1362 assert_ne!(batch1, batch2);
1363 }
1364
1365 #[test]
1366 fn normalize_simple() {
1367 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1368 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1369 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1370
1371 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1372 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1373 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1374
1375 let a = Arc::new(StructArray::from(vec![
1376 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1377 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1378 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1379 ]));
1380
1381 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1382
1383 let schema = Schema::new(vec![
1384 Field::new(
1385 "a",
1386 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1387 false,
1388 ),
1389 Field::new("month", DataType::Int64, true),
1390 ]);
1391
1392 let normalized =
1393 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1394 .expect("valid conversion")
1395 .normalize(".", Some(0))
1396 .expect("valid normalization");
1397
1398 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1399 ("a.animals", animals.clone(), true),
1400 ("a.n_legs", n_legs.clone(), true),
1401 ("a.year", year.clone(), true),
1402 ("month", month.clone(), true),
1403 ])
1404 .expect("valid conversion");
1405
1406 assert_eq!(expected, normalized);
1407
1408 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1410 .expect("valid conversion")
1411 .normalize(".", None)
1412 .expect("valid normalization");
1413
1414 assert_eq!(expected, normalized);
1415 }
1416
1417 #[test]
1418 fn normalize_nested() {
1419 let a = Arc::new(Field::new("a", DataType::Int64, true));
1421 let b = Arc::new(Field::new("b", DataType::Int64, false));
1422 let c = Arc::new(Field::new("c", DataType::Int64, true));
1423
1424 let one = Arc::new(Field::new(
1425 "1",
1426 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1427 false,
1428 ));
1429 let two = Arc::new(Field::new(
1430 "2",
1431 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1432 true,
1433 ));
1434
1435 let exclamation = Arc::new(Field::new(
1436 "!",
1437 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1438 false,
1439 ));
1440
1441 let schema = Schema::new(vec![exclamation.clone()]);
1442
1443 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1445 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1446 let c_field = Int64Array::from(vec![None, Some(4)]);
1447
1448 let one_field = StructArray::from(vec![
1449 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1450 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1451 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1452 ]);
1453 let two_field = StructArray::from(vec![
1454 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1455 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1456 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1457 ]);
1458
1459 let exclamation_field = Arc::new(StructArray::from(vec![
1460 (one.clone(), Arc::new(one_field) as ArrayRef),
1461 (two.clone(), Arc::new(two_field) as ArrayRef),
1462 ]));
1463
1464 let normalized =
1466 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1467 .expect("valid conversion")
1468 .normalize(".", Some(1))
1469 .expect("valid normalization");
1470
1471 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1472 (
1473 "!.1",
1474 Arc::new(StructArray::from(vec![
1475 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1476 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1477 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1478 ])) as ArrayRef,
1479 false,
1480 ),
1481 (
1482 "!.2",
1483 Arc::new(StructArray::from(vec![
1484 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1485 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1486 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1487 ])) as ArrayRef,
1488 true,
1489 ),
1490 ])
1491 .expect("valid conversion");
1492
1493 assert_eq!(expected, normalized);
1494
1495 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1497 .expect("valid conversion")
1498 .normalize(".", None)
1499 .expect("valid normalization");
1500
1501 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1502 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1503 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1504 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1505 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1506 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1507 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1508 ])
1509 .expect("valid conversion");
1510
1511 assert_eq!(expected, normalized);
1512 }
1513
1514 #[test]
1515 fn normalize_empty() {
1516 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1517 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1518 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1519
1520 let schema = Schema::new(vec![
1521 Field::new(
1522 "a",
1523 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1524 false,
1525 ),
1526 Field::new("month", DataType::Int64, true),
1527 ]);
1528
1529 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1530 .normalize(".", Some(0))
1531 .expect("valid normalization");
1532
1533 let expected = RecordBatch::new_empty(Arc::new(
1534 schema.normalize(".", Some(0)).expect("valid normalization"),
1535 ));
1536
1537 assert_eq!(expected, normalized);
1538 }
1539
1540 #[test]
1541 fn project() {
1542 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1543 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1544 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1545
1546 let record_batch =
1547 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1548 .expect("valid conversion");
1549
1550 let expected =
1551 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1552
1553 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1554 }
1555
1556 #[test]
1557 fn project_empty() {
1558 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1559
1560 let record_batch =
1561 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1562
1563 let expected = RecordBatch::try_new_with_options(
1564 Arc::new(Schema::empty()),
1565 vec![],
1566 &RecordBatchOptions {
1567 match_field_names: true,
1568 row_count: Some(3),
1569 },
1570 )
1571 .expect("valid conversion");
1572
1573 assert_eq!(expected, record_batch.project(&[]).unwrap());
1574 }
1575
1576 #[test]
1577 fn test_no_column_record_batch() {
1578 let schema = Arc::new(Schema::empty());
1579
1580 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1581 assert!(
1582 err.to_string()
1583 .contains("must either specify a row count or at least one column")
1584 );
1585
1586 let options = RecordBatchOptions::new().with_row_count(Some(10));
1587
1588 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1589 assert_eq!(ok.num_rows(), 10);
1590
1591 let a = ok.slice(2, 5);
1592 assert_eq!(a.num_rows(), 5);
1593
1594 let b = ok.slice(5, 0);
1595 assert_eq!(b.num_rows(), 0);
1596
1597 assert_ne!(a, b);
1598 assert_eq!(b, RecordBatch::new_empty(schema))
1599 }
1600
1601 #[test]
1602 fn test_nulls_in_non_nullable_field() {
1603 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1604 let maybe_batch = RecordBatch::try_new(
1605 schema,
1606 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1607 );
1608 assert_eq!(
1609 "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1610 format!("{}", maybe_batch.err().unwrap())
1611 );
1612 }
1613 #[test]
1614 fn test_record_batch_options() {
1615 let options = RecordBatchOptions::new()
1616 .with_match_field_names(false)
1617 .with_row_count(Some(20));
1618 assert!(!options.match_field_names);
1619 assert_eq!(options.row_count.unwrap(), 20)
1620 }
1621
1622 #[test]
1623 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1624 fn test_from_struct() {
1625 let s = StructArray::from(ArrayData::new_null(
1626 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1628 2,
1629 ));
1630 let _ = RecordBatch::from(s);
1631 }
1632
1633 #[test]
1634 fn test_with_schema() {
1635 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1636 let required_schema = Arc::new(required_schema);
1637 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1638 let nullable_schema = Arc::new(nullable_schema);
1639
1640 let batch = RecordBatch::try_new(
1641 required_schema.clone(),
1642 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1643 )
1644 .unwrap();
1645
1646 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1648
1649 batch.clone().with_schema(required_schema).unwrap_err();
1651
1652 let metadata = vec![("foo".to_string(), "bar".to_string())]
1654 .into_iter()
1655 .collect();
1656 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1657 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1658
1659 batch.with_schema(nullable_schema).unwrap_err();
1661 }
1662
1663 #[test]
1664 fn test_boxed_reader() {
1665 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1668 let schema = Arc::new(schema);
1669
1670 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1671 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1672
1673 fn get_size(reader: impl RecordBatchReader) -> usize {
1674 reader.size_hint().0
1675 }
1676
1677 let size = get_size(reader);
1678 assert_eq!(size, 0);
1679 }
1680
1681 #[test]
1682 fn test_remove_column_maintains_schema_metadata() {
1683 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1684 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1685
1686 let mut metadata = HashMap::new();
1687 metadata.insert("foo".to_string(), "bar".to_string());
1688 let schema = Schema::new(vec![
1689 Field::new("id", DataType::Int32, false),
1690 Field::new("bool", DataType::Boolean, false),
1691 ])
1692 .with_metadata(metadata);
1693
1694 let mut batch = RecordBatch::try_new(
1695 Arc::new(schema),
1696 vec![Arc::new(id_array), Arc::new(bool_array)],
1697 )
1698 .unwrap();
1699
1700 let _removed_column = batch.remove_column(0);
1701 assert_eq!(batch.schema().metadata().len(), 1);
1702 assert_eq!(
1703 batch.schema().metadata().get("foo").unwrap().as_str(),
1704 "bar"
1705 );
1706 }
1707}