1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 filter.values().into()
62 }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66 fn from(filter: &'a BooleanBuffer) -> Self {
67 Self(filter.set_slices())
68 }
69}
70
71impl Iterator for SlicesIterator<'_> {
72 type Item = (usize, usize);
73
74 fn next(&mut self) -> Option<Self::Item> {
75 self.0.next()
76 }
77}
78
79struct IndexIterator<'a> {
84 remaining: usize,
85 iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90 assert_eq!(filter.null_count(), 0);
91 let iter = filter.values().set_indices();
92 Self { remaining, iter }
93 }
94}
95
96impl Iterator for IndexIterator<'_> {
97 type Item = usize;
98
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.remaining != 0 {
101 let next = self.iter.next().expect("IndexIterator exhausted early");
104 self.remaining -= 1;
105 return Some(next);
107 }
108 None
109 }
110
111 fn size_hint(&self) -> (usize, Option<usize>) {
112 (self.remaining, Some(self.remaining))
113 }
114}
115
116fn filter_count(filter: &BooleanArray) -> usize {
118 filter.values().count_set_bits()
119}
120
121pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
123 let nulls = filter.nulls().unwrap();
124 let mask = filter.values() & nulls.inner();
125 BooleanArray::new(mask, None)
126}
127
128pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
157 let mut filter_builder = FilterBuilder::new(predicate);
158
159 if FilterBuilder::is_optimize_beneficial(values.data_type()) {
160 filter_builder = filter_builder.optimize();
163 }
164
165 let predicate = filter_builder.build();
166
167 filter_array(values, &predicate)
168}
169
170pub fn filter_record_batch(
181 record_batch: &RecordBatch,
182 predicate: &BooleanArray,
183) -> Result<RecordBatch, ArrowError> {
184 let mut filter_builder = FilterBuilder::new(predicate);
185 let num_cols = record_batch.num_columns();
186 if num_cols > 1
187 || (num_cols > 0
188 && FilterBuilder::is_optimize_beneficial(
189 record_batch.schema_ref().field(0).data_type(),
190 ))
191 {
192 filter_builder = filter_builder.optimize();
195 }
196 let filter = filter_builder.build();
197
198 filter.filter_record_batch(record_batch)
199}
200
201#[derive(Debug)]
203pub struct FilterBuilder {
204 filter: BooleanArray,
205 count: usize,
206 strategy: IterationStrategy,
207}
208
209impl FilterBuilder {
210 pub fn new(filter: &BooleanArray) -> Self {
212 let filter = match filter.null_count() {
213 0 => filter.clone(),
214 _ => prep_null_mask_filter(filter),
215 };
216
217 let count = filter_count(&filter);
218 let strategy = IterationStrategy::default_strategy(filter.len(), count);
219
220 Self {
221 filter,
222 count,
223 strategy,
224 }
225 }
226
227 pub fn optimize(mut self) -> Self {
238 match self.strategy {
239 IterationStrategy::SlicesIterator => {
240 let slices = SlicesIterator::new(&self.filter).collect();
241 self.strategy = IterationStrategy::Slices(slices)
242 }
243 IterationStrategy::IndexIterator => {
244 let indices = IndexIterator::new(&self.filter, self.count).collect();
245 self.strategy = IterationStrategy::Indices(indices)
246 }
247 _ => {}
248 }
249 self
250 }
251
252 pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
257 match data_type {
258 DataType::Struct(fields) => {
259 fields.len() > 1
260 || fields.len() == 1
261 && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
262 }
263 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
264 _ => false,
265 }
266 }
267
268 pub fn build(self) -> FilterPredicate {
270 FilterPredicate {
271 filter: self.filter,
272 count: self.count,
273 strategy: self.strategy,
274 }
275 }
276}
277
278#[derive(Debug)]
280enum IterationStrategy {
281 SlicesIterator,
283 IndexIterator,
285 Indices(Vec<usize>),
287 Slices(Vec<(usize, usize)>),
289 All,
291 None,
293}
294
295impl IterationStrategy {
296 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
299 if filter_length == 0 || filter_count == 0 {
300 return IterationStrategy::None;
301 }
302
303 if filter_count == filter_length {
304 return IterationStrategy::All;
305 }
306
307 let selectivity_frac = filter_count as f64 / filter_length as f64;
312 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
313 return IterationStrategy::SlicesIterator;
314 }
315 IterationStrategy::IndexIterator
316 }
317}
318
319#[derive(Debug)]
321pub struct FilterPredicate {
322 filter: BooleanArray,
323 count: usize,
324 strategy: IterationStrategy,
325}
326
327impl FilterPredicate {
328 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
330 filter_array(values, self)
331 }
332
333 pub fn filter_record_batch(
338 &self,
339 record_batch: &RecordBatch,
340 ) -> Result<RecordBatch, ArrowError> {
341 let filtered_arrays = record_batch
342 .columns()
343 .iter()
344 .map(|a| filter_array(a, self))
345 .collect::<Result<Vec<_>, _>>()?;
346
347 unsafe {
350 Ok(RecordBatch::new_unchecked(
351 record_batch.schema(),
352 filtered_arrays,
353 self.count,
354 ))
355 }
356 }
357
358 pub fn count(&self) -> usize {
360 self.count
361 }
362}
363
364fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
365 if predicate.filter.len() > values.len() {
366 return Err(ArrowError::InvalidArgumentError(format!(
367 "Filter predicate of length {} is larger than target array of length {}",
368 predicate.filter.len(),
369 values.len()
370 )));
371 }
372
373 match predicate.strategy {
374 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
375 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
376 _ => downcast_primitive_array! {
378 values => Ok(Arc::new(filter_primitive(values, predicate))),
379 DataType::Boolean => {
380 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
381 Ok(Arc::new(filter_boolean(values, predicate)))
382 }
383 DataType::Utf8 => {
384 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
385 }
386 DataType::LargeUtf8 => {
387 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
388 }
389 DataType::Utf8View => {
390 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
391 }
392 DataType::Binary => {
393 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
394 }
395 DataType::LargeBinary => {
396 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
397 }
398 DataType::BinaryView => {
399 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
400 }
401 DataType::FixedSizeBinary(_) => {
402 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
403 }
404 DataType::ListView(_) => {
405 Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
406 }
407 DataType::LargeListView(_) => {
408 Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
409 }
410 DataType::RunEndEncoded(_, _) => {
411 downcast_run_array!{
412 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
413 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
414 }
415 }
416 DataType::Dictionary(_, _) => downcast_dictionary_array! {
417 values => Ok(Arc::new(filter_dict(values, predicate))),
418 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
419 }
420 DataType::Struct(_) => {
421 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
422 }
423 DataType::Union(_, UnionMode::Sparse) => {
424 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
425 }
426 _ => {
427 let data = values.to_data();
428 let mut mutable = MutableArrayData::new(
430 vec![&data],
431 false,
432 predicate.count,
433 );
434
435 match &predicate.strategy {
436 IterationStrategy::Slices(slices) => {
437 slices
438 .iter()
439 .for_each(|(start, end)| mutable.extend(0, *start, *end));
440 }
441 _ => {
442 let iter = SlicesIterator::new(&predicate.filter);
443 iter.for_each(|(start, end)| mutable.extend(0, start, end));
444 }
445 }
446
447 let data = mutable.freeze();
448 Ok(make_array(data))
449 }
450 },
451 }
452}
453
454fn filter_run_end_array<R: RunEndIndexType>(
456 array: &RunArray<R>,
457 predicate: &FilterPredicate,
458) -> Result<RunArray<R>, ArrowError>
459where
460 R::Native: Into<i64> + From<bool>,
461 R::Native: AddAssign,
462{
463 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
464 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
465
466 let mut start = 0u64;
467 let mut j = 0;
468 let mut count = R::default_value();
469 let filter_values = predicate.filter.values();
470 let run_ends = run_ends.inner();
471
472 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
473 let mut keep = false;
474 let mut end = run_ends[i].into() as u64;
475 let difference = end.saturating_sub(filter_values.len() as u64);
476 end -= difference;
477
478 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
480 count += R::Native::from(pred);
481 keep |= pred
482 }
483 new_run_ends[j] = count;
485 j += keep as usize;
486
487 start = end;
488 keep
489 })
490 .into();
491
492 new_run_ends.truncate(j);
493
494 let values = array.values();
495 let values = filter(&values, &pred)?;
496
497 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
498 RunArray::try_new(&run_ends, &values)
499}
500
501fn filter_null_mask(
508 nulls: Option<&NullBuffer>,
509 predicate: &FilterPredicate,
510) -> Option<(usize, Buffer)> {
511 let nulls = nulls?;
512 if nulls.null_count() == 0 {
513 return None;
514 }
515
516 let nulls = filter_bits(nulls.inner(), predicate);
517 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
520
521 if null_count == 0 {
522 return None;
523 }
524
525 Some((null_count, nulls))
526}
527
528fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
530 let src = buffer.values();
531 let offset = buffer.offset();
532 assert!(buffer.len() >= predicate.filter.len());
533
534 match &predicate.strategy {
535 IterationStrategy::IndexIterator => {
536 let bits =
537 IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
539 bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
540 });
541
542 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
544 }
545 IterationStrategy::Indices(indices) => {
546 let bits = indices.iter().map(|src_idx| unsafe {
548 bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
549 });
550 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
552 }
553 IterationStrategy::SlicesIterator => {
554 let mut builder = BooleanBufferBuilder::new(predicate.count);
555 for (start, end) in SlicesIterator::new(&predicate.filter) {
556 builder.append_packed_range(start + offset..end + offset, src)
557 }
558 builder.into()
559 }
560 IterationStrategy::Slices(slices) => {
561 let mut builder = BooleanBufferBuilder::new(predicate.count);
562 for (start, end) in slices {
563 builder.append_packed_range(*start + offset..*end + offset, src)
564 }
565 builder.into()
566 }
567 IterationStrategy::All | IterationStrategy::None => unreachable!(),
568 }
569}
570
571fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
573 let values = filter_bits(array.values(), predicate);
574
575 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
576 .len(predicate.count)
577 .add_buffer(values);
578
579 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
580 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
581 }
582
583 let data = unsafe { builder.build_unchecked() };
584 BooleanArray::from(data)
585}
586
587#[inline(never)]
588fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
589 assert!(values.len() >= predicate.filter.len());
590
591 match &predicate.strategy {
592 IterationStrategy::SlicesIterator => {
593 let mut buffer = Vec::with_capacity(predicate.count);
594 for (start, end) in SlicesIterator::new(&predicate.filter) {
595 buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
597 }
598 buffer.into()
599 }
600 IterationStrategy::Slices(slices) => {
601 let mut buffer = Vec::with_capacity(predicate.count);
602 for (start, end) in slices {
603 buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
605 }
606 buffer.into()
607 }
608 IterationStrategy::IndexIterator => {
609 let iter = IndexIterator::new(&predicate.filter, predicate.count)
611 .map(|x| unsafe { *values.get_unchecked(x) });
612
613 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
615 }
616 IterationStrategy::Indices(indices) => {
617 let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
619 iter.collect::<Vec<_>>().into()
620 }
621 IterationStrategy::All | IterationStrategy::None => unreachable!(),
622 }
623}
624
625fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
627where
628 T: ArrowPrimitiveType,
629{
630 let values = array.values();
631 let buffer = filter_native(values, predicate);
632 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
633 .len(predicate.count)
634 .add_buffer(buffer);
635
636 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
637 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
638 }
639
640 let data = unsafe { builder.build_unchecked() };
641 PrimitiveArray::from(data)
642}
643
644struct FilterBytes<'a, OffsetSize> {
649 src_offsets: &'a [OffsetSize],
650 src_values: &'a [u8],
651 dst_offsets: Vec<OffsetSize>,
652 dst_values: Vec<u8>,
653 cur_offset: OffsetSize,
654}
655
656impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
657where
658 OffsetSize: OffsetSizeTrait,
659{
660 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
661 where
662 T: ByteArrayType<Offset = OffsetSize>,
663 {
664 let dst_values = Vec::new();
665 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
666 let cur_offset = OffsetSize::from_usize(0).unwrap();
667
668 dst_offsets.push(cur_offset);
669
670 Self {
671 src_offsets: array.value_offsets(),
672 src_values: array.value_data(),
673 dst_offsets,
674 dst_values,
675 cur_offset,
676 }
677 }
678
679 #[inline]
681 fn get_value_offset(&self, idx: usize) -> usize {
682 self.src_offsets[idx].as_usize()
683 }
684
685 #[inline]
687 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
688 let start = self.get_value_offset(idx);
690 let end = self.get_value_offset(idx + 1);
691 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
692 (start, end, len)
693 }
694
695 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
696 self.dst_offsets.extend(iter.map(|idx| {
697 let start = self.src_offsets[idx].as_usize();
698 let end = self.src_offsets[idx + 1].as_usize();
699 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
700 self.cur_offset += len;
701
702 self.cur_offset
703 }));
704 }
705
706 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
708 self.dst_values.reserve_exact(self.cur_offset.as_usize());
709
710 for idx in iter {
711 let start = self.src_offsets[idx].as_usize();
712 let end = self.src_offsets[idx + 1].as_usize();
713 self.dst_values
714 .extend_from_slice(&self.src_values[start..end]);
715 }
716 }
717
718 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
719 self.dst_offsets.reserve_exact(count);
720 for (start, end) in iter {
721 for idx in start..end {
723 let (_, _, len) = self.get_value_range(idx);
724 self.cur_offset += len;
725 self.dst_offsets.push(self.cur_offset);
726 }
727 }
728 }
729
730 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
732 self.dst_values.reserve_exact(self.cur_offset.as_usize());
733
734 for (start, end) in iter {
735 let value_start = self.get_value_offset(start);
736 let value_end = self.get_value_offset(end);
737 self.dst_values
738 .extend_from_slice(&self.src_values[value_start..value_end]);
739 }
740 }
741}
742
743fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
748where
749 T: ByteArrayType,
750{
751 let mut filter = FilterBytes::new(predicate.count, array);
752
753 match &predicate.strategy {
754 IterationStrategy::SlicesIterator => {
755 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
756 filter.extend_slices(SlicesIterator::new(&predicate.filter))
757 }
758 IterationStrategy::Slices(slices) => {
759 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
760 filter.extend_slices(slices.iter().cloned())
761 }
762 IterationStrategy::IndexIterator => {
763 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
764 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
765 }
766 IterationStrategy::Indices(indices) => {
767 filter.extend_offsets_idx(indices.iter().cloned());
768 filter.extend_idx(indices.iter().cloned())
769 }
770 IterationStrategy::All | IterationStrategy::None => unreachable!(),
771 }
772
773 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
774 .len(predicate.count)
775 .add_buffer(filter.dst_offsets.into())
776 .add_buffer(filter.dst_values.into());
777
778 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
779 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
780 }
781
782 let data = unsafe { builder.build_unchecked() };
783 GenericByteArray::from(data)
784}
785
786fn filter_byte_view<T: ByteViewType>(
788 array: &GenericByteViewArray<T>,
789 predicate: &FilterPredicate,
790) -> GenericByteViewArray<T> {
791 let new_view_buffer = filter_native(array.views(), predicate);
792
793 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
794 .len(predicate.count)
795 .add_buffer(new_view_buffer)
796 .add_buffers(array.data_buffers().to_vec());
797
798 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
799 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
800 }
801
802 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
803}
804
805fn filter_fixed_size_binary(
806 array: &FixedSizeBinaryArray,
807 predicate: &FilterPredicate,
808) -> FixedSizeBinaryArray {
809 let values: &[u8] = array.values();
810 let value_length = array.value_length() as usize;
811 let calculate_offset_from_index = |index: usize| index * value_length;
812 let buffer = match &predicate.strategy {
813 IterationStrategy::SlicesIterator => {
814 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
815 for (start, end) in SlicesIterator::new(&predicate.filter) {
816 buffer.extend_from_slice(
817 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
818 );
819 }
820 buffer
821 }
822 IterationStrategy::Slices(slices) => {
823 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
824 for (start, end) in slices {
825 buffer.extend_from_slice(
826 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
827 );
828 }
829 buffer
830 }
831 IterationStrategy::IndexIterator => {
832 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
833 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
834 });
835
836 let mut buffer = MutableBuffer::new(predicate.count * value_length);
837 iter.for_each(|item| buffer.extend_from_slice(item));
838 buffer
839 }
840 IterationStrategy::Indices(indices) => {
841 let iter = indices.iter().map(|x| {
842 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
843 });
844
845 let mut buffer = MutableBuffer::new(predicate.count * value_length);
846 iter.for_each(|item| buffer.extend_from_slice(item));
847 buffer
848 }
849 IterationStrategy::All | IterationStrategy::None => unreachable!(),
850 };
851 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
852 .len(predicate.count)
853 .add_buffer(buffer.into());
854
855 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
856 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
857 }
858
859 let data = unsafe { builder.build_unchecked() };
860 FixedSizeBinaryArray::from(data)
861}
862
863fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
865where
866 T: ArrowDictionaryKeyType,
867 T::Native: num_traits::Num,
868{
869 let builder = filter_primitive::<T>(array.keys(), predicate)
870 .into_data()
871 .into_builder()
872 .data_type(array.data_type().clone())
873 .child_data(vec![array.values().to_data()]);
874
875 DictionaryArray::from(unsafe { builder.build_unchecked() })
878}
879
880fn filter_struct(
882 array: &StructArray,
883 predicate: &FilterPredicate,
884) -> Result<StructArray, ArrowError> {
885 let columns = array
886 .columns()
887 .iter()
888 .map(|column| filter_array(column, predicate))
889 .collect::<Result<_, _>>()?;
890
891 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
892 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
893
894 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
895 } else {
896 None
897 };
898
899 Ok(unsafe {
900 StructArray::new_unchecked_with_length(
901 array.fields().clone(),
902 columns,
903 nulls,
904 predicate.count(),
905 )
906 })
907}
908
909fn filter_sparse_union(
911 array: &UnionArray,
912 predicate: &FilterPredicate,
913) -> Result<UnionArray, ArrowError> {
914 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
915 unreachable!()
916 };
917
918 let type_ids = filter_primitive(
919 &Int8Array::try_new(array.type_ids().clone(), None)?,
920 predicate,
921 );
922
923 let children = fields
924 .iter()
925 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
926 .collect::<Result<_, _>>()?;
927
928 Ok(unsafe {
929 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
930 })
931}
932
933fn filter_list_view<OffsetType: OffsetSizeTrait>(
935 array: &GenericListViewArray<OffsetType>,
936 predicate: &FilterPredicate,
937) -> GenericListViewArray<OffsetType> {
938 let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
939 let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
940
941 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
943 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
944
945 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
946 } else {
947 None
948 };
949
950 let list_data = ArrayDataBuilder::new(array.data_type().clone())
951 .nulls(nulls)
952 .buffers(vec![filtered_offsets, filtered_sizes])
953 .child_data(vec![array.values().to_data()])
954 .len(predicate.count);
955
956 let list_data = unsafe { list_data.build_unchecked() };
957
958 GenericListViewArray::from(list_data)
959}
960
961#[cfg(test)]
962mod tests {
963 use super::*;
964 use arrow_array::builder::*;
965 use arrow_array::cast::as_run_array;
966 use arrow_array::types::*;
967 use arrow_data::ArrayData;
968 use rand::distr::uniform::{UniformSampler, UniformUsize};
969 use rand::distr::{Alphanumeric, StandardUniform};
970 use rand::prelude::*;
971 use rand::rng;
972
973 macro_rules! def_temporal_test {
974 ($test:ident, $array_type: ident, $data: expr) => {
975 #[test]
976 fn $test() {
977 let a = $data;
978 let b = BooleanArray::from(vec![true, false, true, false]);
979 let c = filter(&a, &b).unwrap();
980 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
981 assert_eq!(2, d.len());
982 assert_eq!(1, d.value(0));
983 assert_eq!(3, d.value(1));
984 }
985 };
986 }
987
988 def_temporal_test!(
989 test_filter_date32,
990 Date32Array,
991 Date32Array::from(vec![1, 2, 3, 4])
992 );
993 def_temporal_test!(
994 test_filter_date64,
995 Date64Array,
996 Date64Array::from(vec![1, 2, 3, 4])
997 );
998 def_temporal_test!(
999 test_filter_time32_second,
1000 Time32SecondArray,
1001 Time32SecondArray::from(vec![1, 2, 3, 4])
1002 );
1003 def_temporal_test!(
1004 test_filter_time32_millisecond,
1005 Time32MillisecondArray,
1006 Time32MillisecondArray::from(vec![1, 2, 3, 4])
1007 );
1008 def_temporal_test!(
1009 test_filter_time64_microsecond,
1010 Time64MicrosecondArray,
1011 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1012 );
1013 def_temporal_test!(
1014 test_filter_time64_nanosecond,
1015 Time64NanosecondArray,
1016 Time64NanosecondArray::from(vec![1, 2, 3, 4])
1017 );
1018 def_temporal_test!(
1019 test_filter_duration_second,
1020 DurationSecondArray,
1021 DurationSecondArray::from(vec![1, 2, 3, 4])
1022 );
1023 def_temporal_test!(
1024 test_filter_duration_millisecond,
1025 DurationMillisecondArray,
1026 DurationMillisecondArray::from(vec![1, 2, 3, 4])
1027 );
1028 def_temporal_test!(
1029 test_filter_duration_microsecond,
1030 DurationMicrosecondArray,
1031 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1032 );
1033 def_temporal_test!(
1034 test_filter_duration_nanosecond,
1035 DurationNanosecondArray,
1036 DurationNanosecondArray::from(vec![1, 2, 3, 4])
1037 );
1038 def_temporal_test!(
1039 test_filter_timestamp_second,
1040 TimestampSecondArray,
1041 TimestampSecondArray::from(vec![1, 2, 3, 4])
1042 );
1043 def_temporal_test!(
1044 test_filter_timestamp_millisecond,
1045 TimestampMillisecondArray,
1046 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1047 );
1048 def_temporal_test!(
1049 test_filter_timestamp_microsecond,
1050 TimestampMicrosecondArray,
1051 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1052 );
1053 def_temporal_test!(
1054 test_filter_timestamp_nanosecond,
1055 TimestampNanosecondArray,
1056 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1057 );
1058
1059 #[test]
1060 fn test_filter_array_slice() {
1061 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1062 let b = BooleanArray::from(vec![true, false, false, true]);
1063 let c = filter(&a, &b).unwrap();
1067 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1068 assert_eq!(2, d.len());
1069 assert_eq!(6, d.value(0));
1070 assert_eq!(9, d.value(1));
1071 }
1072
1073 #[test]
1074 fn test_filter_array_low_density() {
1075 let mut data_values = (1..=65).collect::<Vec<i32>>();
1077 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1078 data_values.extend_from_slice(&[66, 67]);
1080 filter_values.extend_from_slice(&[false, true]);
1081 let a = Int32Array::from(data_values);
1082 let b = BooleanArray::from(filter_values);
1083 let c = filter(&a, &b).unwrap();
1084 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1085 assert_eq!(2, d.len());
1086 assert_eq!(65, d.value(0));
1087 assert_eq!(67, d.value(1));
1088 }
1089
1090 #[test]
1091 fn test_filter_array_high_density() {
1092 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1094 let mut filter_values = (1..=65)
1095 .map(|i| !matches!(i % 65, 0))
1096 .collect::<Vec<bool>>();
1097 data_values[1] = None;
1099 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1101 filter_values.extend_from_slice(&[false, true, true, true]);
1102 let a = Int32Array::from(data_values);
1103 let b = BooleanArray::from(filter_values);
1104 let c = filter(&a, &b).unwrap();
1105 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1106 assert_eq!(67, d.len());
1107 assert_eq!(3, d.null_count());
1108 assert_eq!(1, d.value(0));
1109 assert!(d.is_null(1));
1110 assert_eq!(64, d.value(63));
1111 assert!(d.is_null(64));
1112 assert_eq!(67, d.value(65));
1113 }
1114
1115 #[test]
1116 fn test_filter_string_array_simple() {
1117 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1118 let b = BooleanArray::from(vec![true, false, true, false]);
1119 let c = filter(&a, &b).unwrap();
1120 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1121 assert_eq!(2, d.len());
1122 assert_eq!("hello", d.value(0));
1123 assert_eq!("world", d.value(1));
1124 }
1125
1126 #[test]
1127 fn test_filter_primitive_array_with_null() {
1128 let a = Int32Array::from(vec![Some(5), None]);
1129 let b = BooleanArray::from(vec![false, true]);
1130 let c = filter(&a, &b).unwrap();
1131 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1132 assert_eq!(1, d.len());
1133 assert!(d.is_null(0));
1134 }
1135
1136 #[test]
1137 fn test_filter_string_array_with_null() {
1138 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1139 let b = BooleanArray::from(vec![true, false, false, true]);
1140 let c = filter(&a, &b).unwrap();
1141 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1142 assert_eq!(2, d.len());
1143 assert_eq!("hello", d.value(0));
1144 assert!(!d.is_null(0));
1145 assert!(d.is_null(1));
1146 }
1147
1148 #[test]
1149 fn test_filter_binary_array_with_null() {
1150 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1151 let a = BinaryArray::from(data);
1152 let b = BooleanArray::from(vec![true, false, false, true]);
1153 let c = filter(&a, &b).unwrap();
1154 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1155 assert_eq!(2, d.len());
1156 assert_eq!(b"hello", d.value(0));
1157 assert!(!d.is_null(0));
1158 assert!(d.is_null(1));
1159 }
1160
1161 fn _test_filter_byte_view<T>()
1162 where
1163 T: ByteViewType,
1164 str: AsRef<T::Native>,
1165 T::Native: PartialEq,
1166 {
1167 let array = {
1168 let mut builder = GenericByteViewBuilder::<T>::new();
1170 builder.append_value("hello");
1171 builder.append_value("world");
1172 builder.append_null();
1173 builder.append_value("large payload over 12 bytes");
1174 builder.append_value("lulu");
1175 builder.finish()
1176 };
1177
1178 {
1179 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1180 let actual = filter(&array, &predicate).unwrap();
1181
1182 assert_eq!(actual.len(), 3);
1183
1184 let expected = {
1185 let mut builder = GenericByteViewBuilder::<T>::new();
1187 builder.append_value("hello");
1188 builder.append_null();
1189 builder.append_value("large payload over 12 bytes");
1190 builder.finish()
1191 };
1192
1193 assert_eq!(actual.as_ref(), &expected);
1194 }
1195
1196 {
1197 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1198 let actual = filter(&array, &predicate).unwrap();
1199
1200 assert_eq!(actual.len(), 2);
1201
1202 let expected = {
1203 let mut builder = GenericByteViewBuilder::<T>::new();
1205 builder.append_value("hello");
1206 builder.append_value("lulu");
1207 builder.finish()
1208 };
1209
1210 assert_eq!(actual.as_ref(), &expected);
1211 }
1212 }
1213
1214 #[test]
1215 fn test_filter_string_view() {
1216 _test_filter_byte_view::<StringViewType>()
1217 }
1218
1219 #[test]
1220 fn test_filter_binary_view() {
1221 _test_filter_byte_view::<BinaryViewType>()
1222 }
1223
1224 #[test]
1225 fn test_filter_fixed_binary() {
1226 let v1 = [1_u8, 2];
1227 let v2 = [3_u8, 4];
1228 let v3 = [5_u8, 6];
1229 let v = vec![&v1, &v2, &v3];
1230 let a = FixedSizeBinaryArray::from(v);
1231 let b = BooleanArray::from(vec![true, false, true]);
1232 let c = filter(&a, &b).unwrap();
1233 let d = c
1234 .as_ref()
1235 .as_any()
1236 .downcast_ref::<FixedSizeBinaryArray>()
1237 .unwrap();
1238 assert_eq!(d.len(), 2);
1239 assert_eq!(d.value(0), &v1);
1240 assert_eq!(d.value(1), &v3);
1241 let c2 = FilterBuilder::new(&b)
1242 .optimize()
1243 .build()
1244 .filter(&a)
1245 .unwrap();
1246 let d2 = c2
1247 .as_ref()
1248 .as_any()
1249 .downcast_ref::<FixedSizeBinaryArray>()
1250 .unwrap();
1251 assert_eq!(d, d2);
1252
1253 let b = BooleanArray::from(vec![false, false, false]);
1254 let c = filter(&a, &b).unwrap();
1255 let d = c
1256 .as_ref()
1257 .as_any()
1258 .downcast_ref::<FixedSizeBinaryArray>()
1259 .unwrap();
1260 assert_eq!(d.len(), 0);
1261
1262 let b = BooleanArray::from(vec![true, true, true]);
1263 let c = filter(&a, &b).unwrap();
1264 let d = c
1265 .as_ref()
1266 .as_any()
1267 .downcast_ref::<FixedSizeBinaryArray>()
1268 .unwrap();
1269 assert_eq!(d.len(), 3);
1270 assert_eq!(d.value(0), &v1);
1271 assert_eq!(d.value(1), &v2);
1272 assert_eq!(d.value(2), &v3);
1273
1274 let b = BooleanArray::from(vec![false, false, true]);
1275 let c = filter(&a, &b).unwrap();
1276 let d = c
1277 .as_ref()
1278 .as_any()
1279 .downcast_ref::<FixedSizeBinaryArray>()
1280 .unwrap();
1281 assert_eq!(d.len(), 1);
1282 assert_eq!(d.value(0), &v3);
1283 let c2 = FilterBuilder::new(&b)
1284 .optimize()
1285 .build()
1286 .filter(&a)
1287 .unwrap();
1288 let d2 = c2
1289 .as_ref()
1290 .as_any()
1291 .downcast_ref::<FixedSizeBinaryArray>()
1292 .unwrap();
1293 assert_eq!(d, d2);
1294 }
1295
1296 #[test]
1297 fn test_filter_array_slice_with_null() {
1298 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1299 let b = BooleanArray::from(vec![true, false, false, true]);
1300 let c = filter(&a, &b).unwrap();
1304 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1305 assert_eq!(2, d.len());
1306 assert!(d.is_null(0));
1307 assert!(!d.is_null(1));
1308 assert_eq!(9, d.value(1));
1309 }
1310
1311 #[test]
1312 fn test_filter_run_end_encoding_array() {
1313 let run_ends = Int64Array::from(vec![2, 3, 8]);
1314 let values = Int64Array::from(vec![7, -2, 9]);
1315 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1316 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1317 let c = filter(&a, &b).unwrap();
1318 let actual: &RunArray<Int64Type> = as_run_array(&c);
1319 assert_eq!(4, actual.len());
1320
1321 let expected = RunArray::try_new(
1322 &Int64Array::from(vec![1, 2, 4]),
1323 &Int64Array::from(vec![7, -2, 9]),
1324 )
1325 .expect("Failed to make expected RunArray test is broken");
1326
1327 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1328 assert_eq!(actual.values(), expected.values())
1329 }
1330
1331 #[test]
1332 fn test_filter_run_end_encoding_array_remove_value() {
1333 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1334 let values = Int32Array::from(vec![7, -2, 9, -8]);
1335 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1336 let b = BooleanArray::from(vec![
1337 false, true, false, false, true, false, true, false, false, false,
1338 ]);
1339 let c = filter(&a, &b).unwrap();
1340 let actual: &RunArray<Int32Type> = as_run_array(&c);
1341 assert_eq!(3, actual.len());
1342
1343 let expected =
1344 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1345 .expect("Failed to make expected RunArray test is broken");
1346
1347 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1348 assert_eq!(actual.values(), expected.values())
1349 }
1350
1351 #[test]
1352 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1353 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1354 let values = Int16Array::from(vec![7, -2, 9, -8]);
1355 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1356 let b = BooleanArray::from(vec![
1357 false, false, false, false, false, false, true, false, false, false,
1358 ]);
1359 let c = filter(&a, &b).unwrap();
1360 let actual: &RunArray<Int16Type> = as_run_array(&c);
1361 assert_eq!(1, actual.len());
1362
1363 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1364 .expect("Failed to make expected RunArray test is broken");
1365
1366 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1367 assert_eq!(actual.values(), expected.values())
1368 }
1369
1370 #[test]
1371 fn test_filter_run_end_encoding_array_empty() {
1372 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1373 let values = Int64Array::from(vec![7, -2, 9, -8]);
1374 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1375 let b = BooleanArray::from(vec![
1376 false, false, false, false, false, false, false, false, false, false,
1377 ]);
1378 let c = filter(&a, &b).unwrap();
1379 let actual: &RunArray<Int64Type> = as_run_array(&c);
1380 assert_eq!(0, actual.len());
1381 }
1382
1383 #[test]
1384 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1385 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1386 let values = Int64Array::from(vec![7, -2, 9, -8]);
1387 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1388 let b = BooleanArray::from(vec![false, true, true]);
1389 let c = filter(&a, &b).unwrap();
1390 let actual: &RunArray<Int64Type> = as_run_array(&c);
1391 assert_eq!(2, actual.len());
1392
1393 let expected = RunArray::try_new(
1394 &Int64Array::from(vec![1, 2]),
1395 &Int64Array::from(vec![7, -2]),
1396 )
1397 .expect("Failed to make expected RunArray test is broken");
1398
1399 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1400 assert_eq!(actual.values(), expected.values())
1401 }
1402
1403 #[test]
1404 fn test_filter_dictionary_array() {
1405 let values = [Some("hello"), None, Some("world"), Some("!")];
1406 let a: Int8DictionaryArray = values.iter().copied().collect();
1407 let b = BooleanArray::from(vec![false, true, true, false]);
1408 let c = filter(&a, &b).unwrap();
1409 let d = c
1410 .as_ref()
1411 .as_any()
1412 .downcast_ref::<Int8DictionaryArray>()
1413 .unwrap();
1414 let value_array = d.values();
1415 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1416 assert_eq!(3, values.len());
1418 assert_eq!(2, d.len());
1420 assert!(d.is_null(0));
1421 assert_eq!("world", values.value(d.keys().value(1) as usize));
1422 }
1423
1424 #[test]
1425 fn test_filter_list_array() {
1426 let value_data = ArrayData::builder(DataType::Int32)
1427 .len(8)
1428 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1429 .build()
1430 .unwrap();
1431
1432 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1433
1434 let list_data_type =
1435 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1436 let list_data = ArrayData::builder(list_data_type)
1437 .len(4)
1438 .add_buffer(value_offsets)
1439 .add_child_data(value_data)
1440 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1441 .build()
1442 .unwrap();
1443
1444 let a = LargeListArray::from(list_data);
1446 let b = BooleanArray::from(vec![false, true, false, true]);
1447 let result = filter(&a, &b).unwrap();
1448
1449 let value_data = ArrayData::builder(DataType::Int32)
1451 .len(3)
1452 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1453 .build()
1454 .unwrap();
1455
1456 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1457
1458 let list_data_type =
1459 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1460 let expected = ArrayData::builder(list_data_type)
1461 .len(2)
1462 .add_buffer(value_offsets)
1463 .add_child_data(value_data)
1464 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1465 .build()
1466 .unwrap();
1467
1468 assert_eq!(&make_array(expected), &result);
1469 }
1470
1471 fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1472 let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1474 list_array.append_value([Some(1), Some(2)]);
1475 list_array.append_null();
1476 list_array.append_value([]);
1477 list_array.append_value([Some(3), Some(4)]);
1478
1479 let list_array = list_array.finish();
1480 let predicate = BooleanArray::from_iter([true, false, true, false]);
1481
1482 let filtered = filter(&list_array, &predicate)
1484 .unwrap()
1485 .as_list_view::<T>()
1486 .clone();
1487
1488 let mut expected =
1489 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1490 expected.append_value([Some(1), Some(2)]);
1491 expected.append_value([]);
1492 let expected = expected.finish();
1493
1494 assert_eq!(&filtered, &expected);
1495 }
1496
1497 fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1498 let mut list_array =
1500 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1501 list_array.append_value([Some(1), Some(2)]);
1502 list_array.append_null();
1503 list_array.append_value([]);
1504 list_array.append_value([Some(3), Some(4)]);
1505
1506 let list_array = list_array.finish();
1507
1508 let sliced = list_array.slice(1, 3);
1510 let predicate = BooleanArray::from_iter([false, false, true]);
1511
1512 let filtered = filter(&sliced, &predicate)
1514 .unwrap()
1515 .as_list_view::<T>()
1516 .clone();
1517
1518 let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1519 expected.append_value([Some(3), Some(4)]);
1520 let expected = expected.finish();
1521
1522 assert_eq!(&filtered, &expected);
1523 }
1524
1525 #[test]
1526 fn test_filter_list_view_array() {
1527 test_case_filter_list_view::<i32>();
1528 test_case_filter_list_view::<i64>();
1529
1530 test_case_filter_sliced_list_view::<i32>();
1531 test_case_filter_sliced_list_view::<i64>();
1532 }
1533
1534 #[test]
1535 fn test_slice_iterator_bits() {
1536 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1537 let filter = BooleanArray::from(filter_values);
1538 let filter_count = filter_count(&filter);
1539
1540 let iter = SlicesIterator::new(&filter);
1541 let chunks = iter.collect::<Vec<_>>();
1542
1543 assert_eq!(chunks, vec![(1, 2)]);
1544 assert_eq!(filter_count, 1);
1545 }
1546
1547 #[test]
1548 fn test_slice_iterator_bits1() {
1549 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1550 let filter = BooleanArray::from(filter_values);
1551 let filter_count = filter_count(&filter);
1552
1553 let iter = SlicesIterator::new(&filter);
1554 let chunks = iter.collect::<Vec<_>>();
1555
1556 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1557 assert_eq!(filter_count, 64 - 1);
1558 }
1559
1560 #[test]
1561 fn test_slice_iterator_chunk_and_bits() {
1562 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1563 let filter = BooleanArray::from(filter_values);
1564 let filter_count = filter_count(&filter);
1565
1566 let iter = SlicesIterator::new(&filter);
1567 let chunks = iter.collect::<Vec<_>>();
1568
1569 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1570 assert_eq!(filter_count, 61 + 61 + 5);
1571 }
1572
1573 #[test]
1574 fn test_null_mask() {
1575 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1576
1577 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1578 let out = filter(&a, &mask1).unwrap();
1579 assert_eq!(out.as_ref(), &a.slice(0, 2));
1580 }
1581
1582 #[test]
1583 fn test_filter_record_batch_no_columns() {
1584 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1585 let options = RecordBatchOptions::default().with_row_count(Some(100));
1586 let record_batch =
1587 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1588 let out = filter_record_batch(&record_batch, &pred).unwrap();
1589
1590 assert_eq!(out.num_rows(), 2);
1591 }
1592
1593 #[test]
1594 fn test_fast_path() {
1595 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1596
1597 let mask = BooleanArray::from(vec![true, true, true]);
1599 let out = filter(&a, &mask).unwrap();
1600 let b = out
1601 .as_any()
1602 .downcast_ref::<PrimitiveArray<Int64Type>>()
1603 .unwrap();
1604 assert_eq!(&a, b);
1605
1606 let mask = BooleanArray::from(vec![false, false, false]);
1608 let out = filter(&a, &mask).unwrap();
1609 assert_eq!(out.len(), 0);
1610 assert_eq!(out.data_type(), &DataType::Int64);
1611 }
1612
1613 #[test]
1614 fn test_slices() {
1615 let bools = std::iter::repeat_n(true, 10)
1617 .chain(std::iter::repeat_n(false, 30))
1618 .chain(std::iter::repeat_n(true, 20))
1619 .chain(std::iter::repeat_n(false, 17))
1620 .chain(std::iter::repeat_n(true, 4));
1621
1622 let bool_array: BooleanArray = bools.map(Some).collect();
1623
1624 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1625 let expected = vec![(0, 10), (40, 60), (77, 81)];
1626 assert_eq!(slices, expected);
1627
1628 let len = bool_array.len();
1630 let sliced_array = bool_array.slice(7, len - 10);
1631 let sliced_array = sliced_array
1632 .as_any()
1633 .downcast_ref::<BooleanArray>()
1634 .unwrap();
1635 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1636 let expected = vec![(0, 3), (33, 53), (70, 71)];
1637 assert_eq!(slices, expected);
1638 }
1639
1640 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1641 let mut rng = rng();
1642
1643 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1644 .take(mask_len)
1645 .collect();
1646
1647 let buffer = Buffer::from_iter(bools.iter().cloned());
1648
1649 let truncated_length = mask_len - offset - truncate;
1650
1651 let data = ArrayDataBuilder::new(DataType::Boolean)
1652 .len(truncated_length)
1653 .offset(offset)
1654 .add_buffer(buffer)
1655 .build()
1656 .unwrap();
1657
1658 let filter = BooleanArray::from(data);
1659
1660 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1661 .flat_map(|(start, end)| start..end)
1662 .collect();
1663
1664 let count = filter_count(&filter);
1665 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1666
1667 let expected_bits: Vec<_> = bools
1668 .iter()
1669 .skip(offset)
1670 .take(truncated_length)
1671 .enumerate()
1672 .flat_map(|(idx, v)| v.then(|| idx))
1673 .collect();
1674
1675 assert_eq!(slice_bits, expected_bits);
1676 assert_eq!(index_bits, expected_bits);
1677 }
1678
1679 #[test]
1680 #[cfg_attr(miri, ignore)]
1681 fn fuzz_test_slices_iterator() {
1682 let mut rng = rng();
1683
1684 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1685 for _ in 0..100 {
1686 let mask_len = rng.random_range(0..1024);
1687 let max_offset = 64.min(mask_len);
1688 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1689
1690 let max_truncate = 128.min(mask_len - offset);
1691 let truncate = uusize
1692 .sample(&mut rng)
1693 .checked_rem(max_truncate)
1694 .unwrap_or(0);
1695
1696 test_slices_fuzz(mask_len, offset, truncate);
1697 }
1698
1699 test_slices_fuzz(64, 0, 0);
1700 test_slices_fuzz(64, 8, 0);
1701 test_slices_fuzz(64, 8, 8);
1702 test_slices_fuzz(32, 8, 8);
1703 test_slices_fuzz(32, 5, 9);
1704 }
1705
1706 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1708 values
1709 .into_iter()
1710 .zip(predicate)
1711 .filter(|(_, x)| **x)
1712 .map(|(a, _)| a)
1713 .collect()
1714 }
1715
1716 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1718 where
1719 StandardUniform: Distribution<T>,
1720 {
1721 let mut rng = rng();
1722 (0..len)
1723 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1724 .collect()
1725 }
1726
1727 fn gen_strings(
1729 len: usize,
1730 valid_percent: f64,
1731 str_len_range: std::ops::Range<usize>,
1732 ) -> Vec<Option<String>> {
1733 let mut rng = rng();
1734 (0..len)
1735 .map(|_| {
1736 rng.random_bool(valid_percent).then(|| {
1737 let len = rng.random_range(str_len_range.clone());
1738 (0..len)
1739 .map(|_| char::from(rng.sample(Alphanumeric)))
1740 .collect()
1741 })
1742 })
1743 .collect()
1744 }
1745
1746 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1748 src.iter().map(|x| x.as_deref())
1749 }
1750
1751 #[test]
1752 #[cfg_attr(miri, ignore)]
1753 fn fuzz_filter() {
1754 let mut rng = rng();
1755
1756 for i in 0..100 {
1757 let filter_percent = match i {
1758 0..=4 => 1.,
1759 5..=10 => 0.,
1760 _ => rng.random_range(0.0..1.0),
1761 };
1762
1763 let valid_percent = rng.random_range(0.0..1.0);
1764
1765 let array_len = rng.random_range(32..256);
1766 let array_offset = rng.random_range(0..10);
1767
1768 let filter_offset = rng.random_range(0..10);
1770 let filter_truncate = rng.random_range(0..10);
1771 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1772 .take(array_len + filter_offset - filter_truncate)
1773 .collect();
1774
1775 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1776
1777 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1779 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1780 let bools = &bools[filter_offset..];
1781
1782 let values = gen_primitive(array_len + array_offset, valid_percent);
1784 let src = Int32Array::from_iter(values.iter().cloned());
1785
1786 let src = src.slice(array_offset, array_len);
1787 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1788 let values = &values[array_offset..];
1789
1790 let filtered = filter(src, predicate).unwrap();
1791 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1792 let actual: Vec<_> = array.iter().collect();
1793
1794 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1795
1796 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1798 let src = StringArray::from_iter(as_deref(&strings));
1799
1800 let src = src.slice(array_offset, array_len);
1801 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1802
1803 let filtered = filter(src, predicate).unwrap();
1804 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1805 let actual: Vec<_> = array.iter().collect();
1806
1807 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1808 assert_eq!(actual, expected_strings);
1809
1810 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1812
1813 let src = src.slice(array_offset, array_len);
1814 let src = src
1815 .as_any()
1816 .downcast_ref::<DictionaryArray<Int32Type>>()
1817 .unwrap();
1818
1819 let filtered = filter(src, predicate).unwrap();
1820
1821 let array = filtered
1822 .as_any()
1823 .downcast_ref::<DictionaryArray<Int32Type>>()
1824 .unwrap();
1825
1826 let values = array
1827 .values()
1828 .as_any()
1829 .downcast_ref::<StringArray>()
1830 .unwrap();
1831
1832 let actual: Vec<_> = array
1833 .keys()
1834 .iter()
1835 .map(|key| key.map(|key| values.value(key as usize)))
1836 .collect();
1837
1838 assert_eq!(actual, expected_strings);
1839 }
1840 }
1841
1842 #[test]
1843 fn test_filter_map() {
1844 let mut builder =
1845 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1846 builder.keys().append_value("key1");
1848 builder.values().append_value(1);
1849 builder.append(true).unwrap();
1850 builder.keys().append_value("key2");
1851 builder.keys().append_value("key3");
1852 builder.values().append_value(2);
1853 builder.values().append_value(3);
1854 builder.append(true).unwrap();
1855 builder.append(false).unwrap();
1856 builder.keys().append_value("key1");
1857 builder.values().append_value(1);
1858 builder.append(true).unwrap();
1859 let maparray = Arc::new(builder.finish()) as ArrayRef;
1860
1861 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1862 .into_iter()
1863 .collect::<BooleanArray>();
1864 let got = filter(&maparray, &indices).unwrap();
1865
1866 let mut builder =
1867 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1868 builder.keys().append_value("key1");
1869 builder.values().append_value(1);
1870 builder.append(true).unwrap();
1871 builder.keys().append_value("key1");
1872 builder.values().append_value(1);
1873 builder.append(true).unwrap();
1874 let expected = Arc::new(builder.finish()) as ArrayRef;
1875
1876 assert_eq!(&expected, &got);
1877 }
1878
1879 #[test]
1880 fn test_filter_fixed_size_list_arrays() {
1881 let value_data = ArrayData::builder(DataType::Int32)
1882 .len(9)
1883 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1884 .build()
1885 .unwrap();
1886 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1887 let list_data = ArrayData::builder(list_data_type)
1888 .len(3)
1889 .add_child_data(value_data)
1890 .build()
1891 .unwrap();
1892 let array = FixedSizeListArray::from(list_data);
1893
1894 let filter_array = BooleanArray::from(vec![true, false, false]);
1895
1896 let c = filter(&array, &filter_array).unwrap();
1897 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1898
1899 assert_eq!(filtered.len(), 1);
1900
1901 let list = filtered.value(0);
1902 assert_eq!(
1903 &[0, 1, 2],
1904 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1905 );
1906
1907 let filter_array = BooleanArray::from(vec![true, false, true]);
1908
1909 let c = filter(&array, &filter_array).unwrap();
1910 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1911
1912 assert_eq!(filtered.len(), 2);
1913
1914 let list = filtered.value(0);
1915 assert_eq!(
1916 &[0, 1, 2],
1917 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1918 );
1919 let list = filtered.value(1);
1920 assert_eq!(
1921 &[6, 7, 8],
1922 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1923 );
1924 }
1925
1926 #[test]
1927 fn test_filter_fixed_size_list_arrays_with_null() {
1928 let value_data = ArrayData::builder(DataType::Int32)
1929 .len(10)
1930 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1931 .build()
1932 .unwrap();
1933
1934 let mut null_bits: [u8; 1] = [0; 1];
1938 bit_util::set_bit(&mut null_bits, 0);
1939 bit_util::set_bit(&mut null_bits, 3);
1940 bit_util::set_bit(&mut null_bits, 4);
1941
1942 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1943 let list_data = ArrayData::builder(list_data_type)
1944 .len(5)
1945 .add_child_data(value_data)
1946 .null_bit_buffer(Some(Buffer::from(null_bits)))
1947 .build()
1948 .unwrap();
1949 let array = FixedSizeListArray::from(list_data);
1950
1951 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1952
1953 let c = filter(&array, &filter_array).unwrap();
1954 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1955
1956 assert_eq!(filtered.len(), 3);
1957
1958 let list = filtered.value(0);
1959 assert_eq!(
1960 &[0, 1],
1961 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1962 );
1963 assert!(filtered.is_null(1));
1964 let list = filtered.value(2);
1965 assert_eq!(
1966 &[6, 7],
1967 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1968 );
1969 }
1970
1971 fn test_filter_union_array(array: UnionArray) {
1972 let filter_array = BooleanArray::from(vec![true, false, false]);
1973 let c = filter(&array, &filter_array).unwrap();
1974 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1975
1976 let mut builder = UnionBuilder::new_dense();
1977 builder.append::<Int32Type>("A", 1).unwrap();
1978 let expected_array = builder.build().unwrap();
1979
1980 compare_union_arrays(filtered, &expected_array);
1981
1982 let filter_array = BooleanArray::from(vec![true, false, true]);
1983 let c = filter(&array, &filter_array).unwrap();
1984 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1985
1986 let mut builder = UnionBuilder::new_dense();
1987 builder.append::<Int32Type>("A", 1).unwrap();
1988 builder.append::<Int32Type>("A", 34).unwrap();
1989 let expected_array = builder.build().unwrap();
1990
1991 compare_union_arrays(filtered, &expected_array);
1992
1993 let filter_array = BooleanArray::from(vec![true, true, false]);
1994 let c = filter(&array, &filter_array).unwrap();
1995 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1996
1997 let mut builder = UnionBuilder::new_dense();
1998 builder.append::<Int32Type>("A", 1).unwrap();
1999 builder.append::<Float64Type>("B", 3.2).unwrap();
2000 let expected_array = builder.build().unwrap();
2001
2002 compare_union_arrays(filtered, &expected_array);
2003 }
2004
2005 #[test]
2006 fn test_filter_union_array_dense() {
2007 let mut builder = UnionBuilder::new_dense();
2008 builder.append::<Int32Type>("A", 1).unwrap();
2009 builder.append::<Float64Type>("B", 3.2).unwrap();
2010 builder.append::<Int32Type>("A", 34).unwrap();
2011 let array = builder.build().unwrap();
2012
2013 test_filter_union_array(array);
2014 }
2015
2016 #[test]
2017 fn test_filter_run_union_array_dense() {
2018 let mut builder = UnionBuilder::new_dense();
2019 builder.append::<Int32Type>("A", 1).unwrap();
2020 builder.append::<Int32Type>("A", 3).unwrap();
2021 builder.append::<Int32Type>("A", 34).unwrap();
2022 let array = builder.build().unwrap();
2023
2024 let filter_array = BooleanArray::from(vec![true, true, false]);
2025 let c = filter(&array, &filter_array).unwrap();
2026 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2027
2028 let mut builder = UnionBuilder::new_dense();
2029 builder.append::<Int32Type>("A", 1).unwrap();
2030 builder.append::<Int32Type>("A", 3).unwrap();
2031 let expected = builder.build().unwrap();
2032
2033 assert_eq!(filtered.to_data(), expected.to_data());
2034 }
2035
2036 #[test]
2037 fn test_filter_union_array_dense_with_nulls() {
2038 let mut builder = UnionBuilder::new_dense();
2039 builder.append::<Int32Type>("A", 1).unwrap();
2040 builder.append::<Float64Type>("B", 3.2).unwrap();
2041 builder.append_null::<Float64Type>("B").unwrap();
2042 builder.append::<Int32Type>("A", 34).unwrap();
2043 let array = builder.build().unwrap();
2044
2045 let filter_array = BooleanArray::from(vec![true, true, false, false]);
2046 let c = filter(&array, &filter_array).unwrap();
2047 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2048
2049 let mut builder = UnionBuilder::new_dense();
2050 builder.append::<Int32Type>("A", 1).unwrap();
2051 builder.append::<Float64Type>("B", 3.2).unwrap();
2052 let expected_array = builder.build().unwrap();
2053
2054 compare_union_arrays(filtered, &expected_array);
2055
2056 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2057 let c = filter(&array, &filter_array).unwrap();
2058 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2059
2060 let mut builder = UnionBuilder::new_dense();
2061 builder.append::<Int32Type>("A", 1).unwrap();
2062 builder.append_null::<Float64Type>("B").unwrap();
2063 let expected_array = builder.build().unwrap();
2064
2065 compare_union_arrays(filtered, &expected_array);
2066 }
2067
2068 #[test]
2069 fn test_filter_union_array_sparse() {
2070 let mut builder = UnionBuilder::new_sparse();
2071 builder.append::<Int32Type>("A", 1).unwrap();
2072 builder.append::<Float64Type>("B", 3.2).unwrap();
2073 builder.append::<Int32Type>("A", 34).unwrap();
2074 let array = builder.build().unwrap();
2075
2076 test_filter_union_array(array);
2077 }
2078
2079 #[test]
2080 fn test_filter_union_array_sparse_with_nulls() {
2081 let mut builder = UnionBuilder::new_sparse();
2082 builder.append::<Int32Type>("A", 1).unwrap();
2083 builder.append::<Float64Type>("B", 3.2).unwrap();
2084 builder.append_null::<Float64Type>("B").unwrap();
2085 builder.append::<Int32Type>("A", 34).unwrap();
2086 let array = builder.build().unwrap();
2087
2088 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2089 let c = filter(&array, &filter_array).unwrap();
2090 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2091
2092 let mut builder = UnionBuilder::new_sparse();
2093 builder.append::<Int32Type>("A", 1).unwrap();
2094 builder.append_null::<Float64Type>("B").unwrap();
2095 let expected_array = builder.build().unwrap();
2096
2097 compare_union_arrays(filtered, &expected_array);
2098 }
2099
2100 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2101 assert_eq!(union1.len(), union2.len());
2102
2103 for i in 0..union1.len() {
2104 let type_id = union1.type_id(i);
2105
2106 let slot1 = union1.value(i);
2107 let slot2 = union2.value(i);
2108
2109 assert_eq!(slot1.is_null(0), slot2.is_null(0));
2110
2111 if !slot1.is_null(0) && !slot2.is_null(0) {
2112 match type_id {
2113 0 => {
2114 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2115 assert_eq!(slot1.len(), 1);
2116 let value1 = slot1.value(0);
2117
2118 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2119 assert_eq!(slot2.len(), 1);
2120 let value2 = slot2.value(0);
2121 assert_eq!(value1, value2);
2122 }
2123 1 => {
2124 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2125 assert_eq!(slot1.len(), 1);
2126 let value1 = slot1.value(0);
2127
2128 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2129 assert_eq!(slot2.len(), 1);
2130 let value2 = slot2.value(0);
2131 assert_eq!(value1, value2);
2132 }
2133 _ => unreachable!(),
2134 }
2135 }
2136 }
2137 }
2138
2139 #[test]
2140 fn test_filter_struct() {
2141 let predicate = BooleanArray::from(vec![true, false, true, false]);
2142
2143 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2144 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2145
2146 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2147 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2148
2149 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2150 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2151
2152 let a_field = Field::new("a", DataType::Utf8, false);
2153 let b_field = Field::new("b", DataType::Int32, false);
2154
2155 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2156 let expected =
2157 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2158
2159 let result = filter(&array, &predicate).unwrap();
2160
2161 assert_eq!(result.to_data(), expected.to_data());
2162
2163 let array = StructArray::new(
2164 vec![a_field.clone()].into(),
2165 vec![a.clone()],
2166 Some(null_mask.clone()),
2167 );
2168 let expected = StructArray::new(
2169 vec![a_field.clone()].into(),
2170 vec![a_filtered.clone()],
2171 Some(null_mask_filtered.clone()),
2172 );
2173
2174 let result = filter(&array, &predicate).unwrap();
2175
2176 assert_eq!(result.to_data(), expected.to_data());
2177
2178 let array = StructArray::new(
2179 vec![a_field.clone(), b_field.clone()].into(),
2180 vec![a.clone(), b.clone()],
2181 None,
2182 );
2183 let expected = StructArray::new(
2184 vec![a_field.clone(), b_field.clone()].into(),
2185 vec![a_filtered.clone(), b_filtered.clone()],
2186 None,
2187 );
2188
2189 let result = filter(&array, &predicate).unwrap();
2190
2191 assert_eq!(result.to_data(), expected.to_data());
2192
2193 let array = StructArray::new(
2194 vec![a_field.clone(), b_field.clone()].into(),
2195 vec![a.clone(), b.clone()],
2196 Some(null_mask.clone()),
2197 );
2198
2199 let expected = StructArray::new(
2200 vec![a_field.clone(), b_field.clone()].into(),
2201 vec![a_filtered.clone(), b_filtered.clone()],
2202 Some(null_mask_filtered.clone()),
2203 );
2204
2205 let result = filter(&array, &predicate).unwrap();
2206
2207 assert_eq!(result.to_data(), expected.to_data());
2208 }
2209
2210 #[test]
2211 fn test_filter_empty_struct() {
2212 let fields = arrow_schema::Field::new(
2219 "a",
2220 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2221 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2222 arrow_schema::Field::new(
2223 "c",
2224 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2225 true,
2226 ),
2227 ])),
2228 true,
2229 );
2230
2231 let schema = Arc::new(Schema::new(vec![fields]));
2239
2240 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2241 let c = Arc::new(StructArray::new_empty_fields(
2242 3,
2243 Some(NullBuffer::from(vec![true, true, true])),
2244 ));
2245 let a = StructArray::new(
2246 vec![
2247 Field::new("b", DataType::Int64, true),
2248 Field::new("c", DataType::Struct(Fields::empty()), true),
2249 ]
2250 .into(),
2251 vec![b.clone(), c.clone()],
2252 Some(NullBuffer::from(vec![true, true, true])),
2253 );
2254 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2255 println!("{record_batch:?}");
2256
2257 let predicate = BooleanArray::from(vec![true, false, true]);
2259 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2260
2261 assert_eq!(filtered_batch.num_rows(), 2);
2263 }
2264
2265 #[test]
2266 #[should_panic]
2267 fn test_filter_bits_too_large() {
2268 let buffer = BooleanBuffer::from(vec![false; 8]);
2269 let predicate = BooleanArray::from(vec![true; 9]);
2270 let filter = FilterBuilder::new(&predicate).build();
2271 filter_bits(&buffer, &filter);
2272 }
2273
2274 #[test]
2275 #[should_panic]
2276 fn test_filter_native_too_large() {
2277 let values = vec![1; 8];
2278 let predicate = BooleanArray::from(vec![false; 9]);
2279 let filter = FilterBuilder::new(&predicate).build();
2280 filter_native(&values, &filter);
2281 }
2282}