arrow_avro/
codec.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
19use arrow_schema::{
20    ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26/// Avro types are not nullable, with nullability instead encoded as a union
27/// where one of the variants is the null type.
28///
29/// To accommodate this we special case two-variant unions where one of the
30/// variants is the null type, and use this to derive arrow's notion of nullability
31#[derive(Debug, Copy, Clone)]
32pub enum Nullability {
33    /// The nulls are encoded as the first union variant
34    NullFirst,
35    /// The nulls are encoded as the second union variant
36    NullSecond,
37}
38
39/// An Avro datatype mapped to the arrow data model
40#[derive(Debug, Clone)]
41pub struct AvroDataType {
42    nullability: Option<Nullability>,
43    metadata: HashMap<String, String>,
44    codec: Codec,
45}
46
47impl AvroDataType {
48    /// Returns an arrow [`Field`] with the given name
49    pub fn field_with_name(&self, name: &str) -> Field {
50        let d = self.codec.data_type();
51        Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone())
52    }
53
54    pub fn codec(&self) -> &Codec {
55        &self.codec
56    }
57
58    pub fn nullability(&self) -> Option<Nullability> {
59        self.nullability
60    }
61}
62
63/// A named [`AvroDataType`]
64#[derive(Debug, Clone)]
65pub struct AvroField {
66    name: String,
67    data_type: AvroDataType,
68}
69
70impl AvroField {
71    /// Returns the arrow [`Field`]
72    pub fn field(&self) -> Field {
73        self.data_type.field_with_name(&self.name)
74    }
75
76    /// Returns the [`AvroDataType`]
77    pub fn data_type(&self) -> &AvroDataType {
78        &self.data_type
79    }
80
81    pub fn name(&self) -> &str {
82        &self.name
83    }
84}
85
86impl<'a> TryFrom<&Schema<'a>> for AvroField {
87    type Error = ArrowError;
88
89    fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> {
90        match schema {
91            Schema::Complex(ComplexType::Record(r)) => {
92                let mut resolver = Resolver::default();
93                let data_type = make_data_type(schema, None, &mut resolver)?;
94                Ok(AvroField {
95                    data_type,
96                    name: r.name.to_string(),
97                })
98            }
99            _ => Err(ArrowError::ParseError(format!(
100                "Expected record got {schema:?}"
101            ))),
102        }
103    }
104}
105
106/// An Avro encoding
107///
108/// <https://avro.apache.org/docs/1.11.1/specification/#encodings>
109#[derive(Debug, Clone)]
110pub enum Codec {
111    Null,
112    Boolean,
113    Int32,
114    Int64,
115    Float32,
116    Float64,
117    Binary,
118    Utf8,
119    Date32,
120    TimeMillis,
121    TimeMicros,
122    /// TimestampMillis(is_utc)
123    TimestampMillis(bool),
124    /// TimestampMicros(is_utc)
125    TimestampMicros(bool),
126    Fixed(i32),
127    List(Arc<AvroDataType>),
128    Struct(Arc<[AvroField]>),
129    Interval,
130}
131
132impl Codec {
133    fn data_type(&self) -> DataType {
134        match self {
135            Self::Null => DataType::Null,
136            Self::Boolean => DataType::Boolean,
137            Self::Int32 => DataType::Int32,
138            Self::Int64 => DataType::Int64,
139            Self::Float32 => DataType::Float32,
140            Self::Float64 => DataType::Float64,
141            Self::Binary => DataType::Binary,
142            Self::Utf8 => DataType::Utf8,
143            Self::Date32 => DataType::Date32,
144            Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
145            Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
146            Self::TimestampMillis(is_utc) => {
147                DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into()))
148            }
149            Self::TimestampMicros(is_utc) => {
150                DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into()))
151            }
152            Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
153            Self::Fixed(size) => DataType::FixedSizeBinary(*size),
154            Self::List(f) => {
155                DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
156            }
157            Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
158        }
159    }
160}
161
162impl From<PrimitiveType> for Codec {
163    fn from(value: PrimitiveType) -> Self {
164        match value {
165            PrimitiveType::Null => Self::Null,
166            PrimitiveType::Boolean => Self::Boolean,
167            PrimitiveType::Int => Self::Int32,
168            PrimitiveType::Long => Self::Int64,
169            PrimitiveType::Float => Self::Float32,
170            PrimitiveType::Double => Self::Float64,
171            PrimitiveType::Bytes => Self::Binary,
172            PrimitiveType::String => Self::Utf8,
173        }
174    }
175}
176
177/// Resolves Avro type names to [`AvroDataType`]
178///
179/// See <https://avro.apache.org/docs/1.11.1/specification/#names>
180#[derive(Debug, Default)]
181struct Resolver<'a> {
182    map: HashMap<(&'a str, &'a str), AvroDataType>,
183}
184
185impl<'a> Resolver<'a> {
186    fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) {
187        self.map.insert((name, namespace.unwrap_or("")), schema);
188    }
189
190    fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> {
191        let (namespace, name) = name
192            .rsplit_once('.')
193            .unwrap_or_else(|| (namespace.unwrap_or(""), name));
194
195        self.map
196            .get(&(namespace, name))
197            .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}")))
198            .cloned()
199    }
200}
201
202/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace`
203///
204/// `name`: is name used to refer to `schema` in its parent
205/// `namespace`: an optional qualifier used as part of a type hierarchy
206///
207/// See [`Resolver`] for more information
208fn make_data_type<'a>(
209    schema: &Schema<'a>,
210    namespace: Option<&'a str>,
211    resolver: &mut Resolver<'a>,
212) -> Result<AvroDataType, ArrowError> {
213    match schema {
214        Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType {
215            nullability: None,
216            metadata: Default::default(),
217            codec: (*p).into(),
218        }),
219        Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace),
220        Schema::Union(f) => {
221            // Special case the common case of nullable primitives
222            let null = f
223                .iter()
224                .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
225            match (f.len() == 2, null) {
226                (true, Some(0)) => {
227                    let mut field = make_data_type(&f[1], namespace, resolver)?;
228                    field.nullability = Some(Nullability::NullFirst);
229                    Ok(field)
230                }
231                (true, Some(1)) => {
232                    let mut field = make_data_type(&f[0], namespace, resolver)?;
233                    field.nullability = Some(Nullability::NullSecond);
234                    Ok(field)
235                }
236                _ => Err(ArrowError::NotYetImplemented(format!(
237                    "Union of {f:?} not currently supported"
238                ))),
239            }
240        }
241        Schema::Complex(c) => match c {
242            ComplexType::Record(r) => {
243                let namespace = r.namespace.or(namespace);
244                let fields = r
245                    .fields
246                    .iter()
247                    .map(|field| {
248                        Ok(AvroField {
249                            name: field.name.to_string(),
250                            data_type: make_data_type(&field.r#type, namespace, resolver)?,
251                        })
252                    })
253                    .collect::<Result<_, ArrowError>>()?;
254
255                let field = AvroDataType {
256                    nullability: None,
257                    codec: Codec::Struct(fields),
258                    metadata: r.attributes.field_metadata(),
259                };
260                resolver.register(r.name, namespace, field.clone());
261                Ok(field)
262            }
263            ComplexType::Array(a) => {
264                let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?;
265                Ok(AvroDataType {
266                    nullability: None,
267                    metadata: a.attributes.field_metadata(),
268                    codec: Codec::List(Arc::new(field)),
269                })
270            }
271            ComplexType::Fixed(f) => {
272                let size = f.size.try_into().map_err(|e| {
273                    ArrowError::ParseError(format!("Overflow converting size to i32: {e}"))
274                })?;
275
276                let field = AvroDataType {
277                    nullability: None,
278                    metadata: f.attributes.field_metadata(),
279                    codec: Codec::Fixed(size),
280                };
281                resolver.register(f.name, namespace, field.clone());
282                Ok(field)
283            }
284            ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
285                "Enum of {e:?} not currently supported"
286            ))),
287            ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!(
288                "Map of {m:?} not currently supported"
289            ))),
290        },
291        Schema::Type(t) => {
292            let mut field =
293                make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?;
294
295            // https://avro.apache.org/docs/1.11.1/specification/#logical-types
296            match (t.attributes.logical_type, &mut field.codec) {
297                (Some("decimal"), c @ Codec::Fixed(_)) => {
298                    return Err(ArrowError::NotYetImplemented(
299                        "Decimals are not currently supported".to_string(),
300                    ))
301                }
302                (Some("date"), c @ Codec::Int32) => *c = Codec::Date32,
303                (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis,
304                (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros,
305                (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true),
306                (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true),
307                (Some("local-timestamp-millis"), c @ Codec::Int64) => {
308                    *c = Codec::TimestampMillis(false)
309                }
310                (Some("local-timestamp-micros"), c @ Codec::Int64) => {
311                    *c = Codec::TimestampMicros(false)
312                }
313                (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval,
314                (Some(logical), _) => {
315                    // Insert unrecognized logical type into metadata map
316                    field.metadata.insert("logicalType".into(), logical.into());
317                }
318                (None, _) => {}
319            }
320
321            if !t.attributes.additional.is_empty() {
322                for (k, v) in &t.attributes.additional {
323                    field.metadata.insert(k.to_string(), v.to_string());
324                }
325            }
326            Ok(field)
327        }
328    }
329}