parquet/util/
bit_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{cmp, mem::size_of};
19
20use bytes::Bytes;
21
22use crate::data_type::{AsBytes, ByteArray, FixedLenByteArray, Int96};
23use crate::errors::{ParquetError, Result};
24use crate::util::bit_pack::{unpack16, unpack32, unpack64, unpack8};
25
26#[inline]
27fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> {
28    // Need to slice as may be called with zero-padded values
29    match bs.get(..N) {
30        Some(b) => Ok(b.try_into().unwrap()),
31        None => Err(general_err!(
32            "error converting value, expected {} bytes got {}",
33            N,
34            bs.len()
35        )),
36    }
37}
38
39/// # Safety
40/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s,
41/// must be valid, unless BIT_CAPACITY is 0.
42pub unsafe trait FromBytes: Sized {
43    const BIT_CAPACITY: usize;
44    type Buffer: AsMut<[u8]> + Default;
45    fn try_from_le_slice(b: &[u8]) -> Result<Self>;
46    fn from_le_bytes(bs: Self::Buffer) -> Self;
47}
48
49macro_rules! from_le_bytes {
50    ($($ty: ty),*) => {
51        $(
52        // SAFETY: this macro is used for types for which all bit patterns are valid.
53        unsafe impl FromBytes for $ty {
54            const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8;
55            type Buffer = [u8; size_of::<Self>()];
56            fn try_from_le_slice(b: &[u8]) -> Result<Self> {
57                Ok(Self::from_le_bytes(array_from_slice(b)?))
58            }
59            fn from_le_bytes(bs: Self::Buffer) -> Self {
60                <$ty>::from_le_bytes(bs)
61            }
62        }
63        )*
64    };
65}
66
67from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 }
68
69// SAFETY: the 0000000x bit pattern is always valid for `bool`.
70unsafe impl FromBytes for bool {
71    const BIT_CAPACITY: usize = 1;
72    type Buffer = [u8; 1];
73
74    fn try_from_le_slice(b: &[u8]) -> Result<Self> {
75        Ok(Self::from_le_bytes(array_from_slice(b)?))
76    }
77    fn from_le_bytes(bs: Self::Buffer) -> Self {
78        bs[0] != 0
79    }
80}
81
82// SAFETY: BIT_CAPACITY is 0.
83unsafe impl FromBytes for Int96 {
84    const BIT_CAPACITY: usize = 0;
85    type Buffer = [u8; 12];
86
87    fn try_from_le_slice(b: &[u8]) -> Result<Self> {
88        let bs: [u8; 12] = array_from_slice(b)?;
89        let mut i = Int96::new();
90        i.set_data(
91            u32::try_from_le_slice(&bs[0..4])?,
92            u32::try_from_le_slice(&bs[4..8])?,
93            u32::try_from_le_slice(&bs[8..12])?,
94        );
95        Ok(i)
96    }
97
98    fn from_le_bytes(bs: Self::Buffer) -> Self {
99        let mut i = Int96::new();
100        i.set_data(
101            u32::try_from_le_slice(&bs[0..4]).unwrap(),
102            u32::try_from_le_slice(&bs[4..8]).unwrap(),
103            u32::try_from_le_slice(&bs[8..12]).unwrap(),
104        );
105        i
106    }
107}
108
109// SAFETY: BIT_CAPACITY is 0.
110unsafe impl FromBytes for ByteArray {
111    const BIT_CAPACITY: usize = 0;
112    type Buffer = Vec<u8>;
113
114    fn try_from_le_slice(b: &[u8]) -> Result<Self> {
115        Ok(b.to_vec().into())
116    }
117    fn from_le_bytes(bs: Self::Buffer) -> Self {
118        bs.into()
119    }
120}
121
122// SAFETY: BIT_CAPACITY is 0.
123unsafe impl FromBytes for FixedLenByteArray {
124    const BIT_CAPACITY: usize = 0;
125    type Buffer = Vec<u8>;
126
127    fn try_from_le_slice(b: &[u8]) -> Result<Self> {
128        Ok(b.to_vec().into())
129    }
130    fn from_le_bytes(bs: Self::Buffer) -> Self {
131        bs.into()
132    }
133}
134
135/// Reads `size` of bytes from `src`, and reinterprets them as type `ty`, in
136/// little-endian order.
137/// This is copied and modified from byteorder crate.
138pub(crate) fn read_num_bytes<T>(size: usize, src: &[u8]) -> T
139where
140    T: FromBytes,
141{
142    assert!(size <= src.len());
143    let mut buffer = <T as FromBytes>::Buffer::default();
144    buffer.as_mut()[..size].copy_from_slice(&src[..size]);
145    <T>::from_le_bytes(buffer)
146}
147
148/// Returns the ceil of value/divisor.
149///
150/// This function should be removed after
151/// [`int_roundings`](https://github.com/rust-lang/rust/issues/88581) is stable.
152#[inline]
153pub fn ceil<T: num::Integer>(value: T, divisor: T) -> T {
154    num::Integer::div_ceil(&value, &divisor)
155}
156
157/// Returns the `num_bits` least-significant bits of `v`
158#[inline]
159pub fn trailing_bits(v: u64, num_bits: usize) -> u64 {
160    if num_bits >= 64 {
161        v
162    } else {
163        v & ((1 << num_bits) - 1)
164    }
165}
166
167/// Returns the minimum number of bits needed to represent the value 'x'
168#[inline]
169pub fn num_required_bits(x: u64) -> u8 {
170    64 - x.leading_zeros() as u8
171}
172
173static BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128];
174
175/// Returns whether bit at position `i` in `data` is set or not
176#[inline]
177pub fn get_bit(data: &[u8], i: usize) -> bool {
178    (data[i >> 3] & BIT_MASK[i & 7]) != 0
179}
180
181/// Utility class for writing bit/byte streams. This class can write data in either
182/// bit packed or byte aligned fashion.
183pub struct BitWriter {
184    buffer: Vec<u8>,
185    buffered_values: u64,
186    bit_offset: u8,
187}
188
189impl BitWriter {
190    pub fn new(initial_capacity: usize) -> Self {
191        Self {
192            buffer: Vec::with_capacity(initial_capacity),
193            buffered_values: 0,
194            bit_offset: 0,
195        }
196    }
197
198    /// Initializes the writer appending to the existing buffer `buffer`
199    pub fn new_from_buf(buffer: Vec<u8>) -> Self {
200        Self {
201            buffer,
202            buffered_values: 0,
203            bit_offset: 0,
204        }
205    }
206
207    /// Consumes and returns the current buffer.
208    #[inline]
209    pub fn consume(mut self) -> Vec<u8> {
210        self.flush();
211        self.buffer
212    }
213
214    /// Flushes the internal buffered bits and returns the buffer's content.
215    /// This is a borrow equivalent of `consume` method.
216    #[inline]
217    pub fn flush_buffer(&mut self) -> &[u8] {
218        self.flush();
219        self.buffer()
220    }
221
222    /// Clears the internal state so the buffer can be reused.
223    #[inline]
224    pub fn clear(&mut self) {
225        self.buffer.clear();
226        self.buffered_values = 0;
227        self.bit_offset = 0;
228    }
229
230    /// Flushes the internal buffered bits and the align the buffer to the next byte.
231    #[inline]
232    pub fn flush(&mut self) {
233        let num_bytes = ceil(self.bit_offset, 8);
234        let slice = &self.buffered_values.to_le_bytes()[..num_bytes as usize];
235        self.buffer.extend_from_slice(slice);
236        self.buffered_values = 0;
237        self.bit_offset = 0;
238    }
239
240    /// Advances the current offset by skipping `num_bytes`, flushing the internal bit
241    /// buffer first.
242    /// This is useful when you want to jump over `num_bytes` bytes and come back later
243    /// to fill these bytes.
244    #[inline]
245    pub fn skip(&mut self, num_bytes: usize) -> usize {
246        self.flush();
247        let result = self.buffer.len();
248        self.buffer.extend(std::iter::repeat(0).take(num_bytes));
249        result
250    }
251
252    /// Returns a slice containing the next `num_bytes` bytes starting from the current
253    /// offset, and advances the underlying buffer by `num_bytes`.
254    /// This is useful when you want to jump over `num_bytes` bytes and come back later
255    /// to fill these bytes.
256    #[inline]
257    pub fn get_next_byte_ptr(&mut self, num_bytes: usize) -> &mut [u8] {
258        let offset = self.skip(num_bytes);
259        &mut self.buffer[offset..offset + num_bytes]
260    }
261
262    #[inline]
263    pub fn bytes_written(&self) -> usize {
264        self.buffer.len() + ceil(self.bit_offset, 8) as usize
265    }
266
267    #[inline]
268    pub fn buffer(&self) -> &[u8] {
269        &self.buffer
270    }
271
272    #[inline]
273    pub fn byte_offset(&self) -> usize {
274        self.buffer.len()
275    }
276
277    /// Writes the entire byte `value` at the byte `offset`
278    pub fn write_at(&mut self, offset: usize, value: u8) {
279        self.buffer[offset] = value;
280    }
281
282    /// Writes the `num_bits` LSB of value `v` to the internal buffer of this writer.
283    /// The `num_bits` must not be greater than 64. This is bit packed.
284    #[inline]
285    pub fn put_value(&mut self, v: u64, num_bits: usize) {
286        assert!(num_bits <= 64);
287        let num_bits = num_bits as u8;
288        assert_eq!(v.checked_shr(num_bits as u32).unwrap_or(0), 0); // covers case v >> 64
289
290        // Add value to buffered_values
291        self.buffered_values |= v << self.bit_offset;
292        self.bit_offset += num_bits;
293        if let Some(remaining) = self.bit_offset.checked_sub(64) {
294            self.buffer
295                .extend_from_slice(&self.buffered_values.to_le_bytes());
296            self.bit_offset = remaining;
297
298            // Perform checked right shift: v >> offset, where offset < 64, otherwise we
299            // shift all bits
300            self.buffered_values = v
301                .checked_shr((num_bits - self.bit_offset) as u32)
302                .unwrap_or(0);
303        }
304    }
305
306    /// Writes `val` of `num_bytes` bytes to the next aligned byte. If size of `T` is
307    /// larger than `num_bytes`, extra higher ordered bytes will be ignored.
308    #[inline]
309    pub fn put_aligned<T: AsBytes>(&mut self, val: T, num_bytes: usize) {
310        self.flush();
311        let slice = val.as_bytes();
312        let len = num_bytes.min(slice.len());
313        self.buffer.extend_from_slice(&slice[..len]);
314    }
315
316    /// Writes `val` of `num_bytes` bytes at the designated `offset`. The `offset` is the
317    /// offset starting from the beginning of the internal buffer that this writer
318    /// maintains. Note that this will overwrite any existing data between `offset` and
319    /// `offset + num_bytes`. Also that if size of `T` is larger than `num_bytes`, extra
320    /// higher ordered bytes will be ignored.
321    #[inline]
322    pub fn put_aligned_offset<T: AsBytes>(&mut self, val: T, num_bytes: usize, offset: usize) {
323        let slice = val.as_bytes();
324        let len = num_bytes.min(slice.len());
325        self.buffer[offset..offset + len].copy_from_slice(&slice[..len])
326    }
327
328    /// Writes a VLQ encoded integer `v` to this buffer. The value is byte aligned.
329    #[inline]
330    pub fn put_vlq_int(&mut self, mut v: u64) {
331        while v & 0xFFFFFFFFFFFFFF80 != 0 {
332            self.put_aligned::<u8>(((v & 0x7F) | 0x80) as u8, 1);
333            v >>= 7;
334        }
335        self.put_aligned::<u8>((v & 0x7F) as u8, 1);
336    }
337
338    /// Writes a zigzag-VLQ encoded (in little endian order) int `v` to this buffer.
339    /// Zigzag-VLQ is a variant of VLQ encoding where negative and positive
340    /// numbers are encoded in a zigzag fashion.
341    /// See: https://developers.google.com/protocol-buffers/docs/encoding
342    #[inline]
343    pub fn put_zigzag_vlq_int(&mut self, v: i64) {
344        let u: u64 = ((v << 1) ^ (v >> 63)) as u64;
345        self.put_vlq_int(u)
346    }
347
348    /// Returns an estimate of the memory used, in bytes
349    pub fn estimated_memory_size(&self) -> usize {
350        self.buffer.capacity() * size_of::<u8>()
351    }
352}
353
354/// Maximum byte length for a VLQ encoded integer
355/// MAX_VLQ_BYTE_LEN = 5 for i32, and MAX_VLQ_BYTE_LEN = 10 for i64
356pub const MAX_VLQ_BYTE_LEN: usize = 10;
357
358pub struct BitReader {
359    /// The byte buffer to read from, passed in by client
360    buffer: Bytes,
361
362    /// Bytes are memcpy'd from `buffer` and values are read from this variable.
363    /// This is faster than reading values byte by byte directly from `buffer`
364    ///
365    /// This is only populated when `self.bit_offset != 0`
366    buffered_values: u64,
367
368    ///
369    /// End                                         Start
370    /// |............|B|B|B|B|B|B|B|B|..............|
371    ///                   ^          ^
372    ///                 bit_offset   byte_offset
373    ///
374    /// Current byte offset in `buffer`
375    byte_offset: usize,
376
377    /// Current bit offset in `buffered_values`
378    bit_offset: usize,
379}
380
381/// Utility class to read bit/byte stream. This class can read bits or bytes that are
382/// either byte aligned or not.
383impl BitReader {
384    pub fn new(buffer: Bytes) -> Self {
385        BitReader {
386            buffer,
387            buffered_values: 0,
388            byte_offset: 0,
389            bit_offset: 0,
390        }
391    }
392
393    pub fn reset(&mut self, buffer: Bytes) {
394        self.buffer = buffer;
395        self.buffered_values = 0;
396        self.byte_offset = 0;
397        self.bit_offset = 0;
398    }
399
400    /// Gets the current byte offset
401    #[inline]
402    pub fn get_byte_offset(&self) -> usize {
403        self.byte_offset + ceil(self.bit_offset, 8)
404    }
405
406    /// Reads a value of type `T` and of size `num_bits`.
407    ///
408    /// Returns `None` if there's not enough data available. `Some` otherwise.
409    pub fn get_value<T: FromBytes>(&mut self, num_bits: usize) -> Option<T> {
410        assert!(num_bits <= 64);
411        assert!(num_bits <= size_of::<T>() * 8);
412
413        if self.byte_offset * 8 + self.bit_offset + num_bits > self.buffer.len() * 8 {
414            return None;
415        }
416
417        // If buffer is not byte aligned, `self.buffered_values` will
418        // have already been populated
419        if self.bit_offset == 0 {
420            self.load_buffered_values()
421        }
422
423        let mut v =
424            trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset;
425        self.bit_offset += num_bits;
426
427        if self.bit_offset >= 64 {
428            self.byte_offset += 8;
429            self.bit_offset -= 64;
430
431            // If the new bit_offset is not 0, we need to read the next 64-bit chunk
432            // to buffered_values and update `v`
433            if self.bit_offset != 0 {
434                self.load_buffered_values();
435
436                v |= trailing_bits(self.buffered_values, self.bit_offset)
437                    .wrapping_shl((num_bits - self.bit_offset) as u32);
438            }
439        }
440
441        // TODO: better to avoid copying here
442        T::try_from_le_slice(v.as_bytes()).ok()
443    }
444
445    /// Read multiple values from their packed representation where each element is represented
446    /// by `num_bits` bits.
447    ///
448    /// # Panics
449    ///
450    /// This function panics if
451    /// - `num_bits` is larger than the bit-capacity of `T`
452    ///
453    pub fn get_batch<T: FromBytes>(&mut self, batch: &mut [T], num_bits: usize) -> usize {
454        assert!(num_bits <= size_of::<T>() * 8);
455
456        let mut values_to_read = batch.len();
457        let needed_bits = num_bits * values_to_read;
458        let remaining_bits = (self.buffer.len() - self.byte_offset) * 8 - self.bit_offset;
459        if remaining_bits < needed_bits {
460            values_to_read = remaining_bits / num_bits;
461        }
462
463        let mut i = 0;
464
465        // First align bit offset to byte offset
466        if self.bit_offset != 0 {
467            while i < values_to_read && self.bit_offset != 0 {
468                batch[i] = self
469                    .get_value(num_bits)
470                    .expect("expected to have more data");
471                i += 1;
472            }
473        }
474
475        assert_ne!(T::BIT_CAPACITY, 0);
476        assert!(num_bits <= T::BIT_CAPACITY);
477
478        // Read directly into output buffer
479        match size_of::<T>() {
480            1 => {
481                let ptr = batch.as_mut_ptr() as *mut u8;
482                // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
483                // in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
484                // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
485                // checked that num_bits <= T::BIT_CAPACITY.
486                let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
487                while values_to_read - i >= 8 {
488                    let out_slice = (&mut out[i..i + 8]).try_into().unwrap();
489                    unpack8(&self.buffer[self.byte_offset..], out_slice, num_bits);
490                    self.byte_offset += num_bits;
491                    i += 8;
492                }
493            }
494            2 => {
495                let ptr = batch.as_mut_ptr() as *mut u16;
496                // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
497                // in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
498                // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
499                // checked that num_bits <= T::BIT_CAPACITY.
500                let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
501                while values_to_read - i >= 16 {
502                    let out_slice = (&mut out[i..i + 16]).try_into().unwrap();
503                    unpack16(&self.buffer[self.byte_offset..], out_slice, num_bits);
504                    self.byte_offset += 2 * num_bits;
505                    i += 16;
506                }
507            }
508            4 => {
509                let ptr = batch.as_mut_ptr() as *mut u32;
510                // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
511                // in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
512                // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
513                // checked that num_bits <= T::BIT_CAPACITY.
514                let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
515                while values_to_read - i >= 32 {
516                    let out_slice = (&mut out[i..i + 32]).try_into().unwrap();
517                    unpack32(&self.buffer[self.byte_offset..], out_slice, num_bits);
518                    self.byte_offset += 4 * num_bits;
519                    i += 32;
520                }
521            }
522            8 => {
523                let ptr = batch.as_mut_ptr() as *mut u64;
524                // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
525                // in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
526                // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
527                // checked that num_bits <= T::BIT_CAPACITY.
528                let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
529                while values_to_read - i >= 64 {
530                    let out_slice = (&mut out[i..i + 64]).try_into().unwrap();
531                    unpack64(&self.buffer[self.byte_offset..], out_slice, num_bits);
532                    self.byte_offset += 8 * num_bits;
533                    i += 64;
534                }
535            }
536            _ => unreachable!(),
537        }
538
539        // Try to read smaller batches if possible
540        if size_of::<T>() > 4 && values_to_read - i >= 32 && num_bits <= 32 {
541            let mut out_buf = [0_u32; 32];
542            unpack32(&self.buffer[self.byte_offset..], &mut out_buf, num_bits);
543            self.byte_offset += 4 * num_bits;
544
545            for out in out_buf {
546                // Zero-allocate buffer
547                let mut out_bytes = T::Buffer::default();
548                out_bytes.as_mut()[..4].copy_from_slice(&out.to_le_bytes());
549                batch[i] = T::from_le_bytes(out_bytes);
550                i += 1;
551            }
552        }
553
554        if size_of::<T>() > 2 && values_to_read - i >= 16 && num_bits <= 16 {
555            let mut out_buf = [0_u16; 16];
556            unpack16(&self.buffer[self.byte_offset..], &mut out_buf, num_bits);
557            self.byte_offset += 2 * num_bits;
558
559            for out in out_buf {
560                // Zero-allocate buffer
561                let mut out_bytes = T::Buffer::default();
562                out_bytes.as_mut()[..2].copy_from_slice(&out.to_le_bytes());
563                batch[i] = T::from_le_bytes(out_bytes);
564                i += 1;
565            }
566        }
567
568        if size_of::<T>() > 1 && values_to_read - i >= 8 && num_bits <= 8 {
569            let mut out_buf = [0_u8; 8];
570            unpack8(&self.buffer[self.byte_offset..], &mut out_buf, num_bits);
571            self.byte_offset += num_bits;
572
573            for out in out_buf {
574                // Zero-allocate buffer
575                let mut out_bytes = T::Buffer::default();
576                out_bytes.as_mut()[..1].copy_from_slice(&out.to_le_bytes());
577                batch[i] = T::from_le_bytes(out_bytes);
578                i += 1;
579            }
580        }
581
582        // Read any trailing values
583        while i < values_to_read {
584            let value = self
585                .get_value(num_bits)
586                .expect("expected to have more data");
587            batch[i] = value;
588            i += 1;
589        }
590
591        values_to_read
592    }
593
594    /// Skip num_value values with num_bits bit width
595    ///
596    /// Return the number of values skipped (up to num_values)
597    pub fn skip(&mut self, num_values: usize, num_bits: usize) -> usize {
598        assert!(num_bits <= 64);
599
600        let needed_bits = num_bits * num_values;
601        let remaining_bits = (self.buffer.len() - self.byte_offset) * 8 - self.bit_offset;
602
603        let values_to_read = match remaining_bits < needed_bits {
604            true => remaining_bits / num_bits,
605            false => num_values,
606        };
607
608        let end_bit_offset = self.byte_offset * 8 + values_to_read * num_bits + self.bit_offset;
609
610        self.byte_offset = end_bit_offset / 8;
611        self.bit_offset = end_bit_offset % 8;
612
613        if self.bit_offset != 0 {
614            self.load_buffered_values()
615        }
616
617        values_to_read
618    }
619
620    /// Reads up to `num_bytes` to `buf` returning the number of bytes read
621    pub(crate) fn get_aligned_bytes(&mut self, buf: &mut Vec<u8>, num_bytes: usize) -> usize {
622        // Align to byte offset
623        self.byte_offset = self.get_byte_offset();
624        self.bit_offset = 0;
625
626        let src = &self.buffer[self.byte_offset..];
627        let to_read = num_bytes.min(src.len());
628        buf.extend_from_slice(&src[..to_read]);
629
630        self.byte_offset += to_read;
631
632        to_read
633    }
634
635    /// Reads a `num_bytes`-sized value from this buffer and return it.
636    /// `T` needs to be a little-endian native type. The value is assumed to be byte
637    /// aligned so the bit reader will be advanced to the start of the next byte before
638    /// reading the value.
639    ///
640    /// Returns `Some` if there's enough bytes left to form a value of `T`.
641    /// Otherwise `None`.
642    pub fn get_aligned<T: FromBytes>(&mut self, num_bytes: usize) -> Option<T> {
643        self.byte_offset = self.get_byte_offset();
644        self.bit_offset = 0;
645
646        if self.byte_offset + num_bytes > self.buffer.len() {
647            return None;
648        }
649
650        // Advance byte_offset to next unread byte and read num_bytes
651        let v = read_num_bytes::<T>(num_bytes, &self.buffer[self.byte_offset..]);
652        self.byte_offset += num_bytes;
653
654        Some(v)
655    }
656
657    /// Reads a VLQ encoded (in little endian order) int from the stream.
658    /// The encoded int must start at the beginning of a byte.
659    ///
660    /// Returns `None` if there's not enough bytes in the stream. `Some` otherwise.
661    pub fn get_vlq_int(&mut self) -> Option<i64> {
662        let mut shift = 0;
663        let mut v: i64 = 0;
664        while let Some(byte) = self.get_aligned::<u8>(1) {
665            v |= ((byte & 0x7F) as i64) << shift;
666            shift += 7;
667            assert!(
668                shift <= MAX_VLQ_BYTE_LEN * 7,
669                "Num of bytes exceed MAX_VLQ_BYTE_LEN ({MAX_VLQ_BYTE_LEN})"
670            );
671            if byte & 0x80 == 0 {
672                return Some(v);
673            }
674        }
675        None
676    }
677
678    /// Reads a zigzag-VLQ encoded (in little endian order) int from the stream
679    /// Zigzag-VLQ is a variant of VLQ encoding where negative and positive numbers are
680    /// encoded in a zigzag fashion.
681    /// See: https://developers.google.com/protocol-buffers/docs/encoding
682    ///
683    /// Note: the encoded int must start at the beginning of a byte.
684    ///
685    /// Returns `None` if the number of bytes there's not enough bytes in the stream.
686    /// `Some` otherwise.
687    #[inline]
688    pub fn get_zigzag_vlq_int(&mut self) -> Option<i64> {
689        self.get_vlq_int().map(|v| {
690            let u = v as u64;
691            (u >> 1) as i64 ^ -((u & 1) as i64)
692        })
693    }
694
695    /// Loads up to the the next 8 bytes from `self.buffer` at `self.byte_offset`
696    /// into `self.buffered_values`.
697    ///
698    /// Reads fewer than 8 bytes if there are fewer than 8 bytes left
699    #[inline]
700    fn load_buffered_values(&mut self) {
701        let bytes_to_read = cmp::min(self.buffer.len() - self.byte_offset, 8);
702        self.buffered_values =
703            read_num_bytes::<u64>(bytes_to_read, &self.buffer[self.byte_offset..]);
704    }
705}
706
707impl From<Vec<u8>> for BitReader {
708    #[inline]
709    fn from(buffer: Vec<u8>) -> Self {
710        BitReader::new(buffer.into())
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    use crate::util::test_common::rand_gen::random_numbers;
719    use rand::distributions::{Distribution, Standard};
720    use std::fmt::Debug;
721
722    #[test]
723    fn test_ceil() {
724        assert_eq!(ceil(0, 1), 0);
725        assert_eq!(ceil(1, 1), 1);
726        assert_eq!(ceil(1, 2), 1);
727        assert_eq!(ceil(1, 8), 1);
728        assert_eq!(ceil(7, 8), 1);
729        assert_eq!(ceil(8, 8), 1);
730        assert_eq!(ceil(9, 8), 2);
731        assert_eq!(ceil(9, 9), 1);
732        assert_eq!(ceil(10000000000_u64, 10), 1000000000);
733        assert_eq!(ceil(10_u64, 10000000000), 1);
734        assert_eq!(ceil(10000000000_u64, 1000000000), 10);
735    }
736
737    #[test]
738    fn test_bit_reader_get_byte_offset() {
739        let buffer = vec![255; 10];
740        let mut bit_reader = BitReader::from(buffer);
741        assert_eq!(bit_reader.get_byte_offset(), 0); // offset (0 bytes, 0 bits)
742        bit_reader.get_value::<i32>(6);
743        assert_eq!(bit_reader.get_byte_offset(), 1); // offset (0 bytes, 6 bits)
744        bit_reader.get_value::<i32>(10);
745        assert_eq!(bit_reader.get_byte_offset(), 2); // offset (0 bytes, 16 bits)
746        bit_reader.get_value::<i32>(20);
747        assert_eq!(bit_reader.get_byte_offset(), 5); // offset (0 bytes, 36 bits)
748        bit_reader.get_value::<i32>(30);
749        assert_eq!(bit_reader.get_byte_offset(), 9); // offset (8 bytes, 2 bits)
750    }
751
752    #[test]
753    fn test_bit_reader_get_value() {
754        let buffer = vec![255, 0];
755        let mut bit_reader = BitReader::from(buffer);
756        assert_eq!(bit_reader.get_value::<i32>(1), Some(1));
757        assert_eq!(bit_reader.get_value::<i32>(2), Some(3));
758        assert_eq!(bit_reader.get_value::<i32>(3), Some(7));
759        assert_eq!(bit_reader.get_value::<i32>(4), Some(3));
760    }
761
762    #[test]
763    fn test_bit_reader_skip() {
764        let buffer = vec![255, 0];
765        let mut bit_reader = BitReader::from(buffer);
766        let skipped = bit_reader.skip(1, 1);
767        assert_eq!(skipped, 1);
768        assert_eq!(bit_reader.get_value::<i32>(1), Some(1));
769        let skipped = bit_reader.skip(2, 2);
770        assert_eq!(skipped, 2);
771        assert_eq!(bit_reader.get_value::<i32>(2), Some(3));
772        let skipped = bit_reader.skip(4, 1);
773        assert_eq!(skipped, 4);
774        assert_eq!(bit_reader.get_value::<i32>(4), Some(0));
775        let skipped = bit_reader.skip(1, 1);
776        assert_eq!(skipped, 0);
777    }
778
779    #[test]
780    fn test_bit_reader_get_value_boundary() {
781        let buffer = vec![10, 0, 0, 0, 20, 0, 30, 0, 0, 0, 40, 0];
782        let mut bit_reader = BitReader::from(buffer);
783        assert_eq!(bit_reader.get_value::<i64>(32), Some(10));
784        assert_eq!(bit_reader.get_value::<i64>(16), Some(20));
785        assert_eq!(bit_reader.get_value::<i64>(32), Some(30));
786        assert_eq!(bit_reader.get_value::<i64>(16), Some(40));
787    }
788
789    #[test]
790    fn test_bit_reader_skip_boundary() {
791        let buffer = vec![10, 0, 0, 0, 20, 0, 30, 0, 0, 0, 40, 0];
792        let mut bit_reader = BitReader::from(buffer);
793        assert_eq!(bit_reader.get_value::<i64>(32), Some(10));
794        assert_eq!(bit_reader.skip(1, 16), 1);
795        assert_eq!(bit_reader.get_value::<i64>(32), Some(30));
796        assert_eq!(bit_reader.get_value::<i64>(16), Some(40));
797    }
798
799    #[test]
800    fn test_bit_reader_get_aligned() {
801        // 01110101 11001011
802        let buffer = Bytes::from(vec![0x75, 0xCB]);
803        let mut bit_reader = BitReader::new(buffer.clone());
804        assert_eq!(bit_reader.get_value::<i32>(3), Some(5));
805        assert_eq!(bit_reader.get_aligned::<i32>(1), Some(203));
806        assert_eq!(bit_reader.get_value::<i32>(1), None);
807        bit_reader.reset(buffer.clone());
808        assert_eq!(bit_reader.get_aligned::<i32>(3), None);
809    }
810
811    #[test]
812    fn test_bit_reader_get_vlq_int() {
813        // 10001001 00000001 11110010 10110101 00000110
814        let buffer: Vec<u8> = vec![0x89, 0x01, 0xF2, 0xB5, 0x06];
815        let mut bit_reader = BitReader::from(buffer);
816        assert_eq!(bit_reader.get_vlq_int(), Some(137));
817        assert_eq!(bit_reader.get_vlq_int(), Some(105202));
818    }
819
820    #[test]
821    fn test_bit_reader_get_zigzag_vlq_int() {
822        let buffer: Vec<u8> = vec![0, 1, 2, 3];
823        let mut bit_reader = BitReader::from(buffer);
824        assert_eq!(bit_reader.get_zigzag_vlq_int(), Some(0));
825        assert_eq!(bit_reader.get_zigzag_vlq_int(), Some(-1));
826        assert_eq!(bit_reader.get_zigzag_vlq_int(), Some(1));
827        assert_eq!(bit_reader.get_zigzag_vlq_int(), Some(-2));
828    }
829
830    #[test]
831    fn test_num_required_bits() {
832        assert_eq!(num_required_bits(0), 0);
833        assert_eq!(num_required_bits(1), 1);
834        assert_eq!(num_required_bits(2), 2);
835        assert_eq!(num_required_bits(4), 3);
836        assert_eq!(num_required_bits(8), 4);
837        assert_eq!(num_required_bits(10), 4);
838        assert_eq!(num_required_bits(12), 4);
839        assert_eq!(num_required_bits(16), 5);
840        assert_eq!(num_required_bits(u64::MAX), 64);
841    }
842
843    #[test]
844    fn test_get_bit() {
845        // 00001101
846        assert!(get_bit(&[0b00001101], 0));
847        assert!(!get_bit(&[0b00001101], 1));
848        assert!(get_bit(&[0b00001101], 2));
849        assert!(get_bit(&[0b00001101], 3));
850
851        // 01001001 01010010
852        assert!(get_bit(&[0b01001001, 0b01010010], 0));
853        assert!(!get_bit(&[0b01001001, 0b01010010], 1));
854        assert!(!get_bit(&[0b01001001, 0b01010010], 2));
855        assert!(get_bit(&[0b01001001, 0b01010010], 3));
856        assert!(!get_bit(&[0b01001001, 0b01010010], 4));
857        assert!(!get_bit(&[0b01001001, 0b01010010], 5));
858        assert!(get_bit(&[0b01001001, 0b01010010], 6));
859        assert!(!get_bit(&[0b01001001, 0b01010010], 7));
860        assert!(!get_bit(&[0b01001001, 0b01010010], 8));
861        assert!(get_bit(&[0b01001001, 0b01010010], 9));
862        assert!(!get_bit(&[0b01001001, 0b01010010], 10));
863        assert!(!get_bit(&[0b01001001, 0b01010010], 11));
864        assert!(get_bit(&[0b01001001, 0b01010010], 12));
865        assert!(!get_bit(&[0b01001001, 0b01010010], 13));
866        assert!(get_bit(&[0b01001001, 0b01010010], 14));
867        assert!(!get_bit(&[0b01001001, 0b01010010], 15));
868    }
869
870    #[test]
871    fn test_skip() {
872        let mut writer = BitWriter::new(5);
873        let old_offset = writer.skip(1);
874        writer.put_aligned(42, 4);
875        writer.put_aligned_offset(0x10, 1, old_offset);
876        let result = writer.consume();
877        assert_eq!(result.as_ref(), [0x10, 42, 0, 0, 0]);
878
879        writer = BitWriter::new(4);
880        let result = writer.skip(5);
881        assert_eq!(result, 0);
882        assert_eq!(writer.buffer(), &[0; 5])
883    }
884
885    #[test]
886    fn test_get_next_byte_ptr() {
887        let mut writer = BitWriter::new(5);
888        {
889            let first_byte = writer.get_next_byte_ptr(1);
890            first_byte[0] = 0x10;
891        }
892        writer.put_aligned(42, 4);
893        let result = writer.consume();
894        assert_eq!(result.as_ref(), [0x10, 42, 0, 0, 0]);
895    }
896
897    #[test]
898    fn test_consume_flush_buffer() {
899        let mut writer1 = BitWriter::new(3);
900        let mut writer2 = BitWriter::new(3);
901        for i in 1..10 {
902            writer1.put_value(i, 4);
903            writer2.put_value(i, 4);
904        }
905        let res1 = writer1.flush_buffer();
906        let res2 = writer2.consume();
907        assert_eq!(res1, &res2[..]);
908    }
909
910    #[test]
911    fn test_put_get_bool() {
912        let len = 8;
913        let mut writer = BitWriter::new(len);
914
915        for i in 0..8 {
916            writer.put_value(i % 2, 1);
917        }
918
919        writer.flush();
920        {
921            let buffer = writer.buffer();
922            assert_eq!(buffer[0], 0b10101010);
923        }
924
925        // Write 00110011
926        for i in 0..8 {
927            match i {
928                0 | 1 | 4 | 5 => writer.put_value(false as u64, 1),
929                _ => writer.put_value(true as u64, 1),
930            }
931        }
932        writer.flush();
933        {
934            let buffer = writer.buffer();
935            assert_eq!(buffer[0], 0b10101010);
936            assert_eq!(buffer[1], 0b11001100);
937        }
938
939        let mut reader = BitReader::from(writer.consume());
940
941        for i in 0..8 {
942            let val = reader
943                .get_value::<u8>(1)
944                .expect("get_value() should return OK");
945            assert_eq!(val, i % 2);
946        }
947
948        for i in 0..8 {
949            let val = reader
950                .get_value::<bool>(1)
951                .expect("get_value() should return OK");
952            match i {
953                0 | 1 | 4 | 5 => assert!(!val),
954                _ => assert!(val),
955            }
956        }
957    }
958
959    #[test]
960    fn test_put_value_roundtrip() {
961        test_put_value_rand_numbers(32, 2);
962        test_put_value_rand_numbers(32, 3);
963        test_put_value_rand_numbers(32, 4);
964        test_put_value_rand_numbers(32, 5);
965        test_put_value_rand_numbers(32, 6);
966        test_put_value_rand_numbers(32, 7);
967        test_put_value_rand_numbers(32, 8);
968        test_put_value_rand_numbers(64, 16);
969        test_put_value_rand_numbers(64, 24);
970        test_put_value_rand_numbers(64, 32);
971    }
972
973    fn test_put_value_rand_numbers(total: usize, num_bits: usize) {
974        assert!(num_bits < 64);
975        let num_bytes = ceil(num_bits, 8);
976        let mut writer = BitWriter::new(num_bytes * total);
977        let values: Vec<u64> = random_numbers::<u64>(total)
978            .iter()
979            .map(|v| v & ((1 << num_bits) - 1))
980            .collect();
981        (0..total).for_each(|i| writer.put_value(values[i], num_bits));
982
983        let mut reader = BitReader::from(writer.consume());
984        (0..total).for_each(|i| {
985            let v = reader
986                .get_value::<u64>(num_bits)
987                .expect("get_value() should return OK");
988            assert_eq!(
989                v, values[i],
990                "[{}]: expected {} but got {}",
991                i, values[i], v
992            );
993        });
994    }
995
996    #[test]
997    fn test_get_batch() {
998        const SIZE: &[usize] = &[1, 31, 32, 33, 128, 129];
999        for s in SIZE {
1000            for i in 0..=64 {
1001                match i {
1002                    0..=8 => test_get_batch_helper::<u8>(*s, i),
1003                    9..=16 => test_get_batch_helper::<u16>(*s, i),
1004                    17..=32 => test_get_batch_helper::<u32>(*s, i),
1005                    _ => test_get_batch_helper::<u64>(*s, i),
1006                }
1007            }
1008        }
1009    }
1010
1011    fn test_get_batch_helper<T>(total: usize, num_bits: usize)
1012    where
1013        T: FromBytes + Default + Clone + Debug + Eq,
1014    {
1015        assert!(num_bits <= 64);
1016        let num_bytes = ceil(num_bits, 8);
1017        let mut writer = BitWriter::new(num_bytes * total);
1018
1019        let mask = match num_bits {
1020            64 => u64::MAX,
1021            _ => (1 << num_bits) - 1,
1022        };
1023
1024        let values: Vec<u64> = random_numbers::<u64>(total)
1025            .iter()
1026            .map(|v| v & mask)
1027            .collect();
1028
1029        // Generic values used to check against actual values read from `get_batch`.
1030        let expected_values: Vec<T> = values
1031            .iter()
1032            .map(|v| T::try_from_le_slice(v.as_bytes()).unwrap())
1033            .collect();
1034
1035        (0..total).for_each(|i| writer.put_value(values[i], num_bits));
1036
1037        let buf = writer.consume();
1038        let mut reader = BitReader::from(buf);
1039        let mut batch = vec![T::default(); values.len()];
1040        let values_read = reader.get_batch::<T>(&mut batch, num_bits);
1041        assert_eq!(values_read, values.len());
1042        for i in 0..batch.len() {
1043            assert_eq!(
1044                batch[i],
1045                expected_values[i],
1046                "max_num_bits = {}, num_bits = {}, index = {}",
1047                size_of::<T>() * 8,
1048                num_bits,
1049                i
1050            );
1051        }
1052    }
1053
1054    #[test]
1055    fn test_put_aligned_roundtrip() {
1056        test_put_aligned_rand_numbers::<u8>(4, 3);
1057        test_put_aligned_rand_numbers::<u8>(16, 5);
1058        test_put_aligned_rand_numbers::<i16>(32, 7);
1059        test_put_aligned_rand_numbers::<i16>(32, 9);
1060        test_put_aligned_rand_numbers::<i32>(32, 11);
1061        test_put_aligned_rand_numbers::<i32>(32, 13);
1062        test_put_aligned_rand_numbers::<i64>(32, 17);
1063        test_put_aligned_rand_numbers::<i64>(32, 23);
1064    }
1065
1066    fn test_put_aligned_rand_numbers<T>(total: usize, num_bits: usize)
1067    where
1068        T: Copy + FromBytes + AsBytes + Debug + PartialEq,
1069        Standard: Distribution<T>,
1070    {
1071        assert!(num_bits <= 32);
1072        assert!(total % 2 == 0);
1073
1074        let aligned_value_byte_width = std::mem::size_of::<T>();
1075        let value_byte_width = ceil(num_bits, 8);
1076        let mut writer =
1077            BitWriter::new((total / 2) * (aligned_value_byte_width + value_byte_width));
1078        let values: Vec<u32> = random_numbers::<u32>(total / 2)
1079            .iter()
1080            .map(|v| v & ((1 << num_bits) - 1))
1081            .collect();
1082        let aligned_values = random_numbers::<T>(total / 2);
1083
1084        for i in 0..total {
1085            let j = i / 2;
1086            if i % 2 == 0 {
1087                writer.put_value(values[j] as u64, num_bits);
1088            } else {
1089                writer.put_aligned::<T>(aligned_values[j], aligned_value_byte_width)
1090            }
1091        }
1092
1093        let mut reader = BitReader::from(writer.consume());
1094        for i in 0..total {
1095            let j = i / 2;
1096            if i % 2 == 0 {
1097                let v = reader
1098                    .get_value::<u64>(num_bits)
1099                    .expect("get_value() should return OK");
1100                assert_eq!(
1101                    v, values[j] as u64,
1102                    "[{}]: expected {} but got {}",
1103                    i, values[j], v
1104                );
1105            } else {
1106                let v = reader
1107                    .get_aligned::<T>(aligned_value_byte_width)
1108                    .expect("get_aligned() should return OK");
1109                assert_eq!(
1110                    v, aligned_values[j],
1111                    "[{}]: expected {:?} but got {:?}",
1112                    i, aligned_values[j], v
1113                );
1114            }
1115        }
1116    }
1117
1118    #[test]
1119    fn test_put_vlq_int() {
1120        let total = 64;
1121        let mut writer = BitWriter::new(total * 32);
1122        let values = random_numbers::<u32>(total);
1123        (0..total).for_each(|i| writer.put_vlq_int(values[i] as u64));
1124
1125        let mut reader = BitReader::from(writer.consume());
1126        (0..total).for_each(|i| {
1127            let v = reader
1128                .get_vlq_int()
1129                .expect("get_vlq_int() should return OK");
1130            assert_eq!(
1131                v as u32, values[i],
1132                "[{}]: expected {} but got {}",
1133                i, values[i], v
1134            );
1135        });
1136    }
1137
1138    #[test]
1139    fn test_put_zigzag_vlq_int() {
1140        let total = 64;
1141        let mut writer = BitWriter::new(total * 32);
1142        let values = random_numbers::<i32>(total);
1143        (0..total).for_each(|i| writer.put_zigzag_vlq_int(values[i] as i64));
1144
1145        let mut reader = BitReader::from(writer.consume());
1146        (0..total).for_each(|i| {
1147            let v = reader
1148                .get_zigzag_vlq_int()
1149                .expect("get_zigzag_vlq_int() should return OK");
1150            assert_eq!(
1151                v as i32, values[i],
1152                "[{}]: expected {} but got {}",
1153                i, values[i], v
1154            );
1155        });
1156    }
1157
1158    #[test]
1159    fn test_get_batch_zero_extend() {
1160        let to_read = vec![0xFF; 4];
1161        let mut reader = BitReader::from(to_read);
1162
1163        // Create a non-zeroed output buffer
1164        let mut output = [u64::MAX; 32];
1165        reader.get_batch(&mut output, 1);
1166
1167        for v in output {
1168            // Values should be read correctly
1169            assert_eq!(v, 1);
1170        }
1171    }
1172}