1use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
19use arrow_schema::DataType::{Decimal128, Decimal256};
20use arrow_schema::{
21 ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef,
22 TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
23};
24use std::borrow::Cow;
25use std::collections::HashMap;
26use std::sync::Arc;
27
28#[derive(Debug, Copy, Clone)]
34pub enum Nullability {
35 NullFirst,
37 NullSecond,
39}
40
41#[derive(Debug, Clone)]
43pub struct AvroDataType {
44 nullability: Option<Nullability>,
45 metadata: HashMap<String, String>,
46 codec: Codec,
47}
48
49impl AvroDataType {
50 pub fn new(
52 codec: Codec,
53 metadata: HashMap<String, String>,
54 nullability: Option<Nullability>,
55 ) -> Self {
56 AvroDataType {
57 codec,
58 metadata,
59 nullability,
60 }
61 }
62
63 pub fn field_with_name(&self, name: &str) -> Field {
65 let d = self.codec.data_type();
66 Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone())
67 }
68
69 pub fn codec(&self) -> &Codec {
74 &self.codec
75 }
76
77 pub fn nullability(&self) -> Option<Nullability> {
85 self.nullability
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct AvroField {
92 name: String,
93 data_type: AvroDataType,
94}
95
96impl AvroField {
97 pub fn field(&self) -> Field {
99 self.data_type.field_with_name(&self.name)
100 }
101
102 pub fn data_type(&self) -> &AvroDataType {
104 &self.data_type
105 }
106
107 pub fn with_utf8view(&self) -> Self {
116 let mut field = self.clone();
117 if let Codec::Utf8 = field.data_type.codec {
118 field.data_type.codec = Codec::Utf8View;
119 }
120 field
121 }
122
123 pub fn name(&self) -> &str {
128 &self.name
129 }
130}
131
132impl<'a> TryFrom<&Schema<'a>> for AvroField {
133 type Error = ArrowError;
134
135 fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> {
136 match schema {
137 Schema::Complex(ComplexType::Record(r)) => {
138 let mut resolver = Resolver::default();
139 let data_type = make_data_type(schema, None, &mut resolver, false)?;
140 Ok(AvroField {
141 data_type,
142 name: r.name.to_string(),
143 })
144 }
145 _ => Err(ArrowError::ParseError(format!(
146 "Expected record got {schema:?}"
147 ))),
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
156pub enum Codec {
157 Null,
159 Boolean,
161 Int32,
163 Int64,
165 Float32,
167 Float64,
169 Binary,
171 Utf8,
173 Utf8View,
178 Date32,
180 TimeMillis,
182 TimeMicros,
184 TimestampMillis(bool),
189 TimestampMicros(bool),
194 Fixed(i32),
197 Decimal(usize, Option<usize>, Option<usize>),
204 Uuid,
206 List(Arc<AvroDataType>),
208 Struct(Arc<[AvroField]>),
210 Map(Arc<AvroDataType>),
212 Interval,
214}
215
216impl Codec {
217 fn data_type(&self) -> DataType {
218 match self {
219 Self::Null => DataType::Null,
220 Self::Boolean => DataType::Boolean,
221 Self::Int32 => DataType::Int32,
222 Self::Int64 => DataType::Int64,
223 Self::Float32 => DataType::Float32,
224 Self::Float64 => DataType::Float64,
225 Self::Binary => DataType::Binary,
226 Self::Utf8 => DataType::Utf8,
227 Self::Utf8View => DataType::Utf8View,
228 Self::Date32 => DataType::Date32,
229 Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
230 Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
231 Self::TimestampMillis(is_utc) => {
232 DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into()))
233 }
234 Self::TimestampMicros(is_utc) => {
235 DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into()))
236 }
237 Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
238 Self::Fixed(size) => DataType::FixedSizeBinary(*size),
239 Self::Decimal(precision, scale, size) => {
240 let p = *precision as u8;
241 let s = scale.unwrap_or(0) as i8;
242 let too_large_for_128 = match *size {
243 Some(sz) => sz > 16,
244 None => {
245 (p as usize) > DECIMAL128_MAX_PRECISION as usize
246 || (s as usize) > DECIMAL128_MAX_SCALE as usize
247 }
248 };
249 if too_large_for_128 {
250 Decimal256(p, s)
251 } else {
252 Decimal128(p, s)
253 }
254 }
255 Self::Uuid => DataType::FixedSizeBinary(16),
256 Self::List(f) => {
257 DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
258 }
259 Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
260 Self::Map(value_type) => {
261 let val_dt = value_type.codec.data_type();
262 let val_field = Field::new("value", val_dt, value_type.nullability.is_some())
263 .with_metadata(value_type.metadata.clone());
264 DataType::Map(
265 Arc::new(Field::new(
266 "entries",
267 DataType::Struct(Fields::from(vec![
268 Field::new("key", DataType::Utf8, false),
269 val_field,
270 ])),
271 false,
272 )),
273 false,
274 )
275 }
276 }
277 }
278}
279
280impl From<PrimitiveType> for Codec {
281 fn from(value: PrimitiveType) -> Self {
282 match value {
283 PrimitiveType::Null => Self::Null,
284 PrimitiveType::Boolean => Self::Boolean,
285 PrimitiveType::Int => Self::Int32,
286 PrimitiveType::Long => Self::Int64,
287 PrimitiveType::Float => Self::Float32,
288 PrimitiveType::Double => Self::Float64,
289 PrimitiveType::Bytes => Self::Binary,
290 PrimitiveType::String => Self::Utf8,
291 }
292 }
293}
294
295fn parse_decimal_attributes(
296 attributes: &Attributes,
297 fallback_size: Option<usize>,
298 precision_required: bool,
299) -> Result<(usize, usize, Option<usize>), ArrowError> {
300 let precision = attributes
301 .additional
302 .get("precision")
303 .and_then(|v| v.as_u64())
304 .or(if precision_required { None } else { Some(10) })
305 .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))?
306 as usize;
307 let scale = attributes
308 .additional
309 .get("scale")
310 .and_then(|v| v.as_u64())
311 .unwrap_or(0) as usize;
312 let size = attributes
313 .additional
314 .get("size")
315 .and_then(|v| v.as_u64())
316 .map(|s| s as usize)
317 .or(fallback_size);
318 Ok((precision, scale, size))
319}
320
321impl Codec {
322 pub fn with_utf8view(self, use_utf8view: bool) -> Self {
343 if use_utf8view && matches!(self, Self::Utf8) {
344 Self::Utf8View
345 } else {
346 self
347 }
348 }
349}
350
351#[derive(Debug, Default)]
355struct Resolver<'a> {
356 map: HashMap<(&'a str, &'a str), AvroDataType>,
357}
358
359impl<'a> Resolver<'a> {
360 fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) {
361 self.map.insert((name, namespace.unwrap_or("")), schema);
362 }
363
364 fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> {
365 let (namespace, name) = name
366 .rsplit_once('.')
367 .unwrap_or_else(|| (namespace.unwrap_or(""), name));
368
369 self.map
370 .get(&(namespace, name))
371 .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}")))
372 .cloned()
373 }
374}
375
376fn make_data_type<'a>(
389 schema: &Schema<'a>,
390 namespace: Option<&'a str>,
391 resolver: &mut Resolver<'a>,
392 use_utf8view: bool,
393) -> Result<AvroDataType, ArrowError> {
394 match schema {
395 Schema::TypeName(TypeName::Primitive(p)) => {
396 let codec: Codec = (*p).into();
397 let codec = codec.with_utf8view(use_utf8view);
398 Ok(AvroDataType {
399 nullability: None,
400 metadata: Default::default(),
401 codec,
402 })
403 }
404 Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace),
405 Schema::Union(f) => {
406 let null = f
408 .iter()
409 .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
410 match (f.len() == 2, null) {
411 (true, Some(0)) => {
412 let mut field = make_data_type(&f[1], namespace, resolver, use_utf8view)?;
413 field.nullability = Some(Nullability::NullFirst);
414 Ok(field)
415 }
416 (true, Some(1)) => {
417 let mut field = make_data_type(&f[0], namespace, resolver, use_utf8view)?;
418 field.nullability = Some(Nullability::NullSecond);
419 Ok(field)
420 }
421 _ => Err(ArrowError::NotYetImplemented(format!(
422 "Union of {f:?} not currently supported"
423 ))),
424 }
425 }
426 Schema::Complex(c) => match c {
427 ComplexType::Record(r) => {
428 let namespace = r.namespace.or(namespace);
429 let fields = r
430 .fields
431 .iter()
432 .map(|field| {
433 Ok(AvroField {
434 name: field.name.to_string(),
435 data_type: make_data_type(
436 &field.r#type,
437 namespace,
438 resolver,
439 use_utf8view,
440 )?,
441 })
442 })
443 .collect::<Result<_, ArrowError>>()?;
444
445 let field = AvroDataType {
446 nullability: None,
447 codec: Codec::Struct(fields),
448 metadata: r.attributes.field_metadata(),
449 };
450 resolver.register(r.name, namespace, field.clone());
451 Ok(field)
452 }
453 ComplexType::Array(a) => {
454 let mut field =
455 make_data_type(a.items.as_ref(), namespace, resolver, use_utf8view)?;
456 Ok(AvroDataType {
457 nullability: None,
458 metadata: a.attributes.field_metadata(),
459 codec: Codec::List(Arc::new(field)),
460 })
461 }
462 ComplexType::Fixed(f) => {
463 let size = f.size.try_into().map_err(|e| {
464 ArrowError::ParseError(format!("Overflow converting size to i32: {e}"))
465 })?;
466 let field = AvroDataType {
467 nullability: None,
468 metadata: f.attributes.field_metadata(),
469 codec: Codec::Fixed(size),
470 };
471 resolver.register(f.name, namespace, field.clone());
472 Ok(field)
473 }
474 ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!(
475 "Enum of {e:?} not currently supported"
476 ))),
477 ComplexType::Map(m) => {
478 let val = make_data_type(&m.values, namespace, resolver, use_utf8view)?;
479 Ok(AvroDataType {
480 nullability: None,
481 metadata: m.attributes.field_metadata(),
482 codec: Codec::Map(Arc::new(val)),
483 })
484 }
485 },
486 Schema::Type(t) => {
487 let mut field = make_data_type(
488 &Schema::TypeName(t.r#type.clone()),
489 namespace,
490 resolver,
491 use_utf8view,
492 )?;
493
494 match (t.attributes.logical_type, &mut field.codec) {
496 (Some("decimal"), c) => match *c {
497 Codec::Fixed(sz_val) => {
498 let (prec, sc, size_opt) =
499 parse_decimal_attributes(&t.attributes, Some(sz_val as usize), true)?;
500 let final_sz = if let Some(sz_actual) = size_opt {
501 sz_actual
502 } else {
503 sz_val as usize
504 };
505 *c = Codec::Decimal(prec, Some(sc), Some(final_sz));
506 }
507 Codec::Binary => {
508 let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?;
509 *c = Codec::Decimal(prec, Some(sc), None);
510 }
511 _ => {
512 return Err(ArrowError::SchemaError(format!(
513 "Decimal logical type can only be backed by Fixed or Bytes, found {c:?}"
514 )))
515 }
516 },
517 (Some("date"), c @ Codec::Int32) => *c = Codec::Date32,
518 (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis,
519 (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros,
520 (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true),
521 (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true),
522 (Some("local-timestamp-millis"), c @ Codec::Int64) => {
523 *c = Codec::TimestampMillis(false)
524 }
525 (Some("local-timestamp-micros"), c @ Codec::Int64) => {
526 *c = Codec::TimestampMicros(false)
527 }
528 (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval,
529 (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid,
530 (Some(logical), _) => {
531 field.metadata.insert("logicalType".into(), logical.into());
533 }
534 (None, _) => {}
535 }
536
537 if !t.attributes.additional.is_empty() {
538 for (k, v) in &t.attributes.additional {
539 field.metadata.insert(k.to_string(), v.to_string());
540 }
541 }
542 Ok(field)
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::schema::{
551 Attributes, ComplexType, Fixed, PrimitiveType, Record, Schema, Type, TypeName,
552 };
553 use serde_json;
554 use std::collections::HashMap;
555
556 fn create_schema_with_logical_type(
557 primitive_type: PrimitiveType,
558 logical_type: &'static str,
559 ) -> Schema<'static> {
560 let attributes = Attributes {
561 logical_type: Some(logical_type),
562 additional: Default::default(),
563 };
564
565 Schema::Type(Type {
566 r#type: TypeName::Primitive(primitive_type),
567 attributes,
568 })
569 }
570
571 fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> {
572 let attributes = Attributes {
573 logical_type: Some(logical_type),
574 additional: Default::default(),
575 };
576
577 Schema::Complex(ComplexType::Fixed(Fixed {
578 name: "fixed_type",
579 namespace: None,
580 aliases: Vec::new(),
581 size,
582 attributes,
583 }))
584 }
585
586 #[test]
587 fn test_date_logical_type() {
588 let schema = create_schema_with_logical_type(PrimitiveType::Int, "date");
589
590 let mut resolver = Resolver::default();
591 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
592
593 assert!(matches!(result.codec, Codec::Date32));
594 }
595
596 #[test]
597 fn test_time_millis_logical_type() {
598 let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis");
599
600 let mut resolver = Resolver::default();
601 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
602
603 assert!(matches!(result.codec, Codec::TimeMillis));
604 }
605
606 #[test]
607 fn test_time_micros_logical_type() {
608 let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros");
609
610 let mut resolver = Resolver::default();
611 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
612
613 assert!(matches!(result.codec, Codec::TimeMicros));
614 }
615
616 #[test]
617 fn test_timestamp_millis_logical_type() {
618 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis");
619
620 let mut resolver = Resolver::default();
621 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
622
623 assert!(matches!(result.codec, Codec::TimestampMillis(true)));
624 }
625
626 #[test]
627 fn test_timestamp_micros_logical_type() {
628 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros");
629
630 let mut resolver = Resolver::default();
631 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
632
633 assert!(matches!(result.codec, Codec::TimestampMicros(true)));
634 }
635
636 #[test]
637 fn test_local_timestamp_millis_logical_type() {
638 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis");
639
640 let mut resolver = Resolver::default();
641 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
642
643 assert!(matches!(result.codec, Codec::TimestampMillis(false)));
644 }
645
646 #[test]
647 fn test_local_timestamp_micros_logical_type() {
648 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros");
649
650 let mut resolver = Resolver::default();
651 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
652
653 assert!(matches!(result.codec, Codec::TimestampMicros(false)));
654 }
655
656 #[test]
657 fn test_uuid_type() {
658 let mut codec = Codec::Fixed(16);
659
660 if let c @ Codec::Fixed(16) = &mut codec {
661 *c = Codec::Uuid;
662 }
663
664 assert!(matches!(codec, Codec::Uuid));
665 }
666
667 #[test]
668 fn test_duration_logical_type() {
669 let mut codec = Codec::Fixed(12);
670
671 if let c @ Codec::Fixed(12) = &mut codec {
672 *c = Codec::Interval;
673 }
674
675 assert!(matches!(codec, Codec::Interval));
676 }
677
678 #[test]
679 fn test_decimal_logical_type_not_implemented() {
680 let mut codec = Codec::Fixed(16);
681
682 let process_decimal = || -> Result<(), ArrowError> {
683 if let Codec::Fixed(_) = codec {
684 return Err(ArrowError::NotYetImplemented(
685 "Decimals are not currently supported".to_string(),
686 ));
687 }
688 Ok(())
689 };
690
691 let result = process_decimal();
692
693 assert!(result.is_err());
694 if let Err(ArrowError::NotYetImplemented(msg)) = result {
695 assert!(msg.contains("Decimals are not currently supported"));
696 } else {
697 panic!("Expected NotYetImplemented error");
698 }
699 }
700
701 #[test]
702 fn test_unknown_logical_type_added_to_metadata() {
703 let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type");
704
705 let mut resolver = Resolver::default();
706 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
707
708 assert_eq!(
709 result.metadata.get("logicalType"),
710 Some(&"custom-type".to_string())
711 );
712 }
713
714 #[test]
715 fn test_string_with_utf8view_enabled() {
716 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
717
718 let mut resolver = Resolver::default();
719 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
720
721 assert!(matches!(result.codec, Codec::Utf8View));
722 }
723
724 #[test]
725 fn test_string_without_utf8view_enabled() {
726 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
727
728 let mut resolver = Resolver::default();
729 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
730
731 assert!(matches!(result.codec, Codec::Utf8));
732 }
733
734 #[test]
735 fn test_record_with_string_and_utf8view_enabled() {
736 let field_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
737
738 let avro_field = crate::schema::Field {
739 name: "string_field",
740 r#type: field_schema,
741 default: None,
742 doc: None,
743 };
744
745 let record = Record {
746 name: "test_record",
747 namespace: None,
748 aliases: vec![],
749 doc: None,
750 fields: vec![avro_field],
751 attributes: Attributes::default(),
752 };
753
754 let schema = Schema::Complex(ComplexType::Record(record));
755
756 let mut resolver = Resolver::default();
757 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
758
759 if let Codec::Struct(fields) = &result.codec {
760 let first_field_codec = &fields[0].data_type().codec;
761 assert!(matches!(first_field_codec, Codec::Utf8View));
762 } else {
763 panic!("Expected Struct codec");
764 }
765 }
766}