1use arrow_schema::DataType;
58use std::ffi::CStr;
59use std::ptr::addr_of;
60use std::{
61 ffi::CString,
62 os::raw::{c_char, c_int, c_void},
63 sync::Arc,
64};
65
66use arrow_data::ffi::FFI_ArrowArray;
67use arrow_schema::{ffi::FFI_ArrowSchema, ArrowError, Schema, SchemaRef};
68
69use crate::array::Array;
70use crate::array::StructArray;
71use crate::ffi::from_ffi_and_data_type;
72use crate::record_batch::{RecordBatch, RecordBatchReader};
73
74type Result<T> = std::result::Result<T, ArrowError>;
75
76const ENOMEM: i32 = 12;
77const EIO: i32 = 5;
78const EINVAL: i32 = 22;
79const ENOSYS: i32 = 78;
80
81#[repr(C)]
85#[derive(Debug)]
86#[allow(non_camel_case_types)]
87pub struct FFI_ArrowArrayStream {
88 pub get_schema:
90 Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowSchema) -> c_int>,
91 pub get_next: Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowArray) -> c_int>,
93 pub get_last_error: Option<unsafe extern "C" fn(arg1: *mut Self) -> *const c_char>,
95 pub release: Option<unsafe extern "C" fn(arg1: *mut Self)>,
97 pub private_data: *mut c_void,
99}
100
101unsafe impl Send for FFI_ArrowArrayStream {}
102
103unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) {
105 if stream.is_null() {
106 return;
107 }
108 let stream = &mut *stream;
109
110 stream.get_schema = None;
111 stream.get_next = None;
112 stream.get_last_error = None;
113
114 let private_data = Box::from_raw(stream.private_data as *mut StreamPrivateData);
115 drop(private_data);
116
117 stream.release = None;
118}
119
120struct StreamPrivateData {
121 batch_reader: Box<dyn RecordBatchReader + Send>,
122 last_error: Option<CString>,
123}
124
125unsafe extern "C" fn get_schema(
127 stream: *mut FFI_ArrowArrayStream,
128 schema: *mut FFI_ArrowSchema,
129) -> c_int {
130 ExportedArrayStream { stream }.get_schema(schema)
131}
132
133unsafe extern "C" fn get_next(
135 stream: *mut FFI_ArrowArrayStream,
136 array: *mut FFI_ArrowArray,
137) -> c_int {
138 ExportedArrayStream { stream }.get_next(array)
139}
140
141unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char {
143 let mut ffi_stream = ExportedArrayStream { stream };
144 match ffi_stream.get_last_error() {
147 Some(err_string) => err_string.as_ptr(),
148 None => std::ptr::null(),
149 }
150}
151
152impl Drop for FFI_ArrowArrayStream {
153 fn drop(&mut self) {
154 match self.release {
155 None => (),
156 Some(release) => unsafe { release(self) },
157 };
158 }
159}
160
161impl FFI_ArrowArrayStream {
162 pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
164 let private_data = Box::new(StreamPrivateData {
165 batch_reader,
166 last_error: None,
167 });
168
169 Self {
170 get_schema: Some(get_schema),
171 get_next: Some(get_next),
172 get_last_error: Some(get_last_error),
173 release: Some(release_stream),
174 private_data: Box::into_raw(private_data) as *mut c_void,
175 }
176 }
177
178 pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Self {
191 std::ptr::replace(raw_stream, Self::empty())
192 }
193
194 pub fn empty() -> Self {
196 Self {
197 get_schema: None,
198 get_next: None,
199 get_last_error: None,
200 release: None,
201 private_data: std::ptr::null_mut(),
202 }
203 }
204}
205
206struct ExportedArrayStream {
207 stream: *mut FFI_ArrowArrayStream,
208}
209
210impl ExportedArrayStream {
211 fn get_private_data(&mut self) -> &mut StreamPrivateData {
212 unsafe { &mut *((*self.stream).private_data as *mut StreamPrivateData) }
213 }
214
215 pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
216 let private_data = self.get_private_data();
217 let reader = &private_data.batch_reader;
218
219 let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref());
220
221 match schema {
222 Ok(schema) => {
223 unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
224 std::mem::forget(schema);
225 0
226 }
227 Err(ref err) => {
228 private_data.last_error = Some(
229 CString::new(err.to_string()).expect("Error string has a null byte in it."),
230 );
231 get_error_code(err)
232 }
233 }
234 }
235
236 pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
237 let private_data = self.get_private_data();
238 let reader = &mut private_data.batch_reader;
239
240 match reader.next() {
241 None => {
242 unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) }
244 0
245 }
246 Some(next_batch) => {
247 if let Ok(batch) = next_batch {
248 let struct_array = StructArray::from(batch);
249 let array = FFI_ArrowArray::new(&struct_array.to_data());
250
251 unsafe { std::ptr::write_unaligned(out, array) };
252 0
253 } else {
254 let err = &next_batch.unwrap_err();
255 private_data.last_error = Some(
256 CString::new(err.to_string()).expect("Error string has a null byte in it."),
257 );
258 get_error_code(err)
259 }
260 }
261 }
262 }
263
264 pub fn get_last_error(&mut self) -> Option<&CString> {
265 self.get_private_data().last_error.as_ref()
266 }
267}
268
269fn get_error_code(err: &ArrowError) -> i32 {
270 match err {
271 ArrowError::NotYetImplemented(_) => ENOSYS,
272 ArrowError::MemoryError(_) => ENOMEM,
273 ArrowError::IoError(_, _) => EIO,
274 _ => EINVAL,
275 }
276}
277
278#[derive(Debug)]
284pub struct ArrowArrayStreamReader {
285 stream: FFI_ArrowArrayStream,
286 schema: SchemaRef,
287}
288
289fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result<SchemaRef> {
292 let mut schema = FFI_ArrowSchema::empty();
293
294 let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, &mut schema) };
295
296 if ret_code == 0 {
297 let schema = Schema::try_from(&schema)?;
298 Ok(Arc::new(schema))
299 } else {
300 Err(ArrowError::CDataInterface(format!(
301 "Cannot get schema from input stream. Error code: {ret_code:?}"
302 )))
303 }
304}
305
306impl ArrowArrayStreamReader {
307 #[allow(dead_code)]
310 pub fn try_new(mut stream: FFI_ArrowArrayStream) -> Result<Self> {
311 if stream.release.is_none() {
312 return Err(ArrowError::CDataInterface(
313 "input stream is already released".to_string(),
314 ));
315 }
316
317 let schema = get_stream_schema(&mut stream)?;
318
319 Ok(Self { stream, schema })
320 }
321
322 pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result<Self> {
333 Self::try_new(FFI_ArrowArrayStream::from_raw(raw_stream))
334 }
335
336 fn get_stream_last_error(&mut self) -> Option<String> {
338 let get_last_error = self.stream.get_last_error?;
339
340 let error_str = unsafe { get_last_error(&mut self.stream) };
341 if error_str.is_null() {
342 return None;
343 }
344
345 let error_str = unsafe { CStr::from_ptr(error_str) };
346 Some(error_str.to_string_lossy().to_string())
347 }
348}
349
350impl Iterator for ArrowArrayStreamReader {
351 type Item = Result<RecordBatch>;
352
353 fn next(&mut self) -> Option<Self::Item> {
354 let mut array = FFI_ArrowArray::empty();
355
356 let ret_code = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) };
357
358 if ret_code == 0 {
359 if array.is_released() {
361 return None;
362 }
363
364 let result = unsafe {
365 from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
366 };
367 Some(result.map(|data| RecordBatch::from(StructArray::from(data))))
368 } else {
369 let last_error = self.get_stream_last_error();
370 let err = ArrowError::CDataInterface(last_error.unwrap());
371 Some(Err(err))
372 }
373 }
374}
375
376impl RecordBatchReader for ArrowArrayStreamReader {
377 fn schema(&self) -> SchemaRef {
378 self.schema.clone()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 use arrow_schema::Field;
387
388 use crate::array::Int32Array;
389 use crate::ffi::from_ffi;
390
391 struct TestRecordBatchReader {
392 schema: SchemaRef,
393 iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
394 }
395
396 impl TestRecordBatchReader {
397 pub fn new(
398 schema: SchemaRef,
399 iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
400 ) -> Box<TestRecordBatchReader> {
401 Box::new(TestRecordBatchReader { schema, iter })
402 }
403 }
404
405 impl Iterator for TestRecordBatchReader {
406 type Item = Result<RecordBatch>;
407
408 fn next(&mut self) -> Option<Self::Item> {
409 self.iter.next()
410 }
411 }
412
413 impl RecordBatchReader for TestRecordBatchReader {
414 fn schema(&self) -> SchemaRef {
415 self.schema.clone()
416 }
417 }
418
419 fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
420 let schema = Arc::new(Schema::new(vec![
421 Field::new("a", arrays[0].data_type().clone(), true),
422 Field::new("b", arrays[1].data_type().clone(), true),
423 Field::new("c", arrays[2].data_type().clone(), true),
424 ]));
425 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
426 let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
427
428 let reader = TestRecordBatchReader::new(schema.clone(), iter);
429
430 let mut ffi_stream = FFI_ArrowArrayStream::new(reader);
432
433 let mut ffi_schema = FFI_ArrowSchema::empty();
435 let ret_code = unsafe { get_schema(&mut ffi_stream, &mut ffi_schema) };
436 assert_eq!(ret_code, 0);
437
438 let exported_schema = Schema::try_from(&ffi_schema).unwrap();
439 assert_eq!(&exported_schema, schema.as_ref());
440
441 let mut produced_batches = vec![];
443 loop {
444 let mut ffi_array = FFI_ArrowArray::empty();
445 let ret_code = unsafe { get_next(&mut ffi_stream, &mut ffi_array) };
446 assert_eq!(ret_code, 0);
447
448 if ffi_array.is_released() {
450 break;
451 }
452
453 let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
454
455 let record_batch = RecordBatch::from(StructArray::from(array));
456 produced_batches.push(record_batch);
457 }
458
459 assert_eq!(produced_batches, vec![batch.clone(), batch]);
460
461 Ok(())
462 }
463
464 fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
465 let schema = Arc::new(Schema::new(vec![
466 Field::new("a", arrays[0].data_type().clone(), true),
467 Field::new("b", arrays[1].data_type().clone(), true),
468 Field::new("c", arrays[2].data_type().clone(), true),
469 ]));
470 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
471 let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
472
473 let reader = TestRecordBatchReader::new(schema.clone(), iter);
474
475 let stream = FFI_ArrowArrayStream::new(reader);
477 let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
478
479 let imported_schema = stream_reader.schema();
480 assert_eq!(imported_schema, schema);
481
482 let mut produced_batches = vec![];
483 for batch in stream_reader {
484 produced_batches.push(batch.unwrap());
485 }
486
487 assert_eq!(produced_batches, vec![batch.clone(), batch]);
488
489 Ok(())
490 }
491
492 #[test]
493 fn test_stream_round_trip_export() -> Result<()> {
494 let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
495 let array: Arc<dyn Array> = Arc::new(array);
496
497 _test_round_trip_export(vec![array.clone(), array.clone(), array])
498 }
499
500 #[test]
501 fn test_stream_round_trip_import() -> Result<()> {
502 let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
503 let array: Arc<dyn Array> = Arc::new(array);
504
505 _test_round_trip_import(vec![array.clone(), array.clone(), array])
506 }
507
508 #[test]
509 fn test_error_import() -> Result<()> {
510 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
511
512 let iter = Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());
513
514 let reader = TestRecordBatchReader::new(schema.clone(), iter);
515
516 let stream = FFI_ArrowArrayStream::new(reader);
518 let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
519
520 let imported_schema = stream_reader.schema();
521 assert_eq!(imported_schema, schema);
522
523 let mut produced_batches = vec![];
524 for batch in stream_reader {
525 produced_batches.push(batch);
526 }
527
528 assert_eq!(produced_batches.len(), 1);
530 assert!(produced_batches[0].is_err());
531
532 Ok(())
533 }
534}