Skip to main content

arrow_avro/reader/
header.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Decoder for [`Header`]
19
20use crate::compression::{CODEC_METADATA_KEY, CompressionCodec};
21use crate::errors::AvroError;
22use crate::reader::vlq::VLQDecoder;
23use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY, Schema};
24use std::io::BufRead;
25use std::str;
26use std::sync::Arc;
27
28/// Read the Avro file header (magic, metadata, sync marker) from `reader`.
29///
30/// On success, returns the parsed [`Header`] and the number of bytes read from `reader`.
31pub(crate) fn read_header<R: BufRead>(mut reader: R) -> Result<(Header, u64), AvroError> {
32    let mut decoder = HeaderDecoder::default();
33    let mut position = 0;
34    loop {
35        let buf = reader.fill_buf()?;
36        if buf.is_empty() {
37            break;
38        }
39        let read = buf.len();
40        let decoded = decoder.decode(buf)?;
41        reader.consume(decoded);
42        position += decoded as u64;
43        if decoded != read {
44            break;
45        }
46    }
47    decoder
48        .flush()
49        .map(|header| (header, position))
50        .ok_or_else(|| AvroError::EOF("Unexpected EOF while reading Avro header".to_string()))
51}
52
53#[derive(Debug)]
54enum HeaderDecoderState {
55    /// Decoding the [`MAGIC`] prefix
56    Magic,
57    /// Decoding a block count
58    BlockCount,
59    /// Decoding a block byte length
60    BlockLen,
61    /// Decoding a key length
62    KeyLen,
63    /// Decoding a key string
64    Key,
65    /// Decoding a value length
66    ValueLen,
67    /// Decoding a value payload
68    Value,
69    /// Decoding sync marker
70    Sync,
71    /// Finished decoding
72    Finished,
73}
74
75/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files)
76#[derive(Debug, Clone)]
77pub struct Header {
78    meta_offsets: Vec<usize>,
79    meta_buf: Vec<u8>,
80    sync: [u8; 16],
81}
82
83impl Header {
84    /// Returns an iterator over the meta keys in this header
85    pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> {
86        let mut last = 0;
87        self.meta_offsets.chunks_exact(2).map(move |w| {
88            let start = last;
89            last = w[1];
90            (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]])
91        })
92    }
93
94    /// Returns the value for a given metadata key if present
95    pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> {
96        self.metadata()
97            .find_map(|(k, v)| (k == key.as_ref()).then_some(v))
98    }
99
100    /// Returns the sync token for this file
101    pub fn sync(&self) -> [u8; 16] {
102        self.sync
103    }
104
105    /// Returns the [`CompressionCodec`] if any
106    pub fn compression(&self) -> Result<Option<CompressionCodec>, AvroError> {
107        let v = self.get(CODEC_METADATA_KEY);
108        match v {
109            None | Some(b"null") => Ok(None),
110            Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)),
111            Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)),
112            Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)),
113            Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)),
114            Some(b"xz") => Ok(Some(CompressionCodec::Xz)),
115            Some(v) => Err(AvroError::ParseError(format!(
116                "Unrecognized compression codec \'{}\'",
117                String::from_utf8_lossy(v)
118            ))),
119        }
120    }
121
122    /// Returns the `Schema` if any
123    pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, AvroError> {
124        self.get(SCHEMA_METADATA_KEY)
125            .map(|x| {
126                serde_json::from_slice(x).map_err(|e| {
127                    AvroError::ParseError(format!("Failed to parse Avro schema JSON: {e}"))
128                })
129            })
130            .transpose()
131    }
132}
133
134/// Header information for an Avro OCF file.
135///
136/// The header can be parsed once and shared to construct multiple readers
137/// for the same file, and so this struct is designed to be cheaply clonable.
138#[derive(Clone)]
139pub struct HeaderInfo(Arc<HeaderInfoInner>);
140
141struct HeaderInfoInner {
142    header: Header,
143    header_len: u64,
144}
145
146/// Reads the Avro file header (magic, metadata, sync marker) from `reader`.
147///
148/// On success, returns the parsed [`HeaderInfo`] containing the header and its length in bytes.
149pub fn read_header_info<R: BufRead>(reader: R) -> Result<HeaderInfo, AvroError> {
150    let (header, header_len) = read_header(reader)?;
151    Ok(HeaderInfo::new(header, header_len))
152}
153
154impl HeaderInfo {
155    pub(crate) fn new(header: Header, header_len: u64) -> Self {
156        Self(Arc::new(HeaderInfoInner { header, header_len }))
157    }
158
159    /// Returns the writer schema for this file.
160    pub fn writer_schema(&self) -> Result<AvroSchema, AvroError> {
161        let raw = self.0.header.get(SCHEMA_METADATA_KEY).ok_or_else(|| {
162            AvroError::ParseError("No Avro schema present in file header".to_string())
163        })?;
164        let json_string = str::from_utf8(raw)
165            .map_err(|e| {
166                AvroError::ParseError(format!("Invalid UTF-8 in Avro schema header: {e}"))
167            })?
168            .to_string();
169        Ok(AvroSchema::new(json_string))
170    }
171
172    /// Returns the [`CompressionCodec`] if any
173    pub fn compression(&self) -> Result<Option<CompressionCodec>, AvroError> {
174        self.0.header.compression()
175    }
176
177    /// Returns the length of the header in bytes.
178    pub fn header_len(&self) -> u64 {
179        self.0.header_len
180    }
181
182    /// Returns the sync token for this file.
183    pub fn sync(&self) -> [u8; 16] {
184        self.0.header.sync()
185    }
186}
187
188/// A decoder for [`Header`]
189///
190/// The avro file format does not encode the length of the header, and so it
191/// is necessary to provide a push-based decoder that can be used with streams
192#[derive(Debug)]
193pub struct HeaderDecoder {
194    state: HeaderDecoderState,
195    vlq_decoder: VLQDecoder,
196
197    /// The end offsets of strings in `meta_buf`
198    meta_offsets: Vec<usize>,
199    /// The raw binary data of the metadata map
200    meta_buf: Vec<u8>,
201
202    /// The decoded sync marker
203    sync_marker: [u8; 16],
204
205    /// The number of remaining tuples in the current block
206    tuples_remaining: usize,
207    /// The number of bytes remaining in the current string/bytes payload
208    bytes_remaining: usize,
209}
210
211impl Default for HeaderDecoder {
212    fn default() -> Self {
213        Self {
214            state: HeaderDecoderState::Magic,
215            meta_offsets: vec![],
216            meta_buf: vec![],
217            sync_marker: [0; 16],
218            vlq_decoder: Default::default(),
219            tuples_remaining: 0,
220            bytes_remaining: MAGIC.len(),
221        }
222    }
223}
224
225const MAGIC: &[u8; 4] = b"Obj\x01";
226
227impl HeaderDecoder {
228    /// Parse [`Header`] from `buf`, returning the number of bytes read
229    ///
230    /// This method can be called multiple times with consecutive chunks of data, allowing
231    /// integration with chunked IO systems like [`BufRead::fill_buf`]
232    ///
233    /// All errors should be considered fatal, and decoding aborted
234    ///
235    /// Once the entire [`Header`] has been decoded this method will not read any further
236    /// input bytes, and the header can be obtained with [`Self::flush`]
237    ///
238    /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf
239    pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, AvroError> {
240        let max_read = buf.len();
241        while !buf.is_empty() {
242            match self.state {
243                HeaderDecoderState::Magic => {
244                    let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..];
245                    let to_decode = buf.len().min(remaining.len());
246                    if !buf.starts_with(&remaining[..to_decode]) {
247                        return Err(AvroError::ParseError("Incorrect avro magic".to_string()));
248                    }
249                    self.bytes_remaining -= to_decode;
250                    buf = &buf[to_decode..];
251                    if self.bytes_remaining == 0 {
252                        self.state = HeaderDecoderState::BlockCount;
253                    }
254                }
255                HeaderDecoderState::BlockCount => {
256                    if let Some(block_count) = self.vlq_decoder.long(&mut buf) {
257                        match block_count.try_into() {
258                            Ok(0) => {
259                                self.state = HeaderDecoderState::Sync;
260                                self.bytes_remaining = 16;
261                            }
262                            Ok(remaining) => {
263                                self.tuples_remaining = remaining;
264                                self.state = HeaderDecoderState::KeyLen;
265                            }
266                            Err(_) => {
267                                self.tuples_remaining = block_count.unsigned_abs() as _;
268                                self.state = HeaderDecoderState::BlockLen;
269                            }
270                        }
271                    }
272                }
273                HeaderDecoderState::BlockLen => {
274                    if self.vlq_decoder.long(&mut buf).is_some() {
275                        self.state = HeaderDecoderState::KeyLen
276                    }
277                }
278                HeaderDecoderState::Key => {
279                    let to_read = self.bytes_remaining.min(buf.len());
280                    self.meta_buf.extend_from_slice(&buf[..to_read]);
281                    self.bytes_remaining -= to_read;
282                    buf = &buf[to_read..];
283                    if self.bytes_remaining == 0 {
284                        self.meta_offsets.push(self.meta_buf.len());
285                        self.state = HeaderDecoderState::ValueLen;
286                    }
287                }
288                HeaderDecoderState::Value => {
289                    let to_read = self.bytes_remaining.min(buf.len());
290                    self.meta_buf.extend_from_slice(&buf[..to_read]);
291                    self.bytes_remaining -= to_read;
292                    buf = &buf[to_read..];
293                    if self.bytes_remaining == 0 {
294                        self.meta_offsets.push(self.meta_buf.len());
295
296                        self.tuples_remaining -= 1;
297                        match self.tuples_remaining {
298                            0 => self.state = HeaderDecoderState::BlockCount,
299                            _ => self.state = HeaderDecoderState::KeyLen,
300                        }
301                    }
302                }
303                HeaderDecoderState::KeyLen => {
304                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
305                        self.bytes_remaining = len as _;
306                        self.state = HeaderDecoderState::Key;
307                    }
308                }
309                HeaderDecoderState::ValueLen => {
310                    if let Some(len) = self.vlq_decoder.long(&mut buf) {
311                        self.bytes_remaining = len as _;
312                        self.state = HeaderDecoderState::Value;
313                    }
314                }
315                HeaderDecoderState::Sync => {
316                    let to_decode = buf.len().min(self.bytes_remaining);
317                    let write = &mut self.sync_marker[16 - to_decode..];
318                    write[..to_decode].copy_from_slice(&buf[..to_decode]);
319                    self.bytes_remaining -= to_decode;
320                    buf = &buf[to_decode..];
321                    if self.bytes_remaining == 0 {
322                        self.state = HeaderDecoderState::Finished;
323                    }
324                }
325                HeaderDecoderState::Finished => return Ok(max_read - buf.len()),
326            }
327        }
328        Ok(max_read)
329    }
330
331    /// Flush this decoder returning the parsed [`Header`] if any
332    pub fn flush(&mut self) -> Option<Header> {
333        match self.state {
334            HeaderDecoderState::Finished => {
335                self.state = HeaderDecoderState::Magic;
336                Some(Header {
337                    meta_offsets: std::mem::take(&mut self.meta_offsets),
338                    meta_buf: std::mem::take(&mut self.meta_buf),
339                    sync: self.sync_marker,
340                })
341            }
342            _ => None,
343        }
344    }
345}
346
347#[cfg(test)]
348mod test {
349    use super::*;
350    use crate::codec::AvroField;
351    use crate::reader::read_header;
352    use crate::schema::{
353        AVRO_NAME_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, SCHEMA_METADATA_KEY,
354    };
355    use crate::test_util::arrow_test_data;
356    use arrow_schema::{DataType, Field, Fields, TimeUnit};
357    use std::collections::HashMap;
358    use std::fs::File;
359    use std::io::BufReader;
360
361    #[test]
362    fn test_header_decode() {
363        let mut decoder = HeaderDecoder::default();
364        for m in MAGIC {
365            decoder.decode(std::slice::from_ref(m)).unwrap();
366        }
367
368        let mut decoder = HeaderDecoder::default();
369        assert_eq!(decoder.decode(MAGIC).unwrap(), 4);
370
371        let mut decoder = HeaderDecoder::default();
372        decoder.decode(b"Ob").unwrap();
373        let err = decoder.decode(b"s").unwrap_err().to_string();
374        assert_eq!(err, "Parser error: Incorrect avro magic");
375    }
376
377    fn decode_file(file: &str) -> Header {
378        let file = File::open(file).unwrap();
379        read_header(BufReader::with_capacity(1000, file)).unwrap().0
380    }
381
382    #[test]
383    fn test_header() {
384        let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro"));
385        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
386        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#;
387        assert_eq!(schema_json, expected);
388        let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
389        let field = AvroField::try_from(&schema).unwrap();
390
391        assert_eq!(
392            field.field(),
393            Field::new(
394                "topLevelRecord",
395                DataType::Struct(Fields::from(vec![
396                    Field::new("id", DataType::Int32, true),
397                    Field::new("bool_col", DataType::Boolean, true),
398                    Field::new("tinyint_col", DataType::Int32, true),
399                    Field::new("smallint_col", DataType::Int32, true),
400                    Field::new("int_col", DataType::Int32, true),
401                    Field::new("bigint_col", DataType::Int64, true),
402                    Field::new("float_col", DataType::Float32, true),
403                    Field::new("double_col", DataType::Float64, true),
404                    Field::new("date_string_col", DataType::Binary, true),
405                    Field::new("string_col", DataType::Binary, true),
406                    Field::new(
407                        "timestamp_col",
408                        DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
409                        true
410                    ),
411                ])),
412                false
413            )
414            .with_metadata(HashMap::from([(
415                AVRO_NAME_METADATA_KEY.to_string(),
416                AVRO_ROOT_RECORD_DEFAULT_NAME.to_string()
417            )]))
418        );
419
420        assert_eq!(
421            u128::from_le_bytes(header.sync()),
422            226966037233754408753420635932530907102
423        );
424
425        let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro"));
426
427        let meta: Vec<_> = header
428            .metadata()
429            .map(|(k, _)| std::str::from_utf8(k).unwrap())
430            .collect();
431
432        assert_eq!(
433            meta,
434            &["avro.schema", "org.apache.spark.version", "avro.codec"]
435        );
436
437        let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
438        let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#;
439        assert_eq!(schema_json, expected);
440        let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
441        assert_eq!(
442            u128::from_le_bytes(header.sync()),
443            325166208089902833952788552656412487328
444        );
445    }
446}