arrow_avro/reader/
record.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::codec::{AvroDataType, Codec, Nullability};
19use crate::reader::block::{Block, BlockDecoder};
20use crate::reader::cursor::AvroCursor;
21use crate::reader::header::Header;
22use crate::schema::*;
23use arrow_array::types::*;
24use arrow_array::*;
25use arrow_buffer::*;
26use arrow_schema::{
27    ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
28};
29use std::collections::HashMap;
30use std::io::Read;
31use std::sync::Arc;
32
33/// Decodes avro encoded data into [`RecordBatch`]
34pub struct RecordDecoder {
35    schema: SchemaRef,
36    fields: Vec<Decoder>,
37}
38
39impl RecordDecoder {
40    pub fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
41        match Decoder::try_new(data_type)? {
42            Decoder::Record(fields, encodings) => Ok(Self {
43                schema: Arc::new(ArrowSchema::new(fields)),
44                fields: encodings,
45            }),
46            encoding => Err(ArrowError::ParseError(format!(
47                "Expected record got {encoding:?}"
48            ))),
49        }
50    }
51
52    pub fn schema(&self) -> &SchemaRef {
53        &self.schema
54    }
55
56    /// Decode `count` records from `buf`
57    pub fn decode(&mut self, buf: &[u8], count: usize) -> Result<usize, ArrowError> {
58        let mut cursor = AvroCursor::new(buf);
59        for _ in 0..count {
60            for field in &mut self.fields {
61                field.decode(&mut cursor)?;
62            }
63        }
64        Ok(cursor.position())
65    }
66
67    /// Flush the decoded records into a [`RecordBatch`]
68    pub fn flush(&mut self) -> Result<RecordBatch, ArrowError> {
69        let arrays = self
70            .fields
71            .iter_mut()
72            .map(|x| x.flush(None))
73            .collect::<Result<Vec<_>, _>>()?;
74
75        RecordBatch::try_new(self.schema.clone(), arrays)
76    }
77}
78
79#[derive(Debug)]
80enum Decoder {
81    Null(usize),
82    Boolean(BooleanBufferBuilder),
83    Int32(Vec<i32>),
84    Int64(Vec<i64>),
85    Float32(Vec<f32>),
86    Float64(Vec<f64>),
87    Date32(Vec<i32>),
88    TimeMillis(Vec<i32>),
89    TimeMicros(Vec<i64>),
90    TimestampMillis(bool, Vec<i64>),
91    TimestampMicros(bool, Vec<i64>),
92    Binary(OffsetBufferBuilder<i32>, Vec<u8>),
93    String(OffsetBufferBuilder<i32>, Vec<u8>),
94    List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
95    Record(Fields, Vec<Decoder>),
96    Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
97}
98
99impl Decoder {
100    fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
101        let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string()));
102
103        let decoder = match data_type.codec() {
104            Codec::Null => Self::Null(0),
105            Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)),
106            Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)),
107            Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)),
108            Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)),
109            Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)),
110            Codec::Binary => Self::Binary(
111                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
112                Vec::with_capacity(DEFAULT_CAPACITY),
113            ),
114            Codec::Utf8 => Self::String(
115                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
116                Vec::with_capacity(DEFAULT_CAPACITY),
117            ),
118            Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)),
119            Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)),
120            Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)),
121            Codec::TimestampMillis(is_utc) => {
122                Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
123            }
124            Codec::TimestampMicros(is_utc) => {
125                Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
126            }
127            Codec::Fixed(_) => return nyi("decoding fixed"),
128            Codec::Interval => return nyi("decoding interval"),
129            Codec::List(item) => {
130                let decoder = Self::try_new(item)?;
131                Self::List(
132                    Arc::new(item.field_with_name("item")),
133                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
134                    Box::new(decoder),
135                )
136            }
137            Codec::Struct(fields) => {
138                let mut arrow_fields = Vec::with_capacity(fields.len());
139                let mut encodings = Vec::with_capacity(fields.len());
140                for avro_field in fields.iter() {
141                    let encoding = Self::try_new(avro_field.data_type())?;
142                    arrow_fields.push(avro_field.field());
143                    encodings.push(encoding);
144                }
145                Self::Record(arrow_fields.into(), encodings)
146            }
147        };
148
149        Ok(match data_type.nullability() {
150            Some(nullability) => Self::Nullable(
151                nullability,
152                NullBufferBuilder::new(DEFAULT_CAPACITY),
153                Box::new(decoder),
154            ),
155            None => decoder,
156        })
157    }
158
159    /// Append a null record
160    fn append_null(&mut self) {
161        match self {
162            Self::Null(count) => *count += 1,
163            Self::Boolean(b) => b.append(false),
164            Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0),
165            Self::Int64(v)
166            | Self::TimeMicros(v)
167            | Self::TimestampMillis(_, v)
168            | Self::TimestampMicros(_, v) => v.push(0),
169            Self::Float32(v) => v.push(0.),
170            Self::Float64(v) => v.push(0.),
171            Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0),
172            Self::List(_, offsets, e) => {
173                offsets.push_length(0);
174                e.append_null();
175            }
176            Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
177            Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
178        }
179    }
180
181    /// Decode a single record from `buf`
182    fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
183        match self {
184            Self::Null(x) => *x += 1,
185            Self::Boolean(values) => values.append(buf.get_bool()?),
186            Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => {
187                values.push(buf.get_int()?)
188            }
189            Self::Int64(values)
190            | Self::TimeMicros(values)
191            | Self::TimestampMillis(_, values)
192            | Self::TimestampMicros(_, values) => values.push(buf.get_long()?),
193            Self::Float32(values) => values.push(buf.get_float()?),
194            Self::Float64(values) => values.push(buf.get_double()?),
195            Self::Binary(offsets, values) | Self::String(offsets, values) => {
196                let data = buf.get_bytes()?;
197                offsets.push_length(data.len());
198                values.extend_from_slice(data);
199            }
200            Self::List(_, _, _) => {
201                return Err(ArrowError::NotYetImplemented(
202                    "Decoding ListArray".to_string(),
203                ))
204            }
205            Self::Record(_, encodings) => {
206                for encoding in encodings {
207                    encoding.decode(buf)?;
208                }
209            }
210            Self::Nullable(nullability, nulls, e) => {
211                let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
212                nulls.append(is_valid);
213                match is_valid {
214                    true => e.decode(buf)?,
215                    false => e.append_null(),
216                }
217            }
218        }
219        Ok(())
220    }
221
222    /// Flush decoded records to an [`ArrayRef`]
223    fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, ArrowError> {
224        Ok(match self {
225            Self::Nullable(_, n, e) => e.flush(n.finish())?,
226            Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))),
227            Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)),
228            Self::Int32(values) => Arc::new(flush_primitive::<Int32Type>(values, nulls)),
229            Self::Date32(values) => Arc::new(flush_primitive::<Date32Type>(values, nulls)),
230            Self::Int64(values) => Arc::new(flush_primitive::<Int64Type>(values, nulls)),
231            Self::TimeMillis(values) => {
232                Arc::new(flush_primitive::<Time32MillisecondType>(values, nulls))
233            }
234            Self::TimeMicros(values) => {
235                Arc::new(flush_primitive::<Time64MicrosecondType>(values, nulls))
236            }
237            Self::TimestampMillis(is_utc, values) => Arc::new(
238                flush_primitive::<TimestampMillisecondType>(values, nulls)
239                    .with_timezone_opt(is_utc.then(|| "+00:00")),
240            ),
241            Self::TimestampMicros(is_utc, values) => Arc::new(
242                flush_primitive::<TimestampMicrosecondType>(values, nulls)
243                    .with_timezone_opt(is_utc.then(|| "+00:00")),
244            ),
245            Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
246            Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
247
248            Self::Binary(offsets, values) => {
249                let offsets = flush_offsets(offsets);
250                let values = flush_values(values).into();
251                Arc::new(BinaryArray::new(offsets, values, nulls))
252            }
253            Self::String(offsets, values) => {
254                let offsets = flush_offsets(offsets);
255                let values = flush_values(values).into();
256                Arc::new(StringArray::new(offsets, values, nulls))
257            }
258            Self::List(field, offsets, values) => {
259                let values = values.flush(None)?;
260                let offsets = flush_offsets(offsets);
261                Arc::new(ListArray::new(field.clone(), offsets, values, nulls))
262            }
263            Self::Record(fields, encodings) => {
264                let arrays = encodings
265                    .iter_mut()
266                    .map(|x| x.flush(None))
267                    .collect::<Result<Vec<_>, _>>()?;
268                Arc::new(StructArray::new(fields.clone(), arrays, nulls))
269            }
270        })
271    }
272}
273
274#[inline]
275fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
276    std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
277}
278
279#[inline]
280fn flush_offsets(offsets: &mut OffsetBufferBuilder<i32>) -> OffsetBuffer<i32> {
281    std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish()
282}
283
284#[inline]
285fn flush_primitive<T: ArrowPrimitiveType>(
286    values: &mut Vec<T::Native>,
287    nulls: Option<NullBuffer>,
288) -> PrimitiveArray<T> {
289    PrimitiveArray::new(flush_values(values).into(), nulls)
290}
291
292const DEFAULT_CAPACITY: usize = 1024;