arrow_avro/reader/
record.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::codec::{AvroDataType, Codec, Nullability};
19use crate::reader::block::{Block, BlockDecoder};
20use crate::reader::cursor::AvroCursor;
21use crate::reader::header::Header;
22use crate::reader::ReadOptions;
23use crate::schema::*;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::*;
27use arrow_schema::{
28    ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
29};
30use std::cmp::Ordering;
31use std::collections::HashMap;
32use std::io::Read;
33use std::sync::Arc;
34
35/// Decodes avro encoded data into [`RecordBatch`]
36pub struct RecordDecoder {
37    schema: SchemaRef,
38    fields: Vec<Decoder>,
39    use_utf8view: bool,
40}
41
42impl RecordDecoder {
43    /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with default options
44    pub fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
45        Self::try_new_with_options(data_type, ReadOptions::default())
46    }
47
48    /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options
49    ///
50    /// This method allows you to customize how the Avro data is decoded into Arrow arrays.
51    ///
52    /// # Parameters
53    /// * `data_type` - The Avro data type to decode
54    /// * `options` - Configuration options for decoding
55    pub fn try_new_with_options(
56        data_type: &AvroDataType,
57        options: ReadOptions,
58    ) -> Result<Self, ArrowError> {
59        match Decoder::try_new(data_type)? {
60            Decoder::Record(fields, encodings) => Ok(Self {
61                schema: Arc::new(ArrowSchema::new(fields)),
62                fields: encodings,
63                use_utf8view: options.use_utf8view(),
64            }),
65            encoding => Err(ArrowError::ParseError(format!(
66                "Expected record got {encoding:?}"
67            ))),
68        }
69    }
70
71    pub fn schema(&self) -> &SchemaRef {
72        &self.schema
73    }
74
75    /// Decode `count` records from `buf`
76    pub fn decode(&mut self, buf: &[u8], count: usize) -> Result<usize, ArrowError> {
77        let mut cursor = AvroCursor::new(buf);
78        for _ in 0..count {
79            for field in &mut self.fields {
80                field.decode(&mut cursor)?;
81            }
82        }
83        Ok(cursor.position())
84    }
85
86    /// Flush the decoded records into a [`RecordBatch`]
87    pub fn flush(&mut self) -> Result<RecordBatch, ArrowError> {
88        let arrays = self
89            .fields
90            .iter_mut()
91            .map(|x| x.flush(None))
92            .collect::<Result<Vec<_>, _>>()?;
93
94        RecordBatch::try_new(self.schema.clone(), arrays)
95    }
96}
97
98#[derive(Debug)]
99enum Decoder {
100    Null(usize),
101    Boolean(BooleanBufferBuilder),
102    Int32(Vec<i32>),
103    Int64(Vec<i64>),
104    Float32(Vec<f32>),
105    Float64(Vec<f64>),
106    Date32(Vec<i32>),
107    TimeMillis(Vec<i32>),
108    TimeMicros(Vec<i64>),
109    TimestampMillis(bool, Vec<i64>),
110    TimestampMicros(bool, Vec<i64>),
111    Binary(OffsetBufferBuilder<i32>, Vec<u8>),
112    /// String data encoded as UTF-8 bytes, mapped to Arrow's StringArray
113    String(OffsetBufferBuilder<i32>, Vec<u8>),
114    /// String data encoded as UTF-8 bytes, but mapped to Arrow's StringViewArray
115    StringView(OffsetBufferBuilder<i32>, Vec<u8>),
116    List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
117    Record(Fields, Vec<Decoder>),
118    Map(
119        FieldRef,
120        OffsetBufferBuilder<i32>,
121        OffsetBufferBuilder<i32>,
122        Vec<u8>,
123        Box<Decoder>,
124    ),
125    Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
126}
127
128impl Decoder {
129    fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
130        let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string()));
131
132        let decoder = match data_type.codec() {
133            Codec::Null => Self::Null(0),
134            Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)),
135            Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)),
136            Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)),
137            Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)),
138            Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)),
139            Codec::Binary => Self::Binary(
140                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
141                Vec::with_capacity(DEFAULT_CAPACITY),
142            ),
143            Codec::Utf8 => Self::String(
144                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
145                Vec::with_capacity(DEFAULT_CAPACITY),
146            ),
147            Codec::Utf8View => Self::StringView(
148                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
149                Vec::with_capacity(DEFAULT_CAPACITY),
150            ),
151            Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)),
152            Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)),
153            Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)),
154            Codec::TimestampMillis(is_utc) => {
155                Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
156            }
157            Codec::TimestampMicros(is_utc) => {
158                Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
159            }
160            Codec::Fixed(_) => return nyi("decoding fixed"),
161            Codec::Interval => return nyi("decoding interval"),
162            Codec::List(item) => {
163                let decoder = Self::try_new(item)?;
164                Self::List(
165                    Arc::new(item.field_with_name("item")),
166                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
167                    Box::new(decoder),
168                )
169            }
170            Codec::Struct(fields) => {
171                let mut arrow_fields = Vec::with_capacity(fields.len());
172                let mut encodings = Vec::with_capacity(fields.len());
173                for avro_field in fields.iter() {
174                    let encoding = Self::try_new(avro_field.data_type())?;
175                    arrow_fields.push(avro_field.field());
176                    encodings.push(encoding);
177                }
178                Self::Record(arrow_fields.into(), encodings)
179            }
180            Codec::Map(child) => {
181                let val_field = child.field_with_name("value").with_nullable(true);
182                let map_field = Arc::new(ArrowField::new(
183                    "entries",
184                    DataType::Struct(Fields::from(vec![
185                        ArrowField::new("key", DataType::Utf8, false),
186                        val_field,
187                    ])),
188                    false,
189                ));
190                let val_dec = Self::try_new(child)?;
191                Self::Map(
192                    map_field,
193                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
194                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
195                    Vec::with_capacity(DEFAULT_CAPACITY),
196                    Box::new(val_dec),
197                )
198            }
199        };
200
201        Ok(match data_type.nullability() {
202            Some(nullability) => Self::Nullable(
203                nullability,
204                NullBufferBuilder::new(DEFAULT_CAPACITY),
205                Box::new(decoder),
206            ),
207            None => decoder,
208        })
209    }
210
211    /// Append a null record
212    fn append_null(&mut self) {
213        match self {
214            Self::Null(count) => *count += 1,
215            Self::Boolean(b) => b.append(false),
216            Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0),
217            Self::Int64(v)
218            | Self::TimeMicros(v)
219            | Self::TimestampMillis(_, v)
220            | Self::TimestampMicros(_, v) => v.push(0),
221            Self::Float32(v) => v.push(0.),
222            Self::Float64(v) => v.push(0.),
223            Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => {
224                offsets.push_length(0);
225            }
226            Self::List(_, offsets, e) => {
227                offsets.push_length(0);
228                e.append_null();
229            }
230            Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
231            Self::Map(_, _koff, moff, _, _) => {
232                moff.push_length(0);
233            }
234            Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
235        }
236    }
237
238    /// Decode a single record from `buf`
239    fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
240        match self {
241            Self::Null(x) => *x += 1,
242            Self::Boolean(values) => values.append(buf.get_bool()?),
243            Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => {
244                values.push(buf.get_int()?)
245            }
246            Self::Int64(values)
247            | Self::TimeMicros(values)
248            | Self::TimestampMillis(_, values)
249            | Self::TimestampMicros(_, values) => values.push(buf.get_long()?),
250            Self::Float32(values) => values.push(buf.get_float()?),
251            Self::Float64(values) => values.push(buf.get_double()?),
252            Self::Binary(offsets, values)
253            | Self::String(offsets, values)
254            | Self::StringView(offsets, values) => {
255                let data = buf.get_bytes()?;
256                offsets.push_length(data.len());
257                values.extend_from_slice(data);
258            }
259            Self::List(_, _, _) => {
260                return Err(ArrowError::NotYetImplemented(
261                    "Decoding ListArray".to_string(),
262                ))
263            }
264            Self::Record(_, encodings) => {
265                for encoding in encodings {
266                    encoding.decode(buf)?;
267                }
268            }
269            Self::Map(_, koff, moff, kdata, valdec) => {
270                let newly_added = read_map_blocks(buf, |cur| {
271                    let kb = cur.get_bytes()?;
272                    koff.push_length(kb.len());
273                    kdata.extend_from_slice(kb);
274                    valdec.decode(cur)
275                })?;
276                moff.push_length(newly_added);
277            }
278            Self::Nullable(nullability, nulls, e) => {
279                let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
280                nulls.append(is_valid);
281                match is_valid {
282                    true => e.decode(buf)?,
283                    false => e.append_null(),
284                }
285            }
286        }
287        Ok(())
288    }
289
290    /// Flush decoded records to an [`ArrayRef`]
291    fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, ArrowError> {
292        Ok(match self {
293            Self::Nullable(_, n, e) => e.flush(n.finish())?,
294            Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))),
295            Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)),
296            Self::Int32(values) => Arc::new(flush_primitive::<Int32Type>(values, nulls)),
297            Self::Date32(values) => Arc::new(flush_primitive::<Date32Type>(values, nulls)),
298            Self::Int64(values) => Arc::new(flush_primitive::<Int64Type>(values, nulls)),
299            Self::TimeMillis(values) => {
300                Arc::new(flush_primitive::<Time32MillisecondType>(values, nulls))
301            }
302            Self::TimeMicros(values) => {
303                Arc::new(flush_primitive::<Time64MicrosecondType>(values, nulls))
304            }
305            Self::TimestampMillis(is_utc, values) => Arc::new(
306                flush_primitive::<TimestampMillisecondType>(values, nulls)
307                    .with_timezone_opt(is_utc.then(|| "+00:00")),
308            ),
309            Self::TimestampMicros(is_utc, values) => Arc::new(
310                flush_primitive::<TimestampMicrosecondType>(values, nulls)
311                    .with_timezone_opt(is_utc.then(|| "+00:00")),
312            ),
313            Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
314            Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
315            Self::Binary(offsets, values) => {
316                let offsets = flush_offsets(offsets);
317                let values = flush_values(values).into();
318                Arc::new(BinaryArray::new(offsets, values, nulls))
319            }
320            Self::String(offsets, values) => {
321                let offsets = flush_offsets(offsets);
322                let values = flush_values(values).into();
323                Arc::new(StringArray::new(offsets, values, nulls))
324            }
325            Self::StringView(offsets, values) => {
326                let offsets = flush_offsets(offsets);
327                let values = flush_values(values);
328                let array = StringArray::new(offsets, values.into(), nulls.clone());
329
330                let values: Vec<&str> = (0..array.len())
331                    .map(|i| {
332                        if array.is_valid(i) {
333                            array.value(i)
334                        } else {
335                            ""
336                        }
337                    })
338                    .collect();
339
340                Arc::new(StringViewArray::from(values))
341            }
342            Self::List(field, offsets, values) => {
343                let values = values.flush(None)?;
344                let offsets = flush_offsets(offsets);
345                Arc::new(ListArray::new(field.clone(), offsets, values, nulls))
346            }
347            Self::Record(fields, encodings) => {
348                let arrays = encodings
349                    .iter_mut()
350                    .map(|x| x.flush(None))
351                    .collect::<Result<Vec<_>, _>>()?;
352                Arc::new(StructArray::new(fields.clone(), arrays, nulls))
353            }
354            Self::Map(map_field, k_off, m_off, kdata, valdec) => {
355                let moff = flush_offsets(m_off);
356                let koff = flush_offsets(k_off);
357                let kd = flush_values(kdata).into();
358                let val_arr = valdec.flush(None)?;
359                let key_arr = StringArray::new(koff, kd, None);
360                if key_arr.len() != val_arr.len() {
361                    return Err(ArrowError::InvalidArgumentError(format!(
362                        "Map keys length ({}) != map values length ({})",
363                        key_arr.len(),
364                        val_arr.len()
365                    )));
366                }
367                let final_len = moff.len() - 1;
368                if let Some(n) = &nulls {
369                    if n.len() != final_len {
370                        return Err(ArrowError::InvalidArgumentError(format!(
371                            "Map array null buffer length {} != final map length {final_len}",
372                            n.len()
373                        )));
374                    }
375                }
376                let entries_struct = StructArray::new(
377                    Fields::from(vec![
378                        Arc::new(ArrowField::new("key", DataType::Utf8, false)),
379                        Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)),
380                    ]),
381                    vec![Arc::new(key_arr), val_arr],
382                    None,
383                );
384                let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false);
385                Arc::new(map_arr)
386            }
387        })
388    }
389}
390
391fn read_map_blocks(
392    buf: &mut AvroCursor,
393    decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
394) -> Result<usize, ArrowError> {
395    read_blockwise_items(buf, true, decode_entry)
396}
397
398fn read_blockwise_items(
399    buf: &mut AvroCursor,
400    read_size_after_negative: bool,
401    mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
402) -> Result<usize, ArrowError> {
403    let mut total = 0usize;
404    loop {
405        // Read the block count
406        //  positive = that many items
407        //  negative = that many items + read block size
408        //  See: https://avro.apache.org/docs/1.11.1/specification/#maps
409        let block_count = buf.get_long()?;
410        match block_count.cmp(&0) {
411            Ordering::Equal => break,
412            Ordering::Less => {
413                // If block_count is negative, read the absolute value of count,
414                // then read the block size as a long and discard
415                let count = (-block_count) as usize;
416                if read_size_after_negative {
417                    let _size_in_bytes = buf.get_long()?;
418                }
419                for _ in 0..count {
420                    decode_fn(buf)?;
421                }
422                total += count;
423            }
424            Ordering::Greater => {
425                // If block_count is positive, decode that many items
426                let count = block_count as usize;
427                for _i in 0..count {
428                    decode_fn(buf)?;
429                }
430                total += count;
431            }
432        }
433    }
434    Ok(total)
435}
436
437#[inline]
438fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
439    std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
440}
441
442#[inline]
443fn flush_offsets(offsets: &mut OffsetBufferBuilder<i32>) -> OffsetBuffer<i32> {
444    std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish()
445}
446
447#[inline]
448fn flush_primitive<T: ArrowPrimitiveType>(
449    values: &mut Vec<T::Native>,
450    nulls: Option<NullBuffer>,
451) -> PrimitiveArray<T> {
452    PrimitiveArray::new(flush_values(values).into(), nulls)
453}
454
455const DEFAULT_CAPACITY: usize = 1024;
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use arrow_array::{
461        cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray,
462        IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray,
463    };
464
465    fn encode_avro_long(value: i64) -> Vec<u8> {
466        let mut buf = Vec::new();
467        let mut v = (value << 1) ^ (value >> 63);
468        while v & !0x7F != 0 {
469            buf.push(((v & 0x7F) | 0x80) as u8);
470            v >>= 7;
471        }
472        buf.push(v as u8);
473        buf
474    }
475
476    fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
477        let mut buf = encode_avro_long(bytes.len() as i64);
478        buf.extend_from_slice(bytes);
479        buf
480    }
481
482    fn avro_from_codec(codec: Codec) -> AvroDataType {
483        AvroDataType::new(codec, Default::default(), None)
484    }
485
486    #[test]
487    fn test_map_decoding_one_entry() {
488        let value_type = avro_from_codec(Codec::Utf8);
489        let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
490        let mut decoder = Decoder::try_new(&map_type).unwrap();
491        // Encode a single map with one entry: {"hello": "world"}
492        let mut data = Vec::new();
493        data.extend_from_slice(&encode_avro_long(1));
494        data.extend_from_slice(&encode_avro_bytes(b"hello")); // key
495        data.extend_from_slice(&encode_avro_bytes(b"world")); // value
496        data.extend_from_slice(&encode_avro_long(0));
497        let mut cursor = AvroCursor::new(&data);
498        decoder.decode(&mut cursor).unwrap();
499        let array = decoder.flush(None).unwrap();
500        let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
501        assert_eq!(map_arr.len(), 1); // one map
502        assert_eq!(map_arr.value_length(0), 1);
503        let entries = map_arr.value(0);
504        let struct_entries = entries.as_any().downcast_ref::<StructArray>().unwrap();
505        assert_eq!(struct_entries.len(), 1);
506        let key_arr = struct_entries
507            .column_by_name("key")
508            .unwrap()
509            .as_any()
510            .downcast_ref::<StringArray>()
511            .unwrap();
512        let val_arr = struct_entries
513            .column_by_name("value")
514            .unwrap()
515            .as_any()
516            .downcast_ref::<StringArray>()
517            .unwrap();
518        assert_eq!(key_arr.value(0), "hello");
519        assert_eq!(val_arr.value(0), "world");
520    }
521
522    #[test]
523    fn test_map_decoding_empty() {
524        let value_type = avro_from_codec(Codec::Utf8);
525        let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
526        let mut decoder = Decoder::try_new(&map_type).unwrap();
527        let data = encode_avro_long(0);
528        decoder.decode(&mut AvroCursor::new(&data)).unwrap();
529        let array = decoder.flush(None).unwrap();
530        let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
531        assert_eq!(map_arr.len(), 1);
532        assert_eq!(map_arr.value_length(0), 0);
533    }
534}