arrow_schema/
ffi.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html).
19//!
20//! ```
21//! # use arrow_schema::{DataType, Field, Schema};
22//! # use arrow_schema::ffi::FFI_ArrowSchema;
23//!
24//! // Create from data type
25//! let ffi_data_type = FFI_ArrowSchema::try_from(&DataType::LargeUtf8).unwrap();
26//! let back = DataType::try_from(&ffi_data_type).unwrap();
27//! assert_eq!(back, DataType::LargeUtf8);
28//!
29//! // Create from schema
30//! let schema = Schema::new(vec![Field::new("foo", DataType::Int64, false)]);
31//! let ffi_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
32//! let back = Schema::try_from(&ffi_schema).unwrap();
33//!
34//! assert_eq!(schema, back);
35//! ```
36
37use crate::{
38    ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode,
39};
40use bitflags::bitflags;
41use std::borrow::Cow;
42use std::sync::Arc;
43use std::{
44    collections::HashMap,
45    ffi::{CStr, CString, c_char, c_void},
46};
47
48bitflags! {
49    /// Flags for [`FFI_ArrowSchema`]
50    ///
51    /// Old Workaround at <https://github.com/bitflags/bitflags/issues/356>
52    /// is no longer required as `bitflags` [fixed the issue](https://github.com/bitflags/bitflags/pull/355).
53    pub struct Flags: i64 {
54        /// Indicates that the dictionary is ordered
55        const DICTIONARY_ORDERED = 0b00000001;
56        /// Indicates that the field is nullable
57        const NULLABLE = 0b00000010;
58        /// Indicates that the map keys are sorted
59        const MAP_KEYS_SORTED = 0b00000100;
60    }
61}
62
63/// ABI-compatible struct for `ArrowSchema` from C Data Interface
64/// See <https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions>
65///
66/// ```
67/// # use arrow_schema::DataType;
68/// # use arrow_schema::ffi::FFI_ArrowSchema;
69/// fn array_schema(data_type: &DataType) -> FFI_ArrowSchema {
70///     FFI_ArrowSchema::try_from(data_type).unwrap()
71/// }
72/// ```
73///
74#[repr(C)]
75#[derive(Debug)]
76#[allow(non_camel_case_types)]
77pub struct FFI_ArrowSchema {
78    format: *const c_char,
79    name: *const c_char,
80    metadata: *const c_char,
81    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
82    flags: i64,
83    n_children: i64,
84    children: *mut *mut FFI_ArrowSchema,
85    dictionary: *mut FFI_ArrowSchema,
86    release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowSchema)>,
87    private_data: *mut c_void,
88}
89
90struct SchemaPrivateData {
91    children: Box<[*mut FFI_ArrowSchema]>,
92    dictionary: *mut FFI_ArrowSchema,
93    metadata: Option<Vec<u8>>,
94}
95
96// callback used to drop [FFI_ArrowSchema] when it is exported.
97unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) {
98    if schema.is_null() {
99        return;
100    }
101    let schema = unsafe { &mut *schema };
102
103    // take ownership back to release it.
104    drop(unsafe { CString::from_raw(schema.format as *mut c_char) });
105    if !schema.name.is_null() {
106        drop(unsafe { CString::from_raw(schema.name as *mut c_char) });
107    }
108    if !schema.private_data.is_null() {
109        let private_data = unsafe { Box::from_raw(schema.private_data as *mut SchemaPrivateData) };
110        for child in private_data.children.iter() {
111            drop(unsafe { Box::from_raw(*child) })
112        }
113        if !private_data.dictionary.is_null() {
114            drop(unsafe { Box::from_raw(private_data.dictionary) });
115        }
116
117        drop(private_data);
118    }
119
120    schema.release = None;
121}
122
123impl FFI_ArrowSchema {
124    /// create a new [`FFI_ArrowSchema`]. This fails if the fields'
125    /// [`DataType`] is not supported.
126    pub fn try_new(
127        format: &str,
128        children: Vec<FFI_ArrowSchema>,
129        dictionary: Option<FFI_ArrowSchema>,
130    ) -> Result<Self, ArrowError> {
131        let mut this = Self::empty();
132
133        let children_ptr = children
134            .into_iter()
135            .map(Box::new)
136            .map(Box::into_raw)
137            .collect::<Box<_>>();
138
139        this.format = CString::new(format).unwrap().into_raw();
140        this.release = Some(release_schema);
141        this.n_children = children_ptr.len() as i64;
142
143        let dictionary_ptr = dictionary
144            .map(|d| Box::into_raw(Box::new(d)))
145            .unwrap_or(std::ptr::null_mut());
146
147        let mut private_data = Box::new(SchemaPrivateData {
148            children: children_ptr,
149            dictionary: dictionary_ptr,
150            metadata: None,
151        });
152
153        // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580)
154        this.children = private_data.children.as_mut_ptr();
155
156        this.dictionary = dictionary_ptr;
157
158        this.private_data = Box::into_raw(private_data) as *mut c_void;
159
160        Ok(this)
161    }
162
163    /// Set the name of the schema
164    pub fn with_name(mut self, name: &str) -> Result<Self, ArrowError> {
165        self.name = CString::new(name).unwrap().into_raw();
166        Ok(self)
167    }
168
169    /// Set the flags of the schema
170    pub fn with_flags(mut self, flags: Flags) -> Result<Self, ArrowError> {
171        self.flags = flags.bits();
172        Ok(self)
173    }
174
175    /// Add metadata to the schema
176    pub fn with_metadata<I, S>(mut self, metadata: I) -> Result<Self, ArrowError>
177    where
178        I: IntoIterator<Item = (S, S)>,
179        S: AsRef<str>,
180    {
181        let metadata: Vec<(S, S)> = metadata.into_iter().collect();
182        // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata
183        let new_metadata = if !metadata.is_empty() {
184            let mut metadata_serialized: Vec<u8> = Vec::new();
185            let num_entries: i32 = metadata.len().try_into().map_err(|_| {
186                ArrowError::CDataInterface(format!(
187                    "metadata can only have {} entries, but {} were provided",
188                    i32::MAX,
189                    metadata.len()
190                ))
191            })?;
192            metadata_serialized.extend(num_entries.to_ne_bytes());
193
194            for (key, value) in metadata.into_iter() {
195                let key_len: i32 = key.as_ref().len().try_into().map_err(|_| {
196                    ArrowError::CDataInterface(format!(
197                        "metadata key can only have {} bytes, but {} were provided",
198                        i32::MAX,
199                        key.as_ref().len()
200                    ))
201                })?;
202                let value_len: i32 = value.as_ref().len().try_into().map_err(|_| {
203                    ArrowError::CDataInterface(format!(
204                        "metadata value can only have {} bytes, but {} were provided",
205                        i32::MAX,
206                        value.as_ref().len()
207                    ))
208                })?;
209
210                metadata_serialized.extend(key_len.to_ne_bytes());
211                metadata_serialized.extend_from_slice(key.as_ref().as_bytes());
212                metadata_serialized.extend(value_len.to_ne_bytes());
213                metadata_serialized.extend_from_slice(value.as_ref().as_bytes());
214            }
215
216            self.metadata = metadata_serialized.as_ptr() as *const c_char;
217            Some(metadata_serialized)
218        } else {
219            self.metadata = std::ptr::null_mut();
220            None
221        };
222
223        unsafe {
224            let mut private_data = Box::from_raw(self.private_data as *mut SchemaPrivateData);
225            private_data.metadata = new_metadata;
226            self.private_data = Box::into_raw(private_data) as *mut c_void;
227        }
228
229        Ok(self)
230    }
231
232    /// Takes ownership of the pointed to [`FFI_ArrowSchema`]
233    ///
234    /// This acts to [move] the data out of `schema`, setting the release callback to NULL
235    ///
236    /// # Safety
237    ///
238    /// * `schema` must be [valid] for reads and writes
239    /// * `schema` must be properly aligned
240    /// * `schema` must point to a properly initialized value of [`FFI_ArrowSchema`]
241    ///
242    /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array
243    /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety
244    pub unsafe fn from_raw(schema: *mut FFI_ArrowSchema) -> Self {
245        unsafe { std::ptr::replace(schema, Self::empty()) }
246    }
247
248    /// Create an empty [`FFI_ArrowSchema`]
249    pub fn empty() -> Self {
250        Self {
251            format: std::ptr::null_mut(),
252            name: std::ptr::null_mut(),
253            metadata: std::ptr::null_mut(),
254            flags: 0,
255            n_children: 0,
256            children: std::ptr::null_mut(),
257            dictionary: std::ptr::null_mut(),
258            release: None,
259            private_data: std::ptr::null_mut(),
260        }
261    }
262
263    /// Returns the format of this schema.
264    pub fn format(&self) -> &str {
265        assert!(!self.format.is_null());
266        // safe because the lifetime of `self.format` equals `self`
267        unsafe { CStr::from_ptr(self.format) }
268            .to_str()
269            .expect("The external API has a non-utf8 as format")
270    }
271
272    /// Returns the name of this schema.
273    pub fn name(&self) -> Option<&str> {
274        if self.name.is_null() {
275            None
276        } else {
277            // safe because the lifetime of `self.name` equals `self`
278            Some(
279                unsafe { CStr::from_ptr(self.name) }
280                    .to_str()
281                    .expect("The external API has a non-utf8 as name"),
282            )
283        }
284    }
285
286    /// Returns the flags of this schema.
287    pub fn flags(&self) -> Option<Flags> {
288        Flags::from_bits(self.flags)
289    }
290
291    /// Returns the child of this schema at `index`.
292    ///
293    /// # Panics
294    ///
295    /// Panics if `index` is greater than or equal to the number of children.
296    ///
297    /// This is to make sure that the unsafe acces to raw pointer is sound.
298    pub fn child(&self, index: usize) -> &Self {
299        assert!(index < self.n_children as usize);
300        unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
301    }
302
303    /// Returns an iterator to the schema's children.
304    pub fn children(&self) -> impl Iterator<Item = &Self> {
305        (0..self.n_children as usize).map(move |i| self.child(i))
306    }
307
308    /// Returns if the field is semantically nullable,
309    /// regardless of whether it actually has null values.
310    pub fn nullable(&self) -> bool {
311        (self.flags / 2) & 1 == 1
312    }
313
314    /// Returns the reference to the underlying dictionary of the schema.
315    /// Check [ArrowSchema.dictionary](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.dictionary).
316    ///
317    /// This must be `Some` if the schema represents a dictionary-encoded type, `None` otherwise.
318    pub fn dictionary(&self) -> Option<&Self> {
319        unsafe { self.dictionary.as_ref() }
320    }
321
322    /// For map types, returns whether the keys within each map value are sorted.
323    ///
324    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
325    pub fn map_keys_sorted(&self) -> bool {
326        self.flags & 0b00000100 != 0
327    }
328
329    /// For dictionary-encoded types, returns whether the ordering of dictionary indices is semantically meaningful.
330    pub fn dictionary_ordered(&self) -> bool {
331        self.flags & 0b00000001 != 0
332    }
333
334    /// Returns the metadata in the schema as `Key-Value` pairs
335    pub fn metadata(&self) -> Result<HashMap<String, String>, ArrowError> {
336        if self.metadata.is_null() {
337            Ok(HashMap::new())
338        } else {
339            let mut pos = 0;
340
341            // On some platforms, c_char = u8, and on some, c_char = i8. Where c_char = u8, clippy
342            // wants to complain that we're casting to the same type, but if we remove the cast,
343            // this will fail to compile on the other platforms. So we must allow it.
344            #[allow(clippy::unnecessary_cast)]
345            let buffer: *const u8 = self.metadata as *const u8;
346
347            fn next_four_bytes(buffer: *const u8, pos: &mut isize) -> [u8; 4] {
348                let out = unsafe {
349                    [
350                        *buffer.offset(*pos),
351                        *buffer.offset(*pos + 1),
352                        *buffer.offset(*pos + 2),
353                        *buffer.offset(*pos + 3),
354                    ]
355                };
356                *pos += 4;
357                out
358            }
359
360            fn next_n_bytes(buffer: *const u8, pos: &mut isize, n: i32) -> &[u8] {
361                let out = unsafe {
362                    std::slice::from_raw_parts(buffer.offset(*pos), n.try_into().unwrap())
363                };
364                *pos += isize::try_from(n).unwrap();
365                out
366            }
367
368            let num_entries = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
369            if num_entries < 0 {
370                return Err(ArrowError::CDataInterface(
371                    "Negative number of metadata entries".to_string(),
372                ));
373            }
374
375            let mut metadata =
376                HashMap::with_capacity(num_entries.try_into().expect("Too many metadata entries"));
377
378            for _ in 0..num_entries {
379                let key_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
380                if key_length < 0 {
381                    return Err(ArrowError::CDataInterface(
382                        "Negative key length in metadata".to_string(),
383                    ));
384                }
385                let key = String::from_utf8(next_n_bytes(buffer, &mut pos, key_length).to_vec())?;
386                let value_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
387                if value_length < 0 {
388                    return Err(ArrowError::CDataInterface(
389                        "Negative value length in metadata".to_string(),
390                    ));
391                }
392                let value =
393                    String::from_utf8(next_n_bytes(buffer, &mut pos, value_length).to_vec())?;
394                metadata.insert(key, value);
395            }
396
397            Ok(metadata)
398        }
399    }
400}
401
402impl Drop for FFI_ArrowSchema {
403    fn drop(&mut self) {
404        match self.release {
405            None => (),
406            Some(release) => unsafe { release(self) },
407        };
408    }
409}
410
411unsafe impl Send for FFI_ArrowSchema {}
412
413impl TryFrom<&FFI_ArrowSchema> for DataType {
414    type Error = ArrowError;
415
416    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
417    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
418        let mut dtype = match c_schema.format() {
419            "n" => DataType::Null,
420            "b" => DataType::Boolean,
421            "c" => DataType::Int8,
422            "C" => DataType::UInt8,
423            "s" => DataType::Int16,
424            "S" => DataType::UInt16,
425            "i" => DataType::Int32,
426            "I" => DataType::UInt32,
427            "l" => DataType::Int64,
428            "L" => DataType::UInt64,
429            "e" => DataType::Float16,
430            "f" => DataType::Float32,
431            "g" => DataType::Float64,
432            "vz" => DataType::BinaryView,
433            "z" => DataType::Binary,
434            "Z" => DataType::LargeBinary,
435            "vu" => DataType::Utf8View,
436            "u" => DataType::Utf8,
437            "U" => DataType::LargeUtf8,
438            "tdD" => DataType::Date32,
439            "tdm" => DataType::Date64,
440            "tts" => DataType::Time32(TimeUnit::Second),
441            "ttm" => DataType::Time32(TimeUnit::Millisecond),
442            "ttu" => DataType::Time64(TimeUnit::Microsecond),
443            "ttn" => DataType::Time64(TimeUnit::Nanosecond),
444            "tDs" => DataType::Duration(TimeUnit::Second),
445            "tDm" => DataType::Duration(TimeUnit::Millisecond),
446            "tDu" => DataType::Duration(TimeUnit::Microsecond),
447            "tDn" => DataType::Duration(TimeUnit::Nanosecond),
448            "tiM" => DataType::Interval(IntervalUnit::YearMonth),
449            "tiD" => DataType::Interval(IntervalUnit::DayTime),
450            "tin" => DataType::Interval(IntervalUnit::MonthDayNano),
451            "+l" => {
452                let c_child = c_schema.child(0);
453                DataType::List(Arc::new(Field::try_from(c_child)?))
454            }
455            "+L" => {
456                let c_child = c_schema.child(0);
457                DataType::LargeList(Arc::new(Field::try_from(c_child)?))
458            }
459            "+s" => {
460                let fields = c_schema.children().map(Field::try_from);
461                DataType::Struct(fields.collect::<Result<_, ArrowError>>()?)
462            }
463            "+m" => {
464                let c_child = c_schema.child(0);
465                let map_keys_sorted = c_schema.map_keys_sorted();
466                DataType::Map(Arc::new(Field::try_from(c_child)?), map_keys_sorted)
467            }
468            "+r" => {
469                let c_run_ends = c_schema.child(0);
470                let c_values = c_schema.child(1);
471                DataType::RunEndEncoded(
472                    Arc::new(Field::try_from(c_run_ends)?),
473                    Arc::new(Field::try_from(c_values)?),
474                )
475            }
476            // Parametrized types, requiring string parse
477            other => {
478                match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
479                    // FixedSizeBinary type in format "w:num_bytes"
480                    ["w", num_bytes] => {
481                        let parsed_num_bytes = num_bytes.parse::<i32>().map_err(|_| {
482                            ArrowError::CDataInterface(
483                                "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string())
484                        })?;
485                        DataType::FixedSizeBinary(parsed_num_bytes)
486                    }
487                    // FixedSizeList type in format "+w:num_elems"
488                    ["+w", num_elems] => {
489                        let c_child = c_schema.child(0);
490                        let parsed_num_elems = num_elems.parse::<i32>().map_err(|_| {
491                            ArrowError::CDataInterface(
492                                "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string())
493                        })?;
494                        DataType::FixedSizeList(
495                            Arc::new(Field::try_from(c_child)?),
496                            parsed_num_elems,
497                        )
498                    }
499                    // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
500                    ["d", extra] => match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
501                        [precision, scale] => {
502                            let parsed_precision = precision.parse::<u8>().map_err(|_| {
503                                ArrowError::CDataInterface(
504                                    "The decimal type requires an integer precision".to_string(),
505                                )
506                            })?;
507                            let parsed_scale = scale.parse::<i8>().map_err(|_| {
508                                ArrowError::CDataInterface(
509                                    "The decimal type requires an integer scale".to_string(),
510                                )
511                            })?;
512                            DataType::Decimal128(parsed_precision, parsed_scale)
513                        }
514                        [precision, scale, bits] => {
515                            let parsed_precision = precision.parse::<u8>().map_err(|_| {
516                                ArrowError::CDataInterface(
517                                    "The decimal type requires an integer precision".to_string(),
518                                )
519                            })?;
520                            let parsed_scale = scale.parse::<i8>().map_err(|_| {
521                                ArrowError::CDataInterface(
522                                    "The decimal type requires an integer scale".to_string(),
523                                )
524                            })?;
525                            match *bits {
526                                    "32" => DataType::Decimal32(parsed_precision, parsed_scale),
527                                    "64" => DataType::Decimal64(parsed_precision, parsed_scale),
528                                    "128" => DataType::Decimal128(parsed_precision, parsed_scale),
529                                    "256" => DataType::Decimal256(parsed_precision, parsed_scale),
530                                    _ => return Err(ArrowError::CDataInterface("Only 32/64/128/256 bit wide decimals are supported in the Rust implementation".to_string())),
531                                }
532                        }
533                        _ => {
534                            return Err(ArrowError::CDataInterface(format!(
535                                "The decimal pattern \"d:{extra:?}\" is not supported in the Rust implementation"
536                            )));
537                        }
538                    },
539                    // DenseUnion
540                    ["+ud", extra] => {
541                        let type_ids = extra
542                            .split(',')
543                            .map(|t| {
544                                t.parse::<i8>().map_err(|_| {
545                                    ArrowError::CDataInterface(
546                                        "The Union type requires an integer type id".to_string(),
547                                    )
548                                })
549                            })
550                            .collect::<Result<Vec<_>, ArrowError>>()?;
551                        let mut fields = Vec::with_capacity(type_ids.len());
552                        for idx in 0..c_schema.n_children {
553                            let c_child = c_schema.child(idx as usize);
554                            let field = Field::try_from(c_child)?;
555                            fields.push(field);
556                        }
557
558                        if fields.len() != type_ids.len() {
559                            return Err(ArrowError::CDataInterface(
560                                "The Union type requires same number of fields and type ids"
561                                    .to_string(),
562                            ));
563                        }
564
565                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
566                    }
567                    // SparseUnion
568                    ["+us", extra] => {
569                        let type_ids = extra
570                            .split(',')
571                            .map(|t| {
572                                t.parse::<i8>().map_err(|_| {
573                                    ArrowError::CDataInterface(
574                                        "The Union type requires an integer type id".to_string(),
575                                    )
576                                })
577                            })
578                            .collect::<Result<Vec<_>, ArrowError>>()?;
579                        let mut fields = Vec::with_capacity(type_ids.len());
580                        for idx in 0..c_schema.n_children {
581                            let c_child = c_schema.child(idx as usize);
582                            let field = Field::try_from(c_child)?;
583                            fields.push(field);
584                        }
585
586                        if fields.len() != type_ids.len() {
587                            return Err(ArrowError::CDataInterface(
588                                "The Union type requires same number of fields and type ids"
589                                    .to_string(),
590                            ));
591                        }
592
593                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Sparse)
594                    }
595
596                    // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
597                    ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
598                    ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
599                    ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
600                    ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
601                    ["tss", tz] => DataType::Timestamp(TimeUnit::Second, Some(Arc::from(*tz))),
602                    ["tsm", tz] => DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from(*tz))),
603                    ["tsu", tz] => DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from(*tz))),
604                    ["tsn", tz] => DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from(*tz))),
605                    _ => {
606                        return Err(ArrowError::CDataInterface(format!(
607                            "The datatype \"{other:?}\" is still not supported in Rust implementation"
608                        )));
609                    }
610                }
611            }
612        };
613
614        if let Some(dict_schema) = c_schema.dictionary() {
615            let value_type = Self::try_from(dict_schema)?;
616            dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type));
617        }
618
619        Ok(dtype)
620    }
621}
622
623impl TryFrom<&FFI_ArrowSchema> for Field {
624    type Error = ArrowError;
625
626    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
627        let dtype = DataType::try_from(c_schema)?;
628        let mut field = Field::new(c_schema.name().unwrap_or(""), dtype, c_schema.nullable());
629        field.set_metadata(c_schema.metadata()?);
630        Ok(field)
631    }
632}
633
634impl TryFrom<&FFI_ArrowSchema> for Schema {
635    type Error = ArrowError;
636
637    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
638        // interpret it as a struct type then extract its fields
639        let dtype = DataType::try_from(c_schema)?;
640        if let DataType::Struct(fields) = dtype {
641            Ok(Schema::new(fields).with_metadata(c_schema.metadata()?))
642        } else {
643            Err(ArrowError::CDataInterface(
644                "Unable to interpret C data struct as a Schema".to_string(),
645            ))
646        }
647    }
648}
649
650impl TryFrom<&DataType> for FFI_ArrowSchema {
651    type Error = ArrowError;
652
653    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
654    fn try_from(dtype: &DataType) -> Result<Self, ArrowError> {
655        let format = get_format_string(dtype)?;
656        // allocate and hold the children
657        let children = match dtype {
658            DataType::List(child)
659            | DataType::LargeList(child)
660            | DataType::FixedSizeList(child, _)
661            | DataType::Map(child, _) => {
662                vec![FFI_ArrowSchema::try_from(child.as_ref())?]
663            }
664            DataType::Union(fields, _) => fields
665                .iter()
666                .map(|(_, f)| f.as_ref().try_into())
667                .collect::<Result<Vec<_>, ArrowError>>()?,
668            DataType::Struct(fields) => fields
669                .iter()
670                .map(FFI_ArrowSchema::try_from)
671                .collect::<Result<Vec<_>, ArrowError>>()?,
672            DataType::RunEndEncoded(run_ends, values) => vec![
673                FFI_ArrowSchema::try_from(run_ends.as_ref())?,
674                FFI_ArrowSchema::try_from(values.as_ref())?,
675            ],
676            _ => vec![],
677        };
678        let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype {
679            Some(Self::try_from(value_data_type.as_ref())?)
680        } else {
681            None
682        };
683
684        let flags = match dtype {
685            DataType::Map(_, true) => Flags::MAP_KEYS_SORTED,
686            _ => Flags::empty(),
687        };
688
689        FFI_ArrowSchema::try_new(&format, children, dictionary)?.with_flags(flags)
690    }
691}
692
693fn get_format_string(dtype: &DataType) -> Result<Cow<'static, str>, ArrowError> {
694    match dtype {
695        DataType::Null => Ok("n".into()),
696        DataType::Boolean => Ok("b".into()),
697        DataType::Int8 => Ok("c".into()),
698        DataType::UInt8 => Ok("C".into()),
699        DataType::Int16 => Ok("s".into()),
700        DataType::UInt16 => Ok("S".into()),
701        DataType::Int32 => Ok("i".into()),
702        DataType::UInt32 => Ok("I".into()),
703        DataType::Int64 => Ok("l".into()),
704        DataType::UInt64 => Ok("L".into()),
705        DataType::Float16 => Ok("e".into()),
706        DataType::Float32 => Ok("f".into()),
707        DataType::Float64 => Ok("g".into()),
708        DataType::BinaryView => Ok("vz".into()),
709        DataType::Binary => Ok("z".into()),
710        DataType::LargeBinary => Ok("Z".into()),
711        DataType::Utf8View => Ok("vu".into()),
712        DataType::Utf8 => Ok("u".into()),
713        DataType::LargeUtf8 => Ok("U".into()),
714        DataType::FixedSizeBinary(num_bytes) => Ok(Cow::Owned(format!("w:{num_bytes}"))),
715        DataType::FixedSizeList(_, num_elems) => Ok(Cow::Owned(format!("+w:{num_elems}"))),
716        DataType::Decimal32(precision, scale) => {
717            Ok(Cow::Owned(format!("d:{precision},{scale},32")))
718        }
719        DataType::Decimal64(precision, scale) => {
720            Ok(Cow::Owned(format!("d:{precision},{scale},64")))
721        }
722        DataType::Decimal128(precision, scale) => Ok(Cow::Owned(format!("d:{precision},{scale}"))),
723        DataType::Decimal256(precision, scale) => {
724            Ok(Cow::Owned(format!("d:{precision},{scale},256")))
725        }
726        DataType::Date32 => Ok("tdD".into()),
727        DataType::Date64 => Ok("tdm".into()),
728        DataType::Time32(TimeUnit::Second) => Ok("tts".into()),
729        DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".into()),
730        DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".into()),
731        DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".into()),
732        DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".into()),
733        DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".into()),
734        DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".into()),
735        DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".into()),
736        DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(Cow::Owned(format!("tss:{tz}"))),
737        DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(Cow::Owned(format!("tsm:{tz}"))),
738        DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Cow::Owned(format!("tsu:{tz}"))),
739        DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(Cow::Owned(format!("tsn:{tz}"))),
740        DataType::Duration(TimeUnit::Second) => Ok("tDs".into()),
741        DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".into()),
742        DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".into()),
743        DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".into()),
744        DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".into()),
745        DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".into()),
746        DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".into()),
747        DataType::List(_) => Ok("+l".into()),
748        DataType::LargeList(_) => Ok("+L".into()),
749        DataType::Struct(_) => Ok("+s".into()),
750        DataType::Map(_, _) => Ok("+m".into()),
751        DataType::RunEndEncoded(_, _) => Ok("+r".into()),
752        DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type),
753        DataType::Union(fields, mode) => {
754            let formats = fields
755                .iter()
756                .map(|(t, _)| t.to_string())
757                .collect::<Vec<_>>();
758            match mode {
759                UnionMode::Dense => Ok(Cow::Owned(format!("{}:{}", "+ud", formats.join(",")))),
760                UnionMode::Sparse => Ok(Cow::Owned(format!("{}:{}", "+us", formats.join(",")))),
761            }
762        }
763        other => Err(ArrowError::CDataInterface(format!(
764            "The datatype \"{other:?}\" is still not supported in Rust implementation"
765        ))),
766    }
767}
768
769impl TryFrom<&FieldRef> for FFI_ArrowSchema {
770    type Error = ArrowError;
771
772    fn try_from(value: &FieldRef) -> Result<Self, Self::Error> {
773        value.as_ref().try_into()
774    }
775}
776
777impl TryFrom<&Field> for FFI_ArrowSchema {
778    type Error = ArrowError;
779
780    fn try_from(field: &Field) -> Result<Self, ArrowError> {
781        let mut flags = if field.is_nullable() {
782            Flags::NULLABLE
783        } else {
784            Flags::empty()
785        };
786
787        if let Some(true) = field.dict_is_ordered() {
788            flags |= Flags::DICTIONARY_ORDERED;
789        }
790
791        FFI_ArrowSchema::try_from(field.data_type())?
792            .with_name(field.name())?
793            .with_flags(flags)?
794            .with_metadata(field.metadata())
795    }
796}
797
798impl TryFrom<&Schema> for FFI_ArrowSchema {
799    type Error = ArrowError;
800
801    fn try_from(schema: &Schema) -> Result<Self, ArrowError> {
802        let dtype = DataType::Struct(schema.fields().clone());
803        let c_schema = FFI_ArrowSchema::try_from(&dtype)?.with_metadata(&schema.metadata)?;
804        Ok(c_schema)
805    }
806}
807
808impl TryFrom<DataType> for FFI_ArrowSchema {
809    type Error = ArrowError;
810
811    fn try_from(dtype: DataType) -> Result<Self, ArrowError> {
812        FFI_ArrowSchema::try_from(&dtype)
813    }
814}
815
816impl TryFrom<Field> for FFI_ArrowSchema {
817    type Error = ArrowError;
818
819    fn try_from(field: Field) -> Result<Self, ArrowError> {
820        FFI_ArrowSchema::try_from(&field)
821    }
822}
823
824impl TryFrom<Schema> for FFI_ArrowSchema {
825    type Error = ArrowError;
826
827    fn try_from(schema: Schema) -> Result<Self, ArrowError> {
828        FFI_ArrowSchema::try_from(&schema)
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835    use crate::Fields;
836
837    fn round_trip_type(dtype: DataType) {
838        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
839        let restored = DataType::try_from(&c_schema).unwrap();
840        assert_eq!(restored, dtype);
841    }
842
843    fn round_trip_field(field: Field) {
844        let c_schema = FFI_ArrowSchema::try_from(&field).unwrap();
845        let restored = Field::try_from(&c_schema).unwrap();
846        assert_eq!(restored, field);
847    }
848
849    fn round_trip_schema(schema: Schema) {
850        let c_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
851        let restored = Schema::try_from(&c_schema).unwrap();
852        assert_eq!(restored, schema);
853    }
854
855    #[test]
856    fn test_type() {
857        round_trip_type(DataType::Int64);
858        round_trip_type(DataType::UInt64);
859        round_trip_type(DataType::Float64);
860        round_trip_type(DataType::Date64);
861        round_trip_type(DataType::Time64(TimeUnit::Nanosecond));
862        round_trip_type(DataType::FixedSizeBinary(12));
863        round_trip_type(DataType::FixedSizeList(
864            Arc::new(Field::new("a", DataType::Int64, false)),
865            5,
866        ));
867        round_trip_type(DataType::Utf8);
868        round_trip_type(DataType::Utf8View);
869        round_trip_type(DataType::BinaryView);
870        round_trip_type(DataType::Binary);
871        round_trip_type(DataType::LargeBinary);
872        round_trip_type(DataType::List(Arc::new(Field::new(
873            "a",
874            DataType::Int16,
875            false,
876        ))));
877        round_trip_type(DataType::Struct(Fields::from(vec![Field::new(
878            "a",
879            DataType::Utf8,
880            true,
881        )])));
882        round_trip_type(DataType::RunEndEncoded(
883            Arc::new(Field::new("run_ends", DataType::Int32, false)),
884            Arc::new(Field::new("values", DataType::Binary, true)),
885        ));
886    }
887
888    #[test]
889    fn test_field() {
890        let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)].into());
891        round_trip_field(Field::new("test", dtype, true));
892    }
893
894    #[test]
895    fn test_schema() {
896        let schema = Schema::new(vec![
897            Field::new("name", DataType::Utf8, false),
898            Field::new("address", DataType::Utf8, false),
899            Field::new("priority", DataType::UInt8, false),
900        ])
901        .with_metadata([("hello".to_string(), "world".to_string())].into());
902
903        round_trip_schema(schema);
904
905        // test that we can interpret struct types as schema
906        let dtype = DataType::Struct(Fields::from(vec![
907            Field::new("a", DataType::Utf8, true),
908            Field::new("b", DataType::Int16, false),
909        ]));
910        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
911        let schema = Schema::try_from(&c_schema).unwrap();
912        assert_eq!(schema.fields().len(), 2);
913
914        // test that we assert the input type
915        let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64).unwrap();
916        let result = Schema::try_from(&c_schema);
917        assert!(result.is_err());
918    }
919
920    #[test]
921    fn test_map_keys_sorted() {
922        let keys = Field::new("keys", DataType::Int32, false);
923        let values = Field::new("values", DataType::UInt32, false);
924        let entry_struct = DataType::Struct(vec![keys, values].into());
925
926        // Construct a map array from the above two
927        let map_data_type =
928            DataType::Map(Arc::new(Field::new("entries", entry_struct, false)), true);
929
930        let arrow_schema = FFI_ArrowSchema::try_from(map_data_type).unwrap();
931        assert!(arrow_schema.map_keys_sorted());
932    }
933
934    #[test]
935    fn test_dictionary_ordered() {
936        #[allow(deprecated)]
937        let schema = Schema::new(vec![Field::new_dict(
938            "dict",
939            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
940            false,
941            0,
942            true,
943        )]);
944
945        let arrow_schema = FFI_ArrowSchema::try_from(schema).unwrap();
946        assert!(arrow_schema.child(0).dictionary_ordered());
947    }
948
949    #[test]
950    fn test_set_field_metadata() {
951        let metadata_cases: Vec<HashMap<String, String>> = vec![
952            [].into(),
953            [("key".to_string(), "value".to_string())].into(),
954            [
955                ("key".to_string(), "".to_string()),
956                ("ascii123".to_string(), "你好".to_string()),
957                ("".to_string(), "value".to_string()),
958            ]
959            .into(),
960        ];
961
962        let mut schema = FFI_ArrowSchema::try_new("b", vec![], None)
963            .unwrap()
964            .with_name("test")
965            .unwrap();
966
967        for metadata in metadata_cases {
968            schema = schema.with_metadata(&metadata).unwrap();
969            let field = Field::try_from(&schema).unwrap();
970            assert_eq!(field.metadata(), &metadata);
971        }
972    }
973
974    #[test]
975    fn test_import_field_with_null_name() {
976        let dtype = DataType::Int16;
977        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
978        assert!(c_schema.name().is_none());
979        let field = Field::try_from(&c_schema).unwrap();
980        assert_eq!(field.name(), "");
981    }
982}