arrow_array/
record_batch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A two-dimensional batch of column-oriented data with a defined
19//! [schema](arrow_schema::Schema).
20
21use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27/// Trait for types that can read `RecordBatch`'s.
28///
29/// To create from an iterator, see [RecordBatchIterator].
30pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31    /// Returns the schema of this `RecordBatchReader`.
32    ///
33    /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
34    /// reader should have the same schema as returned from this method.
35    fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39    fn schema(&self) -> SchemaRef {
40        self.as_ref().schema()
41    }
42}
43
44/// Trait for types that can write `RecordBatch`'s.
45pub trait RecordBatchWriter {
46    /// Write a single batch to the writer.
47    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49    /// Write footer or termination data, then mark the writer as done.
50    fn close(self) -> Result<(), ArrowError>;
51}
52
53/// Creates an array from a literal slice of values,
54/// suitable for rapid testing and development.
55///
56/// Example:
57///
58/// ```rust
59///
60/// use arrow_array::create_array;
61///
62/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
63/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]);
64/// ```
65/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used.
66/// Presently supported data types are:
67/// - `Boolean`, `Null`
68/// - `Decimal128`, `Decimal256`
69/// - `Float16`, `Float32`, `Float64`
70/// - `Int8`, `Int16`, `Int32`, `Int64`
71/// - `UInt8`, `UInt16`, `UInt32`, `UInt64`
72/// - `IntervalDayTime`, `IntervalYearMonth`
73/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond`
74/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64`
75/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond`
76/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond`
77/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary`
78#[macro_export]
79macro_rules! create_array {
80    // `@from` is used for those types that have a common method `<type>::from`
81    (@from Boolean) => { $crate::BooleanArray };
82    (@from Int8) => { $crate::Int8Array };
83    (@from Int16) => { $crate::Int16Array };
84    (@from Int32) => { $crate::Int32Array };
85    (@from Int64) => { $crate::Int64Array };
86    (@from UInt8) => { $crate::UInt8Array };
87    (@from UInt16) => { $crate::UInt16Array };
88    (@from UInt32) => { $crate::UInt32Array };
89    (@from UInt64) => { $crate::UInt64Array };
90    (@from Float16) => { $crate::Float16Array };
91    (@from Float32) => { $crate::Float32Array };
92    (@from Float64) => { $crate::Float64Array };
93    (@from Utf8) => { $crate::StringArray };
94    (@from Utf8View) => { $crate::StringViewArray };
95    (@from LargeUtf8) => { $crate::LargeStringArray };
96    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98    (@from Second) => { $crate::TimestampSecondArray };
99    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102    (@from Second32) => { $crate::Time32SecondArray };
103    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106    (@from DurationSecond) => { $crate::DurationSecondArray };
107    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110    (@from Decimal128) => { $crate::Decimal128Array };
111    (@from Decimal256) => { $crate::Decimal256Array };
112    (@from TimestampSecond) => { $crate::TimestampSecondArray };
113    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
114    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
115    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
116
117    (@from $ty: ident) => {
118        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
119    };
120
121    (Null, $size: expr) => {
122        std::sync::Arc::new($crate::NullArray::new($size))
123    };
124
125    (Binary, [$($values: expr),*]) => {
126        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
127    };
128
129    (LargeBinary, [$($values: expr),*]) => {
130        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
131    };
132
133    ($ty: tt, [$($values: expr),*]) => {
134        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
135    };
136}
137
138/// Creates a record batch from literal slice of values, suitable for rapid
139/// testing and development.
140///
141/// Example:
142///
143/// ```rust
144/// use arrow_array::record_batch;
145/// use arrow_schema;
146///
147/// let batch = record_batch!(
148///     ("a", Int32, [1, 2, 3]),
149///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
150///     ("c", Utf8, ["alpha", "beta", "gamma"])
151/// );
152/// ```
153/// Due to limitation of [`create_array!`] macro, support for limited data types is available.
154#[macro_export]
155macro_rules! record_batch {
156    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
157        {
158            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
159                $(
160                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
161                )*
162            ]));
163
164            let batch = $crate::RecordBatch::try_new(
165                schema,
166                vec![$(
167                    $crate::create_array!($type, [$($values),*]),
168                )*]
169            );
170
171            batch
172        }
173    }
174}
175
176/// A two-dimensional batch of column-oriented data with a defined
177/// [schema](arrow_schema::Schema).
178///
179/// A `RecordBatch` is a two-dimensional dataset of a number of
180/// contiguous arrays, each the same length.
181/// A record batch has a schema which must match its arrays’
182/// datatypes.
183///
184/// Record batches are a convenient unit of work for various
185/// serialization and computation functions, possibly incremental.
186///
187/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from
188/// literal slice of values, useful for rapid prototyping and testing.
189///
190/// Example:
191/// ```rust
192/// use arrow_array::record_batch;
193/// let batch = record_batch!(
194///     ("a", Int32, [1, 2, 3]),
195///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
196///     ("c", Utf8, ["alpha", "beta", "gamma"])
197/// );
198/// ```
199#[derive(Clone, Debug, PartialEq)]
200pub struct RecordBatch {
201    schema: SchemaRef,
202    columns: Vec<Arc<dyn Array>>,
203
204    /// The number of rows in this RecordBatch
205    ///
206    /// This is stored separately from the columns to handle the case of no columns
207    row_count: usize,
208}
209
210impl RecordBatch {
211    /// Creates a `RecordBatch` from a schema and columns.
212    ///
213    /// Expects the following:
214    ///
215    ///  * `!columns.is_empty()`
216    ///  * `schema.fields.len() == columns.len()`
217    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
218    ///  * `columns[i].len() == columns[j].len()`
219    ///
220    /// If the conditions are not met, an error is returned.
221    ///
222    /// # Example
223    ///
224    /// ```
225    /// # use std::sync::Arc;
226    /// # use arrow_array::{Int32Array, RecordBatch};
227    /// # use arrow_schema::{DataType, Field, Schema};
228    ///
229    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
230    /// let schema = Schema::new(vec![
231    ///     Field::new("id", DataType::Int32, false)
232    /// ]);
233    ///
234    /// let batch = RecordBatch::try_new(
235    ///     Arc::new(schema),
236    ///     vec![Arc::new(id_array)]
237    /// ).unwrap();
238    /// ```
239    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
240        let options = RecordBatchOptions::new();
241        Self::try_new_impl(schema, columns, &options)
242    }
243
244    /// Creates a `RecordBatch` from a schema and columns, without validation.
245    ///
246    /// See [`Self::try_new`] for the checked version.
247    ///
248    /// # Safety
249    ///
250    /// Expects the following:
251    ///
252    ///  * `schema.fields.len() == columns.len()`
253    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
254    ///  * `columns[i].len() == row_count`
255    ///
256    /// Note: if the schema does not match the underlying data exactly, it can lead to undefined
257    /// behavior, for example, via conversion to a `StructArray`, which in turn could lead
258    /// to incorrect access.
259    pub unsafe fn new_unchecked(
260        schema: SchemaRef,
261        columns: Vec<Arc<dyn Array>>,
262        row_count: usize,
263    ) -> Self {
264        Self {
265            schema,
266            columns,
267            row_count,
268        }
269    }
270
271    /// Creates a `RecordBatch` from a schema and columns, with additional options,
272    /// such as whether to strictly validate field names.
273    ///
274    /// See [`RecordBatch::try_new`] for the expected conditions.
275    pub fn try_new_with_options(
276        schema: SchemaRef,
277        columns: Vec<ArrayRef>,
278        options: &RecordBatchOptions,
279    ) -> Result<Self, ArrowError> {
280        Self::try_new_impl(schema, columns, options)
281    }
282
283    /// Creates a new empty [`RecordBatch`].
284    pub fn new_empty(schema: SchemaRef) -> Self {
285        let columns = schema
286            .fields()
287            .iter()
288            .map(|field| new_empty_array(field.data_type()))
289            .collect();
290
291        RecordBatch {
292            schema,
293            columns,
294            row_count: 0,
295        }
296    }
297
298    /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
299    /// if any validation check fails, otherwise returns the created [`Self`]
300    fn try_new_impl(
301        schema: SchemaRef,
302        columns: Vec<ArrayRef>,
303        options: &RecordBatchOptions,
304    ) -> Result<Self, ArrowError> {
305        // check that number of fields in schema match column length
306        if schema.fields().len() != columns.len() {
307            return Err(ArrowError::InvalidArgumentError(format!(
308                "number of columns({}) must match number of fields({}) in schema",
309                columns.len(),
310                schema.fields().len(),
311            )));
312        }
313
314        let row_count = options
315            .row_count
316            .or_else(|| columns.first().map(|col| col.len()))
317            .ok_or_else(|| {
318                ArrowError::InvalidArgumentError(
319                    "must either specify a row count or at least one column".to_string(),
320                )
321            })?;
322
323        for (c, f) in columns.iter().zip(&schema.fields) {
324            if !f.is_nullable() && c.null_count() > 0 {
325                return Err(ArrowError::InvalidArgumentError(format!(
326                    "Column '{}' is declared as non-nullable but contains null values",
327                    f.name()
328                )));
329            }
330        }
331
332        // check that all columns have the same row count
333        if columns.iter().any(|c| c.len() != row_count) {
334            let err = match options.row_count {
335                Some(_) => "all columns in a record batch must have the specified row count",
336                None => "all columns in a record batch must have the same length",
337            };
338            return Err(ArrowError::InvalidArgumentError(err.to_string()));
339        }
340
341        // function for comparing column type and field type
342        // return true if 2 types are not matched
343        let type_not_match = if options.match_field_names {
344            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
345        } else {
346            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
347                !col_type.equals_datatype(field_type)
348            }
349        };
350
351        // check that all columns match the schema
352        let not_match = columns
353            .iter()
354            .zip(schema.fields().iter())
355            .map(|(col, field)| (col.data_type(), field.data_type()))
356            .enumerate()
357            .find(type_not_match);
358
359        if let Some((i, (col_type, field_type))) = not_match {
360            return Err(ArrowError::InvalidArgumentError(format!(
361                "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
362        }
363
364        Ok(RecordBatch {
365            schema,
366            columns,
367            row_count,
368        })
369    }
370
371    /// Return the schema, columns and row count of this [`RecordBatch`]
372    pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
373        (self.schema, self.columns, self.row_count)
374    }
375
376    /// Override the schema of this [`RecordBatch`]
377    ///
378    /// Returns an error if `schema` is not a superset of the current schema
379    /// as determined by [`Schema::contains`]
380    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
381        if !schema.contains(self.schema.as_ref()) {
382            return Err(ArrowError::SchemaError(format!(
383                "target schema is not superset of current schema target={schema} current={}",
384                self.schema
385            )));
386        }
387
388        Ok(Self {
389            schema,
390            columns: self.columns,
391            row_count: self.row_count,
392        })
393    }
394
395    /// Returns the [`Schema`] of the record batch.
396    pub fn schema(&self) -> SchemaRef {
397        self.schema.clone()
398    }
399
400    /// Returns a reference to the [`Schema`] of the record batch.
401    pub fn schema_ref(&self) -> &SchemaRef {
402        &self.schema
403    }
404
405    /// Projects the schema onto the specified columns
406    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
407        let projected_schema = self.schema.project(indices)?;
408        let batch_fields = indices
409            .iter()
410            .map(|f| {
411                self.columns.get(*f).cloned().ok_or_else(|| {
412                    ArrowError::SchemaError(format!(
413                        "project index {} out of bounds, max field {}",
414                        f,
415                        self.columns.len()
416                    ))
417                })
418            })
419            .collect::<Result<Vec<_>, _>>()?;
420
421        RecordBatch::try_new_with_options(
422            SchemaRef::new(projected_schema),
423            batch_fields,
424            &RecordBatchOptions {
425                match_field_names: true,
426                row_count: Some(self.row_count),
427            },
428        )
429    }
430
431    /// Normalize a semi-structured [`RecordBatch`] into a flat table.
432    ///
433    /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
434    /// (unlimited if `None`).
435    ///
436    /// e.g. given a [`RecordBatch`] with schema:
437    ///
438    /// ```text
439    ///     "foo": StructArray<"bar": Utf8>
440    /// ```
441    ///
442    /// A separator of `"."` would generate a batch with the schema:
443    ///
444    /// ```text
445    ///     "foo.bar": Utf8
446    /// ```
447    ///
448    /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
449    /// it will be treated as unlimited.
450    ///
451    /// # Example
452    ///
453    /// ```
454    /// # use std::sync::Arc;
455    /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
456    /// # use arrow_schema::{DataType, Field, Fields, Schema};
457    /// #
458    /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
459    /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
460    ///
461    /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
462    /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
463    ///
464    /// let a = Arc::new(StructArray::from(vec![
465    ///     (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
466    ///     (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
467    /// ]));
468    ///
469    /// let schema = Schema::new(vec![
470    ///     Field::new(
471    ///         "a",
472    ///         DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
473    ///         false,
474    ///     )
475    /// ]);
476    ///
477    /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
478    ///     .expect("valid conversion")
479    ///     .normalize(".", None)
480    ///     .expect("valid normalization");
481    ///
482    /// let expected = RecordBatch::try_from_iter_with_nullable(vec![
483    ///     ("a.animals", animals.clone(), true),
484    ///     ("a.n_legs", n_legs.clone(), true),
485    /// ])
486    /// .expect("valid conversion");
487    ///
488    /// assert_eq!(expected, normalized);
489    /// ```
490    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
491        let max_level = match max_level.unwrap_or(usize::MAX) {
492            0 => usize::MAX,
493            val => val,
494        };
495        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
496            .columns
497            .iter()
498            .zip(self.schema.fields())
499            .rev()
500            .map(|(c, f)| {
501                let name_vec: Vec<&str> = vec![f.name()];
502                (0, c, name_vec, f)
503            })
504            .collect();
505        let mut columns: Vec<ArrayRef> = Vec::new();
506        let mut fields: Vec<FieldRef> = Vec::new();
507
508        while let Some((depth, c, name, field_ref)) = stack.pop() {
509            match field_ref.data_type() {
510                DataType::Struct(ff) if depth < max_level => {
511                    // Need to zip these in reverse to maintain original order
512                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
513                        let mut name = name.clone();
514                        name.push(separator);
515                        name.push(fff.name());
516                        stack.push((depth + 1, cff, name, fff))
517                    }
518                }
519                _ => {
520                    let updated_field = Field::new(
521                        name.concat(),
522                        field_ref.data_type().clone(),
523                        field_ref.is_nullable(),
524                    );
525                    columns.push(c.clone());
526                    fields.push(Arc::new(updated_field));
527                }
528            }
529        }
530        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
531    }
532
533    /// Returns the number of columns in the record batch.
534    ///
535    /// # Example
536    ///
537    /// ```
538    /// # use std::sync::Arc;
539    /// # use arrow_array::{Int32Array, RecordBatch};
540    /// # use arrow_schema::{DataType, Field, Schema};
541    ///
542    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
543    /// let schema = Schema::new(vec![
544    ///     Field::new("id", DataType::Int32, false)
545    /// ]);
546    ///
547    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
548    ///
549    /// assert_eq!(batch.num_columns(), 1);
550    /// ```
551    pub fn num_columns(&self) -> usize {
552        self.columns.len()
553    }
554
555    /// Returns the number of rows in each column.
556    ///
557    /// # Example
558    ///
559    /// ```
560    /// # use std::sync::Arc;
561    /// # use arrow_array::{Int32Array, RecordBatch};
562    /// # use arrow_schema::{DataType, Field, Schema};
563    ///
564    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
565    /// let schema = Schema::new(vec![
566    ///     Field::new("id", DataType::Int32, false)
567    /// ]);
568    ///
569    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
570    ///
571    /// assert_eq!(batch.num_rows(), 5);
572    /// ```
573    pub fn num_rows(&self) -> usize {
574        self.row_count
575    }
576
577    /// Get a reference to a column's array by index.
578    ///
579    /// # Panics
580    ///
581    /// Panics if `index` is outside of `0..num_columns`.
582    pub fn column(&self, index: usize) -> &ArrayRef {
583        &self.columns[index]
584    }
585
586    /// Get a reference to a column's array by name.
587    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
588        self.schema()
589            .column_with_name(name)
590            .map(|(index, _)| &self.columns[index])
591    }
592
593    /// Get a reference to all columns in the record batch.
594    pub fn columns(&self) -> &[ArrayRef] {
595        &self.columns[..]
596    }
597
598    /// Remove column by index and return it.
599    ///
600    /// Return the `ArrayRef` if the column is removed.
601    ///
602    /// # Panics
603    ///
604    /// Panics if `index`` out of bounds.
605    ///
606    /// # Example
607    ///
608    /// ```
609    /// use std::sync::Arc;
610    /// use arrow_array::{BooleanArray, Int32Array, RecordBatch};
611    /// use arrow_schema::{DataType, Field, Schema};
612    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
613    /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
614    /// let schema = Schema::new(vec![
615    ///     Field::new("id", DataType::Int32, false),
616    ///     Field::new("bool", DataType::Boolean, false),
617    /// ]);
618    ///
619    /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap();
620    ///
621    /// let removed_column = batch.remove_column(0);
622    /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
623    /// assert_eq!(batch.num_columns(), 1);
624    /// ```
625    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
626        let mut builder = SchemaBuilder::from(self.schema.as_ref());
627        builder.remove(index);
628        self.schema = Arc::new(builder.finish());
629        self.columns.remove(index)
630    }
631
632    /// Return a new RecordBatch where each column is sliced
633    /// according to `offset` and `length`
634    ///
635    /// # Panics
636    ///
637    /// Panics if `offset` with `length` is greater than column length.
638    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
639        assert!((offset + length) <= self.num_rows());
640
641        let columns = self
642            .columns()
643            .iter()
644            .map(|column| column.slice(offset, length))
645            .collect();
646
647        Self {
648            schema: self.schema.clone(),
649            columns,
650            row_count: length,
651        }
652    }
653
654    /// Create a `RecordBatch` from an iterable list of pairs of the
655    /// form `(field_name, array)`, with the same requirements on
656    /// fields and arrays as [`RecordBatch::try_new`]. This method is
657    /// often used to create a single `RecordBatch` from arrays,
658    /// e.g. for testing.
659    ///
660    /// The resulting schema is marked as nullable for each column if
661    /// the array for that column is has any nulls. To explicitly
662    /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
663    ///
664    /// Example:
665    /// ```
666    /// # use std::sync::Arc;
667    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
668    ///
669    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
670    /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
671    ///
672    /// let record_batch = RecordBatch::try_from_iter(vec![
673    ///   ("a", a),
674    ///   ("b", b),
675    /// ]);
676    /// ```
677    /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro,
678    /// which is particularly helpful for rapid prototyping and testing.
679    ///
680    /// Example:
681    ///
682    /// ```rust
683    /// use arrow_array::record_batch;
684    /// let batch = record_batch!(
685    ///     ("a", Int32, [1, 2, 3]),
686    ///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
687    ///     ("c", Utf8, ["alpha", "beta", "gamma"])
688    /// );
689    /// ```
690    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
691    where
692        I: IntoIterator<Item = (F, ArrayRef)>,
693        F: AsRef<str>,
694    {
695        // TODO: implement `TryFrom` trait, once
696        // https://github.com/rust-lang/rust/issues/50133 is no longer an
697        // issue
698        let iter = value.into_iter().map(|(field_name, array)| {
699            let nullable = array.null_count() > 0;
700            (field_name, array, nullable)
701        });
702
703        Self::try_from_iter_with_nullable(iter)
704    }
705
706    /// Create a `RecordBatch` from an iterable list of tuples of the
707    /// form `(field_name, array, nullable)`, with the same requirements on
708    /// fields and arrays as [`RecordBatch::try_new`]. This method is often
709    /// used to create a single `RecordBatch` from arrays, e.g. for
710    /// testing.
711    ///
712    /// Example:
713    /// ```
714    /// # use std::sync::Arc;
715    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
716    ///
717    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
718    /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")]));
719    ///
720    /// // Note neither `a` nor `b` has any actual nulls, but we mark
721    /// // b an nullable
722    /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
723    ///   ("a", a, false),
724    ///   ("b", b, true),
725    /// ]);
726    /// ```
727    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
728    where
729        I: IntoIterator<Item = (F, ArrayRef, bool)>,
730        F: AsRef<str>,
731    {
732        let iter = value.into_iter();
733        let capacity = iter.size_hint().0;
734        let mut schema = SchemaBuilder::with_capacity(capacity);
735        let mut columns = Vec::with_capacity(capacity);
736
737        for (field_name, array, nullable) in iter {
738            let field_name = field_name.as_ref();
739            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
740            columns.push(array);
741        }
742
743        let schema = Arc::new(schema.finish());
744        RecordBatch::try_new(schema, columns)
745    }
746
747    /// Returns the total number of bytes of memory occupied physically by this batch.
748    ///
749    /// Note that this does not always correspond to the exact memory usage of a
750    /// `RecordBatch` (might overestimate), since multiple columns can share the same
751    /// buffers or slices thereof, the memory used by the shared buffers might be
752    /// counted multiple times.
753    pub fn get_array_memory_size(&self) -> usize {
754        self.columns()
755            .iter()
756            .map(|array| array.get_array_memory_size())
757            .sum()
758    }
759}
760
761/// Options that control the behaviour used when creating a [`RecordBatch`].
762#[derive(Debug)]
763#[non_exhaustive]
764pub struct RecordBatchOptions {
765    /// Match field names of structs and lists. If set to `true`, the names must match.
766    pub match_field_names: bool,
767
768    /// Optional row count, useful for specifying a row count for a RecordBatch with no columns
769    pub row_count: Option<usize>,
770}
771
772impl RecordBatchOptions {
773    /// Creates a new `RecordBatchOptions`
774    pub fn new() -> Self {
775        Self {
776            match_field_names: true,
777            row_count: None,
778        }
779    }
780    /// Sets the row_count of RecordBatchOptions and returns self
781    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
782        self.row_count = row_count;
783        self
784    }
785    /// Sets the match_field_names of RecordBatchOptions and returns self
786    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
787        self.match_field_names = match_field_names;
788        self
789    }
790}
791impl Default for RecordBatchOptions {
792    fn default() -> Self {
793        Self::new()
794    }
795}
796impl From<StructArray> for RecordBatch {
797    fn from(value: StructArray) -> Self {
798        let row_count = value.len();
799        let (fields, columns, nulls) = value.into_parts();
800        assert_eq!(
801            nulls.map(|n| n.null_count()).unwrap_or_default(),
802            0,
803            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
804        );
805
806        RecordBatch {
807            schema: Arc::new(Schema::new(fields)),
808            row_count,
809            columns,
810        }
811    }
812}
813
814impl From<&StructArray> for RecordBatch {
815    fn from(struct_array: &StructArray) -> Self {
816        struct_array.clone().into()
817    }
818}
819
820impl Index<&str> for RecordBatch {
821    type Output = ArrayRef;
822
823    /// Get a reference to a column's array by name.
824    ///
825    /// # Panics
826    ///
827    /// Panics if the name is not in the schema.
828    fn index(&self, name: &str) -> &Self::Output {
829        self.column_by_name(name).unwrap()
830    }
831}
832
833/// Generic implementation of [RecordBatchReader] that wraps an iterator.
834///
835/// # Example
836///
837/// ```
838/// # use std::sync::Arc;
839/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader};
840/// #
841/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
842/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
843///
844/// let record_batch = RecordBatch::try_from_iter(vec![
845///   ("a", a),
846///   ("b", b),
847/// ]).unwrap();
848///
849/// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()];
850///
851/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema());
852///
853/// assert_eq!(reader.schema(), record_batch.schema());
854/// assert_eq!(reader.next().unwrap().unwrap(), record_batch);
855/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch);
856/// # assert!(reader.next().is_none());
857/// ```
858pub struct RecordBatchIterator<I>
859where
860    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
861{
862    inner: I::IntoIter,
863    inner_schema: SchemaRef,
864}
865
866impl<I> RecordBatchIterator<I>
867where
868    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
869{
870    /// Create a new [RecordBatchIterator].
871    ///
872    /// If `iter` is an infallible iterator, use `.map(Ok)`.
873    pub fn new(iter: I, schema: SchemaRef) -> Self {
874        Self {
875            inner: iter.into_iter(),
876            inner_schema: schema,
877        }
878    }
879}
880
881impl<I> Iterator for RecordBatchIterator<I>
882where
883    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
884{
885    type Item = I::Item;
886
887    fn next(&mut self) -> Option<Self::Item> {
888        self.inner.next()
889    }
890
891    fn size_hint(&self) -> (usize, Option<usize>) {
892        self.inner.size_hint()
893    }
894}
895
896impl<I> RecordBatchReader for RecordBatchIterator<I>
897where
898    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
899{
900    fn schema(&self) -> SchemaRef {
901        self.inner_schema.clone()
902    }
903}
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use crate::{
909        BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
910    };
911    use arrow_buffer::{Buffer, ToByteSlice};
912    use arrow_data::{ArrayData, ArrayDataBuilder};
913    use arrow_schema::Fields;
914    use std::collections::HashMap;
915
916    #[test]
917    fn create_record_batch() {
918        let schema = Schema::new(vec![
919            Field::new("a", DataType::Int32, false),
920            Field::new("b", DataType::Utf8, false),
921        ]);
922
923        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
924        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
925
926        let record_batch =
927            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
928        check_batch(record_batch, 5)
929    }
930
931    #[test]
932    fn create_string_view_record_batch() {
933        let schema = Schema::new(vec![
934            Field::new("a", DataType::Int32, false),
935            Field::new("b", DataType::Utf8View, false),
936        ]);
937
938        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
939        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
940
941        let record_batch =
942            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
943
944        assert_eq!(5, record_batch.num_rows());
945        assert_eq!(2, record_batch.num_columns());
946        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
947        assert_eq!(
948            &DataType::Utf8View,
949            record_batch.schema().field(1).data_type()
950        );
951        assert_eq!(5, record_batch.column(0).len());
952        assert_eq!(5, record_batch.column(1).len());
953    }
954
955    #[test]
956    fn byte_size_should_not_regress() {
957        let schema = Schema::new(vec![
958            Field::new("a", DataType::Int32, false),
959            Field::new("b", DataType::Utf8, false),
960        ]);
961
962        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
963        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
964
965        let record_batch =
966            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
967        assert_eq!(record_batch.get_array_memory_size(), 364);
968    }
969
970    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
971        assert_eq!(num_rows, record_batch.num_rows());
972        assert_eq!(2, record_batch.num_columns());
973        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
974        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
975        assert_eq!(num_rows, record_batch.column(0).len());
976        assert_eq!(num_rows, record_batch.column(1).len());
977    }
978
979    #[test]
980    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
981    fn create_record_batch_slice() {
982        let schema = Schema::new(vec![
983            Field::new("a", DataType::Int32, false),
984            Field::new("b", DataType::Utf8, false),
985        ]);
986        let expected_schema = schema.clone();
987
988        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
989        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
990
991        let record_batch =
992            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
993
994        let offset = 2;
995        let length = 5;
996        let record_batch_slice = record_batch.slice(offset, length);
997
998        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
999        check_batch(record_batch_slice, 5);
1000
1001        let offset = 2;
1002        let length = 0;
1003        let record_batch_slice = record_batch.slice(offset, length);
1004
1005        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1006        check_batch(record_batch_slice, 0);
1007
1008        let offset = 2;
1009        let length = 10;
1010        let _record_batch_slice = record_batch.slice(offset, length);
1011    }
1012
1013    #[test]
1014    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1015    fn create_record_batch_slice_empty_batch() {
1016        let schema = Schema::empty();
1017
1018        let record_batch = RecordBatch::new_empty(Arc::new(schema));
1019
1020        let offset = 0;
1021        let length = 0;
1022        let record_batch_slice = record_batch.slice(offset, length);
1023        assert_eq!(0, record_batch_slice.schema().fields().len());
1024
1025        let offset = 1;
1026        let length = 2;
1027        let _record_batch_slice = record_batch.slice(offset, length);
1028    }
1029
1030    #[test]
1031    fn create_record_batch_try_from_iter() {
1032        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1033            Some(1),
1034            Some(2),
1035            None,
1036            Some(4),
1037            Some(5),
1038        ]));
1039        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1040
1041        let record_batch =
1042            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1043
1044        let expected_schema = Schema::new(vec![
1045            Field::new("a", DataType::Int32, true),
1046            Field::new("b", DataType::Utf8, false),
1047        ]);
1048        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1049        check_batch(record_batch, 5);
1050    }
1051
1052    #[test]
1053    fn create_record_batch_try_from_iter_with_nullable() {
1054        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1055        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1056
1057        // Note there are no nulls in a or b, but we specify that b is nullable
1058        let record_batch =
1059            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1060                .expect("valid conversion");
1061
1062        let expected_schema = Schema::new(vec![
1063            Field::new("a", DataType::Int32, false),
1064            Field::new("b", DataType::Utf8, true),
1065        ]);
1066        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1067        check_batch(record_batch, 5);
1068    }
1069
1070    #[test]
1071    fn create_record_batch_schema_mismatch() {
1072        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1073
1074        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1075
1076        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
1077        assert!(batch.is_err());
1078    }
1079
1080    #[test]
1081    fn create_record_batch_field_name_mismatch() {
1082        let fields = vec![
1083            Field::new("a1", DataType::Int32, false),
1084            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1085        ];
1086        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1087
1088        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1089        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1090        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1091            "array",
1092            DataType::Int8,
1093            false,
1094        ))))
1095        .add_child_data(a2_child.into_data())
1096        .len(2)
1097        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1098        .build()
1099        .unwrap();
1100        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1101        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1102            Field::new("aa1", DataType::Int32, false),
1103            Field::new("a2", a2.data_type().clone(), false),
1104        ])))
1105        .add_child_data(a1.into_data())
1106        .add_child_data(a2.into_data())
1107        .len(2)
1108        .build()
1109        .unwrap();
1110        let a: ArrayRef = Arc::new(StructArray::from(a));
1111
1112        // creating the batch with field name validation should fail
1113        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1114        assert!(batch.is_err());
1115
1116        // creating the batch without field name validation should pass
1117        let options = RecordBatchOptions {
1118            match_field_names: false,
1119            row_count: None,
1120        };
1121        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1122        assert!(batch.is_ok());
1123    }
1124
1125    #[test]
1126    fn create_record_batch_record_mismatch() {
1127        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1128
1129        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1130        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1131
1132        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1133        assert!(batch.is_err());
1134    }
1135
1136    #[test]
1137    fn create_record_batch_from_struct_array() {
1138        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1139        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1140        let struct_array = StructArray::from(vec![
1141            (
1142                Arc::new(Field::new("b", DataType::Boolean, false)),
1143                boolean.clone() as ArrayRef,
1144            ),
1145            (
1146                Arc::new(Field::new("c", DataType::Int32, false)),
1147                int.clone() as ArrayRef,
1148            ),
1149        ]);
1150
1151        let batch = RecordBatch::from(&struct_array);
1152        assert_eq!(2, batch.num_columns());
1153        assert_eq!(4, batch.num_rows());
1154        assert_eq!(
1155            struct_array.data_type(),
1156            &DataType::Struct(batch.schema().fields().clone())
1157        );
1158        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1159        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1160    }
1161
1162    #[test]
1163    fn record_batch_equality() {
1164        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1165        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1166        let schema1 = Schema::new(vec![
1167            Field::new("id", DataType::Int32, false),
1168            Field::new("val", DataType::Int32, false),
1169        ]);
1170
1171        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1172        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1173        let schema2 = Schema::new(vec![
1174            Field::new("id", DataType::Int32, false),
1175            Field::new("val", DataType::Int32, false),
1176        ]);
1177
1178        let batch1 = RecordBatch::try_new(
1179            Arc::new(schema1),
1180            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1181        )
1182        .unwrap();
1183
1184        let batch2 = RecordBatch::try_new(
1185            Arc::new(schema2),
1186            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1187        )
1188        .unwrap();
1189
1190        assert_eq!(batch1, batch2);
1191    }
1192
1193    /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
1194    #[test]
1195    fn record_batch_index_access() {
1196        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1197        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1198        let schema1 = Schema::new(vec![
1199            Field::new("id", DataType::Int32, false),
1200            Field::new("val", DataType::Int32, false),
1201        ]);
1202        let record_batch =
1203            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1204
1205        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1206        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1207    }
1208
1209    #[test]
1210    fn record_batch_vals_ne() {
1211        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1212        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1213        let schema1 = Schema::new(vec![
1214            Field::new("id", DataType::Int32, false),
1215            Field::new("val", DataType::Int32, false),
1216        ]);
1217
1218        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1219        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1220        let schema2 = Schema::new(vec![
1221            Field::new("id", DataType::Int32, false),
1222            Field::new("val", DataType::Int32, false),
1223        ]);
1224
1225        let batch1 = RecordBatch::try_new(
1226            Arc::new(schema1),
1227            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1228        )
1229        .unwrap();
1230
1231        let batch2 = RecordBatch::try_new(
1232            Arc::new(schema2),
1233            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1234        )
1235        .unwrap();
1236
1237        assert_ne!(batch1, batch2);
1238    }
1239
1240    #[test]
1241    fn record_batch_column_names_ne() {
1242        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1243        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1244        let schema1 = Schema::new(vec![
1245            Field::new("id", DataType::Int32, false),
1246            Field::new("val", DataType::Int32, false),
1247        ]);
1248
1249        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1250        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1251        let schema2 = Schema::new(vec![
1252            Field::new("id", DataType::Int32, false),
1253            Field::new("num", DataType::Int32, false),
1254        ]);
1255
1256        let batch1 = RecordBatch::try_new(
1257            Arc::new(schema1),
1258            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1259        )
1260        .unwrap();
1261
1262        let batch2 = RecordBatch::try_new(
1263            Arc::new(schema2),
1264            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1265        )
1266        .unwrap();
1267
1268        assert_ne!(batch1, batch2);
1269    }
1270
1271    #[test]
1272    fn record_batch_column_number_ne() {
1273        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1274        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1275        let schema1 = Schema::new(vec![
1276            Field::new("id", DataType::Int32, false),
1277            Field::new("val", DataType::Int32, false),
1278        ]);
1279
1280        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1281        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1282        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283        let schema2 = Schema::new(vec![
1284            Field::new("id", DataType::Int32, false),
1285            Field::new("val", DataType::Int32, false),
1286            Field::new("num", DataType::Int32, false),
1287        ]);
1288
1289        let batch1 = RecordBatch::try_new(
1290            Arc::new(schema1),
1291            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1292        )
1293        .unwrap();
1294
1295        let batch2 = RecordBatch::try_new(
1296            Arc::new(schema2),
1297            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1298        )
1299        .unwrap();
1300
1301        assert_ne!(batch1, batch2);
1302    }
1303
1304    #[test]
1305    fn record_batch_row_count_ne() {
1306        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1307        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1308        let schema1 = Schema::new(vec![
1309            Field::new("id", DataType::Int32, false),
1310            Field::new("val", DataType::Int32, false),
1311        ]);
1312
1313        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1314        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315        let schema2 = Schema::new(vec![
1316            Field::new("id", DataType::Int32, false),
1317            Field::new("num", DataType::Int32, false),
1318        ]);
1319
1320        let batch1 = RecordBatch::try_new(
1321            Arc::new(schema1),
1322            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1323        )
1324        .unwrap();
1325
1326        let batch2 = RecordBatch::try_new(
1327            Arc::new(schema2),
1328            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1329        )
1330        .unwrap();
1331
1332        assert_ne!(batch1, batch2);
1333    }
1334
1335    #[test]
1336    fn normalize_simple() {
1337        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1338        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1339        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1340
1341        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1342        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1343        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1344
1345        let a = Arc::new(StructArray::from(vec![
1346            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1347            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1348            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1349        ]));
1350
1351        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1352
1353        let schema = Schema::new(vec![
1354            Field::new(
1355                "a",
1356                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1357                false,
1358            ),
1359            Field::new("month", DataType::Int64, true),
1360        ]);
1361
1362        let normalized =
1363            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1364                .expect("valid conversion")
1365                .normalize(".", Some(0))
1366                .expect("valid normalization");
1367
1368        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1369            ("a.animals", animals.clone(), true),
1370            ("a.n_legs", n_legs.clone(), true),
1371            ("a.year", year.clone(), true),
1372            ("month", month.clone(), true),
1373        ])
1374        .expect("valid conversion");
1375
1376        assert_eq!(expected, normalized);
1377
1378        // check 0 and None have the same effect
1379        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1380            .expect("valid conversion")
1381            .normalize(".", None)
1382            .expect("valid normalization");
1383
1384        assert_eq!(expected, normalized);
1385    }
1386
1387    #[test]
1388    fn normalize_nested() {
1389        // Initialize schema
1390        let a = Arc::new(Field::new("a", DataType::Int64, true));
1391        let b = Arc::new(Field::new("b", DataType::Int64, false));
1392        let c = Arc::new(Field::new("c", DataType::Int64, true));
1393
1394        let one = Arc::new(Field::new(
1395            "1",
1396            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1397            false,
1398        ));
1399        let two = Arc::new(Field::new(
1400            "2",
1401            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1402            true,
1403        ));
1404
1405        let exclamation = Arc::new(Field::new(
1406            "!",
1407            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1408            false,
1409        ));
1410
1411        let schema = Schema::new(vec![exclamation.clone()]);
1412
1413        // Initialize fields
1414        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1415        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1416        let c_field = Int64Array::from(vec![None, Some(4)]);
1417
1418        let one_field = StructArray::from(vec![
1419            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1420            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1421            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1422        ]);
1423        let two_field = StructArray::from(vec![
1424            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1425            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1426            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1427        ]);
1428
1429        let exclamation_field = Arc::new(StructArray::from(vec![
1430            (one.clone(), Arc::new(one_field) as ArrayRef),
1431            (two.clone(), Arc::new(two_field) as ArrayRef),
1432        ]));
1433
1434        // Normalize top level
1435        let normalized =
1436            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1437                .expect("valid conversion")
1438                .normalize(".", Some(1))
1439                .expect("valid normalization");
1440
1441        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1442            (
1443                "!.1",
1444                Arc::new(StructArray::from(vec![
1445                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1446                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1447                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1448                ])) as ArrayRef,
1449                false,
1450            ),
1451            (
1452                "!.2",
1453                Arc::new(StructArray::from(vec![
1454                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1455                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1456                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1457                ])) as ArrayRef,
1458                true,
1459            ),
1460        ])
1461        .expect("valid conversion");
1462
1463        assert_eq!(expected, normalized);
1464
1465        // Normalize all levels
1466        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1467            .expect("valid conversion")
1468            .normalize(".", None)
1469            .expect("valid normalization");
1470
1471        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1472            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1473            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1474            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1475            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1476            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1477            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1478        ])
1479        .expect("valid conversion");
1480
1481        assert_eq!(expected, normalized);
1482    }
1483
1484    #[test]
1485    fn normalize_empty() {
1486        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1487        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1488        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1489
1490        let schema = Schema::new(vec![
1491            Field::new(
1492                "a",
1493                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1494                false,
1495            ),
1496            Field::new("month", DataType::Int64, true),
1497        ]);
1498
1499        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1500            .normalize(".", Some(0))
1501            .expect("valid normalization");
1502
1503        let expected = RecordBatch::new_empty(Arc::new(
1504            schema.normalize(".", Some(0)).expect("valid normalization"),
1505        ));
1506
1507        assert_eq!(expected, normalized);
1508    }
1509
1510    #[test]
1511    fn project() {
1512        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1513        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1514        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1515
1516        let record_batch =
1517            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1518                .expect("valid conversion");
1519
1520        let expected =
1521            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1522
1523        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1524    }
1525
1526    #[test]
1527    fn project_empty() {
1528        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1529
1530        let record_batch =
1531            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1532
1533        let expected = RecordBatch::try_new_with_options(
1534            Arc::new(Schema::empty()),
1535            vec![],
1536            &RecordBatchOptions {
1537                match_field_names: true,
1538                row_count: Some(3),
1539            },
1540        )
1541        .expect("valid conversion");
1542
1543        assert_eq!(expected, record_batch.project(&[]).unwrap());
1544    }
1545
1546    #[test]
1547    fn test_no_column_record_batch() {
1548        let schema = Arc::new(Schema::empty());
1549
1550        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1551        assert!(err
1552            .to_string()
1553            .contains("must either specify a row count or at least one column"));
1554
1555        let options = RecordBatchOptions::new().with_row_count(Some(10));
1556
1557        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1558        assert_eq!(ok.num_rows(), 10);
1559
1560        let a = ok.slice(2, 5);
1561        assert_eq!(a.num_rows(), 5);
1562
1563        let b = ok.slice(5, 0);
1564        assert_eq!(b.num_rows(), 0);
1565
1566        assert_ne!(a, b);
1567        assert_eq!(b, RecordBatch::new_empty(schema))
1568    }
1569
1570    #[test]
1571    fn test_nulls_in_non_nullable_field() {
1572        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1573        let maybe_batch = RecordBatch::try_new(
1574            schema,
1575            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1576        );
1577        assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1578    }
1579    #[test]
1580    fn test_record_batch_options() {
1581        let options = RecordBatchOptions::new()
1582            .with_match_field_names(false)
1583            .with_row_count(Some(20));
1584        assert!(!options.match_field_names);
1585        assert_eq!(options.row_count.unwrap(), 20)
1586    }
1587
1588    #[test]
1589    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1590    fn test_from_struct() {
1591        let s = StructArray::from(ArrayData::new_null(
1592            // Note child is not nullable
1593            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1594            2,
1595        ));
1596        let _ = RecordBatch::from(s);
1597    }
1598
1599    #[test]
1600    fn test_with_schema() {
1601        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1602        let required_schema = Arc::new(required_schema);
1603        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1604        let nullable_schema = Arc::new(nullable_schema);
1605
1606        let batch = RecordBatch::try_new(
1607            required_schema.clone(),
1608            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1609        )
1610        .unwrap();
1611
1612        // Can add nullability
1613        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1614
1615        // Cannot remove nullability
1616        batch.clone().with_schema(required_schema).unwrap_err();
1617
1618        // Can add metadata
1619        let metadata = vec![("foo".to_string(), "bar".to_string())]
1620            .into_iter()
1621            .collect();
1622        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1623        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1624
1625        // Cannot remove metadata
1626        batch.with_schema(nullable_schema).unwrap_err();
1627    }
1628
1629    #[test]
1630    fn test_boxed_reader() {
1631        // Make sure we can pass a boxed reader to a function generic over
1632        // RecordBatchReader.
1633        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1634        let schema = Arc::new(schema);
1635
1636        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1637        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1638
1639        fn get_size(reader: impl RecordBatchReader) -> usize {
1640            reader.size_hint().0
1641        }
1642
1643        let size = get_size(reader);
1644        assert_eq!(size, 0);
1645    }
1646
1647    #[test]
1648    fn test_remove_column_maintains_schema_metadata() {
1649        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1650        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1651
1652        let mut metadata = HashMap::new();
1653        metadata.insert("foo".to_string(), "bar".to_string());
1654        let schema = Schema::new(vec![
1655            Field::new("id", DataType::Int32, false),
1656            Field::new("bool", DataType::Boolean, false),
1657        ])
1658        .with_metadata(metadata);
1659
1660        let mut batch = RecordBatch::try_new(
1661            Arc::new(schema),
1662            vec![Arc::new(id_array), Arc::new(bool_array)],
1663        )
1664        .unwrap();
1665
1666        let _removed_column = batch.remove_column(0);
1667        assert_eq!(batch.schema().metadata().len(), 1);
1668        assert_eq!(
1669            batch.schema().metadata().get("foo").unwrap().as_str(),
1670            "bar"
1671        );
1672    }
1673}