arrow_schema/
ffi.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html).
19//!
20//! ```
21//! # use arrow_schema::{DataType, Field, Schema};
22//! # use arrow_schema::ffi::FFI_ArrowSchema;
23//!
24//! // Create from data type
25//! let ffi_data_type = FFI_ArrowSchema::try_from(&DataType::LargeUtf8).unwrap();
26//! let back = DataType::try_from(&ffi_data_type).unwrap();
27//! assert_eq!(back, DataType::LargeUtf8);
28//!
29//! // Create from schema
30//! let schema = Schema::new(vec![Field::new("foo", DataType::Int64, false)]);
31//! let ffi_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
32//! let back = Schema::try_from(&ffi_schema).unwrap();
33//!
34//! assert_eq!(schema, back);
35//! ```
36
37use crate::{
38    ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode,
39};
40use bitflags::bitflags;
41use std::borrow::Cow;
42use std::sync::Arc;
43use std::{
44    collections::HashMap,
45    ffi::{c_char, c_void, CStr, CString},
46};
47
48bitflags! {
49    /// Flags for [`FFI_ArrowSchema`]
50    ///
51    /// Old Workaround at <https://github.com/bitflags/bitflags/issues/356>
52    /// is no longer required as `bitflags` [fixed the issue](https://github.com/bitflags/bitflags/pull/355).
53    pub struct Flags: i64 {
54        /// Indicates that the dictionary is ordered
55        const DICTIONARY_ORDERED = 0b00000001;
56        /// Indicates that the field is nullable
57        const NULLABLE = 0b00000010;
58        /// Indicates that the map keys are sorted
59        const MAP_KEYS_SORTED = 0b00000100;
60    }
61}
62
63/// ABI-compatible struct for `ArrowSchema` from C Data Interface
64/// See <https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions>
65///
66/// ```
67/// # use arrow_schema::DataType;
68/// # use arrow_schema::ffi::FFI_ArrowSchema;
69/// fn array_schema(data_type: &DataType) -> FFI_ArrowSchema {
70///     FFI_ArrowSchema::try_from(data_type).unwrap()
71/// }
72/// ```
73///
74#[repr(C)]
75#[derive(Debug)]
76#[allow(non_camel_case_types)]
77pub struct FFI_ArrowSchema {
78    format: *const c_char,
79    name: *const c_char,
80    metadata: *const c_char,
81    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
82    flags: i64,
83    n_children: i64,
84    children: *mut *mut FFI_ArrowSchema,
85    dictionary: *mut FFI_ArrowSchema,
86    release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowSchema)>,
87    private_data: *mut c_void,
88}
89
90struct SchemaPrivateData {
91    children: Box<[*mut FFI_ArrowSchema]>,
92    dictionary: *mut FFI_ArrowSchema,
93    metadata: Option<Vec<u8>>,
94}
95
96// callback used to drop [FFI_ArrowSchema] when it is exported.
97unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) {
98    if schema.is_null() {
99        return;
100    }
101    let schema = &mut *schema;
102
103    // take ownership back to release it.
104    drop(CString::from_raw(schema.format as *mut c_char));
105    if !schema.name.is_null() {
106        drop(CString::from_raw(schema.name as *mut c_char));
107    }
108    if !schema.private_data.is_null() {
109        let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
110        for child in private_data.children.iter() {
111            drop(Box::from_raw(*child))
112        }
113        if !private_data.dictionary.is_null() {
114            drop(Box::from_raw(private_data.dictionary));
115        }
116
117        drop(private_data);
118    }
119
120    schema.release = None;
121}
122
123impl FFI_ArrowSchema {
124    /// create a new [`FFI_ArrowSchema`]. This fails if the fields'
125    /// [`DataType`] is not supported.
126    pub fn try_new(
127        format: &str,
128        children: Vec<FFI_ArrowSchema>,
129        dictionary: Option<FFI_ArrowSchema>,
130    ) -> Result<Self, ArrowError> {
131        let mut this = Self::empty();
132
133        let children_ptr = children
134            .into_iter()
135            .map(Box::new)
136            .map(Box::into_raw)
137            .collect::<Box<_>>();
138
139        this.format = CString::new(format).unwrap().into_raw();
140        this.release = Some(release_schema);
141        this.n_children = children_ptr.len() as i64;
142
143        let dictionary_ptr = dictionary
144            .map(|d| Box::into_raw(Box::new(d)))
145            .unwrap_or(std::ptr::null_mut());
146
147        let mut private_data = Box::new(SchemaPrivateData {
148            children: children_ptr,
149            dictionary: dictionary_ptr,
150            metadata: None,
151        });
152
153        // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580)
154        this.children = private_data.children.as_mut_ptr();
155
156        this.dictionary = dictionary_ptr;
157
158        this.private_data = Box::into_raw(private_data) as *mut c_void;
159
160        Ok(this)
161    }
162
163    /// Set the name of the schema
164    pub fn with_name(mut self, name: &str) -> Result<Self, ArrowError> {
165        self.name = CString::new(name).unwrap().into_raw();
166        Ok(self)
167    }
168
169    /// Set the flags of the schema
170    pub fn with_flags(mut self, flags: Flags) -> Result<Self, ArrowError> {
171        self.flags = flags.bits();
172        Ok(self)
173    }
174
175    /// Add metadata to the schema
176    pub fn with_metadata<I, S>(mut self, metadata: I) -> Result<Self, ArrowError>
177    where
178        I: IntoIterator<Item = (S, S)>,
179        S: AsRef<str>,
180    {
181        let metadata: Vec<(S, S)> = metadata.into_iter().collect();
182        // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata
183        let new_metadata = if !metadata.is_empty() {
184            let mut metadata_serialized: Vec<u8> = Vec::new();
185            let num_entries: i32 = metadata.len().try_into().map_err(|_| {
186                ArrowError::CDataInterface(format!(
187                    "metadata can only have {} entries, but {} were provided",
188                    i32::MAX,
189                    metadata.len()
190                ))
191            })?;
192            metadata_serialized.extend(num_entries.to_ne_bytes());
193
194            for (key, value) in metadata.into_iter() {
195                let key_len: i32 = key.as_ref().len().try_into().map_err(|_| {
196                    ArrowError::CDataInterface(format!(
197                        "metadata key can only have {} bytes, but {} were provided",
198                        i32::MAX,
199                        key.as_ref().len()
200                    ))
201                })?;
202                let value_len: i32 = value.as_ref().len().try_into().map_err(|_| {
203                    ArrowError::CDataInterface(format!(
204                        "metadata value can only have {} bytes, but {} were provided",
205                        i32::MAX,
206                        value.as_ref().len()
207                    ))
208                })?;
209
210                metadata_serialized.extend(key_len.to_ne_bytes());
211                metadata_serialized.extend_from_slice(key.as_ref().as_bytes());
212                metadata_serialized.extend(value_len.to_ne_bytes());
213                metadata_serialized.extend_from_slice(value.as_ref().as_bytes());
214            }
215
216            self.metadata = metadata_serialized.as_ptr() as *const c_char;
217            Some(metadata_serialized)
218        } else {
219            self.metadata = std::ptr::null_mut();
220            None
221        };
222
223        unsafe {
224            let mut private_data = Box::from_raw(self.private_data as *mut SchemaPrivateData);
225            private_data.metadata = new_metadata;
226            self.private_data = Box::into_raw(private_data) as *mut c_void;
227        }
228
229        Ok(self)
230    }
231
232    /// Takes ownership of the pointed to [`FFI_ArrowSchema`]
233    ///
234    /// This acts to [move] the data out of `schema`, setting the release callback to NULL
235    ///
236    /// # Safety
237    ///
238    /// * `schema` must be [valid] for reads and writes
239    /// * `schema` must be properly aligned
240    /// * `schema` must point to a properly initialized value of [`FFI_ArrowSchema`]
241    ///
242    /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array
243    /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety
244    pub unsafe fn from_raw(schema: *mut FFI_ArrowSchema) -> Self {
245        std::ptr::replace(schema, Self::empty())
246    }
247
248    /// Create an empty [`FFI_ArrowSchema`]
249    pub fn empty() -> Self {
250        Self {
251            format: std::ptr::null_mut(),
252            name: std::ptr::null_mut(),
253            metadata: std::ptr::null_mut(),
254            flags: 0,
255            n_children: 0,
256            children: std::ptr::null_mut(),
257            dictionary: std::ptr::null_mut(),
258            release: None,
259            private_data: std::ptr::null_mut(),
260        }
261    }
262
263    /// Returns the format of this schema.
264    pub fn format(&self) -> &str {
265        assert!(!self.format.is_null());
266        // safe because the lifetime of `self.format` equals `self`
267        unsafe { CStr::from_ptr(self.format) }
268            .to_str()
269            .expect("The external API has a non-utf8 as format")
270    }
271
272    /// Returns the name of this schema.
273    pub fn name(&self) -> Option<&str> {
274        if self.name.is_null() {
275            None
276        } else {
277            // safe because the lifetime of `self.name` equals `self`
278            Some(
279                unsafe { CStr::from_ptr(self.name) }
280                    .to_str()
281                    .expect("The external API has a non-utf8 as name"),
282            )
283        }
284    }
285
286    /// Returns the flags of this schema.
287    pub fn flags(&self) -> Option<Flags> {
288        Flags::from_bits(self.flags)
289    }
290
291    /// Returns the child of this schema at `index`.
292    ///
293    /// # Panics
294    ///
295    /// Panics if `index` is greater than or equal to the number of children.
296    ///
297    /// This is to make sure that the unsafe acces to raw pointer is sound.
298    pub fn child(&self, index: usize) -> &Self {
299        assert!(index < self.n_children as usize);
300        unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
301    }
302
303    /// Returns an iterator to the schema's children.
304    pub fn children(&self) -> impl Iterator<Item = &Self> {
305        (0..self.n_children as usize).map(move |i| self.child(i))
306    }
307
308    /// Returns if the field is semantically nullable,
309    /// regardless of whether it actually has null values.
310    pub fn nullable(&self) -> bool {
311        (self.flags / 2) & 1 == 1
312    }
313
314    /// Returns the reference to the underlying dictionary of the schema.
315    /// Check [ArrowSchema.dictionary](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.dictionary).
316    ///
317    /// This must be `Some` if the schema represents a dictionary-encoded type, `None` otherwise.
318    pub fn dictionary(&self) -> Option<&Self> {
319        unsafe { self.dictionary.as_ref() }
320    }
321
322    /// For map types, returns whether the keys within each map value are sorted.
323    ///
324    /// Refer to [Arrow Flags](https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.flags)
325    pub fn map_keys_sorted(&self) -> bool {
326        self.flags & 0b00000100 != 0
327    }
328
329    /// For dictionary-encoded types, returns whether the ordering of dictionary indices is semantically meaningful.
330    pub fn dictionary_ordered(&self) -> bool {
331        self.flags & 0b00000001 != 0
332    }
333
334    /// Returns the metadata in the schema as `Key-Value` pairs
335    pub fn metadata(&self) -> Result<HashMap<String, String>, ArrowError> {
336        if self.metadata.is_null() {
337            Ok(HashMap::new())
338        } else {
339            let mut pos = 0;
340
341            // On some platforms, c_char = u8, and on some, c_char = i8. Where c_char = u8, clippy
342            // wants to complain that we're casting to the same type, but if we remove the cast,
343            // this will fail to compile on the other platforms. So we must allow it.
344            #[allow(clippy::unnecessary_cast)]
345            let buffer: *const u8 = self.metadata as *const u8;
346
347            fn next_four_bytes(buffer: *const u8, pos: &mut isize) -> [u8; 4] {
348                let out = unsafe {
349                    [
350                        *buffer.offset(*pos),
351                        *buffer.offset(*pos + 1),
352                        *buffer.offset(*pos + 2),
353                        *buffer.offset(*pos + 3),
354                    ]
355                };
356                *pos += 4;
357                out
358            }
359
360            fn next_n_bytes(buffer: *const u8, pos: &mut isize, n: i32) -> &[u8] {
361                let out = unsafe {
362                    std::slice::from_raw_parts(buffer.offset(*pos), n.try_into().unwrap())
363                };
364                *pos += isize::try_from(n).unwrap();
365                out
366            }
367
368            let num_entries = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
369            if num_entries < 0 {
370                return Err(ArrowError::CDataInterface(
371                    "Negative number of metadata entries".to_string(),
372                ));
373            }
374
375            let mut metadata =
376                HashMap::with_capacity(num_entries.try_into().expect("Too many metadata entries"));
377
378            for _ in 0..num_entries {
379                let key_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
380                if key_length < 0 {
381                    return Err(ArrowError::CDataInterface(
382                        "Negative key length in metadata".to_string(),
383                    ));
384                }
385                let key = String::from_utf8(next_n_bytes(buffer, &mut pos, key_length).to_vec())?;
386                let value_length = i32::from_ne_bytes(next_four_bytes(buffer, &mut pos));
387                if value_length < 0 {
388                    return Err(ArrowError::CDataInterface(
389                        "Negative value length in metadata".to_string(),
390                    ));
391                }
392                let value =
393                    String::from_utf8(next_n_bytes(buffer, &mut pos, value_length).to_vec())?;
394                metadata.insert(key, value);
395            }
396
397            Ok(metadata)
398        }
399    }
400}
401
402impl Drop for FFI_ArrowSchema {
403    fn drop(&mut self) {
404        match self.release {
405            None => (),
406            Some(release) => unsafe { release(self) },
407        };
408    }
409}
410
411unsafe impl Send for FFI_ArrowSchema {}
412
413impl TryFrom<&FFI_ArrowSchema> for DataType {
414    type Error = ArrowError;
415
416    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
417    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
418        let mut dtype = match c_schema.format() {
419            "n" => DataType::Null,
420            "b" => DataType::Boolean,
421            "c" => DataType::Int8,
422            "C" => DataType::UInt8,
423            "s" => DataType::Int16,
424            "S" => DataType::UInt16,
425            "i" => DataType::Int32,
426            "I" => DataType::UInt32,
427            "l" => DataType::Int64,
428            "L" => DataType::UInt64,
429            "e" => DataType::Float16,
430            "f" => DataType::Float32,
431            "g" => DataType::Float64,
432            "vz" => DataType::BinaryView,
433            "z" => DataType::Binary,
434            "Z" => DataType::LargeBinary,
435            "vu" => DataType::Utf8View,
436            "u" => DataType::Utf8,
437            "U" => DataType::LargeUtf8,
438            "tdD" => DataType::Date32,
439            "tdm" => DataType::Date64,
440            "tts" => DataType::Time32(TimeUnit::Second),
441            "ttm" => DataType::Time32(TimeUnit::Millisecond),
442            "ttu" => DataType::Time64(TimeUnit::Microsecond),
443            "ttn" => DataType::Time64(TimeUnit::Nanosecond),
444            "tDs" => DataType::Duration(TimeUnit::Second),
445            "tDm" => DataType::Duration(TimeUnit::Millisecond),
446            "tDu" => DataType::Duration(TimeUnit::Microsecond),
447            "tDn" => DataType::Duration(TimeUnit::Nanosecond),
448            "tiM" => DataType::Interval(IntervalUnit::YearMonth),
449            "tiD" => DataType::Interval(IntervalUnit::DayTime),
450            "tin" => DataType::Interval(IntervalUnit::MonthDayNano),
451            "+l" => {
452                let c_child = c_schema.child(0);
453                DataType::List(Arc::new(Field::try_from(c_child)?))
454            }
455            "+L" => {
456                let c_child = c_schema.child(0);
457                DataType::LargeList(Arc::new(Field::try_from(c_child)?))
458            }
459            "+s" => {
460                let fields = c_schema.children().map(Field::try_from);
461                DataType::Struct(fields.collect::<Result<_, ArrowError>>()?)
462            }
463            "+m" => {
464                let c_child = c_schema.child(0);
465                let map_keys_sorted = c_schema.map_keys_sorted();
466                DataType::Map(Arc::new(Field::try_from(c_child)?), map_keys_sorted)
467            }
468            "+r" => {
469                let c_run_ends = c_schema.child(0);
470                let c_values = c_schema.child(1);
471                DataType::RunEndEncoded(
472                    Arc::new(Field::try_from(c_run_ends)?),
473                    Arc::new(Field::try_from(c_values)?),
474                )
475            }
476            // Parametrized types, requiring string parse
477            other => {
478                match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
479                    // FixedSizeBinary type in format "w:num_bytes"
480                    ["w", num_bytes] => {
481                        let parsed_num_bytes = num_bytes.parse::<i32>().map_err(|_| {
482                            ArrowError::CDataInterface(
483                                "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string())
484                        })?;
485                        DataType::FixedSizeBinary(parsed_num_bytes)
486                    },
487                    // FixedSizeList type in format "+w:num_elems"
488                    ["+w", num_elems] => {
489                        let c_child = c_schema.child(0);
490                        let parsed_num_elems = num_elems.parse::<i32>().map_err(|_| {
491                            ArrowError::CDataInterface(
492                                "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string())
493                        })?;
494                        DataType::FixedSizeList(Arc::new(Field::try_from(c_child)?), parsed_num_elems)
495                    },
496                    // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
497                    ["d", extra] => {
498                        match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
499                            [precision, scale] => {
500                                let parsed_precision = precision.parse::<u8>().map_err(|_| {
501                                    ArrowError::CDataInterface(
502                                        "The decimal type requires an integer precision".to_string(),
503                                    )
504                                })?;
505                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
506                                    ArrowError::CDataInterface(
507                                        "The decimal type requires an integer scale".to_string(),
508                                    )
509                                })?;
510                                DataType::Decimal128(parsed_precision, parsed_scale)
511                            },
512                            [precision, scale, bits] => {
513                                let parsed_precision = precision.parse::<u8>().map_err(|_| {
514                                    ArrowError::CDataInterface(
515                                        "The decimal type requires an integer precision".to_string(),
516                                    )
517                                })?;
518                                let parsed_scale = scale.parse::<i8>().map_err(|_| {
519                                    ArrowError::CDataInterface(
520                                        "The decimal type requires an integer scale".to_string(),
521                                    )
522                                })?;
523                                match *bits {
524                                    "128" => DataType::Decimal128(parsed_precision, parsed_scale),
525                                    "256" => DataType::Decimal256(parsed_precision, parsed_scale),
526                                    _ => return Err(ArrowError::CDataInterface("Only 128- and 256- bit wide decimals are supported in the Rust implementation".to_string())),
527                                }
528                            }
529                            _ => {
530                                return Err(ArrowError::CDataInterface(format!(
531                                    "The decimal pattern \"d:{extra:?}\" is not supported in the Rust implementation"
532                                )))
533                            }
534                        }
535                    }
536                    // DenseUnion
537                    ["+ud", extra] => {
538                        let type_ids = extra.split(',').map(|t| t.parse::<i8>().map_err(|_| {
539                            ArrowError::CDataInterface(
540                                "The Union type requires an integer type id".to_string(),
541                            )
542                        })).collect::<Result<Vec<_>, ArrowError>>()?;
543                        let mut fields = Vec::with_capacity(type_ids.len());
544                        for idx in 0..c_schema.n_children {
545                            let c_child = c_schema.child(idx as usize);
546                            let field = Field::try_from(c_child)?;
547                            fields.push(field);
548                        }
549
550                        if fields.len() != type_ids.len() {
551                            return Err(ArrowError::CDataInterface(
552                                "The Union type requires same number of fields and type ids".to_string(),
553                            ));
554                        }
555
556                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
557                    }
558                    // SparseUnion
559                    ["+us", extra] => {
560                        let type_ids = extra.split(',').map(|t| t.parse::<i8>().map_err(|_| {
561                            ArrowError::CDataInterface(
562                                "The Union type requires an integer type id".to_string(),
563                            )
564                        })).collect::<Result<Vec<_>, ArrowError>>()?;
565                        let mut fields = Vec::with_capacity(type_ids.len());
566                        for idx in 0..c_schema.n_children {
567                            let c_child = c_schema.child(idx as usize);
568                            let field = Field::try_from(c_child)?;
569                            fields.push(field);
570                        }
571
572                        if fields.len() != type_ids.len() {
573                            return Err(ArrowError::CDataInterface(
574                                "The Union type requires same number of fields and type ids".to_string(),
575                            ));
576                        }
577
578                        DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Sparse)
579                    }
580
581                    // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
582                    ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
583                    ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
584                    ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
585                    ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
586                    ["tss", tz] => {
587                        DataType::Timestamp(TimeUnit::Second, Some(Arc::from(*tz)))
588                    }
589                    ["tsm", tz] => {
590                        DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from(*tz)))
591                    }
592                    ["tsu", tz] => {
593                        DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from(*tz)))
594                    }
595                    ["tsn", tz] => {
596                        DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from(*tz)))
597                    }
598                    _ => {
599                        return Err(ArrowError::CDataInterface(format!(
600                            "The datatype \"{other:?}\" is still not supported in Rust implementation"
601                        )))
602                    }
603                }
604            }
605        };
606
607        if let Some(dict_schema) = c_schema.dictionary() {
608            let value_type = Self::try_from(dict_schema)?;
609            dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type));
610        }
611
612        Ok(dtype)
613    }
614}
615
616impl TryFrom<&FFI_ArrowSchema> for Field {
617    type Error = ArrowError;
618
619    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
620        let dtype = DataType::try_from(c_schema)?;
621        let mut field = Field::new(c_schema.name().unwrap_or(""), dtype, c_schema.nullable());
622        field.set_metadata(c_schema.metadata()?);
623        Ok(field)
624    }
625}
626
627impl TryFrom<&FFI_ArrowSchema> for Schema {
628    type Error = ArrowError;
629
630    fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self, ArrowError> {
631        // interpret it as a struct type then extract its fields
632        let dtype = DataType::try_from(c_schema)?;
633        if let DataType::Struct(fields) = dtype {
634            Ok(Schema::new(fields).with_metadata(c_schema.metadata()?))
635        } else {
636            Err(ArrowError::CDataInterface(
637                "Unable to interpret C data struct as a Schema".to_string(),
638            ))
639        }
640    }
641}
642
643impl TryFrom<&DataType> for FFI_ArrowSchema {
644    type Error = ArrowError;
645
646    /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
647    fn try_from(dtype: &DataType) -> Result<Self, ArrowError> {
648        let format = get_format_string(dtype)?;
649        // allocate and hold the children
650        let children = match dtype {
651            DataType::List(child)
652            | DataType::LargeList(child)
653            | DataType::FixedSizeList(child, _)
654            | DataType::Map(child, _) => {
655                vec![FFI_ArrowSchema::try_from(child.as_ref())?]
656            }
657            DataType::Union(fields, _) => fields
658                .iter()
659                .map(|(_, f)| f.as_ref().try_into())
660                .collect::<Result<Vec<_>, ArrowError>>()?,
661            DataType::Struct(fields) => fields
662                .iter()
663                .map(FFI_ArrowSchema::try_from)
664                .collect::<Result<Vec<_>, ArrowError>>()?,
665            DataType::RunEndEncoded(run_ends, values) => vec![
666                FFI_ArrowSchema::try_from(run_ends.as_ref())?,
667                FFI_ArrowSchema::try_from(values.as_ref())?,
668            ],
669            _ => vec![],
670        };
671        let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype {
672            Some(Self::try_from(value_data_type.as_ref())?)
673        } else {
674            None
675        };
676
677        let flags = match dtype {
678            DataType::Map(_, true) => Flags::MAP_KEYS_SORTED,
679            _ => Flags::empty(),
680        };
681
682        FFI_ArrowSchema::try_new(&format, children, dictionary)?.with_flags(flags)
683    }
684}
685
686fn get_format_string(dtype: &DataType) -> Result<Cow<'static, str>, ArrowError> {
687    match dtype {
688        DataType::Null => Ok("n".into()),
689        DataType::Boolean => Ok("b".into()),
690        DataType::Int8 => Ok("c".into()),
691        DataType::UInt8 => Ok("C".into()),
692        DataType::Int16 => Ok("s".into()),
693        DataType::UInt16 => Ok("S".into()),
694        DataType::Int32 => Ok("i".into()),
695        DataType::UInt32 => Ok("I".into()),
696        DataType::Int64 => Ok("l".into()),
697        DataType::UInt64 => Ok("L".into()),
698        DataType::Float16 => Ok("e".into()),
699        DataType::Float32 => Ok("f".into()),
700        DataType::Float64 => Ok("g".into()),
701        DataType::BinaryView => Ok("vz".into()),
702        DataType::Binary => Ok("z".into()),
703        DataType::LargeBinary => Ok("Z".into()),
704        DataType::Utf8View => Ok("vu".into()),
705        DataType::Utf8 => Ok("u".into()),
706        DataType::LargeUtf8 => Ok("U".into()),
707        DataType::FixedSizeBinary(num_bytes) => Ok(Cow::Owned(format!("w:{num_bytes}"))),
708        DataType::FixedSizeList(_, num_elems) => Ok(Cow::Owned(format!("+w:{num_elems}"))),
709        DataType::Decimal128(precision, scale) => Ok(Cow::Owned(format!("d:{precision},{scale}"))),
710        DataType::Decimal256(precision, scale) => {
711            Ok(Cow::Owned(format!("d:{precision},{scale},256")))
712        }
713        DataType::Date32 => Ok("tdD".into()),
714        DataType::Date64 => Ok("tdm".into()),
715        DataType::Time32(TimeUnit::Second) => Ok("tts".into()),
716        DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".into()),
717        DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".into()),
718        DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".into()),
719        DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".into()),
720        DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".into()),
721        DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".into()),
722        DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".into()),
723        DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(Cow::Owned(format!("tss:{tz}"))),
724        DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(Cow::Owned(format!("tsm:{tz}"))),
725        DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(Cow::Owned(format!("tsu:{tz}"))),
726        DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(Cow::Owned(format!("tsn:{tz}"))),
727        DataType::Duration(TimeUnit::Second) => Ok("tDs".into()),
728        DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".into()),
729        DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".into()),
730        DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".into()),
731        DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".into()),
732        DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".into()),
733        DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".into()),
734        DataType::List(_) => Ok("+l".into()),
735        DataType::LargeList(_) => Ok("+L".into()),
736        DataType::Struct(_) => Ok("+s".into()),
737        DataType::Map(_, _) => Ok("+m".into()),
738        DataType::RunEndEncoded(_, _) => Ok("+r".into()),
739        DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type),
740        DataType::Union(fields, mode) => {
741            let formats = fields
742                .iter()
743                .map(|(t, _)| t.to_string())
744                .collect::<Vec<_>>();
745            match mode {
746                UnionMode::Dense => Ok(Cow::Owned(format!("{}:{}", "+ud", formats.join(",")))),
747                UnionMode::Sparse => Ok(Cow::Owned(format!("{}:{}", "+us", formats.join(",")))),
748            }
749        }
750        other => Err(ArrowError::CDataInterface(format!(
751            "The datatype \"{other:?}\" is still not supported in Rust implementation"
752        ))),
753    }
754}
755
756impl TryFrom<&FieldRef> for FFI_ArrowSchema {
757    type Error = ArrowError;
758
759    fn try_from(value: &FieldRef) -> Result<Self, Self::Error> {
760        value.as_ref().try_into()
761    }
762}
763
764impl TryFrom<&Field> for FFI_ArrowSchema {
765    type Error = ArrowError;
766
767    fn try_from(field: &Field) -> Result<Self, ArrowError> {
768        let mut flags = if field.is_nullable() {
769            Flags::NULLABLE
770        } else {
771            Flags::empty()
772        };
773
774        if let Some(true) = field.dict_is_ordered() {
775            flags |= Flags::DICTIONARY_ORDERED;
776        }
777
778        FFI_ArrowSchema::try_from(field.data_type())?
779            .with_name(field.name())?
780            .with_flags(flags)?
781            .with_metadata(field.metadata())
782    }
783}
784
785impl TryFrom<&Schema> for FFI_ArrowSchema {
786    type Error = ArrowError;
787
788    fn try_from(schema: &Schema) -> Result<Self, ArrowError> {
789        let dtype = DataType::Struct(schema.fields().clone());
790        let c_schema = FFI_ArrowSchema::try_from(&dtype)?.with_metadata(&schema.metadata)?;
791        Ok(c_schema)
792    }
793}
794
795impl TryFrom<DataType> for FFI_ArrowSchema {
796    type Error = ArrowError;
797
798    fn try_from(dtype: DataType) -> Result<Self, ArrowError> {
799        FFI_ArrowSchema::try_from(&dtype)
800    }
801}
802
803impl TryFrom<Field> for FFI_ArrowSchema {
804    type Error = ArrowError;
805
806    fn try_from(field: Field) -> Result<Self, ArrowError> {
807        FFI_ArrowSchema::try_from(&field)
808    }
809}
810
811impl TryFrom<Schema> for FFI_ArrowSchema {
812    type Error = ArrowError;
813
814    fn try_from(schema: Schema) -> Result<Self, ArrowError> {
815        FFI_ArrowSchema::try_from(&schema)
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use crate::Fields;
823
824    fn round_trip_type(dtype: DataType) {
825        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
826        let restored = DataType::try_from(&c_schema).unwrap();
827        assert_eq!(restored, dtype);
828    }
829
830    fn round_trip_field(field: Field) {
831        let c_schema = FFI_ArrowSchema::try_from(&field).unwrap();
832        let restored = Field::try_from(&c_schema).unwrap();
833        assert_eq!(restored, field);
834    }
835
836    fn round_trip_schema(schema: Schema) {
837        let c_schema = FFI_ArrowSchema::try_from(&schema).unwrap();
838        let restored = Schema::try_from(&c_schema).unwrap();
839        assert_eq!(restored, schema);
840    }
841
842    #[test]
843    fn test_type() {
844        round_trip_type(DataType::Int64);
845        round_trip_type(DataType::UInt64);
846        round_trip_type(DataType::Float64);
847        round_trip_type(DataType::Date64);
848        round_trip_type(DataType::Time64(TimeUnit::Nanosecond));
849        round_trip_type(DataType::FixedSizeBinary(12));
850        round_trip_type(DataType::FixedSizeList(
851            Arc::new(Field::new("a", DataType::Int64, false)),
852            5,
853        ));
854        round_trip_type(DataType::Utf8);
855        round_trip_type(DataType::Utf8View);
856        round_trip_type(DataType::BinaryView);
857        round_trip_type(DataType::Binary);
858        round_trip_type(DataType::LargeBinary);
859        round_trip_type(DataType::List(Arc::new(Field::new(
860            "a",
861            DataType::Int16,
862            false,
863        ))));
864        round_trip_type(DataType::Struct(Fields::from(vec![Field::new(
865            "a",
866            DataType::Utf8,
867            true,
868        )])));
869        round_trip_type(DataType::RunEndEncoded(
870            Arc::new(Field::new("run_ends", DataType::Int32, false)),
871            Arc::new(Field::new("values", DataType::Binary, true)),
872        ));
873    }
874
875    #[test]
876    fn test_field() {
877        let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)].into());
878        round_trip_field(Field::new("test", dtype, true));
879    }
880
881    #[test]
882    fn test_schema() {
883        let schema = Schema::new(vec![
884            Field::new("name", DataType::Utf8, false),
885            Field::new("address", DataType::Utf8, false),
886            Field::new("priority", DataType::UInt8, false),
887        ])
888        .with_metadata([("hello".to_string(), "world".to_string())].into());
889
890        round_trip_schema(schema);
891
892        // test that we can interpret struct types as schema
893        let dtype = DataType::Struct(Fields::from(vec![
894            Field::new("a", DataType::Utf8, true),
895            Field::new("b", DataType::Int16, false),
896        ]));
897        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
898        let schema = Schema::try_from(&c_schema).unwrap();
899        assert_eq!(schema.fields().len(), 2);
900
901        // test that we assert the input type
902        let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64).unwrap();
903        let result = Schema::try_from(&c_schema);
904        assert!(result.is_err());
905    }
906
907    #[test]
908    fn test_map_keys_sorted() {
909        let keys = Field::new("keys", DataType::Int32, false);
910        let values = Field::new("values", DataType::UInt32, false);
911        let entry_struct = DataType::Struct(vec![keys, values].into());
912
913        // Construct a map array from the above two
914        let map_data_type =
915            DataType::Map(Arc::new(Field::new("entries", entry_struct, false)), true);
916
917        let arrow_schema = FFI_ArrowSchema::try_from(map_data_type).unwrap();
918        assert!(arrow_schema.map_keys_sorted());
919    }
920
921    #[test]
922    fn test_dictionary_ordered() {
923        #[allow(deprecated)]
924        let schema = Schema::new(vec![Field::new_dict(
925            "dict",
926            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
927            false,
928            0,
929            true,
930        )]);
931
932        let arrow_schema = FFI_ArrowSchema::try_from(schema).unwrap();
933        assert!(arrow_schema.child(0).dictionary_ordered());
934    }
935
936    #[test]
937    fn test_set_field_metadata() {
938        let metadata_cases: Vec<HashMap<String, String>> = vec![
939            [].into(),
940            [("key".to_string(), "value".to_string())].into(),
941            [
942                ("key".to_string(), "".to_string()),
943                ("ascii123".to_string(), "你好".to_string()),
944                ("".to_string(), "value".to_string()),
945            ]
946            .into(),
947        ];
948
949        let mut schema = FFI_ArrowSchema::try_new("b", vec![], None)
950            .unwrap()
951            .with_name("test")
952            .unwrap();
953
954        for metadata in metadata_cases {
955            schema = schema.with_metadata(&metadata).unwrap();
956            let field = Field::try_from(&schema).unwrap();
957            assert_eq!(field.metadata(), &metadata);
958        }
959    }
960
961    #[test]
962    fn test_import_field_with_null_name() {
963        let dtype = DataType::Int16;
964        let c_schema = FFI_ArrowSchema::try_from(&dtype).unwrap();
965        assert!(c_schema.name().is_none());
966        let field = Field::try_from(&c_schema).unwrap();
967        assert_eq!(field.name(), "");
968    }
969}