1use crate::bit_util::ceil;
21
22pub fn set_bits(
29 write_data: &mut [u8],
30 data: &[u8],
31 offset_write: usize,
32 offset_read: usize,
33 len: usize,
34) -> usize {
35 assert!(offset_write + len <= write_data.len() * 8);
36 assert!(offset_read + len <= data.len() * 8);
37 let mut null_count = 0;
38 let mut acc = 0;
39 while len > acc {
40 let (n, len_set) = unsafe {
44 set_upto_64bits(
45 write_data,
46 data,
47 offset_write + acc,
48 offset_read + acc,
49 len - acc,
50 )
51 };
52 null_count += n;
53 acc += len_set;
54 }
55
56 null_count
57}
58
59#[inline]
65unsafe fn set_upto_64bits(
66 write_data: &mut [u8],
67 data: &[u8],
68 offset_write: usize,
69 offset_read: usize,
70 len: usize,
71) -> (usize, usize) {
72 let read_byte = offset_read / 8;
73 let read_shift = offset_read % 8;
74 let write_byte = offset_write / 8;
75 let write_shift = offset_write % 8;
76
77 if len >= 64 {
78 let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
79 if read_shift == 0 {
80 if write_shift == 0 {
81 let len = 64;
83 let null_count = chunk.count_zeros() as usize;
84 unsafe { write_u64_bytes(write_data, write_byte, chunk) };
85 (null_count, len)
86 } else {
87 let len = 64 - write_shift;
89 let chunk = chunk << write_shift;
90 let null_count = len - chunk.count_ones() as usize;
91 unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
92 (null_count, len)
93 }
94 } else if write_shift == 0 {
95 let len = 64 - 8; let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; let null_count = len - chunk.count_ones() as usize;
99 unsafe { write_u64_bytes(write_data, write_byte, chunk) };
100 (null_count, len)
101 } else {
102 let len = 64 - std::cmp::max(read_shift, write_shift);
103 let chunk = (chunk >> read_shift) << write_shift;
104 let null_count = len - chunk.count_ones() as usize;
105 unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
106 (null_count, len)
107 }
108 } else if len == 1 {
109 let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
110 unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
111 ((byte_chunk ^ 1) as usize, 1)
112 } else {
113 let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
114 let bytes = ceil(len + read_shift, 8);
115 let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
117 let mask = u64::MAX >> (64 - len);
118 let chunk = (chunk >> read_shift) & mask; let chunk = chunk << write_shift; let null_count = len - chunk.count_ones() as usize;
121 let bytes = ceil(len + write_shift, 8);
122 for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
123 unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
124 }
125 (null_count, len)
126 }
127}
128
129#[inline]
132unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
133 debug_assert!(count <= 8);
134 let mut tmp: u64 = 0;
135 let src = unsafe { data.as_ptr().add(offset) };
136 unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) };
137 tmp
138}
139
140#[inline]
143unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
144 let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64;
145 unsafe { ptr.write_unaligned(chunk) };
146}
147
148#[inline]
154unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
155 let ptr = unsafe { data.as_mut_ptr().add(offset) };
156 let chunk = chunk | (unsafe { *ptr }) as u64;
157 unsafe { (ptr as *mut u64).write_unaligned(chunk) };
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::bit_util::{get_bit, set_bit, unset_bit};
164 use rand::prelude::StdRng;
165 use rand::{Rng, SeedableRng, TryRngCore};
166 use std::fmt::Display;
167
168 #[test]
169 fn test_set_bits_aligned() {
170 SetBitsTest {
171 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
172 data: vec![
173 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
174 0b10100101,
175 ],
176 offset_write: 8,
177 offset_read: 0,
178 len: 64,
179 expected_data: vec![
180 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
181 0b11100111, 0b10100101, 0,
182 ],
183 expected_null_count: 24,
184 }
185 .verify();
186 }
187
188 #[test]
189 fn test_set_bits_unaligned_destination_start() {
190 SetBitsTest {
191 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
192 data: vec![
193 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
194 0b10100101,
195 ],
196 offset_write: 3,
197 offset_read: 0,
198 len: 64,
199 expected_data: vec![
200 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110,
201 0b00101111, 0b00000101, 0b00000000,
202 ],
203 expected_null_count: 24,
204 }
205 .verify();
206 }
207
208 #[test]
209 fn test_set_bits_unaligned_destination_end() {
210 SetBitsTest {
211 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
212 data: vec![
213 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
214 0b10100101,
215 ],
216 offset_write: 8,
217 offset_read: 0,
218 len: 62,
219 expected_data: vec![
220 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011,
221 0b11100111, 0b00100101, 0,
222 ],
223 expected_null_count: 23,
224 }
225 .verify();
226 }
227
228 #[test]
229 fn test_set_bits_unaligned() {
230 SetBitsTest {
231 write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
232 data: vec![
233 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111,
234 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101,
235 0b10011001, 0b11011011, 0b11101011, 0b11000011,
236 ],
237 offset_write: 3,
238 offset_read: 5,
239 len: 95,
240 expected_data: vec![
241 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001,
242 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001,
243 ],
244 expected_null_count: 35,
245 }
246 .verify();
247 }
248
249 #[test]
250 fn set_bits_fuzz() {
251 let mut rng = StdRng::seed_from_u64(42);
252 let mut data = SetBitsTest::new();
253 for _ in 0..100 {
254 data.regen(&mut rng);
255 data.verify();
256 }
257 }
258
259 #[derive(Debug, Default)]
260 struct SetBitsTest {
261 write_data: Vec<u8>,
263 data: Vec<u8>,
265 offset_write: usize,
266 offset_read: usize,
267 len: usize,
268 expected_data: Vec<u8>,
270 expected_null_count: usize,
272 }
273
274 struct BinaryFormatter<'a>(&'a [u8]);
276 impl Display for BinaryFormatter<'_> {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 for byte in self.0 {
279 write!(f, "{byte:08b} ")?;
280 }
281 write!(f, " ")?;
282 Ok(())
283 }
284 }
285
286 impl Display for SetBitsTest {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 writeln!(f, "SetBitsTest {{")?;
289 writeln!(f, " write_data: {}", BinaryFormatter(&self.write_data))?;
290 writeln!(f, " data: {}", BinaryFormatter(&self.data))?;
291 writeln!(
292 f,
293 " expected_data: {}",
294 BinaryFormatter(&self.expected_data)
295 )?;
296 writeln!(f, " offset_write: {}", self.offset_write)?;
297 writeln!(f, " offset_read: {}", self.offset_read)?;
298 writeln!(f, " len: {}", self.len)?;
299 writeln!(f, " expected_null_count: {}", self.expected_null_count)?;
300 writeln!(f, "}}")
301 }
302 }
303
304 impl SetBitsTest {
305 fn new() -> Self {
307 Self::default()
308 }
309
310 fn regen(&mut self, rng: &mut StdRng) {
312 let len = rng.random_range(0..=200);
324
325 let offset_write_bits = rng.random_range(0..=200);
327 let offset_write_bytes = if offset_write_bits % 8 == 0 {
328 offset_write_bits / 8
329 } else {
330 (offset_write_bits / 8) + 1
331 };
332 let extra_write_data_bytes = rng.random_range(0..=5); let extra_read_data_bytes = rng.random_range(0..=5); let offset_read_bits = rng.random_range(0..=200);
337 let offset_read_bytes = if offset_read_bits % 8 != 0 {
338 (offset_read_bits / 8) + 1
339 } else {
340 offset_read_bits / 8
341 };
342
343 self.write_data.clear();
345 self.write_data
346 .resize(offset_write_bytes + len + extra_write_data_bytes, 0);
347
348 self.offset_write = offset_write_bits;
352
353 self.data
355 .resize(offset_read_bytes + len + extra_read_data_bytes, 0);
356 rng.try_fill_bytes(self.data.as_mut_slice()).unwrap();
358 self.offset_read = offset_read_bits;
359
360 self.len = len;
361
362 self.expected_data.resize(self.write_data.len(), 0);
364 self.expected_data.copy_from_slice(&self.write_data);
365
366 self.expected_null_count = 0;
367 for i in 0..self.len {
368 let bit = get_bit(&self.data, self.offset_read + i);
369 if bit {
370 set_bit(&mut self.expected_data, self.offset_write + i);
371 } else {
372 unset_bit(&mut self.expected_data, self.offset_write + i);
373 self.expected_null_count += 1;
374 }
375 }
376 }
377
378 fn verify(&self) {
380 let mut actual = self.write_data.to_vec();
382 let null_count = set_bits(
383 &mut actual,
384 &self.data,
385 self.offset_write,
386 self.offset_read,
387 self.len,
388 );
389
390 assert_eq!(actual, self.expected_data, "self: {self}");
391 assert_eq!(null_count, self.expected_null_count, "self: {self}");
392 }
393 }
394
395 #[test]
396 fn test_set_upto_64bits() {
397 let write_data: &mut [u8] = &mut [0; 9];
399 let data: &[u8] = &[
400 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
401 0b00000001, 0b00000001,
402 ];
403 let offset_write = 1;
404 let offset_read = 0;
405 let len = 65;
406 let (n, len_set) =
407 unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
408 assert_eq!(n, 55);
409 assert_eq!(len_set, 63);
410 assert_eq!(
411 write_data,
412 &[
413 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
414 0b00000010, 0b00000000
415 ]
416 );
417
418 let write_data: &mut [u8] = &mut [0b00000000];
420 let data: &[u8] = &[0b00000001];
421 let offset_write = 1;
422 let offset_read = 0;
423 let len = 1;
424 let (n, len_set) =
425 unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
426 assert_eq!(n, 0);
427 assert_eq!(len_set, 1);
428 assert_eq!(write_data, &[0b00000010]);
429 }
430}