1use crate::compression::{CompressionCodec, CODEC_METADATA_KEY};
21use crate::reader::vlq::VLQDecoder;
22use crate::schema::{Schema, SCHEMA_METADATA_KEY};
23use arrow_schema::ArrowError;
24
25#[derive(Debug)]
26enum HeaderDecoderState {
27 Magic,
29 BlockCount,
31 BlockLen,
33 KeyLen,
35 Key,
37 ValueLen,
39 Value,
41 Sync,
43 Finished,
45}
46
47#[derive(Debug, Clone)]
49pub struct Header {
50 meta_offsets: Vec<usize>,
51 meta_buf: Vec<u8>,
52 sync: [u8; 16],
53}
54
55impl Header {
56 pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> {
58 let mut last = 0;
59 self.meta_offsets.chunks_exact(2).map(move |w| {
60 let start = last;
61 last = w[1];
62 (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]])
63 })
64 }
65
66 pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> {
68 self.metadata()
69 .find_map(|(k, v)| (k == key.as_ref()).then_some(v))
70 }
71
72 pub fn sync(&self) -> [u8; 16] {
74 self.sync
75 }
76
77 pub fn compression(&self) -> Result<Option<CompressionCodec>, ArrowError> {
79 let v = self.get(CODEC_METADATA_KEY);
80 match v {
81 None | Some(b"null") => Ok(None),
82 Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)),
83 Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)),
84 Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)),
85 Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)),
86 Some(b"xz") => Ok(Some(CompressionCodec::Xz)),
87 Some(v) => Err(ArrowError::ParseError(format!(
88 "Unrecognized compression codec \'{}\'",
89 String::from_utf8_lossy(v)
90 ))),
91 }
92 }
93
94 pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, ArrowError> {
96 self.get(SCHEMA_METADATA_KEY)
97 .map(|x| {
98 serde_json::from_slice(x).map_err(|e| {
99 ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}"))
100 })
101 })
102 .transpose()
103 }
104}
105
106#[derive(Debug)]
111pub struct HeaderDecoder {
112 state: HeaderDecoderState,
113 vlq_decoder: VLQDecoder,
114
115 meta_offsets: Vec<usize>,
117 meta_buf: Vec<u8>,
119
120 sync_marker: [u8; 16],
122
123 tuples_remaining: usize,
125 bytes_remaining: usize,
127}
128
129impl Default for HeaderDecoder {
130 fn default() -> Self {
131 Self {
132 state: HeaderDecoderState::Magic,
133 meta_offsets: vec![],
134 meta_buf: vec![],
135 sync_marker: [0; 16],
136 vlq_decoder: Default::default(),
137 tuples_remaining: 0,
138 bytes_remaining: MAGIC.len(),
139 }
140 }
141}
142
143const MAGIC: &[u8; 4] = b"Obj\x01";
144
145impl HeaderDecoder {
146 pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> {
158 let max_read = buf.len();
159 while !buf.is_empty() {
160 match self.state {
161 HeaderDecoderState::Magic => {
162 let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..];
163 let to_decode = buf.len().min(remaining.len());
164 if !buf.starts_with(&remaining[..to_decode]) {
165 return Err(ArrowError::ParseError("Incorrect avro magic".to_string()));
166 }
167 self.bytes_remaining -= to_decode;
168 buf = &buf[to_decode..];
169 if self.bytes_remaining == 0 {
170 self.state = HeaderDecoderState::BlockCount;
171 }
172 }
173 HeaderDecoderState::BlockCount => {
174 if let Some(block_count) = self.vlq_decoder.long(&mut buf) {
175 match block_count.try_into() {
176 Ok(0) => {
177 self.state = HeaderDecoderState::Sync;
178 self.bytes_remaining = 16;
179 }
180 Ok(remaining) => {
181 self.tuples_remaining = remaining;
182 self.state = HeaderDecoderState::KeyLen;
183 }
184 Err(_) => {
185 self.tuples_remaining = block_count.unsigned_abs() as _;
186 self.state = HeaderDecoderState::BlockLen;
187 }
188 }
189 }
190 }
191 HeaderDecoderState::BlockLen => {
192 if self.vlq_decoder.long(&mut buf).is_some() {
193 self.state = HeaderDecoderState::KeyLen
194 }
195 }
196 HeaderDecoderState::Key => {
197 let to_read = self.bytes_remaining.min(buf.len());
198 self.meta_buf.extend_from_slice(&buf[..to_read]);
199 self.bytes_remaining -= to_read;
200 buf = &buf[to_read..];
201 if self.bytes_remaining == 0 {
202 self.meta_offsets.push(self.meta_buf.len());
203 self.state = HeaderDecoderState::ValueLen;
204 }
205 }
206 HeaderDecoderState::Value => {
207 let to_read = self.bytes_remaining.min(buf.len());
208 self.meta_buf.extend_from_slice(&buf[..to_read]);
209 self.bytes_remaining -= to_read;
210 buf = &buf[to_read..];
211 if self.bytes_remaining == 0 {
212 self.meta_offsets.push(self.meta_buf.len());
213
214 self.tuples_remaining -= 1;
215 match self.tuples_remaining {
216 0 => self.state = HeaderDecoderState::BlockCount,
217 _ => self.state = HeaderDecoderState::KeyLen,
218 }
219 }
220 }
221 HeaderDecoderState::KeyLen => {
222 if let Some(len) = self.vlq_decoder.long(&mut buf) {
223 self.bytes_remaining = len as _;
224 self.state = HeaderDecoderState::Key;
225 }
226 }
227 HeaderDecoderState::ValueLen => {
228 if let Some(len) = self.vlq_decoder.long(&mut buf) {
229 self.bytes_remaining = len as _;
230 self.state = HeaderDecoderState::Value;
231 }
232 }
233 HeaderDecoderState::Sync => {
234 let to_decode = buf.len().min(self.bytes_remaining);
235 let write = &mut self.sync_marker[16 - to_decode..];
236 write[..to_decode].copy_from_slice(&buf[..to_decode]);
237 self.bytes_remaining -= to_decode;
238 buf = &buf[to_decode..];
239 if self.bytes_remaining == 0 {
240 self.state = HeaderDecoderState::Finished;
241 }
242 }
243 HeaderDecoderState::Finished => return Ok(max_read - buf.len()),
244 }
245 }
246 Ok(max_read)
247 }
248
249 pub fn flush(&mut self) -> Option<Header> {
251 match self.state {
252 HeaderDecoderState::Finished => {
253 self.state = HeaderDecoderState::Magic;
254 Some(Header {
255 meta_offsets: std::mem::take(&mut self.meta_offsets),
256 meta_buf: std::mem::take(&mut self.meta_buf),
257 sync: self.sync_marker,
258 })
259 }
260 _ => None,
261 }
262 }
263}
264
265#[cfg(test)]
266mod test {
267 use super::*;
268 use crate::codec::{AvroDataType, AvroField};
269 use crate::reader::read_header;
270 use crate::schema::SCHEMA_METADATA_KEY;
271 use crate::test_util::arrow_test_data;
272 use arrow_schema::{DataType, Field, Fields, TimeUnit};
273 use std::fs::File;
274 use std::io::{BufRead, BufReader};
275
276 #[test]
277 fn test_header_decode() {
278 let mut decoder = HeaderDecoder::default();
279 for m in MAGIC {
280 decoder.decode(std::slice::from_ref(m)).unwrap();
281 }
282
283 let mut decoder = HeaderDecoder::default();
284 assert_eq!(decoder.decode(MAGIC).unwrap(), 4);
285
286 let mut decoder = HeaderDecoder::default();
287 decoder.decode(b"Ob").unwrap();
288 let err = decoder.decode(b"s").unwrap_err().to_string();
289 assert_eq!(err, "Parser error: Incorrect avro magic");
290 }
291
292 fn decode_file(file: &str) -> Header {
293 let file = File::open(file).unwrap();
294 read_header(BufReader::with_capacity(100, file)).unwrap()
295 }
296
297 #[test]
298 fn test_header() {
299 let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro"));
300 let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
301 let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#;
302 assert_eq!(schema_json, expected);
303 let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
304 let field = AvroField::try_from(&schema).unwrap();
305
306 assert_eq!(
307 field.field(),
308 Field::new(
309 "topLevelRecord",
310 DataType::Struct(Fields::from(vec![
311 Field::new("id", DataType::Int32, true),
312 Field::new("bool_col", DataType::Boolean, true),
313 Field::new("tinyint_col", DataType::Int32, true),
314 Field::new("smallint_col", DataType::Int32, true),
315 Field::new("int_col", DataType::Int32, true),
316 Field::new("bigint_col", DataType::Int64, true),
317 Field::new("float_col", DataType::Float32, true),
318 Field::new("double_col", DataType::Float64, true),
319 Field::new("date_string_col", DataType::Binary, true),
320 Field::new("string_col", DataType::Binary, true),
321 Field::new(
322 "timestamp_col",
323 DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())),
324 true
325 ),
326 ])),
327 false
328 )
329 );
330
331 assert_eq!(
332 u128::from_le_bytes(header.sync()),
333 226966037233754408753420635932530907102
334 );
335
336 let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro"));
337
338 let meta: Vec<_> = header
339 .metadata()
340 .map(|(k, _)| std::str::from_utf8(k).unwrap())
341 .collect();
342
343 assert_eq!(
344 meta,
345 &["avro.schema", "org.apache.spark.version", "avro.codec"]
346 );
347
348 let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap();
349 let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#;
350 assert_eq!(schema_json, expected);
351 let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap();
352 assert_eq!(
353 u128::from_le_bytes(header.sync()),
354 325166208089902833952788552656412487328
355 );
356 }
357}