arrow_avro/reader/
record.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::codec::{AvroDataType, Codec, Nullability};
19use crate::reader::block::{Block, BlockDecoder};
20use crate::reader::cursor::AvroCursor;
21use crate::reader::header::Header;
22use crate::schema::*;
23use arrow_array::types::*;
24use arrow_array::*;
25use arrow_buffer::*;
26use arrow_schema::{
27    ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
28};
29use std::collections::HashMap;
30use std::io::Read;
31use std::sync::Arc;
32
33/// Decodes avro encoded data into [`RecordBatch`]
34pub struct RecordDecoder {
35    schema: SchemaRef,
36    fields: Vec<Decoder>,
37}
38
39impl RecordDecoder {
40    /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`]
41    pub fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
42        match Decoder::try_new(data_type)? {
43            Decoder::Record(fields, encodings) => Ok(Self {
44                schema: Arc::new(ArrowSchema::new(fields)),
45                fields: encodings,
46            }),
47            encoding => Err(ArrowError::ParseError(format!(
48                "Expected record got {encoding:?}"
49            ))),
50        }
51    }
52
53    pub fn schema(&self) -> &SchemaRef {
54        &self.schema
55    }
56
57    /// Decode `count` records from `buf`
58    pub fn decode(&mut self, buf: &[u8], count: usize) -> Result<usize, ArrowError> {
59        let mut cursor = AvroCursor::new(buf);
60        for _ in 0..count {
61            for field in &mut self.fields {
62                field.decode(&mut cursor)?;
63            }
64        }
65        Ok(cursor.position())
66    }
67
68    /// Flush the decoded records into a [`RecordBatch`]
69    pub fn flush(&mut self) -> Result<RecordBatch, ArrowError> {
70        let arrays = self
71            .fields
72            .iter_mut()
73            .map(|x| x.flush(None))
74            .collect::<Result<Vec<_>, _>>()?;
75
76        RecordBatch::try_new(self.schema.clone(), arrays)
77    }
78}
79
80#[derive(Debug)]
81enum Decoder {
82    Null(usize),
83    Boolean(BooleanBufferBuilder),
84    Int32(Vec<i32>),
85    Int64(Vec<i64>),
86    Float32(Vec<f32>),
87    Float64(Vec<f64>),
88    Date32(Vec<i32>),
89    TimeMillis(Vec<i32>),
90    TimeMicros(Vec<i64>),
91    TimestampMillis(bool, Vec<i64>),
92    TimestampMicros(bool, Vec<i64>),
93    Binary(OffsetBufferBuilder<i32>, Vec<u8>),
94    String(OffsetBufferBuilder<i32>, Vec<u8>),
95    List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
96    Record(Fields, Vec<Decoder>),
97    Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
98}
99
100impl Decoder {
101    fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
102        let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string()));
103
104        let decoder = match data_type.codec() {
105            Codec::Null => Self::Null(0),
106            Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)),
107            Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)),
108            Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)),
109            Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)),
110            Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)),
111            Codec::Binary => Self::Binary(
112                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
113                Vec::with_capacity(DEFAULT_CAPACITY),
114            ),
115            Codec::Utf8 => Self::String(
116                OffsetBufferBuilder::new(DEFAULT_CAPACITY),
117                Vec::with_capacity(DEFAULT_CAPACITY),
118            ),
119            Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)),
120            Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)),
121            Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)),
122            Codec::TimestampMillis(is_utc) => {
123                Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
124            }
125            Codec::TimestampMicros(is_utc) => {
126                Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
127            }
128            Codec::Fixed(_) => return nyi("decoding fixed"),
129            Codec::Interval => return nyi("decoding interval"),
130            Codec::List(item) => {
131                let decoder = Self::try_new(item)?;
132                Self::List(
133                    Arc::new(item.field_with_name("item")),
134                    OffsetBufferBuilder::new(DEFAULT_CAPACITY),
135                    Box::new(decoder),
136                )
137            }
138            Codec::Struct(fields) => {
139                let mut arrow_fields = Vec::with_capacity(fields.len());
140                let mut encodings = Vec::with_capacity(fields.len());
141                for avro_field in fields.iter() {
142                    let encoding = Self::try_new(avro_field.data_type())?;
143                    arrow_fields.push(avro_field.field());
144                    encodings.push(encoding);
145                }
146                Self::Record(arrow_fields.into(), encodings)
147            }
148        };
149
150        Ok(match data_type.nullability() {
151            Some(nullability) => Self::Nullable(
152                nullability,
153                NullBufferBuilder::new(DEFAULT_CAPACITY),
154                Box::new(decoder),
155            ),
156            None => decoder,
157        })
158    }
159
160    /// Append a null record
161    fn append_null(&mut self) {
162        match self {
163            Self::Null(count) => *count += 1,
164            Self::Boolean(b) => b.append(false),
165            Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0),
166            Self::Int64(v)
167            | Self::TimeMicros(v)
168            | Self::TimestampMillis(_, v)
169            | Self::TimestampMicros(_, v) => v.push(0),
170            Self::Float32(v) => v.push(0.),
171            Self::Float64(v) => v.push(0.),
172            Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0),
173            Self::List(_, offsets, e) => {
174                offsets.push_length(0);
175                e.append_null();
176            }
177            Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
178            Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
179        }
180    }
181
182    /// Decode a single record from `buf`
183    fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
184        match self {
185            Self::Null(x) => *x += 1,
186            Self::Boolean(values) => values.append(buf.get_bool()?),
187            Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => {
188                values.push(buf.get_int()?)
189            }
190            Self::Int64(values)
191            | Self::TimeMicros(values)
192            | Self::TimestampMillis(_, values)
193            | Self::TimestampMicros(_, values) => values.push(buf.get_long()?),
194            Self::Float32(values) => values.push(buf.get_float()?),
195            Self::Float64(values) => values.push(buf.get_double()?),
196            Self::Binary(offsets, values) | Self::String(offsets, values) => {
197                let data = buf.get_bytes()?;
198                offsets.push_length(data.len());
199                values.extend_from_slice(data);
200            }
201            Self::List(_, _, _) => {
202                return Err(ArrowError::NotYetImplemented(
203                    "Decoding ListArray".to_string(),
204                ))
205            }
206            Self::Record(_, encodings) => {
207                for encoding in encodings {
208                    encoding.decode(buf)?;
209                }
210            }
211            Self::Nullable(nullability, nulls, e) => {
212                let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
213                nulls.append(is_valid);
214                match is_valid {
215                    true => e.decode(buf)?,
216                    false => e.append_null(),
217                }
218            }
219        }
220        Ok(())
221    }
222
223    /// Flush decoded records to an [`ArrayRef`]
224    fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, ArrowError> {
225        Ok(match self {
226            Self::Nullable(_, n, e) => e.flush(n.finish())?,
227            Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))),
228            Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)),
229            Self::Int32(values) => Arc::new(flush_primitive::<Int32Type>(values, nulls)),
230            Self::Date32(values) => Arc::new(flush_primitive::<Date32Type>(values, nulls)),
231            Self::Int64(values) => Arc::new(flush_primitive::<Int64Type>(values, nulls)),
232            Self::TimeMillis(values) => {
233                Arc::new(flush_primitive::<Time32MillisecondType>(values, nulls))
234            }
235            Self::TimeMicros(values) => {
236                Arc::new(flush_primitive::<Time64MicrosecondType>(values, nulls))
237            }
238            Self::TimestampMillis(is_utc, values) => Arc::new(
239                flush_primitive::<TimestampMillisecondType>(values, nulls)
240                    .with_timezone_opt(is_utc.then(|| "+00:00")),
241            ),
242            Self::TimestampMicros(is_utc, values) => Arc::new(
243                flush_primitive::<TimestampMicrosecondType>(values, nulls)
244                    .with_timezone_opt(is_utc.then(|| "+00:00")),
245            ),
246            Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
247            Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
248
249            Self::Binary(offsets, values) => {
250                let offsets = flush_offsets(offsets);
251                let values = flush_values(values).into();
252                Arc::new(BinaryArray::new(offsets, values, nulls))
253            }
254            Self::String(offsets, values) => {
255                let offsets = flush_offsets(offsets);
256                let values = flush_values(values).into();
257                Arc::new(StringArray::new(offsets, values, nulls))
258            }
259            Self::List(field, offsets, values) => {
260                let values = values.flush(None)?;
261                let offsets = flush_offsets(offsets);
262                Arc::new(ListArray::new(field.clone(), offsets, values, nulls))
263            }
264            Self::Record(fields, encodings) => {
265                let arrays = encodings
266                    .iter_mut()
267                    .map(|x| x.flush(None))
268                    .collect::<Result<Vec<_>, _>>()?;
269                Arc::new(StructArray::new(fields.clone(), arrays, nulls))
270            }
271        })
272    }
273}
274
275#[inline]
276fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
277    std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
278}
279
280#[inline]
281fn flush_offsets(offsets: &mut OffsetBufferBuilder<i32>) -> OffsetBuffer<i32> {
282    std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish()
283}
284
285#[inline]
286fn flush_primitive<T: ArrowPrimitiveType>(
287    values: &mut Vec<T::Native>,
288    nulls: Option<NullBuffer>,
289) -> PrimitiveArray<T> {
290    PrimitiveArray::new(flush_values(values).into(), nulls)
291}
292
293const DEFAULT_CAPACITY: usize = 1024;