1use crate::codec::AvroFieldBuilder;
61use crate::compression::CompressionCodec;
62use crate::schema::{
63 AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, SCHEMA_METADATA_KEY,
64};
65use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder};
66use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat};
67use arrow_array::RecordBatch;
68use arrow_schema::{ArrowError, Schema};
69use std::io::Write;
70use std::sync::Arc;
71
72pub mod encoder;
74pub mod format;
76
77#[derive(Debug, Clone)]
79pub struct WriterBuilder {
80 schema: Schema,
81 codec: Option<CompressionCodec>,
82 capacity: usize,
83 fingerprint_strategy: Option<FingerprintStrategy>,
84}
85
86impl WriterBuilder {
87 pub fn new(schema: Schema) -> Self {
94 Self {
95 schema,
96 codec: None,
97 capacity: 1024,
98 fingerprint_strategy: None,
99 }
100 }
101
102 pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self {
105 self.fingerprint_strategy = Some(strategy);
106 self
107 }
108
109 pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self {
111 self.codec = codec;
112 self
113 }
114
115 pub fn with_capacity(mut self, capacity: usize) -> Self {
117 self.capacity = capacity;
118 self
119 }
120
121 pub fn build<W, F>(self, mut writer: W) -> Result<Writer<W, F>, ArrowError>
124 where
125 W: Write,
126 F: AvroFormat,
127 {
128 let mut format = F::default();
129 let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) {
130 Some(json) => AvroSchema::new(json.clone()),
131 None => AvroSchema::try_from(&self.schema)?,
132 };
133 let maybe_fingerprint = if F::NEEDS_PREFIX {
134 match self.fingerprint_strategy {
135 Some(FingerprintStrategy::Id(id)) => Some(Fingerprint::Id(id)),
136 Some(strategy) => {
137 Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?)
138 }
139 None => Some(
140 avro_schema
141 .fingerprint(FingerprintAlgorithm::from(FingerprintStrategy::Rabin))?,
142 ),
143 }
144 } else {
145 None
146 };
147 let mut md = self.schema.metadata().clone();
148 md.insert(
149 SCHEMA_METADATA_KEY.to_string(),
150 avro_schema.clone().json_string,
151 );
152 let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md));
153 format.start_stream(&mut writer, &schema, self.codec)?;
154 let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?;
155 let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref())
156 .with_fingerprint(maybe_fingerprint)
157 .build()?;
158 Ok(Writer {
159 writer,
160 schema,
161 format,
162 compression: self.codec,
163 capacity: self.capacity,
164 encoder,
165 })
166 }
167}
168
169#[derive(Debug)]
177pub struct Writer<W: Write, F: AvroFormat> {
178 writer: W,
179 schema: Arc<Schema>,
180 format: F,
181 compression: Option<CompressionCodec>,
182 capacity: usize,
183 encoder: RecordEncoder,
184}
185
186pub type AvroWriter<W> = Writer<W, AvroOcfFormat>;
227
228pub type AvroStreamWriter<W> = Writer<W, AvroBinaryFormat>;
259
260impl<W: Write> Writer<W, AvroOcfFormat> {
261 pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
287 WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer)
288 }
289
290 pub fn sync_marker(&self) -> Option<&[u8; 16]> {
292 self.format.sync_marker()
293 }
294}
295
296impl<W: Write> Writer<W, AvroBinaryFormat> {
297 pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
326 WriterBuilder::new(schema).build::<W, AvroBinaryFormat>(writer)
327 }
328}
329
330impl<W: Write, F: AvroFormat> Writer<W, F> {
331 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
333 if batch.schema().fields() != self.schema.fields() {
334 return Err(ArrowError::SchemaError(
335 "Schema of RecordBatch differs from Writer schema".to_string(),
336 ));
337 }
338 match self.format.sync_marker() {
339 Some(&sync) => self.write_ocf_block(batch, &sync),
340 None => self.write_stream(batch),
341 }
342 }
343
344 pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> {
348 for b in batches {
349 self.write(b)?;
350 }
351 Ok(())
352 }
353
354 pub fn finish(&mut self) -> Result<(), ArrowError> {
356 self.writer
357 .flush()
358 .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e))
359 }
360
361 pub fn into_inner(self) -> W {
363 self.writer
364 }
365
366 fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> {
367 let mut buf = Vec::<u8>::with_capacity(1024);
368 self.encoder.encode(&mut buf, batch)?;
369 let encoded = match self.compression {
370 Some(codec) => codec.compress(&buf)?,
371 None => buf,
372 };
373 write_long(&mut self.writer, batch.num_rows() as i64)?;
374 write_long(&mut self.writer, encoded.len() as i64)?;
375 self.writer
376 .write_all(&encoded)
377 .map_err(|e| ArrowError::IoError(format!("Error writing Avro block: {e}"), e))?;
378 self.writer
379 .write_all(sync)
380 .map_err(|e| ArrowError::IoError(format!("Error writing Avro sync: {e}"), e))?;
381 Ok(())
382 }
383
384 fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
385 self.encoder.encode(&mut self.writer, batch)?;
386 Ok(())
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::compression::CompressionCodec;
394 use crate::reader::ReaderBuilder;
395 use crate::schema::{AvroSchema, SchemaStore, CONFLUENT_MAGIC};
396 use crate::test_util::arrow_test_data;
397 use arrow_array::{ArrayRef, BinaryArray, DurationSecondArray, Int32Array, RecordBatch};
398 use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit};
399 use std::collections::HashSet;
400 use std::fs::File;
401 use std::io::{BufReader, Cursor};
402 use std::path::PathBuf;
403 use std::sync::Arc;
404 use tempfile::NamedTempFile;
405
406 fn make_schema() -> Schema {
407 Schema::new(vec![
408 Field::new("id", DataType::Int32, false),
409 Field::new("name", DataType::Binary, false),
410 ])
411 }
412
413 fn make_batch() -> RecordBatch {
414 let ids = Int32Array::from(vec![1, 2, 3]);
415 let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]);
416 RecordBatch::try_new(
417 Arc::new(make_schema()),
418 vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef],
419 )
420 .expect("failed to build test RecordBatch")
421 }
422
423 #[test]
424 fn test_stream_writer_writes_prefix_per_row_rt() -> Result<(), ArrowError> {
425 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
426 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
427 let batch = RecordBatch::try_new(
428 Arc::new(schema.clone()),
429 vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef],
430 )?;
431 let buf: Vec<u8> = Vec::new();
432 let mut writer = AvroStreamWriter::new(buf, schema.clone())?;
433 writer.write(&batch)?;
434 let encoded = writer.into_inner();
435 let mut store = SchemaStore::new(); let avro_schema = AvroSchema::try_from(&schema)?;
437 let _fp = store.register(avro_schema)?;
438 let mut decoder = ReaderBuilder::new()
439 .with_writer_schema_store(store)
440 .build_decoder()?;
441 let _consumed = decoder.decode(&encoded)?;
442 let decoded = decoder
443 .flush()?
444 .expect("expected at least one batch from decoder");
445 assert_eq!(decoded.num_columns(), 1);
446 assert_eq!(decoded.num_rows(), 2);
447 let col = decoded
448 .column(0)
449 .as_any()
450 .downcast_ref::<Int32Array>()
451 .expect("int column");
452 assert_eq!(col, &Int32Array::from(vec![10, 20]));
453 Ok(())
454 }
455
456 #[test]
457 fn test_stream_writer_with_id_fingerprint_rt() -> Result<(), ArrowError> {
458 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
459 let batch = RecordBatch::try_new(
460 Arc::new(schema.clone()),
461 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
462 )?;
463 let schema_id: u32 = 42;
464 let mut writer = WriterBuilder::new(schema.clone())
465 .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id))
466 .build::<_, AvroBinaryFormat>(Vec::new())?;
467 writer.write(&batch)?;
468 let encoded = writer.into_inner();
469 let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None);
470 let avro_schema = AvroSchema::try_from(&schema)?;
471 let _ = store.set(Fingerprint::Id(schema_id), avro_schema)?;
472 let mut decoder = ReaderBuilder::new()
473 .with_writer_schema_store(store)
474 .build_decoder()?;
475 let _ = decoder.decode(&encoded)?;
476 let decoded = decoder
477 .flush()?
478 .expect("expected at least one batch from decoder");
479 assert_eq!(decoded.num_columns(), 1);
480 assert_eq!(decoded.num_rows(), 3);
481 let col = decoded
482 .column(0)
483 .as_any()
484 .downcast_ref::<Int32Array>()
485 .expect("int column");
486 assert_eq!(col, &Int32Array::from(vec![1, 2, 3]));
487 Ok(())
488 }
489
490 #[test]
491 fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> {
492 let batch = make_batch();
493 let buffer: Vec<u8> = Vec::new();
494 let mut writer = AvroWriter::new(buffer, make_schema())?;
495 writer.write(&batch)?;
496 writer.finish()?;
497 let out = writer.into_inner();
498 assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect");
499 let trailer = &out[out.len() - 16..];
500 assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker");
501 Ok(())
502 }
503
504 #[test]
505 fn test_schema_mismatch_yields_error() {
506 let batch = make_batch();
507 let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
508 let buffer = Vec::<u8>::new();
509 let mut writer = AvroWriter::new(buffer, alt_schema).unwrap();
510 let err = writer.write(&batch).unwrap_err();
511 assert!(matches!(err, ArrowError::SchemaError(_)));
512 }
513
514 #[test]
515 fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> {
516 let batch1 = make_batch();
517 let batch2 = make_batch();
518 let buffer = Vec::<u8>::new();
519 let mut writer = AvroWriter::new(buffer, make_schema())?;
520 writer.write_batches(&[&batch1, &batch2])?;
521 writer.finish()?;
522 let out = writer.into_inner();
523 assert!(out.len() > 4, "combined batches produced tiny file");
524 Ok(())
525 }
526
527 #[test]
528 fn test_finish_without_write_adds_header() -> Result<(), ArrowError> {
529 let buffer = Vec::<u8>::new();
530 let mut writer = AvroWriter::new(buffer, make_schema())?;
531 writer.finish()?;
532 let out = writer.into_inner();
533 assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header");
534 Ok(())
535 }
536
537 #[test]
538 fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> {
539 let mut buf = Vec::new();
540 write_long(&mut buf, 0)?;
541 write_long(&mut buf, -1)?;
542 write_long(&mut buf, 1)?;
543 write_long(&mut buf, -2)?;
544 write_long(&mut buf, 2147483647)?;
545 assert!(
546 buf.starts_with(&[0x00, 0x01, 0x02, 0x03]),
547 "zig‑zag varint encodings incorrect: {buf:?}"
548 );
549 Ok(())
550 }
551
552 #[test]
553 fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> {
554 let files = [
555 "avro/alltypes_plain.avro",
556 "avro/alltypes_plain.snappy.avro",
557 "avro/alltypes_plain.zstandard.avro",
558 "avro/alltypes_plain.bzip2.avro",
559 "avro/alltypes_plain.xz.avro",
560 ];
561 for rel in files {
562 let path = arrow_test_data(rel);
563 let rdr_file = File::open(&path).expect("open input avro");
564 let mut reader = ReaderBuilder::new()
565 .build(BufReader::new(rdr_file))
566 .expect("build reader");
567 let schema = reader.schema();
568 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
569 let original =
570 arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
571 let tmp = NamedTempFile::new().expect("create temp file");
572 let out_path = tmp.into_temp_path();
573 let out_file = File::create(&out_path).expect("create temp avro");
574 let codec = if rel.contains(".snappy.") {
575 Some(CompressionCodec::Snappy)
576 } else if rel.contains(".zstandard.") {
577 Some(CompressionCodec::ZStandard)
578 } else if rel.contains(".bzip2.") {
579 Some(CompressionCodec::Bzip2)
580 } else if rel.contains(".xz.") {
581 Some(CompressionCodec::Xz)
582 } else {
583 None
584 };
585 let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
586 .with_compression(codec)
587 .build::<_, AvroOcfFormat>(out_file)?;
588 writer.write(&original)?;
589 writer.finish()?;
590 drop(writer);
591 let rt_file = File::open(&out_path).expect("open roundtrip avro");
592 let mut rt_reader = ReaderBuilder::new()
593 .build(BufReader::new(rt_file))
594 .expect("build roundtrip reader");
595 let rt_schema = rt_reader.schema();
596 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
597 let roundtrip =
598 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
599 assert_eq!(
600 roundtrip, original,
601 "Round-trip batch mismatch for file: {}",
602 rel
603 );
604 }
605 Ok(())
606 }
607
608 #[test]
609 fn test_roundtrip_nested_records_writer() -> Result<(), ArrowError> {
610 let path = arrow_test_data("avro/nested_records.avro");
611 let rdr_file = File::open(&path).expect("open nested_records.avro");
612 let mut reader = ReaderBuilder::new()
613 .build(BufReader::new(rdr_file))
614 .expect("build reader for nested_records.avro");
615 let schema = reader.schema();
616 let batches = reader.collect::<Result<Vec<_>, _>>()?;
617 let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
618 let tmp = NamedTempFile::new().expect("create temp file");
619 let out_path = tmp.into_temp_path();
620 {
621 let out_file = File::create(&out_path).expect("create output avro");
622 let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
623 writer.write(&original)?;
624 writer.finish()?;
625 }
626 let rt_file = File::open(&out_path).expect("open round_trip avro");
627 let mut rt_reader = ReaderBuilder::new()
628 .build(BufReader::new(rt_file))
629 .expect("build round_trip reader");
630 let rt_schema = rt_reader.schema();
631 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
632 let round_trip =
633 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
634 assert_eq!(
635 round_trip, original,
636 "Round-trip batch mismatch for nested_records.avro"
637 );
638 Ok(())
639 }
640
641 #[test]
642 fn test_roundtrip_nested_lists_writer() -> Result<(), ArrowError> {
643 let path = arrow_test_data("avro/nested_lists.snappy.avro");
644 let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro");
645 let mut reader = ReaderBuilder::new()
646 .build(BufReader::new(rdr_file))
647 .expect("build reader for nested_lists.snappy.avro");
648 let schema = reader.schema();
649 let batches = reader.collect::<Result<Vec<_>, _>>()?;
650 let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
651 let tmp = NamedTempFile::new().expect("create temp file");
652 let out_path = tmp.into_temp_path();
653 {
654 let out_file = File::create(&out_path).expect("create output avro");
655 let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
656 .with_compression(Some(CompressionCodec::Snappy))
657 .build::<_, AvroOcfFormat>(out_file)?;
658 writer.write(&original)?;
659 writer.finish()?;
660 }
661 let rt_file = File::open(&out_path).expect("open round_trip avro");
662 let mut rt_reader = ReaderBuilder::new()
663 .build(BufReader::new(rt_file))
664 .expect("build round_trip reader");
665 let rt_schema = rt_reader.schema();
666 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
667 let round_trip =
668 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
669 assert_eq!(
670 round_trip, original,
671 "Round-trip batch mismatch for nested_lists.snappy.avro"
672 );
673 Ok(())
674 }
675
676 #[test]
677 fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> {
678 let path = arrow_test_data("avro/simple_fixed.avro");
679 let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro");
680 let mut reader = ReaderBuilder::new()
681 .build(BufReader::new(rdr_file))
682 .expect("build avro reader");
683 let schema = reader.schema();
684 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
685 let original =
686 arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
687 let tmp = NamedTempFile::new().expect("create temp file");
688 let out_file = File::create(tmp.path()).expect("create temp avro");
689 let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
690 writer.write(&original)?;
691 writer.finish()?;
692 drop(writer);
693 let rt_file = File::open(tmp.path()).expect("open round_trip avro");
694 let mut rt_reader = ReaderBuilder::new()
695 .build(BufReader::new(rt_file))
696 .expect("build round_trip reader");
697 let rt_schema = rt_reader.schema();
698 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
699 let round_trip =
700 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
701 assert_eq!(round_trip, original);
702 Ok(())
703 }
704
705 #[cfg(not(feature = "canonical_extension_types"))]
706 #[test]
707 fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> {
708 let in_file =
709 File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro");
710 let mut reader = ReaderBuilder::new()
711 .build(BufReader::new(in_file))
712 .expect("build reader for duration_uuid.avro");
713 let in_schema = reader.schema();
714 let has_mdn = in_schema.fields().iter().any(|f| {
715 matches!(
716 f.data_type(),
717 DataType::Interval(IntervalUnit::MonthDayNano)
718 )
719 });
720 assert!(
721 has_mdn,
722 "expected at least one Interval(MonthDayNano) field in duration_uuid.avro"
723 );
724 let has_uuid_fixed = in_schema
725 .fields()
726 .iter()
727 .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16)));
728 assert!(
729 has_uuid_fixed,
730 "expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro"
731 );
732 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
733 let input =
734 arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
735 let tmp = NamedTempFile::new().expect("create temp file");
736 {
737 let out_file = File::create(tmp.path()).expect("create temp avro");
738 let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?;
739 writer.write(&input)?;
740 writer.finish()?;
741 }
742 let rt_file = File::open(tmp.path()).expect("open round_trip avro");
743 let mut rt_reader = ReaderBuilder::new()
744 .build(BufReader::new(rt_file))
745 .expect("build round_trip reader");
746 let rt_schema = rt_reader.schema();
747 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
748 let round_trip =
749 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
750 assert_eq!(round_trip, input);
751 Ok(())
752 }
753
754 #[test]
758 fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> {
759 let path = arrow_test_data("avro/nonnullable.impala.avro");
761 let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro");
762 let mut reader = ReaderBuilder::new()
763 .build(BufReader::new(rdr_file))
764 .expect("build reader for nonnullable.impala.avro");
765 let in_schema = reader.schema();
767 let has_map = in_schema
769 .fields()
770 .iter()
771 .any(|f| matches!(f.data_type(), DataType::Map(_, _)));
772 assert!(
773 has_map,
774 "expected at least one Map field in avro/nonnullable.impala.avro"
775 );
776
777 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
778 let original =
779 arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
780 let buffer = Vec::<u8>::new();
782 let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
783 writer.write(&original)?;
784 writer.finish()?;
785 let out_bytes = writer.into_inner();
786 let mut rt_reader = ReaderBuilder::new()
788 .build(Cursor::new(out_bytes))
789 .expect("build reader for round-tripped in-memory OCF");
790 let rt_schema = rt_reader.schema();
791 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
792 let roundtrip =
793 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
794 assert_eq!(
796 roundtrip, original,
797 "Round-trip Avro map data mismatch for nonnullable.impala.avro"
798 );
799 Ok(())
800 }
801
802 #[test]
803 fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> {
804 let files: [(&str, bool); 8] = [
806 ("avro/fixed_length_decimal.avro", true), ("avro/fixed_length_decimal_legacy.avro", true), ("avro/int32_decimal.avro", true), ("avro/int64_decimal.avro", true), ("test/data/int256_decimal.avro", false), ("test/data/fixed256_decimal.avro", false), ("test/data/fixed_length_decimal_legacy_32.avro", false), ("test/data/int128_decimal.avro", false), ];
815 for (rel, in_test_data_dir) in files {
816 let path: String = if in_test_data_dir {
818 arrow_test_data(rel)
819 } else {
820 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
821 .join(rel)
822 .to_string_lossy()
823 .into_owned()
824 };
825 let f_in = File::open(&path).expect("open input avro");
827 let mut rdr = ReaderBuilder::new().build(BufReader::new(f_in))?;
828 let in_schema = rdr.schema();
829 let in_batches = rdr.collect::<Result<Vec<_>, _>>()?;
830 let original =
831 arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
832 let tmp = NamedTempFile::new().expect("create temp file");
834 let out_path = tmp.into_temp_path();
835 let out_file = File::create(&out_path).expect("create temp avro");
836 let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
837 writer.write(&original)?;
838 writer.finish()?;
839 let f_rt = File::open(&out_path).expect("open roundtrip avro");
841 let mut rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?;
842 let rt_schema = rt_rdr.schema();
843 let rt_batches = rt_rdr.collect::<Result<Vec<_>, _>>()?;
844 let roundtrip =
845 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt");
846 assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}");
847 }
848 Ok(())
849 }
850
851 #[test]
852 fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> {
853 let path = arrow_test_data("avro/simple_enum.avro");
855 let rdr_file = File::open(&path).expect("open avro/simple_enum.avro");
856 let mut reader = ReaderBuilder::new()
857 .build(BufReader::new(rdr_file))
858 .expect("build reader for simple_enum.avro");
859 let in_schema = reader.schema();
861 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
862 let original =
863 arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
864 let has_enum_dict = in_schema.fields().iter().any(|f| {
866 matches!(
867 f.data_type(),
868 DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8
869 )
870 });
871 assert!(
872 has_enum_dict,
873 "Expected at least one enum-mapped Dictionary<Int32, Utf8> field"
874 );
875 let buffer: Vec<u8> = Vec::new();
878 let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
879 writer.write(&original)?;
880 writer.finish()?;
881 let bytes = writer.into_inner();
882 let mut rt_reader = ReaderBuilder::new()
884 .build(Cursor::new(bytes))
885 .expect("reader for round-trip");
886 let rt_schema = rt_reader.schema();
887 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
888 let roundtrip =
889 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
890 assert_eq!(roundtrip, original, "Avro enum round-trip mismatch");
891 Ok(())
892 }
893
894 #[test]
895 #[cfg(feature = "avro_custom_types")]
896 fn test_roundtrip_duration_logical_types_ocf() -> Result<(), ArrowError> {
897 let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
898 .join("test/data/duration_logical_types.avro")
899 .to_string_lossy()
900 .into_owned();
901
902 let in_file = File::open(&file_path)
903 .unwrap_or_else(|_| panic!("Failed to open test file: {}", file_path));
904
905 let mut reader = ReaderBuilder::new()
906 .build(BufReader::new(in_file))
907 .expect("build reader for duration_logical_types.avro");
908 let in_schema = reader.schema();
909
910 let expected_units: HashSet<TimeUnit> = [
911 TimeUnit::Nanosecond,
912 TimeUnit::Microsecond,
913 TimeUnit::Millisecond,
914 TimeUnit::Second,
915 ]
916 .into_iter()
917 .collect();
918
919 let found_units: HashSet<TimeUnit> = in_schema
920 .fields()
921 .iter()
922 .filter_map(|f| match f.data_type() {
923 DataType::Duration(unit) => Some(*unit),
924 _ => None,
925 })
926 .collect();
927
928 assert_eq!(
929 found_units, expected_units,
930 "Expected to find all four Duration TimeUnits in the schema from the initial read"
931 );
932
933 let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
934 let input =
935 arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
936
937 let tmp = NamedTempFile::new().expect("create temp file");
938 {
939 let out_file = File::create(tmp.path()).expect("create temp avro");
940 let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?;
941 writer.write(&input)?;
942 writer.finish()?;
943 }
944
945 let rt_file = File::open(tmp.path()).expect("open round_trip avro");
946 let mut rt_reader = ReaderBuilder::new()
947 .build(BufReader::new(rt_file))
948 .expect("build round_trip reader");
949 let rt_schema = rt_reader.schema();
950 let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
951 let round_trip =
952 arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
953
954 assert_eq!(round_trip, input);
955
956 Ok(())
957 }
958}