arrow_array/
record_batch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A two-dimensional batch of column-oriented data with a defined
19//! [schema](arrow_schema::Schema).
20
21use crate::cast::AsArray;
22use crate::{Array, ArrayRef, StructArray, new_empty_array};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27/// Trait for types that can read `RecordBatch`'s.
28///
29/// To create from an iterator, see [RecordBatchIterator].
30pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31    /// Returns the schema of this `RecordBatchReader`.
32    ///
33    /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
34    /// reader should have the same schema as returned from this method.
35    fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39    fn schema(&self) -> SchemaRef {
40        self.as_ref().schema()
41    }
42}
43
44/// Trait for types that can write `RecordBatch`'s.
45pub trait RecordBatchWriter {
46    /// Write a single batch to the writer.
47    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49    /// Write footer or termination data, then mark the writer as done.
50    fn close(self) -> Result<(), ArrowError>;
51}
52
53/// Creates an array from a literal slice of values,
54/// suitable for rapid testing and development.
55///
56/// Example:
57///
58/// ```rust
59///
60/// use arrow_array::create_array;
61///
62/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
63/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]);
64/// ```
65/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used.
66/// Presently supported data types are:
67/// - `Boolean`, `Null`
68/// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256`
69/// - `Float16`, `Float32`, `Float64`
70/// - `Int8`, `Int16`, `Int32`, `Int64`
71/// - `UInt8`, `UInt16`, `UInt32`, `UInt64`
72/// - `IntervalDayTime`, `IntervalYearMonth`
73/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond`
74/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64`
75/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond`
76/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond`
77/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary`
78#[macro_export]
79macro_rules! create_array {
80    // `@from` is used for those types that have a common method `<type>::from`
81    (@from Boolean) => { $crate::BooleanArray };
82    (@from Int8) => { $crate::Int8Array };
83    (@from Int16) => { $crate::Int16Array };
84    (@from Int32) => { $crate::Int32Array };
85    (@from Int64) => { $crate::Int64Array };
86    (@from UInt8) => { $crate::UInt8Array };
87    (@from UInt16) => { $crate::UInt16Array };
88    (@from UInt32) => { $crate::UInt32Array };
89    (@from UInt64) => { $crate::UInt64Array };
90    (@from Float16) => { $crate::Float16Array };
91    (@from Float32) => { $crate::Float32Array };
92    (@from Float64) => { $crate::Float64Array };
93    (@from Utf8) => { $crate::StringArray };
94    (@from Utf8View) => { $crate::StringViewArray };
95    (@from LargeUtf8) => { $crate::LargeStringArray };
96    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98    (@from Second) => { $crate::TimestampSecondArray };
99    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102    (@from Second32) => { $crate::Time32SecondArray };
103    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106    (@from DurationSecond) => { $crate::DurationSecondArray };
107    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110    (@from Decimal32) => { $crate::Decimal32Array };
111    (@from Decimal64) => { $crate::Decimal64Array };
112    (@from Decimal128) => { $crate::Decimal128Array };
113    (@from Decimal256) => { $crate::Decimal256Array };
114    (@from TimestampSecond) => { $crate::TimestampSecondArray };
115    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119    (@from $ty: ident) => {
120        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121    };
122
123    (Null, $size: expr) => {
124        std::sync::Arc::new($crate::NullArray::new($size))
125    };
126
127    (Binary, [$($values: expr),*]) => {
128        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129    };
130
131    (LargeBinary, [$($values: expr),*]) => {
132        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133    };
134
135    ($ty: tt, [$($values: expr),*]) => {
136        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137    };
138}
139
140/// Creates a record batch from literal slice of values, suitable for rapid
141/// testing and development.
142///
143/// Example:
144///
145/// ```rust
146/// use arrow_array::record_batch;
147/// use arrow_schema;
148///
149/// let batch = record_batch!(
150///     ("a", Int32, [1, 2, 3]),
151///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
152///     ("c", Utf8, ["alpha", "beta", "gamma"])
153/// );
154/// ```
155/// Due to limitation of [`create_array!`] macro, support for limited data types is available.
156#[macro_export]
157macro_rules! record_batch {
158    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159        {
160            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161                $(
162                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163                )*
164            ]));
165
166            let batch = $crate::RecordBatch::try_new(
167                schema,
168                vec![$(
169                    $crate::create_array!($type, [$($values),*]),
170                )*]
171            );
172
173            batch
174        }
175    }
176}
177
178/// A two-dimensional batch of column-oriented data with a defined
179/// [schema](arrow_schema::Schema).
180///
181/// A `RecordBatch` is a two-dimensional dataset of a number of
182/// contiguous arrays, each the same length.
183/// A record batch has a schema which must match its arrays’
184/// datatypes.
185///
186/// Record batches are a convenient unit of work for various
187/// serialization and computation functions, possibly incremental.
188///
189/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from
190/// literal slice of values, useful for rapid prototyping and testing.
191///
192/// Example:
193/// ```rust
194/// use arrow_array::record_batch;
195/// let batch = record_batch!(
196///     ("a", Int32, [1, 2, 3]),
197///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
198///     ("c", Utf8, ["alpha", "beta", "gamma"])
199/// );
200/// ```
201#[derive(Clone, Debug, PartialEq)]
202pub struct RecordBatch {
203    schema: SchemaRef,
204    columns: Vec<Arc<dyn Array>>,
205
206    /// The number of rows in this RecordBatch
207    ///
208    /// This is stored separately from the columns to handle the case of no columns
209    row_count: usize,
210}
211
212impl RecordBatch {
213    /// Creates a `RecordBatch` from a schema and columns.
214    ///
215    /// Expects the following:
216    ///
217    ///  * `!columns.is_empty()`
218    ///  * `schema.fields.len() == columns.len()`
219    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
220    ///  * `columns[i].len() == columns[j].len()`
221    ///
222    /// If the conditions are not met, an error is returned.
223    ///
224    /// # Example
225    ///
226    /// ```
227    /// # use std::sync::Arc;
228    /// # use arrow_array::{Int32Array, RecordBatch};
229    /// # use arrow_schema::{DataType, Field, Schema};
230    ///
231    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
232    /// let schema = Schema::new(vec![
233    ///     Field::new("id", DataType::Int32, false)
234    /// ]);
235    ///
236    /// let batch = RecordBatch::try_new(
237    ///     Arc::new(schema),
238    ///     vec![Arc::new(id_array)]
239    /// ).unwrap();
240    /// ```
241    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242        let options = RecordBatchOptions::new();
243        Self::try_new_impl(schema, columns, &options)
244    }
245
246    /// Creates a `RecordBatch` from a schema and columns, without validation.
247    ///
248    /// See [`Self::try_new`] for the checked version.
249    ///
250    /// # Safety
251    ///
252    /// Expects the following:
253    ///
254    ///  * `schema.fields.len() == columns.len()`
255    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
256    ///  * `columns[i].len() == row_count`
257    ///
258    /// Note: if the schema does not match the underlying data exactly, it can lead to undefined
259    /// behavior, for example, via conversion to a `StructArray`, which in turn could lead
260    /// to incorrect access.
261    pub unsafe fn new_unchecked(
262        schema: SchemaRef,
263        columns: Vec<Arc<dyn Array>>,
264        row_count: usize,
265    ) -> Self {
266        Self {
267            schema,
268            columns,
269            row_count,
270        }
271    }
272
273    /// Creates a `RecordBatch` from a schema and columns, with additional options,
274    /// such as whether to strictly validate field names.
275    ///
276    /// See [`RecordBatch::try_new`] for the expected conditions.
277    pub fn try_new_with_options(
278        schema: SchemaRef,
279        columns: Vec<ArrayRef>,
280        options: &RecordBatchOptions,
281    ) -> Result<Self, ArrowError> {
282        Self::try_new_impl(schema, columns, options)
283    }
284
285    /// Creates a new empty [`RecordBatch`].
286    pub fn new_empty(schema: SchemaRef) -> Self {
287        let columns = schema
288            .fields()
289            .iter()
290            .map(|field| new_empty_array(field.data_type()))
291            .collect();
292
293        RecordBatch {
294            schema,
295            columns,
296            row_count: 0,
297        }
298    }
299
300    /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
301    /// if any validation check fails, otherwise returns the created [`Self`]
302    fn try_new_impl(
303        schema: SchemaRef,
304        columns: Vec<ArrayRef>,
305        options: &RecordBatchOptions,
306    ) -> Result<Self, ArrowError> {
307        // check that number of fields in schema match column length
308        if schema.fields().len() != columns.len() {
309            return Err(ArrowError::InvalidArgumentError(format!(
310                "number of columns({}) must match number of fields({}) in schema",
311                columns.len(),
312                schema.fields().len(),
313            )));
314        }
315
316        let row_count = options
317            .row_count
318            .or_else(|| columns.first().map(|col| col.len()))
319            .ok_or_else(|| {
320                ArrowError::InvalidArgumentError(
321                    "must either specify a row count or at least one column".to_string(),
322                )
323            })?;
324
325        for (c, f) in columns.iter().zip(&schema.fields) {
326            if !f.is_nullable() && c.null_count() > 0 {
327                return Err(ArrowError::InvalidArgumentError(format!(
328                    "Column '{}' is declared as non-nullable but contains null values",
329                    f.name()
330                )));
331            }
332        }
333
334        // check that all columns have the same row count
335        if columns.iter().any(|c| c.len() != row_count) {
336            let err = match options.row_count {
337                Some(_) => "all columns in a record batch must have the specified row count",
338                None => "all columns in a record batch must have the same length",
339            };
340            return Err(ArrowError::InvalidArgumentError(err.to_string()));
341        }
342
343        // function for comparing column type and field type
344        // return true if 2 types are not matched
345        let type_not_match = if options.match_field_names {
346            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347        } else {
348            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349                !col_type.equals_datatype(field_type)
350            }
351        };
352
353        // check that all columns match the schema
354        let not_match = columns
355            .iter()
356            .zip(schema.fields().iter())
357            .map(|(col, field)| (col.data_type(), field.data_type()))
358            .enumerate()
359            .find(type_not_match);
360
361        if let Some((i, (col_type, field_type))) = not_match {
362            return Err(ArrowError::InvalidArgumentError(format!(
363                "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
364            )));
365        }
366
367        Ok(RecordBatch {
368            schema,
369            columns,
370            row_count,
371        })
372    }
373
374    /// Return the schema, columns and row count of this [`RecordBatch`]
375    pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376        (self.schema, self.columns, self.row_count)
377    }
378
379    /// Override the schema of this [`RecordBatch`]
380    ///
381    /// Returns an error if `schema` is not a superset of the current schema
382    /// as determined by [`Schema::contains`]
383    ///
384    /// See also [`Self::schema_metadata_mut`].
385    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
386        if !schema.contains(self.schema.as_ref()) {
387            return Err(ArrowError::SchemaError(format!(
388                "target schema is not superset of current schema target={schema} current={}",
389                self.schema
390            )));
391        }
392
393        Ok(Self {
394            schema,
395            columns: self.columns,
396            row_count: self.row_count,
397        })
398    }
399
400    /// Returns the [`Schema`] of the record batch.
401    pub fn schema(&self) -> SchemaRef {
402        self.schema.clone()
403    }
404
405    /// Returns a reference to the [`Schema`] of the record batch.
406    pub fn schema_ref(&self) -> &SchemaRef {
407        &self.schema
408    }
409
410    /// Mutable access to the metadata of the schema.
411    ///
412    /// This allows you to modify [`Schema::metadata`] of [`Self::schema`] in a convenient and fast way.
413    ///
414    /// Note this will clone the entire underlying `Schema` object if it is currently shared
415    ///
416    /// # Example
417    /// ```
418    /// # use std::sync::Arc;
419    /// # use arrow_array::{record_batch, RecordBatch};
420    /// let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
421    /// // Initially, the metadata is empty
422    /// assert!(batch.schema().metadata().get("key").is_none());
423    /// // Insert a key-value pair into the metadata
424    /// batch.schema_metadata_mut().insert("key".into(), "value".into());
425    /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value")));
426    /// ```
427    pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
428        let schema = Arc::make_mut(&mut self.schema);
429        &mut schema.metadata
430    }
431
432    /// Projects the schema onto the specified columns
433    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434        let projected_schema = self.schema.project(indices)?;
435        let batch_fields = indices
436            .iter()
437            .map(|f| {
438                self.columns.get(*f).cloned().ok_or_else(|| {
439                    ArrowError::SchemaError(format!(
440                        "project index {} out of bounds, max field {}",
441                        f,
442                        self.columns.len()
443                    ))
444                })
445            })
446            .collect::<Result<Vec<_>, _>>()?;
447
448        unsafe {
449            // Since we're starting from a valid RecordBatch and project
450            // creates a strict subset of the original, there's no need to
451            // redo the validation checks in `try_new_with_options`.
452            Ok(RecordBatch::new_unchecked(
453                SchemaRef::new(projected_schema),
454                batch_fields,
455                self.row_count,
456            ))
457        }
458    }
459
460    /// Normalize a semi-structured [`RecordBatch`] into a flat table.
461    ///
462    /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
463    /// (unlimited if `None`).
464    ///
465    /// e.g. given a [`RecordBatch`] with schema:
466    ///
467    /// ```text
468    ///     "foo": StructArray<"bar": Utf8>
469    /// ```
470    ///
471    /// A separator of `"."` would generate a batch with the schema:
472    ///
473    /// ```text
474    ///     "foo.bar": Utf8
475    /// ```
476    ///
477    /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
478    /// it will be treated as unlimited.
479    ///
480    /// # Example
481    ///
482    /// ```
483    /// # use std::sync::Arc;
484    /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
485    /// # use arrow_schema::{DataType, Field, Fields, Schema};
486    /// #
487    /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
488    /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
489    ///
490    /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
491    /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
492    ///
493    /// let a = Arc::new(StructArray::from(vec![
494    ///     (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
495    ///     (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
496    /// ]));
497    ///
498    /// let schema = Schema::new(vec![
499    ///     Field::new(
500    ///         "a",
501    ///         DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
502    ///         false,
503    ///     )
504    /// ]);
505    ///
506    /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
507    ///     .expect("valid conversion")
508    ///     .normalize(".", None)
509    ///     .expect("valid normalization");
510    ///
511    /// let expected = RecordBatch::try_from_iter_with_nullable(vec![
512    ///     ("a.animals", animals.clone(), true),
513    ///     ("a.n_legs", n_legs.clone(), true),
514    /// ])
515    /// .expect("valid conversion");
516    ///
517    /// assert_eq!(expected, normalized);
518    /// ```
519    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
520        let max_level = match max_level.unwrap_or(usize::MAX) {
521            0 => usize::MAX,
522            val => val,
523        };
524        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
525            .columns
526            .iter()
527            .zip(self.schema.fields())
528            .rev()
529            .map(|(c, f)| {
530                let name_vec: Vec<&str> = vec![f.name()];
531                (0, c, name_vec, f)
532            })
533            .collect();
534        let mut columns: Vec<ArrayRef> = Vec::new();
535        let mut fields: Vec<FieldRef> = Vec::new();
536
537        while let Some((depth, c, name, field_ref)) = stack.pop() {
538            match field_ref.data_type() {
539                DataType::Struct(ff) if depth < max_level => {
540                    // Need to zip these in reverse to maintain original order
541                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
542                        let mut name = name.clone();
543                        name.push(separator);
544                        name.push(fff.name());
545                        stack.push((depth + 1, cff, name, fff))
546                    }
547                }
548                _ => {
549                    let updated_field = Field::new(
550                        name.concat(),
551                        field_ref.data_type().clone(),
552                        field_ref.is_nullable(),
553                    );
554                    columns.push(c.clone());
555                    fields.push(Arc::new(updated_field));
556                }
557            }
558        }
559        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
560    }
561
562    /// Returns the number of columns in the record batch.
563    ///
564    /// # Example
565    ///
566    /// ```
567    /// # use std::sync::Arc;
568    /// # use arrow_array::{Int32Array, RecordBatch};
569    /// # use arrow_schema::{DataType, Field, Schema};
570    ///
571    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
572    /// let schema = Schema::new(vec![
573    ///     Field::new("id", DataType::Int32, false)
574    /// ]);
575    ///
576    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
577    ///
578    /// assert_eq!(batch.num_columns(), 1);
579    /// ```
580    pub fn num_columns(&self) -> usize {
581        self.columns.len()
582    }
583
584    /// Returns the number of rows in each column.
585    ///
586    /// # Example
587    ///
588    /// ```
589    /// # use std::sync::Arc;
590    /// # use arrow_array::{Int32Array, RecordBatch};
591    /// # use arrow_schema::{DataType, Field, Schema};
592    ///
593    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
594    /// let schema = Schema::new(vec![
595    ///     Field::new("id", DataType::Int32, false)
596    /// ]);
597    ///
598    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
599    ///
600    /// assert_eq!(batch.num_rows(), 5);
601    /// ```
602    pub fn num_rows(&self) -> usize {
603        self.row_count
604    }
605
606    /// Get a reference to a column's array by index.
607    ///
608    /// # Panics
609    ///
610    /// Panics if `index` is outside of `0..num_columns`.
611    pub fn column(&self, index: usize) -> &ArrayRef {
612        &self.columns[index]
613    }
614
615    /// Get a reference to a column's array by name.
616    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
617        self.schema()
618            .column_with_name(name)
619            .map(|(index, _)| &self.columns[index])
620    }
621
622    /// Get a reference to all columns in the record batch.
623    pub fn columns(&self) -> &[ArrayRef] {
624        &self.columns[..]
625    }
626
627    /// Remove column by index and return it.
628    ///
629    /// Return the `ArrayRef` if the column is removed.
630    ///
631    /// # Panics
632    ///
633    /// Panics if `index`` out of bounds.
634    ///
635    /// # Example
636    ///
637    /// ```
638    /// use std::sync::Arc;
639    /// use arrow_array::{BooleanArray, Int32Array, RecordBatch};
640    /// use arrow_schema::{DataType, Field, Schema};
641    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
642    /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
643    /// let schema = Schema::new(vec![
644    ///     Field::new("id", DataType::Int32, false),
645    ///     Field::new("bool", DataType::Boolean, false),
646    /// ]);
647    ///
648    /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap();
649    ///
650    /// let removed_column = batch.remove_column(0);
651    /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
652    /// assert_eq!(batch.num_columns(), 1);
653    /// ```
654    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
655        let mut builder = SchemaBuilder::from(self.schema.as_ref());
656        builder.remove(index);
657        self.schema = Arc::new(builder.finish());
658        self.columns.remove(index)
659    }
660
661    /// Return a new RecordBatch where each column is sliced
662    /// according to `offset` and `length`
663    ///
664    /// # Panics
665    ///
666    /// Panics if `offset` with `length` is greater than column length.
667    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
668        assert!((offset + length) <= self.num_rows());
669
670        let columns = self
671            .columns()
672            .iter()
673            .map(|column| column.slice(offset, length))
674            .collect();
675
676        Self {
677            schema: self.schema.clone(),
678            columns,
679            row_count: length,
680        }
681    }
682
683    /// Create a `RecordBatch` from an iterable list of pairs of the
684    /// form `(field_name, array)`, with the same requirements on
685    /// fields and arrays as [`RecordBatch::try_new`]. This method is
686    /// often used to create a single `RecordBatch` from arrays,
687    /// e.g. for testing.
688    ///
689    /// The resulting schema is marked as nullable for each column if
690    /// the array for that column is has any nulls. To explicitly
691    /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
692    ///
693    /// Example:
694    /// ```
695    /// # use std::sync::Arc;
696    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
697    ///
698    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
699    /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
700    ///
701    /// let record_batch = RecordBatch::try_from_iter(vec![
702    ///   ("a", a),
703    ///   ("b", b),
704    /// ]);
705    /// ```
706    /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro,
707    /// which is particularly helpful for rapid prototyping and testing.
708    ///
709    /// Example:
710    ///
711    /// ```rust
712    /// use arrow_array::record_batch;
713    /// let batch = record_batch!(
714    ///     ("a", Int32, [1, 2, 3]),
715    ///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
716    ///     ("c", Utf8, ["alpha", "beta", "gamma"])
717    /// );
718    /// ```
719    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
720    where
721        I: IntoIterator<Item = (F, ArrayRef)>,
722        F: AsRef<str>,
723    {
724        // TODO: implement `TryFrom` trait, once
725        // https://github.com/rust-lang/rust/issues/50133 is no longer an
726        // issue
727        let iter = value.into_iter().map(|(field_name, array)| {
728            let nullable = array.null_count() > 0;
729            (field_name, array, nullable)
730        });
731
732        Self::try_from_iter_with_nullable(iter)
733    }
734
735    /// Create a `RecordBatch` from an iterable list of tuples of the
736    /// form `(field_name, array, nullable)`, with the same requirements on
737    /// fields and arrays as [`RecordBatch::try_new`]. This method is often
738    /// used to create a single `RecordBatch` from arrays, e.g. for
739    /// testing.
740    ///
741    /// Example:
742    /// ```
743    /// # use std::sync::Arc;
744    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
745    ///
746    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
747    /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")]));
748    ///
749    /// // Note neither `a` nor `b` has any actual nulls, but we mark
750    /// // b an nullable
751    /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
752    ///   ("a", a, false),
753    ///   ("b", b, true),
754    /// ]);
755    /// ```
756    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
757    where
758        I: IntoIterator<Item = (F, ArrayRef, bool)>,
759        F: AsRef<str>,
760    {
761        let iter = value.into_iter();
762        let capacity = iter.size_hint().0;
763        let mut schema = SchemaBuilder::with_capacity(capacity);
764        let mut columns = Vec::with_capacity(capacity);
765
766        for (field_name, array, nullable) in iter {
767            let field_name = field_name.as_ref();
768            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
769            columns.push(array);
770        }
771
772        let schema = Arc::new(schema.finish());
773        RecordBatch::try_new(schema, columns)
774    }
775
776    /// Returns the total number of bytes of memory occupied physically by this batch.
777    ///
778    /// Note that this does not always correspond to the exact memory usage of a
779    /// `RecordBatch` (might overestimate), since multiple columns can share the same
780    /// buffers or slices thereof, the memory used by the shared buffers might be
781    /// counted multiple times.
782    pub fn get_array_memory_size(&self) -> usize {
783        self.columns()
784            .iter()
785            .map(|array| array.get_array_memory_size())
786            .sum()
787    }
788}
789
790/// Options that control the behaviour used when creating a [`RecordBatch`].
791#[derive(Debug)]
792#[non_exhaustive]
793pub struct RecordBatchOptions {
794    /// Match field names of structs and lists. If set to `true`, the names must match.
795    pub match_field_names: bool,
796
797    /// Optional row count, useful for specifying a row count for a RecordBatch with no columns
798    pub row_count: Option<usize>,
799}
800
801impl RecordBatchOptions {
802    /// Creates a new `RecordBatchOptions`
803    pub fn new() -> Self {
804        Self {
805            match_field_names: true,
806            row_count: None,
807        }
808    }
809    /// Sets the row_count of RecordBatchOptions and returns self
810    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
811        self.row_count = row_count;
812        self
813    }
814    /// Sets the match_field_names of RecordBatchOptions and returns self
815    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
816        self.match_field_names = match_field_names;
817        self
818    }
819}
820impl Default for RecordBatchOptions {
821    fn default() -> Self {
822        Self::new()
823    }
824}
825impl From<StructArray> for RecordBatch {
826    fn from(value: StructArray) -> Self {
827        let row_count = value.len();
828        let (fields, columns, nulls) = value.into_parts();
829        assert_eq!(
830            nulls.map(|n| n.null_count()).unwrap_or_default(),
831            0,
832            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
833        );
834
835        RecordBatch {
836            schema: Arc::new(Schema::new(fields)),
837            row_count,
838            columns,
839        }
840    }
841}
842
843impl From<&StructArray> for RecordBatch {
844    fn from(struct_array: &StructArray) -> Self {
845        struct_array.clone().into()
846    }
847}
848
849impl Index<&str> for RecordBatch {
850    type Output = ArrayRef;
851
852    /// Get a reference to a column's array by name.
853    ///
854    /// # Panics
855    ///
856    /// Panics if the name is not in the schema.
857    fn index(&self, name: &str) -> &Self::Output {
858        self.column_by_name(name).unwrap()
859    }
860}
861
862/// Generic implementation of [RecordBatchReader] that wraps an iterator.
863///
864/// # Example
865///
866/// ```
867/// # use std::sync::Arc;
868/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader};
869/// #
870/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
871/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
872///
873/// let record_batch = RecordBatch::try_from_iter(vec![
874///   ("a", a),
875///   ("b", b),
876/// ]).unwrap();
877///
878/// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()];
879///
880/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema());
881///
882/// assert_eq!(reader.schema(), record_batch.schema());
883/// assert_eq!(reader.next().unwrap().unwrap(), record_batch);
884/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch);
885/// # assert!(reader.next().is_none());
886/// ```
887pub struct RecordBatchIterator<I>
888where
889    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
890{
891    inner: I::IntoIter,
892    inner_schema: SchemaRef,
893}
894
895impl<I> RecordBatchIterator<I>
896where
897    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
898{
899    /// Create a new [RecordBatchIterator].
900    ///
901    /// If `iter` is an infallible iterator, use `.map(Ok)`.
902    pub fn new(iter: I, schema: SchemaRef) -> Self {
903        Self {
904            inner: iter.into_iter(),
905            inner_schema: schema,
906        }
907    }
908}
909
910impl<I> Iterator for RecordBatchIterator<I>
911where
912    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
913{
914    type Item = I::Item;
915
916    fn next(&mut self) -> Option<Self::Item> {
917        self.inner.next()
918    }
919
920    fn size_hint(&self) -> (usize, Option<usize>) {
921        self.inner.size_hint()
922    }
923}
924
925impl<I> RecordBatchReader for RecordBatchIterator<I>
926where
927    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
928{
929    fn schema(&self) -> SchemaRef {
930        self.inner_schema.clone()
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use crate::{
938        BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
939    };
940    use arrow_buffer::{Buffer, ToByteSlice};
941    use arrow_data::{ArrayData, ArrayDataBuilder};
942    use arrow_schema::Fields;
943    use std::collections::HashMap;
944
945    #[test]
946    fn create_record_batch() {
947        let schema = Schema::new(vec![
948            Field::new("a", DataType::Int32, false),
949            Field::new("b", DataType::Utf8, false),
950        ]);
951
952        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
953        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
954
955        let record_batch =
956            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
957        check_batch(record_batch, 5)
958    }
959
960    #[test]
961    fn create_string_view_record_batch() {
962        let schema = Schema::new(vec![
963            Field::new("a", DataType::Int32, false),
964            Field::new("b", DataType::Utf8View, false),
965        ]);
966
967        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
968        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
969
970        let record_batch =
971            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
972
973        assert_eq!(5, record_batch.num_rows());
974        assert_eq!(2, record_batch.num_columns());
975        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
976        assert_eq!(
977            &DataType::Utf8View,
978            record_batch.schema().field(1).data_type()
979        );
980        assert_eq!(5, record_batch.column(0).len());
981        assert_eq!(5, record_batch.column(1).len());
982    }
983
984    #[test]
985    fn byte_size_should_not_regress() {
986        let schema = Schema::new(vec![
987            Field::new("a", DataType::Int32, false),
988            Field::new("b", DataType::Utf8, false),
989        ]);
990
991        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
992        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
993
994        let record_batch =
995            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
996        assert_eq!(record_batch.get_array_memory_size(), 364);
997    }
998
999    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
1000        assert_eq!(num_rows, record_batch.num_rows());
1001        assert_eq!(2, record_batch.num_columns());
1002        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1003        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1004        assert_eq!(num_rows, record_batch.column(0).len());
1005        assert_eq!(num_rows, record_batch.column(1).len());
1006    }
1007
1008    #[test]
1009    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1010    fn create_record_batch_slice() {
1011        let schema = Schema::new(vec![
1012            Field::new("a", DataType::Int32, false),
1013            Field::new("b", DataType::Utf8, false),
1014        ]);
1015        let expected_schema = schema.clone();
1016
1017        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1018        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1019
1020        let record_batch =
1021            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1022
1023        let offset = 2;
1024        let length = 5;
1025        let record_batch_slice = record_batch.slice(offset, length);
1026
1027        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1028        check_batch(record_batch_slice, 5);
1029
1030        let offset = 2;
1031        let length = 0;
1032        let record_batch_slice = record_batch.slice(offset, length);
1033
1034        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1035        check_batch(record_batch_slice, 0);
1036
1037        let offset = 2;
1038        let length = 10;
1039        let _record_batch_slice = record_batch.slice(offset, length);
1040    }
1041
1042    #[test]
1043    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1044    fn create_record_batch_slice_empty_batch() {
1045        let schema = Schema::empty();
1046
1047        let record_batch = RecordBatch::new_empty(Arc::new(schema));
1048
1049        let offset = 0;
1050        let length = 0;
1051        let record_batch_slice = record_batch.slice(offset, length);
1052        assert_eq!(0, record_batch_slice.schema().fields().len());
1053
1054        let offset = 1;
1055        let length = 2;
1056        let _record_batch_slice = record_batch.slice(offset, length);
1057    }
1058
1059    #[test]
1060    fn create_record_batch_try_from_iter() {
1061        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1062            Some(1),
1063            Some(2),
1064            None,
1065            Some(4),
1066            Some(5),
1067        ]));
1068        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1069
1070        let record_batch =
1071            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1072
1073        let expected_schema = Schema::new(vec![
1074            Field::new("a", DataType::Int32, true),
1075            Field::new("b", DataType::Utf8, false),
1076        ]);
1077        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1078        check_batch(record_batch, 5);
1079    }
1080
1081    #[test]
1082    fn create_record_batch_try_from_iter_with_nullable() {
1083        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1084        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1085
1086        // Note there are no nulls in a or b, but we specify that b is nullable
1087        let record_batch =
1088            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1089                .expect("valid conversion");
1090
1091        let expected_schema = Schema::new(vec![
1092            Field::new("a", DataType::Int32, false),
1093            Field::new("b", DataType::Utf8, true),
1094        ]);
1095        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1096        check_batch(record_batch, 5);
1097    }
1098
1099    #[test]
1100    fn create_record_batch_schema_mismatch() {
1101        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1102
1103        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1104
1105        let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1106        assert_eq!(
1107            err.to_string(),
1108            "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1109        );
1110    }
1111
1112    #[test]
1113    fn create_record_batch_field_name_mismatch() {
1114        let fields = vec![
1115            Field::new("a1", DataType::Int32, false),
1116            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1117        ];
1118        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1119
1120        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1121        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1122        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1123            "array",
1124            DataType::Int8,
1125            false,
1126        ))))
1127        .add_child_data(a2_child.into_data())
1128        .len(2)
1129        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1130        .build()
1131        .unwrap();
1132        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1133        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1134            Field::new("aa1", DataType::Int32, false),
1135            Field::new("a2", a2.data_type().clone(), false),
1136        ])))
1137        .add_child_data(a1.into_data())
1138        .add_child_data(a2.into_data())
1139        .len(2)
1140        .build()
1141        .unwrap();
1142        let a: ArrayRef = Arc::new(StructArray::from(a));
1143
1144        // creating the batch with field name validation should fail
1145        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1146        assert!(batch.is_err());
1147
1148        // creating the batch without field name validation should pass
1149        let options = RecordBatchOptions {
1150            match_field_names: false,
1151            row_count: None,
1152        };
1153        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1154        assert!(batch.is_ok());
1155    }
1156
1157    #[test]
1158    fn create_record_batch_record_mismatch() {
1159        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1160
1161        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1162        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1163
1164        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1165        assert!(batch.is_err());
1166    }
1167
1168    #[test]
1169    fn create_record_batch_from_struct_array() {
1170        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1171        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1172        let struct_array = StructArray::from(vec![
1173            (
1174                Arc::new(Field::new("b", DataType::Boolean, false)),
1175                boolean.clone() as ArrayRef,
1176            ),
1177            (
1178                Arc::new(Field::new("c", DataType::Int32, false)),
1179                int.clone() as ArrayRef,
1180            ),
1181        ]);
1182
1183        let batch = RecordBatch::from(&struct_array);
1184        assert_eq!(2, batch.num_columns());
1185        assert_eq!(4, batch.num_rows());
1186        assert_eq!(
1187            struct_array.data_type(),
1188            &DataType::Struct(batch.schema().fields().clone())
1189        );
1190        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1191        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1192    }
1193
1194    #[test]
1195    fn record_batch_equality() {
1196        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1197        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1198        let schema1 = Schema::new(vec![
1199            Field::new("id", DataType::Int32, false),
1200            Field::new("val", DataType::Int32, false),
1201        ]);
1202
1203        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1204        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1205        let schema2 = Schema::new(vec![
1206            Field::new("id", DataType::Int32, false),
1207            Field::new("val", DataType::Int32, false),
1208        ]);
1209
1210        let batch1 = RecordBatch::try_new(
1211            Arc::new(schema1),
1212            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1213        )
1214        .unwrap();
1215
1216        let batch2 = RecordBatch::try_new(
1217            Arc::new(schema2),
1218            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1219        )
1220        .unwrap();
1221
1222        assert_eq!(batch1, batch2);
1223    }
1224
1225    /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
1226    #[test]
1227    fn record_batch_index_access() {
1228        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1229        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1230        let schema1 = Schema::new(vec![
1231            Field::new("id", DataType::Int32, false),
1232            Field::new("val", DataType::Int32, false),
1233        ]);
1234        let record_batch =
1235            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1236
1237        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1238        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1239    }
1240
1241    #[test]
1242    fn record_batch_vals_ne() {
1243        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1244        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1245        let schema1 = Schema::new(vec![
1246            Field::new("id", DataType::Int32, false),
1247            Field::new("val", DataType::Int32, false),
1248        ]);
1249
1250        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1251        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1252        let schema2 = Schema::new(vec![
1253            Field::new("id", DataType::Int32, false),
1254            Field::new("val", DataType::Int32, false),
1255        ]);
1256
1257        let batch1 = RecordBatch::try_new(
1258            Arc::new(schema1),
1259            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1260        )
1261        .unwrap();
1262
1263        let batch2 = RecordBatch::try_new(
1264            Arc::new(schema2),
1265            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1266        )
1267        .unwrap();
1268
1269        assert_ne!(batch1, batch2);
1270    }
1271
1272    #[test]
1273    fn record_batch_column_names_ne() {
1274        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1275        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1276        let schema1 = Schema::new(vec![
1277            Field::new("id", DataType::Int32, false),
1278            Field::new("val", DataType::Int32, false),
1279        ]);
1280
1281        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1282        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283        let schema2 = Schema::new(vec![
1284            Field::new("id", DataType::Int32, false),
1285            Field::new("num", DataType::Int32, false),
1286        ]);
1287
1288        let batch1 = RecordBatch::try_new(
1289            Arc::new(schema1),
1290            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1291        )
1292        .unwrap();
1293
1294        let batch2 = RecordBatch::try_new(
1295            Arc::new(schema2),
1296            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1297        )
1298        .unwrap();
1299
1300        assert_ne!(batch1, batch2);
1301    }
1302
1303    #[test]
1304    fn record_batch_column_number_ne() {
1305        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1306        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1307        let schema1 = Schema::new(vec![
1308            Field::new("id", DataType::Int32, false),
1309            Field::new("val", DataType::Int32, false),
1310        ]);
1311
1312        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1313        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1314        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315        let schema2 = Schema::new(vec![
1316            Field::new("id", DataType::Int32, false),
1317            Field::new("val", DataType::Int32, false),
1318            Field::new("num", DataType::Int32, false),
1319        ]);
1320
1321        let batch1 = RecordBatch::try_new(
1322            Arc::new(schema1),
1323            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1324        )
1325        .unwrap();
1326
1327        let batch2 = RecordBatch::try_new(
1328            Arc::new(schema2),
1329            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1330        )
1331        .unwrap();
1332
1333        assert_ne!(batch1, batch2);
1334    }
1335
1336    #[test]
1337    fn record_batch_row_count_ne() {
1338        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1339        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1340        let schema1 = Schema::new(vec![
1341            Field::new("id", DataType::Int32, false),
1342            Field::new("val", DataType::Int32, false),
1343        ]);
1344
1345        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1346        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1347        let schema2 = Schema::new(vec![
1348            Field::new("id", DataType::Int32, false),
1349            Field::new("num", DataType::Int32, false),
1350        ]);
1351
1352        let batch1 = RecordBatch::try_new(
1353            Arc::new(schema1),
1354            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1355        )
1356        .unwrap();
1357
1358        let batch2 = RecordBatch::try_new(
1359            Arc::new(schema2),
1360            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1361        )
1362        .unwrap();
1363
1364        assert_ne!(batch1, batch2);
1365    }
1366
1367    #[test]
1368    fn normalize_simple() {
1369        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1370        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1371        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1372
1373        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1374        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1375        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1376
1377        let a = Arc::new(StructArray::from(vec![
1378            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1379            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1380            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1381        ]));
1382
1383        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1384
1385        let schema = Schema::new(vec![
1386            Field::new(
1387                "a",
1388                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1389                false,
1390            ),
1391            Field::new("month", DataType::Int64, true),
1392        ]);
1393
1394        let normalized =
1395            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1396                .expect("valid conversion")
1397                .normalize(".", Some(0))
1398                .expect("valid normalization");
1399
1400        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1401            ("a.animals", animals.clone(), true),
1402            ("a.n_legs", n_legs.clone(), true),
1403            ("a.year", year.clone(), true),
1404            ("month", month.clone(), true),
1405        ])
1406        .expect("valid conversion");
1407
1408        assert_eq!(expected, normalized);
1409
1410        // check 0 and None have the same effect
1411        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1412            .expect("valid conversion")
1413            .normalize(".", None)
1414            .expect("valid normalization");
1415
1416        assert_eq!(expected, normalized);
1417    }
1418
1419    #[test]
1420    fn normalize_nested() {
1421        // Initialize schema
1422        let a = Arc::new(Field::new("a", DataType::Int64, true));
1423        let b = Arc::new(Field::new("b", DataType::Int64, false));
1424        let c = Arc::new(Field::new("c", DataType::Int64, true));
1425
1426        let one = Arc::new(Field::new(
1427            "1",
1428            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1429            false,
1430        ));
1431        let two = Arc::new(Field::new(
1432            "2",
1433            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1434            true,
1435        ));
1436
1437        let exclamation = Arc::new(Field::new(
1438            "!",
1439            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1440            false,
1441        ));
1442
1443        let schema = Schema::new(vec![exclamation.clone()]);
1444
1445        // Initialize fields
1446        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1447        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1448        let c_field = Int64Array::from(vec![None, Some(4)]);
1449
1450        let one_field = StructArray::from(vec![
1451            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1452            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1453            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1454        ]);
1455        let two_field = StructArray::from(vec![
1456            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1457            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1458            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1459        ]);
1460
1461        let exclamation_field = Arc::new(StructArray::from(vec![
1462            (one.clone(), Arc::new(one_field) as ArrayRef),
1463            (two.clone(), Arc::new(two_field) as ArrayRef),
1464        ]));
1465
1466        // Normalize top level
1467        let normalized =
1468            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1469                .expect("valid conversion")
1470                .normalize(".", Some(1))
1471                .expect("valid normalization");
1472
1473        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1474            (
1475                "!.1",
1476                Arc::new(StructArray::from(vec![
1477                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1478                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1479                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1480                ])) as ArrayRef,
1481                false,
1482            ),
1483            (
1484                "!.2",
1485                Arc::new(StructArray::from(vec![
1486                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1487                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1488                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1489                ])) as ArrayRef,
1490                true,
1491            ),
1492        ])
1493        .expect("valid conversion");
1494
1495        assert_eq!(expected, normalized);
1496
1497        // Normalize all levels
1498        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1499            .expect("valid conversion")
1500            .normalize(".", None)
1501            .expect("valid normalization");
1502
1503        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1504            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1505            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1506            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1507            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1508            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1509            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1510        ])
1511        .expect("valid conversion");
1512
1513        assert_eq!(expected, normalized);
1514    }
1515
1516    #[test]
1517    fn normalize_empty() {
1518        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1519        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1520        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1521
1522        let schema = Schema::new(vec![
1523            Field::new(
1524                "a",
1525                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1526                false,
1527            ),
1528            Field::new("month", DataType::Int64, true),
1529        ]);
1530
1531        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1532            .normalize(".", Some(0))
1533            .expect("valid normalization");
1534
1535        let expected = RecordBatch::new_empty(Arc::new(
1536            schema.normalize(".", Some(0)).expect("valid normalization"),
1537        ));
1538
1539        assert_eq!(expected, normalized);
1540    }
1541
1542    #[test]
1543    fn project() {
1544        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1545        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1546        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1547
1548        let record_batch =
1549            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1550                .expect("valid conversion");
1551
1552        let expected =
1553            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1554
1555        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1556    }
1557
1558    #[test]
1559    fn project_empty() {
1560        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1561
1562        let record_batch =
1563            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1564
1565        let expected = RecordBatch::try_new_with_options(
1566            Arc::new(Schema::empty()),
1567            vec![],
1568            &RecordBatchOptions {
1569                match_field_names: true,
1570                row_count: Some(3),
1571            },
1572        )
1573        .expect("valid conversion");
1574
1575        assert_eq!(expected, record_batch.project(&[]).unwrap());
1576    }
1577
1578    #[test]
1579    fn test_no_column_record_batch() {
1580        let schema = Arc::new(Schema::empty());
1581
1582        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1583        assert!(
1584            err.to_string()
1585                .contains("must either specify a row count or at least one column")
1586        );
1587
1588        let options = RecordBatchOptions::new().with_row_count(Some(10));
1589
1590        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1591        assert_eq!(ok.num_rows(), 10);
1592
1593        let a = ok.slice(2, 5);
1594        assert_eq!(a.num_rows(), 5);
1595
1596        let b = ok.slice(5, 0);
1597        assert_eq!(b.num_rows(), 0);
1598
1599        assert_ne!(a, b);
1600        assert_eq!(b, RecordBatch::new_empty(schema))
1601    }
1602
1603    #[test]
1604    fn test_nulls_in_non_nullable_field() {
1605        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1606        let maybe_batch = RecordBatch::try_new(
1607            schema,
1608            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1609        );
1610        assert_eq!(
1611            "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1612            format!("{}", maybe_batch.err().unwrap())
1613        );
1614    }
1615    #[test]
1616    fn test_record_batch_options() {
1617        let options = RecordBatchOptions::new()
1618            .with_match_field_names(false)
1619            .with_row_count(Some(20));
1620        assert!(!options.match_field_names);
1621        assert_eq!(options.row_count.unwrap(), 20)
1622    }
1623
1624    #[test]
1625    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1626    fn test_from_struct() {
1627        let s = StructArray::from(ArrayData::new_null(
1628            // Note child is not nullable
1629            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1630            2,
1631        ));
1632        let _ = RecordBatch::from(s);
1633    }
1634
1635    #[test]
1636    fn test_with_schema() {
1637        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1638        let required_schema = Arc::new(required_schema);
1639        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1640        let nullable_schema = Arc::new(nullable_schema);
1641
1642        let batch = RecordBatch::try_new(
1643            required_schema.clone(),
1644            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1645        )
1646        .unwrap();
1647
1648        // Can add nullability
1649        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1650
1651        // Cannot remove nullability
1652        batch.clone().with_schema(required_schema).unwrap_err();
1653
1654        // Can add metadata
1655        let metadata = vec![("foo".to_string(), "bar".to_string())]
1656            .into_iter()
1657            .collect();
1658        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1659        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1660
1661        // Cannot remove metadata
1662        batch.with_schema(nullable_schema).unwrap_err();
1663    }
1664
1665    #[test]
1666    fn test_boxed_reader() {
1667        // Make sure we can pass a boxed reader to a function generic over
1668        // RecordBatchReader.
1669        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1670        let schema = Arc::new(schema);
1671
1672        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1673        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1674
1675        fn get_size(reader: impl RecordBatchReader) -> usize {
1676            reader.size_hint().0
1677        }
1678
1679        let size = get_size(reader);
1680        assert_eq!(size, 0);
1681    }
1682
1683    #[test]
1684    fn test_remove_column_maintains_schema_metadata() {
1685        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1686        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1687
1688        let mut metadata = HashMap::new();
1689        metadata.insert("foo".to_string(), "bar".to_string());
1690        let schema = Schema::new(vec![
1691            Field::new("id", DataType::Int32, false),
1692            Field::new("bool", DataType::Boolean, false),
1693        ])
1694        .with_metadata(metadata);
1695
1696        let mut batch = RecordBatch::try_new(
1697            Arc::new(schema),
1698            vec![Arc::new(id_array), Arc::new(bool_array)],
1699        )
1700        .unwrap();
1701
1702        let _removed_column = batch.remove_column(0);
1703        assert_eq!(batch.schema().metadata().len(), 1);
1704        assert_eq!(
1705            batch.schema().metadata().get("foo").unwrap().as_str(),
1706            "bar"
1707        );
1708    }
1709}