arrow_buffer/util/
bit_mask.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utils for working with packed bit masks
19
20use crate::bit_util::ceil;
21
22/// Util function to set bits in a slice of bytes.
23///
24/// This will sets all bits on `write_data` in the range `[offset_write..offset_write+len]`
25/// to be equal to the bits in `data` in the range `[offset_read..offset_read+len]`
26/// returns the number of `0` bits `data[offset_read..offset_read+len]`
27/// `offset_write`, `offset_read`, and `len` are in terms of bits
28pub fn set_bits(
29    write_data: &mut [u8],
30    data: &[u8],
31    offset_write: usize,
32    offset_read: usize,
33    len: usize,
34) -> usize {
35    assert!(offset_write + len <= write_data.len() * 8);
36    assert!(offset_read + len <= data.len() * 8);
37    let mut null_count = 0;
38    let mut acc = 0;
39    while len > acc {
40        // SAFETY: the arguments to `set_upto_64bits` are within the valid range because
41        // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8
42        // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8
43        let (n, len_set) = unsafe {
44            set_upto_64bits(
45                write_data,
46                data,
47                offset_write + acc,
48                offset_read + acc,
49                len - acc,
50            )
51        };
52        null_count += n;
53        acc += len_set;
54    }
55
56    null_count
57}
58
59/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary.
60/// Returns a pair of the number of `0` bits and the number of bits set
61///
62/// # Safety
63/// The caller must ensure all arguments are within the valid range.
64#[inline]
65unsafe fn set_upto_64bits(
66    write_data: &mut [u8],
67    data: &[u8],
68    offset_write: usize,
69    offset_read: usize,
70    len: usize,
71) -> (usize, usize) {
72    let read_byte = offset_read / 8;
73    let read_shift = offset_read % 8;
74    let write_byte = offset_write / 8;
75    let write_shift = offset_write % 8;
76
77    if len >= 64 {
78        let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
79        if read_shift == 0 {
80            if write_shift == 0 {
81                // no shifting necessary
82                let len = 64;
83                let null_count = chunk.count_zeros() as usize;
84                unsafe { write_u64_bytes(write_data, write_byte, chunk) };
85                (null_count, len)
86            } else {
87                // only write shifting necessary
88                let len = 64 - write_shift;
89                let chunk = chunk << write_shift;
90                let null_count = len - chunk.count_ones() as usize;
91                unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
92                (null_count, len)
93            }
94        } else if write_shift == 0 {
95            // only read shifting necessary
96            let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0
97            let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask
98            let null_count = len - chunk.count_ones() as usize;
99            unsafe { write_u64_bytes(write_data, write_byte, chunk) };
100            (null_count, len)
101        } else {
102            let len = 64 - std::cmp::max(read_shift, write_shift);
103            let chunk = (chunk >> read_shift) << write_shift;
104            let null_count = len - chunk.count_ones() as usize;
105            unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
106            (null_count, len)
107        }
108    } else if len == 1 {
109        let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
110        unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
111        ((byte_chunk ^ 1) as usize, 1)
112    } else {
113        let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
114        let bytes = ceil(len + read_shift, 8);
115        // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len()
116        let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
117        let mask = u64::MAX >> (64 - len);
118        let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only
119        let chunk = chunk << write_shift; // shifting back to align with `write_data`
120        let null_count = len - chunk.count_ones() as usize;
121        let bytes = ceil(len + write_shift, 8);
122        for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
123            unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
124        }
125        (null_count, len)
126    }
127}
128
129/// # Safety
130/// The caller must ensure `data` has `offset..(offset + 8)` range, and `count <= 8`.
131#[inline]
132unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
133    debug_assert!(count <= 8);
134    let mut tmp: u64 = 0;
135    let src = unsafe { data.as_ptr().add(offset) };
136    unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) };
137    tmp
138}
139
140/// # Safety
141/// The caller must ensure `data` has `offset..(offset + 8)` range
142#[inline]
143unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
144    let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64;
145    unsafe { ptr.write_unaligned(chunk) };
146}
147
148/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk`
149/// instead of overwriting
150///
151/// # Safety
152/// The caller must ensure `data` has `offset..(offset + 8)` range
153#[inline]
154unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
155    let ptr = unsafe { data.as_mut_ptr().add(offset) };
156    let chunk = chunk | (unsafe { *ptr }) as u64;
157    unsafe { (ptr as *mut u64).write_unaligned(chunk) };
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::bit_util::{get_bit, set_bit, unset_bit};
164    use rand::prelude::StdRng;
165    use rand::{Rng, SeedableRng, TryRngCore};
166    use std::fmt::Display;
167
168    #[test]
169    fn test_set_bits_aligned() {
170        SetBitsTest {
171            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
172            data: vec![
173                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
174                0b10100101,
175            ],
176            offset_write: 8,
177            offset_read: 0,
178            len: 64,
179            expected_data: vec![
180                0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
181                0b11100111, 0b10100101, 0,
182            ],
183            expected_null_count: 24,
184        }
185        .verify();
186    }
187
188    #[test]
189    fn test_set_bits_unaligned_destination_start() {
190        SetBitsTest {
191            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
192            data: vec![
193                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
194                0b10100101,
195            ],
196            offset_write: 3,
197            offset_read: 0,
198            len: 64,
199            expected_data: vec![
200                0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110,
201                0b00101111, 0b00000101, 0b00000000,
202            ],
203            expected_null_count: 24,
204        }
205        .verify();
206    }
207
208    #[test]
209    fn test_set_bits_unaligned_destination_end() {
210        SetBitsTest {
211            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
212            data: vec![
213                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
214                0b10100101,
215            ],
216            offset_write: 8,
217            offset_read: 0,
218            len: 62,
219            expected_data: vec![
220                0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
221                0b11100111, 0b00100101, 0,
222            ],
223            expected_null_count: 23,
224        }
225        .verify();
226    }
227
228    #[test]
229    fn test_set_bits_unaligned() {
230        SetBitsTest {
231            write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
232            data: vec![
233                0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
234                0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101,
235                0b10011001, 0b11011011, 0b11101011, 0b11000011,
236            ],
237            offset_write: 3,
238            offset_read: 5,
239            len: 95,
240            expected_data: vec![
241                0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001,
242                0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001,
243            ],
244            expected_null_count: 35,
245        }
246        .verify();
247    }
248
249    #[test]
250    fn set_bits_fuzz() {
251        let mut rng = StdRng::seed_from_u64(42);
252        let mut data = SetBitsTest::new();
253        for _ in 0..100 {
254            data.regen(&mut rng);
255            data.verify();
256        }
257    }
258
259    #[derive(Debug, Default)]
260    struct SetBitsTest {
261        /// target write data
262        write_data: Vec<u8>,
263        /// source data
264        data: Vec<u8>,
265        offset_write: usize,
266        offset_read: usize,
267        len: usize,
268        /// the expected contents of write_data after the test
269        expected_data: Vec<u8>,
270        /// the expected number of nulls copied at the end of the test
271        expected_null_count: usize,
272    }
273
274    /// prints a byte slice as a binary string like "01010101 10101010"
275    struct BinaryFormatter<'a>(&'a [u8]);
276    impl Display for BinaryFormatter<'_> {
277        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278            for byte in self.0 {
279                write!(f, "{byte:08b} ")?;
280            }
281            write!(f, " ")?;
282            Ok(())
283        }
284    }
285
286    impl Display for SetBitsTest {
287        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288            writeln!(f, "SetBitsTest {{")?;
289            writeln!(f, "  write_data:    {}", BinaryFormatter(&self.write_data))?;
290            writeln!(f, "  data:          {}", BinaryFormatter(&self.data))?;
291            writeln!(
292                f,
293                "  expected_data: {}",
294                BinaryFormatter(&self.expected_data)
295            )?;
296            writeln!(f, "  offset_write: {}", self.offset_write)?;
297            writeln!(f, "  offset_read: {}", self.offset_read)?;
298            writeln!(f, "  len: {}", self.len)?;
299            writeln!(f, "  expected_null_count: {}", self.expected_null_count)?;
300            writeln!(f, "}}")
301        }
302    }
303
304    impl SetBitsTest {
305        /// create a new instance of FuzzData
306        fn new() -> Self {
307            Self::default()
308        }
309
310        /// Update this instance's fields with randomly selected values and expected data
311        fn regen(&mut self, rng: &mut StdRng) {
312            //  (read) data
313            // ------------------+-----------------+-------
314            // .. offset_read .. | data            | ...
315            // ------------------+-----------------+-------
316
317            // Write data
318            // -------------------+-----------------+-------
319            // .. offset_write .. | (data to write) | ...
320            // -------------------+-----------------+-------
321
322            // length of data to copy
323            let len = rng.random_range(0..=200);
324
325            // randomly pick where we will write to
326            let offset_write_bits = rng.random_range(0..=200);
327            let offset_write_bytes = if offset_write_bits % 8 == 0 {
328                offset_write_bits / 8
329            } else {
330                (offset_write_bits / 8) + 1
331            };
332            let extra_write_data_bytes = rng.random_range(0..=5); // ensure 0 shows up often
333
334            // randomly decide where we will read from
335            let extra_read_data_bytes = rng.random_range(0..=5); // make sure 0 shows up often
336            let offset_read_bits = rng.random_range(0..=200);
337            let offset_read_bytes = if offset_read_bits % 8 != 0 {
338                (offset_read_bits / 8) + 1
339            } else {
340                offset_read_bits / 8
341            };
342
343            // create space for writing
344            self.write_data.clear();
345            self.write_data
346                .resize(offset_write_bytes + len + extra_write_data_bytes, 0);
347
348            // interestingly set_bits seems to assume the output is already zeroed
349            // the fuzz tests fail when this is uncommented
350            //self.write_data.try_fill(rng).unwrap();
351            self.offset_write = offset_write_bits;
352
353            // make source data
354            self.data
355                .resize(offset_read_bytes + len + extra_read_data_bytes, 0);
356            // fill source data with random bytes
357            rng.try_fill_bytes(self.data.as_mut_slice()).unwrap();
358            self.offset_read = offset_read_bits;
359
360            self.len = len;
361
362            // generated expectated output (not efficient)
363            self.expected_data.resize(self.write_data.len(), 0);
364            self.expected_data.copy_from_slice(&self.write_data);
365
366            self.expected_null_count = 0;
367            for i in 0..self.len {
368                let bit = get_bit(&self.data, self.offset_read + i);
369                if bit {
370                    set_bit(&mut self.expected_data, self.offset_write + i);
371                } else {
372                    unset_bit(&mut self.expected_data, self.offset_write + i);
373                    self.expected_null_count += 1;
374                }
375            }
376        }
377
378        /// call set_bits with the given parameters and compare with the expected output
379        fn verify(&self) {
380            // call set_bits and compare
381            let mut actual = self.write_data.to_vec();
382            let null_count = set_bits(
383                &mut actual,
384                &self.data,
385                self.offset_write,
386                self.offset_read,
387                self.len,
388            );
389
390            assert_eq!(actual, self.expected_data, "self: {self}");
391            assert_eq!(null_count, self.expected_null_count, "self: {self}");
392        }
393    }
394
395    #[test]
396    fn test_set_upto_64bits() {
397        // len >= 64
398        let write_data: &mut [u8] = &mut [0; 9];
399        let data: &[u8] = &[
400            0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
401            0b00000001, 0b00000001,
402        ];
403        let offset_write = 1;
404        let offset_read = 0;
405        let len = 65;
406        let (n, len_set) =
407            unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
408        assert_eq!(n, 55);
409        assert_eq!(len_set, 63);
410        assert_eq!(
411            write_data,
412            &[
413                0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
414                0b00000010, 0b00000000
415            ]
416        );
417
418        // len = 1
419        let write_data: &mut [u8] = &mut [0b00000000];
420        let data: &[u8] = &[0b00000001];
421        let offset_write = 1;
422        let offset_read = 0;
423        let len = 1;
424        let (n, len_set) =
425            unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
426        assert_eq!(n, 0);
427        assert_eq!(len_set, 1);
428        assert_eq!(write_data, &[0b00000010]);
429    }
430}