1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::ArrayDataBuilder;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117 let nulls = filter.nulls().unwrap();
118 let mask = filter.values() & nulls.inner();
119 BooleanArray::new(mask, None)
120}
121
122pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144 let mut filter_builder = FilterBuilder::new(predicate);
145
146 if multiple_arrays(values.data_type()) {
147 filter_builder = filter_builder.optimize();
150 }
151
152 let predicate = filter_builder.build();
153
154 filter_array(values, &predicate)
155}
156
157fn multiple_arrays(data_type: &DataType) -> bool {
158 match data_type {
159 DataType::Struct(fields) => {
160 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161 }
162 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163 _ => false,
164 }
165}
166
167pub fn filter_record_batch(
172 record_batch: &RecordBatch,
173 predicate: &BooleanArray,
174) -> Result<RecordBatch, ArrowError> {
175 let mut filter_builder = FilterBuilder::new(predicate);
176 if record_batch.num_columns() > 1 {
177 filter_builder = filter_builder.optimize();
180 }
181 let filter = filter_builder.build();
182
183 let filtered_arrays = record_batch
184 .columns()
185 .iter()
186 .map(|a| filter_array(a, &filter))
187 .collect::<Result<Vec<_>, _>>()?;
188 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190}
191
192#[derive(Debug)]
194pub struct FilterBuilder {
195 filter: BooleanArray,
196 count: usize,
197 strategy: IterationStrategy,
198}
199
200impl FilterBuilder {
201 pub fn new(filter: &BooleanArray) -> Self {
203 let filter = match filter.null_count() {
204 0 => filter.clone(),
205 _ => prep_null_mask_filter(filter),
206 };
207
208 let count = filter_count(&filter);
209 let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211 Self {
212 filter,
213 count,
214 strategy,
215 }
216 }
217
218 pub fn optimize(mut self) -> Self {
224 match self.strategy {
225 IterationStrategy::SlicesIterator => {
226 let slices = SlicesIterator::new(&self.filter).collect();
227 self.strategy = IterationStrategy::Slices(slices)
228 }
229 IterationStrategy::IndexIterator => {
230 let indices = IndexIterator::new(&self.filter, self.count).collect();
231 self.strategy = IterationStrategy::Indices(indices)
232 }
233 _ => {}
234 }
235 self
236 }
237
238 pub fn build(self) -> FilterPredicate {
240 FilterPredicate {
241 filter: self.filter,
242 count: self.count,
243 strategy: self.strategy,
244 }
245 }
246}
247
248#[derive(Debug)]
250enum IterationStrategy {
251 SlicesIterator,
253 IndexIterator,
255 Indices(Vec<usize>),
257 Slices(Vec<(usize, usize)>),
259 All,
261 None,
263}
264
265impl IterationStrategy {
266 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269 if filter_length == 0 || filter_count == 0 {
270 return IterationStrategy::None;
271 }
272
273 if filter_count == filter_length {
274 return IterationStrategy::All;
275 }
276
277 let selectivity_frac = filter_count as f64 / filter_length as f64;
282 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283 return IterationStrategy::SlicesIterator;
284 }
285 IterationStrategy::IndexIterator
286 }
287}
288
289#[derive(Debug)]
291pub struct FilterPredicate {
292 filter: BooleanArray,
293 count: usize,
294 strategy: IterationStrategy,
295}
296
297impl FilterPredicate {
298 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300 filter_array(values, self)
301 }
302
303 pub fn count(&self) -> usize {
305 self.count
306 }
307}
308
309fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310 if predicate.filter.len() > values.len() {
311 return Err(ArrowError::InvalidArgumentError(format!(
312 "Filter predicate of length {} is larger than target array of length {}",
313 predicate.filter.len(),
314 values.len()
315 )));
316 }
317
318 match predicate.strategy {
319 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321 _ => downcast_primitive_array! {
323 values => Ok(Arc::new(filter_primitive(values, predicate))),
324 DataType::Boolean => {
325 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326 Ok(Arc::new(filter_boolean(values, predicate)))
327 }
328 DataType::Utf8 => {
329 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330 }
331 DataType::LargeUtf8 => {
332 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333 }
334 DataType::Utf8View => {
335 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336 }
337 DataType::Binary => {
338 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339 }
340 DataType::LargeBinary => {
341 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342 }
343 DataType::BinaryView => {
344 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345 }
346 DataType::FixedSizeBinary(_) => {
347 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348 }
349 DataType::RunEndEncoded(_, _) => {
350 downcast_run_array!{
351 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353 }
354 }
355 DataType::Dictionary(_, _) => downcast_dictionary_array! {
356 values => Ok(Arc::new(filter_dict(values, predicate))),
357 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358 }
359 DataType::Struct(_) => {
360 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361 }
362 DataType::Union(_, UnionMode::Sparse) => {
363 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364 }
365 _ => {
366 let data = values.to_data();
367 let mut mutable = MutableArrayData::new(
369 vec![&data],
370 false,
371 predicate.count,
372 );
373
374 match &predicate.strategy {
375 IterationStrategy::Slices(slices) => {
376 slices
377 .iter()
378 .for_each(|(start, end)| mutable.extend(0, *start, *end));
379 }
380 _ => {
381 let iter = SlicesIterator::new(&predicate.filter);
382 iter.for_each(|(start, end)| mutable.extend(0, start, end));
383 }
384 }
385
386 let data = mutable.freeze();
387 Ok(make_array(data))
388 }
389 },
390 }
391}
392
393fn filter_run_end_array<R: RunEndIndexType>(
395 array: &RunArray<R>,
396 predicate: &FilterPredicate,
397) -> Result<RunArray<R>, ArrowError>
398where
399 R::Native: Into<i64> + From<bool>,
400 R::Native: AddAssign,
401{
402 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405 let mut start = 0u64;
406 let mut j = 0;
407 let mut count = R::default_value();
408 let filter_values = predicate.filter.values();
409 let run_ends = run_ends.inner();
410
411 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412 let mut keep = false;
413 let mut end = run_ends[i].into() as u64;
414 let difference = end.saturating_sub(filter_values.len() as u64);
415 end -= difference;
416
417 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419 count += R::Native::from(pred);
420 keep |= pred
421 }
422 new_run_ends[j] = count;
424 j += keep as usize;
425
426 start = end;
427 keep
428 })
429 .into();
430
431 new_run_ends.truncate(j);
432
433 let values = array.values();
434 let values = filter(&values, &pred)?;
435
436 let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
437 RunArray::try_new(&run_ends, &values)
438}
439
440fn filter_null_mask(
447 nulls: Option<&NullBuffer>,
448 predicate: &FilterPredicate,
449) -> Option<(usize, Buffer)> {
450 let nulls = nulls?;
451 if nulls.null_count() == 0 {
452 return None;
453 }
454
455 let nulls = filter_bits(nulls.inner(), predicate);
456 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460 if null_count == 0 {
461 return None;
462 }
463
464 Some((null_count, nulls))
465}
466
467fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469 let src = buffer.values();
470 let offset = buffer.offset();
471
472 match &predicate.strategy {
473 IterationStrategy::IndexIterator => {
474 let bits = IndexIterator::new(&predicate.filter, predicate.count)
475 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479 }
480 IterationStrategy::Indices(indices) => {
481 let bits = indices
482 .iter()
483 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487 }
488 IterationStrategy::SlicesIterator => {
489 let mut builder = BooleanBufferBuilder::new(predicate.count);
490 for (start, end) in SlicesIterator::new(&predicate.filter) {
491 builder.append_packed_range(start + offset..end + offset, src)
492 }
493 builder.into()
494 }
495 IterationStrategy::Slices(slices) => {
496 let mut builder = BooleanBufferBuilder::new(predicate.count);
497 for (start, end) in slices {
498 builder.append_packed_range(*start + offset..*end + offset, src)
499 }
500 builder.into()
501 }
502 IterationStrategy::All | IterationStrategy::None => unreachable!(),
503 }
504}
505
506fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508 let values = filter_bits(array.values(), predicate);
509
510 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511 .len(predicate.count)
512 .add_buffer(values);
513
514 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516 }
517
518 let data = unsafe { builder.build_unchecked() };
519 BooleanArray::from(data)
520}
521
522#[inline(never)]
523fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524 assert!(values.len() >= predicate.filter.len());
525
526 match &predicate.strategy {
527 IterationStrategy::SlicesIterator => {
528 let mut buffer = Vec::with_capacity(predicate.count);
529 for (start, end) in SlicesIterator::new(&predicate.filter) {
530 buffer.extend_from_slice(&values[start..end]);
531 }
532 buffer.into()
533 }
534 IterationStrategy::Slices(slices) => {
535 let mut buffer = Vec::with_capacity(predicate.count);
536 for (start, end) in slices {
537 buffer.extend_from_slice(&values[*start..*end]);
538 }
539 buffer.into()
540 }
541 IterationStrategy::IndexIterator => {
542 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546 }
547 IterationStrategy::Indices(indices) => {
548 let iter = indices.iter().map(|x| values[*x]);
549 iter.collect::<Vec<_>>().into()
550 }
551 IterationStrategy::All | IterationStrategy::None => unreachable!(),
552 }
553}
554
555fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557where
558 T: ArrowPrimitiveType,
559{
560 let values = array.values();
561 let buffer = filter_native(values, predicate);
562 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563 .len(predicate.count)
564 .add_buffer(buffer);
565
566 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568 }
569
570 let data = unsafe { builder.build_unchecked() };
571 PrimitiveArray::from(data)
572}
573
574struct FilterBytes<'a, OffsetSize> {
579 src_offsets: &'a [OffsetSize],
580 src_values: &'a [u8],
581 dst_offsets: Vec<OffsetSize>,
582 dst_values: Vec<u8>,
583 cur_offset: OffsetSize,
584}
585
586impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587where
588 OffsetSize: OffsetSizeTrait,
589{
590 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591 where
592 T: ByteArrayType<Offset = OffsetSize>,
593 {
594 let dst_values = Vec::new();
595 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596 let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598 dst_offsets.push(cur_offset);
599
600 Self {
601 src_offsets: array.value_offsets(),
602 src_values: array.value_data(),
603 dst_offsets,
604 dst_values,
605 cur_offset,
606 }
607 }
608
609 #[inline]
611 fn get_value_offset(&self, idx: usize) -> usize {
612 self.src_offsets[idx].as_usize()
613 }
614
615 #[inline]
617 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618 let start = self.get_value_offset(idx);
620 let end = self.get_value_offset(idx + 1);
621 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622 (start, end, len)
623 }
624
625 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626 self.dst_offsets.extend(iter.map(|idx| {
627 let start = self.src_offsets[idx].as_usize();
628 let end = self.src_offsets[idx + 1].as_usize();
629 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630 self.cur_offset += len;
631
632 self.cur_offset
633 }));
634 }
635
636 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638 self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640 for idx in iter {
641 let start = self.src_offsets[idx].as_usize();
642 let end = self.src_offsets[idx + 1].as_usize();
643 self.dst_values
644 .extend_from_slice(&self.src_values[start..end]);
645 }
646 }
647
648 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649 self.dst_offsets.reserve_exact(count);
650 for (start, end) in iter {
651 for idx in start..end {
653 let (_, _, len) = self.get_value_range(idx);
654 self.cur_offset += len;
655 self.dst_offsets.push(self.cur_offset);
656 }
657 }
658 }
659
660 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662 self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664 for (start, end) in iter {
665 let value_start = self.get_value_offset(start);
666 let value_end = self.get_value_offset(end);
667 self.dst_values
668 .extend_from_slice(&self.src_values[value_start..value_end]);
669 }
670 }
671}
672
673fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678where
679 T: ByteArrayType,
680{
681 let mut filter = FilterBytes::new(predicate.count, array);
682
683 match &predicate.strategy {
684 IterationStrategy::SlicesIterator => {
685 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686 filter.extend_slices(SlicesIterator::new(&predicate.filter))
687 }
688 IterationStrategy::Slices(slices) => {
689 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690 filter.extend_slices(slices.iter().cloned())
691 }
692 IterationStrategy::IndexIterator => {
693 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695 }
696 IterationStrategy::Indices(indices) => {
697 filter.extend_offsets_idx(indices.iter().cloned());
698 filter.extend_idx(indices.iter().cloned())
699 }
700 IterationStrategy::All | IterationStrategy::None => unreachable!(),
701 }
702
703 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704 .len(predicate.count)
705 .add_buffer(filter.dst_offsets.into())
706 .add_buffer(filter.dst_values.into());
707
708 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710 }
711
712 let data = unsafe { builder.build_unchecked() };
713 GenericByteArray::from(data)
714}
715
716fn filter_byte_view<T: ByteViewType>(
718 array: &GenericByteViewArray<T>,
719 predicate: &FilterPredicate,
720) -> GenericByteViewArray<T> {
721 let new_view_buffer = filter_native(array.views(), predicate);
722
723 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724 .len(predicate.count)
725 .add_buffer(new_view_buffer)
726 .add_buffers(array.data_buffers().to_vec());
727
728 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730 }
731
732 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733}
734
735fn filter_fixed_size_binary(
736 array: &FixedSizeBinaryArray,
737 predicate: &FilterPredicate,
738) -> FixedSizeBinaryArray {
739 let values: &[u8] = array.values();
740 let value_length = array.value_length() as usize;
741 let calculate_offset_from_index = |index: usize| index * value_length;
742 let buffer = match &predicate.strategy {
743 IterationStrategy::SlicesIterator => {
744 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745 for (start, end) in SlicesIterator::new(&predicate.filter) {
746 buffer.extend_from_slice(
747 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748 );
749 }
750 buffer
751 }
752 IterationStrategy::Slices(slices) => {
753 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754 for (start, end) in slices {
755 buffer.extend_from_slice(
756 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757 );
758 }
759 buffer
760 }
761 IterationStrategy::IndexIterator => {
762 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764 });
765
766 let mut buffer = MutableBuffer::new(predicate.count * value_length);
767 iter.for_each(|item| buffer.extend_from_slice(item));
768 buffer
769 }
770 IterationStrategy::Indices(indices) => {
771 let iter = indices.iter().map(|x| {
772 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773 });
774
775 let mut buffer = MutableBuffer::new(predicate.count * value_length);
776 iter.for_each(|item| buffer.extend_from_slice(item));
777 buffer
778 }
779 IterationStrategy::All | IterationStrategy::None => unreachable!(),
780 };
781 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782 .len(predicate.count)
783 .add_buffer(buffer.into());
784
785 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787 }
788
789 let data = unsafe { builder.build_unchecked() };
790 FixedSizeBinaryArray::from(data)
791}
792
793fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795where
796 T: ArrowDictionaryKeyType,
797 T::Native: num::Num,
798{
799 let builder = filter_primitive::<T>(array.keys(), predicate)
800 .into_data()
801 .into_builder()
802 .data_type(array.data_type().clone())
803 .child_data(vec![array.values().to_data()]);
804
805 DictionaryArray::from(unsafe { builder.build_unchecked() })
808}
809
810fn filter_struct(
812 array: &StructArray,
813 predicate: &FilterPredicate,
814) -> Result<StructArray, ArrowError> {
815 let columns = array
816 .columns()
817 .iter()
818 .map(|column| filter_array(column, predicate))
819 .collect::<Result<_, _>>()?;
820
821 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825 } else {
826 None
827 };
828
829 Ok(unsafe {
830 StructArray::new_unchecked_with_length(
831 array.fields().clone(),
832 columns,
833 nulls,
834 predicate.count(),
835 )
836 })
837}
838
839fn filter_sparse_union(
841 array: &UnionArray,
842 predicate: &FilterPredicate,
843) -> Result<UnionArray, ArrowError> {
844 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845 unreachable!()
846 };
847
848 let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
849
850 let children = fields
851 .iter()
852 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
853 .collect::<Result<_, _>>()?;
854
855 Ok(unsafe {
856 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
857 })
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use arrow_array::builder::*;
864 use arrow_array::cast::as_run_array;
865 use arrow_array::types::*;
866 use arrow_data::ArrayData;
867 use rand::distr::uniform::{UniformSampler, UniformUsize};
868 use rand::distr::{Alphanumeric, StandardUniform};
869 use rand::prelude::*;
870 use rand::rng;
871
872 macro_rules! def_temporal_test {
873 ($test:ident, $array_type: ident, $data: expr) => {
874 #[test]
875 fn $test() {
876 let a = $data;
877 let b = BooleanArray::from(vec![true, false, true, false]);
878 let c = filter(&a, &b).unwrap();
879 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880 assert_eq!(2, d.len());
881 assert_eq!(1, d.value(0));
882 assert_eq!(3, d.value(1));
883 }
884 };
885 }
886
887 def_temporal_test!(
888 test_filter_date32,
889 Date32Array,
890 Date32Array::from(vec![1, 2, 3, 4])
891 );
892 def_temporal_test!(
893 test_filter_date64,
894 Date64Array,
895 Date64Array::from(vec![1, 2, 3, 4])
896 );
897 def_temporal_test!(
898 test_filter_time32_second,
899 Time32SecondArray,
900 Time32SecondArray::from(vec![1, 2, 3, 4])
901 );
902 def_temporal_test!(
903 test_filter_time32_millisecond,
904 Time32MillisecondArray,
905 Time32MillisecondArray::from(vec![1, 2, 3, 4])
906 );
907 def_temporal_test!(
908 test_filter_time64_microsecond,
909 Time64MicrosecondArray,
910 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911 );
912 def_temporal_test!(
913 test_filter_time64_nanosecond,
914 Time64NanosecondArray,
915 Time64NanosecondArray::from(vec![1, 2, 3, 4])
916 );
917 def_temporal_test!(
918 test_filter_duration_second,
919 DurationSecondArray,
920 DurationSecondArray::from(vec![1, 2, 3, 4])
921 );
922 def_temporal_test!(
923 test_filter_duration_millisecond,
924 DurationMillisecondArray,
925 DurationMillisecondArray::from(vec![1, 2, 3, 4])
926 );
927 def_temporal_test!(
928 test_filter_duration_microsecond,
929 DurationMicrosecondArray,
930 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931 );
932 def_temporal_test!(
933 test_filter_duration_nanosecond,
934 DurationNanosecondArray,
935 DurationNanosecondArray::from(vec![1, 2, 3, 4])
936 );
937 def_temporal_test!(
938 test_filter_timestamp_second,
939 TimestampSecondArray,
940 TimestampSecondArray::from(vec![1, 2, 3, 4])
941 );
942 def_temporal_test!(
943 test_filter_timestamp_millisecond,
944 TimestampMillisecondArray,
945 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946 );
947 def_temporal_test!(
948 test_filter_timestamp_microsecond,
949 TimestampMicrosecondArray,
950 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951 );
952 def_temporal_test!(
953 test_filter_timestamp_nanosecond,
954 TimestampNanosecondArray,
955 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956 );
957
958 #[test]
959 fn test_filter_array_slice() {
960 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961 let b = BooleanArray::from(vec![true, false, false, true]);
962 let c = filter(&a, &b).unwrap();
966 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967 assert_eq!(2, d.len());
968 assert_eq!(6, d.value(0));
969 assert_eq!(9, d.value(1));
970 }
971
972 #[test]
973 fn test_filter_array_low_density() {
974 let mut data_values = (1..=65).collect::<Vec<i32>>();
976 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977 data_values.extend_from_slice(&[66, 67]);
979 filter_values.extend_from_slice(&[false, true]);
980 let a = Int32Array::from(data_values);
981 let b = BooleanArray::from(filter_values);
982 let c = filter(&a, &b).unwrap();
983 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984 assert_eq!(2, d.len());
985 assert_eq!(65, d.value(0));
986 assert_eq!(67, d.value(1));
987 }
988
989 #[test]
990 fn test_filter_array_high_density() {
991 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993 let mut filter_values = (1..=65)
994 .map(|i| !matches!(i % 65, 0))
995 .collect::<Vec<bool>>();
996 data_values[1] = None;
998 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000 filter_values.extend_from_slice(&[false, true, true, true]);
1001 let a = Int32Array::from(data_values);
1002 let b = BooleanArray::from(filter_values);
1003 let c = filter(&a, &b).unwrap();
1004 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005 assert_eq!(67, d.len());
1006 assert_eq!(3, d.null_count());
1007 assert_eq!(1, d.value(0));
1008 assert!(d.is_null(1));
1009 assert_eq!(64, d.value(63));
1010 assert!(d.is_null(64));
1011 assert_eq!(67, d.value(65));
1012 }
1013
1014 #[test]
1015 fn test_filter_string_array_simple() {
1016 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017 let b = BooleanArray::from(vec![true, false, true, false]);
1018 let c = filter(&a, &b).unwrap();
1019 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020 assert_eq!(2, d.len());
1021 assert_eq!("hello", d.value(0));
1022 assert_eq!("world", d.value(1));
1023 }
1024
1025 #[test]
1026 fn test_filter_primitive_array_with_null() {
1027 let a = Int32Array::from(vec![Some(5), None]);
1028 let b = BooleanArray::from(vec![false, true]);
1029 let c = filter(&a, &b).unwrap();
1030 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031 assert_eq!(1, d.len());
1032 assert!(d.is_null(0));
1033 }
1034
1035 #[test]
1036 fn test_filter_string_array_with_null() {
1037 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038 let b = BooleanArray::from(vec![true, false, false, true]);
1039 let c = filter(&a, &b).unwrap();
1040 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041 assert_eq!(2, d.len());
1042 assert_eq!("hello", d.value(0));
1043 assert!(!d.is_null(0));
1044 assert!(d.is_null(1));
1045 }
1046
1047 #[test]
1048 fn test_filter_binary_array_with_null() {
1049 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050 let a = BinaryArray::from(data);
1051 let b = BooleanArray::from(vec![true, false, false, true]);
1052 let c = filter(&a, &b).unwrap();
1053 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054 assert_eq!(2, d.len());
1055 assert_eq!(b"hello", d.value(0));
1056 assert!(!d.is_null(0));
1057 assert!(d.is_null(1));
1058 }
1059
1060 fn _test_filter_byte_view<T>()
1061 where
1062 T: ByteViewType,
1063 str: AsRef<T::Native>,
1064 T::Native: PartialEq,
1065 {
1066 let array = {
1067 let mut builder = GenericByteViewBuilder::<T>::new();
1069 builder.append_value("hello");
1070 builder.append_value("world");
1071 builder.append_null();
1072 builder.append_value("large payload over 12 bytes");
1073 builder.append_value("lulu");
1074 builder.finish()
1075 };
1076
1077 {
1078 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079 let actual = filter(&array, &predicate).unwrap();
1080
1081 assert_eq!(actual.len(), 3);
1082
1083 let expected = {
1084 let mut builder = GenericByteViewBuilder::<T>::new();
1086 builder.append_value("hello");
1087 builder.append_null();
1088 builder.append_value("large payload over 12 bytes");
1089 builder.finish()
1090 };
1091
1092 assert_eq!(actual.as_ref(), &expected);
1093 }
1094
1095 {
1096 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097 let actual = filter(&array, &predicate).unwrap();
1098
1099 assert_eq!(actual.len(), 2);
1100
1101 let expected = {
1102 let mut builder = GenericByteViewBuilder::<T>::new();
1104 builder.append_value("hello");
1105 builder.append_value("lulu");
1106 builder.finish()
1107 };
1108
1109 assert_eq!(actual.as_ref(), &expected);
1110 }
1111 }
1112
1113 #[test]
1114 fn test_filter_string_view() {
1115 _test_filter_byte_view::<StringViewType>()
1116 }
1117
1118 #[test]
1119 fn test_filter_binary_view() {
1120 _test_filter_byte_view::<BinaryViewType>()
1121 }
1122
1123 #[test]
1124 fn test_filter_fixed_binary() {
1125 let v1 = [1_u8, 2];
1126 let v2 = [3_u8, 4];
1127 let v3 = [5_u8, 6];
1128 let v = vec![&v1, &v2, &v3];
1129 let a = FixedSizeBinaryArray::from(v);
1130 let b = BooleanArray::from(vec![true, false, true]);
1131 let c = filter(&a, &b).unwrap();
1132 let d = c
1133 .as_ref()
1134 .as_any()
1135 .downcast_ref::<FixedSizeBinaryArray>()
1136 .unwrap();
1137 assert_eq!(d.len(), 2);
1138 assert_eq!(d.value(0), &v1);
1139 assert_eq!(d.value(1), &v3);
1140 let c2 = FilterBuilder::new(&b)
1141 .optimize()
1142 .build()
1143 .filter(&a)
1144 .unwrap();
1145 let d2 = c2
1146 .as_ref()
1147 .as_any()
1148 .downcast_ref::<FixedSizeBinaryArray>()
1149 .unwrap();
1150 assert_eq!(d, d2);
1151
1152 let b = BooleanArray::from(vec![false, false, false]);
1153 let c = filter(&a, &b).unwrap();
1154 let d = c
1155 .as_ref()
1156 .as_any()
1157 .downcast_ref::<FixedSizeBinaryArray>()
1158 .unwrap();
1159 assert_eq!(d.len(), 0);
1160
1161 let b = BooleanArray::from(vec![true, true, true]);
1162 let c = filter(&a, &b).unwrap();
1163 let d = c
1164 .as_ref()
1165 .as_any()
1166 .downcast_ref::<FixedSizeBinaryArray>()
1167 .unwrap();
1168 assert_eq!(d.len(), 3);
1169 assert_eq!(d.value(0), &v1);
1170 assert_eq!(d.value(1), &v2);
1171 assert_eq!(d.value(2), &v3);
1172
1173 let b = BooleanArray::from(vec![false, false, true]);
1174 let c = filter(&a, &b).unwrap();
1175 let d = c
1176 .as_ref()
1177 .as_any()
1178 .downcast_ref::<FixedSizeBinaryArray>()
1179 .unwrap();
1180 assert_eq!(d.len(), 1);
1181 assert_eq!(d.value(0), &v3);
1182 let c2 = FilterBuilder::new(&b)
1183 .optimize()
1184 .build()
1185 .filter(&a)
1186 .unwrap();
1187 let d2 = c2
1188 .as_ref()
1189 .as_any()
1190 .downcast_ref::<FixedSizeBinaryArray>()
1191 .unwrap();
1192 assert_eq!(d, d2);
1193 }
1194
1195 #[test]
1196 fn test_filter_array_slice_with_null() {
1197 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198 let b = BooleanArray::from(vec![true, false, false, true]);
1199 let c = filter(&a, &b).unwrap();
1203 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204 assert_eq!(2, d.len());
1205 assert!(d.is_null(0));
1206 assert!(!d.is_null(1));
1207 assert_eq!(9, d.value(1));
1208 }
1209
1210 #[test]
1211 fn test_filter_run_end_encoding_array() {
1212 let run_ends = Int64Array::from(vec![2, 3, 8]);
1213 let values = Int64Array::from(vec![7, -2, 9]);
1214 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216 let c = filter(&a, &b).unwrap();
1217 let actual: &RunArray<Int64Type> = as_run_array(&c);
1218 assert_eq!(4, actual.len());
1219
1220 let expected = RunArray::try_new(
1221 &Int64Array::from(vec![1, 2, 4]),
1222 &Int64Array::from(vec![7, -2, 9]),
1223 )
1224 .expect("Failed to make expected RunArray test is broken");
1225
1226 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227 assert_eq!(actual.values(), expected.values())
1228 }
1229
1230 #[test]
1231 fn test_filter_run_end_encoding_array_remove_value() {
1232 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233 let values = Int32Array::from(vec![7, -2, 9, -8]);
1234 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235 let b = BooleanArray::from(vec![
1236 false, true, false, false, true, false, true, false, false, false,
1237 ]);
1238 let c = filter(&a, &b).unwrap();
1239 let actual: &RunArray<Int32Type> = as_run_array(&c);
1240 assert_eq!(3, actual.len());
1241
1242 let expected =
1243 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244 .expect("Failed to make expected RunArray test is broken");
1245
1246 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247 assert_eq!(actual.values(), expected.values())
1248 }
1249
1250 #[test]
1251 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253 let values = Int16Array::from(vec![7, -2, 9, -8]);
1254 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255 let b = BooleanArray::from(vec![
1256 false, false, false, false, false, false, true, false, false, false,
1257 ]);
1258 let c = filter(&a, &b).unwrap();
1259 let actual: &RunArray<Int16Type> = as_run_array(&c);
1260 assert_eq!(1, actual.len());
1261
1262 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263 .expect("Failed to make expected RunArray test is broken");
1264
1265 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266 assert_eq!(actual.values(), expected.values())
1267 }
1268
1269 #[test]
1270 fn test_filter_run_end_encoding_array_empty() {
1271 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272 let values = Int64Array::from(vec![7, -2, 9, -8]);
1273 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274 let b = BooleanArray::from(vec![
1275 false, false, false, false, false, false, false, false, false, false,
1276 ]);
1277 let c = filter(&a, &b).unwrap();
1278 let actual: &RunArray<Int64Type> = as_run_array(&c);
1279 assert_eq!(0, actual.len());
1280 }
1281
1282 #[test]
1283 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285 let values = Int64Array::from(vec![7, -2, 9, -8]);
1286 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287 let b = BooleanArray::from(vec![false, true, true]);
1288 let c = filter(&a, &b).unwrap();
1289 let actual: &RunArray<Int64Type> = as_run_array(&c);
1290 assert_eq!(2, actual.len());
1291
1292 let expected = RunArray::try_new(
1293 &Int64Array::from(vec![1, 2]),
1294 &Int64Array::from(vec![7, -2]),
1295 )
1296 .expect("Failed to make expected RunArray test is broken");
1297
1298 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299 assert_eq!(actual.values(), expected.values())
1300 }
1301
1302 #[test]
1303 fn test_filter_dictionary_array() {
1304 let values = [Some("hello"), None, Some("world"), Some("!")];
1305 let a: Int8DictionaryArray = values.iter().copied().collect();
1306 let b = BooleanArray::from(vec![false, true, true, false]);
1307 let c = filter(&a, &b).unwrap();
1308 let d = c
1309 .as_ref()
1310 .as_any()
1311 .downcast_ref::<Int8DictionaryArray>()
1312 .unwrap();
1313 let value_array = d.values();
1314 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315 assert_eq!(3, values.len());
1317 assert_eq!(2, d.len());
1319 assert!(d.is_null(0));
1320 assert_eq!("world", values.value(d.keys().value(1) as usize));
1321 }
1322
1323 #[test]
1324 fn test_filter_list_array() {
1325 let value_data = ArrayData::builder(DataType::Int32)
1326 .len(8)
1327 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328 .build()
1329 .unwrap();
1330
1331 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333 let list_data_type =
1334 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335 let list_data = ArrayData::builder(list_data_type)
1336 .len(4)
1337 .add_buffer(value_offsets)
1338 .add_child_data(value_data)
1339 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340 .build()
1341 .unwrap();
1342
1343 let a = LargeListArray::from(list_data);
1345 let b = BooleanArray::from(vec![false, true, false, true]);
1346 let result = filter(&a, &b).unwrap();
1347
1348 let value_data = ArrayData::builder(DataType::Int32)
1350 .len(3)
1351 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352 .build()
1353 .unwrap();
1354
1355 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357 let list_data_type =
1358 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359 let expected = ArrayData::builder(list_data_type)
1360 .len(2)
1361 .add_buffer(value_offsets)
1362 .add_child_data(value_data)
1363 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364 .build()
1365 .unwrap();
1366
1367 assert_eq!(&make_array(expected), &result);
1368 }
1369
1370 #[test]
1371 fn test_slice_iterator_bits() {
1372 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373 let filter = BooleanArray::from(filter_values);
1374 let filter_count = filter_count(&filter);
1375
1376 let iter = SlicesIterator::new(&filter);
1377 let chunks = iter.collect::<Vec<_>>();
1378
1379 assert_eq!(chunks, vec![(1, 2)]);
1380 assert_eq!(filter_count, 1);
1381 }
1382
1383 #[test]
1384 fn test_slice_iterator_bits1() {
1385 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386 let filter = BooleanArray::from(filter_values);
1387 let filter_count = filter_count(&filter);
1388
1389 let iter = SlicesIterator::new(&filter);
1390 let chunks = iter.collect::<Vec<_>>();
1391
1392 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393 assert_eq!(filter_count, 64 - 1);
1394 }
1395
1396 #[test]
1397 fn test_slice_iterator_chunk_and_bits() {
1398 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399 let filter = BooleanArray::from(filter_values);
1400 let filter_count = filter_count(&filter);
1401
1402 let iter = SlicesIterator::new(&filter);
1403 let chunks = iter.collect::<Vec<_>>();
1404
1405 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406 assert_eq!(filter_count, 61 + 61 + 5);
1407 }
1408
1409 #[test]
1410 fn test_null_mask() {
1411 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414 let out = filter(&a, &mask1).unwrap();
1415 assert_eq!(out.as_ref(), &a.slice(0, 2));
1416 }
1417
1418 #[test]
1419 fn test_filter_record_batch_no_columns() {
1420 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421 let options = RecordBatchOptions::default().with_row_count(Some(100));
1422 let record_batch =
1423 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424 let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426 assert_eq!(out.num_rows(), 2);
1427 }
1428
1429 #[test]
1430 fn test_fast_path() {
1431 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433 let mask = BooleanArray::from(vec![true, true, true]);
1435 let out = filter(&a, &mask).unwrap();
1436 let b = out
1437 .as_any()
1438 .downcast_ref::<PrimitiveArray<Int64Type>>()
1439 .unwrap();
1440 assert_eq!(&a, b);
1441
1442 let mask = BooleanArray::from(vec![false, false, false]);
1444 let out = filter(&a, &mask).unwrap();
1445 assert_eq!(out.len(), 0);
1446 assert_eq!(out.data_type(), &DataType::Int64);
1447 }
1448
1449 #[test]
1450 fn test_slices() {
1451 let bools = std::iter::repeat(true)
1453 .take(10)
1454 .chain(std::iter::repeat(false).take(30))
1455 .chain(std::iter::repeat(true).take(20))
1456 .chain(std::iter::repeat(false).take(17))
1457 .chain(std::iter::repeat(true).take(4));
1458
1459 let bool_array: BooleanArray = bools.map(Some).collect();
1460
1461 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1462 let expected = vec![(0, 10), (40, 60), (77, 81)];
1463 assert_eq!(slices, expected);
1464
1465 let len = bool_array.len();
1467 let sliced_array = bool_array.slice(7, len - 10);
1468 let sliced_array = sliced_array
1469 .as_any()
1470 .downcast_ref::<BooleanArray>()
1471 .unwrap();
1472 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1473 let expected = vec![(0, 3), (33, 53), (70, 71)];
1474 assert_eq!(slices, expected);
1475 }
1476
1477 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1478 let mut rng = rng();
1479
1480 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1481 .take(mask_len)
1482 .collect();
1483
1484 let buffer = Buffer::from_iter(bools.iter().cloned());
1485
1486 let truncated_length = mask_len - offset - truncate;
1487
1488 let data = ArrayDataBuilder::new(DataType::Boolean)
1489 .len(truncated_length)
1490 .offset(offset)
1491 .add_buffer(buffer)
1492 .build()
1493 .unwrap();
1494
1495 let filter = BooleanArray::from(data);
1496
1497 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1498 .flat_map(|(start, end)| start..end)
1499 .collect();
1500
1501 let count = filter_count(&filter);
1502 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1503
1504 let expected_bits: Vec<_> = bools
1505 .iter()
1506 .skip(offset)
1507 .take(truncated_length)
1508 .enumerate()
1509 .flat_map(|(idx, v)| v.then(|| idx))
1510 .collect();
1511
1512 assert_eq!(slice_bits, expected_bits);
1513 assert_eq!(index_bits, expected_bits);
1514 }
1515
1516 #[test]
1517 #[cfg_attr(miri, ignore)]
1518 fn fuzz_test_slices_iterator() {
1519 let mut rng = rng();
1520
1521 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1522 for _ in 0..100 {
1523 let mask_len = rng.random_range(0..1024);
1524 let max_offset = 64.min(mask_len);
1525 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1526
1527 let max_truncate = 128.min(mask_len - offset);
1528 let truncate = uusize
1529 .sample(&mut rng)
1530 .checked_rem(max_truncate)
1531 .unwrap_or(0);
1532
1533 test_slices_fuzz(mask_len, offset, truncate);
1534 }
1535
1536 test_slices_fuzz(64, 0, 0);
1537 test_slices_fuzz(64, 8, 0);
1538 test_slices_fuzz(64, 8, 8);
1539 test_slices_fuzz(32, 8, 8);
1540 test_slices_fuzz(32, 5, 9);
1541 }
1542
1543 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1545 values
1546 .into_iter()
1547 .zip(predicate)
1548 .filter(|(_, x)| **x)
1549 .map(|(a, _)| a)
1550 .collect()
1551 }
1552
1553 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1555 where
1556 StandardUniform: Distribution<T>,
1557 {
1558 let mut rng = rng();
1559 (0..len)
1560 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1561 .collect()
1562 }
1563
1564 fn gen_strings(
1566 len: usize,
1567 valid_percent: f64,
1568 str_len_range: std::ops::Range<usize>,
1569 ) -> Vec<Option<String>> {
1570 let mut rng = rng();
1571 (0..len)
1572 .map(|_| {
1573 rng.random_bool(valid_percent).then(|| {
1574 let len = rng.random_range(str_len_range.clone());
1575 (0..len)
1576 .map(|_| char::from(rng.sample(Alphanumeric)))
1577 .collect()
1578 })
1579 })
1580 .collect()
1581 }
1582
1583 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1585 src.iter().map(|x| x.as_deref())
1586 }
1587
1588 #[test]
1589 #[cfg_attr(miri, ignore)]
1590 fn fuzz_filter() {
1591 let mut rng = rng();
1592
1593 for i in 0..100 {
1594 let filter_percent = match i {
1595 0..=4 => 1.,
1596 5..=10 => 0.,
1597 _ => rng.random_range(0.0..1.0),
1598 };
1599
1600 let valid_percent = rng.random_range(0.0..1.0);
1601
1602 let array_len = rng.random_range(32..256);
1603 let array_offset = rng.random_range(0..10);
1604
1605 let filter_offset = rng.random_range(0..10);
1607 let filter_truncate = rng.random_range(0..10);
1608 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1609 .take(array_len + filter_offset - filter_truncate)
1610 .collect();
1611
1612 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1613
1614 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1616 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1617 let bools = &bools[filter_offset..];
1618
1619 let values = gen_primitive(array_len + array_offset, valid_percent);
1621 let src = Int32Array::from_iter(values.iter().cloned());
1622
1623 let src = src.slice(array_offset, array_len);
1624 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1625 let values = &values[array_offset..];
1626
1627 let filtered = filter(src, predicate).unwrap();
1628 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1629 let actual: Vec<_> = array.iter().collect();
1630
1631 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1632
1633 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1635 let src = StringArray::from_iter(as_deref(&strings));
1636
1637 let src = src.slice(array_offset, array_len);
1638 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1639
1640 let filtered = filter(src, predicate).unwrap();
1641 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1642 let actual: Vec<_> = array.iter().collect();
1643
1644 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1645 assert_eq!(actual, expected_strings);
1646
1647 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1649
1650 let src = src.slice(array_offset, array_len);
1651 let src = src
1652 .as_any()
1653 .downcast_ref::<DictionaryArray<Int32Type>>()
1654 .unwrap();
1655
1656 let filtered = filter(src, predicate).unwrap();
1657
1658 let array = filtered
1659 .as_any()
1660 .downcast_ref::<DictionaryArray<Int32Type>>()
1661 .unwrap();
1662
1663 let values = array
1664 .values()
1665 .as_any()
1666 .downcast_ref::<StringArray>()
1667 .unwrap();
1668
1669 let actual: Vec<_> = array
1670 .keys()
1671 .iter()
1672 .map(|key| key.map(|key| values.value(key as usize)))
1673 .collect();
1674
1675 assert_eq!(actual, expected_strings);
1676 }
1677 }
1678
1679 #[test]
1680 fn test_filter_map() {
1681 let mut builder =
1682 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1683 builder.keys().append_value("key1");
1685 builder.values().append_value(1);
1686 builder.append(true).unwrap();
1687 builder.keys().append_value("key2");
1688 builder.keys().append_value("key3");
1689 builder.values().append_value(2);
1690 builder.values().append_value(3);
1691 builder.append(true).unwrap();
1692 builder.append(false).unwrap();
1693 builder.keys().append_value("key1");
1694 builder.values().append_value(1);
1695 builder.append(true).unwrap();
1696 let maparray = Arc::new(builder.finish()) as ArrayRef;
1697
1698 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1699 .into_iter()
1700 .collect::<BooleanArray>();
1701 let got = filter(&maparray, &indices).unwrap();
1702
1703 let mut builder =
1704 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1705 builder.keys().append_value("key1");
1706 builder.values().append_value(1);
1707 builder.append(true).unwrap();
1708 builder.keys().append_value("key1");
1709 builder.values().append_value(1);
1710 builder.append(true).unwrap();
1711 let expected = Arc::new(builder.finish()) as ArrayRef;
1712
1713 assert_eq!(&expected, &got);
1714 }
1715
1716 #[test]
1717 fn test_filter_fixed_size_list_arrays() {
1718 let value_data = ArrayData::builder(DataType::Int32)
1719 .len(9)
1720 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1721 .build()
1722 .unwrap();
1723 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1724 let list_data = ArrayData::builder(list_data_type)
1725 .len(3)
1726 .add_child_data(value_data)
1727 .build()
1728 .unwrap();
1729 let array = FixedSizeListArray::from(list_data);
1730
1731 let filter_array = BooleanArray::from(vec![true, false, false]);
1732
1733 let c = filter(&array, &filter_array).unwrap();
1734 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1735
1736 assert_eq!(filtered.len(), 1);
1737
1738 let list = filtered.value(0);
1739 assert_eq!(
1740 &[0, 1, 2],
1741 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1742 );
1743
1744 let filter_array = BooleanArray::from(vec![true, false, true]);
1745
1746 let c = filter(&array, &filter_array).unwrap();
1747 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1748
1749 assert_eq!(filtered.len(), 2);
1750
1751 let list = filtered.value(0);
1752 assert_eq!(
1753 &[0, 1, 2],
1754 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1755 );
1756 let list = filtered.value(1);
1757 assert_eq!(
1758 &[6, 7, 8],
1759 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1760 );
1761 }
1762
1763 #[test]
1764 fn test_filter_fixed_size_list_arrays_with_null() {
1765 let value_data = ArrayData::builder(DataType::Int32)
1766 .len(10)
1767 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1768 .build()
1769 .unwrap();
1770
1771 let mut null_bits: [u8; 1] = [0; 1];
1775 bit_util::set_bit(&mut null_bits, 0);
1776 bit_util::set_bit(&mut null_bits, 3);
1777 bit_util::set_bit(&mut null_bits, 4);
1778
1779 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1780 let list_data = ArrayData::builder(list_data_type)
1781 .len(5)
1782 .add_child_data(value_data)
1783 .null_bit_buffer(Some(Buffer::from(null_bits)))
1784 .build()
1785 .unwrap();
1786 let array = FixedSizeListArray::from(list_data);
1787
1788 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1789
1790 let c = filter(&array, &filter_array).unwrap();
1791 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1792
1793 assert_eq!(filtered.len(), 3);
1794
1795 let list = filtered.value(0);
1796 assert_eq!(
1797 &[0, 1],
1798 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1799 );
1800 assert!(filtered.is_null(1));
1801 let list = filtered.value(2);
1802 assert_eq!(
1803 &[6, 7],
1804 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1805 );
1806 }
1807
1808 fn test_filter_union_array(array: UnionArray) {
1809 let filter_array = BooleanArray::from(vec![true, false, false]);
1810 let c = filter(&array, &filter_array).unwrap();
1811 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1812
1813 let mut builder = UnionBuilder::new_dense();
1814 builder.append::<Int32Type>("A", 1).unwrap();
1815 let expected_array = builder.build().unwrap();
1816
1817 compare_union_arrays(filtered, &expected_array);
1818
1819 let filter_array = BooleanArray::from(vec![true, false, true]);
1820 let c = filter(&array, &filter_array).unwrap();
1821 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1822
1823 let mut builder = UnionBuilder::new_dense();
1824 builder.append::<Int32Type>("A", 1).unwrap();
1825 builder.append::<Int32Type>("A", 34).unwrap();
1826 let expected_array = builder.build().unwrap();
1827
1828 compare_union_arrays(filtered, &expected_array);
1829
1830 let filter_array = BooleanArray::from(vec![true, true, false]);
1831 let c = filter(&array, &filter_array).unwrap();
1832 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1833
1834 let mut builder = UnionBuilder::new_dense();
1835 builder.append::<Int32Type>("A", 1).unwrap();
1836 builder.append::<Float64Type>("B", 3.2).unwrap();
1837 let expected_array = builder.build().unwrap();
1838
1839 compare_union_arrays(filtered, &expected_array);
1840 }
1841
1842 #[test]
1843 fn test_filter_union_array_dense() {
1844 let mut builder = UnionBuilder::new_dense();
1845 builder.append::<Int32Type>("A", 1).unwrap();
1846 builder.append::<Float64Type>("B", 3.2).unwrap();
1847 builder.append::<Int32Type>("A", 34).unwrap();
1848 let array = builder.build().unwrap();
1849
1850 test_filter_union_array(array);
1851 }
1852
1853 #[test]
1854 fn test_filter_run_union_array_dense() {
1855 let mut builder = UnionBuilder::new_dense();
1856 builder.append::<Int32Type>("A", 1).unwrap();
1857 builder.append::<Int32Type>("A", 3).unwrap();
1858 builder.append::<Int32Type>("A", 34).unwrap();
1859 let array = builder.build().unwrap();
1860
1861 let filter_array = BooleanArray::from(vec![true, true, false]);
1862 let c = filter(&array, &filter_array).unwrap();
1863 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1864
1865 let mut builder = UnionBuilder::new_dense();
1866 builder.append::<Int32Type>("A", 1).unwrap();
1867 builder.append::<Int32Type>("A", 3).unwrap();
1868 let expected = builder.build().unwrap();
1869
1870 assert_eq!(filtered.to_data(), expected.to_data());
1871 }
1872
1873 #[test]
1874 fn test_filter_union_array_dense_with_nulls() {
1875 let mut builder = UnionBuilder::new_dense();
1876 builder.append::<Int32Type>("A", 1).unwrap();
1877 builder.append::<Float64Type>("B", 3.2).unwrap();
1878 builder.append_null::<Float64Type>("B").unwrap();
1879 builder.append::<Int32Type>("A", 34).unwrap();
1880 let array = builder.build().unwrap();
1881
1882 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1883 let c = filter(&array, &filter_array).unwrap();
1884 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1885
1886 let mut builder = UnionBuilder::new_dense();
1887 builder.append::<Int32Type>("A", 1).unwrap();
1888 builder.append::<Float64Type>("B", 3.2).unwrap();
1889 let expected_array = builder.build().unwrap();
1890
1891 compare_union_arrays(filtered, &expected_array);
1892
1893 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1894 let c = filter(&array, &filter_array).unwrap();
1895 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1896
1897 let mut builder = UnionBuilder::new_dense();
1898 builder.append::<Int32Type>("A", 1).unwrap();
1899 builder.append_null::<Float64Type>("B").unwrap();
1900 let expected_array = builder.build().unwrap();
1901
1902 compare_union_arrays(filtered, &expected_array);
1903 }
1904
1905 #[test]
1906 fn test_filter_union_array_sparse() {
1907 let mut builder = UnionBuilder::new_sparse();
1908 builder.append::<Int32Type>("A", 1).unwrap();
1909 builder.append::<Float64Type>("B", 3.2).unwrap();
1910 builder.append::<Int32Type>("A", 34).unwrap();
1911 let array = builder.build().unwrap();
1912
1913 test_filter_union_array(array);
1914 }
1915
1916 #[test]
1917 fn test_filter_union_array_sparse_with_nulls() {
1918 let mut builder = UnionBuilder::new_sparse();
1919 builder.append::<Int32Type>("A", 1).unwrap();
1920 builder.append::<Float64Type>("B", 3.2).unwrap();
1921 builder.append_null::<Float64Type>("B").unwrap();
1922 builder.append::<Int32Type>("A", 34).unwrap();
1923 let array = builder.build().unwrap();
1924
1925 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1926 let c = filter(&array, &filter_array).unwrap();
1927 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1928
1929 let mut builder = UnionBuilder::new_sparse();
1930 builder.append::<Int32Type>("A", 1).unwrap();
1931 builder.append_null::<Float64Type>("B").unwrap();
1932 let expected_array = builder.build().unwrap();
1933
1934 compare_union_arrays(filtered, &expected_array);
1935 }
1936
1937 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1938 assert_eq!(union1.len(), union2.len());
1939
1940 for i in 0..union1.len() {
1941 let type_id = union1.type_id(i);
1942
1943 let slot1 = union1.value(i);
1944 let slot2 = union2.value(i);
1945
1946 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1947
1948 if !slot1.is_null(0) && !slot2.is_null(0) {
1949 match type_id {
1950 0 => {
1951 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1952 assert_eq!(slot1.len(), 1);
1953 let value1 = slot1.value(0);
1954
1955 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1956 assert_eq!(slot2.len(), 1);
1957 let value2 = slot2.value(0);
1958 assert_eq!(value1, value2);
1959 }
1960 1 => {
1961 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1962 assert_eq!(slot1.len(), 1);
1963 let value1 = slot1.value(0);
1964
1965 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1966 assert_eq!(slot2.len(), 1);
1967 let value2 = slot2.value(0);
1968 assert_eq!(value1, value2);
1969 }
1970 _ => unreachable!(),
1971 }
1972 }
1973 }
1974 }
1975
1976 #[test]
1977 fn test_filter_struct() {
1978 let predicate = BooleanArray::from(vec![true, false, true, false]);
1979
1980 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1981 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1982
1983 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1984 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1985
1986 let null_mask = NullBuffer::from(vec![true, false, false, true]);
1987 let null_mask_filtered = NullBuffer::from(vec![true, false]);
1988
1989 let a_field = Field::new("a", DataType::Utf8, false);
1990 let b_field = Field::new("b", DataType::Int32, false);
1991
1992 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1993 let expected =
1994 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1995
1996 let result = filter(&array, &predicate).unwrap();
1997
1998 assert_eq!(result.to_data(), expected.to_data());
1999
2000 let array = StructArray::new(
2001 vec![a_field.clone()].into(),
2002 vec![a.clone()],
2003 Some(null_mask.clone()),
2004 );
2005 let expected = StructArray::new(
2006 vec![a_field.clone()].into(),
2007 vec![a_filtered.clone()],
2008 Some(null_mask_filtered.clone()),
2009 );
2010
2011 let result = filter(&array, &predicate).unwrap();
2012
2013 assert_eq!(result.to_data(), expected.to_data());
2014
2015 let array = StructArray::new(
2016 vec![a_field.clone(), b_field.clone()].into(),
2017 vec![a.clone(), b.clone()],
2018 None,
2019 );
2020 let expected = StructArray::new(
2021 vec![a_field.clone(), b_field.clone()].into(),
2022 vec![a_filtered.clone(), b_filtered.clone()],
2023 None,
2024 );
2025
2026 let result = filter(&array, &predicate).unwrap();
2027
2028 assert_eq!(result.to_data(), expected.to_data());
2029
2030 let array = StructArray::new(
2031 vec![a_field.clone(), b_field.clone()].into(),
2032 vec![a.clone(), b.clone()],
2033 Some(null_mask.clone()),
2034 );
2035
2036 let expected = StructArray::new(
2037 vec![a_field.clone(), b_field.clone()].into(),
2038 vec![a_filtered.clone(), b_filtered.clone()],
2039 Some(null_mask_filtered.clone()),
2040 );
2041
2042 let result = filter(&array, &predicate).unwrap();
2043
2044 assert_eq!(result.to_data(), expected.to_data());
2045 }
2046
2047 #[test]
2048 fn test_filter_empty_struct() {
2049 let fields = arrow_schema::Field::new(
2056 "a",
2057 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2058 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2059 arrow_schema::Field::new(
2060 "c",
2061 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2062 true,
2063 ),
2064 ])),
2065 true,
2066 );
2067
2068 let schema = Arc::new(Schema::new(vec![fields]));
2076
2077 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2078 let c = Arc::new(StructArray::new_empty_fields(
2079 3,
2080 Some(NullBuffer::from(vec![true, true, true])),
2081 ));
2082 let a = StructArray::new(
2083 vec![
2084 Field::new("b", DataType::Int64, true),
2085 Field::new("c", DataType::Struct(Fields::empty()), true),
2086 ]
2087 .into(),
2088 vec![b.clone(), c.clone()],
2089 Some(NullBuffer::from(vec![true, true, true])),
2090 );
2091 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2092 println!("{record_batch:?}");
2093
2094 let predicate = BooleanArray::from(vec![true, false, true]);
2096 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2097
2098 assert_eq!(filtered_batch.num_rows(), 2);
2100 }
2101}