arrow_avro/writer/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Avro writer implementation for the `arrow-avro` crate.
19//!
20//! # Overview
21//!
22//! Use this module to serialize Arrow `RecordBatch` values into Avro. Two output
23//! formats are supported:
24//!
25//! * **[`AvroWriter`](crate::writer::AvroWriter)** — writes an **Object Container File (OCF)**: a self‑describing
26//!   file with header (schema JSON + metadata), optional compression, data blocks, and
27//!   sync markers. See Avro 1.11.1 “Object Container Files.”
28//!   <https://avro.apache.org/docs/1.11.1/specification/#object-container-files>
29//! * **[`AvroStreamWriter`](crate::writer::AvroStreamWriter)** — writes a **raw Avro binary stream** (“datum” bytes) without
30//!   any container framing. This is useful when the schema is known out‑of‑band (i.e.,
31//!   via a registry) and you want minimal overhead.
32//!
33//! ## Which format should I use?
34//!
35//! * Use **OCF** when you need a portable, self‑contained file. The schema travels with
36//!   the data, making it easy to read elsewhere.
37//! * Use the **raw stream** when your surrounding protocol supplies schema information
38//!   (i.e., a schema registry). If you need **single‑object encoding (SOE)** or Confluent
39//!   **Schema Registry** framing, you must add the appropriate prefix *outside* this writer:
40//!   - **SOE**: `0xC3 0x01` + 8‑byte little‑endian CRC‑64‑AVRO fingerprint + Avro body
41//!     (see Avro 1.11.1 “Single object encoding”).
42//!     <https://avro.apache.org/docs/1.11.1/specification/#single-object-encoding>
43//!   - **Confluent wire format**: magic `0x00` + **big‑endian** 4‑byte schema ID and Avro body.
44//!     <https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format>
45//!
46//! ## Choosing the Avro schema
47//!
48//! By default, the writer converts your Arrow schema to Avro (including a top‑level record
49//! name) and stores the resulting JSON under the `avro::schema` metadata key. If you already
50//! have an Avro schema JSON, you want to use verbatim, put it into the Arrow schema metadata
51//! under the same key before constructing the writer. The builder will pick it up.
52//!
53//! ## Compression
54//!
55//! For OCF, you may enable a compression codec via `WriterBuilder::with_compression`. The
56//! chosen codec is written into the file header and used for subsequent blocks. Raw stream
57//! writing doesn’t apply container‑level compression.
58//!
59//! ---
60use crate::codec::AvroFieldBuilder;
61use crate::compression::CompressionCodec;
62use crate::schema::{
63    AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, SCHEMA_METADATA_KEY,
64};
65use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder};
66use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat};
67use arrow_array::RecordBatch;
68use arrow_schema::{ArrowError, Schema};
69use std::io::Write;
70use std::sync::Arc;
71
72/// Encodes `RecordBatch` into the Avro binary format.
73pub mod encoder;
74/// Logic for different Avro container file formats.
75pub mod format;
76
77/// Builder to configure and create a `Writer`.
78#[derive(Debug, Clone)]
79pub struct WriterBuilder {
80    schema: Schema,
81    codec: Option<CompressionCodec>,
82    capacity: usize,
83    fingerprint_strategy: Option<FingerprintStrategy>,
84}
85
86impl WriterBuilder {
87    /// Create a new builder with default settings.
88    ///
89    /// The Avro schema used for writing is determined as follows:
90    /// 1) If the Arrow schema metadata contains `avro::schema` (see `SCHEMA_METADATA_KEY`),
91    ///    that JSON is used verbatim.
92    /// 2) Otherwise, the Arrow schema is converted to an Avro record schema.
93    pub fn new(schema: Schema) -> Self {
94        Self {
95            schema,
96            codec: None,
97            capacity: 1024,
98            fingerprint_strategy: None,
99        }
100    }
101
102    /// Set the fingerprinting strategy for the stream writer.
103    /// This determines the per-record prefix format.
104    pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self {
105        self.fingerprint_strategy = Some(strategy);
106        self
107    }
108
109    /// Change the compression codec.
110    pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self {
111        self.codec = codec;
112        self
113    }
114
115    /// Sets the capacity for the given object and returns the modified instance.
116    pub fn with_capacity(mut self, capacity: usize) -> Self {
117        self.capacity = capacity;
118        self
119    }
120
121    /// Create a new `Writer` with specified `AvroFormat` and builder options.
122    /// Performs one‑time startup (header/stream init, encoder plan).
123    pub fn build<W, F>(self, mut writer: W) -> Result<Writer<W, F>, ArrowError>
124    where
125        W: Write,
126        F: AvroFormat,
127    {
128        let mut format = F::default();
129        let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) {
130            Some(json) => AvroSchema::new(json.clone()),
131            None => AvroSchema::try_from(&self.schema)?,
132        };
133        let maybe_fingerprint = if F::NEEDS_PREFIX {
134            match self.fingerprint_strategy {
135                Some(FingerprintStrategy::Id(id)) => Some(Fingerprint::Id(id)),
136                Some(strategy) => {
137                    Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?)
138                }
139                None => Some(
140                    avro_schema
141                        .fingerprint(FingerprintAlgorithm::from(FingerprintStrategy::Rabin))?,
142                ),
143            }
144        } else {
145            None
146        };
147        let mut md = self.schema.metadata().clone();
148        md.insert(
149            SCHEMA_METADATA_KEY.to_string(),
150            avro_schema.clone().json_string,
151        );
152        let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md));
153        format.start_stream(&mut writer, &schema, self.codec)?;
154        let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?;
155        let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref())
156            .with_fingerprint(maybe_fingerprint)
157            .build()?;
158        Ok(Writer {
159            writer,
160            schema,
161            format,
162            compression: self.codec,
163            capacity: self.capacity,
164            encoder,
165        })
166    }
167}
168
169/// Generic Avro writer.
170///
171/// This type is generic over the output Write sink (`W`) and the Avro format (`F`).
172/// You’ll usually use the concrete aliases:
173///
174/// * **[`AvroWriter`]** for **OCF** (self‑describing container file)
175/// * **[`AvroStreamWriter`]** for **raw** Avro binary streams
176#[derive(Debug)]
177pub struct Writer<W: Write, F: AvroFormat> {
178    writer: W,
179    schema: Arc<Schema>,
180    format: F,
181    compression: Option<CompressionCodec>,
182    capacity: usize,
183    encoder: RecordEncoder,
184}
185
186/// Alias for an Avro **Object Container File** writer.
187///
188/// ### Quickstart (runnable)
189///
190/// ```
191/// use std::io::Cursor;
192/// use std::sync::Arc;
193/// use arrow_array::{ArrayRef, Int64Array, StringArray, RecordBatch};
194/// use arrow_schema::{DataType, Field, Schema};
195/// use arrow_avro::writer::AvroWriter;
196/// use arrow_avro::reader::ReaderBuilder;
197///
198/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
199/// // Writer schema: { id: long, name: string }
200/// let writer_schema = Schema::new(vec![
201///     Field::new("id", DataType::Int64, false),
202///     Field::new("name", DataType::Utf8, false),
203/// ]);
204///
205/// // Build a RecordBatch with two rows
206/// let batch = RecordBatch::try_new(
207///     Arc::new(writer_schema.clone()),
208///     vec![
209///         Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef,
210///         Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef,
211///     ],
212/// )?;
213///
214/// // Write an Avro **Object Container File** (OCF) to memory
215/// let mut w = AvroWriter::new(Vec::<u8>::new(), writer_schema.clone())?;
216/// w.write(&batch)?;
217/// w.finish()?;
218/// let bytes = w.into_inner();
219///
220/// // Build a Reader and decode the batch back
221/// let mut r = ReaderBuilder::new().build(Cursor::new(bytes))?;
222/// let out = r.next().unwrap()?;
223/// assert_eq!(out.num_rows(), 2);
224/// # Ok(()) }
225/// ```
226pub type AvroWriter<W> = Writer<W, AvroOcfFormat>;
227
228/// Alias for a raw Avro **binary stream** writer.
229///
230/// ### Example
231///
232/// This writes only the **Avro body** bytes — no OCF header/sync and no
233/// single‑object or Confluent framing. If you need those frames, add them externally.
234///
235/// ```
236/// use std::sync::Arc;
237/// use arrow_array::{ArrayRef, Int64Array, RecordBatch};
238/// use arrow_schema::{DataType, Field, Schema};
239/// use arrow_avro::writer::AvroStreamWriter;
240///
241/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
242/// // One‑column Arrow batch
243/// let schema = Schema::new(vec![Field::new("x", DataType::Int64, false)]);
244/// let batch = RecordBatch::try_new(
245///     Arc::new(schema.clone()),
246///     vec![Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef],
247/// )?;
248///
249/// // Write a raw Avro stream to a Vec<u8>
250/// let sink: Vec<u8> = Vec::new();
251/// let mut w = AvroStreamWriter::new(sink, schema)?;
252/// w.write(&batch)?;
253/// w.finish()?;
254/// let bytes = w.into_inner();
255/// assert!(!bytes.is_empty());
256/// # Ok(()) }
257/// ```
258pub type AvroStreamWriter<W> = Writer<W, AvroBinaryFormat>;
259
260impl<W: Write> Writer<W, AvroOcfFormat> {
261    /// Convenience constructor – same as [`WriterBuilder::build`] with `AvroOcfFormat`.
262    ///
263    /// ### Example
264    ///
265    /// ```
266    /// use std::sync::Arc;
267    /// use arrow_array::{ArrayRef, Int32Array, RecordBatch};
268    /// use arrow_schema::{DataType, Field, Schema};
269    /// use arrow_avro::writer::AvroWriter;
270    ///
271    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
272    /// let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
273    /// let batch = RecordBatch::try_new(
274    ///     Arc::new(schema.clone()),
275    ///     vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
276    /// )?;
277    ///
278    /// let buf: Vec<u8> = Vec::new();
279    /// let mut w = AvroWriter::new(buf, schema)?;
280    /// w.write(&batch)?;
281    /// w.finish()?;
282    /// let bytes = w.into_inner();
283    /// assert!(!bytes.is_empty());
284    /// # Ok(()) }
285    /// ```
286    pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
287        WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer)
288    }
289
290    /// Return a reference to the 16‑byte sync marker generated for this file.
291    pub fn sync_marker(&self) -> Option<&[u8; 16]> {
292        self.format.sync_marker()
293    }
294}
295
296impl<W: Write> Writer<W, AvroBinaryFormat> {
297    /// Convenience constructor to create a new [`AvroStreamWriter`].
298    ///
299    /// The resulting stream contains just **Avro binary** bodies (no OCF header/sync and no
300    /// single‑object or Confluent framing). If you need those frames, add them externally.
301    ///
302    /// ### Example
303    ///
304    /// ```
305    /// use std::sync::Arc;
306    /// use arrow_array::{ArrayRef, Int64Array, RecordBatch};
307    /// use arrow_schema::{DataType, Field, Schema};
308    /// use arrow_avro::writer::AvroStreamWriter;
309    ///
310    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
311    /// let schema = Schema::new(vec![Field::new("x", DataType::Int64, false)]);
312    /// let batch = RecordBatch::try_new(
313    ///     Arc::new(schema.clone()),
314    ///     vec![Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef],
315    /// )?;
316    ///
317    /// let sink: Vec<u8> = Vec::new();
318    /// let mut w = AvroStreamWriter::new(sink, schema)?;
319    /// w.write(&batch)?;
320    /// w.finish()?;
321    /// let bytes = w.into_inner();
322    /// assert!(!bytes.is_empty());
323    /// # Ok(()) }
324    /// ```
325    pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
326        WriterBuilder::new(schema).build::<W, AvroBinaryFormat>(writer)
327    }
328}
329
330impl<W: Write, F: AvroFormat> Writer<W, F> {
331    /// Serialize one [`RecordBatch`] to the output.
332    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
333        if batch.schema().fields() != self.schema.fields() {
334            return Err(ArrowError::SchemaError(
335                "Schema of RecordBatch differs from Writer schema".to_string(),
336            ));
337        }
338        match self.format.sync_marker() {
339            Some(&sync) => self.write_ocf_block(batch, &sync),
340            None => self.write_stream(batch),
341        }
342    }
343
344    /// A convenience method to write a slice of [`RecordBatch`].
345    ///
346    /// This is equivalent to calling `write` for each batch in the slice.
347    pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> {
348        for b in batches {
349            self.write(b)?;
350        }
351        Ok(())
352    }
353
354    /// Flush remaining buffered data and (for OCF) ensure the header is present.
355    pub fn finish(&mut self) -> Result<(), ArrowError> {
356        self.writer
357            .flush()
358            .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e))
359    }
360
361    /// Consume the writer, returning the underlying output object.
362    pub fn into_inner(self) -> W {
363        self.writer
364    }
365
366    fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> {
367        let mut buf = Vec::<u8>::with_capacity(1024);
368        self.encoder.encode(&mut buf, batch)?;
369        let encoded = match self.compression {
370            Some(codec) => codec.compress(&buf)?,
371            None => buf,
372        };
373        write_long(&mut self.writer, batch.num_rows() as i64)?;
374        write_long(&mut self.writer, encoded.len() as i64)?;
375        self.writer
376            .write_all(&encoded)
377            .map_err(|e| ArrowError::IoError(format!("Error writing Avro block: {e}"), e))?;
378        self.writer
379            .write_all(sync)
380            .map_err(|e| ArrowError::IoError(format!("Error writing Avro sync: {e}"), e))?;
381        Ok(())
382    }
383
384    fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
385        self.encoder.encode(&mut self.writer, batch)?;
386        Ok(())
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::compression::CompressionCodec;
394    use crate::reader::ReaderBuilder;
395    use crate::schema::{AvroSchema, SchemaStore, CONFLUENT_MAGIC};
396    use crate::test_util::arrow_test_data;
397    use arrow_array::{ArrayRef, BinaryArray, DurationSecondArray, Int32Array, RecordBatch};
398    use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit};
399    use std::collections::HashSet;
400    use std::fs::File;
401    use std::io::{BufReader, Cursor};
402    use std::path::PathBuf;
403    use std::sync::Arc;
404    use tempfile::NamedTempFile;
405
406    fn make_schema() -> Schema {
407        Schema::new(vec![
408            Field::new("id", DataType::Int32, false),
409            Field::new("name", DataType::Binary, false),
410        ])
411    }
412
413    fn make_batch() -> RecordBatch {
414        let ids = Int32Array::from(vec![1, 2, 3]);
415        let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]);
416        RecordBatch::try_new(
417            Arc::new(make_schema()),
418            vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef],
419        )
420        .expect("failed to build test RecordBatch")
421    }
422
423    #[test]
424    fn test_stream_writer_writes_prefix_per_row_rt() -> Result<(), ArrowError> {
425        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
426        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
427        let batch = RecordBatch::try_new(
428            Arc::new(schema.clone()),
429            vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef],
430        )?;
431        let buf: Vec<u8> = Vec::new();
432        let mut writer = AvroStreamWriter::new(buf, schema.clone())?;
433        writer.write(&batch)?;
434        let encoded = writer.into_inner();
435        let mut store = SchemaStore::new(); // Rabin by default
436        let avro_schema = AvroSchema::try_from(&schema)?;
437        let _fp = store.register(avro_schema)?;
438        let mut decoder = ReaderBuilder::new()
439            .with_writer_schema_store(store)
440            .build_decoder()?;
441        let _consumed = decoder.decode(&encoded)?;
442        let decoded = decoder
443            .flush()?
444            .expect("expected at least one batch from decoder");
445        assert_eq!(decoded.num_columns(), 1);
446        assert_eq!(decoded.num_rows(), 2);
447        let col = decoded
448            .column(0)
449            .as_any()
450            .downcast_ref::<Int32Array>()
451            .expect("int column");
452        assert_eq!(col, &Int32Array::from(vec![10, 20]));
453        Ok(())
454    }
455
456    #[test]
457    fn test_stream_writer_with_id_fingerprint_rt() -> Result<(), ArrowError> {
458        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
459        let batch = RecordBatch::try_new(
460            Arc::new(schema.clone()),
461            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
462        )?;
463        let schema_id: u32 = 42;
464        let mut writer = WriterBuilder::new(schema.clone())
465            .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id))
466            .build::<_, AvroBinaryFormat>(Vec::new())?;
467        writer.write(&batch)?;
468        let encoded = writer.into_inner();
469        let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None);
470        let avro_schema = AvroSchema::try_from(&schema)?;
471        let _ = store.set(Fingerprint::Id(schema_id), avro_schema)?;
472        let mut decoder = ReaderBuilder::new()
473            .with_writer_schema_store(store)
474            .build_decoder()?;
475        let _ = decoder.decode(&encoded)?;
476        let decoded = decoder
477            .flush()?
478            .expect("expected at least one batch from decoder");
479        assert_eq!(decoded.num_columns(), 1);
480        assert_eq!(decoded.num_rows(), 3);
481        let col = decoded
482            .column(0)
483            .as_any()
484            .downcast_ref::<Int32Array>()
485            .expect("int column");
486        assert_eq!(col, &Int32Array::from(vec![1, 2, 3]));
487        Ok(())
488    }
489
490    #[test]
491    fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> {
492        let batch = make_batch();
493        let buffer: Vec<u8> = Vec::new();
494        let mut writer = AvroWriter::new(buffer, make_schema())?;
495        writer.write(&batch)?;
496        writer.finish()?;
497        let out = writer.into_inner();
498        assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect");
499        let trailer = &out[out.len() - 16..];
500        assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker");
501        Ok(())
502    }
503
504    #[test]
505    fn test_schema_mismatch_yields_error() {
506        let batch = make_batch();
507        let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
508        let buffer = Vec::<u8>::new();
509        let mut writer = AvroWriter::new(buffer, alt_schema).unwrap();
510        let err = writer.write(&batch).unwrap_err();
511        assert!(matches!(err, ArrowError::SchemaError(_)));
512    }
513
514    #[test]
515    fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> {
516        let batch1 = make_batch();
517        let batch2 = make_batch();
518        let buffer = Vec::<u8>::new();
519        let mut writer = AvroWriter::new(buffer, make_schema())?;
520        writer.write_batches(&[&batch1, &batch2])?;
521        writer.finish()?;
522        let out = writer.into_inner();
523        assert!(out.len() > 4, "combined batches produced tiny file");
524        Ok(())
525    }
526
527    #[test]
528    fn test_finish_without_write_adds_header() -> Result<(), ArrowError> {
529        let buffer = Vec::<u8>::new();
530        let mut writer = AvroWriter::new(buffer, make_schema())?;
531        writer.finish()?;
532        let out = writer.into_inner();
533        assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header");
534        Ok(())
535    }
536
537    #[test]
538    fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> {
539        let mut buf = Vec::new();
540        write_long(&mut buf, 0)?;
541        write_long(&mut buf, -1)?;
542        write_long(&mut buf, 1)?;
543        write_long(&mut buf, -2)?;
544        write_long(&mut buf, 2147483647)?;
545        assert!(
546            buf.starts_with(&[0x00, 0x01, 0x02, 0x03]),
547            "zig‑zag varint encodings incorrect: {buf:?}"
548        );
549        Ok(())
550    }
551
552    #[test]
553    fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> {
554        let files = [
555            "avro/alltypes_plain.avro",
556            "avro/alltypes_plain.snappy.avro",
557            "avro/alltypes_plain.zstandard.avro",
558            "avro/alltypes_plain.bzip2.avro",
559            "avro/alltypes_plain.xz.avro",
560        ];
561        for rel in files {
562            let path = arrow_test_data(rel);
563            let rdr_file = File::open(&path).expect("open input avro");
564            let mut reader = ReaderBuilder::new()
565                .build(BufReader::new(rdr_file))
566                .expect("build reader");
567            let schema = reader.schema();
568            let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
569            let original =
570                arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
571            let tmp = NamedTempFile::new().expect("create temp file");
572            let out_path = tmp.into_temp_path();
573            let out_file = File::create(&out_path).expect("create temp avro");
574            let codec = if rel.contains(".snappy.") {
575                Some(CompressionCodec::Snappy)
576            } else if rel.contains(".zstandard.") {
577                Some(CompressionCodec::ZStandard)
578            } else if rel.contains(".bzip2.") {
579                Some(CompressionCodec::Bzip2)
580            } else if rel.contains(".xz.") {
581                Some(CompressionCodec::Xz)
582            } else {
583                None
584            };
585            let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
586                .with_compression(codec)
587                .build::<_, AvroOcfFormat>(out_file)?;
588            writer.write(&original)?;
589            writer.finish()?;
590            drop(writer);
591            let rt_file = File::open(&out_path).expect("open roundtrip avro");
592            let mut rt_reader = ReaderBuilder::new()
593                .build(BufReader::new(rt_file))
594                .expect("build roundtrip reader");
595            let rt_schema = rt_reader.schema();
596            let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
597            let roundtrip =
598                arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
599            assert_eq!(
600                roundtrip, original,
601                "Round-trip batch mismatch for file: {}",
602                rel
603            );
604        }
605        Ok(())
606    }
607
608    #[test]
609    fn test_roundtrip_nested_records_writer() -> Result<(), ArrowError> {
610        let path = arrow_test_data("avro/nested_records.avro");
611        let rdr_file = File::open(&path).expect("open nested_records.avro");
612        let mut reader = ReaderBuilder::new()
613            .build(BufReader::new(rdr_file))
614            .expect("build reader for nested_records.avro");
615        let schema = reader.schema();
616        let batches = reader.collect::<Result<Vec<_>, _>>()?;
617        let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
618        let tmp = NamedTempFile::new().expect("create temp file");
619        let out_path = tmp.into_temp_path();
620        {
621            let out_file = File::create(&out_path).expect("create output avro");
622            let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
623            writer.write(&original)?;
624            writer.finish()?;
625        }
626        let rt_file = File::open(&out_path).expect("open round_trip avro");
627        let mut rt_reader = ReaderBuilder::new()
628            .build(BufReader::new(rt_file))
629            .expect("build round_trip reader");
630        let rt_schema = rt_reader.schema();
631        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
632        let round_trip =
633            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
634        assert_eq!(
635            round_trip, original,
636            "Round-trip batch mismatch for nested_records.avro"
637        );
638        Ok(())
639    }
640
641    #[test]
642    fn test_roundtrip_nested_lists_writer() -> Result<(), ArrowError> {
643        let path = arrow_test_data("avro/nested_lists.snappy.avro");
644        let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro");
645        let mut reader = ReaderBuilder::new()
646            .build(BufReader::new(rdr_file))
647            .expect("build reader for nested_lists.snappy.avro");
648        let schema = reader.schema();
649        let batches = reader.collect::<Result<Vec<_>, _>>()?;
650        let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
651        let tmp = NamedTempFile::new().expect("create temp file");
652        let out_path = tmp.into_temp_path();
653        {
654            let out_file = File::create(&out_path).expect("create output avro");
655            let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
656                .with_compression(Some(CompressionCodec::Snappy))
657                .build::<_, AvroOcfFormat>(out_file)?;
658            writer.write(&original)?;
659            writer.finish()?;
660        }
661        let rt_file = File::open(&out_path).expect("open round_trip avro");
662        let mut rt_reader = ReaderBuilder::new()
663            .build(BufReader::new(rt_file))
664            .expect("build round_trip reader");
665        let rt_schema = rt_reader.schema();
666        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
667        let round_trip =
668            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
669        assert_eq!(
670            round_trip, original,
671            "Round-trip batch mismatch for nested_lists.snappy.avro"
672        );
673        Ok(())
674    }
675
676    #[test]
677    fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> {
678        let path = arrow_test_data("avro/simple_fixed.avro");
679        let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro");
680        let mut reader = ReaderBuilder::new()
681            .build(BufReader::new(rdr_file))
682            .expect("build avro reader");
683        let schema = reader.schema();
684        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
685        let original =
686            arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
687        let tmp = NamedTempFile::new().expect("create temp file");
688        let out_file = File::create(tmp.path()).expect("create temp avro");
689        let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
690        writer.write(&original)?;
691        writer.finish()?;
692        drop(writer);
693        let rt_file = File::open(tmp.path()).expect("open round_trip avro");
694        let mut rt_reader = ReaderBuilder::new()
695            .build(BufReader::new(rt_file))
696            .expect("build round_trip reader");
697        let rt_schema = rt_reader.schema();
698        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
699        let round_trip =
700            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
701        assert_eq!(round_trip, original);
702        Ok(())
703    }
704
705    #[cfg(not(feature = "canonical_extension_types"))]
706    #[test]
707    fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> {
708        let in_file =
709            File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro");
710        let mut reader = ReaderBuilder::new()
711            .build(BufReader::new(in_file))
712            .expect("build reader for duration_uuid.avro");
713        let in_schema = reader.schema();
714        let has_mdn = in_schema.fields().iter().any(|f| {
715            matches!(
716                f.data_type(),
717                DataType::Interval(IntervalUnit::MonthDayNano)
718            )
719        });
720        assert!(
721            has_mdn,
722            "expected at least one Interval(MonthDayNano) field in duration_uuid.avro"
723        );
724        let has_uuid_fixed = in_schema
725            .fields()
726            .iter()
727            .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16)));
728        assert!(
729            has_uuid_fixed,
730            "expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro"
731        );
732        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
733        let input =
734            arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
735        let tmp = NamedTempFile::new().expect("create temp file");
736        {
737            let out_file = File::create(tmp.path()).expect("create temp avro");
738            let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?;
739            writer.write(&input)?;
740            writer.finish()?;
741        }
742        let rt_file = File::open(tmp.path()).expect("open round_trip avro");
743        let mut rt_reader = ReaderBuilder::new()
744            .build(BufReader::new(rt_file))
745            .expect("build round_trip reader");
746        let rt_schema = rt_reader.schema();
747        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
748        let round_trip =
749            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
750        assert_eq!(round_trip, input);
751        Ok(())
752    }
753
754    // This test reads the same 'nonnullable.impala.avro' used by the reader tests,
755    // writes it back out with the writer (hitting Map encoding paths), then reads it
756    // again and asserts exact Arrow equivalence.
757    #[test]
758    fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> {
759        // Load source Avro with Map fields
760        let path = arrow_test_data("avro/nonnullable.impala.avro");
761        let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro");
762        let mut reader = ReaderBuilder::new()
763            .build(BufReader::new(rdr_file))
764            .expect("build reader for nonnullable.impala.avro");
765        // Collect all input batches and concatenate to a single RecordBatch
766        let in_schema = reader.schema();
767        // Sanity: ensure the file actually contains at least one Map field
768        let has_map = in_schema
769            .fields()
770            .iter()
771            .any(|f| matches!(f.data_type(), DataType::Map(_, _)));
772        assert!(
773            has_map,
774            "expected at least one Map field in avro/nonnullable.impala.avro"
775        );
776
777        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
778        let original =
779            arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
780        // Write out using the OCF writer into an in-memory Vec<u8>
781        let buffer = Vec::<u8>::new();
782        let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
783        writer.write(&original)?;
784        writer.finish()?;
785        let out_bytes = writer.into_inner();
786        // Read the produced bytes back with the Reader
787        let mut rt_reader = ReaderBuilder::new()
788            .build(Cursor::new(out_bytes))
789            .expect("build reader for round-tripped in-memory OCF");
790        let rt_schema = rt_reader.schema();
791        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
792        let roundtrip =
793            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
794        // Exact value fidelity (schema + data)
795        assert_eq!(
796            roundtrip, original,
797            "Round-trip Avro map data mismatch for nonnullable.impala.avro"
798        );
799        Ok(())
800    }
801
802    #[test]
803    fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> {
804        // (file, resolve via ARROW_TEST_DATA?)
805        let files: [(&str, bool); 8] = [
806            ("avro/fixed_length_decimal.avro", true), // fixed-backed -> Decimal128(25,2)
807            ("avro/fixed_length_decimal_legacy.avro", true), // legacy fixed[8] -> Decimal64(13,2)
808            ("avro/int32_decimal.avro", true),        // bytes-backed -> Decimal32(4,2)
809            ("avro/int64_decimal.avro", true),        // bytes-backed -> Decimal64(10,2)
810            ("test/data/int256_decimal.avro", false), // bytes-backed -> Decimal256(76,2)
811            ("test/data/fixed256_decimal.avro", false), // fixed[32]-backed -> Decimal256(76,10)
812            ("test/data/fixed_length_decimal_legacy_32.avro", false), // legacy fixed[4] -> Decimal32(9,2)
813            ("test/data/int128_decimal.avro", false), // bytes-backed -> Decimal128(38,2)
814        ];
815        for (rel, in_test_data_dir) in files {
816            // Resolve path the same way as reader::test_decimal
817            let path: String = if in_test_data_dir {
818                arrow_test_data(rel)
819            } else {
820                PathBuf::from(env!("CARGO_MANIFEST_DIR"))
821                    .join(rel)
822                    .to_string_lossy()
823                    .into_owned()
824            };
825            // Read original file into a single RecordBatch for comparison
826            let f_in = File::open(&path).expect("open input avro");
827            let mut rdr = ReaderBuilder::new().build(BufReader::new(f_in))?;
828            let in_schema = rdr.schema();
829            let in_batches = rdr.collect::<Result<Vec<_>, _>>()?;
830            let original =
831                arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
832            // Write it out with the OCF writer (no special compression)
833            let tmp = NamedTempFile::new().expect("create temp file");
834            let out_path = tmp.into_temp_path();
835            let out_file = File::create(&out_path).expect("create temp avro");
836            let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
837            writer.write(&original)?;
838            writer.finish()?;
839            // Read back the file we just wrote and compare equality (schema + data)
840            let f_rt = File::open(&out_path).expect("open roundtrip avro");
841            let mut rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?;
842            let rt_schema = rt_rdr.schema();
843            let rt_batches = rt_rdr.collect::<Result<Vec<_>, _>>()?;
844            let roundtrip =
845                arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt");
846            assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}");
847        }
848        Ok(())
849    }
850
851    #[test]
852    fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> {
853        // Read the known-good enum file (same as reader::test_simple)
854        let path = arrow_test_data("avro/simple_enum.avro");
855        let rdr_file = File::open(&path).expect("open avro/simple_enum.avro");
856        let mut reader = ReaderBuilder::new()
857            .build(BufReader::new(rdr_file))
858            .expect("build reader for simple_enum.avro");
859        // Concatenate all batches to one RecordBatch for a clean equality check
860        let in_schema = reader.schema();
861        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
862        let original =
863            arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
864        // Sanity: expect at least one Dictionary(Int32, Utf8) column (enum)
865        let has_enum_dict = in_schema.fields().iter().any(|f| {
866            matches!(
867                f.data_type(),
868                DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8
869            )
870        });
871        assert!(
872            has_enum_dict,
873            "Expected at least one enum-mapped Dictionary<Int32, Utf8> field"
874        );
875        // Write with OCF writer into memory using the reader-provided Arrow schema.
876        // The writer will embed the Avro JSON from `avro.schema` metadata if present.
877        let buffer: Vec<u8> = Vec::new();
878        let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
879        writer.write(&original)?;
880        writer.finish()?;
881        let bytes = writer.into_inner();
882        // Read back and compare for exact equality (schema + data)
883        let mut rt_reader = ReaderBuilder::new()
884            .build(Cursor::new(bytes))
885            .expect("reader for round-trip");
886        let rt_schema = rt_reader.schema();
887        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
888        let roundtrip =
889            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
890        assert_eq!(roundtrip, original, "Avro enum round-trip mismatch");
891        Ok(())
892    }
893
894    #[test]
895    #[cfg(feature = "avro_custom_types")]
896    fn test_roundtrip_duration_logical_types_ocf() -> Result<(), ArrowError> {
897        let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
898            .join("test/data/duration_logical_types.avro")
899            .to_string_lossy()
900            .into_owned();
901
902        let in_file = File::open(&file_path)
903            .unwrap_or_else(|_| panic!("Failed to open test file: {}", file_path));
904
905        let mut reader = ReaderBuilder::new()
906            .build(BufReader::new(in_file))
907            .expect("build reader for duration_logical_types.avro");
908        let in_schema = reader.schema();
909
910        let expected_units: HashSet<TimeUnit> = [
911            TimeUnit::Nanosecond,
912            TimeUnit::Microsecond,
913            TimeUnit::Millisecond,
914            TimeUnit::Second,
915        ]
916        .into_iter()
917        .collect();
918
919        let found_units: HashSet<TimeUnit> = in_schema
920            .fields()
921            .iter()
922            .filter_map(|f| match f.data_type() {
923                DataType::Duration(unit) => Some(*unit),
924                _ => None,
925            })
926            .collect();
927
928        assert_eq!(
929            found_units, expected_units,
930            "Expected to find all four Duration TimeUnits in the schema from the initial read"
931        );
932
933        let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
934        let input =
935            arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
936
937        let tmp = NamedTempFile::new().expect("create temp file");
938        {
939            let out_file = File::create(tmp.path()).expect("create temp avro");
940            let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?;
941            writer.write(&input)?;
942            writer.finish()?;
943        }
944
945        let rt_file = File::open(tmp.path()).expect("open round_trip avro");
946        let mut rt_reader = ReaderBuilder::new()
947            .build(BufReader::new(rt_file))
948            .expect("build round_trip reader");
949        let rt_schema = rt_reader.schema();
950        let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
951        let round_trip =
952            arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
953
954        assert_eq!(round_trip, input);
955
956        Ok(())
957    }
958}