arrow_array/
ffi_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains declarations to bind to the [C Stream Interface](https://arrow.apache.org/docs/format/CStreamInterface.html).
19//!
20//! This module has two main interfaces:
21//! One interface maps C ABI to native Rust types, i.e. convert c-pointers, c_char, to native rust.
22//! This is handled by [FFI_ArrowArrayStream].
23//!
24//! The second interface is used to import `FFI_ArrowArrayStream` as Rust implementation `RecordBatch` reader.
25//! This is handled by `ArrowArrayStreamReader`.
26//!
27//! ```ignore
28//! # use std::fs::File;
29//! # use std::sync::Arc;
30//! # use arrow::error::Result;
31//! # use arrow::ffi_stream::{export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream};
32//! # use arrow::ipc::reader::FileReader;
33//! # use arrow::record_batch::RecordBatchReader;
34//! # fn main() -> Result<()> {
35//! // create an record batch reader natively
36//! let file = File::open("arrow_file").unwrap();
37//! let reader = Box::new(FileReader::try_new(file).unwrap());
38//!
39//! // export it
40//! let mut stream = FFI_ArrowArrayStream::empty();
41//! unsafe { export_reader_into_raw(reader, &mut stream) };
42//!
43//! // consumed and used by something else...
44//!
45//! // import it
46//! let stream_reader = unsafe { ArrowArrayStreamReader::from_raw(&mut stream).unwrap() };
47//! let imported_schema = stream_reader.schema();
48//!
49//! let mut produced_batches = vec![];
50//! for batch in stream_reader {
51//!      produced_batches.push(batch.unwrap());
52//! }
53//! Ok(())
54//! }
55//! ```
56
57use arrow_schema::DataType;
58use std::ffi::CStr;
59use std::ptr::addr_of;
60use std::{
61    ffi::CString,
62    os::raw::{c_char, c_int, c_void},
63    sync::Arc,
64};
65
66use arrow_data::ffi::FFI_ArrowArray;
67use arrow_schema::{ArrowError, Schema, SchemaRef, ffi::FFI_ArrowSchema};
68
69use crate::array::Array;
70use crate::array::StructArray;
71use crate::ffi::from_ffi_and_data_type;
72use crate::record_batch::{RecordBatch, RecordBatchReader};
73
74type Result<T> = std::result::Result<T, ArrowError>;
75
76const ENOMEM: i32 = 12;
77const EIO: i32 = 5;
78const EINVAL: i32 = 22;
79const ENOSYS: i32 = 78;
80
81/// ABI-compatible struct for `ArrayStream` from C Stream Interface
82/// See <https://arrow.apache.org/docs/format/CStreamInterface.html#structure-definitions>
83/// This was created by bindgen
84#[repr(C)]
85#[derive(Debug)]
86#[allow(non_camel_case_types)]
87pub struct FFI_ArrowArrayStream {
88    /// C function to get schema from the stream
89    pub get_schema:
90        Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowSchema) -> c_int>,
91    /// C function to get next array from the stream
92    pub get_next: Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowArray) -> c_int>,
93    /// C function to get the error from last operation on the stream
94    pub get_last_error: Option<unsafe extern "C" fn(arg1: *mut Self) -> *const c_char>,
95    /// C function to release the stream
96    pub release: Option<unsafe extern "C" fn(arg1: *mut Self)>,
97    /// Private data used by the stream
98    pub private_data: *mut c_void,
99}
100
101unsafe impl Send for FFI_ArrowArrayStream {}
102
103// callback used to drop [FFI_ArrowArrayStream] when it is exported.
104unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) {
105    if stream.is_null() {
106        return;
107    }
108    let stream = unsafe { &mut *stream };
109
110    stream.get_schema = None;
111    stream.get_next = None;
112    stream.get_last_error = None;
113
114    let private_data = unsafe { Box::from_raw(stream.private_data as *mut StreamPrivateData) };
115    drop(private_data);
116
117    stream.release = None;
118}
119
120struct StreamPrivateData {
121    batch_reader: Box<dyn RecordBatchReader + Send>,
122    last_error: Option<CString>,
123}
124
125// The callback used to get array schema
126unsafe extern "C" fn get_schema(
127    stream: *mut FFI_ArrowArrayStream,
128    schema: *mut FFI_ArrowSchema,
129) -> c_int {
130    ExportedArrayStream { stream }.get_schema(schema)
131}
132
133// The callback used to get next array
134unsafe extern "C" fn get_next(
135    stream: *mut FFI_ArrowArrayStream,
136    array: *mut FFI_ArrowArray,
137) -> c_int {
138    ExportedArrayStream { stream }.get_next(array)
139}
140
141// The callback used to get the error from last operation on the `FFI_ArrowArrayStream`
142unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char {
143    let mut ffi_stream = ExportedArrayStream { stream };
144    // The consumer should not take ownership of this string, we should return
145    // a const pointer to it.
146    match ffi_stream.get_last_error() {
147        Some(err_string) => err_string.as_ptr(),
148        None => std::ptr::null(),
149    }
150}
151
152impl Drop for FFI_ArrowArrayStream {
153    fn drop(&mut self) {
154        match self.release {
155            None => (),
156            Some(release) => unsafe { release(self) },
157        };
158    }
159}
160
161impl FFI_ArrowArrayStream {
162    /// Creates a new [`FFI_ArrowArrayStream`].
163    pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
164        let private_data = Box::new(StreamPrivateData {
165            batch_reader,
166            last_error: None,
167        });
168
169        Self {
170            get_schema: Some(get_schema),
171            get_next: Some(get_next),
172            get_last_error: Some(get_last_error),
173            release: Some(release_stream),
174            private_data: Box::into_raw(private_data) as *mut c_void,
175        }
176    }
177
178    /// Takes ownership of the pointed to [`FFI_ArrowArrayStream`]
179    ///
180    /// This acts to [move] the data out of `raw_stream`, setting the release callback to NULL
181    ///
182    /// # Safety
183    ///
184    /// * `raw_stream` must be [valid] for reads and writes
185    /// * `raw_stream` must be properly aligned
186    /// * `raw_stream` must point to a properly initialized value of [`FFI_ArrowArrayStream`]
187    ///
188    /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array
189    /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety
190    pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Self {
191        unsafe { std::ptr::replace(raw_stream, Self::empty()) }
192    }
193
194    /// Creates a new empty [FFI_ArrowArrayStream]. Used to import from the C Stream Interface.
195    pub fn empty() -> Self {
196        Self {
197            get_schema: None,
198            get_next: None,
199            get_last_error: None,
200            release: None,
201            private_data: std::ptr::null_mut(),
202        }
203    }
204}
205
206struct ExportedArrayStream {
207    stream: *mut FFI_ArrowArrayStream,
208}
209
210impl ExportedArrayStream {
211    fn get_private_data(&mut self) -> &mut StreamPrivateData {
212        unsafe { &mut *((*self.stream).private_data as *mut StreamPrivateData) }
213    }
214
215    pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
216        let private_data = self.get_private_data();
217        let reader = &private_data.batch_reader;
218
219        let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref());
220
221        match schema {
222            Ok(schema) => {
223                unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
224                std::mem::forget(schema);
225                0
226            }
227            Err(ref err) => {
228                private_data.last_error = Some(
229                    CString::new(err.to_string()).expect("Error string has a null byte in it."),
230                );
231                get_error_code(err)
232            }
233        }
234    }
235
236    pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
237        let private_data = self.get_private_data();
238        let reader = &mut private_data.batch_reader;
239
240        match reader.next() {
241            None => {
242                // Marks ArrowArray released to indicate reaching the end of stream.
243                unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) }
244                0
245            }
246            Some(next_batch) => {
247                if let Ok(batch) = next_batch {
248                    let struct_array = StructArray::from(batch);
249                    let array = FFI_ArrowArray::new(&struct_array.to_data());
250
251                    unsafe { std::ptr::write_unaligned(out, array) };
252                    0
253                } else {
254                    let err = &next_batch.unwrap_err();
255                    private_data.last_error = Some(
256                        CString::new(err.to_string()).expect("Error string has a null byte in it."),
257                    );
258                    get_error_code(err)
259                }
260            }
261        }
262    }
263
264    pub fn get_last_error(&mut self) -> Option<&CString> {
265        self.get_private_data().last_error.as_ref()
266    }
267}
268
269fn get_error_code(err: &ArrowError) -> i32 {
270    match err {
271        ArrowError::NotYetImplemented(_) => ENOSYS,
272        ArrowError::MemoryError(_) => ENOMEM,
273        ArrowError::IoError(_, _) => EIO,
274        _ => EINVAL,
275    }
276}
277
278/// A `RecordBatchReader` which imports Arrays from `FFI_ArrowArrayStream`.
279///
280/// Struct used to fetch `RecordBatch` from the C Stream Interface.
281/// Its main responsibility is to expose `RecordBatchReader` functionality
282/// that requires [FFI_ArrowArrayStream].
283#[derive(Debug)]
284pub struct ArrowArrayStreamReader {
285    stream: FFI_ArrowArrayStream,
286    schema: SchemaRef,
287}
288
289/// Gets schema from a raw pointer of `FFI_ArrowArrayStream`. This is used when constructing
290/// `ArrowArrayStreamReader` to cache schema.
291fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result<SchemaRef> {
292    let mut schema = FFI_ArrowSchema::empty();
293
294    let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, &mut schema) };
295
296    if ret_code == 0 {
297        let schema = Schema::try_from(&schema)?;
298        Ok(Arc::new(schema))
299    } else {
300        Err(ArrowError::CDataInterface(format!(
301            "Cannot get schema from input stream. Error code: {ret_code:?}"
302        )))
303    }
304}
305
306impl ArrowArrayStreamReader {
307    /// Creates a new `ArrowArrayStreamReader` from a `FFI_ArrowArrayStream`.
308    /// This is used to import from the C Stream Interface.
309    #[allow(dead_code)]
310    pub fn try_new(mut stream: FFI_ArrowArrayStream) -> Result<Self> {
311        if stream.release.is_none() {
312            return Err(ArrowError::CDataInterface(
313                "input stream is already released".to_string(),
314            ));
315        }
316
317        let schema = get_stream_schema(&mut stream)?;
318
319        Ok(Self { stream, schema })
320    }
321
322    /// Creates a new `ArrowArrayStreamReader` from a raw pointer of `FFI_ArrowArrayStream`.
323    ///
324    /// Assumes that the pointer represents valid C Stream Interfaces.
325    /// This function copies the content from the raw pointer and cleans up it to prevent
326    /// double-dropping. The caller is responsible for freeing up the memory allocated for
327    /// the pointer.
328    ///
329    /// # Safety
330    ///
331    /// See [`FFI_ArrowArrayStream::from_raw`]
332    pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result<Self> {
333        Self::try_new(unsafe { FFI_ArrowArrayStream::from_raw(raw_stream) })
334    }
335
336    /// Get the last error from `ArrowArrayStreamReader`
337    fn get_stream_last_error(&mut self) -> Option<String> {
338        let get_last_error = self.stream.get_last_error?;
339
340        let error_str = unsafe { get_last_error(&mut self.stream) };
341        if error_str.is_null() {
342            return None;
343        }
344
345        let error_str = unsafe { CStr::from_ptr(error_str) };
346        Some(error_str.to_string_lossy().to_string())
347    }
348}
349
350impl Iterator for ArrowArrayStreamReader {
351    type Item = Result<RecordBatch>;
352
353    fn next(&mut self) -> Option<Self::Item> {
354        let mut array = FFI_ArrowArray::empty();
355
356        let ret_code = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) };
357
358        if ret_code == 0 {
359            // The end of stream has been reached
360            if array.is_released() {
361                return None;
362            }
363
364            let result = unsafe {
365                from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
366            };
367            Some(result.and_then(|data| {
368                RecordBatch::try_new(self.schema.clone(), StructArray::from(data).into_parts().1)
369            }))
370        } else {
371            let last_error = self.get_stream_last_error();
372            let err = ArrowError::CDataInterface(last_error.unwrap());
373            Some(Err(err))
374        }
375    }
376}
377
378impl RecordBatchReader for ArrowArrayStreamReader {
379    fn schema(&self) -> SchemaRef {
380        self.schema.clone()
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use std::collections::HashMap;
388
389    use arrow_schema::Field;
390
391    use crate::array::Int32Array;
392    use crate::ffi::from_ffi;
393
394    struct TestRecordBatchReader {
395        schema: SchemaRef,
396        iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
397    }
398
399    impl TestRecordBatchReader {
400        pub fn new(
401            schema: SchemaRef,
402            iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
403        ) -> Box<TestRecordBatchReader> {
404            Box::new(TestRecordBatchReader { schema, iter })
405        }
406    }
407
408    impl Iterator for TestRecordBatchReader {
409        type Item = Result<RecordBatch>;
410
411        fn next(&mut self) -> Option<Self::Item> {
412            self.iter.next()
413        }
414    }
415
416    impl RecordBatchReader for TestRecordBatchReader {
417        fn schema(&self) -> SchemaRef {
418            self.schema.clone()
419        }
420    }
421
422    fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
423        let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
424        let schema = Arc::new(Schema::new_with_metadata(
425            vec![
426                Field::new("a", arrays[0].data_type().clone(), true)
427                    .with_metadata(metadata.clone()),
428                Field::new("b", arrays[1].data_type().clone(), true)
429                    .with_metadata(metadata.clone()),
430                Field::new("c", arrays[2].data_type().clone(), true)
431                    .with_metadata(metadata.clone()),
432            ],
433            metadata,
434        ));
435        let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
436        let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
437
438        let reader = TestRecordBatchReader::new(schema.clone(), iter);
439
440        // Export a `RecordBatchReader` through `FFI_ArrowArrayStream`
441        let mut ffi_stream = FFI_ArrowArrayStream::new(reader);
442
443        // Get schema from `FFI_ArrowArrayStream`
444        let mut ffi_schema = FFI_ArrowSchema::empty();
445        let ret_code = unsafe { get_schema(&mut ffi_stream, &mut ffi_schema) };
446        assert_eq!(ret_code, 0);
447
448        let exported_schema = Schema::try_from(&ffi_schema).unwrap();
449        assert_eq!(&exported_schema, schema.as_ref());
450
451        // Get array from `FFI_ArrowArrayStream`
452        let mut produced_batches = vec![];
453        loop {
454            let mut ffi_array = FFI_ArrowArray::empty();
455            let ret_code = unsafe { get_next(&mut ffi_stream, &mut ffi_array) };
456            assert_eq!(ret_code, 0);
457
458            // The end of stream has been reached
459            if ffi_array.is_released() {
460                break;
461            }
462
463            let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
464
465            let record_batch = RecordBatch::try_new(
466                SchemaRef::from(exported_schema.clone()),
467                StructArray::from(array).into_parts().1,
468            )
469            .unwrap();
470            produced_batches.push(record_batch);
471        }
472
473        assert_eq!(produced_batches, vec![batch.clone(), batch]);
474
475        Ok(())
476    }
477
478    fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
479        let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
480        let schema = Arc::new(Schema::new_with_metadata(
481            vec![
482                Field::new("a", arrays[0].data_type().clone(), true)
483                    .with_metadata(metadata.clone()),
484                Field::new("b", arrays[1].data_type().clone(), true)
485                    .with_metadata(metadata.clone()),
486                Field::new("c", arrays[2].data_type().clone(), true)
487                    .with_metadata(metadata.clone()),
488            ],
489            metadata,
490        ));
491        let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
492        let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
493
494        let reader = TestRecordBatchReader::new(schema.clone(), iter);
495
496        // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader`
497        let stream = FFI_ArrowArrayStream::new(reader);
498        let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
499
500        let imported_schema = stream_reader.schema();
501        assert_eq!(imported_schema, schema);
502
503        let mut produced_batches = vec![];
504        for batch in stream_reader {
505            produced_batches.push(batch.unwrap());
506        }
507
508        assert_eq!(produced_batches, vec![batch.clone(), batch]);
509
510        Ok(())
511    }
512
513    #[test]
514    fn test_stream_round_trip_export() -> Result<()> {
515        let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
516        let array: Arc<dyn Array> = Arc::new(array);
517
518        _test_round_trip_export(vec![array.clone(), array.clone(), array])
519    }
520
521    #[test]
522    fn test_stream_round_trip_import() -> Result<()> {
523        let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
524        let array: Arc<dyn Array> = Arc::new(array);
525
526        _test_round_trip_import(vec![array.clone(), array.clone(), array])
527    }
528
529    #[test]
530    fn test_error_import() -> Result<()> {
531        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
532
533        let iter = Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());
534
535        let reader = TestRecordBatchReader::new(schema.clone(), iter);
536
537        // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader`
538        let stream = FFI_ArrowArrayStream::new(reader);
539        let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
540
541        let imported_schema = stream_reader.schema();
542        assert_eq!(imported_schema, schema);
543
544        let mut produced_batches = vec![];
545        for batch in stream_reader {
546            produced_batches.push(batch);
547        }
548
549        // The results should outlive the lifetime of the stream itself.
550        assert_eq!(produced_batches.len(), 1);
551        assert!(produced_batches[0].is_err());
552
553        Ok(())
554    }
555}