1use crate::codec::{AvroDataType, Codec, Nullability};
19use crate::reader::block::{Block, BlockDecoder};
20use crate::reader::cursor::AvroCursor;
21use crate::reader::header::Header;
22use crate::reader::ReadOptions;
23use crate::schema::*;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::*;
27use arrow_schema::{
28 ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef,
29};
30use std::cmp::Ordering;
31use std::collections::HashMap;
32use std::io::Read;
33use std::sync::Arc;
34
35pub struct RecordDecoder {
37 schema: SchemaRef,
38 fields: Vec<Decoder>,
39 use_utf8view: bool,
40}
41
42impl RecordDecoder {
43 pub fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
45 Self::try_new_with_options(data_type, ReadOptions::default())
46 }
47
48 pub fn try_new_with_options(
56 data_type: &AvroDataType,
57 options: ReadOptions,
58 ) -> Result<Self, ArrowError> {
59 match Decoder::try_new(data_type)? {
60 Decoder::Record(fields, encodings) => Ok(Self {
61 schema: Arc::new(ArrowSchema::new(fields)),
62 fields: encodings,
63 use_utf8view: options.use_utf8view(),
64 }),
65 encoding => Err(ArrowError::ParseError(format!(
66 "Expected record got {encoding:?}"
67 ))),
68 }
69 }
70
71 pub fn schema(&self) -> &SchemaRef {
72 &self.schema
73 }
74
75 pub fn decode(&mut self, buf: &[u8], count: usize) -> Result<usize, ArrowError> {
77 let mut cursor = AvroCursor::new(buf);
78 for _ in 0..count {
79 for field in &mut self.fields {
80 field.decode(&mut cursor)?;
81 }
82 }
83 Ok(cursor.position())
84 }
85
86 pub fn flush(&mut self) -> Result<RecordBatch, ArrowError> {
88 let arrays = self
89 .fields
90 .iter_mut()
91 .map(|x| x.flush(None))
92 .collect::<Result<Vec<_>, _>>()?;
93
94 RecordBatch::try_new(self.schema.clone(), arrays)
95 }
96}
97
98#[derive(Debug)]
99enum Decoder {
100 Null(usize),
101 Boolean(BooleanBufferBuilder),
102 Int32(Vec<i32>),
103 Int64(Vec<i64>),
104 Float32(Vec<f32>),
105 Float64(Vec<f64>),
106 Date32(Vec<i32>),
107 TimeMillis(Vec<i32>),
108 TimeMicros(Vec<i64>),
109 TimestampMillis(bool, Vec<i64>),
110 TimestampMicros(bool, Vec<i64>),
111 Binary(OffsetBufferBuilder<i32>, Vec<u8>),
112 String(OffsetBufferBuilder<i32>, Vec<u8>),
114 StringView(OffsetBufferBuilder<i32>, Vec<u8>),
116 List(FieldRef, OffsetBufferBuilder<i32>, Box<Decoder>),
117 Record(Fields, Vec<Decoder>),
118 Map(
119 FieldRef,
120 OffsetBufferBuilder<i32>,
121 OffsetBufferBuilder<i32>,
122 Vec<u8>,
123 Box<Decoder>,
124 ),
125 Nullable(Nullability, NullBufferBuilder, Box<Decoder>),
126}
127
128impl Decoder {
129 fn try_new(data_type: &AvroDataType) -> Result<Self, ArrowError> {
130 let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string()));
131
132 let decoder = match data_type.codec() {
133 Codec::Null => Self::Null(0),
134 Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)),
135 Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)),
136 Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)),
137 Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)),
138 Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)),
139 Codec::Binary => Self::Binary(
140 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
141 Vec::with_capacity(DEFAULT_CAPACITY),
142 ),
143 Codec::Utf8 => Self::String(
144 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
145 Vec::with_capacity(DEFAULT_CAPACITY),
146 ),
147 Codec::Utf8View => Self::StringView(
148 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
149 Vec::with_capacity(DEFAULT_CAPACITY),
150 ),
151 Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)),
152 Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)),
153 Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)),
154 Codec::TimestampMillis(is_utc) => {
155 Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
156 }
157 Codec::TimestampMicros(is_utc) => {
158 Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY))
159 }
160 Codec::Fixed(_) => return nyi("decoding fixed"),
161 Codec::Interval => return nyi("decoding interval"),
162 Codec::List(item) => {
163 let decoder = Self::try_new(item)?;
164 Self::List(
165 Arc::new(item.field_with_name("item")),
166 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
167 Box::new(decoder),
168 )
169 }
170 Codec::Struct(fields) => {
171 let mut arrow_fields = Vec::with_capacity(fields.len());
172 let mut encodings = Vec::with_capacity(fields.len());
173 for avro_field in fields.iter() {
174 let encoding = Self::try_new(avro_field.data_type())?;
175 arrow_fields.push(avro_field.field());
176 encodings.push(encoding);
177 }
178 Self::Record(arrow_fields.into(), encodings)
179 }
180 Codec::Map(child) => {
181 let val_field = child.field_with_name("value").with_nullable(true);
182 let map_field = Arc::new(ArrowField::new(
183 "entries",
184 DataType::Struct(Fields::from(vec![
185 ArrowField::new("key", DataType::Utf8, false),
186 val_field,
187 ])),
188 false,
189 ));
190 let val_dec = Self::try_new(child)?;
191 Self::Map(
192 map_field,
193 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
194 OffsetBufferBuilder::new(DEFAULT_CAPACITY),
195 Vec::with_capacity(DEFAULT_CAPACITY),
196 Box::new(val_dec),
197 )
198 }
199 };
200
201 Ok(match data_type.nullability() {
202 Some(nullability) => Self::Nullable(
203 nullability,
204 NullBufferBuilder::new(DEFAULT_CAPACITY),
205 Box::new(decoder),
206 ),
207 None => decoder,
208 })
209 }
210
211 fn append_null(&mut self) {
213 match self {
214 Self::Null(count) => *count += 1,
215 Self::Boolean(b) => b.append(false),
216 Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0),
217 Self::Int64(v)
218 | Self::TimeMicros(v)
219 | Self::TimestampMillis(_, v)
220 | Self::TimestampMicros(_, v) => v.push(0),
221 Self::Float32(v) => v.push(0.),
222 Self::Float64(v) => v.push(0.),
223 Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => {
224 offsets.push_length(0);
225 }
226 Self::List(_, offsets, e) => {
227 offsets.push_length(0);
228 e.append_null();
229 }
230 Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()),
231 Self::Map(_, _koff, moff, _, _) => {
232 moff.push_length(0);
233 }
234 Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"),
235 }
236 }
237
238 fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> {
240 match self {
241 Self::Null(x) => *x += 1,
242 Self::Boolean(values) => values.append(buf.get_bool()?),
243 Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => {
244 values.push(buf.get_int()?)
245 }
246 Self::Int64(values)
247 | Self::TimeMicros(values)
248 | Self::TimestampMillis(_, values)
249 | Self::TimestampMicros(_, values) => values.push(buf.get_long()?),
250 Self::Float32(values) => values.push(buf.get_float()?),
251 Self::Float64(values) => values.push(buf.get_double()?),
252 Self::Binary(offsets, values)
253 | Self::String(offsets, values)
254 | Self::StringView(offsets, values) => {
255 let data = buf.get_bytes()?;
256 offsets.push_length(data.len());
257 values.extend_from_slice(data);
258 }
259 Self::List(_, _, _) => {
260 return Err(ArrowError::NotYetImplemented(
261 "Decoding ListArray".to_string(),
262 ))
263 }
264 Self::Record(_, encodings) => {
265 for encoding in encodings {
266 encoding.decode(buf)?;
267 }
268 }
269 Self::Map(_, koff, moff, kdata, valdec) => {
270 let newly_added = read_map_blocks(buf, |cur| {
271 let kb = cur.get_bytes()?;
272 koff.push_length(kb.len());
273 kdata.extend_from_slice(kb);
274 valdec.decode(cur)
275 })?;
276 moff.push_length(newly_added);
277 }
278 Self::Nullable(nullability, nulls, e) => {
279 let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst);
280 nulls.append(is_valid);
281 match is_valid {
282 true => e.decode(buf)?,
283 false => e.append_null(),
284 }
285 }
286 }
287 Ok(())
288 }
289
290 fn flush(&mut self, nulls: Option<NullBuffer>) -> Result<ArrayRef, ArrowError> {
292 Ok(match self {
293 Self::Nullable(_, n, e) => e.flush(n.finish())?,
294 Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))),
295 Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)),
296 Self::Int32(values) => Arc::new(flush_primitive::<Int32Type>(values, nulls)),
297 Self::Date32(values) => Arc::new(flush_primitive::<Date32Type>(values, nulls)),
298 Self::Int64(values) => Arc::new(flush_primitive::<Int64Type>(values, nulls)),
299 Self::TimeMillis(values) => {
300 Arc::new(flush_primitive::<Time32MillisecondType>(values, nulls))
301 }
302 Self::TimeMicros(values) => {
303 Arc::new(flush_primitive::<Time64MicrosecondType>(values, nulls))
304 }
305 Self::TimestampMillis(is_utc, values) => Arc::new(
306 flush_primitive::<TimestampMillisecondType>(values, nulls)
307 .with_timezone_opt(is_utc.then(|| "+00:00")),
308 ),
309 Self::TimestampMicros(is_utc, values) => Arc::new(
310 flush_primitive::<TimestampMicrosecondType>(values, nulls)
311 .with_timezone_opt(is_utc.then(|| "+00:00")),
312 ),
313 Self::Float32(values) => Arc::new(flush_primitive::<Float32Type>(values, nulls)),
314 Self::Float64(values) => Arc::new(flush_primitive::<Float64Type>(values, nulls)),
315 Self::Binary(offsets, values) => {
316 let offsets = flush_offsets(offsets);
317 let values = flush_values(values).into();
318 Arc::new(BinaryArray::new(offsets, values, nulls))
319 }
320 Self::String(offsets, values) => {
321 let offsets = flush_offsets(offsets);
322 let values = flush_values(values).into();
323 Arc::new(StringArray::new(offsets, values, nulls))
324 }
325 Self::StringView(offsets, values) => {
326 let offsets = flush_offsets(offsets);
327 let values = flush_values(values);
328 let array = StringArray::new(offsets, values.into(), nulls.clone());
329
330 let values: Vec<&str> = (0..array.len())
331 .map(|i| {
332 if array.is_valid(i) {
333 array.value(i)
334 } else {
335 ""
336 }
337 })
338 .collect();
339
340 Arc::new(StringViewArray::from(values))
341 }
342 Self::List(field, offsets, values) => {
343 let values = values.flush(None)?;
344 let offsets = flush_offsets(offsets);
345 Arc::new(ListArray::new(field.clone(), offsets, values, nulls))
346 }
347 Self::Record(fields, encodings) => {
348 let arrays = encodings
349 .iter_mut()
350 .map(|x| x.flush(None))
351 .collect::<Result<Vec<_>, _>>()?;
352 Arc::new(StructArray::new(fields.clone(), arrays, nulls))
353 }
354 Self::Map(map_field, k_off, m_off, kdata, valdec) => {
355 let moff = flush_offsets(m_off);
356 let koff = flush_offsets(k_off);
357 let kd = flush_values(kdata).into();
358 let val_arr = valdec.flush(None)?;
359 let key_arr = StringArray::new(koff, kd, None);
360 if key_arr.len() != val_arr.len() {
361 return Err(ArrowError::InvalidArgumentError(format!(
362 "Map keys length ({}) != map values length ({})",
363 key_arr.len(),
364 val_arr.len()
365 )));
366 }
367 let final_len = moff.len() - 1;
368 if let Some(n) = &nulls {
369 if n.len() != final_len {
370 return Err(ArrowError::InvalidArgumentError(format!(
371 "Map array null buffer length {} != final map length {final_len}",
372 n.len()
373 )));
374 }
375 }
376 let entries_struct = StructArray::new(
377 Fields::from(vec![
378 Arc::new(ArrowField::new("key", DataType::Utf8, false)),
379 Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)),
380 ]),
381 vec![Arc::new(key_arr), val_arr],
382 None,
383 );
384 let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false);
385 Arc::new(map_arr)
386 }
387 })
388 }
389}
390
391fn read_map_blocks(
392 buf: &mut AvroCursor,
393 decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
394) -> Result<usize, ArrowError> {
395 read_blockwise_items(buf, true, decode_entry)
396}
397
398fn read_blockwise_items(
399 buf: &mut AvroCursor,
400 read_size_after_negative: bool,
401 mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>,
402) -> Result<usize, ArrowError> {
403 let mut total = 0usize;
404 loop {
405 let block_count = buf.get_long()?;
410 match block_count.cmp(&0) {
411 Ordering::Equal => break,
412 Ordering::Less => {
413 let count = (-block_count) as usize;
416 if read_size_after_negative {
417 let _size_in_bytes = buf.get_long()?;
418 }
419 for _ in 0..count {
420 decode_fn(buf)?;
421 }
422 total += count;
423 }
424 Ordering::Greater => {
425 let count = block_count as usize;
427 for _i in 0..count {
428 decode_fn(buf)?;
429 }
430 total += count;
431 }
432 }
433 }
434 Ok(total)
435}
436
437#[inline]
438fn flush_values<T>(values: &mut Vec<T>) -> Vec<T> {
439 std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY))
440}
441
442#[inline]
443fn flush_offsets(offsets: &mut OffsetBufferBuilder<i32>) -> OffsetBuffer<i32> {
444 std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish()
445}
446
447#[inline]
448fn flush_primitive<T: ArrowPrimitiveType>(
449 values: &mut Vec<T::Native>,
450 nulls: Option<NullBuffer>,
451) -> PrimitiveArray<T> {
452 PrimitiveArray::new(flush_values(values).into(), nulls)
453}
454
455const DEFAULT_CAPACITY: usize = 1024;
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use arrow_array::{
461 cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray,
462 IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray,
463 };
464
465 fn encode_avro_long(value: i64) -> Vec<u8> {
466 let mut buf = Vec::new();
467 let mut v = (value << 1) ^ (value >> 63);
468 while v & !0x7F != 0 {
469 buf.push(((v & 0x7F) | 0x80) as u8);
470 v >>= 7;
471 }
472 buf.push(v as u8);
473 buf
474 }
475
476 fn encode_avro_bytes(bytes: &[u8]) -> Vec<u8> {
477 let mut buf = encode_avro_long(bytes.len() as i64);
478 buf.extend_from_slice(bytes);
479 buf
480 }
481
482 fn avro_from_codec(codec: Codec) -> AvroDataType {
483 AvroDataType::new(codec, Default::default(), None)
484 }
485
486 #[test]
487 fn test_map_decoding_one_entry() {
488 let value_type = avro_from_codec(Codec::Utf8);
489 let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
490 let mut decoder = Decoder::try_new(&map_type).unwrap();
491 let mut data = Vec::new();
493 data.extend_from_slice(&encode_avro_long(1));
494 data.extend_from_slice(&encode_avro_bytes(b"hello")); data.extend_from_slice(&encode_avro_bytes(b"world")); data.extend_from_slice(&encode_avro_long(0));
497 let mut cursor = AvroCursor::new(&data);
498 decoder.decode(&mut cursor).unwrap();
499 let array = decoder.flush(None).unwrap();
500 let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
501 assert_eq!(map_arr.len(), 1); assert_eq!(map_arr.value_length(0), 1);
503 let entries = map_arr.value(0);
504 let struct_entries = entries.as_any().downcast_ref::<StructArray>().unwrap();
505 assert_eq!(struct_entries.len(), 1);
506 let key_arr = struct_entries
507 .column_by_name("key")
508 .unwrap()
509 .as_any()
510 .downcast_ref::<StringArray>()
511 .unwrap();
512 let val_arr = struct_entries
513 .column_by_name("value")
514 .unwrap()
515 .as_any()
516 .downcast_ref::<StringArray>()
517 .unwrap();
518 assert_eq!(key_arr.value(0), "hello");
519 assert_eq!(val_arr.value(0), "world");
520 }
521
522 #[test]
523 fn test_map_decoding_empty() {
524 let value_type = avro_from_codec(Codec::Utf8);
525 let map_type = avro_from_codec(Codec::Map(Arc::new(value_type)));
526 let mut decoder = Decoder::try_new(&map_type).unwrap();
527 let data = encode_avro_long(0);
528 decoder.decode(&mut AvroCursor::new(&data)).unwrap();
529 let array = decoder.flush(None).unwrap();
530 let map_arr = array.as_any().downcast_ref::<MapArray>().unwrap();
531 assert_eq!(map_arr.len(), 1);
532 assert_eq!(map_arr.value_length(0), 0);
533 }
534}