arrow_array/
record_batch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A two-dimensional batch of column-oriented data with a defined
19//! [schema](arrow_schema::Schema).
20
21use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27/// Trait for types that can read `RecordBatch`'s.
28///
29/// To create from an iterator, see [RecordBatchIterator].
30pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31    /// Returns the schema of this `RecordBatchReader`.
32    ///
33    /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
34    /// reader should have the same schema as returned from this method.
35    fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39    fn schema(&self) -> SchemaRef {
40        self.as_ref().schema()
41    }
42}
43
44/// Trait for types that can write `RecordBatch`'s.
45pub trait RecordBatchWriter {
46    /// Write a single batch to the writer.
47    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49    /// Write footer or termination data, then mark the writer as done.
50    fn close(self) -> Result<(), ArrowError>;
51}
52
53/// Creates an array from a literal slice of values,
54/// suitable for rapid testing and development.
55///
56/// Example:
57///
58/// ```rust
59///
60/// use arrow_array::create_array;
61///
62/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
63/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]);
64/// ```
65/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used.
66/// Presently supported data types are:
67/// - `Boolean`, `Null`
68/// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256`
69/// - `Float16`, `Float32`, `Float64`
70/// - `Int8`, `Int16`, `Int32`, `Int64`
71/// - `UInt8`, `UInt16`, `UInt32`, `UInt64`
72/// - `IntervalDayTime`, `IntervalYearMonth`
73/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond`
74/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64`
75/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond`
76/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond`
77/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary`
78#[macro_export]
79macro_rules! create_array {
80    // `@from` is used for those types that have a common method `<type>::from`
81    (@from Boolean) => { $crate::BooleanArray };
82    (@from Int8) => { $crate::Int8Array };
83    (@from Int16) => { $crate::Int16Array };
84    (@from Int32) => { $crate::Int32Array };
85    (@from Int64) => { $crate::Int64Array };
86    (@from UInt8) => { $crate::UInt8Array };
87    (@from UInt16) => { $crate::UInt16Array };
88    (@from UInt32) => { $crate::UInt32Array };
89    (@from UInt64) => { $crate::UInt64Array };
90    (@from Float16) => { $crate::Float16Array };
91    (@from Float32) => { $crate::Float32Array };
92    (@from Float64) => { $crate::Float64Array };
93    (@from Utf8) => { $crate::StringArray };
94    (@from Utf8View) => { $crate::StringViewArray };
95    (@from LargeUtf8) => { $crate::LargeStringArray };
96    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98    (@from Second) => { $crate::TimestampSecondArray };
99    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102    (@from Second32) => { $crate::Time32SecondArray };
103    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106    (@from DurationSecond) => { $crate::DurationSecondArray };
107    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110    (@from Decimal32) => { $crate::Decimal32Array };
111    (@from Decimal64) => { $crate::Decimal64Array };
112    (@from Decimal128) => { $crate::Decimal128Array };
113    (@from Decimal256) => { $crate::Decimal256Array };
114    (@from TimestampSecond) => { $crate::TimestampSecondArray };
115    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119    (@from $ty: ident) => {
120        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121    };
122
123    (Null, $size: expr) => {
124        std::sync::Arc::new($crate::NullArray::new($size))
125    };
126
127    (Binary, [$($values: expr),*]) => {
128        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129    };
130
131    (LargeBinary, [$($values: expr),*]) => {
132        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133    };
134
135    ($ty: tt, [$($values: expr),*]) => {
136        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137    };
138}
139
140/// Creates a record batch from literal slice of values, suitable for rapid
141/// testing and development.
142///
143/// Example:
144///
145/// ```rust
146/// use arrow_array::record_batch;
147/// use arrow_schema;
148///
149/// let batch = record_batch!(
150///     ("a", Int32, [1, 2, 3]),
151///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
152///     ("c", Utf8, ["alpha", "beta", "gamma"])
153/// );
154/// ```
155/// Due to limitation of [`create_array!`] macro, support for limited data types is available.
156#[macro_export]
157macro_rules! record_batch {
158    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159        {
160            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161                $(
162                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163                )*
164            ]));
165
166            let batch = $crate::RecordBatch::try_new(
167                schema,
168                vec![$(
169                    $crate::create_array!($type, [$($values),*]),
170                )*]
171            );
172
173            batch
174        }
175    }
176}
177
178/// A two-dimensional batch of column-oriented data with a defined
179/// [schema](arrow_schema::Schema).
180///
181/// A `RecordBatch` is a two-dimensional dataset of a number of
182/// contiguous arrays, each the same length.
183/// A record batch has a schema which must match its arrays’
184/// datatypes.
185///
186/// Record batches are a convenient unit of work for various
187/// serialization and computation functions, possibly incremental.
188///
189/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from
190/// literal slice of values, useful for rapid prototyping and testing.
191///
192/// Example:
193/// ```rust
194/// use arrow_array::record_batch;
195/// let batch = record_batch!(
196///     ("a", Int32, [1, 2, 3]),
197///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
198///     ("c", Utf8, ["alpha", "beta", "gamma"])
199/// );
200/// ```
201#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203    schema: SchemaRef,
204    columns: Vec<Arc<dyn Array>>,
205
206    /// The number of rows in this RecordBatch
207    ///
208    /// This is stored separately from the columns to handle the case of no columns
209    row_count: usize,
210}
211
212impl RecordBatch {
213    /// Creates a `RecordBatch` from a schema and columns.
214    ///
215    /// Expects the following:
216    ///
217    ///  * `!columns.is_empty()`
218    ///  * `schema.fields.len() == columns.len()`
219    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
220    ///  * `columns[i].len() == columns[j].len()`
221    ///
222    /// If the conditions are not met, an error is returned.
223    ///
224    /// # Example
225    ///
226    /// ```
227    /// # use std::sync::Arc;
228    /// # use arrow_array::{Int32Array, RecordBatch};
229    /// # use arrow_schema::{DataType, Field, Schema};
230    ///
231    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
232    /// let schema = Schema::new(vec![
233    ///     Field::new("id", DataType::Int32, false)
234    /// ]);
235    ///
236    /// let batch = RecordBatch::try_new(
237    ///     Arc::new(schema),
238    ///     vec![Arc::new(id_array)]
239    /// ).unwrap();
240    /// ```
241    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242        let options = RecordBatchOptions::new();
243        Self::try_new_impl(schema, columns, &options)
244    }
245
246    /// Creates a `RecordBatch` from a schema and columns, without validation.
247    ///
248    /// See [`Self::try_new`] for the checked version.
249    ///
250    /// # Safety
251    ///
252    /// Expects the following:
253    ///
254    ///  * `schema.fields.len() == columns.len()`
255    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
256    ///  * `columns[i].len() == row_count`
257    ///
258    /// Note: if the schema does not match the underlying data exactly, it can lead to undefined
259    /// behavior, for example, via conversion to a `StructArray`, which in turn could lead
260    /// to incorrect access.
261    pub unsafe fn new_unchecked(
262        schema: SchemaRef,
263        columns: Vec<Arc<dyn Array>>,
264        row_count: usize,
265    ) -> Self {
266        Self {
267            schema,
268            columns,
269            row_count,
270        }
271    }
272
273    /// Creates a `RecordBatch` from a schema and columns, with additional options,
274    /// such as whether to strictly validate field names.
275    ///
276    /// See [`RecordBatch::try_new`] for the expected conditions.
277    pub fn try_new_with_options(
278        schema: SchemaRef,
279        columns: Vec<ArrayRef>,
280        options: &RecordBatchOptions,
281    ) -> Result<Self, ArrowError> {
282        Self::try_new_impl(schema, columns, options)
283    }
284
285    /// Creates a new empty [`RecordBatch`].
286    pub fn new_empty(schema: SchemaRef) -> Self {
287        let columns = schema
288            .fields()
289            .iter()
290            .map(|field| new_empty_array(field.data_type()))
291            .collect();
292
293        RecordBatch {
294            schema,
295            columns,
296            row_count: 0,
297        }
298    }
299
300    /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
301    /// if any validation check fails, otherwise returns the created [`Self`]
302    fn try_new_impl(
303        schema: SchemaRef,
304        columns: Vec<ArrayRef>,
305        options: &RecordBatchOptions,
306    ) -> Result<Self, ArrowError> {
307        // check that number of fields in schema match column length
308        if schema.fields().len() != columns.len() {
309            return Err(ArrowError::InvalidArgumentError(format!(
310                "number of columns({}) must match number of fields({}) in schema",
311                columns.len(),
312                schema.fields().len(),
313            )));
314        }
315
316        let row_count = options
317            .row_count
318            .or_else(|| columns.first().map(|col| col.len()))
319            .ok_or_else(|| {
320                ArrowError::InvalidArgumentError(
321                    "must either specify a row count or at least one column".to_string(),
322                )
323            })?;
324
325        for (c, f) in columns.iter().zip(&schema.fields) {
326            if !f.is_nullable() && c.null_count() > 0 {
327                return Err(ArrowError::InvalidArgumentError(format!(
328                    "Column '{}' is declared as non-nullable but contains null values",
329                    f.name()
330                )));
331            }
332        }
333
334        // check that all columns have the same row count
335        if columns.iter().any(|c| c.len() != row_count) {
336            let err = match options.row_count {
337                Some(_) => "all columns in a record batch must have the specified row count",
338                None => "all columns in a record batch must have the same length",
339            };
340            return Err(ArrowError::InvalidArgumentError(err.to_string()));
341        }
342
343        // function for comparing column type and field type
344        // return true if 2 types are not matched
345        let type_not_match = if options.match_field_names {
346            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347        } else {
348            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349                !col_type.equals_datatype(field_type)
350            }
351        };
352
353        // check that all columns match the schema
354        let not_match = columns
355            .iter()
356            .zip(schema.fields().iter())
357            .map(|(col, field)| (col.data_type(), field.data_type()))
358            .enumerate()
359            .find(type_not_match);
360
361        if let Some((i, (col_type, field_type))) = not_match {
362            return Err(ArrowError::InvalidArgumentError(format!(
363                "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
364        }
365
366        Ok(RecordBatch {
367            schema,
368            columns,
369            row_count,
370        })
371    }
372
373    /// Return the schema, columns and row count of this [`RecordBatch`]
374    pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
375        (self.schema, self.columns, self.row_count)
376    }
377
378    /// Override the schema of this [`RecordBatch`]
379    ///
380    /// Returns an error if `schema` is not a superset of the current schema
381    /// as determined by [`Schema::contains`]
382    ///
383    /// See also [`Self::schema_metadata_mut`].
384    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
385        if !schema.contains(self.schema.as_ref()) {
386            return Err(ArrowError::SchemaError(format!(
387                "target schema is not superset of current schema target={schema} current={}",
388                self.schema
389            )));
390        }
391
392        Ok(Self {
393            schema,
394            columns: self.columns,
395            row_count: self.row_count,
396        })
397    }
398
399    /// Returns the [`Schema`] of the record batch.
400    pub fn schema(&self) -> SchemaRef {
401        self.schema.clone()
402    }
403
404    /// Returns a reference to the [`Schema`] of the record batch.
405    pub fn schema_ref(&self) -> &SchemaRef {
406        &self.schema
407    }
408
409    /// Mutable access to the metadata of the schema.
410    ///
411    /// This allows you to modify [`Schema::metadata`] of [`Self::schema`] in a convenient and fast way.
412    ///
413    /// Note this will clone the entire underlying `Schema` object if it is currently shared
414    ///
415    /// # Example
416    /// ```
417    /// # use std::sync::Arc;
418    /// # use arrow_array::{record_batch, RecordBatch};
419    /// let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
420    /// // Initially, the metadata is empty
421    /// assert!(batch.schema().metadata().get("key").is_none());
422    /// // Insert a key-value pair into the metadata
423    /// batch.schema_metadata_mut().insert("key".into(), "value".into());
424    /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value")));
425    /// ```    
426    pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
427        let schema = Arc::make_mut(&mut self.schema);
428        &mut schema.metadata
429    }
430
431    /// Projects the schema onto the specified columns
432    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
433        let projected_schema = self.schema.project(indices)?;
434        let batch_fields = indices
435            .iter()
436            .map(|f| {
437                self.columns.get(*f).cloned().ok_or_else(|| {
438                    ArrowError::SchemaError(format!(
439                        "project index {} out of bounds, max field {}",
440                        f,
441                        self.columns.len()
442                    ))
443                })
444            })
445            .collect::<Result<Vec<_>, _>>()?;
446
447        RecordBatch::try_new_with_options(
448            SchemaRef::new(projected_schema),
449            batch_fields,
450            &RecordBatchOptions {
451                match_field_names: true,
452                row_count: Some(self.row_count),
453            },
454        )
455    }
456
457    /// Normalize a semi-structured [`RecordBatch`] into a flat table.
458    ///
459    /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
460    /// (unlimited if `None`).
461    ///
462    /// e.g. given a [`RecordBatch`] with schema:
463    ///
464    /// ```text
465    ///     "foo": StructArray<"bar": Utf8>
466    /// ```
467    ///
468    /// A separator of `"."` would generate a batch with the schema:
469    ///
470    /// ```text
471    ///     "foo.bar": Utf8
472    /// ```
473    ///
474    /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
475    /// it will be treated as unlimited.
476    ///
477    /// # Example
478    ///
479    /// ```
480    /// # use std::sync::Arc;
481    /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
482    /// # use arrow_schema::{DataType, Field, Fields, Schema};
483    /// #
484    /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
485    /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
486    ///
487    /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
488    /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
489    ///
490    /// let a = Arc::new(StructArray::from(vec![
491    ///     (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
492    ///     (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
493    /// ]));
494    ///
495    /// let schema = Schema::new(vec![
496    ///     Field::new(
497    ///         "a",
498    ///         DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
499    ///         false,
500    ///     )
501    /// ]);
502    ///
503    /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
504    ///     .expect("valid conversion")
505    ///     .normalize(".", None)
506    ///     .expect("valid normalization");
507    ///
508    /// let expected = RecordBatch::try_from_iter_with_nullable(vec![
509    ///     ("a.animals", animals.clone(), true),
510    ///     ("a.n_legs", n_legs.clone(), true),
511    /// ])
512    /// .expect("valid conversion");
513    ///
514    /// assert_eq!(expected, normalized);
515    /// ```
516    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
517        let max_level = match max_level.unwrap_or(usize::MAX) {
518            0 => usize::MAX,
519            val => val,
520        };
521        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
522            .columns
523            .iter()
524            .zip(self.schema.fields())
525            .rev()
526            .map(|(c, f)| {
527                let name_vec: Vec<&str> = vec![f.name()];
528                (0, c, name_vec, f)
529            })
530            .collect();
531        let mut columns: Vec<ArrayRef> = Vec::new();
532        let mut fields: Vec<FieldRef> = Vec::new();
533
534        while let Some((depth, c, name, field_ref)) = stack.pop() {
535            match field_ref.data_type() {
536                DataType::Struct(ff) if depth < max_level => {
537                    // Need to zip these in reverse to maintain original order
538                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
539                        let mut name = name.clone();
540                        name.push(separator);
541                        name.push(fff.name());
542                        stack.push((depth + 1, cff, name, fff))
543                    }
544                }
545                _ => {
546                    let updated_field = Field::new(
547                        name.concat(),
548                        field_ref.data_type().clone(),
549                        field_ref.is_nullable(),
550                    );
551                    columns.push(c.clone());
552                    fields.push(Arc::new(updated_field));
553                }
554            }
555        }
556        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
557    }
558
559    /// Returns the number of columns in the record batch.
560    ///
561    /// # Example
562    ///
563    /// ```
564    /// # use std::sync::Arc;
565    /// # use arrow_array::{Int32Array, RecordBatch};
566    /// # use arrow_schema::{DataType, Field, Schema};
567    ///
568    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
569    /// let schema = Schema::new(vec![
570    ///     Field::new("id", DataType::Int32, false)
571    /// ]);
572    ///
573    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
574    ///
575    /// assert_eq!(batch.num_columns(), 1);
576    /// ```
577    pub fn num_columns(&self) -> usize {
578        self.columns.len()
579    }
580
581    /// Returns the number of rows in each column.
582    ///
583    /// # Example
584    ///
585    /// ```
586    /// # use std::sync::Arc;
587    /// # use arrow_array::{Int32Array, RecordBatch};
588    /// # use arrow_schema::{DataType, Field, Schema};
589    ///
590    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
591    /// let schema = Schema::new(vec![
592    ///     Field::new("id", DataType::Int32, false)
593    /// ]);
594    ///
595    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
596    ///
597    /// assert_eq!(batch.num_rows(), 5);
598    /// ```
599    pub fn num_rows(&self) -> usize {
600        self.row_count
601    }
602
603    /// Get a reference to a column's array by index.
604    ///
605    /// # Panics
606    ///
607    /// Panics if `index` is outside of `0..num_columns`.
608    pub fn column(&self, index: usize) -> &ArrayRef {
609        &self.columns[index]
610    }
611
612    /// Get a reference to a column's array by name.
613    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
614        self.schema()
615            .column_with_name(name)
616            .map(|(index, _)| &self.columns[index])
617    }
618
619    /// Get a reference to all columns in the record batch.
620    pub fn columns(&self) -> &[ArrayRef] {
621        &self.columns[..]
622    }
623
624    /// Remove column by index and return it.
625    ///
626    /// Return the `ArrayRef` if the column is removed.
627    ///
628    /// # Panics
629    ///
630    /// Panics if `index`` out of bounds.
631    ///
632    /// # Example
633    ///
634    /// ```
635    /// use std::sync::Arc;
636    /// use arrow_array::{BooleanArray, Int32Array, RecordBatch};
637    /// use arrow_schema::{DataType, Field, Schema};
638    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
639    /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
640    /// let schema = Schema::new(vec![
641    ///     Field::new("id", DataType::Int32, false),
642    ///     Field::new("bool", DataType::Boolean, false),
643    /// ]);
644    ///
645    /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap();
646    ///
647    /// let removed_column = batch.remove_column(0);
648    /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
649    /// assert_eq!(batch.num_columns(), 1);
650    /// ```
651    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
652        let mut builder = SchemaBuilder::from(self.schema.as_ref());
653        builder.remove(index);
654        self.schema = Arc::new(builder.finish());
655        self.columns.remove(index)
656    }
657
658    /// Return a new RecordBatch where each column is sliced
659    /// according to `offset` and `length`
660    ///
661    /// # Panics
662    ///
663    /// Panics if `offset` with `length` is greater than column length.
664    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
665        assert!((offset + length) <= self.num_rows());
666
667        let columns = self
668            .columns()
669            .iter()
670            .map(|column| column.slice(offset, length))
671            .collect();
672
673        Self {
674            schema: self.schema.clone(),
675            columns,
676            row_count: length,
677        }
678    }
679
680    /// Create a `RecordBatch` from an iterable list of pairs of the
681    /// form `(field_name, array)`, with the same requirements on
682    /// fields and arrays as [`RecordBatch::try_new`]. This method is
683    /// often used to create a single `RecordBatch` from arrays,
684    /// e.g. for testing.
685    ///
686    /// The resulting schema is marked as nullable for each column if
687    /// the array for that column is has any nulls. To explicitly
688    /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
689    ///
690    /// Example:
691    /// ```
692    /// # use std::sync::Arc;
693    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
694    ///
695    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
696    /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
697    ///
698    /// let record_batch = RecordBatch::try_from_iter(vec![
699    ///   ("a", a),
700    ///   ("b", b),
701    /// ]);
702    /// ```
703    /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro,
704    /// which is particularly helpful for rapid prototyping and testing.
705    ///
706    /// Example:
707    ///
708    /// ```rust
709    /// use arrow_array::record_batch;
710    /// let batch = record_batch!(
711    ///     ("a", Int32, [1, 2, 3]),
712    ///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
713    ///     ("c", Utf8, ["alpha", "beta", "gamma"])
714    /// );
715    /// ```
716    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
717    where
718        I: IntoIterator<Item = (F, ArrayRef)>,
719        F: AsRef<str>,
720    {
721        // TODO: implement `TryFrom` trait, once
722        // https://github.com/rust-lang/rust/issues/50133 is no longer an
723        // issue
724        let iter = value.into_iter().map(|(field_name, array)| {
725            let nullable = array.null_count() > 0;
726            (field_name, array, nullable)
727        });
728
729        Self::try_from_iter_with_nullable(iter)
730    }
731
732    /// Create a `RecordBatch` from an iterable list of tuples of the
733    /// form `(field_name, array, nullable)`, with the same requirements on
734    /// fields and arrays as [`RecordBatch::try_new`]. This method is often
735    /// used to create a single `RecordBatch` from arrays, e.g. for
736    /// testing.
737    ///
738    /// Example:
739    /// ```
740    /// # use std::sync::Arc;
741    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
742    ///
743    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
744    /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")]));
745    ///
746    /// // Note neither `a` nor `b` has any actual nulls, but we mark
747    /// // b an nullable
748    /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
749    ///   ("a", a, false),
750    ///   ("b", b, true),
751    /// ]);
752    /// ```
753    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
754    where
755        I: IntoIterator<Item = (F, ArrayRef, bool)>,
756        F: AsRef<str>,
757    {
758        let iter = value.into_iter();
759        let capacity = iter.size_hint().0;
760        let mut schema = SchemaBuilder::with_capacity(capacity);
761        let mut columns = Vec::with_capacity(capacity);
762
763        for (field_name, array, nullable) in iter {
764            let field_name = field_name.as_ref();
765            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
766            columns.push(array);
767        }
768
769        let schema = Arc::new(schema.finish());
770        RecordBatch::try_new(schema, columns)
771    }
772
773    /// Returns the total number of bytes of memory occupied physically by this batch.
774    ///
775    /// Note that this does not always correspond to the exact memory usage of a
776    /// `RecordBatch` (might overestimate), since multiple columns can share the same
777    /// buffers or slices thereof, the memory used by the shared buffers might be
778    /// counted multiple times.
779    pub fn get_array_memory_size(&self) -> usize {
780        self.columns()
781            .iter()
782            .map(|array| array.get_array_memory_size())
783            .sum()
784    }
785}
786
787/// Options that control the behaviour used when creating a [`RecordBatch`].
788#[derive(Debug)]
789#[non_exhaustive]
790pub struct RecordBatchOptions {
791    /// Match field names of structs and lists. If set to `true`, the names must match.
792    pub match_field_names: bool,
793
794    /// Optional row count, useful for specifying a row count for a RecordBatch with no columns
795    pub row_count: Option<usize>,
796}
797
798impl RecordBatchOptions {
799    /// Creates a new `RecordBatchOptions`
800    pub fn new() -> Self {
801        Self {
802            match_field_names: true,
803            row_count: None,
804        }
805    }
806    /// Sets the row_count of RecordBatchOptions and returns self
807    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
808        self.row_count = row_count;
809        self
810    }
811    /// Sets the match_field_names of RecordBatchOptions and returns self
812    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
813        self.match_field_names = match_field_names;
814        self
815    }
816}
817impl Default for RecordBatchOptions {
818    fn default() -> Self {
819        Self::new()
820    }
821}
822impl From<StructArray> for RecordBatch {
823    fn from(value: StructArray) -> Self {
824        let row_count = value.len();
825        let (fields, columns, nulls) = value.into_parts();
826        assert_eq!(
827            nulls.map(|n| n.null_count()).unwrap_or_default(),
828            0,
829            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
830        );
831
832        RecordBatch {
833            schema: Arc::new(Schema::new(fields)),
834            row_count,
835            columns,
836        }
837    }
838}
839
840impl From<&StructArray> for RecordBatch {
841    fn from(struct_array: &StructArray) -> Self {
842        struct_array.clone().into()
843    }
844}
845
846impl Index<&str> for RecordBatch {
847    type Output = ArrayRef;
848
849    /// Get a reference to a column's array by name.
850    ///
851    /// # Panics
852    ///
853    /// Panics if the name is not in the schema.
854    fn index(&self, name: &str) -> &Self::Output {
855        self.column_by_name(name).unwrap()
856    }
857}
858
859/// Generic implementation of [RecordBatchReader] that wraps an iterator.
860///
861/// # Example
862///
863/// ```
864/// # use std::sync::Arc;
865/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader};
866/// #
867/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
868/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
869///
870/// let record_batch = RecordBatch::try_from_iter(vec![
871///   ("a", a),
872///   ("b", b),
873/// ]).unwrap();
874///
875/// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()];
876///
877/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema());
878///
879/// assert_eq!(reader.schema(), record_batch.schema());
880/// assert_eq!(reader.next().unwrap().unwrap(), record_batch);
881/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch);
882/// # assert!(reader.next().is_none());
883/// ```
884pub struct RecordBatchIterator<I>
885where
886    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
887{
888    inner: I::IntoIter,
889    inner_schema: SchemaRef,
890}
891
892impl<I> RecordBatchIterator<I>
893where
894    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
895{
896    /// Create a new [RecordBatchIterator].
897    ///
898    /// If `iter` is an infallible iterator, use `.map(Ok)`.
899    pub fn new(iter: I, schema: SchemaRef) -> Self {
900        Self {
901            inner: iter.into_iter(),
902            inner_schema: schema,
903        }
904    }
905}
906
907impl<I> Iterator for RecordBatchIterator<I>
908where
909    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
910{
911    type Item = I::Item;
912
913    fn next(&mut self) -> Option<Self::Item> {
914        self.inner.next()
915    }
916
917    fn size_hint(&self) -> (usize, Option<usize>) {
918        self.inner.size_hint()
919    }
920}
921
922impl<I> RecordBatchReader for RecordBatchIterator<I>
923where
924    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
925{
926    fn schema(&self) -> SchemaRef {
927        self.inner_schema.clone()
928    }
929}
930
931#[cfg(test)]
932mod tests {
933    use super::*;
934    use crate::{
935        BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
936    };
937    use arrow_buffer::{Buffer, ToByteSlice};
938    use arrow_data::{ArrayData, ArrayDataBuilder};
939    use arrow_schema::Fields;
940    use std::collections::HashMap;
941
942    #[test]
943    fn create_record_batch() {
944        let schema = Schema::new(vec![
945            Field::new("a", DataType::Int32, false),
946            Field::new("b", DataType::Utf8, false),
947        ]);
948
949        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
950        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
951
952        let record_batch =
953            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
954        check_batch(record_batch, 5)
955    }
956
957    #[test]
958    fn create_string_view_record_batch() {
959        let schema = Schema::new(vec![
960            Field::new("a", DataType::Int32, false),
961            Field::new("b", DataType::Utf8View, false),
962        ]);
963
964        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
965        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
966
967        let record_batch =
968            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
969
970        assert_eq!(5, record_batch.num_rows());
971        assert_eq!(2, record_batch.num_columns());
972        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
973        assert_eq!(
974            &DataType::Utf8View,
975            record_batch.schema().field(1).data_type()
976        );
977        assert_eq!(5, record_batch.column(0).len());
978        assert_eq!(5, record_batch.column(1).len());
979    }
980
981    #[test]
982    fn byte_size_should_not_regress() {
983        let schema = Schema::new(vec![
984            Field::new("a", DataType::Int32, false),
985            Field::new("b", DataType::Utf8, false),
986        ]);
987
988        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
989        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
990
991        let record_batch =
992            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
993        assert_eq!(record_batch.get_array_memory_size(), 364);
994    }
995
996    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
997        assert_eq!(num_rows, record_batch.num_rows());
998        assert_eq!(2, record_batch.num_columns());
999        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1000        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1001        assert_eq!(num_rows, record_batch.column(0).len());
1002        assert_eq!(num_rows, record_batch.column(1).len());
1003    }
1004
1005    #[test]
1006    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1007    fn create_record_batch_slice() {
1008        let schema = Schema::new(vec![
1009            Field::new("a", DataType::Int32, false),
1010            Field::new("b", DataType::Utf8, false),
1011        ]);
1012        let expected_schema = schema.clone();
1013
1014        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1015        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1016
1017        let record_batch =
1018            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1019
1020        let offset = 2;
1021        let length = 5;
1022        let record_batch_slice = record_batch.slice(offset, length);
1023
1024        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1025        check_batch(record_batch_slice, 5);
1026
1027        let offset = 2;
1028        let length = 0;
1029        let record_batch_slice = record_batch.slice(offset, length);
1030
1031        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1032        check_batch(record_batch_slice, 0);
1033
1034        let offset = 2;
1035        let length = 10;
1036        let _record_batch_slice = record_batch.slice(offset, length);
1037    }
1038
1039    #[test]
1040    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1041    fn create_record_batch_slice_empty_batch() {
1042        let schema = Schema::empty();
1043
1044        let record_batch = RecordBatch::new_empty(Arc::new(schema));
1045
1046        let offset = 0;
1047        let length = 0;
1048        let record_batch_slice = record_batch.slice(offset, length);
1049        assert_eq!(0, record_batch_slice.schema().fields().len());
1050
1051        let offset = 1;
1052        let length = 2;
1053        let _record_batch_slice = record_batch.slice(offset, length);
1054    }
1055
1056    #[test]
1057    fn create_record_batch_try_from_iter() {
1058        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1059            Some(1),
1060            Some(2),
1061            None,
1062            Some(4),
1063            Some(5),
1064        ]));
1065        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1066
1067        let record_batch =
1068            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1069
1070        let expected_schema = Schema::new(vec![
1071            Field::new("a", DataType::Int32, true),
1072            Field::new("b", DataType::Utf8, false),
1073        ]);
1074        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1075        check_batch(record_batch, 5);
1076    }
1077
1078    #[test]
1079    fn create_record_batch_try_from_iter_with_nullable() {
1080        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1081        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1082
1083        // Note there are no nulls in a or b, but we specify that b is nullable
1084        let record_batch =
1085            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1086                .expect("valid conversion");
1087
1088        let expected_schema = Schema::new(vec![
1089            Field::new("a", DataType::Int32, false),
1090            Field::new("b", DataType::Utf8, true),
1091        ]);
1092        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1093        check_batch(record_batch, 5);
1094    }
1095
1096    #[test]
1097    fn create_record_batch_schema_mismatch() {
1098        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1099
1100        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1101
1102        let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1103        assert_eq!(err.to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0");
1104    }
1105
1106    #[test]
1107    fn create_record_batch_field_name_mismatch() {
1108        let fields = vec![
1109            Field::new("a1", DataType::Int32, false),
1110            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1111        ];
1112        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1113
1114        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1115        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1116        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1117            "array",
1118            DataType::Int8,
1119            false,
1120        ))))
1121        .add_child_data(a2_child.into_data())
1122        .len(2)
1123        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1124        .build()
1125        .unwrap();
1126        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1127        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1128            Field::new("aa1", DataType::Int32, false),
1129            Field::new("a2", a2.data_type().clone(), false),
1130        ])))
1131        .add_child_data(a1.into_data())
1132        .add_child_data(a2.into_data())
1133        .len(2)
1134        .build()
1135        .unwrap();
1136        let a: ArrayRef = Arc::new(StructArray::from(a));
1137
1138        // creating the batch with field name validation should fail
1139        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1140        assert!(batch.is_err());
1141
1142        // creating the batch without field name validation should pass
1143        let options = RecordBatchOptions {
1144            match_field_names: false,
1145            row_count: None,
1146        };
1147        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1148        assert!(batch.is_ok());
1149    }
1150
1151    #[test]
1152    fn create_record_batch_record_mismatch() {
1153        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1154
1155        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1156        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1157
1158        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1159        assert!(batch.is_err());
1160    }
1161
1162    #[test]
1163    fn create_record_batch_from_struct_array() {
1164        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1165        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1166        let struct_array = StructArray::from(vec![
1167            (
1168                Arc::new(Field::new("b", DataType::Boolean, false)),
1169                boolean.clone() as ArrayRef,
1170            ),
1171            (
1172                Arc::new(Field::new("c", DataType::Int32, false)),
1173                int.clone() as ArrayRef,
1174            ),
1175        ]);
1176
1177        let batch = RecordBatch::from(&struct_array);
1178        assert_eq!(2, batch.num_columns());
1179        assert_eq!(4, batch.num_rows());
1180        assert_eq!(
1181            struct_array.data_type(),
1182            &DataType::Struct(batch.schema().fields().clone())
1183        );
1184        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1185        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1186    }
1187
1188    #[test]
1189    fn record_batch_equality() {
1190        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1191        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1192        let schema1 = Schema::new(vec![
1193            Field::new("id", DataType::Int32, false),
1194            Field::new("val", DataType::Int32, false),
1195        ]);
1196
1197        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1198        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1199        let schema2 = Schema::new(vec![
1200            Field::new("id", DataType::Int32, false),
1201            Field::new("val", DataType::Int32, false),
1202        ]);
1203
1204        let batch1 = RecordBatch::try_new(
1205            Arc::new(schema1),
1206            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1207        )
1208        .unwrap();
1209
1210        let batch2 = RecordBatch::try_new(
1211            Arc::new(schema2),
1212            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1213        )
1214        .unwrap();
1215
1216        assert_eq!(batch1, batch2);
1217    }
1218
1219    /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
1220    #[test]
1221    fn record_batch_index_access() {
1222        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1223        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1224        let schema1 = Schema::new(vec![
1225            Field::new("id", DataType::Int32, false),
1226            Field::new("val", DataType::Int32, false),
1227        ]);
1228        let record_batch =
1229            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1230
1231        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1232        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1233    }
1234
1235    #[test]
1236    fn record_batch_vals_ne() {
1237        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1238        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1239        let schema1 = Schema::new(vec![
1240            Field::new("id", DataType::Int32, false),
1241            Field::new("val", DataType::Int32, false),
1242        ]);
1243
1244        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1245        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1246        let schema2 = Schema::new(vec![
1247            Field::new("id", DataType::Int32, false),
1248            Field::new("val", DataType::Int32, false),
1249        ]);
1250
1251        let batch1 = RecordBatch::try_new(
1252            Arc::new(schema1),
1253            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1254        )
1255        .unwrap();
1256
1257        let batch2 = RecordBatch::try_new(
1258            Arc::new(schema2),
1259            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1260        )
1261        .unwrap();
1262
1263        assert_ne!(batch1, batch2);
1264    }
1265
1266    #[test]
1267    fn record_batch_column_names_ne() {
1268        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1269        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1270        let schema1 = Schema::new(vec![
1271            Field::new("id", DataType::Int32, false),
1272            Field::new("val", DataType::Int32, false),
1273        ]);
1274
1275        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1276        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1277        let schema2 = Schema::new(vec![
1278            Field::new("id", DataType::Int32, false),
1279            Field::new("num", DataType::Int32, false),
1280        ]);
1281
1282        let batch1 = RecordBatch::try_new(
1283            Arc::new(schema1),
1284            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1285        )
1286        .unwrap();
1287
1288        let batch2 = RecordBatch::try_new(
1289            Arc::new(schema2),
1290            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1291        )
1292        .unwrap();
1293
1294        assert_ne!(batch1, batch2);
1295    }
1296
1297    #[test]
1298    fn record_batch_column_number_ne() {
1299        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1300        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1301        let schema1 = Schema::new(vec![
1302            Field::new("id", DataType::Int32, false),
1303            Field::new("val", DataType::Int32, false),
1304        ]);
1305
1306        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1307        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1308        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1309        let schema2 = Schema::new(vec![
1310            Field::new("id", DataType::Int32, false),
1311            Field::new("val", DataType::Int32, false),
1312            Field::new("num", DataType::Int32, false),
1313        ]);
1314
1315        let batch1 = RecordBatch::try_new(
1316            Arc::new(schema1),
1317            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1318        )
1319        .unwrap();
1320
1321        let batch2 = RecordBatch::try_new(
1322            Arc::new(schema2),
1323            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1324        )
1325        .unwrap();
1326
1327        assert_ne!(batch1, batch2);
1328    }
1329
1330    #[test]
1331    fn record_batch_row_count_ne() {
1332        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1333        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1334        let schema1 = Schema::new(vec![
1335            Field::new("id", DataType::Int32, false),
1336            Field::new("val", DataType::Int32, false),
1337        ]);
1338
1339        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1340        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1341        let schema2 = Schema::new(vec![
1342            Field::new("id", DataType::Int32, false),
1343            Field::new("num", DataType::Int32, false),
1344        ]);
1345
1346        let batch1 = RecordBatch::try_new(
1347            Arc::new(schema1),
1348            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1349        )
1350        .unwrap();
1351
1352        let batch2 = RecordBatch::try_new(
1353            Arc::new(schema2),
1354            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1355        )
1356        .unwrap();
1357
1358        assert_ne!(batch1, batch2);
1359    }
1360
1361    #[test]
1362    fn normalize_simple() {
1363        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1364        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1365        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1366
1367        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1368        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1369        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1370
1371        let a = Arc::new(StructArray::from(vec![
1372            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1373            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1374            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1375        ]));
1376
1377        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1378
1379        let schema = Schema::new(vec![
1380            Field::new(
1381                "a",
1382                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1383                false,
1384            ),
1385            Field::new("month", DataType::Int64, true),
1386        ]);
1387
1388        let normalized =
1389            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1390                .expect("valid conversion")
1391                .normalize(".", Some(0))
1392                .expect("valid normalization");
1393
1394        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1395            ("a.animals", animals.clone(), true),
1396            ("a.n_legs", n_legs.clone(), true),
1397            ("a.year", year.clone(), true),
1398            ("month", month.clone(), true),
1399        ])
1400        .expect("valid conversion");
1401
1402        assert_eq!(expected, normalized);
1403
1404        // check 0 and None have the same effect
1405        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1406            .expect("valid conversion")
1407            .normalize(".", None)
1408            .expect("valid normalization");
1409
1410        assert_eq!(expected, normalized);
1411    }
1412
1413    #[test]
1414    fn normalize_nested() {
1415        // Initialize schema
1416        let a = Arc::new(Field::new("a", DataType::Int64, true));
1417        let b = Arc::new(Field::new("b", DataType::Int64, false));
1418        let c = Arc::new(Field::new("c", DataType::Int64, true));
1419
1420        let one = Arc::new(Field::new(
1421            "1",
1422            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1423            false,
1424        ));
1425        let two = Arc::new(Field::new(
1426            "2",
1427            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1428            true,
1429        ));
1430
1431        let exclamation = Arc::new(Field::new(
1432            "!",
1433            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1434            false,
1435        ));
1436
1437        let schema = Schema::new(vec![exclamation.clone()]);
1438
1439        // Initialize fields
1440        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1441        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1442        let c_field = Int64Array::from(vec![None, Some(4)]);
1443
1444        let one_field = StructArray::from(vec![
1445            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1446            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1447            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1448        ]);
1449        let two_field = StructArray::from(vec![
1450            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1451            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1452            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1453        ]);
1454
1455        let exclamation_field = Arc::new(StructArray::from(vec![
1456            (one.clone(), Arc::new(one_field) as ArrayRef),
1457            (two.clone(), Arc::new(two_field) as ArrayRef),
1458        ]));
1459
1460        // Normalize top level
1461        let normalized =
1462            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1463                .expect("valid conversion")
1464                .normalize(".", Some(1))
1465                .expect("valid normalization");
1466
1467        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1468            (
1469                "!.1",
1470                Arc::new(StructArray::from(vec![
1471                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1472                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1473                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1474                ])) as ArrayRef,
1475                false,
1476            ),
1477            (
1478                "!.2",
1479                Arc::new(StructArray::from(vec![
1480                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1481                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1482                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1483                ])) as ArrayRef,
1484                true,
1485            ),
1486        ])
1487        .expect("valid conversion");
1488
1489        assert_eq!(expected, normalized);
1490
1491        // Normalize all levels
1492        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1493            .expect("valid conversion")
1494            .normalize(".", None)
1495            .expect("valid normalization");
1496
1497        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1498            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1499            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1500            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1501            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1502            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1503            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1504        ])
1505        .expect("valid conversion");
1506
1507        assert_eq!(expected, normalized);
1508    }
1509
1510    #[test]
1511    fn normalize_empty() {
1512        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1513        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1514        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1515
1516        let schema = Schema::new(vec![
1517            Field::new(
1518                "a",
1519                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1520                false,
1521            ),
1522            Field::new("month", DataType::Int64, true),
1523        ]);
1524
1525        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1526            .normalize(".", Some(0))
1527            .expect("valid normalization");
1528
1529        let expected = RecordBatch::new_empty(Arc::new(
1530            schema.normalize(".", Some(0)).expect("valid normalization"),
1531        ));
1532
1533        assert_eq!(expected, normalized);
1534    }
1535
1536    #[test]
1537    fn project() {
1538        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1539        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1540        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1541
1542        let record_batch =
1543            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1544                .expect("valid conversion");
1545
1546        let expected =
1547            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1548
1549        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1550    }
1551
1552    #[test]
1553    fn project_empty() {
1554        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1555
1556        let record_batch =
1557            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1558
1559        let expected = RecordBatch::try_new_with_options(
1560            Arc::new(Schema::empty()),
1561            vec![],
1562            &RecordBatchOptions {
1563                match_field_names: true,
1564                row_count: Some(3),
1565            },
1566        )
1567        .expect("valid conversion");
1568
1569        assert_eq!(expected, record_batch.project(&[]).unwrap());
1570    }
1571
1572    #[test]
1573    fn test_no_column_record_batch() {
1574        let schema = Arc::new(Schema::empty());
1575
1576        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1577        assert!(err
1578            .to_string()
1579            .contains("must either specify a row count or at least one column"));
1580
1581        let options = RecordBatchOptions::new().with_row_count(Some(10));
1582
1583        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1584        assert_eq!(ok.num_rows(), 10);
1585
1586        let a = ok.slice(2, 5);
1587        assert_eq!(a.num_rows(), 5);
1588
1589        let b = ok.slice(5, 0);
1590        assert_eq!(b.num_rows(), 0);
1591
1592        assert_ne!(a, b);
1593        assert_eq!(b, RecordBatch::new_empty(schema))
1594    }
1595
1596    #[test]
1597    fn test_nulls_in_non_nullable_field() {
1598        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1599        let maybe_batch = RecordBatch::try_new(
1600            schema,
1601            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1602        );
1603        assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1604    }
1605    #[test]
1606    fn test_record_batch_options() {
1607        let options = RecordBatchOptions::new()
1608            .with_match_field_names(false)
1609            .with_row_count(Some(20));
1610        assert!(!options.match_field_names);
1611        assert_eq!(options.row_count.unwrap(), 20)
1612    }
1613
1614    #[test]
1615    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1616    fn test_from_struct() {
1617        let s = StructArray::from(ArrayData::new_null(
1618            // Note child is not nullable
1619            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1620            2,
1621        ));
1622        let _ = RecordBatch::from(s);
1623    }
1624
1625    #[test]
1626    fn test_with_schema() {
1627        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1628        let required_schema = Arc::new(required_schema);
1629        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1630        let nullable_schema = Arc::new(nullable_schema);
1631
1632        let batch = RecordBatch::try_new(
1633            required_schema.clone(),
1634            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1635        )
1636        .unwrap();
1637
1638        // Can add nullability
1639        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1640
1641        // Cannot remove nullability
1642        batch.clone().with_schema(required_schema).unwrap_err();
1643
1644        // Can add metadata
1645        let metadata = vec![("foo".to_string(), "bar".to_string())]
1646            .into_iter()
1647            .collect();
1648        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1649        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1650
1651        // Cannot remove metadata
1652        batch.with_schema(nullable_schema).unwrap_err();
1653    }
1654
1655    #[test]
1656    fn test_boxed_reader() {
1657        // Make sure we can pass a boxed reader to a function generic over
1658        // RecordBatchReader.
1659        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1660        let schema = Arc::new(schema);
1661
1662        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1663        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1664
1665        fn get_size(reader: impl RecordBatchReader) -> usize {
1666            reader.size_hint().0
1667        }
1668
1669        let size = get_size(reader);
1670        assert_eq!(size, 0);
1671    }
1672
1673    #[test]
1674    fn test_remove_column_maintains_schema_metadata() {
1675        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1676        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1677
1678        let mut metadata = HashMap::new();
1679        metadata.insert("foo".to_string(), "bar".to_string());
1680        let schema = Schema::new(vec![
1681            Field::new("id", DataType::Int32, false),
1682            Field::new("bool", DataType::Boolean, false),
1683        ])
1684        .with_metadata(metadata);
1685
1686        let mut batch = RecordBatch::try_new(
1687            Arc::new(schema),
1688            vec![Arc::new(id_array), Arc::new(bool_array)],
1689        )
1690        .unwrap();
1691
1692        let _removed_column = batch.remove_column(0);
1693        assert_eq!(batch.schema().metadata().len(), 1);
1694        assert_eq!(
1695            batch.schema().metadata().get("foo").unwrap().as_str(),
1696            "bar"
1697        );
1698    }
1699}