Skip to main content

arrow_pyarrow/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Pass Arrow objects from and to PyArrow, using Arrow's
19//! [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html)
20//! and [pyo3](https://docs.rs/pyo3/latest/pyo3/).
21//!
22//! For underlying implementation, see the [ffi] module.
23//!
24//! One can use these to write Python functions that take and return PyArrow
25//! objects, with automatic conversion to corresponding arrow-rs types.
26//!
27//! ```ignore
28//! #[pyfunction]
29//! fn double_array(array: PyArrowType<ArrayData>) -> PyResult<PyArrowType<ArrayData>> {
30//!     let array = array.0; // Extract from PyArrowType wrapper
31//!     let array: Arc<dyn Array> = make_array(array); // Convert ArrayData to ArrayRef
32//!     let array: &Int32Array = array.as_any().downcast_ref()
33//!         .ok_or_else(|| PyValueError::new_err("expected int32 array"))?;
34//!     let array: Int32Array = array.iter().map(|x| x.map(|x| x * 2)).collect();
35//!     Ok(PyArrowType(array.into_data()))
36//! }
37//! ```
38//!
39//! | pyarrow type                | arrow-rs type                                                      |
40//! |-----------------------------|--------------------------------------------------------------------|
41//! | `pyarrow.DataType`          | [DataType]                                                         |
42//! | `pyarrow.Field`             | [Field]                                                            |
43//! | `pyarrow.Schema`            | [Schema]                                                           |
44//! | `pyarrow.Array`             | [ArrayData]                                                        |
45//! | `pyarrow.RecordBatch`       | [RecordBatch]                                                      |
46//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
47//! | `pyarrow.Table`             | [Table] (2)                                                        |
48//!
49//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
50//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
51//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
52//! easier to create.)
53//!
54//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
55//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
56//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
57//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
58//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
59//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` instead of
60//! forcing eager reading into `Vec<RecordBatch>`.
61
62use std::convert::{From, TryFrom};
63use std::ffi::CStr;
64use std::ptr::{addr_of, addr_of_mut};
65use std::sync::Arc;
66
67use arrow_array::ffi;
68use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
69use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
70use arrow_array::{
71    RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
72    make_array,
73};
74use arrow_data::ArrayData;
75use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
76use pyo3::exceptions::{PyTypeError, PyValueError};
77use pyo3::ffi::Py_uintptr_t;
78use pyo3::import_exception;
79use pyo3::prelude::*;
80use pyo3::sync::PyOnceLock;
81use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
82
83import_exception!(pyarrow, ArrowException);
84/// Represents an exception raised by PyArrow.
85pub type PyArrowException = ArrowException;
86
87const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
88const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
89const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
90
91fn to_py_err(err: ArrowError) -> PyErr {
92    PyArrowException::new_err(err.to_string())
93}
94
95/// Trait for converting Python objects to arrow-rs types.
96pub trait FromPyArrow: Sized {
97    /// Convert a Python object to an arrow-rs type.
98    ///
99    /// Takes a GIL-bound value from Python and returns a result with the arrow-rs type.
100    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
101}
102
103/// Create a new PyArrow object from a arrow-rs type.
104pub trait ToPyArrow {
105    /// Convert the implemented type into a Python object without consuming it.
106    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
107}
108
109/// Convert an arrow-rs type into a PyArrow object.
110pub trait IntoPyArrow {
111    /// Convert the implemented type into a Python object while consuming it.
112    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
113}
114
115impl<T: ToPyArrow> IntoPyArrow for T {
116    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
117        self.to_pyarrow(py)
118    }
119}
120
121fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()> {
122    if !value.is_instance(expected)? {
123        let expected_module = expected.getattr("__module__")?;
124        let expected_name = expected.getattr("__name__")?;
125        let found_class = value.get_type();
126        let found_module = found_class.getattr("__module__")?;
127        let found_name = found_class.getattr("__name__")?;
128        return Err(PyTypeError::new_err(format!(
129            "Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
130        )));
131    }
132    Ok(())
133}
134
135fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
136    let capsule_name = capsule.name()?;
137    if capsule_name.is_none() {
138        return Err(PyValueError::new_err(
139            "Expected schema PyCapsule to have name set.",
140        ));
141    }
142
143    let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
144    if capsule_name != name {
145        return Err(PyValueError::new_err(format!(
146            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
147        )));
148    }
149
150    Ok(())
151}
152
153impl FromPyArrow for DataType {
154    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
155        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
156        // method, so prefer it over _export_to_c.
157        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
158        if value.hasattr("__arrow_c_schema__")? {
159            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
160            let capsule = capsule.cast::<PyCapsule>()?;
161            validate_pycapsule(capsule, "arrow_schema")?;
162
163            let schema_ptr = capsule
164                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
165                .cast::<FFI_ArrowSchema>();
166            unsafe {
167                let dtype = DataType::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
168                return Ok(dtype);
169            }
170        }
171
172        validate_class(data_type_class(value.py())?, value)?;
173
174        let c_schema = FFI_ArrowSchema::empty();
175        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
176        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
177        let dtype = DataType::try_from(&c_schema).map_err(to_py_err)?;
178        Ok(dtype)
179    }
180}
181
182impl ToPyArrow for DataType {
183    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
185        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
186        let dtype =
187            data_type_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
188        Ok(dtype)
189    }
190}
191
192impl FromPyArrow for Field {
193    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
194        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
195        // method, so prefer it over _export_to_c.
196        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
197        if value.hasattr("__arrow_c_schema__")? {
198            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
199            let capsule = capsule.cast::<PyCapsule>()?;
200            validate_pycapsule(capsule, "arrow_schema")?;
201
202            let schema_ptr = capsule
203                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
204                .cast::<FFI_ArrowSchema>();
205            unsafe {
206                let field = Field::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
207                return Ok(field);
208            }
209        }
210
211        validate_class(field_class(value.py())?, value)?;
212
213        let c_schema = FFI_ArrowSchema::empty();
214        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
215        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
216        let field = Field::try_from(&c_schema).map_err(to_py_err)?;
217        Ok(field)
218    }
219}
220
221impl ToPyArrow for Field {
222    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
223        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
224        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
225        let dtype =
226            field_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
227        Ok(dtype)
228    }
229}
230
231impl FromPyArrow for Schema {
232    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
233        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
234        // method, so prefer it over _export_to_c.
235        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
236        if value.hasattr("__arrow_c_schema__")? {
237            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
238            let capsule = capsule.cast::<PyCapsule>()?;
239            validate_pycapsule(capsule, "arrow_schema")?;
240
241            let schema_ptr = capsule
242                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
243                .cast::<FFI_ArrowSchema>();
244            unsafe {
245                let schema = Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
246                return Ok(schema);
247            }
248        }
249
250        validate_class(schema_class(value.py())?, value)?;
251
252        let c_schema = FFI_ArrowSchema::empty();
253        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
254        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
255        let schema = Schema::try_from(&c_schema).map_err(to_py_err)?;
256        Ok(schema)
257    }
258}
259
260impl ToPyArrow for Schema {
261    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
262        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
263        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
264        let schema =
265            schema_class(py)?.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
266        Ok(schema)
267    }
268}
269
270impl FromPyArrow for ArrayData {
271    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
272        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
273        // method, so prefer it over _export_to_c.
274        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
275        if value.hasattr("__arrow_c_array__")? {
276            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
277
278            if !tuple.is_instance_of::<PyTuple>() {
279                return Err(PyTypeError::new_err(
280                    "Expected __arrow_c_array__ to return a tuple.",
281                ));
282            }
283
284            let schema_capsule = tuple.get_item(0)?;
285            let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
286            let array_capsule = tuple.get_item(1)?;
287            let array_capsule = array_capsule.cast::<PyCapsule>()?;
288
289            validate_pycapsule(schema_capsule, "arrow_schema")?;
290            validate_pycapsule(array_capsule, "arrow_array")?;
291
292            let schema_ptr = schema_capsule
293                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
294                .cast::<FFI_ArrowSchema>();
295            let array = unsafe {
296                FFI_ArrowArray::from_raw(
297                    array_capsule
298                        .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
299                        .cast::<FFI_ArrowArray>()
300                        .as_ptr(),
301                )
302            };
303            return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
304        }
305
306        validate_class(array_class(value.py())?, value)?;
307
308        // prepare a pointer to receive the Array struct
309        let mut array = FFI_ArrowArray::empty();
310        let mut schema = FFI_ArrowSchema::empty();
311
312        // make the conversion through PyArrow's private API
313        // this changes the pointer's memory and is thus unsafe.
314        // In particular, `_export_to_c` can go out of bounds
315        value.call_method1(
316            "_export_to_c",
317            (
318                addr_of_mut!(array) as Py_uintptr_t,
319                addr_of_mut!(schema) as Py_uintptr_t,
320            ),
321        )?;
322
323        unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
324    }
325}
326
327impl ToPyArrow for ArrayData {
328    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
329        let array = FFI_ArrowArray::new(self);
330        let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
331
332        let array = array_class(py)?.call_method1(
333            "_import_from_c",
334            (
335                addr_of!(array) as Py_uintptr_t,
336                addr_of!(schema) as Py_uintptr_t,
337            ),
338        )?;
339        Ok(array)
340    }
341}
342
343impl<T: FromPyArrow> FromPyArrow for Vec<T> {
344    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
345        let list = value.cast::<PyList>()?;
346        list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
347    }
348}
349
350impl<T: ToPyArrow> ToPyArrow for Vec<T> {
351    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
352        let values = self
353            .iter()
354            .map(|v| v.to_pyarrow(py))
355            .collect::<PyResult<Vec<_>>>()?;
356        Ok(PyList::new(py, values)?.into_any())
357    }
358}
359
360impl FromPyArrow for RecordBatch {
361    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
362        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
363        // method, so prefer it over _export_to_c.
364        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
365
366        if value.hasattr("__arrow_c_array__")? {
367            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
368
369            if !tuple.is_instance_of::<PyTuple>() {
370                return Err(PyTypeError::new_err(
371                    "Expected __arrow_c_array__ to return a tuple.",
372                ));
373            }
374
375            let schema_capsule = tuple.get_item(0)?;
376            let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
377            let array_capsule = tuple.get_item(1)?;
378            let array_capsule = array_capsule.cast::<PyCapsule>()?;
379
380            validate_pycapsule(schema_capsule, "arrow_schema")?;
381            validate_pycapsule(array_capsule, "arrow_array")?;
382
383            let schema_ptr = schema_capsule
384                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
385                .cast::<FFI_ArrowSchema>();
386            let array_ptr = array_capsule
387                .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
388                .cast::<FFI_ArrowArray>();
389            let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
390            let mut array_data =
391                unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
392            if !matches!(array_data.data_type(), DataType::Struct(_)) {
393                return Err(PyTypeError::new_err(
394                    "Expected Struct type from __arrow_c_array.",
395                ));
396            }
397            let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
398            // Ensure data is aligned (by potentially copying the buffers).
399            // This is needed because some python code (for example the
400            // python flight client) produces unaligned buffers
401            // See https://github.com/apache/arrow/issues/43552 for details
402            array_data.align_buffers();
403            let array = StructArray::from(array_data);
404            // StructArray does not embed metadata from schema. We need to override
405            // the output schema with the schema from the capsule.
406            let schema =
407                unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) };
408            let (_fields, columns, nulls) = array.into_parts();
409            assert_eq!(
410                nulls.map(|n| n.null_count()).unwrap_or_default(),
411                0,
412                "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
413            );
414            return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
415        }
416
417        validate_class(record_batch_class(value.py())?, value)?;
418        // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches
419        let schema = value.getattr("schema")?;
420        let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
421
422        let arrays = value.getattr("columns")?;
423        let arrays = arrays
424            .cast::<PyList>()?
425            .iter()
426            .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
427            .collect::<PyResult<_>>()?;
428
429        let row_count = value
430            .getattr("num_rows")
431            .ok()
432            .and_then(|x| x.extract().ok());
433        let options = RecordBatchOptions::default().with_row_count(row_count);
434
435        let batch =
436            RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
437        Ok(batch)
438    }
439}
440
441impl ToPyArrow for RecordBatch {
442    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
443        // Workaround apache/arrow#37669 by returning RecordBatchIterator
444        let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
445        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
446        let py_reader = reader.into_pyarrow(py)?;
447        py_reader.call_method0("read_next_batch")
448    }
449}
450
451/// Supports conversion from `pyarrow.RecordBatchReader` to [ArrowArrayStreamReader].
452impl FromPyArrow for ArrowArrayStreamReader {
453    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
454        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
455        // method, so prefer it over _export_to_c.
456        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
457        if value.hasattr("__arrow_c_stream__")? {
458            let capsule = value.getattr("__arrow_c_stream__")?.call0()?;
459            let capsule = capsule.cast::<PyCapsule>()?;
460            validate_pycapsule(capsule, "arrow_array_stream")?;
461
462            let stream = unsafe {
463                FFI_ArrowArrayStream::from_raw(
464                    capsule
465                        .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
466                        .cast::<FFI_ArrowArrayStream>()
467                        .as_ptr(),
468                )
469            };
470
471            let stream_reader = ArrowArrayStreamReader::try_new(stream)
472                .map_err(|err| PyValueError::new_err(err.to_string()))?;
473
474            return Ok(stream_reader);
475        }
476
477        validate_class(record_batch_reader_class(value.py())?, value)?;
478
479        // prepare a pointer to receive the stream struct
480        let mut stream = FFI_ArrowArrayStream::empty();
481        let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;
482
483        // make the conversion through PyArrow's private API
484        // this changes the pointer's memory and is thus unsafe.
485        // In particular, `_export_to_c` can go out of bounds
486        let args = PyTuple::new(value.py(), [stream_ptr as Py_uintptr_t])?;
487        value.call_method1("_export_to_c", args)?;
488
489        let stream_reader = ArrowArrayStreamReader::try_new(stream)
490            .map_err(|err| PyValueError::new_err(err.to_string()))?;
491
492        Ok(stream_reader)
493    }
494}
495
496/// Convert a [`RecordBatchReader`] into a `pyarrow.RecordBatchReader`.
497impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
498    // We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because
499    // there is already a blanket implementation for `T: ToPyArrow`.
500    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
501        let mut stream = FFI_ArrowArrayStream::new(self);
502
503        let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
504        let reader = record_batch_reader_class(py)?
505            .call_method1("_import_from_c", (stream_ptr as Py_uintptr_t,))?;
506
507        Ok(reader)
508    }
509}
510
511/// Convert a [`ArrowArrayStreamReader`] into a `pyarrow.RecordBatchReader`.
512impl IntoPyArrow for ArrowArrayStreamReader {
513    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
514        let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
515        boxed.into_pyarrow(py)
516    }
517}
518
519/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
520/// and to `pyarrow.Table`.
521///
522/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
523/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
524/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
525/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
526/// side.
527///
528/// ```ignore
529/// #[pyfunction]
530/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
531///     let batches: Vec<RecordBatch>;
532///     let schema: SchemaRef;
533///     PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
534/// }
535/// ```
536#[derive(Clone)]
537pub struct Table {
538    record_batches: Vec<RecordBatch>,
539    schema: SchemaRef,
540}
541
542impl Table {
543    pub fn try_new(
544        record_batches: Vec<RecordBatch>,
545        schema: SchemaRef,
546    ) -> Result<Self, ArrowError> {
547        for record_batch in &record_batches {
548            if schema != record_batch.schema() {
549                return Err(ArrowError::SchemaError(format!(
550                    "All record batches must have the same schema. \
551                         Expected schema: {:?}, got schema: {:?}",
552                    schema,
553                    record_batch.schema()
554                )));
555            }
556        }
557        Ok(Self {
558            record_batches,
559            schema,
560        })
561    }
562
563    pub fn record_batches(&self) -> &[RecordBatch] {
564        &self.record_batches
565    }
566
567    pub fn schema(&self) -> SchemaRef {
568        self.schema.clone()
569    }
570
571    pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
572        (self.record_batches, self.schema)
573    }
574}
575
576impl TryFrom<Box<dyn RecordBatchReader>> for Table {
577    type Error = ArrowError;
578
579    fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
580        let schema = value.schema();
581        let batches = value.collect::<Result<Vec<_>, _>>()?;
582        Self::try_new(batches, schema)
583    }
584}
585
586/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
587impl FromPyArrow for Table {
588    fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
589        let reader: Box<dyn RecordBatchReader> =
590            Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
591        Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
592    }
593}
594
595/// Convert a [`Table`] into `pyarrow.Table`.
596impl IntoPyArrow for Table {
597    fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
598        let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
599        let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
600
601        let kwargs = PyDict::new(py);
602        kwargs.set_item("schema", py_schema)?;
603
604        let reader = table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))?;
605
606        Ok(reader)
607    }
608}
609
610fn array_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
611    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
612    TYPE.import(py, "pyarrow", "Array")
613}
614
615fn record_batch_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
616    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
617    TYPE.import(py, "pyarrow", "RecordBatch")
618}
619
620fn record_batch_reader_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
621    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
622    TYPE.import(py, "pyarrow", "RecordBatchReader")
623}
624fn data_type_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
625    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
626    TYPE.import(py, "pyarrow", "DataType")
627}
628
629fn field_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
630    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
631    TYPE.import(py, "pyarrow", "Field")
632}
633
634fn schema_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
635    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
636    TYPE.import(py, "pyarrow", "Schema")
637}
638
639fn table_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
640    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
641    TYPE.import(py, "pyarrow", "Table")
642}
643
644/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
645///
646/// When wrapped around a type `T: FromPyArrow`, it
647/// implements [`FromPyObject`] for the PyArrow objects. When wrapped around a
648/// `T: IntoPyArrow`, it implements `IntoPy<PyObject>` for the wrapped type.
649#[derive(Debug)]
650pub struct PyArrowType<T>(pub T);
651
652impl<T: FromPyArrow> FromPyObject<'_, '_> for PyArrowType<T> {
653    type Error = PyErr;
654
655    fn extract(value: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
656        Ok(Self(T::from_pyarrow_bound(&value)?))
657    }
658}
659
660impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
661    type Target = PyAny;
662
663    type Output = Bound<'py, Self::Target>;
664
665    type Error = PyErr;
666
667    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, PyErr> {
668        self.0.into_pyarrow(py)
669    }
670}
671
672impl<T> From<T> for PyArrowType<T> {
673    fn from(s: T) -> Self {
674        Self(s)
675    }
676}