arrow_schema/
ffi.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html).
19//!
20//! ```
21//! # use arrow_schema::{DataType, Field, Schema};
22//! # use arrow_schema::ffi::FFI_ArrowSchema;
23//!
24//! // Create from data type
25//! let ffi_data_type = FFI_ArrowSchema::try_from(&DataType::LargeUtf8).unwrap();
26//! let back = DataType::try_from(&ffi_data_type).unwrap();
27//! assert_eq!(back, DataType::LargeUtf8);
28//!
29//! // Create from schema
30//! let schema = Schema::new(vec![Field::new("foo", DataType::Int64, false)]);
31//! let ffi_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
32//! let back = Schema::try_from(&ffi_schema).unwrap();
33//!
34//! assert_eq!(schema, back);
35//! ```
36
37use crate::{
38    ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode,
39};
40use bitflags::bitflags;
41use std::borrow::Cow;
42use std::sync::Arc;
43use std::{
44    collections::HashMap,
45    ffi::{c_char, c_void, CStr, CString},
46};
47
48bitflags! {
49    /// Flags for [`FFI_ArrowSchema`]
50    ///
51    /// Old Workaround at <https://github.com/bitflags/bitflags/issues/356>
52    /// is no longer required as `bitflags` [fixed the issue](https://github.com/bitflags/bitflags/pull/355).
53    pub struct Flags: i64 {
54        /// Indicates that the dictionary is ordered
55        const DICTIONARY_ORDERED = 0b00000001;
56        /// Indicates that the field is nullable
57        const NULLABLE = 0b00000010;
58        /// Indicates that the map keys are sorted
59        const MAP_KEYS_SORTED = 0b00000100;
60    }
61}
62
63/// ABI-compatible struct for `ArrowSchema` from C Data Interface
64/// See <https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions>
65///
66/// ```
67/// # use arrow_schema::DataType;
68/// # use arrow_schema::ffi::FFI_ArrowSchema;
69/// fn array_schema(data_type: &DataType) -> FFI_ArrowSchema {
70///     FFI_ArrowSchema::try_from(data_type).unwrap()
71/// }
72/// ```
73///
74#[repr(C)]
75#[derive(Debug)]
76#[allow(non_camel_case_types)]
77pub struct FFI_ArrowSchema {
78    format: *const c_char,
79    name: *const c_char,
80    metadata: *const c_char,
81    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
82    flags: i64,
83    n_children: i64,
84    children: *mut *mut FFI_ArrowSchema,
85    dictionary: *mut FFI_ArrowSchema,
86    release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowSchema)>,
87    private_data: *mut c_void,
88}
89
90struct SchemaPrivateData {
91    children: Box<[*mut FFI_ArrowSchema]>,
92    dictionary: *mut FFI_ArrowSchema,
93    metadata: Option<Vec<u8>>,
94}
95
96// callback used to drop [FFI_ArrowSchema] when it is exported.
97unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) {
98    if schema.is_null() {
99        return;
100    }
101    let schema = &mut *schema;
102
103    // take ownership back to release it.
104    drop(CString::from_raw(schema.format as *mut c_char));
105    if !schema.name.is_null() {
106        drop(CString::from_raw(schema.name as *mut c_char));
107    }
108    if !schema.private_data.is_null() {
109        let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
110        for child in private_data.children.iter() {
111            drop(Box::from_raw(*child))
112        }
113        if !private_data.dictionary.is_null() {
114            drop(Box::from_raw(private_data.dictionary));
115        }
116
117        drop(private_data);
118    }
119
120    schema.release = None;
121}
122
123impl FFI_ArrowSchema {
124    /// create a new [`FFI_ArrowSchema`]. This fails if the fields'
125    /// [`DataType`] is not supported.
126    pub fn try_new(
127        format: &str,
128        children: Vec<FFI_ArrowSchema>,
129        dictionary: Option<FFI_ArrowSchema>,
130    ) -> Result<Self, ArrowError> {
131        let mut this = Self::empty();
132
133        let children_ptr = children
134            .into_iter()
135            .map(Box::new)
136            .map(Box::into_raw)
137            .collect::<Box<_>>();
138
139        this.format = CString::new(format).unwrap().into_raw();
140        this.release = Some(release_schema);
141        this.n_children = children_ptr.len() as i64;
142
143        let dictionary_ptr = dictionary
144            .map(|d| Box::into_raw(Box::new(d)))
145            .unwrap_or(std::ptr::null_mut());
146
147        let mut private_data = Box::new(SchemaPrivateData {
148            children: children_ptr,
149            dictionary: dictionary_ptr,
150            metadata: None,
151        });
152
153        // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580)
154        this.children = private_data.children.as_mut_ptr();
155
156        this.dictionary = dictionary_ptr;
157
158        this.private_data = Box::into_raw(private_data) as *mut c_void;
159
160        Ok(this)
161    }
162
163    /// Set the name of the schema
164    pub fn with_name(mut self, name: &str) -> Result<Self, ArrowError> {
165        self.name = CString::new(name).unwrap().into_raw();
166        Ok(self)
167    }
168
169    /// Set the flags of the schema
170    pub fn with_flags(mut self, flags: Flags) -> Result<Self, ArrowError> {
171        self.flags = flags.bits();
172        Ok(self)
173    }
174
175    /// Add metadata to the schema
176    pub fn with_metadata<I, S>(mut self, metadata: I) -> Result<Self, ArrowError>
177    where
178        I: IntoIterator<Item = (S, S)>,
179        S: AsRef<str>,
180    {
181        let metadata: Vec<(S, S)> = metadata.into_iter().collect();
182        // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata
183        let new_metadata = if !metadata.is_empty() {
184            let mut metadata_serialized: Vec<u8> = Vec::new();
185            let num_entries: i32 = metadata.len().try_into().map_err(|_| {
186                ArrowError::CDataInterface(format!(
187                    "metadata can only have {} entries, but {} were provided",
188                    i32::MAX,
189                    metadata.len()
190                ))
191            })?;
192            metadata_serialized.extend(num_entries.to_ne_bytes());
193
194            for (key, value) in metadata.into_iter() {
195                let key_len: i32 = key.as_ref().len().try_into().map_err(|_| {
196                    ArrowError::CDataInterface(format!(
197                        "metadata key can only have {} bytes, but {} were provided",
198                        i32::MAX,
199                        key.as_ref().len()
200                    ))
201                })?;
202                let value_len: i32 = value.as_ref().len().try_into().map_err(|_| {
203                    ArrowError::CDataInterface(format!(
204                        "metadata value can only have {} bytes, but {} were provided",
205                        i32::MAX,
206                        value.as_ref().len()
207                    ))
208                })?;
209
210                metadata_serialized.extend(key_len.to_ne_bytes());
211                metadata_serialized.extend_from_slice(key.as_ref().as_bytes());
212                metadata_serialized.extend(value_len.to_ne_bytes());
213                metadata_serialized.extend_from_slice(value.as_ref().as_bytes());
214            }
215
216            self.metadata = metadata_serialized.as_ptr() as *const c_char;
217            Some(metadata_serialized)
218        } else {
219            self.metadata = std::ptr::null_mut();
220            None
221        };
222
223        unsafe {
224            let mut private_data = Box::from_raw(self.private_data as *mut SchemaPrivateData);
225            private_data.metadata = new_metadata;
226            self.private_data = Box::into_raw(private_data) as *mut c_void;
227        }
228
229        Ok(self)
230    }
231
232    /// Takes ownership of the pointed to [`FFI_ArrowSchema`]
233    ///
234    /// This acts to [move] the data out of `schema`, setting the release callback to NULL
235    ///
236    /// # Safety
237    ///
238    /// * `schema` must be [valid] for reads and writes
239    /// * `schema` must be properly aligned
240    /// * `schema` must point to a properly initialized value of [`FFI_ArrowSchema`]
241    ///
242    /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array
243    /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety
244    pub unsafe fn from_raw(schema: *mut FFI_ArrowSchema) -> Self {
245        std::ptr::replace(schema, Self::empty())
246    }
247
248    /// Create an empty [`FFI_ArrowSchema`]
249    pub fn empty() -> Self {
250        Self {
251            format: std::ptr::null_mut(),
252            name: std::ptr::null_mut(),
253            metadata: std::ptr::null_mut(),
254            flags: 0,
255            n_children: 0,
256            children: std::ptr::null_mut(),
257            dictionary: std::ptr::null_mut(),
258            release: None,
259            private_data: std::ptr::null_mut(),
260        }
261    }
262
263    /// Returns the format of this schema.
264    pub fn format(&self) -> &str {
265        assert!(!self.format.is_null());
266        // safe because the lifetime of `self.format` equals `self`
267        unsafe { CStr::from_ptr(self.format) }
268            .to_str()
269            .expect("The external API has a non-utf8 as format")
270    }
271
272    /// Returns the name of this schema.
273    pub fn name(&self) -> Option<&str> {
274        if self.name.is_null() {
275            None
276        } else {
277            // safe because the lifetime of `self.name` equals `self`
278            Some(
279                unsafe { CStr::from_ptr(self.name) }
280                    .to_str()
281                    .expect("The external API has a non-utf8 as name"),
282            )
283        }
284    }
285
286    /// Returns the flags of this schema.
287    pub fn flags(&self) -> Option<Flags> {
288        Flags::from_bits(self.flags)
289    }
290
291    /// Returns the child of this schema at `index`.
292    ///
293    /// # Panics
294    ///
295    /// Panics if `index` is greater than or equal to the number of children.
296    ///
297    /// This is to make sure that the unsafe acces to raw pointer is sound.
298    pub fn child(&self, index: usize) -> &Self {
299        assert!(index < self.n_children as usize);
300        unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
301    }
302
303    /// Returns an iterator to the schema's children.
304    pub fn children(&self) -> impl Iterator<Item = &Self> {
305        (0..self.n_children as usize).map(move |i| self.child(i))
306    }
307
308    /// Returns if the field is semantically nullable,
309    /// regardless of whether it actually has null values.
310    pub fn nullable(&self) -> bool {
311        (self.flags / 2) & 1 == 1
312    }
313
314    /// Returns the reference to the underlying dictionary of the schema.
315    /// Check [ArrowSchema.dictionary](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.dictionary).
316    ///
317    /// This must be `Some` if the schema represents a dictionary-encoded type, `None` otherwise.
318    pub fn dictionary(&self) -> Option<&Self> {
319        unsafe { self.dictionary.as_ref() }
320    }
321
322    /// For map types, returns whether the keys within each map value are sorted.
323    ///
324    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
325    pub fn map_keys_sorted(&self) -> bool {
326        self.flags & 0b00000100 != 0
327    }
328
329    /// For dictionary-encoded types, returns whether the ordering of dictionary indices is semantically meaningful.
330    pub fn dictionary_ordered(&self) -> bool {
331        self.flags & 0b00000001 != 0
332    }
333
334    /// Returns the metadata in the schema as `Key-Value` pairs
335    pub fn metadata(&self) -> Result<HashMap<String, String>, ArrowError> {
336        if self.metadata.is_null() {
337            Ok(HashMap::new())
338        } else {
339            let mut pos = 0;
340
341            // On some platforms, c_char = u8, and on some, c_char = i8. Where c_char = u8, clippy
342            // wants to complain that we're casting to the same type, but if we remove the cast,
343            // this will fail to compile on the other platforms. So we must allow it.
344            #[allow(clippy::unnecessary_cast)]
345            let buffer: *const u8 = self.metadata as *const u8;
346
347            fn next_four_bytes(buffer: *const u8, pos: &mut isize) -> [u8; 4] {
348                let out = unsafe {
349                    [
350                        *buffer.offset(*pos),
351                        *buffer.offset(*pos + 1),
352                        *buffer.offset(*pos + 2),
353                        *buffer.offset(*pos + 3),
354                    ]
355                };
356                *pos += 4;
357                out
358            }
359
360            fn next_n_bytes(buffer: *const u8, pos: &mut isize, n: i32) -> &[u8] {
361                let out = unsafe {
362                    std::slice::from_raw_parts(buffer.offset(*pos), n.try_into().unwrap())
363                };
364                *pos += isize::try_from(n).unwrap();
365                out
366            }
367
368            let num_entries = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
369            if num_entries < 0 {
370                return Err(ArrowError::CDataInterface(
371                    "Negative number of metadata entries".to_string(),
372                ));
373            }
374
375            let mut metadata =
376                HashMap::with_capacity(num_entries.try_into().expect("Too many metadata entries"));
377
378            for _ in 0..num_entries {
379                let key_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
380                if key_length < 0 {
381                    return Err(ArrowError::CDataInterface(
382                        "Negative key length in metadata".to_string(),
383                    ));
384                }
385                let key = String::from_utf8(next_n_bytes(buffer, &mut pos, key_length).to_vec())?;
386                let value_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
387                if value_length < 0 {
388                    return Err(ArrowError::CDataInterface(
389                        "Negative value length in metadata".to_string(),
390                    ));
391                }
392                let value =
393                    String::from_utf8(next_n_bytes(buffer, &mut pos, value_length).to_vec())?;
394                metadata.insert(key, value);
395            }
396
397            Ok(metadata)
398        }
399    }
400}
401
402impl Drop for FFI_ArrowSchema {
403    fn drop(&mut self) {
404        match self.release {
405            None => (),
406            Some(release) => unsafe { release(self) },
407        };
408    }
409}
410
411unsafe impl Send for FFI_ArrowSchema {}
412
413impl TryFrom<&FFI_ArrowSchema> for DataType {
414    type Error = ArrowError;
415
416    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
417    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
418        let mut dtype = match c_schema.format() {
419            "n" => DataType::Null,
420            "b" => DataType::Boolean,
421            "c" => DataType::Int8,
422            "C" => DataType::UInt8,
423            "s" => DataType::Int16,
424            "S" => DataType::UInt16,
425            "i" => DataType::Int32,
426            "I" => DataType::UInt32,
427            "l" => DataType::Int64,
428            "L" => DataType::UInt64,
429            "e" => DataType::Float16,
430            "f" => DataType::Float32,
431            "g" => DataType::Float64,
432            "vz" => DataType::BinaryView,
433            "z" => DataType::Binary,
434            "Z" => DataType::LargeBinary,
435            "vu" => DataType::Utf8View,
436            "u" => DataType::Utf8,
437            "U" => DataType::LargeUtf8,
438            "tdD" => DataType::Date32,
439            "tdm" => DataType::Date64,
440            "tts" => DataType::Time32(TimeUnit::Second),
441            "ttm" => DataType::Time32(TimeUnit::Millisecond),
442            "ttu" => DataType::Time64(TimeUnit::Microsecond),
443            "ttn" => DataType::Time64(TimeUnit::Nanosecond),
444            "tDs" => DataType::Duration(TimeUnit::Second),
445            "tDm" => DataType::Duration(TimeUnit::Millisecond),
446            "tDu" => DataType::Duration(TimeUnit::Microsecond),
447            "tDn" => DataType::Duration(TimeUnit::Nanosecond),
448            "tiM" => DataType::Interval(IntervalUnit::YearMonth),
449            "tiD" => DataType::Interval(IntervalUnit::DayTime),
450            "tin" => DataType::Interval(IntervalUnit::MonthDayNano),
451            "+l" => {
452                let c_child = c_schema.child(0);
453                DataType::List(Arc::new(Field::try_from(c_child)?))
454            }
455            "+L" => {
456                let c_child = c_schema.child(0);
457                DataType::LargeList(Arc::new(Field::try_from(c_child)?))
458            }
459            "+s" => {
460                let fields = c_schema.children().map(Field::try_from);
461                DataType::Struct(fields.collect::<Result<_, ArrowError>>()?)
462            }
463            "+m" => {
464                let c_child = c_schema.child(0);
465                let map_keys_sorted = c_schema.map_keys_sorted();
466                DataType::Map(Arc::new(Field::try_from(c_child)?), map_keys_sorted)
467            }
468            "+r" => {
469                let c_run_ends = c_schema.child(0);
470                let c_values = c_schema.child(1);
471                DataType::RunEndEncoded(
472                    Arc::new(Field::try_from(c_run_ends)?),
473                    Arc::new(Field::try_from(c_values)?),
474                )
475            }
476            // Parametrized types, requiring string parse
477            other => {
478                match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
479                    // FixedSizeBinary type in format "w:num_bytes"
480                    ["w", num_bytes] => {
481                        let parsed_num_bytes = num_bytes.parse::<i32>().map_err(|_| {
482                            ArrowError::CDataInterface(
483                                "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string())
484                        })?;
485                        DataType::FixedSizeBinary(parsed_num_bytes)
486                    },
487                    // FixedSizeList type in format "+w:num_elems"
488                    ["+w", num_elems] => {
489                        let c_child = c_schema.child(0);
490                        let parsed_num_elems = num_elems.parse::<i32>().map_err(|_| {
491                            ArrowError::CDataInterface(
492                                "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string())
493                        })?;
494                        DataType::FixedSizeList(Arc::new(Field::try_from(c_child)?), parsed_num_elems)
495                    },
496                    // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
497                    ["d", extra] => {
498                        match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
499                            [precision, scale] => {
500                                let parsed_precision = precision.parse::<u8>().map_err(|_| {
501                                    ArrowError::CDataInterface(
502                                        "The decimal type requires an integer precision".to_string(),
503                                    )
504                                })?;
505                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
506                                    ArrowError::CDataInterface(
507                                        "The decimal type requires an integer scale".to_string(),
508                                    )
509                                })?;
510                                DataType::Decimal128(parsed_precision, parsed_scale)
511                            },
512                            [precision, scale, bits] => {
513                                let parsed_precision = precision.parse::<u8>().map_err(|_| {
514                                    ArrowError::CDataInterface(
515                                        "The decimal type requires an integer precision".to_string(),
516                                    )
517                                })?;
518                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
519                                    ArrowError::CDataInterface(
520                                        "The decimal type requires an integer scale".to_string(),
521                                    )
522                                })?;
523                                match *bits {
524                                    "32" => DataType::Decimal32(parsed_precision, parsed_scale),
525                                    "64" => DataType::Decimal64(parsed_precision, parsed_scale),
526                                    "128" => DataType::Decimal128(parsed_precision, parsed_scale),
527                                    "256" => DataType::Decimal256(parsed_precision, parsed_scale),
528                                    _ => return Err(ArrowError::CDataInterface("Only 32/64/128/256 bit wide decimals are supported in the Rust implementation".to_string())),
529                                }
530                            }
531                            _ => {
532                                return Err(ArrowError::CDataInterface(format!(
533                                    "The decimal pattern \"d:{extra:?}\" is not supported in the Rust implementation"
534                                )))
535                            }
536                        }
537                    }
538                    // DenseUnion
539                    ["+ud", extra] => {
540                        let type_ids = extra.split(',').map(|t| t.parse::<i8>().map_err(|_| {
541                            ArrowError::CDataInterface(
542                                "The Union type requires an integer type id".to_string(),
543                            )
544                        })).collect::<Result<Vec<_>, ArrowError>>()?;
545                        let mut fields = Vec::with_capacity(type_ids.len());
546                        for idx in 0..c_schema.n_children {
547                            let c_child = c_schema.child(idx as usize);
548                            let field = Field::try_from(c_child)?;
549                            fields.push(field);
550                        }
551
552                        if fields.len() != type_ids.len() {
553                            return Err(ArrowError::CDataInterface(
554                                "The Union type requires same number of fields and type ids".to_string(),
555                            ));
556                        }
557
558                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
559                    }
560                    // SparseUnion
561                    ["+us", extra] => {
562                        let type_ids = extra.split(',').map(|t| t.parse::<i8>().map_err(|_| {
563                            ArrowError::CDataInterface(
564                                "The Union type requires an integer type id".to_string(),
565                            )
566                        })).collect::<Result<Vec<_>, ArrowError>>()?;
567                        let mut fields = Vec::with_capacity(type_ids.len());
568                        for idx in 0..c_schema.n_children {
569                            let c_child = c_schema.child(idx as usize);
570                            let field = Field::try_from(c_child)?;
571                            fields.push(field);
572                        }
573
574                        if fields.len() != type_ids.len() {
575                            return Err(ArrowError::CDataInterface(
576                                "The Union type requires same number of fields and type ids".to_string(),
577                            ));
578                        }
579
580                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Sparse)
581                    }
582
583                    // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
584                    ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
585                    ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
586                    ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
587                    ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
588                    ["tss", tz] => {
589                        DataType::Timestamp(TimeUnit::Second, Some(Arc::from(*tz)))
590                    }
591                    ["tsm", tz] => {
592                        DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from(*tz)))
593                    }
594                    ["tsu", tz] => {
595                        DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from(*tz)))
596                    }
597                    ["tsn", tz] => {
598                        DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from(*tz)))
599                    }
600                    _ => {
601                        return Err(ArrowError::CDataInterface(format!(
602                            "The datatype \"{other:?}\" is still not supported in Rust implementation"
603                        )))
604                    }
605                }
606            }
607        };
608
609        if let Some(dict_schema) = c_schema.dictionary() {
610            let value_type = Self::try_from(dict_schema)?;
611            dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type));
612        }
613
614        Ok(dtype)
615    }
616}
617
618impl TryFrom<&FFI_ArrowSchema> for Field {
619    type Error = ArrowError;
620
621    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
622        let dtype = DataType::try_from(c_schema)?;
623        let mut field = Field::new(c_schema.name().unwrap_or(""), dtype, c_schema.nullable());
624        field.set_metadata(c_schema.metadata()?);
625        Ok(field)
626    }
627}
628
629impl TryFrom<&FFI_ArrowSchema> for Schema {
630    type Error = ArrowError;
631
632    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
633        // interpret it as a struct type then extract its fields
634        let dtype = DataType::try_from(c_schema)?;
635        if let DataType::Struct(fields) = dtype {
636            Ok(Schema::new(fields).with_metadata(c_schema.metadata()?))
637        } else {
638            Err(ArrowError::CDataInterface(
639                "Unable to interpret C data struct as a Schema".to_string(),
640            ))
641        }
642    }
643}
644
645impl TryFrom<&DataType> for FFI_ArrowSchema {
646    type Error = ArrowError;
647
648    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
649    fn try_from(dtype: &DataType) -> Result<Self, ArrowError> {
650        let format = get_format_string(dtype)?;
651        // allocate and hold the children
652        let children = match dtype {
653            DataType::List(child)
654            | DataType::LargeList(child)
655            | DataType::FixedSizeList(child, _)
656            | DataType::Map(child, _) => {
657                vec![FFI_ArrowSchema::try_from(child.as_ref())?]
658            }
659            DataType::Union(fields, _) => fields
660                .iter()
661                .map(|(_, f)| f.as_ref().try_into())
662                .collect::<Result<Vec<_>, ArrowError>>()?,
663            DataType::Struct(fields) => fields
664                .iter()
665                .map(FFI_ArrowSchema::try_from)
666                .collect::<Result<Vec<_>, ArrowError>>()?,
667            DataType::RunEndEncoded(run_ends, values) => vec![
668                FFI_ArrowSchema::try_from(run_ends.as_ref())?,
669                FFI_ArrowSchema::try_from(values.as_ref())?,
670            ],
671            _ => vec![],
672        };
673        let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype {
674            Some(Self::try_from(value_data_type.as_ref())?)
675        } else {
676            None
677        };
678
679        let flags = match dtype {
680            DataType::Map(_, true) => Flags::MAP_KEYS_SORTED,
681            _ => Flags::empty(),
682        };
683
684        FFI_ArrowSchema::try_new(&format, children, dictionary)?.with_flags(flags)
685    }
686}
687
688fn get_format_string(dtype: &DataType) -> Result<Cow<'static, str>, ArrowError> {
689    match dtype {
690        DataType::Null => Ok("n".into()),
691        DataType::Boolean => Ok("b".into()),
692        DataType::Int8 => Ok("c".into()),
693        DataType::UInt8 => Ok("C".into()),
694        DataType::Int16 => Ok("s".into()),
695        DataType::UInt16 => Ok("S".into()),
696        DataType::Int32 => Ok("i".into()),
697        DataType::UInt32 => Ok("I".into()),
698        DataType::Int64 => Ok("l".into()),
699        DataType::UInt64 => Ok("L".into()),
700        DataType::Float16 => Ok("e".into()),
701        DataType::Float32 => Ok("f".into()),
702        DataType::Float64 => Ok("g".into()),
703        DataType::BinaryView => Ok("vz".into()),
704        DataType::Binary => Ok("z".into()),
705        DataType::LargeBinary => Ok("Z".into()),
706        DataType::Utf8View => Ok("vu".into()),
707        DataType::Utf8 => Ok("u".into()),
708        DataType::LargeUtf8 => Ok("U".into()),
709        DataType::FixedSizeBinary(num_bytes) => Ok(Cow::Owned(format!("w:{num_bytes}"))),
710        DataType::FixedSizeList(_, num_elems) => Ok(Cow::Owned(format!("+w:{num_elems}"))),
711        DataType::Decimal32(precision, scale) => {
712            Ok(Cow::Owned(format!("d:{precision},{scale},32")))
713        }
714        DataType::Decimal64(precision, scale) => {
715            Ok(Cow::Owned(format!("d:{precision},{scale},64")))
716        }
717        DataType::Decimal128(precision, scale) => Ok(Cow::Owned(format!("d:{precision},{scale}"))),
718        DataType::Decimal256(precision, scale) => {
719            Ok(Cow::Owned(format!("d:{precision},{scale},256")))
720        }
721        DataType::Date32 => Ok("tdD".into()),
722        DataType::Date64 => Ok("tdm".into()),
723        DataType::Time32(TimeUnit::Second) => Ok("tts".into()),
724        DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".into()),
725        DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".into()),
726        DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".into()),
727        DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".into()),
728        DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".into()),
729        DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".into()),
730        DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".into()),
731        DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(Cow::Owned(format!("tss:{tz}"))),
732        DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(Cow::Owned(format!("tsm:{tz}"))),
733        DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Cow::Owned(format!("tsu:{tz}"))),
734        DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(Cow::Owned(format!("tsn:{tz}"))),
735        DataType::Duration(TimeUnit::Second) => Ok("tDs".into()),
736        DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".into()),
737        DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".into()),
738        DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".into()),
739        DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".into()),
740        DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".into()),
741        DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".into()),
742        DataType::List(_) => Ok("+l".into()),
743        DataType::LargeList(_) => Ok("+L".into()),
744        DataType::Struct(_) => Ok("+s".into()),
745        DataType::Map(_, _) => Ok("+m".into()),
746        DataType::RunEndEncoded(_, _) => Ok("+r".into()),
747        DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type),
748        DataType::Union(fields, mode) => {
749            let formats = fields
750                .iter()
751                .map(|(t, _)| t.to_string())
752                .collect::<Vec<_>>();
753            match mode {
754                UnionMode::Dense => Ok(Cow::Owned(format!("{}:{}", "+ud", formats.join(",")))),
755                UnionMode::Sparse => Ok(Cow::Owned(format!("{}:{}", "+us", formats.join(",")))),
756            }
757        }
758        other => Err(ArrowError::CDataInterface(format!(
759            "The datatype \"{other:?}\" is still not supported in Rust implementation"
760        ))),
761    }
762}
763
764impl TryFrom<&FieldRef> for FFI_ArrowSchema {
765    type Error = ArrowError;
766
767    fn try_from(value: &FieldRef) -> Result<Self, Self::Error> {
768        value.as_ref().try_into()
769    }
770}
771
772impl TryFrom<&Field> for FFI_ArrowSchema {
773    type Error = ArrowError;
774
775    fn try_from(field: &Field) -> Result<Self, ArrowError> {
776        let mut flags = if field.is_nullable() {
777            Flags::NULLABLE
778        } else {
779            Flags::empty()
780        };
781
782        if let Some(true) = field.dict_is_ordered() {
783            flags |= Flags::DICTIONARY_ORDERED;
784        }
785
786        FFI_ArrowSchema::try_from(field.data_type())?
787            .with_name(field.name())?
788            .with_flags(flags)?
789            .with_metadata(field.metadata())
790    }
791}
792
793impl TryFrom<&Schema> for FFI_ArrowSchema {
794    type Error = ArrowError;
795
796    fn try_from(schema: &Schema) -> Result<Self, ArrowError> {
797        let dtype = DataType::Struct(schema.fields().clone());
798        let c_schema = FFI_ArrowSchema::try_from(&dtype)?.with_metadata(&schema.metadata)?;
799        Ok(c_schema)
800    }
801}
802
803impl TryFrom<DataType> for FFI_ArrowSchema {
804    type Error = ArrowError;
805
806    fn try_from(dtype: DataType) -> Result<Self, ArrowError> {
807        FFI_ArrowSchema::try_from(&dtype)
808    }
809}
810
811impl TryFrom<Field> for FFI_ArrowSchema {
812    type Error = ArrowError;
813
814    fn try_from(field: Field) -> Result<Self, ArrowError> {
815        FFI_ArrowSchema::try_from(&field)
816    }
817}
818
819impl TryFrom<Schema> for FFI_ArrowSchema {
820    type Error = ArrowError;
821
822    fn try_from(schema: Schema) -> Result<Self, ArrowError> {
823        FFI_ArrowSchema::try_from(&schema)
824    }
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use crate::Fields;
831
832    fn round_trip_type(dtype: DataType) {
833        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
834        let restored = DataType::try_from(&c_schema).unwrap();
835        assert_eq!(restored, dtype);
836    }
837
838    fn round_trip_field(field: Field) {
839        let c_schema = FFI_ArrowSchema::try_from(&field).unwrap();
840        let restored = Field::try_from(&c_schema).unwrap();
841        assert_eq!(restored, field);
842    }
843
844    fn round_trip_schema(schema: Schema) {
845        let c_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
846        let restored = Schema::try_from(&c_schema).unwrap();
847        assert_eq!(restored, schema);
848    }
849
850    #[test]
851    fn test_type() {
852        round_trip_type(DataType::Int64);
853        round_trip_type(DataType::UInt64);
854        round_trip_type(DataType::Float64);
855        round_trip_type(DataType::Date64);
856        round_trip_type(DataType::Time64(TimeUnit::Nanosecond));
857        round_trip_type(DataType::FixedSizeBinary(12));
858        round_trip_type(DataType::FixedSizeList(
859            Arc::new(Field::new("a", DataType::Int64, false)),
860            5,
861        ));
862        round_trip_type(DataType::Utf8);
863        round_trip_type(DataType::Utf8View);
864        round_trip_type(DataType::BinaryView);
865        round_trip_type(DataType::Binary);
866        round_trip_type(DataType::LargeBinary);
867        round_trip_type(DataType::List(Arc::new(Field::new(
868            "a",
869            DataType::Int16,
870            false,
871        ))));
872        round_trip_type(DataType::Struct(Fields::from(vec![Field::new(
873            "a",
874            DataType::Utf8,
875            true,
876        )])));
877        round_trip_type(DataType::RunEndEncoded(
878            Arc::new(Field::new("run_ends", DataType::Int32, false)),
879            Arc::new(Field::new("values", DataType::Binary, true)),
880        ));
881    }
882
883    #[test]
884    fn test_field() {
885        let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)].into());
886        round_trip_field(Field::new("test", dtype, true));
887    }
888
889    #[test]
890    fn test_schema() {
891        let schema = Schema::new(vec![
892            Field::new("name", DataType::Utf8, false),
893            Field::new("address", DataType::Utf8, false),
894            Field::new("priority", DataType::UInt8, false),
895        ])
896        .with_metadata([("hello".to_string(), "world".to_string())].into());
897
898        round_trip_schema(schema);
899
900        // test that we can interpret struct types as schema
901        let dtype = DataType::Struct(Fields::from(vec![
902            Field::new("a", DataType::Utf8, true),
903            Field::new("b", DataType::Int16, false),
904        ]));
905        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
906        let schema = Schema::try_from(&c_schema).unwrap();
907        assert_eq!(schema.fields().len(), 2);
908
909        // test that we assert the input type
910        let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64).unwrap();
911        let result = Schema::try_from(&c_schema);
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn test_map_keys_sorted() {
917        let keys = Field::new("keys", DataType::Int32, false);
918        let values = Field::new("values", DataType::UInt32, false);
919        let entry_struct = DataType::Struct(vec![keys, values].into());
920
921        // Construct a map array from the above two
922        let map_data_type =
923            DataType::Map(Arc::new(Field::new("entries", entry_struct, false)), true);
924
925        let arrow_schema = FFI_ArrowSchema::try_from(map_data_type).unwrap();
926        assert!(arrow_schema.map_keys_sorted());
927    }
928
929    #[test]
930    fn test_dictionary_ordered() {
931        #[allow(deprecated)]
932        let schema = Schema::new(vec![Field::new_dict(
933            "dict",
934            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
935            false,
936            0,
937            true,
938        )]);
939
940        let arrow_schema = FFI_ArrowSchema::try_from(schema).unwrap();
941        assert!(arrow_schema.child(0).dictionary_ordered());
942    }
943
944    #[test]
945    fn test_set_field_metadata() {
946        let metadata_cases: Vec<HashMap<String, String>> = vec![
947            [].into(),
948            [("key".to_string(), "value".to_string())].into(),
949            [
950                ("key".to_string(), "".to_string()),
951                ("ascii123".to_string(), "你好".to_string()),
952                ("".to_string(), "value".to_string()),
953            ]
954            .into(),
955        ];
956
957        let mut schema = FFI_ArrowSchema::try_new("b", vec![], None)
958            .unwrap()
959            .with_name("test")
960            .unwrap();
961
962        for metadata in metadata_cases {
963            schema = schema.with_metadata(&metadata).unwrap();
964            let field = Field::try_from(&schema).unwrap();
965            assert_eq!(field.metadata(), &metadata);
966        }
967    }
968
969    #[test]
970    fn test_import_field_with_null_name() {
971        let dtype = DataType::Int16;
972        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
973        assert!(c_schema.name().is_none());
974        let field = Field::try_from(&c_schema).unwrap();
975        assert_eq!(field.name(), "");
976    }
977}