arrow_avro/reader/
record.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::codec::{AvroDataType, Codec, Nullability};
19use crate::reader::block::{Block, BlockDecoder};
20use crate::reader::cursor::AvroCursor;
21use crate::reader::header::Header;
22use crate::schema::*;
23use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::*;
27use arrow_schema::{
28    ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
29    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
30};
31use std::cmp::Ordering;
32use std::collections::HashMap;
33use std::io::Read;
34use std::sync::Arc;
35
36const DEFAULT_CAPACITY: usize = 1024;
37
38#[derive(Debug)]
39pub(crate) struct RecordDecoderBuilder<'a> {
40    data_type: &'a AvroDataType,
41    use_utf8view: bool,
42    strict_mode: bool,
43}
44
45impl<'a> RecordDecoderBuilder<'a> {
46    pub(crate) fn new(data_type: &'a AvroDataType) -> Self {
47        Self {
48            data_type,
49            use_utf8view: false,
50            strict_mode: false,
51        }
52    }
53
54    pub(crate) fn with_utf8_view(mut self, use_utf8view: bool) -> Self {
55        self.use_utf8view = use_utf8view;
56        self
57    }
58
59    pub(crate) fn with_strict_mode(mut self, strict_mode: bool) -> Self {
60        self.strict_mode = strict_mode;
61        self
62    }
63
64    /// Builds the `RecordDecoder`.
65    pub(crate) fn build(self) -> Result<RecordDecoder, ArrowError> {
66        RecordDecoder::try_new_with_options(self.data_type, self.use_utf8view, self.strict_mode)
67    }
68}
69
70/// Decodes avro encoded data into [`RecordBatch`]
71#[derive(Debug)]
72pub(crate) struct RecordDecoder {
73    schema: SchemaRef,
74    fields: Vec<Decoder>,
75    use_utf8view: bool,
76    strict_mode: bool,
77}
78
79impl RecordDecoder {
80    /// Creates a new `RecordDecoderBuilder` for configuring a `RecordDecoder`.
81    pub(crate) fn new(data_type: &'_ AvroDataType) -> Self {
82        RecordDecoderBuilder::new(data_type).build().unwrap()
83    }
84
85    /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with default options
86    pub(crate) fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
87        RecordDecoderBuilder::new(data_type)
88            .with_utf8_view(true)
89            .with_strict_mode(true)
90            .build()
91    }
92
93    /// Creates a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options.
94    ///
95    /// This method allows you to customize how the Avro data is decoded into Arrow arrays.
96    ///
97    /// # Arguments
98    /// * `data_type` - The Avro data type to decode.
99    /// * `use_utf8view` - A flag indicating whether to use `Utf8View` for string types.
100    /// * `strict_mode` - A flag to enable strict decoding, returning an error if the data
101    ///   does not conform to the schema.
102    ///
103    /// # Errors
104    /// This function will return an error if the provided `data_type` is not a `Record`.
105    pub(crate) fn try_new_with_options(
106        data_type: &AvroDataType,
107        use_utf8view: bool,
108        strict_mode: bool,
109    ) -> Result<Self, ArrowError> {
110        match Decoder::try_new(data_type)? {
111            Decoder::Record(fields, encodings) => Ok(Self {
112                schema: Arc::new(ArrowSchema::new(fields)),
113                fields: encodings,
114                use_utf8view,
115                strict_mode,
116            }),
117            encoding => Err(ArrowError::ParseError(format!(
118                "Expected record got {encoding:?}"
119            ))),
120        }
121    }
122
123    /// Returns the decoder's `SchemaRef`
124    pub(crate) fn schema(&self) -> &SchemaRef {
125        &self.schema
126    }
127
128    /// Decode `count` records from `buf`
129    pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result<usize, ArrowError> {
130        let mut cursor = AvroCursor::new(buf);
131        for _ in 0..count {
132            for field in &mut self.fields {
133                field.decode(&mut cursor)?;
134            }
135        }
136        Ok(cursor.position())
137    }
138
139    /// Flush the decoded records into a [`RecordBatch`]
140    pub(crate) fn flush(&mut self) -> Result<RecordBatch, ArrowError> {
141        let arrays = self
142            .fields
143            .iter_mut()
144            .map(|x| x.flush(None))
145            .collect::<Result<Vec<_>, _>>()?;
146
147        RecordBatch::try_new(self.schema.clone(), arrays)
148    }
149}
150
151#[derive(Debug)]
152enum Decoder {
153    Null(usize),
154    Boolean(BooleanBufferBuilder),
155    Int32(Vec<i32>),
156    Int64(Vec<i64>),
157    Float32(Vec<f32>),
158    Float64(Vec<f64>),
159    Date32(Vec<i32>),
160    TimeMillis(Vec<i32>),
161    TimeMicros(Vec<i64>),
162    TimestampMillis(bool, Vec<i64>),
163    TimestampMicros(bool, Vec<i64>),
164    Binary(OffsetBufferBuilder<i32>, Vec<u8>),
165    /// String data encoded as UTF-8 bytes, mapped to Arrow's StringArray
166    String(OffsetBufferBuilder<i32>, Vec<u8>),
167    /// String data encoded as UTF-8 bytes, but mapped to Arrow's StringViewArray
168    StringView(OffsetBufferBuilder<i32>, Vec<u8>),
169    Array(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
170    Record(Fields, Vec<Decoder>),
171    Map(
172        FieldRef,
173        OffsetBufferBuilder<i32>,
174        OffsetBufferBuilder<i32>,
175        Vec<u8>,
176        Box<Decoder>,
177    ),
178    Fixed(i32, Vec<u8>),
179    Enum(Vec<i32>, Arc<[String]>),
180    Decimal128(usize, Option<usize>, Option<usize>, Decimal128Builder),
181    Decimal256(usize, Option<usize>, Option<usize>, Decimal256Builder),
182    Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
183}
184
185impl Decoder {
186    fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
187        let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string()));
188
189        let decoder = match data_type.codec() {
190            Codec::Null => Self::Null(0),
191            Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)),
192            Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)),
193            Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)),
194            Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)),
195            Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)),
196            Codec::Binary => Self::Binary(
197                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
198                Vec::with_capacity(DEFAULT_CAPACITY),
199            ),
200            Codec::Utf8 => Self::String(
201                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
202                Vec::with_capacity(DEFAULT_CAPACITY),
203            ),
204            Codec::Utf8View => Self::StringView(
205                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
206                Vec::with_capacity(DEFAULT_CAPACITY),
207            ),
208            Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)),
209            Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)),
210            Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)),
211            Codec::TimestampMillis(is_utc) => {
212                Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
213            }
214            Codec::TimestampMicros(is_utc) => {
215                Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
216            }
217            Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)),
218            Codec::Decimal(precision, scale, size) => {
219                let p = *precision;
220                let s = *scale;
221                let sz = *size;
222                let prec = p as u8;
223                let scl = s.unwrap_or(0) as i8;
224                match (sz, p) {
225                    (Some(fixed_size), _) if fixed_size <= 16 => {
226                        let builder =
227                            Decimal128Builder::new().with_precision_and_scale(prec, scl)?;
228                        Self::Decimal128(p, s, sz, builder)
229                    }
230                    (Some(fixed_size), _) if fixed_size <= 32 => {
231                        let builder =
232                            Decimal256Builder::new().with_precision_and_scale(prec, scl)?;
233                        Self::Decimal256(p, s, sz, builder)
234                    }
235                    (Some(fixed_size), _) => {
236                        return Err(ArrowError::ParseError(format!(
237                            "Unsupported decimal size: {fixed_size:?}"
238                        )));
239                    }
240                    (None, p) if p <= DECIMAL128_MAX_PRECISION as usize => {
241                        let builder =
242                            Decimal128Builder::new().with_precision_and_scale(prec, scl)?;
243                        Self::Decimal128(p, s, sz, builder)
244                    }
245                    (None, p) if p <= DECIMAL256_MAX_PRECISION as usize => {
246                        let builder =
247                            Decimal256Builder::new().with_precision_and_scale(prec, scl)?;
248                        Self::Decimal256(p, s, sz, builder)
249                    }
250                    (None, _) => {
251                        return Err(ArrowError::ParseError(format!(
252                            "Decimal precision {p} exceeds maximum supported"
253                        )));
254                    }
255                }
256            }
257            Codec::Interval => return nyi("decoding interval"),
258            Codec::List(item) => {
259                let decoder = Self::try_new(item)?;
260                Self::Array(
261                    Arc::new(item.field_with_name("item")),
262                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
263                    Box::new(decoder),
264                )
265            }
266            Codec::Enum(symbols) => {
267                Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone())
268            }
269            Codec::Struct(fields) => {
270                let mut arrow_fields = Vec::with_capacity(fields.len());
271                let mut encodings = Vec::with_capacity(fields.len());
272                for avro_field in fields.iter() {
273                    let encoding = Self::try_new(avro_field.data_type())?;
274                    arrow_fields.push(avro_field.field());
275                    encodings.push(encoding);
276                }
277                Self::Record(arrow_fields.into(), encodings)
278            }
279            Codec::Map(child) => {
280                let val_field = child.field_with_name("value").with_nullable(true);
281                let map_field = Arc::new(ArrowField::new(
282                    "entries",
283                    DataType::Struct(Fields::from(vec![
284                        ArrowField::new("key", DataType::Utf8, false),
285                        val_field,
286                    ])),
287                    false,
288                ));
289                let val_dec = Self::try_new(child)?;
290                Self::Map(
291                    map_field,
292                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
293                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
294                    Vec::with_capacity(DEFAULT_CAPACITY),
295                    Box::new(val_dec),
296                )
297            }
298            Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)),
299        };
300        Ok(match data_type.nullability() {
301            Some(nullability) => Self::Nullable(
302                nullability,
303                NullBufferBuilder::new(DEFAULT_CAPACITY),
304                Box::new(decoder),
305            ),
306            None => decoder,
307        })
308    }
309
310    /// Append a null record
311    fn append_null(&mut self) {
312        match self {
313            Self::Null(count) => *count += 1,
314            Self::Boolean(b) => b.append(false),
315            Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0),
316            Self::Int64(v)
317            | Self::TimeMicros(v)
318            | Self::TimestampMillis(_, v)
319            | Self::TimestampMicros(_, v) => v.push(0),
320            Self::Float32(v) => v.push(0.),
321            Self::Float64(v) => v.push(0.),
322            Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => {
323                offsets.push_length(0);
324            }
325            Self::Array(_, offsets, e) => {
326                offsets.push_length(0);
327                e.append_null();
328            }
329            Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
330            Self::Map(_, _koff, moff, _, _) => {
331                moff.push_length(0);
332            }
333            Self::Fixed(sz, accum) => {
334                accum.extend(std::iter::repeat(0u8).take(*sz as usize));
335            }
336            Self::Decimal128(_, _, _, builder) => builder.append_value(0),
337            Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO),
338            Self::Enum(indices, _) => indices.push(0),
339            Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
340        }
341    }
342
343    /// Decode a single record from `buf`
344    fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
345        match self {
346            Self::Null(x) => *x += 1,
347            Self::Boolean(values) => values.append(buf.get_bool()?),
348            Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => {
349                values.push(buf.get_int()?)
350            }
351            Self::Int64(values)
352            | Self::TimeMicros(values)
353            | Self::TimestampMillis(_, values)
354            | Self::TimestampMicros(_, values) => values.push(buf.get_long()?),
355            Self::Float32(values) => values.push(buf.get_float()?),
356            Self::Float64(values) => values.push(buf.get_double()?),
357            Self::Binary(offsets, values)
358            | Self::String(offsets, values)
359            | Self::StringView(offsets, values) => {
360                let data = buf.get_bytes()?;
361                offsets.push_length(data.len());
362                values.extend_from_slice(data);
363            }
364            Self::Array(_, off, encoding) => {
365                let total_items = read_blocks(buf, |cursor| encoding.decode(cursor))?;
366                off.push_length(total_items);
367            }
368            Self::Record(_, encodings) => {
369                for encoding in encodings {
370                    encoding.decode(buf)?;
371                }
372            }
373            Self::Map(_, koff, moff, kdata, valdec) => {
374                let newly_added = read_blocks(buf, |cur| {
375                    let kb = cur.get_bytes()?;
376                    koff.push_length(kb.len());
377                    kdata.extend_from_slice(kb);
378                    valdec.decode(cur)
379                })?;
380                moff.push_length(newly_added);
381            }
382            Self::Fixed(sz, accum) => {
383                let fx = buf.get_fixed(*sz as usize)?;
384                accum.extend_from_slice(fx);
385            }
386            Self::Decimal128(_, _, size, builder) => {
387                let raw = if let Some(s) = size {
388                    buf.get_fixed(*s)?
389                } else {
390                    buf.get_bytes()?
391                };
392                let ext = sign_extend_to::<16>(raw)?;
393                let val = i128::from_be_bytes(ext);
394                builder.append_value(val);
395            }
396            Self::Decimal256(_, _, size, builder) => {
397                let raw = if let Some(s) = size {
398                    buf.get_fixed(*s)?
399                } else {
400                    buf.get_bytes()?
401                };
402                let ext = sign_extend_to::<32>(raw)?;
403                let val = i256::from_be_bytes(ext);
404                builder.append_value(val);
405            }
406            Self::Enum(indices, _) => {
407                indices.push(buf.get_int()?);
408            }
409            Self::Nullable(nullability, nulls, e) => {
410                let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
411                nulls.append(is_valid);
412                match is_valid {
413                    true => e.decode(buf)?,
414                    false => e.append_null(),
415                }
416            }
417        }
418        Ok(())
419    }
420
421    /// Flush decoded records to an [`ArrayRef`]
422    fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, ArrowError> {
423        Ok(match self {
424            Self::Nullable(_, n, e) => e.flush(n.finish())?,
425            Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))),
426            Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)),
427            Self::Int32(values) => Arc::new(flush_primitive::<Int32Type>(values, nulls)),
428            Self::Date32(values) => Arc::new(flush_primitive::<Date32Type>(values, nulls)),
429            Self::Int64(values) => Arc::new(flush_primitive::<Int64Type>(values, nulls)),
430            Self::TimeMillis(values) => {
431                Arc::new(flush_primitive::<Time32MillisecondType>(values, nulls))
432            }
433            Self::TimeMicros(values) => {
434                Arc::new(flush_primitive::<Time64MicrosecondType>(values, nulls))
435            }
436            Self::TimestampMillis(is_utc, values) => Arc::new(
437                flush_primitive::<TimestampMillisecondType>(values, nulls)
438                    .with_timezone_opt(is_utc.then(|| "+00:00")),
439            ),
440            Self::TimestampMicros(is_utc, values) => Arc::new(
441                flush_primitive::<TimestampMicrosecondType>(values, nulls)
442                    .with_timezone_opt(is_utc.then(|| "+00:00")),
443            ),
444            Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
445            Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
446            Self::Binary(offsets, values) => {
447                let offsets = flush_offsets(offsets);
448                let values = flush_values(values).into();
449                Arc::new(BinaryArray::new(offsets, values, nulls))
450            }
451            Self::String(offsets, values) => {
452                let offsets = flush_offsets(offsets);
453                let values = flush_values(values).into();
454                Arc::new(StringArray::new(offsets, values, nulls))
455            }
456            Self::StringView(offsets, values) => {
457                let offsets = flush_offsets(offsets);
458                let values = flush_values(values);
459                let array = StringArray::new(offsets, values.into(), nulls.clone());
460                let values: Vec<&str> = (0..array.len())
461                    .map(|i| {
462                        if array.is_valid(i) {
463                            array.value(i)
464                        } else {
465                            ""
466                        }
467                    })
468                    .collect();
469
470                Arc::new(StringViewArray::from(values))
471            }
472            Self::Array(field, offsets, values) => {
473                let values = values.flush(None)?;
474                let offsets = flush_offsets(offsets);
475                Arc::new(ListArray::new(field.clone(), offsets, values, nulls))
476            }
477            Self::Record(fields, encodings) => {
478                let arrays = encodings
479                    .iter_mut()
480                    .map(|x| x.flush(None))
481                    .collect::<Result<Vec<_>, _>>()?;
482                Arc::new(StructArray::new(fields.clone(), arrays, nulls))
483            }
484            Self::Map(map_field, k_off, m_off, kdata, valdec) => {
485                let moff = flush_offsets(m_off);
486                let koff = flush_offsets(k_off);
487                let kd = flush_values(kdata).into();
488                let val_arr = valdec.flush(None)?;
489                let key_arr = StringArray::new(koff, kd, None);
490                if key_arr.len() != val_arr.len() {
491                    return Err(ArrowError::InvalidArgumentError(format!(
492                        "Map keys length ({}) != map values length ({})",
493                        key_arr.len(),
494                        val_arr.len()
495                    )));
496                }
497                let final_len = moff.len() - 1;
498                if let Some(n) = &nulls {
499                    if n.len() != final_len {
500                        return Err(ArrowError::InvalidArgumentError(format!(
501                            "Map array null buffer length {} != final map length {final_len}",
502                            n.len()
503                        )));
504                    }
505                }
506                let entries_struct = StructArray::new(
507                    Fields::from(vec![
508                        Arc::new(ArrowField::new("key", DataType::Utf8, false)),
509                        Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)),
510                    ]),
511                    vec![Arc::new(key_arr), val_arr],
512                    None,
513                );
514                let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false);
515                Arc::new(map_arr)
516            }
517            Self::Fixed(sz, accum) => {
518                let b: Buffer = flush_values(accum).into();
519                let arr = FixedSizeBinaryArray::try_new(*sz, b, nulls)
520                    .map_err(|e| ArrowError::ParseError(e.to_string()))?;
521                Arc::new(arr)
522            }
523            Self::Decimal128(precision, scale, _, builder) => {
524                let mut b = std::mem::take(builder);
525                let (_, vals, _) = b.finish().into_parts();
526                let scl = scale.unwrap_or(0);
527                let dec = Decimal128Array::new(vals, nulls)
528                    .with_precision_and_scale(*precision as u8, scl as i8)
529                    .map_err(|e| ArrowError::ParseError(e.to_string()))?;
530                Arc::new(dec)
531            }
532            Self::Decimal256(precision, scale, _, builder) => {
533                let mut b = std::mem::take(builder);
534                let (_, vals, _) = b.finish().into_parts();
535                let scl = scale.unwrap_or(0);
536                let dec = Decimal256Array::new(vals, nulls)
537                    .with_precision_and_scale(*precision as u8, scl as i8)
538                    .map_err(|e| ArrowError::ParseError(e.to_string()))?;
539                Arc::new(dec)
540            }
541            Self::Enum(indices, symbols) => {
542                let keys = flush_primitive::<Int32Type>(indices, nulls);
543                let values = Arc::new(StringArray::from(
544                    symbols.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
545                ));
546                Arc::new(DictionaryArray::try_new(keys, values)?)
547            }
548        })
549    }
550}
551
552fn read_blocks(
553    buf: &mut AvroCursor,
554    decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
555) -> Result<usize, ArrowError> {
556    read_blockwise_items(buf, true, decode_entry)
557}
558
559fn read_blockwise_items(
560    buf: &mut AvroCursor,
561    read_size_after_negative: bool,
562    mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
563) -> Result<usize, ArrowError> {
564    let mut total = 0usize;
565    loop {
566        // Read the block count
567        //  positive = that many items
568        //  negative = that many items + read block size
569        //  See: https://avro.apache.org/docs/1.11.1/specification/#maps
570        let block_count = buf.get_long()?;
571        match block_count.cmp(&0) {
572            Ordering::Equal => break,
573            Ordering::Less => {
574                // If block_count is negative, read the absolute value of count,
575                // then read the block size as a long and discard
576                let count = (-block_count) as usize;
577                if read_size_after_negative {
578                    let _size_in_bytes = buf.get_long()?;
579                }
580                for _ in 0..count {
581                    decode_fn(buf)?;
582                }
583                total += count;
584            }
585            Ordering::Greater => {
586                // If block_count is positive, decode that many items
587                let count = block_count as usize;
588                for _i in 0..count {
589                    decode_fn(buf)?;
590                }
591                total += count;
592            }
593        }
594    }
595    Ok(total)
596}
597
598#[inline]
599fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
600    std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
601}
602
603#[inline]
604fn flush_offsets(offsets: &mut OffsetBufferBuilder<i32>) -> OffsetBuffer<i32> {
605    std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish()
606}
607
608#[inline]
609fn flush_primitive<T: ArrowPrimitiveType>(
610    values: &mut Vec<T::Native>,
611    nulls: Option<NullBuffer>,
612) -> PrimitiveArray<T> {
613    PrimitiveArray::new(flush_values(values).into(), nulls)
614}
615
616/// Sign extends a byte slice to a fixed-size array of N bytes.
617/// This is done by filling the leading bytes with 0x00 for positive numbers
618/// or 0xFF for negative numbers.
619#[inline]
620fn sign_extend_to<const N: usize>(raw: &[u8]) -> Result<[u8; N], ArrowError> {
621    if raw.len() > N {
622        return Err(ArrowError::ParseError(format!(
623            "Cannot extend a slice of length {} to {} bytes.",
624            raw.len(),
625            N
626        )));
627    }
628    let mut arr = [0u8; N];
629    let pad_len = N - raw.len();
630    // Determine the byte to use for padding based on the sign bit of the raw data.
631    let extension_byte = if raw.is_empty() || (raw[0] & 0x80 == 0) {
632        0x00
633    } else {
634        0xFF
635    };
636    arr[..pad_len].fill(extension_byte);
637    arr[pad_len..].copy_from_slice(raw);
638    Ok(arr)
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use arrow_array::{
645        cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray,
646        IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray,
647    };
648
649    fn encode_avro_int(value: i32) -> Vec<u8> {
650        let mut buf = Vec::new();
651        let mut v = (value << 1) ^ (value >> 31);
652        while v & !0x7F != 0 {
653            buf.push(((v & 0x7F) | 0x80) as u8);
654            v >>= 7;
655        }
656        buf.push(v as u8);
657        buf
658    }
659
660    fn encode_avro_long(value: i64) -> Vec<u8> {
661        let mut buf = Vec::new();
662        let mut v = (value << 1) ^ (value >> 63);
663        while v & !0x7F != 0 {
664            buf.push(((v & 0x7F) | 0x80) as u8);
665            v >>= 7;
666        }
667        buf.push(v as u8);
668        buf
669    }
670
671    fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
672        let mut buf = encode_avro_long(bytes.len() as i64);
673        buf.extend_from_slice(bytes);
674        buf
675    }
676
677    fn avro_from_codec(codec: Codec) -> AvroDataType {
678        AvroDataType::new(codec, Default::default(), None)
679    }
680
681    #[test]
682    fn test_map_decoding_one_entry() {
683        let value_type = avro_from_codec(Codec::Utf8);
684        let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
685        let mut decoder = Decoder::try_new(&map_type).unwrap();
686        // Encode a single map with one entry: {"hello": "world"}
687        let mut data = Vec::new();
688        data.extend_from_slice(&encode_avro_long(1));
689        data.extend_from_slice(&encode_avro_bytes(b"hello")); // key
690        data.extend_from_slice(&encode_avro_bytes(b"world")); // value
691        data.extend_from_slice(&encode_avro_long(0));
692        let mut cursor = AvroCursor::new(&data);
693        decoder.decode(&mut cursor).unwrap();
694        let array = decoder.flush(None).unwrap();
695        let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
696        assert_eq!(map_arr.len(), 1); // one map
697        assert_eq!(map_arr.value_length(0), 1);
698        let entries = map_arr.value(0);
699        let struct_entries = entries.as_any().downcast_ref::<StructArray>().unwrap();
700        assert_eq!(struct_entries.len(), 1);
701        let key_arr = struct_entries
702            .column_by_name("key")
703            .unwrap()
704            .as_any()
705            .downcast_ref::<StringArray>()
706            .unwrap();
707        let val_arr = struct_entries
708            .column_by_name("value")
709            .unwrap()
710            .as_any()
711            .downcast_ref::<StringArray>()
712            .unwrap();
713        assert_eq!(key_arr.value(0), "hello");
714        assert_eq!(val_arr.value(0), "world");
715    }
716
717    #[test]
718    fn test_map_decoding_empty() {
719        let value_type = avro_from_codec(Codec::Utf8);
720        let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
721        let mut decoder = Decoder::try_new(&map_type).unwrap();
722        let data = encode_avro_long(0);
723        decoder.decode(&mut AvroCursor::new(&data)).unwrap();
724        let array = decoder.flush(None).unwrap();
725        let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
726        assert_eq!(map_arr.len(), 1);
727        assert_eq!(map_arr.value_length(0), 0);
728    }
729
730    #[test]
731    fn test_fixed_decoding() {
732        let avro_type = avro_from_codec(Codec::Fixed(3));
733        let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder");
734
735        let data1 = [1u8, 2, 3];
736        let mut cursor1 = AvroCursor::new(&data1);
737        decoder
738            .decode(&mut cursor1)
739            .expect("Failed to decode data1");
740        assert_eq!(cursor1.position(), 3, "Cursor should advance by fixed size");
741        let data2 = [4u8, 5, 6];
742        let mut cursor2 = AvroCursor::new(&data2);
743        decoder
744            .decode(&mut cursor2)
745            .expect("Failed to decode data2");
746        assert_eq!(cursor2.position(), 3, "Cursor should advance by fixed size");
747        let array = decoder.flush(None).expect("Failed to flush decoder");
748        assert_eq!(array.len(), 2, "Array should contain two items");
749        let fixed_size_binary_array = array
750            .as_any()
751            .downcast_ref::<FixedSizeBinaryArray>()
752            .expect("Failed to downcast to FixedSizeBinaryArray");
753        assert_eq!(
754            fixed_size_binary_array.value_length(),
755            3,
756            "Fixed size of binary values should be 3"
757        );
758        assert_eq!(
759            fixed_size_binary_array.value(0),
760            &[1, 2, 3],
761            "First item mismatch"
762        );
763        assert_eq!(
764            fixed_size_binary_array.value(1),
765            &[4, 5, 6],
766            "Second item mismatch"
767        );
768    }
769
770    #[test]
771    fn test_fixed_decoding_empty() {
772        let avro_type = avro_from_codec(Codec::Fixed(5));
773        let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder");
774
775        let array = decoder
776            .flush(None)
777            .expect("Failed to flush decoder for empty input");
778
779        assert_eq!(array.len(), 0, "Array should be empty");
780        let fixed_size_binary_array = array
781            .as_any()
782            .downcast_ref::<FixedSizeBinaryArray>()
783            .expect("Failed to downcast to FixedSizeBinaryArray for empty array");
784
785        assert_eq!(
786            fixed_size_binary_array.value_length(),
787            5,
788            "Fixed size of binary values should be 5 as per type"
789        );
790    }
791
792    #[test]
793    fn test_uuid_decoding() {
794        let avro_type = avro_from_codec(Codec::Uuid);
795        let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder");
796
797        let data1 = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
798        let mut cursor1 = AvroCursor::new(&data1);
799        decoder
800            .decode(&mut cursor1)
801            .expect("Failed to decode data1");
802        assert_eq!(
803            cursor1.position(),
804            16,
805            "Cursor should advance by fixed size"
806        );
807    }
808
809    #[test]
810    fn test_array_decoding() {
811        let item_dt = avro_from_codec(Codec::Int32);
812        let list_dt = avro_from_codec(Codec::List(Arc::new(item_dt)));
813        let mut decoder = Decoder::try_new(&list_dt).unwrap();
814        let mut row1 = Vec::new();
815        row1.extend_from_slice(&encode_avro_long(2));
816        row1.extend_from_slice(&encode_avro_int(10));
817        row1.extend_from_slice(&encode_avro_int(20));
818        row1.extend_from_slice(&encode_avro_long(0));
819        let row2 = encode_avro_long(0);
820        let mut cursor = AvroCursor::new(&row1);
821        decoder.decode(&mut cursor).unwrap();
822        let mut cursor2 = AvroCursor::new(&row2);
823        decoder.decode(&mut cursor2).unwrap();
824        let array = decoder.flush(None).unwrap();
825        let list_arr = array.as_any().downcast_ref::<ListArray>().unwrap();
826        assert_eq!(list_arr.len(), 2);
827        let offsets = list_arr.value_offsets();
828        assert_eq!(offsets, &[0, 2, 2]);
829        let values = list_arr.values();
830        let int_arr = values.as_primitive::<Int32Type>();
831        assert_eq!(int_arr.len(), 2);
832        assert_eq!(int_arr.value(0), 10);
833        assert_eq!(int_arr.value(1), 20);
834    }
835
836    #[test]
837    fn test_array_decoding_with_negative_block_count() {
838        let item_dt = avro_from_codec(Codec::Int32);
839        let list_dt = avro_from_codec(Codec::List(Arc::new(item_dt)));
840        let mut decoder = Decoder::try_new(&list_dt).unwrap();
841        let mut data = encode_avro_long(-3);
842        data.extend_from_slice(&encode_avro_long(12));
843        data.extend_from_slice(&encode_avro_int(1));
844        data.extend_from_slice(&encode_avro_int(2));
845        data.extend_from_slice(&encode_avro_int(3));
846        data.extend_from_slice(&encode_avro_long(0));
847        let mut cursor = AvroCursor::new(&data);
848        decoder.decode(&mut cursor).unwrap();
849        let array = decoder.flush(None).unwrap();
850        let list_arr = array.as_any().downcast_ref::<ListArray>().unwrap();
851        assert_eq!(list_arr.len(), 1);
852        assert_eq!(list_arr.value_length(0), 3);
853        let values = list_arr.values().as_primitive::<Int32Type>();
854        assert_eq!(values.len(), 3);
855        assert_eq!(values.value(0), 1);
856        assert_eq!(values.value(1), 2);
857        assert_eq!(values.value(2), 3);
858    }
859
860    #[test]
861    fn test_nested_array_decoding() {
862        let inner_ty = avro_from_codec(Codec::List(Arc::new(avro_from_codec(Codec::Int32))));
863        let nested_ty = avro_from_codec(Codec::List(Arc::new(inner_ty.clone())));
864        let mut decoder = Decoder::try_new(&nested_ty).unwrap();
865        let mut buf = Vec::new();
866        buf.extend(encode_avro_long(1));
867        buf.extend(encode_avro_long(2));
868        buf.extend(encode_avro_int(5));
869        buf.extend(encode_avro_int(6));
870        buf.extend(encode_avro_long(0));
871        buf.extend(encode_avro_long(0));
872        let mut cursor = AvroCursor::new(&buf);
873        decoder.decode(&mut cursor).unwrap();
874        let arr = decoder.flush(None).unwrap();
875        let outer = arr.as_any().downcast_ref::<ListArray>().unwrap();
876        assert_eq!(outer.len(), 1);
877        assert_eq!(outer.value_length(0), 1);
878        let inner = outer.values().as_any().downcast_ref::<ListArray>().unwrap();
879        assert_eq!(inner.len(), 1);
880        assert_eq!(inner.value_length(0), 2);
881        let values = inner
882            .values()
883            .as_any()
884            .downcast_ref::<Int32Array>()
885            .unwrap();
886        assert_eq!(values.values(), &[5, 6]);
887    }
888
889    #[test]
890    fn test_array_decoding_empty_array() {
891        let value_type = avro_from_codec(Codec::Utf8);
892        let map_type = avro_from_codec(Codec::List(Arc::new(value_type)));
893        let mut decoder = Decoder::try_new(&map_type).unwrap();
894        let data = encode_avro_long(0);
895        decoder.decode(&mut AvroCursor::new(&data)).unwrap();
896        let array = decoder.flush(None).unwrap();
897        let list_arr = array.as_any().downcast_ref::<ListArray>().unwrap();
898        assert_eq!(list_arr.len(), 1);
899        assert_eq!(list_arr.value_length(0), 0);
900    }
901
902    #[test]
903    fn test_decimal_decoding_fixed256() {
904        let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32)));
905        let mut decoder = Decoder::try_new(&dt).unwrap();
906        let row1 = [
907            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
908            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
909            0x00, 0x00, 0x30, 0x39,
910        ];
911        let row2 = [
912            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
913            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
914            0xFF, 0xFF, 0xFF, 0x85,
915        ];
916        let mut data = Vec::new();
917        data.extend_from_slice(&row1);
918        data.extend_from_slice(&row2);
919        let mut cursor = AvroCursor::new(&data);
920        decoder.decode(&mut cursor).unwrap();
921        decoder.decode(&mut cursor).unwrap();
922        let arr = decoder.flush(None).unwrap();
923        let dec = arr.as_any().downcast_ref::<Decimal256Array>().unwrap();
924        assert_eq!(dec.len(), 2);
925        assert_eq!(dec.value_as_string(0), "123.45");
926        assert_eq!(dec.value_as_string(1), "-1.23");
927    }
928
929    #[test]
930    fn test_decimal_decoding_fixed128() {
931        let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16)));
932        let mut decoder = Decoder::try_new(&dt).unwrap();
933        let row1 = [
934            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
935            0x30, 0x39,
936        ];
937        let row2 = [
938            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
939            0xFF, 0x85,
940        ];
941        let mut data = Vec::new();
942        data.extend_from_slice(&row1);
943        data.extend_from_slice(&row2);
944        let mut cursor = AvroCursor::new(&data);
945        decoder.decode(&mut cursor).unwrap();
946        decoder.decode(&mut cursor).unwrap();
947        let arr = decoder.flush(None).unwrap();
948        let dec = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
949        assert_eq!(dec.len(), 2);
950        assert_eq!(dec.value_as_string(0), "123.45");
951        assert_eq!(dec.value_as_string(1), "-1.23");
952    }
953
954    #[test]
955    fn test_decimal_decoding_bytes_with_nulls() {
956        let dt = avro_from_codec(Codec::Decimal(4, Some(1), None));
957        let inner = Decoder::try_new(&dt).unwrap();
958        let mut decoder = Decoder::Nullable(
959            Nullability::NullSecond,
960            NullBufferBuilder::new(DEFAULT_CAPACITY),
961            Box::new(inner),
962        );
963        let mut data = Vec::new();
964        data.extend_from_slice(&encode_avro_int(0));
965        data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2]));
966        data.extend_from_slice(&encode_avro_int(1));
967        data.extend_from_slice(&encode_avro_int(0));
968        data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E]));
969        let mut cursor = AvroCursor::new(&data);
970        decoder.decode(&mut cursor).unwrap(); // row1
971        decoder.decode(&mut cursor).unwrap(); // row2
972        decoder.decode(&mut cursor).unwrap(); // row3
973        let arr = decoder.flush(None).unwrap();
974        let dec_arr = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
975        assert_eq!(dec_arr.len(), 3);
976        assert!(dec_arr.is_valid(0));
977        assert!(!dec_arr.is_valid(1));
978        assert!(dec_arr.is_valid(2));
979        assert_eq!(dec_arr.value_as_string(0), "123.4");
980        assert_eq!(dec_arr.value_as_string(2), "-123.4");
981    }
982
983    #[test]
984    fn test_decimal_decoding_bytes_with_nulls_fixed_size() {
985        let dt = avro_from_codec(Codec::Decimal(6, Some(2), Some(16)));
986        let inner = Decoder::try_new(&dt).unwrap();
987        let mut decoder = Decoder::Nullable(
988            Nullability::NullSecond,
989            NullBufferBuilder::new(DEFAULT_CAPACITY),
990            Box::new(inner),
991        );
992        let row1 = [
993            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
994            0xE2, 0x40,
995        ];
996        let row3 = [
997            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
998            0x1D, 0xC0,
999        ];
1000        let mut data = Vec::new();
1001        data.extend_from_slice(&encode_avro_int(0));
1002        data.extend_from_slice(&row1);
1003        data.extend_from_slice(&encode_avro_int(1));
1004        data.extend_from_slice(&encode_avro_int(0));
1005        data.extend_from_slice(&row3);
1006        let mut cursor = AvroCursor::new(&data);
1007        decoder.decode(&mut cursor).unwrap();
1008        decoder.decode(&mut cursor).unwrap();
1009        decoder.decode(&mut cursor).unwrap();
1010        let arr = decoder.flush(None).unwrap();
1011        let dec_arr = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
1012        assert_eq!(dec_arr.len(), 3);
1013        assert!(dec_arr.is_valid(0));
1014        assert!(!dec_arr.is_valid(1));
1015        assert!(dec_arr.is_valid(2));
1016        assert_eq!(dec_arr.value_as_string(0), "1234.56");
1017        assert_eq!(dec_arr.value_as_string(2), "-1234.56");
1018    }
1019
1020    #[test]
1021    fn test_enum_decoding() {
1022        let symbols: Arc<[String]> = vec!["A", "B", "C"].into_iter().map(String::from).collect();
1023        let avro_type = avro_from_codec(Codec::Enum(symbols.clone()));
1024        let mut decoder = Decoder::try_new(&avro_type).unwrap();
1025        let mut data = Vec::new();
1026        data.extend_from_slice(&encode_avro_int(2));
1027        data.extend_from_slice(&encode_avro_int(0));
1028        data.extend_from_slice(&encode_avro_int(1));
1029        let mut cursor = AvroCursor::new(&data);
1030        decoder.decode(&mut cursor).unwrap();
1031        decoder.decode(&mut cursor).unwrap();
1032        decoder.decode(&mut cursor).unwrap();
1033        let array = decoder.flush(None).unwrap();
1034        let dict_array = array
1035            .as_any()
1036            .downcast_ref::<DictionaryArray<Int32Type>>()
1037            .unwrap();
1038
1039        assert_eq!(dict_array.len(), 3);
1040        let values = dict_array
1041            .values()
1042            .as_any()
1043            .downcast_ref::<StringArray>()
1044            .unwrap();
1045        assert_eq!(values.value(0), "A");
1046        assert_eq!(values.value(1), "B");
1047        assert_eq!(values.value(2), "C");
1048        assert_eq!(dict_array.keys().values(), &[2, 0, 1]);
1049    }
1050
1051    #[test]
1052    fn test_enum_decoding_with_nulls() {
1053        let symbols: Arc<[String]> = vec!["X", "Y"].into_iter().map(String::from).collect();
1054        let enum_codec = Codec::Enum(symbols.clone());
1055        let avro_type =
1056            AvroDataType::new(enum_codec, Default::default(), Some(Nullability::NullFirst));
1057        let mut decoder = Decoder::try_new(&avro_type).unwrap();
1058        let mut data = Vec::new();
1059        data.extend_from_slice(&encode_avro_long(1));
1060        data.extend_from_slice(&encode_avro_int(1));
1061        data.extend_from_slice(&encode_avro_long(0));
1062        data.extend_from_slice(&encode_avro_long(1));
1063        data.extend_from_slice(&encode_avro_int(0));
1064        let mut cursor = AvroCursor::new(&data);
1065        decoder.decode(&mut cursor).unwrap();
1066        decoder.decode(&mut cursor).unwrap();
1067        decoder.decode(&mut cursor).unwrap();
1068        let array = decoder.flush(None).unwrap();
1069        let dict_array = array
1070            .as_any()
1071            .downcast_ref::<DictionaryArray<Int32Type>>()
1072            .unwrap();
1073        assert_eq!(dict_array.len(), 3);
1074        assert!(dict_array.is_valid(0));
1075        assert!(dict_array.is_null(1));
1076        assert!(dict_array.is_valid(2));
1077        let expected_keys = Int32Array::from(vec![Some(1), None, Some(0)]);
1078        assert_eq!(dict_array.keys(), &expected_keys);
1079        let values = dict_array
1080            .values()
1081            .as_any()
1082            .downcast_ref::<StringArray>()
1083            .unwrap();
1084        assert_eq!(values.value(0), "X");
1085        assert_eq!(values.value(1), "Y");
1086    }
1087}