1use std::convert::{From, TryFrom};
63use std::ffi::CStr;
64use std::sync::Arc;
65
66use arrow_array::ffi;
67use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
68use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
69use arrow_array::{
70 RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
71 make_array,
72};
73use arrow_data::ArrayData;
74use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
75use pyo3::exceptions::{PyTypeError, PyValueError};
76use pyo3::ffi::Py_uintptr_t;
77use pyo3::import_exception;
78use pyo3::prelude::*;
79use pyo3::sync::PyOnceLock;
80use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
81
82import_exception!(pyarrow, ArrowException);
83pub type PyArrowException = ArrowException;
85
86const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
87const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
88const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
89
90fn to_py_err(err: ArrowError) -> PyErr {
91 PyArrowException::new_err(err.to_string())
92}
93
94pub trait FromPyArrow: Sized {
96 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
100}
101
102pub trait ToPyArrow {
104 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
106}
107
108pub trait IntoPyArrow {
110 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
112}
113
114impl<T: ToPyArrow> IntoPyArrow for T {
115 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
116 self.to_pyarrow(py)
117 }
118}
119
120fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()> {
121 if !value.is_instance(expected)? {
122 let expected_module = expected.getattr("__module__")?;
123 let expected_name = expected.getattr("__name__")?;
124 let found_class = value.get_type();
125 let found_module = found_class.getattr("__module__")?;
126 let found_name = found_class.getattr("__name__")?;
127 return Err(PyTypeError::new_err(format!(
128 "Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
129 )));
130 }
131 Ok(())
132}
133
134fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
135 let capsule_name = capsule.name()?;
136 if capsule_name.is_none() {
137 return Err(PyValueError::new_err(
138 "Expected schema PyCapsule to have name set.",
139 ));
140 }
141
142 let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
143 if capsule_name != name {
144 return Err(PyValueError::new_err(format!(
145 "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
146 )));
147 }
148
149 Ok(())
150}
151
152fn extract_arrow_c_array_capsules<'py>(
153 value: &Bound<'py, PyAny>,
154) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
155 let tuple = value.call_method0("__arrow_c_array__")?;
156
157 if !tuple.is_instance_of::<PyTuple>() {
158 return Err(PyTypeError::new_err(
159 "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
160 ));
161 }
162
163 tuple.extract().map_err(|_| {
164 PyTypeError::new_err(
165 "Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
166 )
167 })
168}
169
170impl FromPyArrow for DataType {
171 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
172 if value.hasattr("__arrow_c_schema__")? {
176 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
177 validate_pycapsule(&capsule, "arrow_schema")?;
178
179 let schema_ptr = capsule
180 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
181 .cast::<FFI_ArrowSchema>();
182 return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
183 }
184
185 validate_class(data_type_class(value.py())?, value)?;
186
187 let mut c_schema = FFI_ArrowSchema::empty();
188 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
189 DataType::try_from(&c_schema).map_err(to_py_err)
190 }
191}
192
193impl ToPyArrow for DataType {
194 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
195 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
196 data_type_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
197 }
198}
199
200impl FromPyArrow for Field {
201 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
202 if value.hasattr("__arrow_c_schema__")? {
206 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
207 validate_pycapsule(&capsule, "arrow_schema")?;
208
209 let schema_ptr = capsule
210 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
211 .cast::<FFI_ArrowSchema>();
212 return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
213 }
214
215 validate_class(field_class(value.py())?, value)?;
216
217 let mut c_schema = FFI_ArrowSchema::empty();
218 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
219 Field::try_from(&c_schema).map_err(to_py_err)
220 }
221}
222
223impl ToPyArrow for Field {
224 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
225 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
226 field_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
227 }
228}
229
230impl FromPyArrow for Schema {
231 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
232 if value.hasattr("__arrow_c_schema__")? {
236 let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
237 validate_pycapsule(&capsule, "arrow_schema")?;
238
239 let schema_ptr = capsule
240 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
241 .cast::<FFI_ArrowSchema>();
242 return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
243 }
244
245 validate_class(schema_class(value.py())?, value)?;
246
247 let mut c_schema = FFI_ArrowSchema::empty();
248 value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
249 Schema::try_from(&c_schema).map_err(to_py_err)
250 }
251}
252
253impl ToPyArrow for Schema {
254 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
255 let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
256 schema_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
257 }
258}
259
260impl FromPyArrow for ArrayData {
261 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
262 if value.hasattr("__arrow_c_array__")? {
266 let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
267
268 validate_pycapsule(&schema_capsule, "arrow_schema")?;
269 validate_pycapsule(&array_capsule, "arrow_array")?;
270
271 let schema_ptr = schema_capsule
272 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
273 .cast::<FFI_ArrowSchema>();
274 let array = unsafe {
275 FFI_ArrowArray::from_raw(
276 array_capsule
277 .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
278 .cast::<FFI_ArrowArray>()
279 .as_ptr(),
280 )
281 };
282 return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
283 }
284
285 validate_class(array_class(value.py())?, value)?;
286
287 let mut array = FFI_ArrowArray::empty();
289 let mut schema = FFI_ArrowSchema::empty();
290
291 value.call_method1(
295 "_export_to_c",
296 (
297 &raw mut array as Py_uintptr_t,
298 &raw mut schema as Py_uintptr_t,
299 ),
300 )?;
301
302 unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
303 }
304}
305
306impl ToPyArrow for ArrayData {
307 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
308 let array = FFI_ArrowArray::new(self);
309 let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
310 array_class(py)?.call_method1(
311 "_import_from_c",
312 (
313 &raw const array as Py_uintptr_t,
314 &raw const schema as Py_uintptr_t,
315 ),
316 )
317 }
318}
319
320impl<T: FromPyArrow> FromPyArrow for Vec<T> {
321 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
322 let list = value.cast::<PyList>()?;
323 list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
324 }
325}
326
327impl<T: ToPyArrow> ToPyArrow for Vec<T> {
328 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
329 let values = self
330 .iter()
331 .map(|v| v.to_pyarrow(py))
332 .collect::<PyResult<Vec<_>>>()?;
333 Ok(PyList::new(py, values)?.into_any())
334 }
335}
336
337impl FromPyArrow for RecordBatch {
338 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
339 if value.hasattr("__arrow_c_array__")? {
344 let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
345
346 validate_pycapsule(&schema_capsule, "arrow_schema")?;
347 validate_pycapsule(&array_capsule, "arrow_array")?;
348
349 let schema_ptr = schema_capsule
350 .pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
351 .cast::<FFI_ArrowSchema>();
352 let array_ptr = array_capsule
353 .pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
354 .cast::<FFI_ArrowArray>();
355 let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
356 let array_data =
357 unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
358 if !matches!(array_data.data_type(), DataType::Struct(_)) {
359 return Err(PyTypeError::new_err(
360 "Expected Struct type from __arrow_c_array.",
361 ));
362 }
363 let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
364 let array = StructArray::from(array_data);
365 let schema =
368 unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) };
369 let (_fields, columns, nulls) = array.into_parts();
370 assert_eq!(
371 nulls.map(|n| n.null_count()).unwrap_or_default(),
372 0,
373 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
374 );
375 return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
376 }
377
378 validate_class(record_batch_class(value.py())?, value)?;
379 let schema = value.getattr("schema")?;
381 let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
382
383 let arrays = value.getattr("columns")?;
384 let arrays = arrays
385 .cast::<PyList>()?
386 .iter()
387 .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
388 .collect::<PyResult<_>>()?;
389
390 let row_count = value
391 .getattr("num_rows")
392 .ok()
393 .and_then(|x| x.extract().ok());
394 let options = RecordBatchOptions::default().with_row_count(row_count);
395
396 let batch =
397 RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
398 Ok(batch)
399 }
400}
401
402impl ToPyArrow for RecordBatch {
403 fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
404 let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
406 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
407 let py_reader = reader.into_pyarrow(py)?;
408 py_reader.call_method0("read_next_batch")
409 }
410}
411
412impl FromPyArrow for ArrowArrayStreamReader {
414 fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
415 if value.hasattr("__arrow_c_stream__")? {
419 let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
420
421 validate_pycapsule(&capsule, "arrow_array_stream")?;
422
423 let stream = unsafe {
424 FFI_ArrowArrayStream::from_raw(
425 capsule
426 .pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
427 .cast::<FFI_ArrowArrayStream>()
428 .as_ptr(),
429 )
430 };
431
432 let stream_reader = ArrowArrayStreamReader::try_new(stream)
433 .map_err(|err| PyValueError::new_err(err.to_string()))?;
434
435 return Ok(stream_reader);
436 }
437
438 validate_class(record_batch_reader_class(value.py())?, value)?;
439
440 let mut stream = FFI_ArrowArrayStream::empty();
442
443 let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
447 value.call_method1("_export_to_c", args)?;
448
449 ArrowArrayStreamReader::try_new(stream)
450 .map_err(|err| PyValueError::new_err(err.to_string()))
451 }
452}
453
454impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
456 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
459 let stream = FFI_ArrowArrayStream::new(self);
460 record_batch_reader_class(py)?
461 .call_method1("_import_from_c", (&raw const stream as Py_uintptr_t,))
462 }
463}
464
465impl IntoPyArrow for ArrowArrayStreamReader {
467 fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
468 let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
469 boxed.into_pyarrow(py)
470 }
471}
472
473#[derive(Clone)]
491pub struct Table {
492 record_batches: Vec<RecordBatch>,
493 schema: SchemaRef,
494}
495
496impl Table {
497 pub fn try_new(
498 record_batches: Vec<RecordBatch>,
499 schema: SchemaRef,
500 ) -> Result<Self, ArrowError> {
501 for record_batch in &record_batches {
502 if schema != record_batch.schema() {
503 return Err(ArrowError::SchemaError(format!(
504 "All record batches must have the same schema. \
505 Expected schema: {:?}, got schema: {:?}",
506 schema,
507 record_batch.schema()
508 )));
509 }
510 }
511 Ok(Self {
512 record_batches,
513 schema,
514 })
515 }
516
517 pub fn record_batches(&self) -> &[RecordBatch] {
518 &self.record_batches
519 }
520
521 pub fn schema(&self) -> SchemaRef {
522 self.schema.clone()
523 }
524
525 pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
526 (self.record_batches, self.schema)
527 }
528}
529
530impl TryFrom<Box<dyn RecordBatchReader>> for Table {
531 type Error = ArrowError;
532
533 fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
534 let schema = value.schema();
535 let batches = value.collect::<Result<Vec<_>, _>>()?;
536 Self::try_new(batches, schema)
537 }
538}
539
540impl FromPyArrow for Table {
542 fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
543 let reader: Box<dyn RecordBatchReader> =
544 Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
545 Self::try_from(reader).map_err(|err| PyValueError::new_err(err.to_string()))
546 }
547}
548
549impl IntoPyArrow for Table {
551 fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
552 let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
553 let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
554
555 let kwargs = PyDict::new(py);
556 kwargs.set_item("schema", py_schema)?;
557
558 table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))
559 }
560}
561
562fn array_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
563 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
564 TYPE.import(py, "pyarrow", "Array")
565}
566
567fn record_batch_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
568 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
569 TYPE.import(py, "pyarrow", "RecordBatch")
570}
571
572fn record_batch_reader_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
573 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
574 TYPE.import(py, "pyarrow", "RecordBatchReader")
575}
576fn data_type_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
577 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
578 TYPE.import(py, "pyarrow", "DataType")
579}
580
581fn field_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
582 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
583 TYPE.import(py, "pyarrow", "Field")
584}
585
586fn schema_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
587 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
588 TYPE.import(py, "pyarrow", "Schema")
589}
590
591fn table_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
592 static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
593 TYPE.import(py, "pyarrow", "Table")
594}
595
596#[derive(Debug)]
602pub struct PyArrowType<T>(pub T);
603
604impl<T: FromPyArrow> FromPyObject<'_, '_> for PyArrowType<T> {
605 type Error = PyErr;
606
607 fn extract(value: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
608 Ok(Self(T::from_pyarrow_bound(&value)?))
609 }
610}
611
612impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
613 type Target = PyAny;
614
615 type Output = Bound<'py, Self::Target>;
616
617 type Error = PyErr;
618
619 fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
620 self.0.into_pyarrow(py)
621 }
622}
623
624impl<T> From<T> for PyArrowType<T> {
625 fn from(s: T) -> Self {
626 Self(s)
627 }
628}