1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{
30 ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, RunEndBuffer, ScalarBuffer, bit_util,
31};
32use arrow_buffer::{Buffer, MutableBuffer};
33use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
34use arrow_data::transform::MutableArrayData;
35use arrow_schema::*;
36
37const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
44
45#[derive(Debug)]
57pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
58
59impl<'a> SlicesIterator<'a> {
60 pub fn new(filter: &'a BooleanArray) -> Self {
62 filter.values().into()
63 }
64}
65
66impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
67 fn from(filter: &'a BooleanBuffer) -> Self {
68 Self(filter.set_slices())
69 }
70}
71
72impl Iterator for SlicesIterator<'_> {
73 type Item = (usize, usize);
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.0.next()
77 }
78}
79
80pub(crate) struct IndexIterator<'a> {
85 remaining: usize,
86 iter: BitIndexIterator<'a>,
87}
88
89impl<'a> IndexIterator<'a> {
90 pub(crate) fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
91 assert_eq!(filter.null_count(), 0);
92 let iter = filter.values().set_indices();
93 Self { remaining, iter }
94 }
95
96 pub fn collect(mut self) -> Vec<usize> {
100 let len = self.remaining;
101 let mut result = Vec::with_capacity(len);
102 let ptr: *mut usize = result.as_mut_ptr();
103 for i in 0..len {
104 let next = self.iter.next();
107 debug_assert!(next.is_some(), "IndexIterator exhausted early");
108 unsafe {
109 *ptr.add(i) = next.unwrap_unchecked();
110 }
111 }
112 unsafe {
114 result.set_len(len);
115 }
116 result
117 }
118}
119
120impl Iterator for IndexIterator<'_> {
121 type Item = usize;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 if self.remaining != 0 {
125 let next = self.iter.next().expect("IndexIterator exhausted early");
128 self.remaining -= 1;
129 return Some(next);
131 }
132 None
133 }
134
135 fn size_hint(&self) -> (usize, Option<usize>) {
136 (self.remaining, Some(self.remaining))
137 }
138}
139
140pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
168 let nulls = filter.nulls().unwrap();
169 let mask = filter.values() & nulls.inner();
170 BooleanArray::new(mask, None)
171}
172
173pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
202 let mut filter_builder = FilterBuilder::new(predicate);
203
204 if FilterBuilder::is_optimize_beneficial(values.data_type()) {
205 filter_builder = filter_builder.optimize();
208 }
209
210 let predicate = filter_builder.build();
211
212 filter_array(values, &predicate)
213}
214
215pub fn filter_record_batch(
226 record_batch: &RecordBatch,
227 predicate: &BooleanArray,
228) -> Result<RecordBatch, ArrowError> {
229 let mut filter_builder = FilterBuilder::new(predicate);
230 let num_cols = record_batch.num_columns();
231 if num_cols > 1
232 || (num_cols > 0
233 && FilterBuilder::is_optimize_beneficial(
234 record_batch.schema_ref().field(0).data_type(),
235 ))
236 {
237 filter_builder = filter_builder.optimize();
240 }
241 let filter = filter_builder.build();
242
243 filter.filter_record_batch(record_batch)
244}
245
246#[derive(Debug)]
248pub struct FilterBuilder {
249 filter: BooleanArray,
250 count: usize,
251 strategy: IterationStrategy,
252}
253
254impl FilterBuilder {
255 pub fn new(filter: &BooleanArray) -> Self {
257 Self::new_with_count(filter, filter.true_count())
258 }
259
260 pub(crate) fn new_with_count(filter: &BooleanArray, count: usize) -> Self {
261 let filter = match filter.null_count() {
262 0 => filter.clone(),
263 _ => prep_null_mask_filter(filter),
264 };
265
266 let strategy = IterationStrategy::default_strategy(filter.len(), count);
267
268 Self {
269 filter,
270 count,
271 strategy,
272 }
273 }
274
275 pub fn optimize(mut self) -> Self {
286 match self.strategy {
287 IterationStrategy::SlicesIterator => {
288 let slices = SlicesIterator::new(&self.filter).collect();
289 self.strategy = IterationStrategy::Slices(slices)
290 }
291 IterationStrategy::IndexIterator => {
292 let indices = IndexIterator::new(&self.filter, self.count).collect();
293 self.strategy = IterationStrategy::Indices(indices)
294 }
295 _ => {}
296 }
297 self
298 }
299
300 pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
305 match data_type {
306 DataType::Struct(fields) => {
307 fields.len() > 1
308 || fields.len() == 1
309 && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
310 }
311 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
312 _ => false,
313 }
314 }
315
316 pub fn build(self) -> FilterPredicate {
318 FilterPredicate {
319 filter: self.filter,
320 count: self.count,
321 strategy: self.strategy,
322 }
323 }
324}
325
326#[derive(Debug)]
328enum IterationStrategy {
329 SlicesIterator,
331 IndexIterator,
333 Indices(Vec<usize>),
335 Slices(Vec<(usize, usize)>),
337 All,
339 None,
341}
342
343impl IterationStrategy {
344 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
347 if filter_length == 0 || filter_count == 0 {
348 return IterationStrategy::None;
349 }
350
351 if filter_count == filter_length {
352 return IterationStrategy::All;
353 }
354
355 let selectivity_frac = filter_count as f64 / filter_length as f64;
360 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
361 return IterationStrategy::SlicesIterator;
362 }
363 IterationStrategy::IndexIterator
364 }
365}
366
367pub(crate) enum FilterSelection<'a> {
369 None,
370 All { len: usize },
371 Slices(FilterSlices<'a>),
372 Indices(FilterIndices<'a>),
373}
374
375pub(crate) type FilterSlices<'a> =
376 FilterIterator<std::iter::Copied<std::slice::Iter<'a, (usize, usize)>>, SlicesIterator<'a>>;
377
378pub(crate) type FilterIndices<'a> =
379 FilterIterator<std::iter::Copied<std::slice::Iter<'a, usize>>, IndexIterator<'a>>;
380
381pub(crate) enum FilterIterator<M, I> {
387 Materialized(M),
388 Lazy(I),
389}
390
391impl<M, I> FilterIterator<M, I>
392where
393 M: Iterator,
394 I: Iterator<Item = M::Item>,
395{
396 pub(crate) fn for_each<F>(self, f: F)
397 where
398 F: FnMut(M::Item),
399 {
400 match self {
401 Self::Materialized(iter) => iter.for_each(f),
402 Self::Lazy(iter) => iter.for_each(f),
403 }
404 }
405
406 pub(crate) fn try_for_each<F, E>(self, mut f: F) -> Result<(), E>
407 where
408 F: FnMut(M::Item) -> Result<(), E>,
409 {
410 match self {
411 Self::Materialized(iter) => {
412 for item in iter {
413 f(item)?;
414 }
415 }
416 Self::Lazy(iter) => {
417 for item in iter {
418 f(item)?;
419 }
420 }
421 }
422
423 Ok(())
424 }
425}
426
427#[derive(Debug)]
429pub struct FilterPredicate {
430 filter: BooleanArray,
431 count: usize,
432 strategy: IterationStrategy,
433}
434
435impl FilterPredicate {
436 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
438 filter_array(values, self)
439 }
440
441 pub fn filter_record_batch(
446 &self,
447 record_batch: &RecordBatch,
448 ) -> Result<RecordBatch, ArrowError> {
449 let filtered_arrays = record_batch
450 .columns()
451 .iter()
452 .map(|a| filter_array(a, self))
453 .collect::<Result<Vec<_>, _>>()?;
454
455 unsafe {
458 Ok(RecordBatch::new_unchecked(
459 record_batch.schema(),
460 filtered_arrays,
461 self.count,
462 ))
463 }
464 }
465
466 pub fn count(&self) -> usize {
468 self.count
469 }
470
471 pub(crate) fn selection(&self) -> FilterSelection<'_> {
472 match &self.strategy {
473 IterationStrategy::None => FilterSelection::None,
474 IterationStrategy::All => FilterSelection::All { len: self.count },
475 IterationStrategy::Slices(slices) => {
476 FilterSelection::Slices(FilterIterator::Materialized(slices.iter().copied()))
477 }
478 IterationStrategy::SlicesIterator => {
479 FilterSelection::Slices(FilterIterator::Lazy(SlicesIterator::new(&self.filter)))
480 }
481 IterationStrategy::Indices(indices) => {
482 FilterSelection::Indices(FilterIterator::Materialized(indices.iter().copied()))
483 }
484 IterationStrategy::IndexIterator => FilterSelection::Indices(FilterIterator::Lazy(
485 IndexIterator::new(&self.filter, self.count),
486 )),
487 }
488 }
489
490 pub fn filter_nulls(&self, nulls: Option<&NullBuffer>) -> Option<NullBuffer> {
497 let nulls = nulls?;
498 if nulls.null_count() == 0 {
499 return None;
500 }
501
502 let nulls = filter_bits(nulls.inner(), self);
503 let null_count = self.count - nulls.count_set_bits_offset(0, self.count);
506
507 if null_count == 0 {
508 return None;
509 }
510
511 let buffer = BooleanBuffer::new(nulls, 0, self.count);
512 debug_assert_eq!(null_count, buffer.len() - buffer.count_set_bits());
513 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
516 }
517}
518
519fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
520 if predicate.filter.len() > values.len() {
521 return Err(ArrowError::InvalidArgumentError(format!(
522 "Filter predicate of length {} is larger than target array of length {}",
523 predicate.filter.len(),
524 values.len()
525 )));
526 }
527
528 match predicate.strategy {
529 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
530 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
531 _ => downcast_primitive_array! {
533 values => Ok(Arc::new(filter_primitive(values, predicate))),
534 DataType::Boolean => {
535 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
536 Ok(Arc::new(filter_boolean(values, predicate)))
537 }
538 DataType::Utf8 => {
539 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
540 }
541 DataType::LargeUtf8 => {
542 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
543 }
544 DataType::Utf8View => {
545 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
546 }
547 DataType::Binary => {
548 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
549 }
550 DataType::LargeBinary => {
551 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
552 }
553 DataType::BinaryView => {
554 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
555 }
556 DataType::FixedSizeBinary(_) => {
557 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
558 }
559 DataType::ListView(_) => {
560 Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
561 }
562 DataType::LargeListView(_) => {
563 Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
564 }
565 DataType::RunEndEncoded(_, _) => {
566 downcast_run_array!{
567 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
568 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
569 }
570 }
571 DataType::Dictionary(_, _) => downcast_dictionary_array! {
572 values => Ok(Arc::new(filter_dict(values, predicate))),
573 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
574 }
575 DataType::Struct(_) => {
576 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
577 }
578 DataType::Union(_, UnionMode::Sparse) => {
579 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
580 }
581 _ => {
582 let data = values.to_data();
583 let mut mutable = MutableArrayData::new(
585 vec![&data],
586 false,
587 predicate.count,
588 );
589
590 match &predicate.strategy {
591 IterationStrategy::Slices(slices) => {
592 for (start, end) in slices {
593 mutable.try_extend(0, *start, *end)?;
594 }
595 }
596 _ => {
597 let iter = SlicesIterator::new(&predicate.filter);
598 for (start, end) in iter {
599 mutable.try_extend(0, start, end)?;
600 }
601 }
602 }
603
604 let data = mutable.freeze();
605 Ok(make_array(data))
606 }
607 },
608 }
609}
610
611fn filter_run_end_array<R: RunEndIndexType>(
613 array: &RunArray<R>,
614 predicate: &FilterPredicate,
615) -> Result<RunArray<R>, ArrowError>
616where
617 R::Native: Into<i64> + From<bool>,
618 R::Native: AddAssign,
619{
620 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
621 let start_physical = run_ends.get_start_physical_index();
622 let end_physical = run_ends.get_end_physical_index();
623 let physical_len = end_physical - start_physical + 1;
624
625 let mut new_run_ends = vec![R::default_value(); physical_len];
626 let offset = run_ends.offset() as u64;
627
628 let mut start = 0u64;
629 let mut j = 0;
630 let mut count = R::default_value();
631 let filter_values = predicate.filter.values();
632 let run_ends = run_ends.inner();
633
634 let pred: BooleanArray = BooleanBuffer::collect_bool(physical_len, |i| {
635 let mut keep = false;
636 let mut end = (run_ends[i + start_physical].into() as u64).saturating_sub(offset);
637 let difference = end.saturating_sub(filter_values.len() as u64);
638 end -= difference;
639
640 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
642 count += R::Native::from(pred);
643 keep |= pred
644 }
645 new_run_ends[j] = count;
647 j += keep as usize;
648
649 start = end;
650 keep
651 })
652 .into();
653
654 new_run_ends.truncate(j);
655
656 let values = array.values_slice();
657 let values = filter(values.as_ref(), &pred)?;
658
659 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
660 RunArray::try_new(&run_ends, &values)
661}
662
663pub(crate) fn filter_null_mask(
670 nulls: Option<&NullBuffer>,
671 predicate: &FilterPredicate,
672) -> Option<(usize, Buffer)> {
673 let nulls = nulls?;
674 if nulls.null_count() == 0 {
675 return None;
676 }
677
678 let nulls = filter_bits(nulls.inner(), predicate);
679 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
682
683 if null_count == 0 {
684 return None;
685 }
686
687 Some((null_count, nulls))
688}
689
690fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
692 let src = buffer.values();
693 let offset = buffer.offset();
694 assert!(buffer.len() >= predicate.filter.len());
695
696 match &predicate.strategy {
697 IterationStrategy::IndexIterator => {
698 let bits =
699 IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
701 bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
702 });
703
704 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
706 }
707 IterationStrategy::Indices(indices) => {
708 let bits = indices.iter().map(|src_idx| unsafe {
710 bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
711 });
712 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
714 }
715 IterationStrategy::SlicesIterator => {
716 let mut builder = BooleanBufferBuilder::new(predicate.count);
717 for (start, end) in SlicesIterator::new(&predicate.filter) {
718 builder.append_packed_range(start + offset..end + offset, src)
719 }
720 builder.into()
721 }
722 IterationStrategy::Slices(slices) => {
723 let mut builder = BooleanBufferBuilder::new(predicate.count);
724 for (start, end) in slices {
725 builder.append_packed_range(*start + offset..*end + offset, src)
726 }
727 builder.into()
728 }
729 IterationStrategy::All | IterationStrategy::None => unreachable!(),
730 }
731}
732
733fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
735 let buffer = filter_bits(array.values(), predicate);
736 let values = BooleanBuffer::new(buffer, 0, predicate.count);
737 let nulls = predicate.filter_nulls(array.nulls());
738
739 BooleanArray::new(values, nulls)
740}
741
742#[inline(never)]
743pub(crate) fn filter_native<T: ArrowNativeType>(
744 values: &[T],
745 predicate: &FilterPredicate,
746) -> Buffer {
747 assert!(values.len() >= predicate.filter.len());
748
749 match &predicate.strategy {
750 IterationStrategy::SlicesIterator => {
751 let mut buffer = Vec::with_capacity(predicate.count);
752 for (start, end) in SlicesIterator::new(&predicate.filter) {
753 buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
755 }
756 buffer.into()
757 }
758 IterationStrategy::Slices(slices) => {
759 let mut buffer = Vec::with_capacity(predicate.count);
760 for (start, end) in slices {
761 buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
763 }
764 buffer.into()
765 }
766 IterationStrategy::IndexIterator => {
767 let iter = IndexIterator::new(&predicate.filter, predicate.count)
769 .map(|x| unsafe { *values.get_unchecked(x) });
770
771 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
773 }
774 IterationStrategy::Indices(indices) => {
775 let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
777 iter.collect::<Vec<_>>().into()
778 }
779 IterationStrategy::All | IterationStrategy::None => unreachable!(),
780 }
781}
782
783fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
785where
786 T: ArrowPrimitiveType,
787{
788 let buffer = filter_native(array.values(), predicate);
789 let values = ScalarBuffer::new(buffer, 0, predicate.count);
790 let nulls = predicate.filter_nulls(array.nulls());
791 let filtered = PrimitiveArray::new(values, nulls);
792
793 if array.data_type() == &T::DATA_TYPE {
795 filtered
796 } else {
797 filtered.with_data_type(array.data_type().clone())
798 }
799}
800
801struct FilterBytes<'a, OffsetSize> {
806 src_offsets: &'a [OffsetSize],
807 src_values: &'a [u8],
808 dst_offsets: Vec<OffsetSize>,
809 dst_values: Vec<u8>,
810 cur_offset: OffsetSize,
811}
812
813impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
814where
815 OffsetSize: OffsetSizeTrait,
816{
817 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
818 where
819 T: ByteArrayType<Offset = OffsetSize>,
820 {
821 let dst_values = Vec::new();
822 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
823 let cur_offset = OffsetSize::from_usize(0).unwrap();
824
825 dst_offsets.push(cur_offset);
826
827 Self {
828 src_offsets: array.value_offsets(),
829 src_values: array.value_data(),
830 dst_offsets,
831 dst_values,
832 cur_offset,
833 }
834 }
835
836 #[inline]
838 fn get_value_offset(&self, idx: usize) -> usize {
839 self.src_offsets[idx].as_usize()
840 }
841
842 #[inline]
844 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
845 let start = self.get_value_offset(idx);
847 let end = self.get_value_offset(idx + 1);
848 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
849 (start, end, len)
850 }
851
852 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
853 self.dst_offsets.extend(iter.map(|idx| {
854 let start = self.src_offsets[idx].as_usize();
855 let end = self.src_offsets[idx + 1].as_usize();
856 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
857 self.cur_offset += len;
858
859 self.cur_offset
860 }));
861 }
862
863 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
865 self.dst_values.reserve_exact(self.cur_offset.as_usize());
866
867 for idx in iter {
868 let start = self.src_offsets[idx].as_usize();
869 let end = self.src_offsets[idx + 1].as_usize();
870 self.dst_values
871 .extend_from_slice(&self.src_values[start..end]);
872 }
873 }
874
875 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
876 self.dst_offsets.reserve_exact(count);
877 for (start, end) in iter {
878 for idx in start..end {
880 let (_, _, len) = self.get_value_range(idx);
881 self.cur_offset += len;
882 self.dst_offsets.push(self.cur_offset);
883 }
884 }
885 }
886
887 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
889 self.dst_values.reserve_exact(self.cur_offset.as_usize());
890
891 for (start, end) in iter {
892 let value_start = self.get_value_offset(start);
893 let value_end = self.get_value_offset(end);
894 self.dst_values
895 .extend_from_slice(&self.src_values[value_start..value_end]);
896 }
897 }
898}
899
900fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
905where
906 T: ByteArrayType,
907{
908 let mut filter = FilterBytes::new(predicate.count, array);
909
910 match &predicate.strategy {
911 IterationStrategy::SlicesIterator => {
912 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
913 filter.extend_slices(SlicesIterator::new(&predicate.filter))
914 }
915 IterationStrategy::Slices(slices) => {
916 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
917 filter.extend_slices(slices.iter().cloned())
918 }
919 IterationStrategy::IndexIterator => {
920 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
921 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
922 }
923 IterationStrategy::Indices(indices) => {
924 filter.extend_offsets_idx(indices.iter().cloned());
925 filter.extend_idx(indices.iter().cloned())
926 }
927 IterationStrategy::All | IterationStrategy::None => unreachable!(),
928 }
929
930 let offsets = unsafe { OffsetBuffer::new_unchecked(filter.dst_offsets.into()) };
933 let nulls = predicate.filter_nulls(array.nulls());
934
935 unsafe { GenericByteArray::new_unchecked(offsets, filter.dst_values.into(), nulls) }
939}
940
941fn filter_byte_view<T: ByteViewType>(
943 array: &GenericByteViewArray<T>,
944 predicate: &FilterPredicate,
945) -> GenericByteViewArray<T> {
946 let new_view_buffer = filter_native(array.views(), predicate);
947 let views = ScalarBuffer::new(new_view_buffer, 0, predicate.count);
948 let buffers = array.data_buffers().to_vec();
949 let nulls = predicate.filter_nulls(array.nulls());
950
951 unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) }
955}
956
957fn filter_fixed_size_binary(
958 array: &FixedSizeBinaryArray,
959 predicate: &FilterPredicate,
960) -> FixedSizeBinaryArray {
961 let values: &[u8] = array.values();
962 let value_length = array.value_length() as usize;
963 let calculate_offset_from_index = |index: usize| index * value_length;
964 let buffer = match &predicate.strategy {
965 IterationStrategy::SlicesIterator => {
966 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
967 for (start, end) in SlicesIterator::new(&predicate.filter) {
968 buffer.extend_from_slice(
969 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
970 );
971 }
972 buffer
973 }
974 IterationStrategy::Slices(slices) => {
975 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
976 for (start, end) in slices {
977 buffer.extend_from_slice(
978 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
979 );
980 }
981 buffer
982 }
983 IterationStrategy::IndexIterator => {
984 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
985 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
986 });
987
988 let mut buffer = MutableBuffer::new(predicate.count * value_length);
989 iter.for_each(|item| buffer.extend_from_slice(item));
990 buffer
991 }
992 IterationStrategy::Indices(indices) => {
993 let iter = indices.iter().map(|x| {
994 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
995 });
996
997 let mut buffer = MutableBuffer::new(predicate.count * value_length);
998 iter.for_each(|item| buffer.extend_from_slice(item));
999 buffer
1000 }
1001 IterationStrategy::All | IterationStrategy::None => unreachable!(),
1002 };
1003
1004 let nulls = predicate.filter_nulls(array.nulls());
1005
1006 FixedSizeBinaryArray::new(array.value_length(), buffer.into(), nulls)
1007}
1008
1009fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
1011where
1012 T: ArrowDictionaryKeyType,
1013 T::Native: num_traits::Num,
1014{
1015 let builder = filter_primitive::<T>(array.keys(), predicate)
1016 .into_data()
1017 .into_builder()
1018 .data_type(array.data_type().clone())
1019 .child_data(vec![array.values().to_data()]);
1020
1021 DictionaryArray::from(unsafe { builder.build_unchecked() })
1024}
1025
1026fn filter_struct(
1028 array: &StructArray,
1029 predicate: &FilterPredicate,
1030) -> Result<StructArray, ArrowError> {
1031 let columns = array
1032 .columns()
1033 .iter()
1034 .map(|column| filter_array(column, predicate))
1035 .collect::<Result<_, _>>()?;
1036
1037 let nulls = predicate.filter_nulls(array.nulls());
1038
1039 Ok(unsafe {
1040 StructArray::new_unchecked_with_length(
1041 array.fields().clone(),
1042 columns,
1043 nulls,
1044 predicate.count(),
1045 )
1046 })
1047}
1048
1049fn filter_sparse_union(
1051 array: &UnionArray,
1052 predicate: &FilterPredicate,
1053) -> Result<UnionArray, ArrowError> {
1054 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
1055 unreachable!()
1056 };
1057
1058 let type_ids = filter_primitive(
1059 &Int8Array::try_new(array.type_ids().clone(), None)?,
1060 predicate,
1061 );
1062
1063 let children = fields
1064 .iter()
1065 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
1066 .collect::<Result<_, _>>()?;
1067
1068 Ok(unsafe {
1069 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
1070 })
1071}
1072
1073fn filter_list_view<OffsetType: OffsetSizeTrait>(
1075 array: &GenericListViewArray<OffsetType>,
1076 predicate: &FilterPredicate,
1077) -> GenericListViewArray<OffsetType> {
1078 let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
1079 let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
1080
1081 let field = match array.data_type() {
1082 DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
1083 _ => unreachable!(),
1084 };
1085 let offsets = ScalarBuffer::new(filtered_offsets, 0, predicate.count);
1086 let sizes = ScalarBuffer::new(filtered_sizes, 0, predicate.count);
1087 let values = array.values().clone();
1088 let nulls = predicate.filter_nulls(array.nulls());
1089
1090 unsafe { GenericListViewArray::new_unchecked(field, offsets, sizes, values, nulls) }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098 use super::*;
1099 use arrow_array::builder::*;
1100 use arrow_array::cast::as_run_array;
1101 use arrow_array::types::*;
1102 use rand::distr::uniform::{UniformSampler, UniformUsize};
1103 use rand::distr::{Alphanumeric, StandardUniform};
1104 use rand::prelude::*;
1105 use rand::rng;
1106
1107 macro_rules! def_temporal_test {
1108 ($test:ident, $array_type: ident, $data: expr) => {
1109 #[test]
1110 fn $test() {
1111 let a = $data;
1112 let b = BooleanArray::from(vec![true, false, true, false]);
1113 let c = filter(&a, &b).unwrap();
1114 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1115 assert_eq!(2, d.len());
1116 assert_eq!(1, d.value(0));
1117 assert_eq!(3, d.value(1));
1118 }
1119 };
1120 }
1121
1122 def_temporal_test!(
1123 test_filter_date32,
1124 Date32Array,
1125 Date32Array::from(vec![1, 2, 3, 4])
1126 );
1127 def_temporal_test!(
1128 test_filter_date64,
1129 Date64Array,
1130 Date64Array::from(vec![1, 2, 3, 4])
1131 );
1132 def_temporal_test!(
1133 test_filter_time32_second,
1134 Time32SecondArray,
1135 Time32SecondArray::from(vec![1, 2, 3, 4])
1136 );
1137 def_temporal_test!(
1138 test_filter_time32_millisecond,
1139 Time32MillisecondArray,
1140 Time32MillisecondArray::from(vec![1, 2, 3, 4])
1141 );
1142 def_temporal_test!(
1143 test_filter_time64_microsecond,
1144 Time64MicrosecondArray,
1145 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1146 );
1147 def_temporal_test!(
1148 test_filter_time64_nanosecond,
1149 Time64NanosecondArray,
1150 Time64NanosecondArray::from(vec![1, 2, 3, 4])
1151 );
1152 def_temporal_test!(
1153 test_filter_duration_second,
1154 DurationSecondArray,
1155 DurationSecondArray::from(vec![1, 2, 3, 4])
1156 );
1157 def_temporal_test!(
1158 test_filter_duration_millisecond,
1159 DurationMillisecondArray,
1160 DurationMillisecondArray::from(vec![1, 2, 3, 4])
1161 );
1162 def_temporal_test!(
1163 test_filter_duration_microsecond,
1164 DurationMicrosecondArray,
1165 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1166 );
1167 def_temporal_test!(
1168 test_filter_duration_nanosecond,
1169 DurationNanosecondArray,
1170 DurationNanosecondArray::from(vec![1, 2, 3, 4])
1171 );
1172 def_temporal_test!(
1173 test_filter_timestamp_second,
1174 TimestampSecondArray,
1175 TimestampSecondArray::from(vec![1, 2, 3, 4])
1176 );
1177 def_temporal_test!(
1178 test_filter_timestamp_millisecond,
1179 TimestampMillisecondArray,
1180 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1181 );
1182 def_temporal_test!(
1183 test_filter_timestamp_microsecond,
1184 TimestampMicrosecondArray,
1185 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1186 );
1187 def_temporal_test!(
1188 test_filter_timestamp_nanosecond,
1189 TimestampNanosecondArray,
1190 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1191 );
1192
1193 #[test]
1194 fn test_filter_array_slice() {
1195 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1196 let b = BooleanArray::from(vec![true, false, false, true]);
1197 let c = filter(&a, &b).unwrap();
1201 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1202 assert_eq!(2, d.len());
1203 assert_eq!(6, d.value(0));
1204 assert_eq!(9, d.value(1));
1205 }
1206
1207 #[test]
1208 fn test_filter_array_low_density() {
1209 let mut data_values = (1..=65).collect::<Vec<i32>>();
1211 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1212 data_values.extend_from_slice(&[66, 67]);
1214 filter_values.extend_from_slice(&[false, true]);
1215 let a = Int32Array::from(data_values);
1216 let b = BooleanArray::from(filter_values);
1217 let c = filter(&a, &b).unwrap();
1218 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1219 assert_eq!(2, d.len());
1220 assert_eq!(65, d.value(0));
1221 assert_eq!(67, d.value(1));
1222 }
1223
1224 #[test]
1225 fn test_filter_array_high_density() {
1226 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1228 let mut filter_values = (1..=65)
1229 .map(|i| !matches!(i % 65, 0))
1230 .collect::<Vec<bool>>();
1231 data_values[1] = None;
1233 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1235 filter_values.extend_from_slice(&[false, true, true, true]);
1236 let a = Int32Array::from(data_values);
1237 let b = BooleanArray::from(filter_values);
1238 let c = filter(&a, &b).unwrap();
1239 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1240 assert_eq!(67, d.len());
1241 assert_eq!(3, d.null_count());
1242 assert_eq!(1, d.value(0));
1243 assert!(d.is_null(1));
1244 assert_eq!(64, d.value(63));
1245 assert!(d.is_null(64));
1246 assert_eq!(67, d.value(65));
1247 }
1248
1249 #[test]
1250 fn test_filter_string_array_simple() {
1251 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1252 let b = BooleanArray::from(vec![true, false, true, false]);
1253 let c = filter(&a, &b).unwrap();
1254 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1255 assert_eq!(2, d.len());
1256 assert_eq!("hello", d.value(0));
1257 assert_eq!("world", d.value(1));
1258 }
1259
1260 #[test]
1261 fn test_filter_primitive_array_with_null() {
1262 let a = Int32Array::from(vec![Some(5), None]);
1263 let b = BooleanArray::from(vec![false, true]);
1264 let c = filter(&a, &b).unwrap();
1265 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1266 assert_eq!(1, d.len());
1267 assert!(d.is_null(0));
1268 }
1269
1270 #[test]
1271 fn test_filter_string_array_with_null() {
1272 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1273 let b = BooleanArray::from(vec![true, false, false, true]);
1274 let c = filter(&a, &b).unwrap();
1275 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1276 assert_eq!(2, d.len());
1277 assert_eq!("hello", d.value(0));
1278 assert!(!d.is_null(0));
1279 assert!(d.is_null(1));
1280 }
1281
1282 #[test]
1283 fn test_filter_binary_array_with_null() {
1284 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1285 let a = BinaryArray::from(data);
1286 let b = BooleanArray::from(vec![true, false, false, true]);
1287 let c = filter(&a, &b).unwrap();
1288 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1289 assert_eq!(2, d.len());
1290 assert_eq!(b"hello", d.value(0));
1291 assert!(!d.is_null(0));
1292 assert!(d.is_null(1));
1293 }
1294
1295 fn _test_filter_byte_view<T>()
1296 where
1297 T: ByteViewType,
1298 str: AsRef<T::Native>,
1299 T::Native: PartialEq,
1300 {
1301 let array = {
1302 let mut builder = GenericByteViewBuilder::<T>::new();
1304 builder.append_value("hello");
1305 builder.append_value("world");
1306 builder.append_null();
1307 builder.append_value("large payload over 12 bytes");
1308 builder.append_value("lulu");
1309 builder.finish()
1310 };
1311
1312 {
1313 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1314 let actual = filter(&array, &predicate).unwrap();
1315
1316 assert_eq!(actual.len(), 3);
1317
1318 let expected = {
1319 let mut builder = GenericByteViewBuilder::<T>::new();
1321 builder.append_value("hello");
1322 builder.append_null();
1323 builder.append_value("large payload over 12 bytes");
1324 builder.finish()
1325 };
1326
1327 assert_eq!(actual.as_ref(), &expected);
1328 }
1329
1330 {
1331 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1332 let actual = filter(&array, &predicate).unwrap();
1333
1334 assert_eq!(actual.len(), 2);
1335
1336 let expected = {
1337 let mut builder = GenericByteViewBuilder::<T>::new();
1339 builder.append_value("hello");
1340 builder.append_value("lulu");
1341 builder.finish()
1342 };
1343
1344 assert_eq!(actual.as_ref(), &expected);
1345 }
1346 }
1347
1348 #[test]
1349 fn test_filter_string_view() {
1350 _test_filter_byte_view::<StringViewType>()
1351 }
1352
1353 #[test]
1354 fn test_filter_binary_view() {
1355 _test_filter_byte_view::<BinaryViewType>()
1356 }
1357
1358 #[test]
1359 fn test_filter_fixed_binary() {
1360 let v1 = [1_u8, 2];
1361 let v2 = [3_u8, 4];
1362 let v3 = [5_u8, 6];
1363 let v = vec![&v1, &v2, &v3];
1364 let a = FixedSizeBinaryArray::try_from(v).unwrap();
1365 let b = BooleanArray::from(vec![true, false, true]);
1366 let c = filter(&a, &b).unwrap();
1367 let d = c
1368 .as_ref()
1369 .as_any()
1370 .downcast_ref::<FixedSizeBinaryArray>()
1371 .unwrap();
1372 assert_eq!(d.len(), 2);
1373 assert_eq!(d.value(0), &v1);
1374 assert_eq!(d.value(1), &v3);
1375 let c2 = FilterBuilder::new(&b)
1376 .optimize()
1377 .build()
1378 .filter(&a)
1379 .unwrap();
1380 let d2 = c2
1381 .as_ref()
1382 .as_any()
1383 .downcast_ref::<FixedSizeBinaryArray>()
1384 .unwrap();
1385 assert_eq!(d, d2);
1386
1387 let b = BooleanArray::from(vec![false, false, false]);
1388 let c = filter(&a, &b).unwrap();
1389 let d = c
1390 .as_ref()
1391 .as_any()
1392 .downcast_ref::<FixedSizeBinaryArray>()
1393 .unwrap();
1394 assert_eq!(d.len(), 0);
1395
1396 let b = BooleanArray::from(vec![true, true, true]);
1397 let c = filter(&a, &b).unwrap();
1398 let d = c
1399 .as_ref()
1400 .as_any()
1401 .downcast_ref::<FixedSizeBinaryArray>()
1402 .unwrap();
1403 assert_eq!(d.len(), 3);
1404 assert_eq!(d.value(0), &v1);
1405 assert_eq!(d.value(1), &v2);
1406 assert_eq!(d.value(2), &v3);
1407
1408 let b = BooleanArray::from(vec![false, false, true]);
1409 let c = filter(&a, &b).unwrap();
1410 let d = c
1411 .as_ref()
1412 .as_any()
1413 .downcast_ref::<FixedSizeBinaryArray>()
1414 .unwrap();
1415 assert_eq!(d.len(), 1);
1416 assert_eq!(d.value(0), &v3);
1417 let c2 = FilterBuilder::new(&b)
1418 .optimize()
1419 .build()
1420 .filter(&a)
1421 .unwrap();
1422 let d2 = c2
1423 .as_ref()
1424 .as_any()
1425 .downcast_ref::<FixedSizeBinaryArray>()
1426 .unwrap();
1427 assert_eq!(d, d2);
1428 }
1429
1430 #[test]
1431 fn test_filter_array_slice_with_null() {
1432 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1433 let b = BooleanArray::from(vec![true, false, false, true]);
1434 let c = filter(&a, &b).unwrap();
1438 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1439 assert_eq!(2, d.len());
1440 assert!(d.is_null(0));
1441 assert!(!d.is_null(1));
1442 assert_eq!(9, d.value(1));
1443 }
1444
1445 #[test]
1446 fn test_filter_run_end_encoding_array() {
1447 let run_ends = Int64Array::from(vec![2, 3, 8]);
1448 let values = Int64Array::from(vec![7, -2, 9]);
1449 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1450 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1451 let c = filter(&a, &b).unwrap();
1452 let actual: &RunArray<Int64Type> = as_run_array(&c);
1453 assert_eq!(4, actual.len());
1454
1455 let expected = RunArray::try_new(
1456 &Int64Array::from(vec![1, 2, 4]),
1457 &Int64Array::from(vec![7, -2, 9]),
1458 )
1459 .expect("Failed to make expected RunArray test is broken");
1460
1461 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1462 assert_eq!(actual.values(), expected.values())
1463 }
1464
1465 #[test]
1466 fn test_filter_run_end_encoding_array_sliced() {
1467 let run_ends = Int64Array::from(vec![2, 3, 8]);
1468 let values = Int64Array::from(vec![7, -2, 9]);
1469 let a = RunArray::try_new(&run_ends, &values).unwrap(); let a = a.slice(2, 3); let b = BooleanArray::from(vec![true, false, true]);
1472 let result = filter(&a, &b).unwrap();
1473
1474 let result = result.as_run::<Int64Type>();
1475 let result = result.downcast::<Int64Array>().unwrap();
1476
1477 let expected = vec![-2, 9];
1478 let actual = result.into_iter().flatten().collect::<Vec<_>>();
1479 assert_eq!(expected, actual);
1480 }
1481
1482 #[test]
1483 fn test_filter_run_end_encoding_array_remove_value() {
1484 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1485 let values = Int32Array::from(vec![7, -2, 9, -8]);
1486 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1487 let b = BooleanArray::from(vec![
1488 false, true, false, false, true, false, true, false, false, false,
1489 ]);
1490 let c = filter(&a, &b).unwrap();
1491 let actual: &RunArray<Int32Type> = as_run_array(&c);
1492 assert_eq!(3, actual.len());
1493
1494 let expected =
1495 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1496 .expect("Failed to make expected RunArray test is broken");
1497
1498 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1499 assert_eq!(actual.values(), expected.values())
1500 }
1501
1502 #[test]
1503 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1504 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1505 let values = Int16Array::from(vec![7, -2, 9, -8]);
1506 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1507 let b = BooleanArray::from(vec![
1508 false, false, false, false, false, false, true, false, false, false,
1509 ]);
1510 let c = filter(&a, &b).unwrap();
1511 let actual: &RunArray<Int16Type> = as_run_array(&c);
1512 assert_eq!(1, actual.len());
1513
1514 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1515 .expect("Failed to make expected RunArray test is broken");
1516
1517 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1518 assert_eq!(actual.values(), expected.values())
1519 }
1520
1521 #[test]
1522 fn test_filter_run_end_encoding_array_empty() {
1523 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1524 let values = Int64Array::from(vec![7, -2, 9, -8]);
1525 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1526 let b = BooleanArray::from(vec![
1527 false, false, false, false, false, false, false, false, false, false,
1528 ]);
1529 let c = filter(&a, &b).unwrap();
1530 let actual: &RunArray<Int64Type> = as_run_array(&c);
1531 assert_eq!(0, actual.len());
1532 }
1533
1534 #[test]
1535 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1536 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1537 let values = Int64Array::from(vec![7, -2, 9, -8]);
1538 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1539 let b = BooleanArray::from(vec![false, true, true]);
1540 let c = filter(&a, &b).unwrap();
1541 let actual: &RunArray<Int64Type> = as_run_array(&c);
1542 assert_eq!(2, actual.len());
1543
1544 let expected = RunArray::try_new(
1545 &Int64Array::from(vec![1, 2]),
1546 &Int64Array::from(vec![7, -2]),
1547 )
1548 .expect("Failed to make expected RunArray test is broken");
1549
1550 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1551 assert_eq!(actual.values(), expected.values())
1552 }
1553
1554 #[test]
1555 fn test_filter_dictionary_array() {
1556 let values = [Some("hello"), None, Some("world"), Some("!")];
1557 let a: Int8DictionaryArray = values.iter().copied().collect();
1558 let b = BooleanArray::from(vec![false, true, true, false]);
1559 let c = filter(&a, &b).unwrap();
1560 let d = c
1561 .as_ref()
1562 .as_any()
1563 .downcast_ref::<Int8DictionaryArray>()
1564 .unwrap();
1565 let value_array = d.values();
1566 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1567 assert_eq!(3, values.len());
1569 assert_eq!(2, d.len());
1571 assert!(d.is_null(0));
1572 assert_eq!("world", values.value(d.keys().value(1) as usize));
1573 }
1574
1575 #[test]
1576 fn test_filter_list_array() {
1577 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1578 let offsets = OffsetBuffer::new(vec![0i64, 3, 6, 8, 8].into());
1579 let value_array = Arc::new(Int32Array::from_iter_values(0..8));
1580 let nulls = Some(NullBuffer::from(vec![true, true, true, false]));
1581 let a = LargeListArray::new(field.clone(), offsets, value_array, nulls);
1583 let b = BooleanArray::from(vec![false, true, false, true]);
1584 let result = filter(&a, &b).unwrap();
1585
1586 let offsets = OffsetBuffer::new(vec![0i64, 3, 3].into());
1588 let value_array = Arc::new(Int32Array::from_iter_values([3, 4, 5]));
1589 let nulls = Some(NullBuffer::from(vec![true, false]));
1590 let expected: ArrayRef = Arc::new(LargeListArray::new(field, offsets, value_array, nulls));
1591
1592 assert_eq!(&expected, &result);
1593 }
1594
1595 fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1596 let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1598 list_array.append_value([Some(1), Some(2)]);
1599 list_array.append_null();
1600 list_array.append_value([]);
1601 list_array.append_value([Some(3), Some(4)]);
1602
1603 let list_array = list_array.finish();
1604 let predicate = BooleanArray::from_iter([true, false, true, false]);
1605
1606 let filtered = filter(&list_array, &predicate)
1608 .unwrap()
1609 .as_list_view::<T>()
1610 .clone();
1611
1612 let mut expected =
1613 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1614 expected.append_value([Some(1), Some(2)]);
1615 expected.append_value([]);
1616 let expected = expected.finish();
1617
1618 assert_eq!(&filtered, &expected);
1619 }
1620
1621 fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1622 let mut list_array =
1624 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1625 list_array.append_value([Some(1), Some(2)]);
1626 list_array.append_null();
1627 list_array.append_value([]);
1628 list_array.append_value([Some(3), Some(4)]);
1629
1630 let list_array = list_array.finish();
1631
1632 let sliced = list_array.slice(1, 3);
1634 let predicate = BooleanArray::from_iter([false, false, true]);
1635
1636 let filtered = filter(&sliced, &predicate)
1638 .unwrap()
1639 .as_list_view::<T>()
1640 .clone();
1641
1642 let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1643 expected.append_value([Some(3), Some(4)]);
1644 let expected = expected.finish();
1645
1646 assert_eq!(&filtered, &expected);
1647 }
1648
1649 #[test]
1650 fn test_filter_list_view_array() {
1651 test_case_filter_list_view::<i32>();
1652 test_case_filter_list_view::<i64>();
1653
1654 test_case_filter_sliced_list_view::<i32>();
1655 test_case_filter_sliced_list_view::<i64>();
1656 }
1657
1658 #[test]
1659 fn test_slice_iterator_bits() {
1660 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1661 let filter = BooleanArray::from(filter_values);
1662 let filter_count = filter.true_count();
1663
1664 let iter = SlicesIterator::new(&filter);
1665 let chunks = iter.collect::<Vec<_>>();
1666
1667 assert_eq!(chunks, vec![(1, 2)]);
1668 assert_eq!(filter_count, 1);
1669 }
1670
1671 #[test]
1672 fn test_slice_iterator_bits1() {
1673 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1674 let filter = BooleanArray::from(filter_values);
1675 let filter_count = filter.true_count();
1676
1677 let iter = SlicesIterator::new(&filter);
1678 let chunks = iter.collect::<Vec<_>>();
1679
1680 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1681 assert_eq!(filter_count, 64 - 1);
1682 }
1683
1684 #[test]
1685 fn test_slice_iterator_chunk_and_bits() {
1686 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1687 let filter = BooleanArray::from(filter_values);
1688 let filter_count = filter.true_count();
1689
1690 let iter = SlicesIterator::new(&filter);
1691 let chunks = iter.collect::<Vec<_>>();
1692
1693 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1694 assert_eq!(filter_count, 61 + 61 + 5);
1695 }
1696
1697 #[test]
1698 fn test_filter_selection_iterators() {
1699 let slices = [(0, 2), (4, 5)];
1700 let mut ranges = Vec::new();
1701 let selection: FilterSlices<'_> = FilterIterator::Materialized(slices.iter().copied());
1702 selection.for_each(|range| ranges.push(range));
1703 assert_eq!(ranges, slices);
1704
1705 let filter = BooleanArray::from(vec![true, true, false, false, true]);
1706 let mut ranges = Vec::new();
1707 let selection: FilterSlices<'_> = FilterIterator::Lazy(SlicesIterator::new(&filter));
1708 selection
1709 .try_for_each(|range| {
1710 ranges.push(range);
1711 Ok::<(), ArrowError>(())
1712 })
1713 .unwrap();
1714 assert_eq!(ranges, vec![(0, 2), (4, 5)]);
1715
1716 let indices = [1, 3, 5];
1717 let mut selected = Vec::new();
1718 let selection: FilterIndices<'_> = FilterIterator::Materialized(indices.iter().copied());
1719 selection.for_each(|idx| selected.push(idx));
1720 assert_eq!(selected, indices);
1721
1722 let filter = BooleanArray::from(vec![false, true, false, true]);
1723 let mut selected = Vec::new();
1724 let selection: FilterIndices<'_> = FilterIterator::Lazy(IndexIterator::new(&filter, 2));
1725 selection
1726 .try_for_each(|idx| {
1727 selected.push(idx);
1728 Ok::<(), ArrowError>(())
1729 })
1730 .unwrap();
1731 assert_eq!(selected, vec![1, 3]);
1732 }
1733
1734 #[test]
1735 fn test_null_mask() {
1736 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1737
1738 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1739 let out = filter(&a, &mask1).unwrap();
1740 assert_eq!(out.as_ref(), &a.slice(0, 2));
1741 }
1742
1743 #[test]
1744 fn test_filter_record_batch_no_columns() {
1745 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1746 let options = RecordBatchOptions::default().with_row_count(Some(100));
1747 let record_batch =
1748 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1749 let out = filter_record_batch(&record_batch, &pred).unwrap();
1750
1751 assert_eq!(out.num_rows(), 2);
1752 }
1753
1754 #[test]
1755 fn test_fast_path() {
1756 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1757
1758 let mask = BooleanArray::from(vec![true, true, true]);
1760 let out = filter(&a, &mask).unwrap();
1761 let b = out
1762 .as_any()
1763 .downcast_ref::<PrimitiveArray<Int64Type>>()
1764 .unwrap();
1765 assert_eq!(&a, b);
1766
1767 let mask = BooleanArray::from(vec![false, false, false]);
1769 let out = filter(&a, &mask).unwrap();
1770 assert_eq!(out.len(), 0);
1771 assert_eq!(out.data_type(), &DataType::Int64);
1772 }
1773
1774 #[test]
1775 fn test_slices() {
1776 let bools = std::iter::repeat_n(true, 10)
1778 .chain(std::iter::repeat_n(false, 30))
1779 .chain(std::iter::repeat_n(true, 20))
1780 .chain(std::iter::repeat_n(false, 17))
1781 .chain(std::iter::repeat_n(true, 4));
1782
1783 let bool_array: BooleanArray = bools.map(Some).collect();
1784
1785 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1786 let expected = vec![(0, 10), (40, 60), (77, 81)];
1787 assert_eq!(slices, expected);
1788
1789 let len = bool_array.len();
1791 let sliced_array = bool_array.slice(7, len - 10);
1792 let sliced_array = sliced_array
1793 .as_any()
1794 .downcast_ref::<BooleanArray>()
1795 .unwrap();
1796 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1797 let expected = vec![(0, 3), (33, 53), (70, 71)];
1798 assert_eq!(slices, expected);
1799 }
1800
1801 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1802 let mut rng = rng();
1803
1804 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1805 .take(mask_len)
1806 .collect();
1807
1808 let buffer = Buffer::from_iter(bools.iter().cloned());
1809
1810 let truncated_length = mask_len - offset - truncate;
1811
1812 let filter = BooleanArray::new(BooleanBuffer::new(buffer, offset, truncated_length), None);
1813
1814 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1815 .flat_map(|(start, end)| start..end)
1816 .collect();
1817
1818 let count = filter.true_count();
1819 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1820
1821 let expected_bits: Vec<_> = bools
1822 .iter()
1823 .skip(offset)
1824 .take(truncated_length)
1825 .enumerate()
1826 .flat_map(|(idx, v)| v.then(|| idx))
1827 .collect();
1828
1829 assert_eq!(slice_bits, expected_bits);
1830 assert_eq!(index_bits, expected_bits);
1831 }
1832
1833 #[test]
1834 #[cfg_attr(miri, ignore)]
1835 fn fuzz_test_slices_iterator() {
1836 let mut rng = rng();
1837
1838 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1839 for _ in 0..100 {
1840 let mask_len = rng.random_range(0..1024);
1841 let max_offset = 64.min(mask_len);
1842 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1843
1844 let max_truncate = 128.min(mask_len - offset);
1845 let truncate = uusize
1846 .sample(&mut rng)
1847 .checked_rem(max_truncate)
1848 .unwrap_or(0);
1849
1850 test_slices_fuzz(mask_len, offset, truncate);
1851 }
1852
1853 test_slices_fuzz(64, 0, 0);
1854 test_slices_fuzz(64, 8, 0);
1855 test_slices_fuzz(64, 8, 8);
1856 test_slices_fuzz(32, 8, 8);
1857 test_slices_fuzz(32, 5, 9);
1858 }
1859
1860 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1862 values
1863 .into_iter()
1864 .zip(predicate)
1865 .filter(|(_, x)| **x)
1866 .map(|(a, _)| a)
1867 .collect()
1868 }
1869
1870 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1872 where
1873 StandardUniform: Distribution<T>,
1874 {
1875 let mut rng = rng();
1876 (0..len)
1877 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1878 .collect()
1879 }
1880
1881 fn gen_strings(
1883 len: usize,
1884 valid_percent: f64,
1885 str_len_range: std::ops::Range<usize>,
1886 ) -> Vec<Option<String>> {
1887 let mut rng = rng();
1888 (0..len)
1889 .map(|_| {
1890 rng.random_bool(valid_percent).then(|| {
1891 let len = rng.random_range(str_len_range.clone());
1892 (0..len)
1893 .map(|_| char::from(rng.sample(Alphanumeric)))
1894 .collect()
1895 })
1896 })
1897 .collect()
1898 }
1899
1900 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1902 src.iter().map(|x| x.as_deref())
1903 }
1904
1905 #[test]
1906 #[cfg_attr(miri, ignore)]
1907 fn fuzz_filter() {
1908 let mut rng = rng();
1909
1910 for i in 0..100 {
1911 let filter_percent = match i {
1912 0..=4 => 1.,
1913 5..=10 => 0.,
1914 _ => rng.random_range(0.0..1.0),
1915 };
1916
1917 let valid_percent = rng.random_range(0.0..1.0);
1918
1919 let array_len = rng.random_range(32..256);
1920 let array_offset = rng.random_range(0..10);
1921
1922 let filter_offset = rng.random_range(0..10);
1924 let filter_truncate = rng.random_range(0..10);
1925 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1926 .take(array_len + filter_offset - filter_truncate)
1927 .collect();
1928
1929 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1930
1931 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1933 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1934 let bools = &bools[filter_offset..];
1935
1936 let values = gen_primitive(array_len + array_offset, valid_percent);
1938 let src = Int32Array::from_iter(values.iter().cloned());
1939
1940 let src = src.slice(array_offset, array_len);
1941 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1942 let values = &values[array_offset..];
1943
1944 let filtered = filter(src, predicate).unwrap();
1945 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1946 let actual: Vec<_> = array.iter().collect();
1947
1948 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1949
1950 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1952 let src = StringArray::from_iter(as_deref(&strings));
1953
1954 let src = src.slice(array_offset, array_len);
1955 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1956
1957 let filtered = filter(src, predicate).unwrap();
1958 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1959 let actual: Vec<_> = array.iter().collect();
1960
1961 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1962 assert_eq!(actual, expected_strings);
1963
1964 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1966
1967 let src = src.slice(array_offset, array_len);
1968 let src = src
1969 .as_any()
1970 .downcast_ref::<DictionaryArray<Int32Type>>()
1971 .unwrap();
1972
1973 let filtered = filter(src, predicate).unwrap();
1974
1975 let array = filtered
1976 .as_any()
1977 .downcast_ref::<DictionaryArray<Int32Type>>()
1978 .unwrap();
1979
1980 let values = array
1981 .values()
1982 .as_any()
1983 .downcast_ref::<StringArray>()
1984 .unwrap();
1985
1986 let actual: Vec<_> = array
1987 .keys()
1988 .iter()
1989 .map(|key| key.map(|key| values.value(key as usize)))
1990 .collect();
1991
1992 assert_eq!(actual, expected_strings);
1993 }
1994 }
1995
1996 #[test]
1997 fn test_filter_map() {
1998 let mut builder =
1999 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
2000 builder.keys().append_value("key1");
2002 builder.values().append_value(1);
2003 builder.append(true).unwrap();
2004 builder.keys().append_value("key2");
2005 builder.keys().append_value("key3");
2006 builder.values().append_value(2);
2007 builder.values().append_value(3);
2008 builder.append(true).unwrap();
2009 builder.append(false).unwrap();
2010 builder.keys().append_value("key1");
2011 builder.values().append_value(1);
2012 builder.append(true).unwrap();
2013 let maparray = Arc::new(builder.finish()) as ArrayRef;
2014
2015 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
2016 .into_iter()
2017 .collect::<BooleanArray>();
2018 let got = filter(&maparray, &indices).unwrap();
2019
2020 let mut builder =
2021 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
2022 builder.keys().append_value("key1");
2023 builder.values().append_value(1);
2024 builder.append(true).unwrap();
2025 builder.keys().append_value("key1");
2026 builder.values().append_value(1);
2027 builder.append(true).unwrap();
2028 let expected = Arc::new(builder.finish()) as ArrayRef;
2029
2030 assert_eq!(&expected, &got);
2031 }
2032
2033 #[test]
2034 fn test_filter_fixed_size_list_arrays() {
2035 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
2036 let value_array = Arc::new(Int32Array::from_iter_values(0..9));
2037 let array = FixedSizeListArray::new(field, 3, value_array, None);
2038
2039 let filter_array = BooleanArray::from(vec![true, false, false]);
2040
2041 let c = filter(&array, &filter_array).unwrap();
2042 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
2043
2044 assert_eq!(filtered.len(), 1);
2045
2046 let list = filtered.value(0);
2047 assert_eq!(
2048 &[0, 1, 2],
2049 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2050 );
2051
2052 let filter_array = BooleanArray::from(vec![true, false, true]);
2053
2054 let c = filter(&array, &filter_array).unwrap();
2055 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
2056
2057 assert_eq!(filtered.len(), 2);
2058
2059 let list = filtered.value(0);
2060 assert_eq!(
2061 &[0, 1, 2],
2062 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2063 );
2064 let list = filtered.value(1);
2065 assert_eq!(
2066 &[6, 7, 8],
2067 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2068 );
2069 }
2070
2071 #[test]
2072 fn test_filter_fixed_size_list_arrays_with_null() {
2073 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
2074 let value_array = Arc::new(Int32Array::from_iter_values(0..10));
2075 let nulls = Some(NullBuffer::from(vec![true, false, false, true, true]));
2076 let array = FixedSizeListArray::new(field, 2, value_array, nulls);
2077
2078 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
2079
2080 let c = filter(&array, &filter_array).unwrap();
2081 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
2082
2083 assert_eq!(filtered.len(), 3);
2084
2085 let list = filtered.value(0);
2086 assert_eq!(
2087 &[0, 1],
2088 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2089 );
2090 assert!(filtered.is_null(1));
2091 let list = filtered.value(2);
2092 assert_eq!(
2093 &[6, 7],
2094 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2095 );
2096 }
2097
2098 fn test_filter_union_array(array: UnionArray) {
2099 let filter_array = BooleanArray::from(vec![true, false, false]);
2100 let c = filter(&array, &filter_array).unwrap();
2101 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2102
2103 let mut builder = UnionBuilder::new_dense();
2104 builder.append::<Int32Type>("A", 1).unwrap();
2105 let expected_array = builder.build().unwrap();
2106
2107 compare_union_arrays(filtered, &expected_array);
2108
2109 let filter_array = BooleanArray::from(vec![true, false, true]);
2110 let c = filter(&array, &filter_array).unwrap();
2111 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2112
2113 let mut builder = UnionBuilder::new_dense();
2114 builder.append::<Int32Type>("A", 1).unwrap();
2115 builder.append::<Int32Type>("A", 34).unwrap();
2116 let expected_array = builder.build().unwrap();
2117
2118 compare_union_arrays(filtered, &expected_array);
2119
2120 let filter_array = BooleanArray::from(vec![true, true, false]);
2121 let c = filter(&array, &filter_array).unwrap();
2122 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2123
2124 let mut builder = UnionBuilder::new_dense();
2125 builder.append::<Int32Type>("A", 1).unwrap();
2126 builder.append::<Float64Type>("B", 3.2).unwrap();
2127 let expected_array = builder.build().unwrap();
2128
2129 compare_union_arrays(filtered, &expected_array);
2130 }
2131
2132 #[test]
2133 fn test_filter_union_array_dense() {
2134 let mut builder = UnionBuilder::new_dense();
2135 builder.append::<Int32Type>("A", 1).unwrap();
2136 builder.append::<Float64Type>("B", 3.2).unwrap();
2137 builder.append::<Int32Type>("A", 34).unwrap();
2138 let array = builder.build().unwrap();
2139
2140 test_filter_union_array(array);
2141 }
2142
2143 #[test]
2144 fn test_filter_run_union_array_dense() {
2145 let mut builder = UnionBuilder::new_dense();
2146 builder.append::<Int32Type>("A", 1).unwrap();
2147 builder.append::<Int32Type>("A", 3).unwrap();
2148 builder.append::<Int32Type>("A", 34).unwrap();
2149 let array = builder.build().unwrap();
2150
2151 let filter_array = BooleanArray::from(vec![true, true, false]);
2152 let c = filter(&array, &filter_array).unwrap();
2153 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2154
2155 let mut builder = UnionBuilder::new_dense();
2156 builder.append::<Int32Type>("A", 1).unwrap();
2157 builder.append::<Int32Type>("A", 3).unwrap();
2158 let expected = builder.build().unwrap();
2159
2160 assert_eq!(filtered.to_data(), expected.to_data());
2161 }
2162
2163 #[test]
2164 fn test_filter_union_array_dense_with_nulls() {
2165 let mut builder = UnionBuilder::new_dense();
2166 builder.append::<Int32Type>("A", 1).unwrap();
2167 builder.append::<Float64Type>("B", 3.2).unwrap();
2168 builder.append_null::<Float64Type>("B").unwrap();
2169 builder.append::<Int32Type>("A", 34).unwrap();
2170 let array = builder.build().unwrap();
2171
2172 let filter_array = BooleanArray::from(vec![true, true, false, false]);
2173 let c = filter(&array, &filter_array).unwrap();
2174 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2175
2176 let mut builder = UnionBuilder::new_dense();
2177 builder.append::<Int32Type>("A", 1).unwrap();
2178 builder.append::<Float64Type>("B", 3.2).unwrap();
2179 let expected_array = builder.build().unwrap();
2180
2181 compare_union_arrays(filtered, &expected_array);
2182
2183 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2184 let c = filter(&array, &filter_array).unwrap();
2185 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2186
2187 let mut builder = UnionBuilder::new_dense();
2188 builder.append::<Int32Type>("A", 1).unwrap();
2189 builder.append_null::<Float64Type>("B").unwrap();
2190 let expected_array = builder.build().unwrap();
2191
2192 compare_union_arrays(filtered, &expected_array);
2193 }
2194
2195 #[test]
2196 fn test_filter_union_array_sparse() {
2197 let mut builder = UnionBuilder::new_sparse();
2198 builder.append::<Int32Type>("A", 1).unwrap();
2199 builder.append::<Float64Type>("B", 3.2).unwrap();
2200 builder.append::<Int32Type>("A", 34).unwrap();
2201 let array = builder.build().unwrap();
2202
2203 test_filter_union_array(array);
2204 }
2205
2206 #[test]
2207 fn test_filter_union_array_sparse_with_nulls() {
2208 let mut builder = UnionBuilder::new_sparse();
2209 builder.append::<Int32Type>("A", 1).unwrap();
2210 builder.append::<Float64Type>("B", 3.2).unwrap();
2211 builder.append_null::<Float64Type>("B").unwrap();
2212 builder.append::<Int32Type>("A", 34).unwrap();
2213 let array = builder.build().unwrap();
2214
2215 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2216 let c = filter(&array, &filter_array).unwrap();
2217 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2218
2219 let mut builder = UnionBuilder::new_sparse();
2220 builder.append::<Int32Type>("A", 1).unwrap();
2221 builder.append_null::<Float64Type>("B").unwrap();
2222 let expected_array = builder.build().unwrap();
2223
2224 compare_union_arrays(filtered, &expected_array);
2225 }
2226
2227 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2228 assert_eq!(union1.len(), union2.len());
2229
2230 for i in 0..union1.len() {
2231 let type_id = union1.type_id(i);
2232
2233 let slot1 = union1.value(i);
2234 let slot2 = union2.value(i);
2235
2236 assert_eq!(slot1.is_null(0), slot2.is_null(0));
2237
2238 if !slot1.is_null(0) && !slot2.is_null(0) {
2239 match type_id {
2240 0 => {
2241 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2242 assert_eq!(slot1.len(), 1);
2243 let value1 = slot1.value(0);
2244
2245 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2246 assert_eq!(slot2.len(), 1);
2247 let value2 = slot2.value(0);
2248 assert_eq!(value1, value2);
2249 }
2250 1 => {
2251 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2252 assert_eq!(slot1.len(), 1);
2253 let value1 = slot1.value(0);
2254
2255 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2256 assert_eq!(slot2.len(), 1);
2257 let value2 = slot2.value(0);
2258 assert_eq!(value1, value2);
2259 }
2260 _ => unreachable!(),
2261 }
2262 }
2263 }
2264 }
2265
2266 #[test]
2267 fn test_filter_struct() {
2268 let predicate = BooleanArray::from(vec![true, false, true, false]);
2269
2270 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2271 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2272
2273 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2274 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2275
2276 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2277 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2278
2279 let a_field = Field::new("a", DataType::Utf8, false);
2280 let b_field = Field::new("b", DataType::Int32, false);
2281
2282 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2283 let expected =
2284 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2285
2286 let result = filter(&array, &predicate).unwrap();
2287
2288 assert_eq!(result.to_data(), expected.to_data());
2289
2290 let array = StructArray::new(
2291 vec![a_field.clone()].into(),
2292 vec![a.clone()],
2293 Some(null_mask.clone()),
2294 );
2295 let expected = StructArray::new(
2296 vec![a_field.clone()].into(),
2297 vec![a_filtered.clone()],
2298 Some(null_mask_filtered.clone()),
2299 );
2300
2301 let result = filter(&array, &predicate).unwrap();
2302
2303 assert_eq!(result.to_data(), expected.to_data());
2304
2305 let array = StructArray::new(
2306 vec![a_field.clone(), b_field.clone()].into(),
2307 vec![a.clone(), b.clone()],
2308 None,
2309 );
2310 let expected = StructArray::new(
2311 vec![a_field.clone(), b_field.clone()].into(),
2312 vec![a_filtered.clone(), b_filtered.clone()],
2313 None,
2314 );
2315
2316 let result = filter(&array, &predicate).unwrap();
2317
2318 assert_eq!(result.to_data(), expected.to_data());
2319
2320 let array = StructArray::new(
2321 vec![a_field.clone(), b_field.clone()].into(),
2322 vec![a.clone(), b.clone()],
2323 Some(null_mask.clone()),
2324 );
2325
2326 let expected = StructArray::new(
2327 vec![a_field.clone(), b_field.clone()].into(),
2328 vec![a_filtered.clone(), b_filtered.clone()],
2329 Some(null_mask_filtered.clone()),
2330 );
2331
2332 let result = filter(&array, &predicate).unwrap();
2333
2334 assert_eq!(result.to_data(), expected.to_data());
2335 }
2336
2337 #[test]
2338 fn test_filter_empty_struct() {
2339 let fields = arrow_schema::Field::new(
2346 "a",
2347 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2348 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2349 arrow_schema::Field::new(
2350 "c",
2351 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2352 true,
2353 ),
2354 ])),
2355 true,
2356 );
2357
2358 let schema = Arc::new(Schema::new(vec![fields]));
2366
2367 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2368 let c = Arc::new(StructArray::new_empty_fields(
2369 3,
2370 Some(NullBuffer::from(vec![true, true, true])),
2371 ));
2372 let a = StructArray::new(
2373 vec![
2374 Field::new("b", DataType::Int64, true),
2375 Field::new("c", DataType::Struct(Fields::empty()), true),
2376 ]
2377 .into(),
2378 vec![b.clone(), c.clone()],
2379 Some(NullBuffer::from(vec![true, true, true])),
2380 );
2381 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2382 println!("{record_batch:?}");
2383
2384 let predicate = BooleanArray::from(vec![true, false, true]);
2386 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2387
2388 assert_eq!(filtered_batch.num_rows(), 2);
2390 }
2391
2392 #[test]
2393 #[should_panic]
2394 fn test_filter_bits_too_large() {
2395 let buffer = BooleanBuffer::from(vec![false; 8]);
2396 let predicate = BooleanArray::from(vec![true; 9]);
2397 let filter = FilterBuilder::new(&predicate).build();
2398 filter_bits(&buffer, &filter);
2399 }
2400
2401 #[test]
2402 #[should_panic]
2403 fn test_filter_native_too_large() {
2404 let values = vec![1; 8];
2405 let predicate = BooleanArray::from(vec![false; 9]);
2406 let filter = FilterBuilder::new(&predicate).build();
2407 filter_native(&values, &filter);
2408 }
2409}