1use crate::util::bit_util::ceil;
21use std::fmt::Debug;
22
23#[derive(Debug)]
31pub struct UnalignedBitChunk<'a> {
32 lead_padding: usize,
33 trailing_padding: usize,
34
35 prefix: Option<u64>,
36 chunks: &'a [u64],
37 suffix: Option<u64>,
38}
39
40impl<'a> UnalignedBitChunk<'a> {
41 pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
43 if len == 0 {
44 return Self {
45 lead_padding: 0,
46 trailing_padding: 0,
47 prefix: None,
48 chunks: &[],
49 suffix: None,
50 };
51 }
52
53 let byte_offset = offset / 8;
54 let offset_padding = offset % 8;
55
56 let bytes_len = (len + offset_padding).div_ceil(8);
57 let buffer = &buffer[byte_offset..byte_offset + bytes_len];
58
59 let prefix_mask = compute_prefix_mask(offset_padding);
60
61 if buffer.len() <= 8 {
63 let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding);
64 let prefix = read_u64(buffer) & suffix_mask & prefix_mask;
65
66 return Self {
67 lead_padding: offset_padding,
68 trailing_padding,
69 prefix: Some(prefix),
70 chunks: &[],
71 suffix: None,
72 };
73 }
74
75 if buffer.len() <= 16 {
77 let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding);
78 let prefix = read_u64(&buffer[..8]) & prefix_mask;
79 let suffix = read_u64(&buffer[8..]) & suffix_mask;
80
81 return Self {
82 lead_padding: offset_padding,
83 trailing_padding,
84 prefix: Some(prefix),
85 chunks: &[],
86 suffix: Some(suffix),
87 };
88 }
89
90 let (prefix, mut chunks, suffix) = unsafe { buffer.align_to::<u64>() };
92 assert!(
93 prefix.len() < 8 && suffix.len() < 8,
94 "align_to did not return largest possible aligned slice"
95 );
96
97 let (alignment_padding, prefix) = match (offset_padding, prefix.is_empty()) {
98 (0, true) => (0, None),
99 (_, true) => {
100 let prefix = chunks[0] & prefix_mask;
101 chunks = &chunks[1..];
102 (0, Some(prefix))
103 }
104 (_, false) => {
105 let alignment_padding = (8 - prefix.len()) * 8;
106
107 let prefix = (read_u64(prefix) & prefix_mask) << alignment_padding;
108 (alignment_padding, Some(prefix))
109 }
110 };
111
112 let lead_padding = offset_padding + alignment_padding;
113 let (suffix_mask, trailing_padding) = compute_suffix_mask(len, lead_padding);
114
115 let suffix = match (trailing_padding, suffix.is_empty()) {
116 (0, _) => None,
117 (_, true) => {
118 let suffix = chunks[chunks.len() - 1] & suffix_mask;
119 chunks = &chunks[..chunks.len() - 1];
120 Some(suffix)
121 }
122 (_, false) => Some(read_u64(suffix) & suffix_mask),
123 };
124
125 Self {
126 lead_padding,
127 trailing_padding,
128 prefix,
129 chunks,
130 suffix,
131 }
132 }
133
134 pub fn lead_padding(&self) -> usize {
136 self.lead_padding
137 }
138
139 pub fn trailing_padding(&self) -> usize {
141 self.trailing_padding
142 }
143
144 pub fn prefix(&self) -> Option<u64> {
146 self.prefix
147 }
148
149 pub fn suffix(&self) -> Option<u64> {
151 self.suffix
152 }
153
154 pub fn chunks(&self) -> &'a [u64] {
156 self.chunks
157 }
158
159 pub fn iter(&self) -> UnalignedBitChunkIterator<'a> {
161 self.prefix
162 .into_iter()
163 .chain(self.chunks.iter().cloned())
164 .chain(self.suffix)
165 }
166
167 pub fn count_ones(&self) -> usize {
169 self.iter().map(|x| x.count_ones() as usize).sum()
170 }
171}
172
173pub type UnalignedBitChunkIterator<'a> = std::iter::Chain<
175 std::iter::Chain<std::option::IntoIter<u64>, std::iter::Cloned<std::slice::Iter<'a, u64>>>,
176 std::option::IntoIter<u64>,
177>;
178
179#[inline]
180fn read_u64(input: &[u8]) -> u64 {
181 let len = input.len().min(8);
182 let mut buf = [0_u8; 8];
183 buf[..len].copy_from_slice(input);
184 u64::from_le_bytes(buf)
185}
186
187#[inline]
188fn compute_prefix_mask(lead_padding: usize) -> u64 {
189 !((1 << lead_padding) - 1)
190}
191
192#[inline]
193fn compute_suffix_mask(len: usize, lead_padding: usize) -> (u64, usize) {
194 let trailing_bits = (len + lead_padding) % 64;
195
196 if trailing_bits == 0 {
197 return (u64::MAX, 0);
198 }
199
200 let trailing_padding = 64 - trailing_bits;
201 let suffix_mask = (1 << trailing_bits) - 1;
202 (suffix_mask, trailing_padding)
203}
204
205#[derive(Debug)]
210pub struct BitChunks<'a> {
211 buffer: &'a [u8],
212 bit_offset: usize,
214 chunk_len: usize,
216 remainder_len: usize,
218}
219
220impl<'a> BitChunks<'a> {
221 pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self {
223 let end = offset.checked_add(len).expect("offset + len out of bounds");
224 assert!(ceil(end, 8) <= buffer.len(), "offset + len out of bounds");
225
226 let byte_offset = offset / 8;
227 let bit_offset = offset % 8;
228
229 let chunk_len = len / 64;
231 let remainder_len = len % 64;
233
234 BitChunks::<'a> {
235 buffer: &buffer[byte_offset..],
236 bit_offset,
237 chunk_len,
238 remainder_len,
239 }
240 }
241}
242
243#[derive(Debug)]
245pub struct BitChunkIterator<'a> {
246 buffer: &'a [u8],
247 bit_offset: usize,
248 chunk_len: usize,
249 index: usize,
250}
251
252impl<'a> BitChunks<'a> {
253 #[inline]
255 pub const fn remainder_len(&self) -> usize {
256 self.remainder_len
257 }
258
259 #[inline]
261 pub const fn chunk_len(&self) -> usize {
262 self.chunk_len
263 }
264
265 #[inline]
267 pub fn remainder_bits(&self) -> u64 {
268 let bit_len = self.remainder_len;
269 if bit_len == 0 {
270 0
271 } else {
272 let bit_offset = self.bit_offset;
273 let byte_len = ceil(bit_len + bit_offset, 8);
276 let base = unsafe {
278 self.buffer
279 .as_ptr()
280 .add(self.chunk_len * std::mem::size_of::<u64>())
281 };
282
283 let mut bits = unsafe { std::ptr::read(base) } as u64 >> bit_offset;
284 for i in 1..byte_len {
285 let byte = unsafe { std::ptr::read(base.add(i)) };
286 bits |= (byte as u64) << (i * 8 - bit_offset);
287 }
288
289 bits & ((1 << bit_len) - 1)
290 }
291 }
292
293 #[inline]
299 pub fn num_u64s(&self) -> usize {
300 if self.remainder_len == 0 {
301 self.chunk_len
302 } else {
303 self.chunk_len + 1
304 }
305 }
306
307 #[inline]
310 pub fn num_bytes(&self) -> usize {
311 ceil(self.chunk_len * 64 + self.remainder_len, 8)
312 }
313
314 #[inline]
316 pub const fn iter(&self) -> BitChunkIterator<'a> {
317 BitChunkIterator::<'a> {
318 buffer: self.buffer,
319 bit_offset: self.bit_offset,
320 chunk_len: self.chunk_len,
321 index: 0,
322 }
323 }
324
325 #[inline]
327 pub fn iter_padded(&self) -> impl Iterator<Item = u64> + 'a {
328 self.iter().chain(std::iter::once(self.remainder_bits()))
329 }
330}
331
332impl<'a> IntoIterator for BitChunks<'a> {
333 type Item = u64;
334 type IntoIter = BitChunkIterator<'a>;
335
336 fn into_iter(self) -> Self::IntoIter {
337 self.iter()
338 }
339}
340
341impl Iterator for BitChunkIterator<'_> {
342 type Item = u64;
343
344 #[inline]
345 fn next(&mut self) -> Option<u64> {
346 let index = self.index;
347 if index >= self.chunk_len {
348 return None;
349 }
350
351 #[allow(clippy::cast_ptr_alignment)]
353 let raw_data = self.buffer.as_ptr() as *const u64;
354
355 let current = unsafe { std::ptr::read_unaligned(raw_data.add(index)).to_le() };
358
359 let bit_offset = self.bit_offset;
360
361 let combined = if bit_offset == 0 {
362 current
363 } else {
364 let next =
367 unsafe { std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 };
368
369 (current >> bit_offset) | (next << (64 - bit_offset))
370 };
371
372 self.index = index + 1;
373
374 Some(combined)
375 }
376
377 #[inline]
378 fn size_hint(&self) -> (usize, Option<usize>) {
379 (
380 self.chunk_len - self.index,
381 Some(self.chunk_len - self.index),
382 )
383 }
384}
385
386impl ExactSizeIterator for BitChunkIterator<'_> {
387 #[inline]
388 fn len(&self) -> usize {
389 self.chunk_len - self.index
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use rand::distr::uniform::UniformSampler;
396 use rand::distr::uniform::UniformUsize;
397 use rand::prelude::*;
398 use rand::rng;
399
400 use crate::buffer::Buffer;
401 use crate::util::bit_chunk_iterator::UnalignedBitChunk;
402
403 #[test]
404 fn test_iter_aligned() {
405 let input: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7];
406 let buffer: Buffer = Buffer::from(input);
407
408 let bitchunks = buffer.bit_chunks(0, 64);
409 let result = bitchunks.into_iter().collect::<Vec<_>>();
410
411 assert_eq!(vec![0x0706050403020100], result);
412 }
413
414 #[test]
415 fn test_iter_unaligned() {
416 let input: &[u8] = &[
417 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000,
418 0b01000000, 0b11111111,
419 ];
420 let buffer: Buffer = Buffer::from(input);
421
422 let bitchunks = buffer.bit_chunks(4, 64);
423
424 assert_eq!(0, bitchunks.remainder_len());
425 assert_eq!(0, bitchunks.remainder_bits());
426
427 let result = bitchunks.into_iter().collect::<Vec<_>>();
428
429 assert_eq!(
430 vec![0b1111010000000010000000010000000010000000010000000010000000010000],
431 result
432 );
433 }
434
435 #[test]
436 fn test_iter_unaligned_remainder_1_byte() {
437 let input: &[u8] = &[
438 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000,
439 0b01000000, 0b11111111,
440 ];
441 let buffer: Buffer = Buffer::from(input);
442
443 let bitchunks = buffer.bit_chunks(4, 66);
444
445 assert_eq!(2, bitchunks.remainder_len());
446 assert_eq!(0b00000011, bitchunks.remainder_bits());
447
448 let result = bitchunks.into_iter().collect::<Vec<_>>();
449
450 assert_eq!(
451 vec![0b1111010000000010000000010000000010000000010000000010000000010000],
452 result
453 );
454 }
455
456 #[test]
457 fn test_iter_unaligned_remainder_bits_across_bytes() {
458 let input: &[u8] = &[0b00111111, 0b11111100];
459 let buffer: Buffer = Buffer::from(input);
460
461 let bitchunks = buffer.bit_chunks(6, 7);
464
465 assert_eq!(7, bitchunks.remainder_len());
466 assert_eq!(0b1110000, bitchunks.remainder_bits());
467 }
468
469 #[test]
470 fn test_iter_unaligned_remainder_bits_large() {
471 let input: &[u8] = &[
472 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111,
473 0b00000000, 0b11111111,
474 ];
475 let buffer: Buffer = Buffer::from(input);
476
477 let bitchunks = buffer.bit_chunks(2, 63);
478
479 assert_eq!(63, bitchunks.remainder_len());
480 assert_eq!(
481 0b100_0000_0011_1111_1100_0000_0011_1111_1100_0000_0011_1111_1100_0000_0011_1111,
482 bitchunks.remainder_bits()
483 );
484 }
485
486 #[test]
487 fn test_iter_remainder_out_of_bounds() {
488 const ALLOC_SIZE: usize = 4 * 1024;
490 let input = vec![0xFF_u8; ALLOC_SIZE];
491
492 let buffer: Buffer = Buffer::from_vec(input);
493
494 let bitchunks = buffer.bit_chunks(57, ALLOC_SIZE * 8 - 57);
495
496 assert_eq!(u64::MAX, bitchunks.iter().last().unwrap());
497 assert_eq!(0x7F, bitchunks.remainder_bits());
498 }
499
500 #[test]
501 #[should_panic(expected = "offset + len out of bounds")]
502 fn test_out_of_bound_should_panic_length_is_more_than_buffer_length() {
503 const ALLOC_SIZE: usize = 4 * 1024;
504 let input = vec![0xFF_u8; ALLOC_SIZE];
505
506 let buffer: Buffer = Buffer::from_vec(input);
507
508 buffer.bit_chunks(0, (ALLOC_SIZE + 1) * 8);
510 }
511
512 #[test]
513 #[should_panic(expected = "offset + len out of bounds")]
514 fn test_out_of_bound_should_panic_length_is_more_than_buffer_length_but_not_when_not_using_ceil()
515 {
516 const ALLOC_SIZE: usize = 4 * 1024;
517 let input = vec![0xFF_u8; ALLOC_SIZE];
518
519 let buffer: Buffer = Buffer::from_vec(input);
520
521 buffer.bit_chunks(0, (ALLOC_SIZE * 8) + 1);
523 }
524
525 #[test]
526 #[should_panic(expected = "offset + len out of bounds")]
527 fn test_out_of_bound_should_panic_when_offset_is_not_zero_and_length_is_the_entire_buffer_length()
528 {
529 const ALLOC_SIZE: usize = 4 * 1024;
530 let input = vec![0xFF_u8; ALLOC_SIZE];
531
532 let buffer: Buffer = Buffer::from_vec(input);
533
534 buffer.bit_chunks(8, ALLOC_SIZE * 8);
536 }
537
538 #[test]
539 #[should_panic(expected = "offset + len out of bounds")]
540 fn test_out_of_bound_should_panic_when_offset_is_not_zero_and_length_is_the_entire_buffer_length_with_ceil()
541 {
542 const ALLOC_SIZE: usize = 4 * 1024;
543 let input = vec![0xFF_u8; ALLOC_SIZE];
544
545 let buffer: Buffer = Buffer::from_vec(input);
546
547 buffer.bit_chunks(1, ALLOC_SIZE * 8);
549 }
550
551 #[test]
552 #[should_panic(expected = "offset + len out of bounds")]
553 fn test_out_of_bound_should_panic_when_offset_and_length_overflow() {
554 let buffer = Buffer::from(vec![0xFF_u8; 8]);
555 buffer.bit_chunks(1, usize::MAX);
556 }
557
558 #[test]
559 #[allow(clippy::assertions_on_constants)]
560 fn test_unaligned_bit_chunk_iterator() {
561 let buffer = Buffer::from(&[0xFF; 5]);
562 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 40);
563
564 assert!(unaligned.chunks().is_empty()); assert_eq!(unaligned.lead_padding(), 0);
566 assert_eq!(unaligned.trailing_padding(), 24);
567 assert_eq!(
569 unaligned.prefix(),
570 Some(0b0000000000000000000000001111111111111111111111111111111111111111)
571 );
572 assert_eq!(unaligned.suffix(), None);
573
574 let buffer = buffer.slice(1);
575 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 32);
576
577 assert!(unaligned.chunks().is_empty()); assert_eq!(unaligned.lead_padding(), 0);
579 assert_eq!(unaligned.trailing_padding(), 32);
580 assert_eq!(
582 unaligned.prefix(),
583 Some(0b0000000000000000000000000000000011111111111111111111111111111111)
584 );
585 assert_eq!(unaligned.suffix(), None);
586
587 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 5, 27);
588
589 assert!(unaligned.chunks().is_empty()); assert_eq!(unaligned.lead_padding(), 5); assert_eq!(unaligned.trailing_padding(), 32);
592 assert_eq!(
594 unaligned.prefix(),
595 Some(0b0000000000000000000000000000000011111111111111111111111111100000)
596 );
597 assert_eq!(unaligned.suffix(), None);
598
599 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 12, 20);
600
601 assert!(unaligned.chunks().is_empty()); assert_eq!(unaligned.lead_padding(), 4); assert_eq!(unaligned.trailing_padding(), 40);
604 assert_eq!(
606 unaligned.prefix(),
607 Some(0b0000000000000000000000000000000000000000111111111111111111110000)
608 );
609 assert_eq!(unaligned.suffix(), None);
610
611 let buffer = Buffer::from(&[0xFF; 14]);
612
613 let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::<u64>() };
615 assert_eq!(prefix.len(), 0);
616 assert_eq!(aligned.len(), 1);
617 assert_eq!(suffix.len(), 6);
618
619 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 112);
620
621 assert!(unaligned.chunks().is_empty()); assert_eq!(unaligned.lead_padding(), 0); assert_eq!(unaligned.trailing_padding(), 16);
624 assert_eq!(unaligned.prefix(), Some(u64::MAX));
625 assert_eq!(unaligned.suffix(), Some((1 << 48) - 1));
626
627 let buffer = Buffer::from(&[0xFF; 16]);
628
629 let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::<u64>() };
631 assert_eq!(prefix.len(), 0);
632 assert_eq!(aligned.len(), 2);
633 assert_eq!(suffix.len(), 0);
634
635 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 128);
636
637 assert_eq!(unaligned.prefix(), Some(u64::MAX));
638 assert_eq!(unaligned.suffix(), Some(u64::MAX));
639 assert!(unaligned.chunks().is_empty()); let buffer = Buffer::from(&[0xFF; 64]);
642
643 let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::<u64>() };
645 assert_eq!(prefix.len(), 0);
646 assert_eq!(aligned.len(), 8);
647 assert_eq!(suffix.len(), 0);
648
649 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 512);
650
651 assert_eq!(unaligned.suffix(), None);
653 assert_eq!(unaligned.prefix(), None);
654 assert_eq!(unaligned.chunks(), [u64::MAX; 8].as_slice());
655 assert_eq!(unaligned.lead_padding(), 0);
656 assert_eq!(unaligned.trailing_padding(), 0);
657
658 let buffer = buffer.slice(1); let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::<u64>() };
662 assert_eq!(prefix.len(), 7);
663 assert_eq!(aligned.len(), 7);
664 assert_eq!(suffix.len(), 0);
665
666 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 504);
667
668 assert_eq!(unaligned.prefix(), Some(u64::MAX - 0xFF));
670 assert_eq!(unaligned.suffix(), None);
671 assert_eq!(unaligned.chunks(), [u64::MAX; 7].as_slice());
672 assert_eq!(unaligned.lead_padding(), 8);
673 assert_eq!(unaligned.trailing_padding(), 0);
674
675 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 300);
676
677 assert_eq!(unaligned.lead_padding(), 25);
684 assert_eq!(unaligned.trailing_padding(), 59);
685 assert_eq!(unaligned.prefix(), Some(u64::MAX - (1 << 25) + 1));
686 assert_eq!(unaligned.suffix(), Some(0b11111));
687 assert_eq!(unaligned.chunks(), [u64::MAX; 4].as_slice());
688
689 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 0);
690
691 assert_eq!(unaligned.prefix(), None);
692 assert_eq!(unaligned.suffix(), None);
693 assert!(unaligned.chunks().is_empty());
694 assert_eq!(unaligned.lead_padding(), 0);
695 assert_eq!(unaligned.trailing_padding(), 0);
696
697 let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 1);
698
699 assert_eq!(unaligned.prefix(), Some(2));
700 assert_eq!(unaligned.suffix(), None);
701 assert!(unaligned.chunks().is_empty());
702 assert_eq!(unaligned.lead_padding(), 1);
703 assert_eq!(unaligned.trailing_padding(), 62);
704 }
705
706 #[test]
707 #[cfg_attr(miri, ignore)]
708 fn fuzz_unaligned_bit_chunk_iterator() {
709 let mut rng = rng();
710
711 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
712 for _ in 0..100 {
713 let mask_len = rng.random_range(0..1024);
714 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random()))
715 .take(mask_len)
716 .collect();
717
718 let buffer = Buffer::from_iter(bools.iter().cloned());
719
720 let max_offset = 64.min(mask_len);
721 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
722
723 let max_truncate = 128.min(mask_len - offset);
724 let truncate = uusize
725 .sample(&mut rng)
726 .checked_rem(max_truncate)
727 .unwrap_or(0);
728
729 let unaligned =
730 UnalignedBitChunk::new(buffer.as_slice(), offset, mask_len - offset - truncate);
731
732 let bool_slice = &bools[offset..mask_len - truncate];
733
734 let count = unaligned.count_ones();
735 let expected_count = bool_slice.iter().filter(|x| **x).count();
736
737 assert_eq!(count, expected_count);
738
739 let collected: Vec<u64> = unaligned.iter().collect();
740
741 let get_bit = |idx: usize| -> bool {
742 let padded_index = idx + unaligned.lead_padding();
743 let byte_idx = padded_index / 64;
744 let bit_idx = padded_index % 64;
745 (collected[byte_idx] & (1 << bit_idx)) != 0
746 };
747
748 for (idx, b) in bool_slice.iter().enumerate() {
749 assert_eq!(*b, get_bit(idx))
750 }
751 }
752 }
753}