1use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
19use arrow_schema::{
20 ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26#[derive(Debug, Copy, Clone)]
32pub enum Nullability {
33 NullFirst,
35 NullSecond,
37}
38
39#[derive(Debug, Clone)]
41pub struct AvroDataType {
42 nullability: Option<Nullability>,
43 metadata: HashMap<String, String>,
44 codec: Codec,
45}
46
47impl AvroDataType {
48 pub fn new(
50 codec: Codec,
51 metadata: HashMap<String, String>,
52 nullability: Option<Nullability>,
53 ) -> Self {
54 AvroDataType {
55 codec,
56 metadata,
57 nullability,
58 }
59 }
60
61 pub fn field_with_name(&self, name: &str) -> Field {
63 let d = self.codec.data_type();
64 Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone())
65 }
66
67 pub fn codec(&self) -> &Codec {
72 &self.codec
73 }
74
75 pub fn nullability(&self) -> Option<Nullability> {
83 self.nullability
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct AvroField {
90 name: String,
91 data_type: AvroDataType,
92}
93
94impl AvroField {
95 pub fn field(&self) -> Field {
97 self.data_type.field_with_name(&self.name)
98 }
99
100 pub fn data_type(&self) -> &AvroDataType {
102 &self.data_type
103 }
104
105 pub fn with_utf8view(&self) -> Self {
114 let mut field = self.clone();
115 if let Codec::Utf8 = field.data_type.codec {
116 field.data_type.codec = Codec::Utf8View;
117 }
118 field
119 }
120
121 pub fn name(&self) -> &str {
126 &self.name
127 }
128}
129
130impl<'a> TryFrom<&Schema<'a>> for AvroField {
131 type Error = ArrowError;
132
133 fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> {
134 match schema {
135 Schema::Complex(ComplexType::Record(r)) => {
136 let mut resolver = Resolver::default();
137 let data_type = make_data_type(schema, None, &mut resolver, false)?;
138 Ok(AvroField {
139 data_type,
140 name: r.name.to_string(),
141 })
142 }
143 _ => Err(ArrowError::ParseError(format!(
144 "Expected record got {schema:?}"
145 ))),
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
154pub enum Codec {
155 Null,
157 Boolean,
159 Int32,
161 Int64,
163 Float32,
165 Float64,
167 Binary,
169 Utf8,
171 Utf8View,
176 Date32,
178 TimeMillis,
180 TimeMicros,
182 TimestampMillis(bool),
187 TimestampMicros(bool),
192 Fixed(i32),
195 List(Arc<AvroDataType>),
197 Struct(Arc<[AvroField]>),
199 Map(Arc<AvroDataType>),
201 Interval,
203}
204
205impl Codec {
206 fn data_type(&self) -> DataType {
207 match self {
208 Self::Null => DataType::Null,
209 Self::Boolean => DataType::Boolean,
210 Self::Int32 => DataType::Int32,
211 Self::Int64 => DataType::Int64,
212 Self::Float32 => DataType::Float32,
213 Self::Float64 => DataType::Float64,
214 Self::Binary => DataType::Binary,
215 Self::Utf8 => DataType::Utf8,
216 Self::Utf8View => DataType::Utf8View,
217 Self::Date32 => DataType::Date32,
218 Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
219 Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
220 Self::TimestampMillis(is_utc) => {
221 DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into()))
222 }
223 Self::TimestampMicros(is_utc) => {
224 DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into()))
225 }
226 Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
227 Self::Fixed(size) => DataType::FixedSizeBinary(*size),
228 Self::List(f) => {
229 DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
230 }
231 Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
232 Self::Map(value_type) => {
233 let val_dt = value_type.codec.data_type();
234 let val_field = Field::new("value", val_dt, value_type.nullability.is_some())
235 .with_metadata(value_type.metadata.clone());
236 DataType::Map(
237 Arc::new(Field::new(
238 "entries",
239 DataType::Struct(Fields::from(vec![
240 Field::new("key", DataType::Utf8, false),
241 val_field,
242 ])),
243 false,
244 )),
245 false,
246 )
247 }
248 }
249 }
250}
251
252impl From<PrimitiveType> for Codec {
253 fn from(value: PrimitiveType) -> Self {
254 match value {
255 PrimitiveType::Null => Self::Null,
256 PrimitiveType::Boolean => Self::Boolean,
257 PrimitiveType::Int => Self::Int32,
258 PrimitiveType::Long => Self::Int64,
259 PrimitiveType::Float => Self::Float32,
260 PrimitiveType::Double => Self::Float64,
261 PrimitiveType::Bytes => Self::Binary,
262 PrimitiveType::String => Self::Utf8,
263 }
264 }
265}
266
267impl Codec {
268 pub fn with_utf8view(self, use_utf8view: bool) -> Self {
289 if use_utf8view && matches!(self, Self::Utf8) {
290 Self::Utf8View
291 } else {
292 self
293 }
294 }
295}
296
297#[derive(Debug, Default)]
301struct Resolver<'a> {
302 map: HashMap<(&'a str, &'a str), AvroDataType>,
303}
304
305impl<'a> Resolver<'a> {
306 fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) {
307 self.map.insert((name, namespace.unwrap_or("")), schema);
308 }
309
310 fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> {
311 let (namespace, name) = name
312 .rsplit_once('.')
313 .unwrap_or_else(|| (namespace.unwrap_or(""), name));
314
315 self.map
316 .get(&(namespace, name))
317 .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}")))
318 .cloned()
319 }
320}
321
322fn make_data_type<'a>(
335 schema: &Schema<'a>,
336 namespace: Option<&'a str>,
337 resolver: &mut Resolver<'a>,
338 use_utf8view: bool,
339) -> Result<AvroDataType, ArrowError> {
340 match schema {
341 Schema::TypeName(TypeName::Primitive(p)) => {
342 let codec: Codec = (*p).into();
343 let codec = codec.with_utf8view(use_utf8view);
344 Ok(AvroDataType {
345 nullability: None,
346 metadata: Default::default(),
347 codec,
348 })
349 }
350 Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace),
351 Schema::Union(f) => {
352 let null = f
354 .iter()
355 .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
356 match (f.len() == 2, null) {
357 (true, Some(0)) => {
358 let mut field = make_data_type(&f[1], namespace, resolver, use_utf8view)?;
359 field.nullability = Some(Nullability::NullFirst);
360 Ok(field)
361 }
362 (true, Some(1)) => {
363 let mut field = make_data_type(&f[0], namespace, resolver, use_utf8view)?;
364 field.nullability = Some(Nullability::NullSecond);
365 Ok(field)
366 }
367 _ => Err(ArrowError::NotYetImplemented(format!(
368 "Union of {f:?} not currently supported"
369 ))),
370 }
371 }
372 Schema::Complex(c) => match c {
373 ComplexType::Record(r) => {
374 let namespace = r.namespace.or(namespace);
375 let fields = r
376 .fields
377 .iter()
378 .map(|field| {
379 Ok(AvroField {
380 name: field.name.to_string(),
381 data_type: make_data_type(
382 &field.r#type,
383 namespace,
384 resolver,
385 use_utf8view,
386 )?,
387 })
388 })
389 .collect::<Result<_, ArrowError>>()?;
390
391 let field = AvroDataType {
392 nullability: None,
393 codec: Codec::Struct(fields),
394 metadata: r.attributes.field_metadata(),
395 };
396 resolver.register(r.name, namespace, field.clone());
397 Ok(field)
398 }
399 ComplexType::Array(a) => {
400 let mut field =
401 make_data_type(a.items.as_ref(), namespace, resolver, use_utf8view)?;
402 Ok(AvroDataType {
403 nullability: None,
404 metadata: a.attributes.field_metadata(),
405 codec: Codec::List(Arc::new(field)),
406 })
407 }
408 ComplexType::Fixed(f) => {
409 let size = f.size.try_into().map_err(|e| {
410 ArrowError::ParseError(format!("Overflow converting size to i32: {e}"))
411 })?;
412
413 let field = AvroDataType {
414 nullability: None,
415 metadata: f.attributes.field_metadata(),
416 codec: Codec::Fixed(size),
417 };
418 resolver.register(f.name, namespace, field.clone());
419 Ok(field)
420 }
421 ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
422 "Enum of {e:?} not currently supported"
423 ))),
424 ComplexType::Map(m) => {
425 let val = make_data_type(&m.values, namespace, resolver, use_utf8view)?;
426 Ok(AvroDataType {
427 nullability: None,
428 metadata: m.attributes.field_metadata(),
429 codec: Codec::Map(Arc::new(val)),
430 })
431 }
432 },
433 Schema::Type(t) => {
434 let mut field = make_data_type(
435 &Schema::TypeName(t.r#type.clone()),
436 namespace,
437 resolver,
438 use_utf8view,
439 )?;
440
441 match (t.attributes.logical_type, &mut field.codec) {
443 (Some("decimal"), c @ Codec::Fixed(_)) => {
444 return Err(ArrowError::NotYetImplemented(
445 "Decimals are not currently supported".to_string(),
446 ))
447 }
448 (Some("date"), c @ Codec::Int32) => *c = Codec::Date32,
449 (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis,
450 (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros,
451 (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true),
452 (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true),
453 (Some("local-timestamp-millis"), c @ Codec::Int64) => {
454 *c = Codec::TimestampMillis(false)
455 }
456 (Some("local-timestamp-micros"), c @ Codec::Int64) => {
457 *c = Codec::TimestampMicros(false)
458 }
459 (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval,
460 (Some(logical), _) => {
461 field.metadata.insert("logicalType".into(), logical.into());
463 }
464 (None, _) => {}
465 }
466
467 if !t.attributes.additional.is_empty() {
468 for (k, v) in &t.attributes.additional {
469 field.metadata.insert(k.to_string(), v.to_string());
470 }
471 }
472 Ok(field)
473 }
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::schema::{
481 Attributes, ComplexType, Fixed, PrimitiveType, Record, Schema, Type, TypeName,
482 };
483 use serde_json;
484 use std::collections::HashMap;
485
486 fn create_schema_with_logical_type(
487 primitive_type: PrimitiveType,
488 logical_type: &'static str,
489 ) -> Schema<'static> {
490 let attributes = Attributes {
491 logical_type: Some(logical_type),
492 additional: Default::default(),
493 };
494
495 Schema::Type(Type {
496 r#type: TypeName::Primitive(primitive_type),
497 attributes,
498 })
499 }
500
501 fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> {
502 let attributes = Attributes {
503 logical_type: Some(logical_type),
504 additional: Default::default(),
505 };
506
507 Schema::Complex(ComplexType::Fixed(Fixed {
508 name: "fixed_type",
509 namespace: None,
510 aliases: Vec::new(),
511 size,
512 attributes,
513 }))
514 }
515
516 #[test]
517 fn test_date_logical_type() {
518 let schema = create_schema_with_logical_type(PrimitiveType::Int, "date");
519
520 let mut resolver = Resolver::default();
521 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
522
523 assert!(matches!(result.codec, Codec::Date32));
524 }
525
526 #[test]
527 fn test_time_millis_logical_type() {
528 let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis");
529
530 let mut resolver = Resolver::default();
531 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
532
533 assert!(matches!(result.codec, Codec::TimeMillis));
534 }
535
536 #[test]
537 fn test_time_micros_logical_type() {
538 let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros");
539
540 let mut resolver = Resolver::default();
541 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
542
543 assert!(matches!(result.codec, Codec::TimeMicros));
544 }
545
546 #[test]
547 fn test_timestamp_millis_logical_type() {
548 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis");
549
550 let mut resolver = Resolver::default();
551 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
552
553 assert!(matches!(result.codec, Codec::TimestampMillis(true)));
554 }
555
556 #[test]
557 fn test_timestamp_micros_logical_type() {
558 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros");
559
560 let mut resolver = Resolver::default();
561 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
562
563 assert!(matches!(result.codec, Codec::TimestampMicros(true)));
564 }
565
566 #[test]
567 fn test_local_timestamp_millis_logical_type() {
568 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis");
569
570 let mut resolver = Resolver::default();
571 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
572
573 assert!(matches!(result.codec, Codec::TimestampMillis(false)));
574 }
575
576 #[test]
577 fn test_local_timestamp_micros_logical_type() {
578 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros");
579
580 let mut resolver = Resolver::default();
581 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
582
583 assert!(matches!(result.codec, Codec::TimestampMicros(false)));
584 }
585
586 #[test]
587 fn test_duration_logical_type() {
588 let mut codec = Codec::Fixed(12);
589
590 if let c @ Codec::Fixed(12) = &mut codec {
591 *c = Codec::Interval;
592 }
593
594 assert!(matches!(codec, Codec::Interval));
595 }
596
597 #[test]
598 fn test_decimal_logical_type_not_implemented() {
599 let mut codec = Codec::Fixed(16);
600
601 let process_decimal = || -> Result<(), ArrowError> {
602 if let Codec::Fixed(_) = codec {
603 return Err(ArrowError::NotYetImplemented(
604 "Decimals are not currently supported".to_string(),
605 ));
606 }
607 Ok(())
608 };
609
610 let result = process_decimal();
611
612 assert!(result.is_err());
613 if let Err(ArrowError::NotYetImplemented(msg)) = result {
614 assert!(msg.contains("Decimals are not currently supported"));
615 } else {
616 panic!("Expected NotYetImplemented error");
617 }
618 }
619
620 #[test]
621 fn test_unknown_logical_type_added_to_metadata() {
622 let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type");
623
624 let mut resolver = Resolver::default();
625 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
626
627 assert_eq!(
628 result.metadata.get("logicalType"),
629 Some(&"custom-type".to_string())
630 );
631 }
632
633 #[test]
634 fn test_string_with_utf8view_enabled() {
635 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
636
637 let mut resolver = Resolver::default();
638 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
639
640 assert!(matches!(result.codec, Codec::Utf8View));
641 }
642
643 #[test]
644 fn test_string_without_utf8view_enabled() {
645 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
646
647 let mut resolver = Resolver::default();
648 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
649
650 assert!(matches!(result.codec, Codec::Utf8));
651 }
652
653 #[test]
654 fn test_record_with_string_and_utf8view_enabled() {
655 let field_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
656
657 let avro_field = crate::schema::Field {
658 name: "string_field",
659 r#type: field_schema,
660 default: None,
661 doc: None,
662 };
663
664 let record = Record {
665 name: "test_record",
666 namespace: None,
667 aliases: vec![],
668 doc: None,
669 fields: vec![avro_field],
670 attributes: Attributes::default(),
671 };
672
673 let schema = Schema::Complex(ComplexType::Record(record));
674
675 let mut resolver = Resolver::default();
676 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
677
678 if let Codec::Struct(fields) = &result.codec {
679 let first_field_codec = &fields[0].data_type().codec;
680 assert!(matches!(first_field_codec, Codec::Utf8View));
681 } else {
682 panic!("Expected Struct codec");
683 }
684 }
685}