1use std::convert::{From, TryFrom};
60use std::ptr::{addr_of, addr_of_mut};
61use std::sync::Arc;
62
63use arrow_array::ffi;
64use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
65use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
66use arrow_array::{
67    RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
68    make_array,
69};
70use arrow_data::ArrayData;
71use arrow_schema::{ArrowError, DataType, Field, Schema};
72use pyo3::exceptions::{PyTypeError, PyValueError};
73use pyo3::ffi::Py_uintptr_t;
74use pyo3::import_exception;
75use pyo3::prelude::*;
76use pyo3::pybacked::PyBackedStr;
77use pyo3::types::{PyCapsule, PyList, PyTuple};
78
79import_exception!(pyarrow, ArrowException);
80pub type PyArrowException = ArrowException;
82
83fn to_py_err(err: ArrowError) -> PyErr {
84    PyArrowException::new_err(err.to_string())
85}
86
87pub trait FromPyArrow: Sized {
89    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
93}
94
95pub trait ToPyArrow {
97    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
99}
100
101pub trait IntoPyArrow {
103    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
105}
106
107impl<T: ToPyArrow> IntoPyArrow for T {
108    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
109        self.to_pyarrow(py)
110    }
111}
112
113fn validate_class(expected: &str, value: &Bound<PyAny>) -> PyResult<()> {
114    let pyarrow = PyModule::import(value.py(), "pyarrow")?;
115    let class = pyarrow.getattr(expected)?;
116    if !value.is_instance(&class)? {
117        let expected_module = class.getattr("__module__")?.extract::<PyBackedStr>()?;
118        let expected_name = class.getattr("__name__")?.extract::<PyBackedStr>()?;
119        let found_class = value.get_type();
120        let found_module = found_class
121            .getattr("__module__")?
122            .extract::<PyBackedStr>()?;
123        let found_name = found_class.getattr("__name__")?.extract::<PyBackedStr>()?;
124        return Err(PyTypeError::new_err(format!(
125            "Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
126        )));
127    }
128    Ok(())
129}
130
131fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
132    let capsule_name = capsule.name()?;
133    if capsule_name.is_none() {
134        return Err(PyValueError::new_err(
135            "Expected schema PyCapsule to have name set.",
136        ));
137    }
138
139    let capsule_name = capsule_name.unwrap().to_str()?;
140    if capsule_name != name {
141        return Err(PyValueError::new_err(format!(
142            "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
143        )));
144    }
145
146    Ok(())
147}
148
149impl FromPyArrow for DataType {
150    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
151        if value.hasattr("__arrow_c_schema__")? {
155            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
156            let capsule = capsule.downcast::<PyCapsule>()?;
157            validate_pycapsule(capsule, "arrow_schema")?;
158
159            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
160            let dtype = DataType::try_from(schema_ptr).map_err(to_py_err)?;
161            return Ok(dtype);
162        }
163
164        validate_class("DataType", value)?;
165
166        let c_schema = FFI_ArrowSchema::empty();
167        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
168        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
169        let dtype = DataType::try_from(&c_schema).map_err(to_py_err)?;
170        Ok(dtype)
171    }
172}
173
174impl ToPyArrow for DataType {
175    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
176        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
177        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
178        let module = py.import("pyarrow")?;
179        let class = module.getattr("DataType")?;
180        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
181        Ok(dtype)
182    }
183}
184
185impl FromPyArrow for Field {
186    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
187        if value.hasattr("__arrow_c_schema__")? {
191            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
192            let capsule = capsule.downcast::<PyCapsule>()?;
193            validate_pycapsule(capsule, "arrow_schema")?;
194
195            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
196            let field = Field::try_from(schema_ptr).map_err(to_py_err)?;
197            return Ok(field);
198        }
199
200        validate_class("Field", value)?;
201
202        let c_schema = FFI_ArrowSchema::empty();
203        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
204        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
205        let field = Field::try_from(&c_schema).map_err(to_py_err)?;
206        Ok(field)
207    }
208}
209
210impl ToPyArrow for Field {
211    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
212        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
213        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
214        let module = py.import("pyarrow")?;
215        let class = module.getattr("Field")?;
216        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
217        Ok(dtype)
218    }
219}
220
221impl FromPyArrow for Schema {
222    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
223        if value.hasattr("__arrow_c_schema__")? {
227            let capsule = value.getattr("__arrow_c_schema__")?.call0()?;
228            let capsule = capsule.downcast::<PyCapsule>()?;
229            validate_pycapsule(capsule, "arrow_schema")?;
230
231            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
232            let schema = Schema::try_from(schema_ptr).map_err(to_py_err)?;
233            return Ok(schema);
234        }
235
236        validate_class("Schema", value)?;
237
238        let c_schema = FFI_ArrowSchema::empty();
239        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
240        value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
241        let schema = Schema::try_from(&c_schema).map_err(to_py_err)?;
242        Ok(schema)
243    }
244}
245
246impl ToPyArrow for Schema {
247    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
248        let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
249        let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
250        let module = py.import("pyarrow")?;
251        let class = module.getattr("Schema")?;
252        let schema = class.call_method1("_import_from_c", (c_schema_ptr as Py_uintptr_t,))?;
253        Ok(schema)
254    }
255}
256
257impl FromPyArrow for ArrayData {
258    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
259        if value.hasattr("__arrow_c_array__")? {
263            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
264
265            if !tuple.is_instance_of::<PyTuple>() {
266                return Err(PyTypeError::new_err(
267                    "Expected __arrow_c_array__ to return a tuple.",
268                ));
269            }
270
271            let schema_capsule = tuple.get_item(0)?;
272            let schema_capsule = schema_capsule.downcast::<PyCapsule>()?;
273            let array_capsule = tuple.get_item(1)?;
274            let array_capsule = array_capsule.downcast::<PyCapsule>()?;
275
276            validate_pycapsule(schema_capsule, "arrow_schema")?;
277            validate_pycapsule(array_capsule, "arrow_array")?;
278
279            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
280            let array = unsafe { FFI_ArrowArray::from_raw(array_capsule.pointer() as _) };
281            return unsafe { ffi::from_ffi(array, schema_ptr) }.map_err(to_py_err);
282        }
283
284        validate_class("Array", value)?;
285
286        let mut array = FFI_ArrowArray::empty();
288        let mut schema = FFI_ArrowSchema::empty();
289
290        value.call_method1(
294            "_export_to_c",
295            (
296                addr_of_mut!(array) as Py_uintptr_t,
297                addr_of_mut!(schema) as Py_uintptr_t,
298            ),
299        )?;
300
301        unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
302    }
303}
304
305impl ToPyArrow for ArrayData {
306    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
307        let array = FFI_ArrowArray::new(self);
308        let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
309
310        let module = py.import("pyarrow")?;
311        let class = module.getattr("Array")?;
312        let array = class.call_method1(
313            "_import_from_c",
314            (
315                addr_of!(array) as Py_uintptr_t,
316                addr_of!(schema) as Py_uintptr_t,
317            ),
318        )?;
319        Ok(array)
320    }
321}
322
323impl<T: FromPyArrow> FromPyArrow for Vec<T> {
324    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
325        let list = value.downcast::<PyList>()?;
326        list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
327    }
328}
329
330impl<T: ToPyArrow> ToPyArrow for Vec<T> {
331    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
332        let values = self
333            .iter()
334            .map(|v| v.to_pyarrow(py))
335            .collect::<PyResult<Vec<_>>>()?;
336        Ok(PyList::new(py, values)?.into_any())
337    }
338}
339
340impl FromPyArrow for RecordBatch {
341    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
342        if value.hasattr("__arrow_c_array__")? {
346            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
347
348            if !tuple.is_instance_of::<PyTuple>() {
349                return Err(PyTypeError::new_err(
350                    "Expected __arrow_c_array__ to return a tuple.",
351                ));
352            }
353
354            let schema_capsule = tuple.get_item(0)?;
355            let schema_capsule = schema_capsule.downcast::<PyCapsule>()?;
356            let array_capsule = tuple.get_item(1)?;
357            let array_capsule = array_capsule.downcast::<PyCapsule>()?;
358
359            validate_pycapsule(schema_capsule, "arrow_schema")?;
360            validate_pycapsule(array_capsule, "arrow_array")?;
361
362            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
363            let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_capsule.pointer().cast()) };
364            let mut array_data =
365                unsafe { ffi::from_ffi(ffi_array, schema_ptr) }.map_err(to_py_err)?;
366            if !matches!(array_data.data_type(), DataType::Struct(_)) {
367                return Err(PyTypeError::new_err(
368                    "Expected Struct type from __arrow_c_array.",
369                ));
370            }
371            let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
372            array_data.align_buffers();
377            let array = StructArray::from(array_data);
378            let schema = Arc::new(Schema::try_from(schema_ptr).map_err(to_py_err)?);
381            let (_fields, columns, nulls) = array.into_parts();
382            assert_eq!(
383                nulls.map(|n| n.null_count()).unwrap_or_default(),
384                0,
385                "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
386            );
387            return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
388        }
389
390        validate_class("RecordBatch", value)?;
391        let schema = value.getattr("schema")?;
393        let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
394
395        let arrays = value.getattr("columns")?;
396        let arrays = arrays
397            .downcast::<PyList>()?
398            .iter()
399            .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
400            .collect::<PyResult<_>>()?;
401
402        let row_count = value
403            .getattr("num_rows")
404            .ok()
405            .and_then(|x| x.extract().ok());
406        let options = RecordBatchOptions::default().with_row_count(row_count);
407
408        let batch =
409            RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
410        Ok(batch)
411    }
412}
413
414impl ToPyArrow for RecordBatch {
415    fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
416        let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
418        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
419        let py_reader = reader.into_pyarrow(py)?;
420        py_reader.call_method0("read_next_batch")
421    }
422}
423
424impl FromPyArrow for ArrowArrayStreamReader {
426    fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
427        if value.hasattr("__arrow_c_stream__")? {
431            let capsule = value.getattr("__arrow_c_stream__")?.call0()?;
432            let capsule = capsule.downcast::<PyCapsule>()?;
433            validate_pycapsule(capsule, "arrow_array_stream")?;
434
435            let stream = unsafe { FFI_ArrowArrayStream::from_raw(capsule.pointer() as _) };
436
437            let stream_reader = ArrowArrayStreamReader::try_new(stream)
438                .map_err(|err| PyValueError::new_err(err.to_string()))?;
439
440            return Ok(stream_reader);
441        }
442
443        validate_class("RecordBatchReader", value)?;
444
445        let mut stream = FFI_ArrowArrayStream::empty();
447        let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;
448
449        let args = PyTuple::new(value.py(), [stream_ptr as Py_uintptr_t])?;
453        value.call_method1("_export_to_c", args)?;
454
455        let stream_reader = ArrowArrayStreamReader::try_new(stream)
456            .map_err(|err| PyValueError::new_err(err.to_string()))?;
457
458        Ok(stream_reader)
459    }
460}
461
462impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
464    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
467        let mut stream = FFI_ArrowArrayStream::new(self);
468
469        let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
470        let module = py.import("pyarrow")?;
471        let class = module.getattr("RecordBatchReader")?;
472        let args = PyTuple::new(py, [stream_ptr as Py_uintptr_t])?;
473        let reader = class.call_method1("_import_from_c", args)?;
474
475        Ok(reader)
476    }
477}
478
479impl IntoPyArrow for ArrowArrayStreamReader {
481    fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
482        let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
483        boxed.into_pyarrow(py)
484    }
485}
486
487#[derive(Debug)]
493pub struct PyArrowType<T>(pub T);
494
495impl<'source, T: FromPyArrow> FromPyObject<'source> for PyArrowType<T> {
496    fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
497        Ok(Self(T::from_pyarrow_bound(value)?))
498    }
499}
500
501impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
502    type Target = PyAny;
503
504    type Output = Bound<'py, Self::Target>;
505
506    type Error = PyErr;
507
508    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, PyErr> {
509        self.0.into_pyarrow(py)
510    }
511}
512
513impl<T> From<T> for PyArrowType<T> {
514    fn from(s: T) -> Self {
515        Self(s)
516    }
517}