1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::ArrayDataBuilder;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117 let nulls = filter.nulls().unwrap();
118 let mask = filter.values() & nulls.inner();
119 BooleanArray::new(mask, None)
120}
121
122pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144 let mut filter_builder = FilterBuilder::new(predicate);
145
146 if multiple_arrays(values.data_type()) {
147 filter_builder = filter_builder.optimize();
150 }
151
152 let predicate = filter_builder.build();
153
154 filter_array(values, &predicate)
155}
156
157fn multiple_arrays(data_type: &DataType) -> bool {
158 match data_type {
159 DataType::Struct(fields) => {
160 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161 }
162 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163 _ => false,
164 }
165}
166
167pub fn filter_record_batch(
172 record_batch: &RecordBatch,
173 predicate: &BooleanArray,
174) -> Result<RecordBatch, ArrowError> {
175 let mut filter_builder = FilterBuilder::new(predicate);
176 if record_batch.num_columns() > 1 {
177 filter_builder = filter_builder.optimize();
180 }
181 let filter = filter_builder.build();
182
183 let filtered_arrays = record_batch
184 .columns()
185 .iter()
186 .map(|a| filter_array(a, &filter))
187 .collect::<Result<Vec<_>, _>>()?;
188 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190}
191
192#[derive(Debug)]
194pub struct FilterBuilder {
195 filter: BooleanArray,
196 count: usize,
197 strategy: IterationStrategy,
198}
199
200impl FilterBuilder {
201 pub fn new(filter: &BooleanArray) -> Self {
203 let filter = match filter.null_count() {
204 0 => filter.clone(),
205 _ => prep_null_mask_filter(filter),
206 };
207
208 let count = filter_count(&filter);
209 let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211 Self {
212 filter,
213 count,
214 strategy,
215 }
216 }
217
218 pub fn optimize(mut self) -> Self {
224 match self.strategy {
225 IterationStrategy::SlicesIterator => {
226 let slices = SlicesIterator::new(&self.filter).collect();
227 self.strategy = IterationStrategy::Slices(slices)
228 }
229 IterationStrategy::IndexIterator => {
230 let indices = IndexIterator::new(&self.filter, self.count).collect();
231 self.strategy = IterationStrategy::Indices(indices)
232 }
233 _ => {}
234 }
235 self
236 }
237
238 pub fn build(self) -> FilterPredicate {
240 FilterPredicate {
241 filter: self.filter,
242 count: self.count,
243 strategy: self.strategy,
244 }
245 }
246}
247
248#[derive(Debug)]
250enum IterationStrategy {
251 SlicesIterator,
253 IndexIterator,
255 Indices(Vec<usize>),
257 Slices(Vec<(usize, usize)>),
259 All,
261 None,
263}
264
265impl IterationStrategy {
266 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269 if filter_length == 0 || filter_count == 0 {
270 return IterationStrategy::None;
271 }
272
273 if filter_count == filter_length {
274 return IterationStrategy::All;
275 }
276
277 let selectivity_frac = filter_count as f64 / filter_length as f64;
282 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283 return IterationStrategy::SlicesIterator;
284 }
285 IterationStrategy::IndexIterator
286 }
287}
288
289#[derive(Debug)]
291pub struct FilterPredicate {
292 filter: BooleanArray,
293 count: usize,
294 strategy: IterationStrategy,
295}
296
297impl FilterPredicate {
298 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300 filter_array(values, self)
301 }
302
303 pub fn count(&self) -> usize {
305 self.count
306 }
307}
308
309fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310 if predicate.filter.len() > values.len() {
311 return Err(ArrowError::InvalidArgumentError(format!(
312 "Filter predicate of length {} is larger than target array of length {}",
313 predicate.filter.len(),
314 values.len()
315 )));
316 }
317
318 match predicate.strategy {
319 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321 _ => downcast_primitive_array! {
323 values => Ok(Arc::new(filter_primitive(values, predicate))),
324 DataType::Boolean => {
325 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326 Ok(Arc::new(filter_boolean(values, predicate)))
327 }
328 DataType::Utf8 => {
329 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330 }
331 DataType::LargeUtf8 => {
332 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333 }
334 DataType::Utf8View => {
335 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336 }
337 DataType::Binary => {
338 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339 }
340 DataType::LargeBinary => {
341 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342 }
343 DataType::BinaryView => {
344 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345 }
346 DataType::FixedSizeBinary(_) => {
347 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348 }
349 DataType::RunEndEncoded(_, _) => {
350 downcast_run_array!{
351 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353 }
354 }
355 DataType::Dictionary(_, _) => downcast_dictionary_array! {
356 values => Ok(Arc::new(filter_dict(values, predicate))),
357 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358 }
359 DataType::Struct(_) => {
360 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361 }
362 DataType::Union(_, UnionMode::Sparse) => {
363 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364 }
365 _ => {
366 let data = values.to_data();
367 let mut mutable = MutableArrayData::new(
369 vec![&data],
370 false,
371 predicate.count,
372 );
373
374 match &predicate.strategy {
375 IterationStrategy::Slices(slices) => {
376 slices
377 .iter()
378 .for_each(|(start, end)| mutable.extend(0, *start, *end));
379 }
380 _ => {
381 let iter = SlicesIterator::new(&predicate.filter);
382 iter.for_each(|(start, end)| mutable.extend(0, start, end));
383 }
384 }
385
386 let data = mutable.freeze();
387 Ok(make_array(data))
388 }
389 },
390 }
391}
392
393fn filter_run_end_array<R: RunEndIndexType>(
395 array: &RunArray<R>,
396 predicate: &FilterPredicate,
397) -> Result<RunArray<R>, ArrowError>
398where
399 R::Native: Into<i64> + From<bool>,
400 R::Native: AddAssign,
401{
402 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405 let mut start = 0u64;
406 let mut j = 0;
407 let mut count = R::default_value();
408 let filter_values = predicate.filter.values();
409 let run_ends = run_ends.inner();
410
411 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412 let mut keep = false;
413 let mut end = run_ends[i].into() as u64;
414 let difference = end.saturating_sub(filter_values.len() as u64);
415 end -= difference;
416
417 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419 count += R::Native::from(pred);
420 keep |= pred
421 }
422 new_run_ends[j] = count;
424 j += keep as usize;
425
426 start = end;
427 keep
428 })
429 .into();
430
431 new_run_ends.truncate(j);
432
433 let values = array.values();
434 let values = filter(&values, &pred)?;
435
436 let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
437 RunArray::try_new(&run_ends, &values)
438}
439
440fn filter_null_mask(
447 nulls: Option<&NullBuffer>,
448 predicate: &FilterPredicate,
449) -> Option<(usize, Buffer)> {
450 let nulls = nulls?;
451 if nulls.null_count() == 0 {
452 return None;
453 }
454
455 let nulls = filter_bits(nulls.inner(), predicate);
456 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460 if null_count == 0 {
461 return None;
462 }
463
464 Some((null_count, nulls))
465}
466
467fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469 let src = buffer.values();
470 let offset = buffer.offset();
471
472 match &predicate.strategy {
473 IterationStrategy::IndexIterator => {
474 let bits = IndexIterator::new(&predicate.filter, predicate.count)
475 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479 }
480 IterationStrategy::Indices(indices) => {
481 let bits = indices
482 .iter()
483 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487 }
488 IterationStrategy::SlicesIterator => {
489 let mut builder = BooleanBufferBuilder::new(predicate.count);
490 for (start, end) in SlicesIterator::new(&predicate.filter) {
491 builder.append_packed_range(start + offset..end + offset, src)
492 }
493 builder.into()
494 }
495 IterationStrategy::Slices(slices) => {
496 let mut builder = BooleanBufferBuilder::new(predicate.count);
497 for (start, end) in slices {
498 builder.append_packed_range(*start + offset..*end + offset, src)
499 }
500 builder.into()
501 }
502 IterationStrategy::All | IterationStrategy::None => unreachable!(),
503 }
504}
505
506fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508 let values = filter_bits(array.values(), predicate);
509
510 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511 .len(predicate.count)
512 .add_buffer(values);
513
514 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516 }
517
518 let data = unsafe { builder.build_unchecked() };
519 BooleanArray::from(data)
520}
521
522#[inline(never)]
523fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524 assert!(values.len() >= predicate.filter.len());
525
526 match &predicate.strategy {
527 IterationStrategy::SlicesIterator => {
528 let mut buffer = Vec::with_capacity(predicate.count);
529 for (start, end) in SlicesIterator::new(&predicate.filter) {
530 buffer.extend_from_slice(&values[start..end]);
531 }
532 buffer.into()
533 }
534 IterationStrategy::Slices(slices) => {
535 let mut buffer = Vec::with_capacity(predicate.count);
536 for (start, end) in slices {
537 buffer.extend_from_slice(&values[*start..*end]);
538 }
539 buffer.into()
540 }
541 IterationStrategy::IndexIterator => {
542 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546 }
547 IterationStrategy::Indices(indices) => {
548 let iter = indices.iter().map(|x| values[*x]);
549 iter.collect::<Vec<_>>().into()
550 }
551 IterationStrategy::All | IterationStrategy::None => unreachable!(),
552 }
553}
554
555fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557where
558 T: ArrowPrimitiveType,
559{
560 let values = array.values();
561 let buffer = filter_native(values, predicate);
562 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563 .len(predicate.count)
564 .add_buffer(buffer);
565
566 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568 }
569
570 let data = unsafe { builder.build_unchecked() };
571 PrimitiveArray::from(data)
572}
573
574struct FilterBytes<'a, OffsetSize> {
579 src_offsets: &'a [OffsetSize],
580 src_values: &'a [u8],
581 dst_offsets: Vec<OffsetSize>,
582 dst_values: Vec<u8>,
583 cur_offset: OffsetSize,
584}
585
586impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587where
588 OffsetSize: OffsetSizeTrait,
589{
590 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591 where
592 T: ByteArrayType<Offset = OffsetSize>,
593 {
594 let dst_values = Vec::new();
595 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596 let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598 dst_offsets.push(cur_offset);
599
600 Self {
601 src_offsets: array.value_offsets(),
602 src_values: array.value_data(),
603 dst_offsets,
604 dst_values,
605 cur_offset,
606 }
607 }
608
609 #[inline]
611 fn get_value_offset(&self, idx: usize) -> usize {
612 self.src_offsets[idx].as_usize()
613 }
614
615 #[inline]
617 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618 let start = self.get_value_offset(idx);
620 let end = self.get_value_offset(idx + 1);
621 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622 (start, end, len)
623 }
624
625 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626 self.dst_offsets.extend(iter.map(|idx| {
627 let start = self.src_offsets[idx].as_usize();
628 let end = self.src_offsets[idx + 1].as_usize();
629 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630 self.cur_offset += len;
631
632 self.cur_offset
633 }));
634 }
635
636 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638 self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640 for idx in iter {
641 let start = self.src_offsets[idx].as_usize();
642 let end = self.src_offsets[idx + 1].as_usize();
643 self.dst_values
644 .extend_from_slice(&self.src_values[start..end]);
645 }
646 }
647
648 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649 self.dst_offsets.reserve_exact(count);
650 for (start, end) in iter {
651 for idx in start..end {
653 let (_, _, len) = self.get_value_range(idx);
654 self.cur_offset += len;
655 self.dst_offsets.push(self.cur_offset);
656 }
657 }
658 }
659
660 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662 self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664 for (start, end) in iter {
665 let value_start = self.get_value_offset(start);
666 let value_end = self.get_value_offset(end);
667 self.dst_values
668 .extend_from_slice(&self.src_values[value_start..value_end]);
669 }
670 }
671}
672
673fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678where
679 T: ByteArrayType,
680{
681 let mut filter = FilterBytes::new(predicate.count, array);
682
683 match &predicate.strategy {
684 IterationStrategy::SlicesIterator => {
685 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686 filter.extend_slices(SlicesIterator::new(&predicate.filter))
687 }
688 IterationStrategy::Slices(slices) => {
689 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690 filter.extend_slices(slices.iter().cloned())
691 }
692 IterationStrategy::IndexIterator => {
693 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695 }
696 IterationStrategy::Indices(indices) => {
697 filter.extend_offsets_idx(indices.iter().cloned());
698 filter.extend_idx(indices.iter().cloned())
699 }
700 IterationStrategy::All | IterationStrategy::None => unreachable!(),
701 }
702
703 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704 .len(predicate.count)
705 .add_buffer(filter.dst_offsets.into())
706 .add_buffer(filter.dst_values.into());
707
708 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710 }
711
712 let data = unsafe { builder.build_unchecked() };
713 GenericByteArray::from(data)
714}
715
716fn filter_byte_view<T: ByteViewType>(
718 array: &GenericByteViewArray<T>,
719 predicate: &FilterPredicate,
720) -> GenericByteViewArray<T> {
721 let new_view_buffer = filter_native(array.views(), predicate);
722
723 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724 .len(predicate.count)
725 .add_buffer(new_view_buffer)
726 .add_buffers(array.data_buffers().to_vec());
727
728 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730 }
731
732 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733}
734
735fn filter_fixed_size_binary(
736 array: &FixedSizeBinaryArray,
737 predicate: &FilterPredicate,
738) -> FixedSizeBinaryArray {
739 let values: &[u8] = array.values();
740 let value_length = array.value_length() as usize;
741 let calculate_offset_from_index = |index: usize| index * value_length;
742 let buffer = match &predicate.strategy {
743 IterationStrategy::SlicesIterator => {
744 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745 for (start, end) in SlicesIterator::new(&predicate.filter) {
746 buffer.extend_from_slice(
747 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748 );
749 }
750 buffer
751 }
752 IterationStrategy::Slices(slices) => {
753 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754 for (start, end) in slices {
755 buffer.extend_from_slice(
756 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757 );
758 }
759 buffer
760 }
761 IterationStrategy::IndexIterator => {
762 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764 });
765
766 let mut buffer = MutableBuffer::new(predicate.count * value_length);
767 iter.for_each(|item| buffer.extend_from_slice(item));
768 buffer
769 }
770 IterationStrategy::Indices(indices) => {
771 let iter = indices.iter().map(|x| {
772 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773 });
774
775 let mut buffer = MutableBuffer::new(predicate.count * value_length);
776 iter.for_each(|item| buffer.extend_from_slice(item));
777 buffer
778 }
779 IterationStrategy::All | IterationStrategy::None => unreachable!(),
780 };
781 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782 .len(predicate.count)
783 .add_buffer(buffer.into());
784
785 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787 }
788
789 let data = unsafe { builder.build_unchecked() };
790 FixedSizeBinaryArray::from(data)
791}
792
793fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795where
796 T: ArrowDictionaryKeyType,
797 T::Native: num::Num,
798{
799 let builder = filter_primitive::<T>(array.keys(), predicate)
800 .into_data()
801 .into_builder()
802 .data_type(array.data_type().clone())
803 .child_data(vec![array.values().to_data()]);
804
805 DictionaryArray::from(unsafe { builder.build_unchecked() })
808}
809
810fn filter_struct(
812 array: &StructArray,
813 predicate: &FilterPredicate,
814) -> Result<StructArray, ArrowError> {
815 let columns = array
816 .columns()
817 .iter()
818 .map(|column| filter_array(column, predicate))
819 .collect::<Result<_, _>>()?;
820
821 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825 } else {
826 None
827 };
828
829 Ok(unsafe {
830 StructArray::new_unchecked_with_length(
831 array.fields().clone(),
832 columns,
833 nulls,
834 predicate.count(),
835 )
836 })
837}
838
839fn filter_sparse_union(
841 array: &UnionArray,
842 predicate: &FilterPredicate,
843) -> Result<UnionArray, ArrowError> {
844 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845 unreachable!()
846 };
847
848 let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
849
850 let children = fields
851 .iter()
852 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
853 .collect::<Result<_, _>>()?;
854
855 Ok(unsafe {
856 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
857 })
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use arrow_array::builder::*;
864 use arrow_array::cast::as_run_array;
865 use arrow_array::types::*;
866 use arrow_data::ArrayData;
867 use rand::distr::uniform::{UniformSampler, UniformUsize};
868 use rand::distr::{Alphanumeric, StandardUniform};
869 use rand::prelude::*;
870 use rand::rng;
871
872 macro_rules! def_temporal_test {
873 ($test:ident, $array_type: ident, $data: expr) => {
874 #[test]
875 fn $test() {
876 let a = $data;
877 let b = BooleanArray::from(vec![true, false, true, false]);
878 let c = filter(&a, &b).unwrap();
879 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880 assert_eq!(2, d.len());
881 assert_eq!(1, d.value(0));
882 assert_eq!(3, d.value(1));
883 }
884 };
885 }
886
887 def_temporal_test!(
888 test_filter_date32,
889 Date32Array,
890 Date32Array::from(vec![1, 2, 3, 4])
891 );
892 def_temporal_test!(
893 test_filter_date64,
894 Date64Array,
895 Date64Array::from(vec![1, 2, 3, 4])
896 );
897 def_temporal_test!(
898 test_filter_time32_second,
899 Time32SecondArray,
900 Time32SecondArray::from(vec![1, 2, 3, 4])
901 );
902 def_temporal_test!(
903 test_filter_time32_millisecond,
904 Time32MillisecondArray,
905 Time32MillisecondArray::from(vec![1, 2, 3, 4])
906 );
907 def_temporal_test!(
908 test_filter_time64_microsecond,
909 Time64MicrosecondArray,
910 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911 );
912 def_temporal_test!(
913 test_filter_time64_nanosecond,
914 Time64NanosecondArray,
915 Time64NanosecondArray::from(vec![1, 2, 3, 4])
916 );
917 def_temporal_test!(
918 test_filter_duration_second,
919 DurationSecondArray,
920 DurationSecondArray::from(vec![1, 2, 3, 4])
921 );
922 def_temporal_test!(
923 test_filter_duration_millisecond,
924 DurationMillisecondArray,
925 DurationMillisecondArray::from(vec![1, 2, 3, 4])
926 );
927 def_temporal_test!(
928 test_filter_duration_microsecond,
929 DurationMicrosecondArray,
930 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931 );
932 def_temporal_test!(
933 test_filter_duration_nanosecond,
934 DurationNanosecondArray,
935 DurationNanosecondArray::from(vec![1, 2, 3, 4])
936 );
937 def_temporal_test!(
938 test_filter_timestamp_second,
939 TimestampSecondArray,
940 TimestampSecondArray::from(vec![1, 2, 3, 4])
941 );
942 def_temporal_test!(
943 test_filter_timestamp_millisecond,
944 TimestampMillisecondArray,
945 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946 );
947 def_temporal_test!(
948 test_filter_timestamp_microsecond,
949 TimestampMicrosecondArray,
950 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951 );
952 def_temporal_test!(
953 test_filter_timestamp_nanosecond,
954 TimestampNanosecondArray,
955 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956 );
957
958 #[test]
959 fn test_filter_array_slice() {
960 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961 let b = BooleanArray::from(vec![true, false, false, true]);
962 let c = filter(&a, &b).unwrap();
966 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967 assert_eq!(2, d.len());
968 assert_eq!(6, d.value(0));
969 assert_eq!(9, d.value(1));
970 }
971
972 #[test]
973 fn test_filter_array_low_density() {
974 let mut data_values = (1..=65).collect::<Vec<i32>>();
976 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977 data_values.extend_from_slice(&[66, 67]);
979 filter_values.extend_from_slice(&[false, true]);
980 let a = Int32Array::from(data_values);
981 let b = BooleanArray::from(filter_values);
982 let c = filter(&a, &b).unwrap();
983 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984 assert_eq!(2, d.len());
985 assert_eq!(65, d.value(0));
986 assert_eq!(67, d.value(1));
987 }
988
989 #[test]
990 fn test_filter_array_high_density() {
991 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993 let mut filter_values = (1..=65)
994 .map(|i| !matches!(i % 65, 0))
995 .collect::<Vec<bool>>();
996 data_values[1] = None;
998 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000 filter_values.extend_from_slice(&[false, true, true, true]);
1001 let a = Int32Array::from(data_values);
1002 let b = BooleanArray::from(filter_values);
1003 let c = filter(&a, &b).unwrap();
1004 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005 assert_eq!(67, d.len());
1006 assert_eq!(3, d.null_count());
1007 assert_eq!(1, d.value(0));
1008 assert!(d.is_null(1));
1009 assert_eq!(64, d.value(63));
1010 assert!(d.is_null(64));
1011 assert_eq!(67, d.value(65));
1012 }
1013
1014 #[test]
1015 fn test_filter_string_array_simple() {
1016 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017 let b = BooleanArray::from(vec![true, false, true, false]);
1018 let c = filter(&a, &b).unwrap();
1019 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020 assert_eq!(2, d.len());
1021 assert_eq!("hello", d.value(0));
1022 assert_eq!("world", d.value(1));
1023 }
1024
1025 #[test]
1026 fn test_filter_primitive_array_with_null() {
1027 let a = Int32Array::from(vec![Some(5), None]);
1028 let b = BooleanArray::from(vec![false, true]);
1029 let c = filter(&a, &b).unwrap();
1030 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031 assert_eq!(1, d.len());
1032 assert!(d.is_null(0));
1033 }
1034
1035 #[test]
1036 fn test_filter_string_array_with_null() {
1037 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038 let b = BooleanArray::from(vec![true, false, false, true]);
1039 let c = filter(&a, &b).unwrap();
1040 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041 assert_eq!(2, d.len());
1042 assert_eq!("hello", d.value(0));
1043 assert!(!d.is_null(0));
1044 assert!(d.is_null(1));
1045 }
1046
1047 #[test]
1048 fn test_filter_binary_array_with_null() {
1049 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050 let a = BinaryArray::from(data);
1051 let b = BooleanArray::from(vec![true, false, false, true]);
1052 let c = filter(&a, &b).unwrap();
1053 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054 assert_eq!(2, d.len());
1055 assert_eq!(b"hello", d.value(0));
1056 assert!(!d.is_null(0));
1057 assert!(d.is_null(1));
1058 }
1059
1060 fn _test_filter_byte_view<T>()
1061 where
1062 T: ByteViewType,
1063 str: AsRef<T::Native>,
1064 T::Native: PartialEq,
1065 {
1066 let array = {
1067 let mut builder = GenericByteViewBuilder::<T>::new();
1069 builder.append_value("hello");
1070 builder.append_value("world");
1071 builder.append_null();
1072 builder.append_value("large payload over 12 bytes");
1073 builder.append_value("lulu");
1074 builder.finish()
1075 };
1076
1077 {
1078 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079 let actual = filter(&array, &predicate).unwrap();
1080
1081 assert_eq!(actual.len(), 3);
1082
1083 let expected = {
1084 let mut builder = GenericByteViewBuilder::<T>::new();
1086 builder.append_value("hello");
1087 builder.append_null();
1088 builder.append_value("large payload over 12 bytes");
1089 builder.finish()
1090 };
1091
1092 assert_eq!(actual.as_ref(), &expected);
1093 }
1094
1095 {
1096 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097 let actual = filter(&array, &predicate).unwrap();
1098
1099 assert_eq!(actual.len(), 2);
1100
1101 let expected = {
1102 let mut builder = GenericByteViewBuilder::<T>::new();
1104 builder.append_value("hello");
1105 builder.append_value("lulu");
1106 builder.finish()
1107 };
1108
1109 assert_eq!(actual.as_ref(), &expected);
1110 }
1111 }
1112
1113 #[test]
1114 fn test_filter_string_view() {
1115 _test_filter_byte_view::<StringViewType>()
1116 }
1117
1118 #[test]
1119 fn test_filter_binary_view() {
1120 _test_filter_byte_view::<BinaryViewType>()
1121 }
1122
1123 #[test]
1124 fn test_filter_fixed_binary() {
1125 let v1 = [1_u8, 2];
1126 let v2 = [3_u8, 4];
1127 let v3 = [5_u8, 6];
1128 let v = vec![&v1, &v2, &v3];
1129 let a = FixedSizeBinaryArray::from(v);
1130 let b = BooleanArray::from(vec![true, false, true]);
1131 let c = filter(&a, &b).unwrap();
1132 let d = c
1133 .as_ref()
1134 .as_any()
1135 .downcast_ref::<FixedSizeBinaryArray>()
1136 .unwrap();
1137 assert_eq!(d.len(), 2);
1138 assert_eq!(d.value(0), &v1);
1139 assert_eq!(d.value(1), &v3);
1140 let c2 = FilterBuilder::new(&b)
1141 .optimize()
1142 .build()
1143 .filter(&a)
1144 .unwrap();
1145 let d2 = c2
1146 .as_ref()
1147 .as_any()
1148 .downcast_ref::<FixedSizeBinaryArray>()
1149 .unwrap();
1150 assert_eq!(d, d2);
1151
1152 let b = BooleanArray::from(vec![false, false, false]);
1153 let c = filter(&a, &b).unwrap();
1154 let d = c
1155 .as_ref()
1156 .as_any()
1157 .downcast_ref::<FixedSizeBinaryArray>()
1158 .unwrap();
1159 assert_eq!(d.len(), 0);
1160
1161 let b = BooleanArray::from(vec![true, true, true]);
1162 let c = filter(&a, &b).unwrap();
1163 let d = c
1164 .as_ref()
1165 .as_any()
1166 .downcast_ref::<FixedSizeBinaryArray>()
1167 .unwrap();
1168 assert_eq!(d.len(), 3);
1169 assert_eq!(d.value(0), &v1);
1170 assert_eq!(d.value(1), &v2);
1171 assert_eq!(d.value(2), &v3);
1172
1173 let b = BooleanArray::from(vec![false, false, true]);
1174 let c = filter(&a, &b).unwrap();
1175 let d = c
1176 .as_ref()
1177 .as_any()
1178 .downcast_ref::<FixedSizeBinaryArray>()
1179 .unwrap();
1180 assert_eq!(d.len(), 1);
1181 assert_eq!(d.value(0), &v3);
1182 let c2 = FilterBuilder::new(&b)
1183 .optimize()
1184 .build()
1185 .filter(&a)
1186 .unwrap();
1187 let d2 = c2
1188 .as_ref()
1189 .as_any()
1190 .downcast_ref::<FixedSizeBinaryArray>()
1191 .unwrap();
1192 assert_eq!(d, d2);
1193 }
1194
1195 #[test]
1196 fn test_filter_array_slice_with_null() {
1197 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198 let b = BooleanArray::from(vec![true, false, false, true]);
1199 let c = filter(&a, &b).unwrap();
1203 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204 assert_eq!(2, d.len());
1205 assert!(d.is_null(0));
1206 assert!(!d.is_null(1));
1207 assert_eq!(9, d.value(1));
1208 }
1209
1210 #[test]
1211 fn test_filter_run_end_encoding_array() {
1212 let run_ends = Int64Array::from(vec![2, 3, 8]);
1213 let values = Int64Array::from(vec![7, -2, 9]);
1214 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216 let c = filter(&a, &b).unwrap();
1217 let actual: &RunArray<Int64Type> = as_run_array(&c);
1218 assert_eq!(4, actual.len());
1219
1220 let expected = RunArray::try_new(
1221 &Int64Array::from(vec![1, 2, 4]),
1222 &Int64Array::from(vec![7, -2, 9]),
1223 )
1224 .expect("Failed to make expected RunArray test is broken");
1225
1226 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227 assert_eq!(actual.values(), expected.values())
1228 }
1229
1230 #[test]
1231 fn test_filter_run_end_encoding_array_remove_value() {
1232 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233 let values = Int32Array::from(vec![7, -2, 9, -8]);
1234 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235 let b = BooleanArray::from(vec![
1236 false, true, false, false, true, false, true, false, false, false,
1237 ]);
1238 let c = filter(&a, &b).unwrap();
1239 let actual: &RunArray<Int32Type> = as_run_array(&c);
1240 assert_eq!(3, actual.len());
1241
1242 let expected =
1243 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244 .expect("Failed to make expected RunArray test is broken");
1245
1246 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247 assert_eq!(actual.values(), expected.values())
1248 }
1249
1250 #[test]
1251 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253 let values = Int16Array::from(vec![7, -2, 9, -8]);
1254 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255 let b = BooleanArray::from(vec![
1256 false, false, false, false, false, false, true, false, false, false,
1257 ]);
1258 let c = filter(&a, &b).unwrap();
1259 let actual: &RunArray<Int16Type> = as_run_array(&c);
1260 assert_eq!(1, actual.len());
1261
1262 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263 .expect("Failed to make expected RunArray test is broken");
1264
1265 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266 assert_eq!(actual.values(), expected.values())
1267 }
1268
1269 #[test]
1270 fn test_filter_run_end_encoding_array_empty() {
1271 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272 let values = Int64Array::from(vec![7, -2, 9, -8]);
1273 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274 let b = BooleanArray::from(vec![
1275 false, false, false, false, false, false, false, false, false, false,
1276 ]);
1277 let c = filter(&a, &b).unwrap();
1278 let actual: &RunArray<Int64Type> = as_run_array(&c);
1279 assert_eq!(0, actual.len());
1280 }
1281
1282 #[test]
1283 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285 let values = Int64Array::from(vec![7, -2, 9, -8]);
1286 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287 let b = BooleanArray::from(vec![false, true, true]);
1288 let c = filter(&a, &b).unwrap();
1289 let actual: &RunArray<Int64Type> = as_run_array(&c);
1290 assert_eq!(2, actual.len());
1291
1292 let expected = RunArray::try_new(
1293 &Int64Array::from(vec![1, 2]),
1294 &Int64Array::from(vec![7, -2]),
1295 )
1296 .expect("Failed to make expected RunArray test is broken");
1297
1298 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299 assert_eq!(actual.values(), expected.values())
1300 }
1301
1302 #[test]
1303 fn test_filter_dictionary_array() {
1304 let values = [Some("hello"), None, Some("world"), Some("!")];
1305 let a: Int8DictionaryArray = values.iter().copied().collect();
1306 let b = BooleanArray::from(vec![false, true, true, false]);
1307 let c = filter(&a, &b).unwrap();
1308 let d = c
1309 .as_ref()
1310 .as_any()
1311 .downcast_ref::<Int8DictionaryArray>()
1312 .unwrap();
1313 let value_array = d.values();
1314 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315 assert_eq!(3, values.len());
1317 assert_eq!(2, d.len());
1319 assert!(d.is_null(0));
1320 assert_eq!("world", values.value(d.keys().value(1) as usize));
1321 }
1322
1323 #[test]
1324 fn test_filter_list_array() {
1325 let value_data = ArrayData::builder(DataType::Int32)
1326 .len(8)
1327 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328 .build()
1329 .unwrap();
1330
1331 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333 let list_data_type =
1334 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335 let list_data = ArrayData::builder(list_data_type)
1336 .len(4)
1337 .add_buffer(value_offsets)
1338 .add_child_data(value_data)
1339 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340 .build()
1341 .unwrap();
1342
1343 let a = LargeListArray::from(list_data);
1345 let b = BooleanArray::from(vec![false, true, false, true]);
1346 let result = filter(&a, &b).unwrap();
1347
1348 let value_data = ArrayData::builder(DataType::Int32)
1350 .len(3)
1351 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352 .build()
1353 .unwrap();
1354
1355 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357 let list_data_type =
1358 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359 let expected = ArrayData::builder(list_data_type)
1360 .len(2)
1361 .add_buffer(value_offsets)
1362 .add_child_data(value_data)
1363 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364 .build()
1365 .unwrap();
1366
1367 assert_eq!(&make_array(expected), &result);
1368 }
1369
1370 #[test]
1371 fn test_slice_iterator_bits() {
1372 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373 let filter = BooleanArray::from(filter_values);
1374 let filter_count = filter_count(&filter);
1375
1376 let iter = SlicesIterator::new(&filter);
1377 let chunks = iter.collect::<Vec<_>>();
1378
1379 assert_eq!(chunks, vec![(1, 2)]);
1380 assert_eq!(filter_count, 1);
1381 }
1382
1383 #[test]
1384 fn test_slice_iterator_bits1() {
1385 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386 let filter = BooleanArray::from(filter_values);
1387 let filter_count = filter_count(&filter);
1388
1389 let iter = SlicesIterator::new(&filter);
1390 let chunks = iter.collect::<Vec<_>>();
1391
1392 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393 assert_eq!(filter_count, 64 - 1);
1394 }
1395
1396 #[test]
1397 fn test_slice_iterator_chunk_and_bits() {
1398 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399 let filter = BooleanArray::from(filter_values);
1400 let filter_count = filter_count(&filter);
1401
1402 let iter = SlicesIterator::new(&filter);
1403 let chunks = iter.collect::<Vec<_>>();
1404
1405 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406 assert_eq!(filter_count, 61 + 61 + 5);
1407 }
1408
1409 #[test]
1410 fn test_null_mask() {
1411 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414 let out = filter(&a, &mask1).unwrap();
1415 assert_eq!(out.as_ref(), &a.slice(0, 2));
1416 }
1417
1418 #[test]
1419 fn test_filter_record_batch_no_columns() {
1420 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421 let options = RecordBatchOptions::default().with_row_count(Some(100));
1422 let record_batch =
1423 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424 let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426 assert_eq!(out.num_rows(), 2);
1427 }
1428
1429 #[test]
1430 fn test_fast_path() {
1431 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433 let mask = BooleanArray::from(vec![true, true, true]);
1435 let out = filter(&a, &mask).unwrap();
1436 let b = out
1437 .as_any()
1438 .downcast_ref::<PrimitiveArray<Int64Type>>()
1439 .unwrap();
1440 assert_eq!(&a, b);
1441
1442 let mask = BooleanArray::from(vec![false, false, false]);
1444 let out = filter(&a, &mask).unwrap();
1445 assert_eq!(out.len(), 0);
1446 assert_eq!(out.data_type(), &DataType::Int64);
1447 }
1448
1449 #[test]
1450 fn test_slices() {
1451 let bools = std::iter::repeat_n(true, 10)
1453 .chain(std::iter::repeat_n(false, 30))
1454 .chain(std::iter::repeat_n(true, 20))
1455 .chain(std::iter::repeat_n(false, 17))
1456 .chain(std::iter::repeat_n(true, 4));
1457
1458 let bool_array: BooleanArray = bools.map(Some).collect();
1459
1460 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1461 let expected = vec![(0, 10), (40, 60), (77, 81)];
1462 assert_eq!(slices, expected);
1463
1464 let len = bool_array.len();
1466 let sliced_array = bool_array.slice(7, len - 10);
1467 let sliced_array = sliced_array
1468 .as_any()
1469 .downcast_ref::<BooleanArray>()
1470 .unwrap();
1471 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1472 let expected = vec![(0, 3), (33, 53), (70, 71)];
1473 assert_eq!(slices, expected);
1474 }
1475
1476 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1477 let mut rng = rng();
1478
1479 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1480 .take(mask_len)
1481 .collect();
1482
1483 let buffer = Buffer::from_iter(bools.iter().cloned());
1484
1485 let truncated_length = mask_len - offset - truncate;
1486
1487 let data = ArrayDataBuilder::new(DataType::Boolean)
1488 .len(truncated_length)
1489 .offset(offset)
1490 .add_buffer(buffer)
1491 .build()
1492 .unwrap();
1493
1494 let filter = BooleanArray::from(data);
1495
1496 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1497 .flat_map(|(start, end)| start..end)
1498 .collect();
1499
1500 let count = filter_count(&filter);
1501 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1502
1503 let expected_bits: Vec<_> = bools
1504 .iter()
1505 .skip(offset)
1506 .take(truncated_length)
1507 .enumerate()
1508 .flat_map(|(idx, v)| v.then(|| idx))
1509 .collect();
1510
1511 assert_eq!(slice_bits, expected_bits);
1512 assert_eq!(index_bits, expected_bits);
1513 }
1514
1515 #[test]
1516 #[cfg_attr(miri, ignore)]
1517 fn fuzz_test_slices_iterator() {
1518 let mut rng = rng();
1519
1520 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1521 for _ in 0..100 {
1522 let mask_len = rng.random_range(0..1024);
1523 let max_offset = 64.min(mask_len);
1524 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1525
1526 let max_truncate = 128.min(mask_len - offset);
1527 let truncate = uusize
1528 .sample(&mut rng)
1529 .checked_rem(max_truncate)
1530 .unwrap_or(0);
1531
1532 test_slices_fuzz(mask_len, offset, truncate);
1533 }
1534
1535 test_slices_fuzz(64, 0, 0);
1536 test_slices_fuzz(64, 8, 0);
1537 test_slices_fuzz(64, 8, 8);
1538 test_slices_fuzz(32, 8, 8);
1539 test_slices_fuzz(32, 5, 9);
1540 }
1541
1542 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1544 values
1545 .into_iter()
1546 .zip(predicate)
1547 .filter(|(_, x)| **x)
1548 .map(|(a, _)| a)
1549 .collect()
1550 }
1551
1552 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1554 where
1555 StandardUniform: Distribution<T>,
1556 {
1557 let mut rng = rng();
1558 (0..len)
1559 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1560 .collect()
1561 }
1562
1563 fn gen_strings(
1565 len: usize,
1566 valid_percent: f64,
1567 str_len_range: std::ops::Range<usize>,
1568 ) -> Vec<Option<String>> {
1569 let mut rng = rng();
1570 (0..len)
1571 .map(|_| {
1572 rng.random_bool(valid_percent).then(|| {
1573 let len = rng.random_range(str_len_range.clone());
1574 (0..len)
1575 .map(|_| char::from(rng.sample(Alphanumeric)))
1576 .collect()
1577 })
1578 })
1579 .collect()
1580 }
1581
1582 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1584 src.iter().map(|x| x.as_deref())
1585 }
1586
1587 #[test]
1588 #[cfg_attr(miri, ignore)]
1589 fn fuzz_filter() {
1590 let mut rng = rng();
1591
1592 for i in 0..100 {
1593 let filter_percent = match i {
1594 0..=4 => 1.,
1595 5..=10 => 0.,
1596 _ => rng.random_range(0.0..1.0),
1597 };
1598
1599 let valid_percent = rng.random_range(0.0..1.0);
1600
1601 let array_len = rng.random_range(32..256);
1602 let array_offset = rng.random_range(0..10);
1603
1604 let filter_offset = rng.random_range(0..10);
1606 let filter_truncate = rng.random_range(0..10);
1607 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1608 .take(array_len + filter_offset - filter_truncate)
1609 .collect();
1610
1611 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1612
1613 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1615 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1616 let bools = &bools[filter_offset..];
1617
1618 let values = gen_primitive(array_len + array_offset, valid_percent);
1620 let src = Int32Array::from_iter(values.iter().cloned());
1621
1622 let src = src.slice(array_offset, array_len);
1623 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1624 let values = &values[array_offset..];
1625
1626 let filtered = filter(src, predicate).unwrap();
1627 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1628 let actual: Vec<_> = array.iter().collect();
1629
1630 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1631
1632 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1634 let src = StringArray::from_iter(as_deref(&strings));
1635
1636 let src = src.slice(array_offset, array_len);
1637 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1638
1639 let filtered = filter(src, predicate).unwrap();
1640 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1641 let actual: Vec<_> = array.iter().collect();
1642
1643 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1644 assert_eq!(actual, expected_strings);
1645
1646 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1648
1649 let src = src.slice(array_offset, array_len);
1650 let src = src
1651 .as_any()
1652 .downcast_ref::<DictionaryArray<Int32Type>>()
1653 .unwrap();
1654
1655 let filtered = filter(src, predicate).unwrap();
1656
1657 let array = filtered
1658 .as_any()
1659 .downcast_ref::<DictionaryArray<Int32Type>>()
1660 .unwrap();
1661
1662 let values = array
1663 .values()
1664 .as_any()
1665 .downcast_ref::<StringArray>()
1666 .unwrap();
1667
1668 let actual: Vec<_> = array
1669 .keys()
1670 .iter()
1671 .map(|key| key.map(|key| values.value(key as usize)))
1672 .collect();
1673
1674 assert_eq!(actual, expected_strings);
1675 }
1676 }
1677
1678 #[test]
1679 fn test_filter_map() {
1680 let mut builder =
1681 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1682 builder.keys().append_value("key1");
1684 builder.values().append_value(1);
1685 builder.append(true).unwrap();
1686 builder.keys().append_value("key2");
1687 builder.keys().append_value("key3");
1688 builder.values().append_value(2);
1689 builder.values().append_value(3);
1690 builder.append(true).unwrap();
1691 builder.append(false).unwrap();
1692 builder.keys().append_value("key1");
1693 builder.values().append_value(1);
1694 builder.append(true).unwrap();
1695 let maparray = Arc::new(builder.finish()) as ArrayRef;
1696
1697 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1698 .into_iter()
1699 .collect::<BooleanArray>();
1700 let got = filter(&maparray, &indices).unwrap();
1701
1702 let mut builder =
1703 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1704 builder.keys().append_value("key1");
1705 builder.values().append_value(1);
1706 builder.append(true).unwrap();
1707 builder.keys().append_value("key1");
1708 builder.values().append_value(1);
1709 builder.append(true).unwrap();
1710 let expected = Arc::new(builder.finish()) as ArrayRef;
1711
1712 assert_eq!(&expected, &got);
1713 }
1714
1715 #[test]
1716 fn test_filter_fixed_size_list_arrays() {
1717 let value_data = ArrayData::builder(DataType::Int32)
1718 .len(9)
1719 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1720 .build()
1721 .unwrap();
1722 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1723 let list_data = ArrayData::builder(list_data_type)
1724 .len(3)
1725 .add_child_data(value_data)
1726 .build()
1727 .unwrap();
1728 let array = FixedSizeListArray::from(list_data);
1729
1730 let filter_array = BooleanArray::from(vec![true, false, false]);
1731
1732 let c = filter(&array, &filter_array).unwrap();
1733 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1734
1735 assert_eq!(filtered.len(), 1);
1736
1737 let list = filtered.value(0);
1738 assert_eq!(
1739 &[0, 1, 2],
1740 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1741 );
1742
1743 let filter_array = BooleanArray::from(vec![true, false, true]);
1744
1745 let c = filter(&array, &filter_array).unwrap();
1746 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1747
1748 assert_eq!(filtered.len(), 2);
1749
1750 let list = filtered.value(0);
1751 assert_eq!(
1752 &[0, 1, 2],
1753 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1754 );
1755 let list = filtered.value(1);
1756 assert_eq!(
1757 &[6, 7, 8],
1758 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1759 );
1760 }
1761
1762 #[test]
1763 fn test_filter_fixed_size_list_arrays_with_null() {
1764 let value_data = ArrayData::builder(DataType::Int32)
1765 .len(10)
1766 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1767 .build()
1768 .unwrap();
1769
1770 let mut null_bits: [u8; 1] = [0; 1];
1774 bit_util::set_bit(&mut null_bits, 0);
1775 bit_util::set_bit(&mut null_bits, 3);
1776 bit_util::set_bit(&mut null_bits, 4);
1777
1778 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1779 let list_data = ArrayData::builder(list_data_type)
1780 .len(5)
1781 .add_child_data(value_data)
1782 .null_bit_buffer(Some(Buffer::from(null_bits)))
1783 .build()
1784 .unwrap();
1785 let array = FixedSizeListArray::from(list_data);
1786
1787 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1788
1789 let c = filter(&array, &filter_array).unwrap();
1790 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1791
1792 assert_eq!(filtered.len(), 3);
1793
1794 let list = filtered.value(0);
1795 assert_eq!(
1796 &[0, 1],
1797 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1798 );
1799 assert!(filtered.is_null(1));
1800 let list = filtered.value(2);
1801 assert_eq!(
1802 &[6, 7],
1803 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1804 );
1805 }
1806
1807 fn test_filter_union_array(array: UnionArray) {
1808 let filter_array = BooleanArray::from(vec![true, false, false]);
1809 let c = filter(&array, &filter_array).unwrap();
1810 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1811
1812 let mut builder = UnionBuilder::new_dense();
1813 builder.append::<Int32Type>("A", 1).unwrap();
1814 let expected_array = builder.build().unwrap();
1815
1816 compare_union_arrays(filtered, &expected_array);
1817
1818 let filter_array = BooleanArray::from(vec![true, false, true]);
1819 let c = filter(&array, &filter_array).unwrap();
1820 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1821
1822 let mut builder = UnionBuilder::new_dense();
1823 builder.append::<Int32Type>("A", 1).unwrap();
1824 builder.append::<Int32Type>("A", 34).unwrap();
1825 let expected_array = builder.build().unwrap();
1826
1827 compare_union_arrays(filtered, &expected_array);
1828
1829 let filter_array = BooleanArray::from(vec![true, true, false]);
1830 let c = filter(&array, &filter_array).unwrap();
1831 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1832
1833 let mut builder = UnionBuilder::new_dense();
1834 builder.append::<Int32Type>("A", 1).unwrap();
1835 builder.append::<Float64Type>("B", 3.2).unwrap();
1836 let expected_array = builder.build().unwrap();
1837
1838 compare_union_arrays(filtered, &expected_array);
1839 }
1840
1841 #[test]
1842 fn test_filter_union_array_dense() {
1843 let mut builder = UnionBuilder::new_dense();
1844 builder.append::<Int32Type>("A", 1).unwrap();
1845 builder.append::<Float64Type>("B", 3.2).unwrap();
1846 builder.append::<Int32Type>("A", 34).unwrap();
1847 let array = builder.build().unwrap();
1848
1849 test_filter_union_array(array);
1850 }
1851
1852 #[test]
1853 fn test_filter_run_union_array_dense() {
1854 let mut builder = UnionBuilder::new_dense();
1855 builder.append::<Int32Type>("A", 1).unwrap();
1856 builder.append::<Int32Type>("A", 3).unwrap();
1857 builder.append::<Int32Type>("A", 34).unwrap();
1858 let array = builder.build().unwrap();
1859
1860 let filter_array = BooleanArray::from(vec![true, true, false]);
1861 let c = filter(&array, &filter_array).unwrap();
1862 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1863
1864 let mut builder = UnionBuilder::new_dense();
1865 builder.append::<Int32Type>("A", 1).unwrap();
1866 builder.append::<Int32Type>("A", 3).unwrap();
1867 let expected = builder.build().unwrap();
1868
1869 assert_eq!(filtered.to_data(), expected.to_data());
1870 }
1871
1872 #[test]
1873 fn test_filter_union_array_dense_with_nulls() {
1874 let mut builder = UnionBuilder::new_dense();
1875 builder.append::<Int32Type>("A", 1).unwrap();
1876 builder.append::<Float64Type>("B", 3.2).unwrap();
1877 builder.append_null::<Float64Type>("B").unwrap();
1878 builder.append::<Int32Type>("A", 34).unwrap();
1879 let array = builder.build().unwrap();
1880
1881 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1882 let c = filter(&array, &filter_array).unwrap();
1883 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1884
1885 let mut builder = UnionBuilder::new_dense();
1886 builder.append::<Int32Type>("A", 1).unwrap();
1887 builder.append::<Float64Type>("B", 3.2).unwrap();
1888 let expected_array = builder.build().unwrap();
1889
1890 compare_union_arrays(filtered, &expected_array);
1891
1892 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1893 let c = filter(&array, &filter_array).unwrap();
1894 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1895
1896 let mut builder = UnionBuilder::new_dense();
1897 builder.append::<Int32Type>("A", 1).unwrap();
1898 builder.append_null::<Float64Type>("B").unwrap();
1899 let expected_array = builder.build().unwrap();
1900
1901 compare_union_arrays(filtered, &expected_array);
1902 }
1903
1904 #[test]
1905 fn test_filter_union_array_sparse() {
1906 let mut builder = UnionBuilder::new_sparse();
1907 builder.append::<Int32Type>("A", 1).unwrap();
1908 builder.append::<Float64Type>("B", 3.2).unwrap();
1909 builder.append::<Int32Type>("A", 34).unwrap();
1910 let array = builder.build().unwrap();
1911
1912 test_filter_union_array(array);
1913 }
1914
1915 #[test]
1916 fn test_filter_union_array_sparse_with_nulls() {
1917 let mut builder = UnionBuilder::new_sparse();
1918 builder.append::<Int32Type>("A", 1).unwrap();
1919 builder.append::<Float64Type>("B", 3.2).unwrap();
1920 builder.append_null::<Float64Type>("B").unwrap();
1921 builder.append::<Int32Type>("A", 34).unwrap();
1922 let array = builder.build().unwrap();
1923
1924 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1925 let c = filter(&array, &filter_array).unwrap();
1926 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1927
1928 let mut builder = UnionBuilder::new_sparse();
1929 builder.append::<Int32Type>("A", 1).unwrap();
1930 builder.append_null::<Float64Type>("B").unwrap();
1931 let expected_array = builder.build().unwrap();
1932
1933 compare_union_arrays(filtered, &expected_array);
1934 }
1935
1936 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1937 assert_eq!(union1.len(), union2.len());
1938
1939 for i in 0..union1.len() {
1940 let type_id = union1.type_id(i);
1941
1942 let slot1 = union1.value(i);
1943 let slot2 = union2.value(i);
1944
1945 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1946
1947 if !slot1.is_null(0) && !slot2.is_null(0) {
1948 match type_id {
1949 0 => {
1950 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1951 assert_eq!(slot1.len(), 1);
1952 let value1 = slot1.value(0);
1953
1954 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1955 assert_eq!(slot2.len(), 1);
1956 let value2 = slot2.value(0);
1957 assert_eq!(value1, value2);
1958 }
1959 1 => {
1960 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1961 assert_eq!(slot1.len(), 1);
1962 let value1 = slot1.value(0);
1963
1964 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1965 assert_eq!(slot2.len(), 1);
1966 let value2 = slot2.value(0);
1967 assert_eq!(value1, value2);
1968 }
1969 _ => unreachable!(),
1970 }
1971 }
1972 }
1973 }
1974
1975 #[test]
1976 fn test_filter_struct() {
1977 let predicate = BooleanArray::from(vec![true, false, true, false]);
1978
1979 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1980 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1981
1982 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1983 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1984
1985 let null_mask = NullBuffer::from(vec![true, false, false, true]);
1986 let null_mask_filtered = NullBuffer::from(vec![true, false]);
1987
1988 let a_field = Field::new("a", DataType::Utf8, false);
1989 let b_field = Field::new("b", DataType::Int32, false);
1990
1991 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1992 let expected =
1993 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1994
1995 let result = filter(&array, &predicate).unwrap();
1996
1997 assert_eq!(result.to_data(), expected.to_data());
1998
1999 let array = StructArray::new(
2000 vec![a_field.clone()].into(),
2001 vec![a.clone()],
2002 Some(null_mask.clone()),
2003 );
2004 let expected = StructArray::new(
2005 vec![a_field.clone()].into(),
2006 vec![a_filtered.clone()],
2007 Some(null_mask_filtered.clone()),
2008 );
2009
2010 let result = filter(&array, &predicate).unwrap();
2011
2012 assert_eq!(result.to_data(), expected.to_data());
2013
2014 let array = StructArray::new(
2015 vec![a_field.clone(), b_field.clone()].into(),
2016 vec![a.clone(), b.clone()],
2017 None,
2018 );
2019 let expected = StructArray::new(
2020 vec![a_field.clone(), b_field.clone()].into(),
2021 vec![a_filtered.clone(), b_filtered.clone()],
2022 None,
2023 );
2024
2025 let result = filter(&array, &predicate).unwrap();
2026
2027 assert_eq!(result.to_data(), expected.to_data());
2028
2029 let array = StructArray::new(
2030 vec![a_field.clone(), b_field.clone()].into(),
2031 vec![a.clone(), b.clone()],
2032 Some(null_mask.clone()),
2033 );
2034
2035 let expected = StructArray::new(
2036 vec![a_field.clone(), b_field.clone()].into(),
2037 vec![a_filtered.clone(), b_filtered.clone()],
2038 Some(null_mask_filtered.clone()),
2039 );
2040
2041 let result = filter(&array, &predicate).unwrap();
2042
2043 assert_eq!(result.to_data(), expected.to_data());
2044 }
2045
2046 #[test]
2047 fn test_filter_empty_struct() {
2048 let fields = arrow_schema::Field::new(
2055 "a",
2056 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2057 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2058 arrow_schema::Field::new(
2059 "c",
2060 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2061 true,
2062 ),
2063 ])),
2064 true,
2065 );
2066
2067 let schema = Arc::new(Schema::new(vec![fields]));
2075
2076 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2077 let c = Arc::new(StructArray::new_empty_fields(
2078 3,
2079 Some(NullBuffer::from(vec![true, true, true])),
2080 ));
2081 let a = StructArray::new(
2082 vec![
2083 Field::new("b", DataType::Int64, true),
2084 Field::new("c", DataType::Struct(Fields::empty()), true),
2085 ]
2086 .into(),
2087 vec![b.clone(), c.clone()],
2088 Some(NullBuffer::from(vec![true, true, true])),
2089 );
2090 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2091 println!("{record_batch:?}");
2092
2093 let predicate = BooleanArray::from(vec![true, false, true]);
2095 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2096
2097 assert_eq!(filtered_batch.num_rows(), 2);
2099 }
2100}