arrow_avro/reader/
header.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Decoder for [`Header`]
19
20use crate::compression::{CODEC_METADATA_KEY, CompressionCodec};
21use crate::reader::vlq::VLQDecoder;
22use crate::schema::{SCHEMA_METADATA_KEY, Schema};
23use arrow_schema::ArrowError;
24use std::io::BufRead;
25
26/// Read the Avro file header (magic, metadata, sync marker) from `reader`.
27pub(crate) fn read_header<R: BufRead>(mut reader: R) -> Result<Header, ArrowError> {
28    let mut decoder = HeaderDecoder::default();
29    loop {
30        let buf = reader.fill_buf()?;
31        if buf.is_empty() {
32            break;
33        }
34        let read = buf.len();
35        let decoded = decoder.decode(buf)?;
36        reader.consume(decoded);
37        if decoded != read {
38            break;
39        }
40    }
41    decoder.flush().ok_or_else(|| {
42        ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string())
43    })
44}
45
46#[derive(Debug)]
47enum HeaderDecoderState {
48    /// Decoding the [`MAGIC`] prefix
49    Magic,
50    /// Decoding a block count
51    BlockCount,
52    /// Decoding a block byte length
53    BlockLen,
54    /// Decoding a key length
55    KeyLen,
56    /// Decoding a key string
57    Key,
58    /// Decoding a value length
59    ValueLen,
60    /// Decoding a value payload
61    Value,
62    /// Decoding sync marker
63    Sync,
64    /// Finished decoding
65    Finished,
66}
67
68/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files)
69#[derive(Debug, Clone)]
70pub struct Header {
71    meta_offsets: Vec<usize>,
72    meta_buf: Vec<u8>,
73    sync: [u8; 16],
74}
75
76impl Header {
77    /// Returns an iterator over the meta keys in this header
78    pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> {
79        let mut last = 0;
80        self.meta_offsets.chunks_exact(2).map(move |w| {
81            let start = last;
82            last = w[1];
83            (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]])
84        })
85    }
86
87    /// Returns the value for a given metadata key if present
88    pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> {
89        self.metadata()
90            .find_map(|(k, v)| (k == key.as_ref()).then_some(v))
91    }
92
93    /// Returns the sync token for this file
94    pub fn sync(&self) -> [u8; 16] {
95        self.sync
96    }
97
98    /// Returns the [`CompressionCodec`] if any
99    pub fn compression(&self) -> Result<Option<CompressionCodec>, ArrowError> {
100        let v = self.get(CODEC_METADATA_KEY);
101        match v {
102            None | Some(b"null") => Ok(None),
103            Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)),
104            Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)),
105            Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)),
106            Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)),
107            Some(b"xz") => Ok(Some(CompressionCodec::Xz)),
108            Some(v) => Err(ArrowError::ParseError(format!(
109                "Unrecognized compression codec \'{}\'",
110                String::from_utf8_lossy(v)
111            ))),
112        }
113    }
114
115    /// Returns the `Schema` if any
116    pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, ArrowError> {
117        self.get(SCHEMA_METADATA_KEY)
118            .map(|x| {
119                serde_json::from_slice(x).map_err(|e| {
120                    ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}"))
121                })
122            })
123            .transpose()
124    }
125}
126
127/// A decoder for [`Header`]
128///
129/// The avro file format does not encode the length of the header, and so it
130/// is necessary to provide a push-based decoder that can be used with streams
131#[derive(Debug)]
132pub struct HeaderDecoder {
133    state: HeaderDecoderState,
134    vlq_decoder: VLQDecoder,
135
136    /// The end offsets of strings in `meta_buf`
137    meta_offsets: Vec<usize>,
138    /// The raw binary data of the metadata map
139    meta_buf: Vec<u8>,
140
141    /// The decoded sync marker
142    sync_marker: [u8; 16],
143
144    /// The number of remaining tuples in the current block
145    tuples_remaining: usize,
146    /// The number of bytes remaining in the current string/bytes payload
147    bytes_remaining: usize,
148}
149
150impl Default for HeaderDecoder {
151    fn default() -> Self {
152        Self {
153            state: HeaderDecoderState::Magic,
154            meta_offsets: vec![],
155            meta_buf: vec![],
156            sync_marker: [0; 16],
157            vlq_decoder: Default::default(),
158            tuples_remaining: 0,
159            bytes_remaining: MAGIC.len(),
160        }
161    }
162}
163
164const MAGIC: &[u8; 4] = b"Obj\x01";
165
166impl HeaderDecoder {
167    /// Parse [`Header`] from `buf`, returning the number of bytes read
168    ///
169    /// This method can be called multiple times with consecutive chunks of data, allowing
170    /// integration with chunked IO systems like [`BufRead::fill_buf`]
171    ///
172    /// All errors should be considered fatal, and decoding aborted
173    ///
174    /// Once the entire [`Header`] has been decoded this method will not read any further
175    /// input bytes, and the header can be obtained with [`Self::flush`]
176    ///
177    /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf
178    pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> {
179        let max_read = buf.len();
180        while !buf.is_empty() {
181            match self.state {
182                HeaderDecoderState::Magic => {
183                    let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..];
184                    let to_decode = buf.len().min(remaining.len());
185                    if !buf.starts_with(&remaining[..to_decode]) {
186                        return Err(ArrowError::ParseError("Incorrect avro magic".to_string()));
187                    }
188                    self.bytes_remaining -= to_decode;
189                    buf = &buf[to_decode..];
190                    if self.bytes_remaining == 0 {
191                        self.state = HeaderDecoderState::BlockCount;
192                    }
193                }
194                HeaderDecoderState::BlockCount => {
195                    if let Some(block_count) = self.vlq_decoder.long(&mut buf) {
196                        match block_count.try_into() {
197                            Ok(0) => {
198                                self.state = HeaderDecoderState::Sync;
199                                self.bytes_remaining = 16;
200                            }
201                            Ok(remaining) => {
202                                self.tuples_remaining = remaining;
203                                self.state = HeaderDecoderState::KeyLen;
204                            }
205                            Err(_) => {
206                                self.tuples_remaining = block_count.unsigned_abs() as _;
207                                self.state = HeaderDecoderState::BlockLen;
208                            }
209                        }
210                    }
211                }
212                HeaderDecoderState::BlockLen => {
213                    if self.vlq_decoder.long(&mut buf).is_some() {
214                        self.state = HeaderDecoderState::KeyLen
215                    }
216                }
217                HeaderDecoderState::Key => {
218                    let to_read = self.bytes_remaining.min(buf.len());
219                    self.meta_buf.extend_from_slice(&buf[..to_read]);
220                    self.bytes_remaining -= to_read;
221                    buf = &buf[to_read..];
222                    if self.bytes_remaining == 0 {
223                        self.meta_offsets.push(self.meta_buf.len());
224                        self.state = HeaderDecoderState::ValueLen;
225                    }
226                }
227                HeaderDecoderState::Value => {
228                    let to_read = self.bytes_remaining.min(buf.len());
229                    self.meta_buf.extend_from_slice(&buf[..to_read]);
230                    self.bytes_remaining -= to_read;
231                    buf = &buf[to_read..];
232                    if self.bytes_remaining == 0 {
233                        self.meta_offsets.push(self.meta_buf.len());
234
235                        self.tuples_remaining -= 1;
236                        match self.tuples_remaining {
237                            0 => self.state = HeaderDecoderState::BlockCount,
238                            _ => self.state = HeaderDecoderState::KeyLen,
239                        }
240                    }
241                }
242                HeaderDecoderState::KeyLen => {
243                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
244                        self.bytes_remaining = len as _;
245                        self.state = HeaderDecoderState::Key;
246                    }
247                }
248                HeaderDecoderState::ValueLen => {
249                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
250                        self.bytes_remaining = len as _;
251                        self.state = HeaderDecoderState::Value;
252                    }
253                }
254                HeaderDecoderState::Sync => {
255                    let to_decode = buf.len().min(self.bytes_remaining);
256                    let write = &mut self.sync_marker[16 - to_decode..];
257                    write[..to_decode].copy_from_slice(&buf[..to_decode]);
258                    self.bytes_remaining -= to_decode;
259                    buf = &buf[to_decode..];
260                    if self.bytes_remaining == 0 {
261                        self.state = HeaderDecoderState::Finished;
262                    }
263                }
264                HeaderDecoderState::Finished => return Ok(max_read - buf.len()),
265            }
266        }
267        Ok(max_read)
268    }
269
270    /// Flush this decoder returning the parsed [`Header`] if any
271    pub fn flush(&mut self) -> Option<Header> {
272        match self.state {
273            HeaderDecoderState::Finished => {
274                self.state = HeaderDecoderState::Magic;
275                Some(Header {
276                    meta_offsets: std::mem::take(&mut self.meta_offsets),
277                    meta_buf: std::mem::take(&mut self.meta_buf),
278                    sync: self.sync_marker,
279                })
280            }
281            _ => None,
282        }
283    }
284}
285
286#[cfg(test)]
287mod test {
288    use super::*;
289    use crate::codec::AvroField;
290    use crate::reader::read_header;
291    use crate::schema::{
292        AVRO_NAME_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, SCHEMA_METADATA_KEY,
293    };
294    use crate::test_util::arrow_test_data;
295    use arrow_schema::{DataType, Field, Fields, TimeUnit};
296    use std::collections::HashMap;
297    use std::fs::File;
298    use std::io::BufReader;
299
300    #[test]
301    fn test_header_decode() {
302        let mut decoder = HeaderDecoder::default();
303        for m in MAGIC {
304            decoder.decode(std::slice::from_ref(m)).unwrap();
305        }
306
307        let mut decoder = HeaderDecoder::default();
308        assert_eq!(decoder.decode(MAGIC).unwrap(), 4);
309
310        let mut decoder = HeaderDecoder::default();
311        decoder.decode(b"Ob").unwrap();
312        let err = decoder.decode(b"s").unwrap_err().to_string();
313        assert_eq!(err, "Parser error: Incorrect avro magic");
314    }
315
316    fn decode_file(file: &str) -> Header {
317        let file = File::open(file).unwrap();
318        read_header(BufReader::with_capacity(1000, file)).unwrap()
319    }
320
321    #[test]
322    fn test_header() {
323        let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro"));
324        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
325        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#;
326        assert_eq!(schema_json, expected);
327        let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
328        let field = AvroField::try_from(&schema).unwrap();
329
330        assert_eq!(
331            field.field(),
332            Field::new(
333                "topLevelRecord",
334                DataType::Struct(Fields::from(vec![
335                    Field::new("id", DataType::Int32, true),
336                    Field::new("bool_col", DataType::Boolean, true),
337                    Field::new("tinyint_col", DataType::Int32, true),
338                    Field::new("smallint_col", DataType::Int32, true),
339                    Field::new("int_col", DataType::Int32, true),
340                    Field::new("bigint_col", DataType::Int64, true),
341                    Field::new("float_col", DataType::Float32, true),
342                    Field::new("double_col", DataType::Float64, true),
343                    Field::new("date_string_col", DataType::Binary, true),
344                    Field::new("string_col", DataType::Binary, true),
345                    Field::new(
346                        "timestamp_col",
347                        DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
348                        true
349                    ),
350                ])),
351                false
352            )
353            .with_metadata(HashMap::from([(
354                AVRO_NAME_METADATA_KEY.to_string(),
355                AVRO_ROOT_RECORD_DEFAULT_NAME.to_string()
356            )]))
357        );
358
359        assert_eq!(
360            u128::from_le_bytes(header.sync()),
361            226966037233754408753420635932530907102
362        );
363
364        let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro"));
365
366        let meta: Vec<_> = header
367            .metadata()
368            .map(|(k, _)| std::str::from_utf8(k).unwrap())
369            .collect();
370
371        assert_eq!(
372            meta,
373            &["avro.schema", "org.apache.spark.version", "avro.codec"]
374        );
375
376        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
377        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#;
378        assert_eq!(schema_json, expected);
379        let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
380        assert_eq!(
381            u128::from_le_bytes(header.sync()),
382            325166208089902833952788552656412487328
383        );
384    }
385}