1use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName};
19use arrow_schema::{
20 ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef,
21 TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
22};
23use std::borrow::Cow;
24use std::collections::HashMap;
25use std::sync::Arc;
26
27#[derive(Debug, Copy, Clone)]
33pub enum Nullability {
34 NullFirst,
36 NullSecond,
38}
39
40#[derive(Debug, Clone)]
42pub struct AvroDataType {
43 nullability: Option<Nullability>,
44 metadata: HashMap<String, String>,
45 codec: Codec,
46}
47
48impl AvroDataType {
49 pub fn new(
51 codec: Codec,
52 metadata: HashMap<String, String>,
53 nullability: Option<Nullability>,
54 ) -> Self {
55 AvroDataType {
56 codec,
57 metadata,
58 nullability,
59 }
60 }
61
62 pub fn field_with_name(&self, name: &str) -> Field {
64 let d = self.codec.data_type();
65 Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone())
66 }
67
68 pub fn codec(&self) -> &Codec {
73 &self.codec
74 }
75
76 pub fn nullability(&self) -> Option<Nullability> {
84 self.nullability
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct AvroField {
91 name: String,
92 data_type: AvroDataType,
93}
94
95impl AvroField {
96 pub fn field(&self) -> Field {
98 self.data_type.field_with_name(&self.name)
99 }
100
101 pub fn data_type(&self) -> &AvroDataType {
103 &self.data_type
104 }
105
106 pub fn with_utf8view(&self) -> Self {
115 let mut field = self.clone();
116 if let Codec::Utf8 = field.data_type.codec {
117 field.data_type.codec = Codec::Utf8View;
118 }
119 field
120 }
121
122 pub fn name(&self) -> &str {
127 &self.name
128 }
129}
130
131impl<'a> TryFrom<&Schema<'a>> for AvroField {
132 type Error = ArrowError;
133
134 fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> {
135 match schema {
136 Schema::Complex(ComplexType::Record(r)) => {
137 let mut resolver = Resolver::default();
138 let data_type = make_data_type(schema, None, &mut resolver, false)?;
139 Ok(AvroField {
140 data_type,
141 name: r.name.to_string(),
142 })
143 }
144 _ => Err(ArrowError::ParseError(format!(
145 "Expected record got {schema:?}"
146 ))),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
155pub enum Codec {
156 Null,
158 Boolean,
160 Int32,
162 Int64,
164 Float32,
166 Float64,
168 Binary,
170 Utf8,
172 Utf8View,
177 Date32,
179 TimeMillis,
181 TimeMicros,
183 TimestampMillis(bool),
188 TimestampMicros(bool),
193 Fixed(i32),
196 Decimal(usize, Option<usize>, Option<usize>),
203 Uuid,
205 Enum(Arc<[String]>),
209 List(Arc<AvroDataType>),
211 Struct(Arc<[AvroField]>),
213 Map(Arc<AvroDataType>),
215 Interval,
217}
218
219impl Codec {
220 fn data_type(&self) -> DataType {
221 match self {
222 Self::Null => DataType::Null,
223 Self::Boolean => DataType::Boolean,
224 Self::Int32 => DataType::Int32,
225 Self::Int64 => DataType::Int64,
226 Self::Float32 => DataType::Float32,
227 Self::Float64 => DataType::Float64,
228 Self::Binary => DataType::Binary,
229 Self::Utf8 => DataType::Utf8,
230 Self::Utf8View => DataType::Utf8View,
231 Self::Date32 => DataType::Date32,
232 Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
233 Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
234 Self::TimestampMillis(is_utc) => {
235 DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into()))
236 }
237 Self::TimestampMicros(is_utc) => {
238 DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into()))
239 }
240 Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
241 Self::Fixed(size) => DataType::FixedSizeBinary(*size),
242 Self::Decimal(precision, scale, size) => {
243 let p = *precision as u8;
244 let s = scale.unwrap_or(0) as i8;
245 let too_large_for_128 = match *size {
246 Some(sz) => sz > 16,
247 None => {
248 (p as usize) > DECIMAL128_MAX_PRECISION as usize
249 || (s as usize) > DECIMAL128_MAX_SCALE as usize
250 }
251 };
252 if too_large_for_128 {
253 DataType::Decimal256(p, s)
254 } else {
255 DataType::Decimal128(p, s)
256 }
257 }
258 Self::Uuid => DataType::FixedSizeBinary(16),
259 Self::Enum(_) => {
260 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))
261 }
262 Self::List(f) => {
263 DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME)))
264 }
265 Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()),
266 Self::Map(value_type) => {
267 let val_dt = value_type.codec.data_type();
268 let val_field = Field::new("value", val_dt, value_type.nullability.is_some())
269 .with_metadata(value_type.metadata.clone());
270 DataType::Map(
271 Arc::new(Field::new(
272 "entries",
273 DataType::Struct(Fields::from(vec![
274 Field::new("key", DataType::Utf8, false),
275 val_field,
276 ])),
277 false,
278 )),
279 false,
280 )
281 }
282 }
283 }
284}
285
286impl From<PrimitiveType> for Codec {
287 fn from(value: PrimitiveType) -> Self {
288 match value {
289 PrimitiveType::Null => Self::Null,
290 PrimitiveType::Boolean => Self::Boolean,
291 PrimitiveType::Int => Self::Int32,
292 PrimitiveType::Long => Self::Int64,
293 PrimitiveType::Float => Self::Float32,
294 PrimitiveType::Double => Self::Float64,
295 PrimitiveType::Bytes => Self::Binary,
296 PrimitiveType::String => Self::Utf8,
297 }
298 }
299}
300
301fn parse_decimal_attributes(
302 attributes: &Attributes,
303 fallback_size: Option<usize>,
304 precision_required: bool,
305) -> Result<(usize, usize, Option<usize>), ArrowError> {
306 let precision = attributes
307 .additional
308 .get("precision")
309 .and_then(|v| v.as_u64())
310 .or(if precision_required { None } else { Some(10) })
311 .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))?
312 as usize;
313 let scale = attributes
314 .additional
315 .get("scale")
316 .and_then(|v| v.as_u64())
317 .unwrap_or(0) as usize;
318 let size = attributes
319 .additional
320 .get("size")
321 .and_then(|v| v.as_u64())
322 .map(|s| s as usize)
323 .or(fallback_size);
324 Ok((precision, scale, size))
325}
326
327impl Codec {
328 pub fn with_utf8view(self, use_utf8view: bool) -> Self {
349 if use_utf8view && matches!(self, Self::Utf8) {
350 Self::Utf8View
351 } else {
352 self
353 }
354 }
355}
356
357#[derive(Debug, Default)]
361struct Resolver<'a> {
362 map: HashMap<(&'a str, &'a str), AvroDataType>,
363}
364
365impl<'a> Resolver<'a> {
366 fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) {
367 self.map.insert((name, namespace.unwrap_or("")), schema);
368 }
369
370 fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> {
371 let (namespace, name) = name
372 .rsplit_once('.')
373 .unwrap_or_else(|| (namespace.unwrap_or(""), name));
374
375 self.map
376 .get(&(namespace, name))
377 .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}")))
378 .cloned()
379 }
380}
381
382fn make_data_type<'a>(
395 schema: &Schema<'a>,
396 namespace: Option<&'a str>,
397 resolver: &mut Resolver<'a>,
398 use_utf8view: bool,
399) -> Result<AvroDataType, ArrowError> {
400 match schema {
401 Schema::TypeName(TypeName::Primitive(p)) => {
402 let codec: Codec = (*p).into();
403 let codec = codec.with_utf8view(use_utf8view);
404 Ok(AvroDataType {
405 nullability: None,
406 metadata: Default::default(),
407 codec,
408 })
409 }
410 Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace),
411 Schema::Union(f) => {
412 let null = f
414 .iter()
415 .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
416 match (f.len() == 2, null) {
417 (true, Some(0)) => {
418 let mut field = make_data_type(&f[1], namespace, resolver, use_utf8view)?;
419 field.nullability = Some(Nullability::NullFirst);
420 Ok(field)
421 }
422 (true, Some(1)) => {
423 let mut field = make_data_type(&f[0], namespace, resolver, use_utf8view)?;
424 field.nullability = Some(Nullability::NullSecond);
425 Ok(field)
426 }
427 _ => Err(ArrowError::NotYetImplemented(format!(
428 "Union of {f:?} not currently supported"
429 ))),
430 }
431 }
432 Schema::Complex(c) => match c {
433 ComplexType::Record(r) => {
434 let namespace = r.namespace.or(namespace);
435 let fields = r
436 .fields
437 .iter()
438 .map(|field| {
439 Ok(AvroField {
440 name: field.name.to_string(),
441 data_type: make_data_type(
442 &field.r#type,
443 namespace,
444 resolver,
445 use_utf8view,
446 )?,
447 })
448 })
449 .collect::<Result<_, ArrowError>>()?;
450 let field = AvroDataType {
451 nullability: None,
452 codec: Codec::Struct(fields),
453 metadata: r.attributes.field_metadata(),
454 };
455 resolver.register(r.name, namespace, field.clone());
456 Ok(field)
457 }
458 ComplexType::Array(a) => {
459 let mut field =
460 make_data_type(a.items.as_ref(), namespace, resolver, use_utf8view)?;
461 Ok(AvroDataType {
462 nullability: None,
463 metadata: a.attributes.field_metadata(),
464 codec: Codec::List(Arc::new(field)),
465 })
466 }
467 ComplexType::Fixed(f) => {
468 let size = f.size.try_into().map_err(|e| {
469 ArrowError::ParseError(format!("Overflow converting size to i32: {e}"))
470 })?;
471 let md = f.attributes.field_metadata();
472 let field = match f.attributes.logical_type {
473 Some("decimal") => {
474 let (precision, scale, _) =
475 parse_decimal_attributes(&f.attributes, Some(size as usize), true)?;
476 AvroDataType {
477 nullability: None,
478 metadata: md,
479 codec: Codec::Decimal(precision, Some(scale), Some(size as usize)),
480 }
481 }
482 _ => AvroDataType {
483 nullability: None,
484 metadata: md,
485 codec: Codec::Fixed(size),
486 },
487 };
488 resolver.register(f.name, namespace, field.clone());
489 Ok(field)
490 }
491 ComplexType::Enum(e) => {
492 let namespace = e.namespace.or(namespace);
493 let symbols = e
494 .symbols
495 .iter()
496 .map(|s| s.to_string())
497 .collect::<Arc<[String]>>();
498
499 let mut metadata = e.attributes.field_metadata();
500 let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| {
501 ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}"))
502 })?;
503 metadata.insert("avro.enum.symbols".to_string(), symbols_json);
504 let field = AvroDataType {
505 nullability: None,
506 metadata,
507 codec: Codec::Enum(symbols),
508 };
509 resolver.register(e.name, namespace, field.clone());
510 Ok(field)
511 }
512 ComplexType::Map(m) => {
513 let val = make_data_type(&m.values, namespace, resolver, use_utf8view)?;
514 Ok(AvroDataType {
515 nullability: None,
516 metadata: m.attributes.field_metadata(),
517 codec: Codec::Map(Arc::new(val)),
518 })
519 }
520 },
521 Schema::Type(t) => {
522 let mut field = make_data_type(
523 &Schema::TypeName(t.r#type.clone()),
524 namespace,
525 resolver,
526 use_utf8view,
527 )?;
528
529 match (t.attributes.logical_type, &mut field.codec) {
531 (Some("decimal"), c @ Codec::Binary) => {
532 let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?;
533 *c = Codec::Decimal(prec, Some(sc), None);
534 }
535 (Some("date"), c @ Codec::Int32) => *c = Codec::Date32,
536 (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis,
537 (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros,
538 (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true),
539 (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true),
540 (Some("local-timestamp-millis"), c @ Codec::Int64) => {
541 *c = Codec::TimestampMillis(false)
542 }
543 (Some("local-timestamp-micros"), c @ Codec::Int64) => {
544 *c = Codec::TimestampMicros(false)
545 }
546 (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval,
547 (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid,
548 (Some(logical), _) => {
549 field.metadata.insert("logicalType".into(), logical.into());
551 }
552 (None, _) => {}
553 }
554
555 if !t.attributes.additional.is_empty() {
556 for (k, v) in &t.attributes.additional {
557 field.metadata.insert(k.to_string(), v.to_string());
558 }
559 }
560 Ok(field)
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::schema::{
569 Attributes, ComplexType, Fixed, PrimitiveType, Record, Schema, Type, TypeName,
570 };
571 use serde_json;
572 use std::collections::HashMap;
573
574 fn create_schema_with_logical_type(
575 primitive_type: PrimitiveType,
576 logical_type: &'static str,
577 ) -> Schema<'static> {
578 let attributes = Attributes {
579 logical_type: Some(logical_type),
580 additional: Default::default(),
581 };
582
583 Schema::Type(Type {
584 r#type: TypeName::Primitive(primitive_type),
585 attributes,
586 })
587 }
588
589 fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> {
590 let attributes = Attributes {
591 logical_type: Some(logical_type),
592 additional: Default::default(),
593 };
594
595 Schema::Complex(ComplexType::Fixed(Fixed {
596 name: "fixed_type",
597 namespace: None,
598 aliases: Vec::new(),
599 size,
600 attributes,
601 }))
602 }
603
604 #[test]
605 fn test_date_logical_type() {
606 let schema = create_schema_with_logical_type(PrimitiveType::Int, "date");
607
608 let mut resolver = Resolver::default();
609 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
610
611 assert!(matches!(result.codec, Codec::Date32));
612 }
613
614 #[test]
615 fn test_time_millis_logical_type() {
616 let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis");
617
618 let mut resolver = Resolver::default();
619 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
620
621 assert!(matches!(result.codec, Codec::TimeMillis));
622 }
623
624 #[test]
625 fn test_time_micros_logical_type() {
626 let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros");
627
628 let mut resolver = Resolver::default();
629 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
630
631 assert!(matches!(result.codec, Codec::TimeMicros));
632 }
633
634 #[test]
635 fn test_timestamp_millis_logical_type() {
636 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis");
637
638 let mut resolver = Resolver::default();
639 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
640
641 assert!(matches!(result.codec, Codec::TimestampMillis(true)));
642 }
643
644 #[test]
645 fn test_timestamp_micros_logical_type() {
646 let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros");
647
648 let mut resolver = Resolver::default();
649 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
650
651 assert!(matches!(result.codec, Codec::TimestampMicros(true)));
652 }
653
654 #[test]
655 fn test_local_timestamp_millis_logical_type() {
656 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis");
657
658 let mut resolver = Resolver::default();
659 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
660
661 assert!(matches!(result.codec, Codec::TimestampMillis(false)));
662 }
663
664 #[test]
665 fn test_local_timestamp_micros_logical_type() {
666 let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros");
667
668 let mut resolver = Resolver::default();
669 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
670
671 assert!(matches!(result.codec, Codec::TimestampMicros(false)));
672 }
673
674 #[test]
675 fn test_uuid_type() {
676 let mut codec = Codec::Fixed(16);
677
678 if let c @ Codec::Fixed(16) = &mut codec {
679 *c = Codec::Uuid;
680 }
681
682 assert!(matches!(codec, Codec::Uuid));
683 }
684
685 #[test]
686 fn test_duration_logical_type() {
687 let mut codec = Codec::Fixed(12);
688
689 if let c @ Codec::Fixed(12) = &mut codec {
690 *c = Codec::Interval;
691 }
692
693 assert!(matches!(codec, Codec::Interval));
694 }
695
696 #[test]
697 fn test_decimal_logical_type_not_implemented() {
698 let mut codec = Codec::Fixed(16);
699
700 let process_decimal = || -> Result<(), ArrowError> {
701 if let Codec::Fixed(_) = codec {
702 return Err(ArrowError::NotYetImplemented(
703 "Decimals are not currently supported".to_string(),
704 ));
705 }
706 Ok(())
707 };
708
709 let result = process_decimal();
710
711 assert!(result.is_err());
712 if let Err(ArrowError::NotYetImplemented(msg)) = result {
713 assert!(msg.contains("Decimals are not currently supported"));
714 } else {
715 panic!("Expected NotYetImplemented error");
716 }
717 }
718
719 #[test]
720 fn test_unknown_logical_type_added_to_metadata() {
721 let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type");
722
723 let mut resolver = Resolver::default();
724 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
725
726 assert_eq!(
727 result.metadata.get("logicalType"),
728 Some(&"custom-type".to_string())
729 );
730 }
731
732 #[test]
733 fn test_string_with_utf8view_enabled() {
734 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
735
736 let mut resolver = Resolver::default();
737 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
738
739 assert!(matches!(result.codec, Codec::Utf8View));
740 }
741
742 #[test]
743 fn test_string_without_utf8view_enabled() {
744 let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
745
746 let mut resolver = Resolver::default();
747 let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
748
749 assert!(matches!(result.codec, Codec::Utf8));
750 }
751
752 #[test]
753 fn test_record_with_string_and_utf8view_enabled() {
754 let field_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));
755
756 let avro_field = crate::schema::Field {
757 name: "string_field",
758 r#type: field_schema,
759 default: None,
760 doc: None,
761 };
762
763 let record = Record {
764 name: "test_record",
765 namespace: None,
766 aliases: vec![],
767 doc: None,
768 fields: vec![avro_field],
769 attributes: Attributes::default(),
770 };
771
772 let schema = Schema::Complex(ComplexType::Record(record));
773
774 let mut resolver = Resolver::default();
775 let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
776
777 if let Codec::Struct(fields) = &result.codec {
778 let first_field_codec = &fields[0].data_type().codec;
779 assert!(matches!(first_field_codec, Codec::Utf8View));
780 } else {
781 panic!("Expected Struct codec");
782 }
783 }
784}