arrow_avro/reader/
header.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Decoder for [`Header`]
19
20use crate::compression::{CompressionCodec, CODEC_METADATA_KEY};
21use crate::reader::vlq::VLQDecoder;
22use crate::schema::{Schema, SCHEMA_METADATA_KEY};
23use arrow_schema::ArrowError;
24
25#[derive(Debug)]
26enum HeaderDecoderState {
27    /// Decoding the [`MAGIC`] prefix
28    Magic,
29    /// Decoding a block count
30    BlockCount,
31    /// Decoding a block byte length
32    BlockLen,
33    /// Decoding a key length
34    KeyLen,
35    /// Decoding a key string
36    Key,
37    /// Decoding a value length
38    ValueLen,
39    /// Decoding a value payload
40    Value,
41    /// Decoding sync marker
42    Sync,
43    /// Finished decoding
44    Finished,
45}
46
47/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files)
48#[derive(Debug, Clone)]
49pub struct Header {
50    meta_offsets: Vec<usize>,
51    meta_buf: Vec<u8>,
52    sync: [u8; 16],
53}
54
55impl Header {
56    /// Returns an iterator over the meta keys in this header
57    pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> {
58        let mut last = 0;
59        self.meta_offsets.chunks_exact(2).map(move |w| {
60            let start = last;
61            last = w[1];
62            (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]])
63        })
64    }
65
66    /// Returns the value for a given metadata key if present
67    pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> {
68        self.metadata()
69            .find_map(|(k, v)| (k == key.as_ref()).then_some(v))
70    }
71
72    /// Returns the sync token for this file
73    pub fn sync(&self) -> [u8; 16] {
74        self.sync
75    }
76
77    /// Returns the [`CompressionCodec`] if any
78    pub fn compression(&self) -> Result<Option<CompressionCodec>, ArrowError> {
79        let v = self.get(CODEC_METADATA_KEY);
80        match v {
81            None | Some(b"null") => Ok(None),
82            Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)),
83            Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)),
84            Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)),
85            Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)),
86            Some(b"xz") => Ok(Some(CompressionCodec::Xz)),
87            Some(v) => Err(ArrowError::ParseError(format!(
88                "Unrecognized compression codec \'{}\'",
89                String::from_utf8_lossy(v)
90            ))),
91        }
92    }
93
94    /// Returns the [`Schema`] if any
95    pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, ArrowError> {
96        self.get(SCHEMA_METADATA_KEY)
97            .map(|x| {
98                serde_json::from_slice(x).map_err(|e| {
99                    ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}"))
100                })
101            })
102            .transpose()
103    }
104}
105
106/// A decoder for [`Header`]
107///
108/// The avro file format does not encode the length of the header, and so it
109/// is necessary to provide a push-based decoder that can be used with streams
110#[derive(Debug)]
111pub struct HeaderDecoder {
112    state: HeaderDecoderState,
113    vlq_decoder: VLQDecoder,
114
115    /// The end offsets of strings in `meta_buf`
116    meta_offsets: Vec<usize>,
117    /// The raw binary data of the metadata map
118    meta_buf: Vec<u8>,
119
120    /// The decoded sync marker
121    sync_marker: [u8; 16],
122
123    /// The number of remaining tuples in the current block
124    tuples_remaining: usize,
125    /// The number of bytes remaining in the current string/bytes payload
126    bytes_remaining: usize,
127}
128
129impl Default for HeaderDecoder {
130    fn default() -> Self {
131        Self {
132            state: HeaderDecoderState::Magic,
133            meta_offsets: vec![],
134            meta_buf: vec![],
135            sync_marker: [0; 16],
136            vlq_decoder: Default::default(),
137            tuples_remaining: 0,
138            bytes_remaining: MAGIC.len(),
139        }
140    }
141}
142
143const MAGIC: &[u8; 4] = b"Obj\x01";
144
145impl HeaderDecoder {
146    /// Parse [`Header`] from `buf`, returning the number of bytes read
147    ///
148    /// This method can be called multiple times with consecutive chunks of data, allowing
149    /// integration with chunked IO systems like [`BufRead::fill_buf`]
150    ///
151    /// All errors should be considered fatal, and decoding aborted
152    ///
153    /// Once the entire [`Header`] has been decoded this method will not read any further
154    /// input bytes, and the header can be obtained with [`Self::flush`]
155    ///
156    /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf
157    pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> {
158        let max_read = buf.len();
159        while !buf.is_empty() {
160            match self.state {
161                HeaderDecoderState::Magic => {
162                    let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..];
163                    let to_decode = buf.len().min(remaining.len());
164                    if !buf.starts_with(&remaining[..to_decode]) {
165                        return Err(ArrowError::ParseError("Incorrect avro magic".to_string()));
166                    }
167                    self.bytes_remaining -= to_decode;
168                    buf = &buf[to_decode..];
169                    if self.bytes_remaining == 0 {
170                        self.state = HeaderDecoderState::BlockCount;
171                    }
172                }
173                HeaderDecoderState::BlockCount => {
174                    if let Some(block_count) = self.vlq_decoder.long(&mut buf) {
175                        match block_count.try_into() {
176                            Ok(0) => {
177                                self.state = HeaderDecoderState::Sync;
178                                self.bytes_remaining = 16;
179                            }
180                            Ok(remaining) => {
181                                self.tuples_remaining = remaining;
182                                self.state = HeaderDecoderState::KeyLen;
183                            }
184                            Err(_) => {
185                                self.tuples_remaining = block_count.unsigned_abs() as _;
186                                self.state = HeaderDecoderState::BlockLen;
187                            }
188                        }
189                    }
190                }
191                HeaderDecoderState::BlockLen => {
192                    if self.vlq_decoder.long(&mut buf).is_some() {
193                        self.state = HeaderDecoderState::KeyLen
194                    }
195                }
196                HeaderDecoderState::Key => {
197                    let to_read = self.bytes_remaining.min(buf.len());
198                    self.meta_buf.extend_from_slice(&buf[..to_read]);
199                    self.bytes_remaining -= to_read;
200                    buf = &buf[to_read..];
201                    if self.bytes_remaining == 0 {
202                        self.meta_offsets.push(self.meta_buf.len());
203                        self.state = HeaderDecoderState::ValueLen;
204                    }
205                }
206                HeaderDecoderState::Value => {
207                    let to_read = self.bytes_remaining.min(buf.len());
208                    self.meta_buf.extend_from_slice(&buf[..to_read]);
209                    self.bytes_remaining -= to_read;
210                    buf = &buf[to_read..];
211                    if self.bytes_remaining == 0 {
212                        self.meta_offsets.push(self.meta_buf.len());
213
214                        self.tuples_remaining -= 1;
215                        match self.tuples_remaining {
216                            0 => self.state = HeaderDecoderState::BlockCount,
217                            _ => self.state = HeaderDecoderState::KeyLen,
218                        }
219                    }
220                }
221                HeaderDecoderState::KeyLen => {
222                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
223                        self.bytes_remaining = len as _;
224                        self.state = HeaderDecoderState::Key;
225                    }
226                }
227                HeaderDecoderState::ValueLen => {
228                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
229                        self.bytes_remaining = len as _;
230                        self.state = HeaderDecoderState::Value;
231                    }
232                }
233                HeaderDecoderState::Sync => {
234                    let to_decode = buf.len().min(self.bytes_remaining);
235                    let write = &mut self.sync_marker[16 - to_decode..];
236                    write[..to_decode].copy_from_slice(&buf[..to_decode]);
237                    self.bytes_remaining -= to_decode;
238                    buf = &buf[to_decode..];
239                    if self.bytes_remaining == 0 {
240                        self.state = HeaderDecoderState::Finished;
241                    }
242                }
243                HeaderDecoderState::Finished => return Ok(max_read - buf.len()),
244            }
245        }
246        Ok(max_read)
247    }
248
249    /// Flush this decoder returning the parsed [`Header`] if any
250    pub fn flush(&mut self) -> Option<Header> {
251        match self.state {
252            HeaderDecoderState::Finished => {
253                self.state = HeaderDecoderState::Magic;
254                Some(Header {
255                    meta_offsets: std::mem::take(&mut self.meta_offsets),
256                    meta_buf: std::mem::take(&mut self.meta_buf),
257                    sync: self.sync_marker,
258                })
259            }
260            _ => None,
261        }
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use super::*;
268    use crate::codec::{AvroDataType, AvroField};
269    use crate::reader::read_header;
270    use crate::schema::SCHEMA_METADATA_KEY;
271    use crate::test_util::arrow_test_data;
272    use arrow_schema::{DataType, Field, Fields, TimeUnit};
273    use std::fs::File;
274    use std::io::{BufRead, BufReader};
275
276    #[test]
277    fn test_header_decode() {
278        let mut decoder = HeaderDecoder::default();
279        for m in MAGIC {
280            decoder.decode(std::slice::from_ref(m)).unwrap();
281        }
282
283        let mut decoder = HeaderDecoder::default();
284        assert_eq!(decoder.decode(MAGIC).unwrap(), 4);
285
286        let mut decoder = HeaderDecoder::default();
287        decoder.decode(b"Ob").unwrap();
288        let err = decoder.decode(b"s").unwrap_err().to_string();
289        assert_eq!(err, "Parser error: Incorrect avro magic");
290    }
291
292    fn decode_file(file: &str) -> Header {
293        let file = File::open(file).unwrap();
294        read_header(BufReader::with_capacity(100, file)).unwrap()
295    }
296
297    #[test]
298    fn test_header() {
299        let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro"));
300        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
301        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#;
302        assert_eq!(schema_json, expected);
303        let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
304        let field = AvroField::try_from(&schema).unwrap();
305
306        assert_eq!(
307            field.field(),
308            Field::new(
309                "topLevelRecord",
310                DataType::Struct(Fields::from(vec![
311                    Field::new("id", DataType::Int32, true),
312                    Field::new("bool_col", DataType::Boolean, true),
313                    Field::new("tinyint_col", DataType::Int32, true),
314                    Field::new("smallint_col", DataType::Int32, true),
315                    Field::new("int_col", DataType::Int32, true),
316                    Field::new("bigint_col", DataType::Int64, true),
317                    Field::new("float_col", DataType::Float32, true),
318                    Field::new("double_col", DataType::Float64, true),
319                    Field::new("date_string_col", DataType::Binary, true),
320                    Field::new("string_col", DataType::Binary, true),
321                    Field::new(
322                        "timestamp_col",
323                        DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
324                        true
325                    ),
326                ])),
327                false
328            )
329        );
330
331        assert_eq!(
332            u128::from_le_bytes(header.sync()),
333            226966037233754408753420635932530907102
334        );
335
336        let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro"));
337
338        let meta: Vec<_> = header
339            .metadata()
340            .map(|(k, _)| std::str::from_utf8(k).unwrap())
341            .collect();
342
343        assert_eq!(
344            meta,
345            &["avro.schema", "org.apache.spark.version", "avro.codec"]
346        );
347
348        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
349        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#;
350        assert_eq!(schema_json, expected);
351        let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
352        assert_eq!(
353            u128::from_le_bytes(header.sync()),
354            325166208089902833952788552656412487328
355        );
356    }
357}