Skip to main content

arrow_avro/reader/
vlq.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18/// Decoder for zig-zag encoded variable length (VLW) integers
19///
20/// See also:
21/// <https://avro.apache.org/docs/1.11.1/specification/#primitive-types-1>
22/// <https://protobuf.dev/programming-guides/encoding/#varints>
23#[derive(Debug, Default)]
24pub struct VLQDecoder {
25    /// Scratch space for decoding VLQ integers
26    in_progress: u64,
27    shift: u32,
28}
29
30impl VLQDecoder {
31    /// Decode a signed long from `buf`
32    pub fn long(&mut self, buf: &mut &[u8]) -> Option<i64> {
33        while let Some(byte) = buf.first().copied() {
34            *buf = &buf[1..];
35            self.in_progress |= ((byte & 0x7F) as u64) << self.shift;
36            self.shift += 7;
37            if byte & 0x80 == 0 {
38                let val = self.in_progress;
39                self.in_progress = 0;
40                self.shift = 0;
41                return Some((val >> 1) as i64 ^ -((val & 1) as i64));
42            }
43        }
44        None
45    }
46}
47
48/// Read a varint from `buf` returning the decoded `u64` and the number of bytes read
49#[inline]
50pub(crate) fn read_varint(buf: &[u8]) -> Option<(u64, usize)> {
51    let first = *buf.first()?;
52    if first < 0x80 {
53        return Some((first as u64, 1));
54    }
55
56    if let Some(array) = buf.get(..10) {
57        return read_varint_array(array.try_into().unwrap());
58    }
59
60    read_varint_slow(buf)
61}
62
63/// Based on
64/// - <https://github.com/tokio-rs/prost/blob/master/prost/src/encoding/varint.rs#L71>
65/// - <https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406>
66/// - <https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358>
67#[inline]
68fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> {
69    let mut in_progress = 0_u64;
70    for (idx, b) in buf.into_iter().take(9).enumerate() {
71        in_progress += (b as u64) << (7 * idx);
72        if b < 0x80 {
73            return Some((in_progress, idx + 1));
74        }
75        in_progress -= 0x80 << (7 * idx);
76    }
77
78    let b = buf[9] as u64;
79    in_progress += b << (7 * 9);
80    (b < 0x02).then_some((in_progress, 10))
81}
82
83#[inline(never)]
84#[cold]
85fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> {
86    let mut value = 0;
87    for (count, _byte) in buf.iter().take(10).enumerate() {
88        let byte = buf[count];
89        value |= u64::from(byte & 0x7F) << (count * 7);
90        if byte <= 0x7F {
91            // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
92            // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
93            return (count != 9 || byte < 2).then_some((value, count + 1));
94        }
95    }
96
97    None
98}
99
100pub(crate) fn skip_varint(buf: &[u8]) -> Option<usize> {
101    if let Some(array) = buf.get(..10) {
102        return skip_varint_array(array.try_into().unwrap());
103    }
104    skip_varint_slow(buf)
105}
106
107fn skip_varint_array(buf: [u8; 10]) -> Option<usize> {
108    // Using buf.into_iter().enumerate() regresses performance by 1% on x86-64
109    #[allow(clippy::needless_range_loop)]
110    for idx in 0..9 {
111        if buf[idx] < 0x80 {
112            return Some(idx + 1);
113        }
114    }
115    (buf[9] < 0x02).then_some(10)
116}
117
118#[cold]
119fn skip_varint_slow(buf: &[u8]) -> Option<usize> {
120    debug_assert!(
121        buf.len() < 10,
122        "should be only called on buffers too short for the fast path"
123    );
124    for (idx, &byte) in buf.iter().enumerate() {
125        if byte < 0x80 {
126            return Some(idx + 1);
127        }
128    }
129    None
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    fn encode_var(mut n: u64, dst: &mut [u8]) -> usize {
137        let mut i = 0;
138
139        while n >= 0x80 {
140            dst[i] = 0x80 | (n as u8);
141            i += 1;
142            n >>= 7;
143        }
144
145        dst[i] = n as u8;
146        i + 1
147    }
148
149    fn varint_test(a: u64) {
150        let mut buf = [0_u8; 10];
151        let len = encode_var(a, &mut buf);
152        assert_eq!(read_varint(&buf[..len]).unwrap(), (a, len));
153        assert_eq!(read_varint(&buf).unwrap(), (a, len));
154    }
155
156    #[test]
157    fn test_varint() {
158        varint_test(0);
159        varint_test(4395932);
160        varint_test(u64::MAX);
161
162        for _ in 0..1000 {
163            varint_test(rand::random());
164        }
165    }
166}