1use crate::compression::{CODEC_METADATA_KEY, CompressionCodec};
21use crate::reader::vlq::VLQDecoder;
22use crate::schema::{SCHEMA_METADATA_KEY, Schema};
23use arrow_schema::ArrowError;
24use std::io::BufRead;
25
26pub(crate) fn read_header<R: BufRead>(mut reader: R) -> Result<Header, ArrowError> {
28 let mut decoder = HeaderDecoder::default();
29 loop {
30 let buf = reader.fill_buf()?;
31 if buf.is_empty() {
32 break;
33 }
34 let read = buf.len();
35 let decoded = decoder.decode(buf)?;
36 reader.consume(decoded);
37 if decoded != read {
38 break;
39 }
40 }
41 decoder.flush().ok_or_else(|| {
42 ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string())
43 })
44}
45
46#[derive(Debug)]
47enum HeaderDecoderState {
48 Magic,
50 BlockCount,
52 BlockLen,
54 KeyLen,
56 Key,
58 ValueLen,
60 Value,
62 Sync,
64 Finished,
66}
67
68#[derive(Debug, Clone)]
70pub struct Header {
71 meta_offsets: Vec<usize>,
72 meta_buf: Vec<u8>,
73 sync: [u8; 16],
74}
75
76impl Header {
77 pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> {
79 let mut last = 0;
80 self.meta_offsets.chunks_exact(2).map(move |w| {
81 let start = last;
82 last = w[1];
83 (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]])
84 })
85 }
86
87 pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> {
89 self.metadata()
90 .find_map(|(k, v)| (k == key.as_ref()).then_some(v))
91 }
92
93 pub fn sync(&self) -> [u8; 16] {
95 self.sync
96 }
97
98 pub fn compression(&self) -> Result<Option<CompressionCodec>, ArrowError> {
100 let v = self.get(CODEC_METADATA_KEY);
101 match v {
102 None | Some(b"null") => Ok(None),
103 Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)),
104 Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)),
105 Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)),
106 Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)),
107 Some(b"xz") => Ok(Some(CompressionCodec::Xz)),
108 Some(v) => Err(ArrowError::ParseError(format!(
109 "Unrecognized compression codec \'{}\'",
110 String::from_utf8_lossy(v)
111 ))),
112 }
113 }
114
115 pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, ArrowError> {
117 self.get(SCHEMA_METADATA_KEY)
118 .map(|x| {
119 serde_json::from_slice(x).map_err(|e| {
120 ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}"))
121 })
122 })
123 .transpose()
124 }
125}
126
127#[derive(Debug)]
132pub struct HeaderDecoder {
133 state: HeaderDecoderState,
134 vlq_decoder: VLQDecoder,
135
136 meta_offsets: Vec<usize>,
138 meta_buf: Vec<u8>,
140
141 sync_marker: [u8; 16],
143
144 tuples_remaining: usize,
146 bytes_remaining: usize,
148}
149
150impl Default for HeaderDecoder {
151 fn default() -> Self {
152 Self {
153 state: HeaderDecoderState::Magic,
154 meta_offsets: vec![],
155 meta_buf: vec![],
156 sync_marker: [0; 16],
157 vlq_decoder: Default::default(),
158 tuples_remaining: 0,
159 bytes_remaining: MAGIC.len(),
160 }
161 }
162}
163
164const MAGIC: &[u8; 4] = b"Obj\x01";
165
166impl HeaderDecoder {
167 pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> {
179 let max_read = buf.len();
180 while !buf.is_empty() {
181 match self.state {
182 HeaderDecoderState::Magic => {
183 let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..];
184 let to_decode = buf.len().min(remaining.len());
185 if !buf.starts_with(&remaining[..to_decode]) {
186 return Err(ArrowError::ParseError("Incorrect avro magic".to_string()));
187 }
188 self.bytes_remaining -= to_decode;
189 buf = &buf[to_decode..];
190 if self.bytes_remaining == 0 {
191 self.state = HeaderDecoderState::BlockCount;
192 }
193 }
194 HeaderDecoderState::BlockCount => {
195 if let Some(block_count) = self.vlq_decoder.long(&mut buf) {
196 match block_count.try_into() {
197 Ok(0) => {
198 self.state = HeaderDecoderState::Sync;
199 self.bytes_remaining = 16;
200 }
201 Ok(remaining) => {
202 self.tuples_remaining = remaining;
203 self.state = HeaderDecoderState::KeyLen;
204 }
205 Err(_) => {
206 self.tuples_remaining = block_count.unsigned_abs() as _;
207 self.state = HeaderDecoderState::BlockLen;
208 }
209 }
210 }
211 }
212 HeaderDecoderState::BlockLen => {
213 if self.vlq_decoder.long(&mut buf).is_some() {
214 self.state = HeaderDecoderState::KeyLen
215 }
216 }
217 HeaderDecoderState::Key => {
218 let to_read = self.bytes_remaining.min(buf.len());
219 self.meta_buf.extend_from_slice(&buf[..to_read]);
220 self.bytes_remaining -= to_read;
221 buf = &buf[to_read..];
222 if self.bytes_remaining == 0 {
223 self.meta_offsets.push(self.meta_buf.len());
224 self.state = HeaderDecoderState::ValueLen;
225 }
226 }
227 HeaderDecoderState::Value => {
228 let to_read = self.bytes_remaining.min(buf.len());
229 self.meta_buf.extend_from_slice(&buf[..to_read]);
230 self.bytes_remaining -= to_read;
231 buf = &buf[to_read..];
232 if self.bytes_remaining == 0 {
233 self.meta_offsets.push(self.meta_buf.len());
234
235 self.tuples_remaining -= 1;
236 match self.tuples_remaining {
237 0 => self.state = HeaderDecoderState::BlockCount,
238 _ => self.state = HeaderDecoderState::KeyLen,
239 }
240 }
241 }
242 HeaderDecoderState::KeyLen => {
243 if let Some(len) = self.vlq_decoder.long(&mut buf) {
244 self.bytes_remaining = len as _;
245 self.state = HeaderDecoderState::Key;
246 }
247 }
248 HeaderDecoderState::ValueLen => {
249 if let Some(len) = self.vlq_decoder.long(&mut buf) {
250 self.bytes_remaining = len as _;
251 self.state = HeaderDecoderState::Value;
252 }
253 }
254 HeaderDecoderState::Sync => {
255 let to_decode = buf.len().min(self.bytes_remaining);
256 let write = &mut self.sync_marker[16 - to_decode..];
257 write[..to_decode].copy_from_slice(&buf[..to_decode]);
258 self.bytes_remaining -= to_decode;
259 buf = &buf[to_decode..];
260 if self.bytes_remaining == 0 {
261 self.state = HeaderDecoderState::Finished;
262 }
263 }
264 HeaderDecoderState::Finished => return Ok(max_read - buf.len()),
265 }
266 }
267 Ok(max_read)
268 }
269
270 pub fn flush(&mut self) -> Option<Header> {
272 match self.state {
273 HeaderDecoderState::Finished => {
274 self.state = HeaderDecoderState::Magic;
275 Some(Header {
276 meta_offsets: std::mem::take(&mut self.meta_offsets),
277 meta_buf: std::mem::take(&mut self.meta_buf),
278 sync: self.sync_marker,
279 })
280 }
281 _ => None,
282 }
283 }
284}
285
286#[cfg(test)]
287mod test {
288 use super::*;
289 use crate::codec::AvroField;
290 use crate::reader::read_header;
291 use crate::schema::{
292 AVRO_NAME_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, SCHEMA_METADATA_KEY,
293 };
294 use crate::test_util::arrow_test_data;
295 use arrow_schema::{DataType, Field, Fields, TimeUnit};
296 use std::collections::HashMap;
297 use std::fs::File;
298 use std::io::BufReader;
299
300 #[test]
301 fn test_header_decode() {
302 let mut decoder = HeaderDecoder::default();
303 for m in MAGIC {
304 decoder.decode(std::slice::from_ref(m)).unwrap();
305 }
306
307 let mut decoder = HeaderDecoder::default();
308 assert_eq!(decoder.decode(MAGIC).unwrap(), 4);
309
310 let mut decoder = HeaderDecoder::default();
311 decoder.decode(b"Ob").unwrap();
312 let err = decoder.decode(b"s").unwrap_err().to_string();
313 assert_eq!(err, "Parser error: Incorrect avro magic");
314 }
315
316 fn decode_file(file: &str) -> Header {
317 let file = File::open(file).unwrap();
318 read_header(BufReader::with_capacity(1000, file)).unwrap()
319 }
320
321 #[test]
322 fn test_header() {
323 let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro"));
324 let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
325 let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#;
326 assert_eq!(schema_json, expected);
327 let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
328 let field = AvroField::try_from(&schema).unwrap();
329
330 assert_eq!(
331 field.field(),
332 Field::new(
333 "topLevelRecord",
334 DataType::Struct(Fields::from(vec![
335 Field::new("id", DataType::Int32, true),
336 Field::new("bool_col", DataType::Boolean, true),
337 Field::new("tinyint_col", DataType::Int32, true),
338 Field::new("smallint_col", DataType::Int32, true),
339 Field::new("int_col", DataType::Int32, true),
340 Field::new("bigint_col", DataType::Int64, true),
341 Field::new("float_col", DataType::Float32, true),
342 Field::new("double_col", DataType::Float64, true),
343 Field::new("date_string_col", DataType::Binary, true),
344 Field::new("string_col", DataType::Binary, true),
345 Field::new(
346 "timestamp_col",
347 DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
348 true
349 ),
350 ])),
351 false
352 )
353 .with_metadata(HashMap::from([(
354 AVRO_NAME_METADATA_KEY.to_string(),
355 AVRO_ROOT_RECORD_DEFAULT_NAME.to_string()
356 )]))
357 );
358
359 assert_eq!(
360 u128::from_le_bytes(header.sync()),
361 226966037233754408753420635932530907102
362 );
363
364 let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro"));
365
366 let meta: Vec<_> = header
367 .metadata()
368 .map(|(k, _)| std::str::from_utf8(k).unwrap())
369 .collect();
370
371 assert_eq!(
372 meta,
373 &["avro.schema", "org.apache.spark.version", "avro.codec"]
374 );
375
376 let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
377 let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#;
378 assert_eq!(schema_json, expected);
379 let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
380 assert_eq!(
381 u128::from_le_bytes(header.sync()),
382 325166208089902833952788552656412487328
383 );
384 }
385}