1use std::convert::{From, TryFrom};
63use std::ffi::CStr;
64use std::sync::Arc;
65
66use arrow_array::ffi;
67use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
68use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
69use arrow_array::{
70 RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
71 make_array,
72};
73use arrow_data::ArrayData;
74use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
75use pyo3::exceptions::{PyTypeError, PyValueError};
76use pyo3::ffi::Py_uintptr_t;
77use pyo3::import_exception;
78use pyo3::prelude::*;
79use pyo3::sync::PyOnceLock;
80use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
81
82import_exception!(pyarrow, ArrowException);
83pub type PyArrowException = ArrowException;
85
86const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
87const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
88const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
89
90fn to_py_err(err: ArrowError) -> PyErr {
91 PyArrowException::new_err(err.to_string())
92}
93
94pub trait FromPyArrow: Sized {
96 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
100}
101
102pub trait ToPyArrow {
104 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
106}
107
108pub trait IntoPyArrow {
110 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
112}
113
114impl<T: ToPyArrow> IntoPyArrow for T {
115 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116 self.to_pyarrow(py)
117 }
118}
119
120fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()> {
121 if !value.is_instance(expected)? {
122 let expected_module = expected.getattr("__module__")?;
123 let expected_name = expected.getattr("__name__")?;
124 let found_class = value.get_type();
125 let found_module = found_class.getattr("__module__")?;
126 let found_name = found_class.getattr("__name__")?;
127 return Err(PyTypeError::new_err(format!(
128 "Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
129 )));
130 }
131 Ok(())
132}
133
134fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
135 let capsule_name = capsule.name()?;
136 if capsule_name.is_none() {
137 return Err(PyValueError::new_err(
138 "Expected schema PyCapsule to have name set.",
139 ));
140 }
141
142 let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
143 if capsule_name != name {
144 return Err(PyValueError::new_err(format!(
145 "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
146 )));
147 }
148
149 Ok(())
150}
151
152fn extract_arrow_c_array_capsules<'py>(
153 value: &Bound<'py, PyAny>,
154) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
155 let tuple = value.call_method0("__arrow_c_array__")?;
156
157 if !tuple.is_instance_of::<PyTuple>() {
158 return Err(PyTypeError::new_err(
159 "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
160 ));
161 }
162
163 tuple.extract().map_err(|_| {
164 PyTypeError::new_err(
165 "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
166 )
167 })
168}
169
170impl FromPyArrow for DataType {
171 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
172 if value.hasattr("__arrow_c_schema__")? {
176 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
177 validate_pycapsule(&capsule, "arrow_schema")?;
178
179 let schema_ptr = capsule
180 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
181 .cast::<FFI_ArrowSchema>();
182 return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
183 }
184
185 validate_class(data_type_class(value.py())?, value)?;
186
187 let mut c_schema = FFI_ArrowSchema::empty();
188 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
189 DataType::try_from(&c_schema).map_err(to_py_err)
190 }
191}
192
193impl ToPyArrow for DataType {
194 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
195 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
196 data_type_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
197 }
198}
199
200impl FromPyArrow for Field {
201 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
202 if value.hasattr("__arrow_c_schema__")? {
206 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
207 validate_pycapsule(&capsule, "arrow_schema")?;
208
209 let schema_ptr = capsule
210 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
211 .cast::<FFI_ArrowSchema>();
212 return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
213 }
214
215 validate_class(field_class(value.py())?, value)?;
216
217 let mut c_schema = FFI_ArrowSchema::empty();
218 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
219 Field::try_from(&c_schema).map_err(to_py_err)
220 }
221}
222
223impl ToPyArrow for Field {
224 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
225 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
226 field_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
227 }
228}
229
230impl FromPyArrow for Schema {
231 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
232 if value.hasattr("__arrow_c_schema__")? {
236 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
237 validate_pycapsule(&capsule, "arrow_schema")?;
238
239 let schema_ptr = capsule
240 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
241 .cast::<FFI_ArrowSchema>();
242 return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
243 }
244
245 validate_class(schema_class(value.py())?, value)?;
246
247 let mut c_schema = FFI_ArrowSchema::empty();
248 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
249 Schema::try_from(&c_schema).map_err(to_py_err)
250 }
251}
252
253impl ToPyArrow for Schema {
254 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
255 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
256 schema_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
257 }
258}
259
260impl FromPyArrow for ArrayData {
261 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
262 if value.hasattr("__arrow_c_array__")? {
266 let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
267
268 validate_pycapsule(&schema_capsule, "arrow_schema")?;
269 validate_pycapsule(&array_capsule, "arrow_array")?;
270
271 let schema_ptr = schema_capsule
272 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
273 .cast::<FFI_ArrowSchema>();
274 let array = unsafe {
275 FFI_ArrowArray::from_raw(
276 array_capsule
277 .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
278 .cast::<FFI_ArrowArray>()
279 .as_ptr(),
280 )
281 };
282 return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
283 }
284
285 validate_class(array_class(value.py())?, value)?;
286
287 let mut array = FFI_ArrowArray::empty();
289 let mut schema = FFI_ArrowSchema::empty();
290
291 value.call_method1(
295 "_export_to_c",
296 (
297 &raw mut array as Py_uintptr_t,
298 &raw mut schema as Py_uintptr_t,
299 ),
300 )?;
301
302 unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
303 }
304}
305
306impl ToPyArrow for ArrayData {
307 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
308 let array = FFI_ArrowArray::new(self);
309 let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
310 array_class(py)?.call_method1(
311 "_import_from_c",
312 (
313 &raw const array as Py_uintptr_t,
314 &raw const schema as Py_uintptr_t,
315 ),
316 )
317 }
318}
319
320impl<T: FromPyArrow> FromPyArrow for Vec<T> {
321 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
322 let list = value.cast::<PyList>()?;
323 list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
324 }
325}
326
327impl<T: ToPyArrow> ToPyArrow for Vec<T> {
328 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
329 let values = self
330 .iter()
331 .map(|v| v.to_pyarrow(py))
332 .collect::<PyResult<Vec<_>>>()?;
333 Ok(PyList::new(py, values)?.into_any())
334 }
335}
336
337impl FromPyArrow for RecordBatch {
338 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
339 if value.hasattr("__arrow_c_array__")? {
344 let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
345
346 validate_pycapsule(&schema_capsule, "arrow_schema")?;
347 validate_pycapsule(&array_capsule, "arrow_array")?;
348
349 let schema_ptr = schema_capsule
350 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
351 .cast::<FFI_ArrowSchema>();
352 let array_ptr = array_capsule
353 .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
354 .cast::<FFI_ArrowArray>();
355 let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
356 let mut array_data =
357 unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
358 if !matches!(array_data.data_type(), DataType::Struct(_)) {
359 return Err(PyTypeError::new_err(
360 "Expected Struct type from __arrow_c_array.",
361 ));
362 }
363 let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
364 array_data.align_buffers();
369 let array = StructArray::from(array_data);
370 let schema =
373 unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) };
374 let (_fields, columns, nulls) = array.into_parts();
375 assert_eq!(
376 nulls.map(|n| n.null_count()).unwrap_or_default(),
377 0,
378 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
379 );
380 return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
381 }
382
383 validate_class(record_batch_class(value.py())?, value)?;
384 let schema = value.getattr("schema")?;
386 let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
387
388 let arrays = value.getattr("columns")?;
389 let arrays = arrays
390 .cast::<PyList>()?
391 .iter()
392 .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
393 .collect::<PyResult<_>>()?;
394
395 let row_count = value
396 .getattr("num_rows")
397 .ok()
398 .and_then(|x| x.extract().ok());
399 let options = RecordBatchOptions::default().with_row_count(row_count);
400
401 let batch =
402 RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
403 Ok(batch)
404 }
405}
406
407impl ToPyArrow for RecordBatch {
408 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
409 let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
411 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
412 let py_reader = reader.into_pyarrow(py)?;
413 py_reader.call_method0("read_next_batch")
414 }
415}
416
417impl FromPyArrow for ArrowArrayStreamReader {
419 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
420 if value.hasattr("__arrow_c_stream__")? {
424 let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
425
426 validate_pycapsule(&capsule, "arrow_array_stream")?;
427
428 let stream = unsafe {
429 FFI_ArrowArrayStream::from_raw(
430 capsule
431 .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
432 .cast::<FFI_ArrowArrayStream>()
433 .as_ptr(),
434 )
435 };
436
437 let stream_reader = ArrowArrayStreamReader::try_new(stream)
438 .map_err(|err| PyValueError::new_err(err.to_string()))?;
439
440 return Ok(stream_reader);
441 }
442
443 validate_class(record_batch_reader_class(value.py())?, value)?;
444
445 let mut stream = FFI_ArrowArrayStream::empty();
447
448 let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
452 value.call_method1("_export_to_c", args)?;
453
454 ArrowArrayStreamReader::try_new(stream)
455 .map_err(|err| PyValueError::new_err(err.to_string()))
456 }
457}
458
459impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
461 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
464 let stream = FFI_ArrowArrayStream::new(self);
465 record_batch_reader_class(py)?
466 .call_method1("_import_from_c", (&raw const stream as Py_uintptr_t,))
467 }
468}
469
470impl IntoPyArrow for ArrowArrayStreamReader {
472 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
473 let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
474 boxed.into_pyarrow(py)
475 }
476}
477
478#[derive(Clone)]
496pub struct Table {
497 record_batches: Vec<RecordBatch>,
498 schema: SchemaRef,
499}
500
501impl Table {
502 pub fn try_new(
503 record_batches: Vec<RecordBatch>,
504 schema: SchemaRef,
505 ) -> Result<Self, ArrowError> {
506 for record_batch in &record_batches {
507 if schema != record_batch.schema() {
508 return Err(ArrowError::SchemaError(format!(
509 "All record batches must have the same schema. \
510 Expected schema: {:?}, got schema: {:?}",
511 schema,
512 record_batch.schema()
513 )));
514 }
515 }
516 Ok(Self {
517 record_batches,
518 schema,
519 })
520 }
521
522 pub fn record_batches(&self) -> &[RecordBatch] {
523 &self.record_batches
524 }
525
526 pub fn schema(&self) -> SchemaRef {
527 self.schema.clone()
528 }
529
530 pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
531 (self.record_batches, self.schema)
532 }
533}
534
535impl TryFrom<Box<dyn RecordBatchReader>> for Table {
536 type Error = ArrowError;
537
538 fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
539 let schema = value.schema();
540 let batches = value.collect::<Result<Vec<_>, _>>()?;
541 Self::try_new(batches, schema)
542 }
543}
544
545impl FromPyArrow for Table {
547 fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
548 let reader: Box<dyn RecordBatchReader> =
549 Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
550 Self::try_from(reader).map_err(|err| PyValueError::new_err(err.to_string()))
551 }
552}
553
554impl IntoPyArrow for Table {
556 fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
557 let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
558 let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
559
560 let kwargs = PyDict::new(py);
561 kwargs.set_item("schema", py_schema)?;
562
563 table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))
564 }
565}
566
567fn array_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
568 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
569 TYPE.import(py, "pyarrow", "Array")
570}
571
572fn record_batch_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
573 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
574 TYPE.import(py, "pyarrow", "RecordBatch")
575}
576
577fn record_batch_reader_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
578 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
579 TYPE.import(py, "pyarrow", "RecordBatchReader")
580}
581fn data_type_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
582 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
583 TYPE.import(py, "pyarrow", "DataType")
584}
585
586fn field_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
587 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
588 TYPE.import(py, "pyarrow", "Field")
589}
590
591fn schema_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
592 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
593 TYPE.import(py, "pyarrow", "Schema")
594}
595
596fn table_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
597 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
598 TYPE.import(py, "pyarrow", "Table")
599}
600
601#[derive(Debug)]
607pub struct PyArrowType<T>(pub T);
608
609impl<T: FromPyArrow> FromPyObject<'_, '_> for PyArrowType<T> {
610 type Error = PyErr;
611
612 fn extract(value: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
613 Ok(Self(T::from_pyarrow_bound(&value)?))
614 }
615}
616
617impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
618 type Target = PyAny;
619
620 type Output = Bound<'py, Self::Target>;
621
622 type Error = PyErr;
623
624 fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
625 self.0.into_pyarrow(py)
626 }
627}
628
629impl<T> From<T> for PyArrowType<T> {
630 fn from(s: T) -> Self {
631 Self(s)
632 }
633}