Skip to main content

arrow_pyarrow/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Pass Arrow objects from and to PyArrow, using Arrow's
19//! [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html)
20//! and [pyo3](https://docs.rs/pyo3/latest/pyo3/).
21//!
22//! For underlying implementation, see the [ffi] module.
23//!
24//! One can use these to write Python functions that take and return PyArrow
25//! objects, with automatic conversion to corresponding arrow-rs types.
26//!
27//! ```ignore
28//! #[pyfunction]
29//! fn double_array(array: PyArrowType<ArrayData>) -> PyResult<PyArrowType<ArrayData>> {
30//!     let array = array.0; // Extract from PyArrowType wrapper
31//!     let array: Arc<dyn Array> = make_array(array); // Convert ArrayData to ArrayRef
32//!     let array: &Int32Array = array.as_any().downcast_ref()
33//!         .ok_or_else(|| PyValueError::new_err("expected int32 array"))?;
34//!     let array: Int32Array = array.iter().map(|x| x.map(|x| x * 2)).collect();
35//!     Ok(PyArrowType(array.into_data()))
36//! }
37//! ```
38//!
39//! | pyarrow type                | arrow-rs type                                                      |
40//! |-----------------------------|--------------------------------------------------------------------|
41//! | `pyarrow.DataType`          | [DataType]                                                         |
42//! | `pyarrow.Field`             | [Field]                                                            |
43//! | `pyarrow.Schema`            | [Schema]                                                           |
44//! | `pyarrow.Array`             | [ArrayData]                                                        |
45//! | `pyarrow.RecordBatch`       | [RecordBatch]                                                      |
46//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
47//! | `pyarrow.Table`             | [Table] (2)                                                        |
48//!
49//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
50//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
51//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
52//! easier to create.)
53//!
54//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
55//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
56//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
57//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
58//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
59//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` instead of
60//! forcing eager reading into `Vec<RecordBatch>`.
61
62use std::convert::{From, TryFrom};
63use std::ffi::CStr;
64use std::ptr::{addr_of, addr_of_mut};
65use std::sync::Arc;
66
67use arrow_array::ffi;
68use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
69use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
70use arrow_array::{
71    RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
72    make_array,
73};
74use arrow_data::ArrayData;
75use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
76use pyo3::exceptions::{PyTypeError, PyValueError};
77use pyo3::ffi::Py_uintptr_t;
78use pyo3::prelude::*;
79use pyo3::pybacked::PyBackedStr;
80use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
81use pyo3::{import_exception, intern};
82
83import_exception!(pyarrow, ArrowException);
84/// Represents an exception raised by PyArrow.
85pub type PyArrowException = ArrowException;
86
87const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
88const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
89const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
90
91fn to_py_err(err: ArrowError) -> PyErr {
92    PyArrowException::new_err(err.to_string())
93}
94
95/// Trait for converting Python objects to arrow-rs types.
96pub trait FromPyArrow: Sized {
97    /// Convert a Python object to an arrow-rs type.
98    ///
99    /// Takes a GIL-bound value from Python and returns a result with the arrow-rs type.
100    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
101}
102
103/// Create a new PyArrow object from a arrow-rs type.
104pub trait ToPyArrow {
105    /// Convert the implemented type into a Python object without consuming it.
106    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
107}
108
109/// Convert an arrow-rs type into a PyArrow object.
110pub trait IntoPyArrow {
111    /// Convert the implemented type into a Python object while consuming it.
112    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
113}
114
115impl<T: ToPyArrow> IntoPyArrow for T {
116    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
117        self.to_pyarrow(py)
118    }
119}
120
121fn validate_class(expected: &str, value: &Bound<PyAny>) -> PyResult<()> {
122    let pyarrow = PyModule::import(value.py(), "pyarrow")?;
123    let class = pyarrow.getattr(expected)?;
124    if !value.is_instance(&class)? {
125        let expected_module = class.getattr("__module__")?.extract::<PyBackedStr>()?;
126        let expected_name = class.getattr("__name__")?.extract::<PyBackedStr>()?;
127        let found_class = value.get_type();
128        let found_module = found_class
129            .getattr("__module__")?
130            .extract::<PyBackedStr>()?;
131        let found_name = found_class.getattr("__name__")?.extract::<PyBackedStr>()?;
132        return Err(PyTypeError::new_err(format!(
133            "Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
134        )));
135    }
136    Ok(())
137}
138
139fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
140    let capsule_name = capsule.name()?;
141    if capsule_name.is_none() {
142        return Err(PyValueError::new_err(
143            "Expected schema PyCapsule to have name set.",
144        ));
145    }
146
147    let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
148    if capsule_name != name {
149        return Err(PyValueError::new_err(format!(
150            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
151        )));
152    }
153
154    Ok(())
155}
156
157impl FromPyArrow for DataType {
158    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
159        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
160        // method, so prefer it over _export_to_c.
161        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
162        if value.hasattr("__arrow_c_schema__")? {
163            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
164            let capsule = capsule.cast::<PyCapsule>()?;
165            validate_pycapsule(capsule, "arrow_schema")?;
166
167            let schema_ptr = capsule
168                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
169                .cast::<FFI_ArrowSchema>();
170            unsafe {
171                let dtype = DataType::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
172                return Ok(dtype);
173            }
174        }
175
176        validate_class("DataType", value)?;
177
178        let c_schema = FFI_ArrowSchema::empty();
179        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
180        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
181        let dtype = DataType::try_from(&c_schema).map_err(to_py_err)?;
182        Ok(dtype)
183    }
184}
185
186impl ToPyArrow for DataType {
187    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
188        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
189        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
190        let module = py.import("pyarrow")?;
191        let class = module.getattr("DataType")?;
192        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
193        Ok(dtype)
194    }
195}
196
197impl FromPyArrow for Field {
198    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
199        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
200        // method, so prefer it over _export_to_c.
201        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
202        if value.hasattr("__arrow_c_schema__")? {
203            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
204            let capsule = capsule.cast::<PyCapsule>()?;
205            validate_pycapsule(capsule, "arrow_schema")?;
206
207            let schema_ptr = capsule
208                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
209                .cast::<FFI_ArrowSchema>();
210            unsafe {
211                let field = Field::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
212                return Ok(field);
213            }
214        }
215
216        validate_class("Field", value)?;
217
218        let c_schema = FFI_ArrowSchema::empty();
219        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
220        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
221        let field = Field::try_from(&c_schema).map_err(to_py_err)?;
222        Ok(field)
223    }
224}
225
226impl ToPyArrow for Field {
227    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
228        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
229        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
230        let module = py.import("pyarrow")?;
231        let class = module.getattr("Field")?;
232        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
233        Ok(dtype)
234    }
235}
236
237impl FromPyArrow for Schema {
238    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
239        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
240        // method, so prefer it over _export_to_c.
241        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
242        if value.hasattr("__arrow_c_schema__")? {
243            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
244            let capsule = capsule.cast::<PyCapsule>()?;
245            validate_pycapsule(capsule, "arrow_schema")?;
246
247            let schema_ptr = capsule
248                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
249                .cast::<FFI_ArrowSchema>();
250            unsafe {
251                let schema = Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?;
252                return Ok(schema);
253            }
254        }
255
256        validate_class("Schema", value)?;
257
258        let c_schema = FFI_ArrowSchema::empty();
259        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
260        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
261        let schema = Schema::try_from(&c_schema).map_err(to_py_err)?;
262        Ok(schema)
263    }
264}
265
266impl ToPyArrow for Schema {
267    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
268        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
269        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
270        let module = py.import("pyarrow")?;
271        let class = module.getattr("Schema")?;
272        let schema = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
273        Ok(schema)
274    }
275}
276
277impl FromPyArrow for ArrayData {
278    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
279        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
280        // method, so prefer it over _export_to_c.
281        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
282        if value.hasattr("__arrow_c_array__")? {
283            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
284
285            if !tuple.is_instance_of::<PyTuple>() {
286                return Err(PyTypeError::new_err(
287                    "Expected __arrow_c_array__ to return a tuple.",
288                ));
289            }
290
291            let schema_capsule = tuple.get_item(0)?;
292            let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
293            let array_capsule = tuple.get_item(1)?;
294            let array_capsule = array_capsule.cast::<PyCapsule>()?;
295
296            validate_pycapsule(schema_capsule, "arrow_schema")?;
297            validate_pycapsule(array_capsule, "arrow_array")?;
298
299            let schema_ptr = schema_capsule
300                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
301                .cast::<FFI_ArrowSchema>();
302            let array = unsafe {
303                FFI_ArrowArray::from_raw(
304                    array_capsule
305                        .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
306                        .cast::<FFI_ArrowArray>()
307                        .as_ptr(),
308                )
309            };
310            return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
311        }
312
313        validate_class("Array", value)?;
314
315        // prepare a pointer to receive the Array struct
316        let mut array = FFI_ArrowArray::empty();
317        let mut schema = FFI_ArrowSchema::empty();
318
319        // make the conversion through PyArrow's private API
320        // this changes the pointer's memory and is thus unsafe.
321        // In particular, `_export_to_c` can go out of bounds
322        value.call_method1(
323            "_export_to_c",
324            (
325                addr_of_mut!(array) as Py_uintptr_t,
326                addr_of_mut!(schema) as Py_uintptr_t,
327            ),
328        )?;
329
330        unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
331    }
332}
333
334impl ToPyArrow for ArrayData {
335    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
336        let array = FFI_ArrowArray::new(self);
337        let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
338
339        let module = py.import("pyarrow")?;
340        let class = module.getattr("Array")?;
341        let array = class.call_method1(
342            "_import_from_c",
343            (
344                addr_of!(array) as Py_uintptr_t,
345                addr_of!(schema) as Py_uintptr_t,
346            ),
347        )?;
348        Ok(array)
349    }
350}
351
352impl<T: FromPyArrow> FromPyArrow for Vec<T> {
353    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
354        let list = value.cast::<PyList>()?;
355        list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
356    }
357}
358
359impl<T: ToPyArrow> ToPyArrow for Vec<T> {
360    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
361        let values = self
362            .iter()
363            .map(|v| v.to_pyarrow(py))
364            .collect::<PyResult<Vec<_>>>()?;
365        Ok(PyList::new(py, values)?.into_any())
366    }
367}
368
369impl FromPyArrow for RecordBatch {
370    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
371        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
372        // method, so prefer it over _export_to_c.
373        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
374
375        if value.hasattr("__arrow_c_array__")? {
376            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
377
378            if !tuple.is_instance_of::<PyTuple>() {
379                return Err(PyTypeError::new_err(
380                    "Expected __arrow_c_array__ to return a tuple.",
381                ));
382            }
383
384            let schema_capsule = tuple.get_item(0)?;
385            let schema_capsule = schema_capsule.cast::<PyCapsule>()?;
386            let array_capsule = tuple.get_item(1)?;
387            let array_capsule = array_capsule.cast::<PyCapsule>()?;
388
389            validate_pycapsule(schema_capsule, "arrow_schema")?;
390            validate_pycapsule(array_capsule, "arrow_array")?;
391
392            let schema_ptr = schema_capsule
393                .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
394                .cast::<FFI_ArrowSchema>();
395            let array_ptr = array_capsule
396                .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
397                .cast::<FFI_ArrowArray>();
398            let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
399            let mut array_data =
400                unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
401            if !matches!(array_data.data_type(), DataType::Struct(_)) {
402                return Err(PyTypeError::new_err(
403                    "Expected Struct type from __arrow_c_array.",
404                ));
405            }
406            let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
407            // Ensure data is aligned (by potentially copying the buffers).
408            // This is needed because some python code (for example the
409            // python flight client) produces unaligned buffers
410            // See https://github.com/apache/arrow/issues/43552 for details
411            array_data.align_buffers();
412            let array = StructArray::from(array_data);
413            // StructArray does not embed metadata from schema. We need to override
414            // the output schema with the schema from the capsule.
415            let schema =
416                unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) };
417            let (_fields, columns, nulls) = array.into_parts();
418            assert_eq!(
419                nulls.map(|n| n.null_count()).unwrap_or_default(),
420                0,
421                "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
422            );
423            return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
424        }
425
426        validate_class("RecordBatch", value)?;
427        // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches
428        let schema = value.getattr("schema")?;
429        let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
430
431        let arrays = value.getattr("columns")?;
432        let arrays = arrays
433            .cast::<PyList>()?
434            .iter()
435            .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
436            .collect::<PyResult<_>>()?;
437
438        let row_count = value
439            .getattr("num_rows")
440            .ok()
441            .and_then(|x| x.extract().ok());
442        let options = RecordBatchOptions::default().with_row_count(row_count);
443
444        let batch =
445            RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
446        Ok(batch)
447    }
448}
449
450impl ToPyArrow for RecordBatch {
451    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
452        // Workaround apache/arrow#37669 by returning RecordBatchIterator
453        let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
454        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
455        let py_reader = reader.into_pyarrow(py)?;
456        py_reader.call_method0("read_next_batch")
457    }
458}
459
460/// Supports conversion from `pyarrow.RecordBatchReader` to [ArrowArrayStreamReader].
461impl FromPyArrow for ArrowArrayStreamReader {
462    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
463        // Newer versions of PyArrow as well as other libraries with Arrow data implement this
464        // method, so prefer it over _export_to_c.
465        // See https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
466        if value.hasattr("__arrow_c_stream__")? {
467            let capsule = value.getattr("__arrow_c_stream__")?.call0()?;
468            let capsule = capsule.cast::<PyCapsule>()?;
469            validate_pycapsule(capsule, "arrow_array_stream")?;
470
471            let stream = unsafe {
472                FFI_ArrowArrayStream::from_raw(
473                    capsule
474                        .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
475                        .cast::<FFI_ArrowArrayStream>()
476                        .as_ptr(),
477                )
478            };
479
480            let stream_reader = ArrowArrayStreamReader::try_new(stream)
481                .map_err(|err| PyValueError::new_err(err.to_string()))?;
482
483            return Ok(stream_reader);
484        }
485
486        validate_class("RecordBatchReader", value)?;
487
488        // prepare a pointer to receive the stream struct
489        let mut stream = FFI_ArrowArrayStream::empty();
490        let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;
491
492        // make the conversion through PyArrow's private API
493        // this changes the pointer's memory and is thus unsafe.
494        // In particular, `_export_to_c` can go out of bounds
495        let args = PyTuple::new(value.py(), [stream_ptr as Py_uintptr_t])?;
496        value.call_method1("_export_to_c", args)?;
497
498        let stream_reader = ArrowArrayStreamReader::try_new(stream)
499            .map_err(|err| PyValueError::new_err(err.to_string()))?;
500
501        Ok(stream_reader)
502    }
503}
504
505/// Convert a [`RecordBatchReader`] into a `pyarrow.RecordBatchReader`.
506impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
507    // We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because
508    // there is already a blanket implementation for `T: ToPyArrow`.
509    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
510        let mut stream = FFI_ArrowArrayStream::new(self);
511
512        let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
513        let module = py.import("pyarrow")?;
514        let class = module.getattr("RecordBatchReader")?;
515        let args = PyTuple::new(py, [stream_ptr as Py_uintptr_t])?;
516        let reader = class.call_method1("_import_from_c", args)?;
517
518        Ok(reader)
519    }
520}
521
522/// Convert a [`ArrowArrayStreamReader`] into a `pyarrow.RecordBatchReader`.
523impl IntoPyArrow for ArrowArrayStreamReader {
524    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
525        let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
526        boxed.into_pyarrow(py)
527    }
528}
529
530/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
531/// and to `pyarrow.Table`.
532///
533/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
534/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
535/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
536/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
537/// side.
538///
539/// ```ignore
540/// #[pyfunction]
541/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
542///     let batches: Vec<RecordBatch>;
543///     let schema: SchemaRef;
544///     PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
545/// }
546/// ```
547#[derive(Clone)]
548pub struct Table {
549    record_batches: Vec<RecordBatch>,
550    schema: SchemaRef,
551}
552
553impl Table {
554    pub fn try_new(
555        record_batches: Vec<RecordBatch>,
556        schema: SchemaRef,
557    ) -> Result<Self, ArrowError> {
558        for record_batch in &record_batches {
559            if schema != record_batch.schema() {
560                return Err(ArrowError::SchemaError(format!(
561                    "All record batches must have the same schema. \
562                         Expected schema: {:?}, got schema: {:?}",
563                    schema,
564                    record_batch.schema()
565                )));
566            }
567        }
568        Ok(Self {
569            record_batches,
570            schema,
571        })
572    }
573
574    pub fn record_batches(&self) -> &[RecordBatch] {
575        &self.record_batches
576    }
577
578    pub fn schema(&self) -> SchemaRef {
579        self.schema.clone()
580    }
581
582    pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
583        (self.record_batches, self.schema)
584    }
585}
586
587impl TryFrom<Box<dyn RecordBatchReader>> for Table {
588    type Error = ArrowError;
589
590    fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
591        let schema = value.schema();
592        let batches = value.collect::<Result<Vec<_>, _>>()?;
593        Self::try_new(batches, schema)
594    }
595}
596
597/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
598impl FromPyArrow for Table {
599    fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
600        let reader: Box<dyn RecordBatchReader> =
601            Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
602        Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
603    }
604}
605
606/// Convert a [`Table`] into `pyarrow.Table`.
607impl IntoPyArrow for Table {
608    fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
609        let module = py.import(intern!(py, "pyarrow"))?;
610        let class = module.getattr(intern!(py, "Table"))?;
611
612        let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
613        let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
614
615        let kwargs = PyDict::new(py);
616        kwargs.set_item("schema", py_schema)?;
617
618        let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;
619
620        Ok(reader)
621    }
622}
623
624/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
625///
626/// When wrapped around a type `T: FromPyArrow`, it
627/// implements [`FromPyObject`] for the PyArrow objects. When wrapped around a
628/// `T: IntoPyArrow`, it implements `IntoPy<PyObject>` for the wrapped type.
629#[derive(Debug)]
630pub struct PyArrowType<T>(pub T);
631
632impl<T: FromPyArrow> FromPyObject<'_, '_> for PyArrowType<T> {
633    type Error = PyErr;
634
635    fn extract(value: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
636        Ok(Self(T::from_pyarrow_bound(&value)?))
637    }
638}
639
640impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
641    type Target = PyAny;
642
643    type Output = Bound<'py, Self::Target>;
644
645    type Error = PyErr;
646
647    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, PyErr> {
648        self.0.into_pyarrow(py)
649    }
650}
651
652impl<T> From<T> for PyArrowType<T> {
653    fn from(s: T) -> Self {
654        Self(s)
655    }
656}