1use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
19use arrow_schema::{
20 ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26#[derive(Debug, Copy, Clone)]
32pub enum Nullability {
33 NullFirst,
35 NullSecond,
37}
38
39#[derive(Debug, Clone)]
41pub struct AvroDataType {
42 nullability: Option<Nullability>,
43 metadata: HashMap<String, String>,
44 codec: Codec,
45}
46
47impl AvroDataType {
48 pub fn field_with_name(&self, name: &str) -> Field {
50 let d = self.codec.data_type();
51 Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone())
52 }
53
54 pub fn codec(&self) -> &Codec {
55 &self.codec
56 }
57
58 pub fn nullability(&self) -> Option<Nullability> {
59 self.nullability
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct AvroField {
66 name: String,
67 data_type: AvroDataType,
68}
69
70impl AvroField {
71 pub fn field(&self) -> Field {
73 self.data_type.field_with_name(&self.name)
74 }
75
76 pub fn data_type(&self) -> &AvroDataType {
78 &self.data_type
79 }
80
81 pub fn name(&self) -> &str {
82 &self.name
83 }
84}
85
86impl<'a> TryFrom<&Schema<'a>> for AvroField {
87 type Error = ArrowError;
88
89 fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> {
90 match schema {
91 Schema::Complex(ComplexType::Record(r)) => {
92 let mut resolver = Resolver::default();
93 let data_type = make_data_type(schema, None, &mut resolver)?;
94 Ok(AvroField {
95 data_type,
96 name: r.name.to_string(),
97 })
98 }
99 _ => Err(ArrowError::ParseError(format!(
100 "Expected record got {schema:?}"
101 ))),
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
110pub enum Codec {
111 Null,
112 Boolean,
113 Int32,
114 Int64,
115 Float32,
116 Float64,
117 Binary,
118 Utf8,
119 Date32,
120 TimeMillis,
121 TimeMicros,
122 TimestampMillis(bool),
124 TimestampMicros(bool),
126 Fixed(i32),
127 List(Arc<AvroDataType>),
128 Struct(Arc<[AvroField]>),
129 Interval,
130}
131
132impl Codec {
133 fn data_type(&self) -> DataType {
134 match self {
135 Self::Null => DataType::Null,
136 Self::Boolean => DataType::Boolean,
137 Self::Int32 => DataType::Int32,
138 Self::Int64 => DataType::Int64,
139 Self::Float32 => DataType::Float32,
140 Self::Float64 => DataType::Float64,
141 Self::Binary => DataType::Binary,
142 Self::Utf8 => DataType::Utf8,
143 Self::Date32 => DataType::Date32,
144 Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
145 Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
146 Self::TimestampMillis(is_utc) => {
147 DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into()))
148 }
149 Self::TimestampMicros(is_utc) => {
150 DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into()))
151 }
152 Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
153 Self::Fixed(size) => DataType::FixedSizeBinary(*size),
154 Self::List(f) => {
155 DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
156 }
157 Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
158 }
159 }
160}
161
162impl From<PrimitiveType> for Codec {
163 fn from(value: PrimitiveType) -> Self {
164 match value {
165 PrimitiveType::Null => Self::Null,
166 PrimitiveType::Boolean => Self::Boolean,
167 PrimitiveType::Int => Self::Int32,
168 PrimitiveType::Long => Self::Int64,
169 PrimitiveType::Float => Self::Float32,
170 PrimitiveType::Double => Self::Float64,
171 PrimitiveType::Bytes => Self::Binary,
172 PrimitiveType::String => Self::Utf8,
173 }
174 }
175}
176
177#[derive(Debug, Default)]
181struct Resolver<'a> {
182 map: HashMap<(&'a str, &'a str), AvroDataType>,
183}
184
185impl<'a> Resolver<'a> {
186 fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) {
187 self.map.insert((name, namespace.unwrap_or("")), schema);
188 }
189
190 fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> {
191 let (namespace, name) = name
192 .rsplit_once('.')
193 .unwrap_or_else(|| (namespace.unwrap_or(""), name));
194
195 self.map
196 .get(&(namespace, name))
197 .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}")))
198 .cloned()
199 }
200}
201
202fn make_data_type<'a>(
209 schema: &Schema<'a>,
210 namespace: Option<&'a str>,
211 resolver: &mut Resolver<'a>,
212) -> Result<AvroDataType, ArrowError> {
213 match schema {
214 Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType {
215 nullability: None,
216 metadata: Default::default(),
217 codec: (*p).into(),
218 }),
219 Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace),
220 Schema::Union(f) => {
221 let null = f
223 .iter()
224 .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
225 match (f.len() == 2, null) {
226 (true, Some(0)) => {
227 let mut field = make_data_type(&f[1], namespace, resolver)?;
228 field.nullability = Some(Nullability::NullFirst);
229 Ok(field)
230 }
231 (true, Some(1)) => {
232 let mut field = make_data_type(&f[0], namespace, resolver)?;
233 field.nullability = Some(Nullability::NullSecond);
234 Ok(field)
235 }
236 _ => Err(ArrowError::NotYetImplemented(format!(
237 "Union of {f:?} not currently supported"
238 ))),
239 }
240 }
241 Schema::Complex(c) => match c {
242 ComplexType::Record(r) => {
243 let namespace = r.namespace.or(namespace);
244 let fields = r
245 .fields
246 .iter()
247 .map(|field| {
248 Ok(AvroField {
249 name: field.name.to_string(),
250 data_type: make_data_type(&field.r#type, namespace, resolver)?,
251 })
252 })
253 .collect::<Result<_, ArrowError>>()?;
254
255 let field = AvroDataType {
256 nullability: None,
257 codec: Codec::Struct(fields),
258 metadata: r.attributes.field_metadata(),
259 };
260 resolver.register(r.name, namespace, field.clone());
261 Ok(field)
262 }
263 ComplexType::Array(a) => {
264 let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?;
265 Ok(AvroDataType {
266 nullability: None,
267 metadata: a.attributes.field_metadata(),
268 codec: Codec::List(Arc::new(field)),
269 })
270 }
271 ComplexType::Fixed(f) => {
272 let size = f.size.try_into().map_err(|e| {
273 ArrowError::ParseError(format!("Overflow converting size to i32: {e}"))
274 })?;
275
276 let field = AvroDataType {
277 nullability: None,
278 metadata: f.attributes.field_metadata(),
279 codec: Codec::Fixed(size),
280 };
281 resolver.register(f.name, namespace, field.clone());
282 Ok(field)
283 }
284 ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
285 "Enum of {e:?} not currently supported"
286 ))),
287 ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!(
288 "Map of {m:?} not currently supported"
289 ))),
290 },
291 Schema::Type(t) => {
292 let mut field =
293 make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?;
294
295 match (t.attributes.logical_type, &mut field.codec) {
297 (Some("decimal"), c @ Codec::Fixed(_)) => {
298 return Err(ArrowError::NotYetImplemented(
299 "Decimals are not currently supported".to_string(),
300 ))
301 }
302 (Some("date"), c @ Codec::Int32) => *c = Codec::Date32,
303 (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis,
304 (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros,
305 (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true),
306 (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true),
307 (Some("local-timestamp-millis"), c @ Codec::Int64) => {
308 *c = Codec::TimestampMillis(false)
309 }
310 (Some("local-timestamp-micros"), c @ Codec::Int64) => {
311 *c = Codec::TimestampMicros(false)
312 }
313 (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval,
314 (Some(logical), _) => {
315 field.metadata.insert("logicalType".into(), logical.into());
317 }
318 (None, _) => {}
319 }
320
321 if !t.attributes.additional.is_empty() {
322 for (k, v) in &t.attributes.additional {
323 field.metadata.insert(k.to_string(), v.to_string());
324 }
325 }
326 Ok(field)
327 }
328 }
329}