1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 filter.values().into()
62 }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66 fn from(filter: &'a BooleanBuffer) -> Self {
67 Self(filter.set_slices())
68 }
69}
70
71impl Iterator for SlicesIterator<'_> {
72 type Item = (usize, usize);
73
74 fn next(&mut self) -> Option<Self::Item> {
75 self.0.next()
76 }
77}
78
79struct IndexIterator<'a> {
84 remaining: usize,
85 iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90 assert_eq!(filter.null_count(), 0);
91 let iter = filter.values().set_indices();
92 Self { remaining, iter }
93 }
94}
95
96impl Iterator for IndexIterator<'_> {
97 type Item = usize;
98
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.remaining != 0 {
101 let next = self.iter.next().expect("IndexIterator exhausted early");
104 self.remaining -= 1;
105 return Some(next);
107 }
108 None
109 }
110
111 fn size_hint(&self) -> (usize, Option<usize>) {
112 (self.remaining, Some(self.remaining))
113 }
114}
115
116fn filter_count(filter: &BooleanArray) -> usize {
118 filter.values().count_set_bits()
119}
120
121pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
149 let nulls = filter.nulls().unwrap();
150 let mask = filter.values() & nulls.inner();
151 BooleanArray::new(mask, None)
152}
153
154pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
183 let mut filter_builder = FilterBuilder::new(predicate);
184
185 if FilterBuilder::is_optimize_beneficial(values.data_type()) {
186 filter_builder = filter_builder.optimize();
189 }
190
191 let predicate = filter_builder.build();
192
193 filter_array(values, &predicate)
194}
195
196pub fn filter_record_batch(
207 record_batch: &RecordBatch,
208 predicate: &BooleanArray,
209) -> Result<RecordBatch, ArrowError> {
210 let mut filter_builder = FilterBuilder::new(predicate);
211 let num_cols = record_batch.num_columns();
212 if num_cols > 1
213 || (num_cols > 0
214 && FilterBuilder::is_optimize_beneficial(
215 record_batch.schema_ref().field(0).data_type(),
216 ))
217 {
218 filter_builder = filter_builder.optimize();
221 }
222 let filter = filter_builder.build();
223
224 filter.filter_record_batch(record_batch)
225}
226
227#[derive(Debug)]
229pub struct FilterBuilder {
230 filter: BooleanArray,
231 count: usize,
232 strategy: IterationStrategy,
233}
234
235impl FilterBuilder {
236 pub fn new(filter: &BooleanArray) -> Self {
238 let filter = match filter.null_count() {
239 0 => filter.clone(),
240 _ => prep_null_mask_filter(filter),
241 };
242
243 let count = filter_count(&filter);
244 let strategy = IterationStrategy::default_strategy(filter.len(), count);
245
246 Self {
247 filter,
248 count,
249 strategy,
250 }
251 }
252
253 pub fn optimize(mut self) -> Self {
264 match self.strategy {
265 IterationStrategy::SlicesIterator => {
266 let slices = SlicesIterator::new(&self.filter).collect();
267 self.strategy = IterationStrategy::Slices(slices)
268 }
269 IterationStrategy::IndexIterator => {
270 let indices = IndexIterator::new(&self.filter, self.count).collect();
271 self.strategy = IterationStrategy::Indices(indices)
272 }
273 _ => {}
274 }
275 self
276 }
277
278 pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
283 match data_type {
284 DataType::Struct(fields) => {
285 fields.len() > 1
286 || fields.len() == 1
287 && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
288 }
289 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
290 _ => false,
291 }
292 }
293
294 pub fn build(self) -> FilterPredicate {
296 FilterPredicate {
297 filter: self.filter,
298 count: self.count,
299 strategy: self.strategy,
300 }
301 }
302}
303
304#[derive(Debug)]
306enum IterationStrategy {
307 SlicesIterator,
309 IndexIterator,
311 Indices(Vec<usize>),
313 Slices(Vec<(usize, usize)>),
315 All,
317 None,
319}
320
321impl IterationStrategy {
322 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
325 if filter_length == 0 || filter_count == 0 {
326 return IterationStrategy::None;
327 }
328
329 if filter_count == filter_length {
330 return IterationStrategy::All;
331 }
332
333 let selectivity_frac = filter_count as f64 / filter_length as f64;
338 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
339 return IterationStrategy::SlicesIterator;
340 }
341 IterationStrategy::IndexIterator
342 }
343}
344
345#[derive(Debug)]
347pub struct FilterPredicate {
348 filter: BooleanArray,
349 count: usize,
350 strategy: IterationStrategy,
351}
352
353impl FilterPredicate {
354 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
356 filter_array(values, self)
357 }
358
359 pub fn filter_record_batch(
364 &self,
365 record_batch: &RecordBatch,
366 ) -> Result<RecordBatch, ArrowError> {
367 let filtered_arrays = record_batch
368 .columns()
369 .iter()
370 .map(|a| filter_array(a, self))
371 .collect::<Result<Vec<_>, _>>()?;
372
373 unsafe {
376 Ok(RecordBatch::new_unchecked(
377 record_batch.schema(),
378 filtered_arrays,
379 self.count,
380 ))
381 }
382 }
383
384 pub fn count(&self) -> usize {
386 self.count
387 }
388}
389
390fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
391 if predicate.filter.len() > values.len() {
392 return Err(ArrowError::InvalidArgumentError(format!(
393 "Filter predicate of length {} is larger than target array of length {}",
394 predicate.filter.len(),
395 values.len()
396 )));
397 }
398
399 match predicate.strategy {
400 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
401 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
402 _ => downcast_primitive_array! {
404 values => Ok(Arc::new(filter_primitive(values, predicate))),
405 DataType::Boolean => {
406 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
407 Ok(Arc::new(filter_boolean(values, predicate)))
408 }
409 DataType::Utf8 => {
410 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
411 }
412 DataType::LargeUtf8 => {
413 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
414 }
415 DataType::Utf8View => {
416 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
417 }
418 DataType::Binary => {
419 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
420 }
421 DataType::LargeBinary => {
422 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
423 }
424 DataType::BinaryView => {
425 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
426 }
427 DataType::FixedSizeBinary(_) => {
428 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
429 }
430 DataType::ListView(_) => {
431 Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
432 }
433 DataType::LargeListView(_) => {
434 Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
435 }
436 DataType::RunEndEncoded(_, _) => {
437 downcast_run_array!{
438 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
439 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
440 }
441 }
442 DataType::Dictionary(_, _) => downcast_dictionary_array! {
443 values => Ok(Arc::new(filter_dict(values, predicate))),
444 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
445 }
446 DataType::Struct(_) => {
447 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
448 }
449 DataType::Union(_, UnionMode::Sparse) => {
450 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
451 }
452 _ => {
453 let data = values.to_data();
454 let mut mutable = MutableArrayData::new(
456 vec![&data],
457 false,
458 predicate.count,
459 );
460
461 match &predicate.strategy {
462 IterationStrategy::Slices(slices) => {
463 slices
464 .iter()
465 .for_each(|(start, end)| mutable.extend(0, *start, *end));
466 }
467 _ => {
468 let iter = SlicesIterator::new(&predicate.filter);
469 iter.for_each(|(start, end)| mutable.extend(0, start, end));
470 }
471 }
472
473 let data = mutable.freeze();
474 Ok(make_array(data))
475 }
476 },
477 }
478}
479
480fn filter_run_end_array<R: RunEndIndexType>(
482 array: &RunArray<R>,
483 predicate: &FilterPredicate,
484) -> Result<RunArray<R>, ArrowError>
485where
486 R::Native: Into<i64> + From<bool>,
487 R::Native: AddAssign,
488{
489 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
490 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
491
492 let mut start = 0u64;
493 let mut j = 0;
494 let mut count = R::default_value();
495 let filter_values = predicate.filter.values();
496 let run_ends = run_ends.inner();
497
498 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
499 let mut keep = false;
500 let mut end = run_ends[i].into() as u64;
501 let difference = end.saturating_sub(filter_values.len() as u64);
502 end -= difference;
503
504 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
506 count += R::Native::from(pred);
507 keep |= pred
508 }
509 new_run_ends[j] = count;
511 j += keep as usize;
512
513 start = end;
514 keep
515 })
516 .into();
517
518 new_run_ends.truncate(j);
519
520 let values = array.values();
521 let values = filter(&values, &pred)?;
522
523 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
524 RunArray::try_new(&run_ends, &values)
525}
526
527fn filter_null_mask(
534 nulls: Option<&NullBuffer>,
535 predicate: &FilterPredicate,
536) -> Option<(usize, Buffer)> {
537 let nulls = nulls?;
538 if nulls.null_count() == 0 {
539 return None;
540 }
541
542 let nulls = filter_bits(nulls.inner(), predicate);
543 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
546
547 if null_count == 0 {
548 return None;
549 }
550
551 Some((null_count, nulls))
552}
553
554fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
556 let src = buffer.values();
557 let offset = buffer.offset();
558 assert!(buffer.len() >= predicate.filter.len());
559
560 match &predicate.strategy {
561 IterationStrategy::IndexIterator => {
562 let bits =
563 IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
565 bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
566 });
567
568 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
570 }
571 IterationStrategy::Indices(indices) => {
572 let bits = indices.iter().map(|src_idx| unsafe {
574 bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
575 });
576 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
578 }
579 IterationStrategy::SlicesIterator => {
580 let mut builder = BooleanBufferBuilder::new(predicate.count);
581 for (start, end) in SlicesIterator::new(&predicate.filter) {
582 builder.append_packed_range(start + offset..end + offset, src)
583 }
584 builder.into()
585 }
586 IterationStrategy::Slices(slices) => {
587 let mut builder = BooleanBufferBuilder::new(predicate.count);
588 for (start, end) in slices {
589 builder.append_packed_range(*start + offset..*end + offset, src)
590 }
591 builder.into()
592 }
593 IterationStrategy::All | IterationStrategy::None => unreachable!(),
594 }
595}
596
597fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
599 let values = filter_bits(array.values(), predicate);
600
601 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
602 .len(predicate.count)
603 .add_buffer(values);
604
605 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
606 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
607 }
608
609 let data = unsafe { builder.build_unchecked() };
610 BooleanArray::from(data)
611}
612
613#[inline(never)]
614fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
615 assert!(values.len() >= predicate.filter.len());
616
617 match &predicate.strategy {
618 IterationStrategy::SlicesIterator => {
619 let mut buffer = Vec::with_capacity(predicate.count);
620 for (start, end) in SlicesIterator::new(&predicate.filter) {
621 buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
623 }
624 buffer.into()
625 }
626 IterationStrategy::Slices(slices) => {
627 let mut buffer = Vec::with_capacity(predicate.count);
628 for (start, end) in slices {
629 buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
631 }
632 buffer.into()
633 }
634 IterationStrategy::IndexIterator => {
635 let iter = IndexIterator::new(&predicate.filter, predicate.count)
637 .map(|x| unsafe { *values.get_unchecked(x) });
638
639 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
641 }
642 IterationStrategy::Indices(indices) => {
643 let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
645 iter.collect::<Vec<_>>().into()
646 }
647 IterationStrategy::All | IterationStrategy::None => unreachable!(),
648 }
649}
650
651fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
653where
654 T: ArrowPrimitiveType,
655{
656 let values = array.values();
657 let buffer = filter_native(values, predicate);
658 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
659 .len(predicate.count)
660 .add_buffer(buffer);
661
662 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
663 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
664 }
665
666 let data = unsafe { builder.build_unchecked() };
667 PrimitiveArray::from(data)
668}
669
670struct FilterBytes<'a, OffsetSize> {
675 src_offsets: &'a [OffsetSize],
676 src_values: &'a [u8],
677 dst_offsets: Vec<OffsetSize>,
678 dst_values: Vec<u8>,
679 cur_offset: OffsetSize,
680}
681
682impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
683where
684 OffsetSize: OffsetSizeTrait,
685{
686 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
687 where
688 T: ByteArrayType<Offset = OffsetSize>,
689 {
690 let dst_values = Vec::new();
691 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
692 let cur_offset = OffsetSize::from_usize(0).unwrap();
693
694 dst_offsets.push(cur_offset);
695
696 Self {
697 src_offsets: array.value_offsets(),
698 src_values: array.value_data(),
699 dst_offsets,
700 dst_values,
701 cur_offset,
702 }
703 }
704
705 #[inline]
707 fn get_value_offset(&self, idx: usize) -> usize {
708 self.src_offsets[idx].as_usize()
709 }
710
711 #[inline]
713 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
714 let start = self.get_value_offset(idx);
716 let end = self.get_value_offset(idx + 1);
717 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
718 (start, end, len)
719 }
720
721 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
722 self.dst_offsets.extend(iter.map(|idx| {
723 let start = self.src_offsets[idx].as_usize();
724 let end = self.src_offsets[idx + 1].as_usize();
725 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
726 self.cur_offset += len;
727
728 self.cur_offset
729 }));
730 }
731
732 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
734 self.dst_values.reserve_exact(self.cur_offset.as_usize());
735
736 for idx in iter {
737 let start = self.src_offsets[idx].as_usize();
738 let end = self.src_offsets[idx + 1].as_usize();
739 self.dst_values
740 .extend_from_slice(&self.src_values[start..end]);
741 }
742 }
743
744 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
745 self.dst_offsets.reserve_exact(count);
746 for (start, end) in iter {
747 for idx in start..end {
749 let (_, _, len) = self.get_value_range(idx);
750 self.cur_offset += len;
751 self.dst_offsets.push(self.cur_offset);
752 }
753 }
754 }
755
756 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
758 self.dst_values.reserve_exact(self.cur_offset.as_usize());
759
760 for (start, end) in iter {
761 let value_start = self.get_value_offset(start);
762 let value_end = self.get_value_offset(end);
763 self.dst_values
764 .extend_from_slice(&self.src_values[value_start..value_end]);
765 }
766 }
767}
768
769fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
774where
775 T: ByteArrayType,
776{
777 let mut filter = FilterBytes::new(predicate.count, array);
778
779 match &predicate.strategy {
780 IterationStrategy::SlicesIterator => {
781 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
782 filter.extend_slices(SlicesIterator::new(&predicate.filter))
783 }
784 IterationStrategy::Slices(slices) => {
785 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
786 filter.extend_slices(slices.iter().cloned())
787 }
788 IterationStrategy::IndexIterator => {
789 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
790 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
791 }
792 IterationStrategy::Indices(indices) => {
793 filter.extend_offsets_idx(indices.iter().cloned());
794 filter.extend_idx(indices.iter().cloned())
795 }
796 IterationStrategy::All | IterationStrategy::None => unreachable!(),
797 }
798
799 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
800 .len(predicate.count)
801 .add_buffer(filter.dst_offsets.into())
802 .add_buffer(filter.dst_values.into());
803
804 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
805 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
806 }
807
808 let data = unsafe { builder.build_unchecked() };
809 GenericByteArray::from(data)
810}
811
812fn filter_byte_view<T: ByteViewType>(
814 array: &GenericByteViewArray<T>,
815 predicate: &FilterPredicate,
816) -> GenericByteViewArray<T> {
817 let new_view_buffer = filter_native(array.views(), predicate);
818
819 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
820 .len(predicate.count)
821 .add_buffer(new_view_buffer)
822 .add_buffers(array.data_buffers().to_vec());
823
824 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
825 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
826 }
827
828 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
829}
830
831fn filter_fixed_size_binary(
832 array: &FixedSizeBinaryArray,
833 predicate: &FilterPredicate,
834) -> FixedSizeBinaryArray {
835 let values: &[u8] = array.values();
836 let value_length = array.value_length() as usize;
837 let calculate_offset_from_index = |index: usize| index * value_length;
838 let buffer = match &predicate.strategy {
839 IterationStrategy::SlicesIterator => {
840 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
841 for (start, end) in SlicesIterator::new(&predicate.filter) {
842 buffer.extend_from_slice(
843 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
844 );
845 }
846 buffer
847 }
848 IterationStrategy::Slices(slices) => {
849 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
850 for (start, end) in slices {
851 buffer.extend_from_slice(
852 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
853 );
854 }
855 buffer
856 }
857 IterationStrategy::IndexIterator => {
858 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
859 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
860 });
861
862 let mut buffer = MutableBuffer::new(predicate.count * value_length);
863 iter.for_each(|item| buffer.extend_from_slice(item));
864 buffer
865 }
866 IterationStrategy::Indices(indices) => {
867 let iter = indices.iter().map(|x| {
868 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
869 });
870
871 let mut buffer = MutableBuffer::new(predicate.count * value_length);
872 iter.for_each(|item| buffer.extend_from_slice(item));
873 buffer
874 }
875 IterationStrategy::All | IterationStrategy::None => unreachable!(),
876 };
877 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
878 .len(predicate.count)
879 .add_buffer(buffer.into());
880
881 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
882 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
883 }
884
885 let data = unsafe { builder.build_unchecked() };
886 FixedSizeBinaryArray::from(data)
887}
888
889fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
891where
892 T: ArrowDictionaryKeyType,
893 T::Native: num_traits::Num,
894{
895 let builder = filter_primitive::<T>(array.keys(), predicate)
896 .into_data()
897 .into_builder()
898 .data_type(array.data_type().clone())
899 .child_data(vec![array.values().to_data()]);
900
901 DictionaryArray::from(unsafe { builder.build_unchecked() })
904}
905
906fn filter_struct(
908 array: &StructArray,
909 predicate: &FilterPredicate,
910) -> Result<StructArray, ArrowError> {
911 let columns = array
912 .columns()
913 .iter()
914 .map(|column| filter_array(column, predicate))
915 .collect::<Result<_, _>>()?;
916
917 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
918 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
919
920 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
921 } else {
922 None
923 };
924
925 Ok(unsafe {
926 StructArray::new_unchecked_with_length(
927 array.fields().clone(),
928 columns,
929 nulls,
930 predicate.count(),
931 )
932 })
933}
934
935fn filter_sparse_union(
937 array: &UnionArray,
938 predicate: &FilterPredicate,
939) -> Result<UnionArray, ArrowError> {
940 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
941 unreachable!()
942 };
943
944 let type_ids = filter_primitive(
945 &Int8Array::try_new(array.type_ids().clone(), None)?,
946 predicate,
947 );
948
949 let children = fields
950 .iter()
951 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
952 .collect::<Result<_, _>>()?;
953
954 Ok(unsafe {
955 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
956 })
957}
958
959fn filter_list_view<OffsetType: OffsetSizeTrait>(
961 array: &GenericListViewArray<OffsetType>,
962 predicate: &FilterPredicate,
963) -> GenericListViewArray<OffsetType> {
964 let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
965 let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
966
967 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
969 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
970
971 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
972 } else {
973 None
974 };
975
976 let list_data = ArrayDataBuilder::new(array.data_type().clone())
977 .nulls(nulls)
978 .buffers(vec![filtered_offsets, filtered_sizes])
979 .child_data(vec![array.values().to_data()])
980 .len(predicate.count);
981
982 let list_data = unsafe { list_data.build_unchecked() };
983
984 GenericListViewArray::from(list_data)
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use arrow_array::builder::*;
991 use arrow_array::cast::as_run_array;
992 use arrow_array::types::*;
993 use arrow_data::ArrayData;
994 use rand::distr::uniform::{UniformSampler, UniformUsize};
995 use rand::distr::{Alphanumeric, StandardUniform};
996 use rand::prelude::*;
997 use rand::rng;
998
999 macro_rules! def_temporal_test {
1000 ($test:ident, $array_type: ident, $data: expr) => {
1001 #[test]
1002 fn $test() {
1003 let a = $data;
1004 let b = BooleanArray::from(vec![true, false, true, false]);
1005 let c = filter(&a, &b).unwrap();
1006 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1007 assert_eq!(2, d.len());
1008 assert_eq!(1, d.value(0));
1009 assert_eq!(3, d.value(1));
1010 }
1011 };
1012 }
1013
1014 def_temporal_test!(
1015 test_filter_date32,
1016 Date32Array,
1017 Date32Array::from(vec![1, 2, 3, 4])
1018 );
1019 def_temporal_test!(
1020 test_filter_date64,
1021 Date64Array,
1022 Date64Array::from(vec![1, 2, 3, 4])
1023 );
1024 def_temporal_test!(
1025 test_filter_time32_second,
1026 Time32SecondArray,
1027 Time32SecondArray::from(vec![1, 2, 3, 4])
1028 );
1029 def_temporal_test!(
1030 test_filter_time32_millisecond,
1031 Time32MillisecondArray,
1032 Time32MillisecondArray::from(vec![1, 2, 3, 4])
1033 );
1034 def_temporal_test!(
1035 test_filter_time64_microsecond,
1036 Time64MicrosecondArray,
1037 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1038 );
1039 def_temporal_test!(
1040 test_filter_time64_nanosecond,
1041 Time64NanosecondArray,
1042 Time64NanosecondArray::from(vec![1, 2, 3, 4])
1043 );
1044 def_temporal_test!(
1045 test_filter_duration_second,
1046 DurationSecondArray,
1047 DurationSecondArray::from(vec![1, 2, 3, 4])
1048 );
1049 def_temporal_test!(
1050 test_filter_duration_millisecond,
1051 DurationMillisecondArray,
1052 DurationMillisecondArray::from(vec![1, 2, 3, 4])
1053 );
1054 def_temporal_test!(
1055 test_filter_duration_microsecond,
1056 DurationMicrosecondArray,
1057 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1058 );
1059 def_temporal_test!(
1060 test_filter_duration_nanosecond,
1061 DurationNanosecondArray,
1062 DurationNanosecondArray::from(vec![1, 2, 3, 4])
1063 );
1064 def_temporal_test!(
1065 test_filter_timestamp_second,
1066 TimestampSecondArray,
1067 TimestampSecondArray::from(vec![1, 2, 3, 4])
1068 );
1069 def_temporal_test!(
1070 test_filter_timestamp_millisecond,
1071 TimestampMillisecondArray,
1072 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1073 );
1074 def_temporal_test!(
1075 test_filter_timestamp_microsecond,
1076 TimestampMicrosecondArray,
1077 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1078 );
1079 def_temporal_test!(
1080 test_filter_timestamp_nanosecond,
1081 TimestampNanosecondArray,
1082 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1083 );
1084
1085 #[test]
1086 fn test_filter_array_slice() {
1087 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1088 let b = BooleanArray::from(vec![true, false, false, true]);
1089 let c = filter(&a, &b).unwrap();
1093 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1094 assert_eq!(2, d.len());
1095 assert_eq!(6, d.value(0));
1096 assert_eq!(9, d.value(1));
1097 }
1098
1099 #[test]
1100 fn test_filter_array_low_density() {
1101 let mut data_values = (1..=65).collect::<Vec<i32>>();
1103 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1104 data_values.extend_from_slice(&[66, 67]);
1106 filter_values.extend_from_slice(&[false, true]);
1107 let a = Int32Array::from(data_values);
1108 let b = BooleanArray::from(filter_values);
1109 let c = filter(&a, &b).unwrap();
1110 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1111 assert_eq!(2, d.len());
1112 assert_eq!(65, d.value(0));
1113 assert_eq!(67, d.value(1));
1114 }
1115
1116 #[test]
1117 fn test_filter_array_high_density() {
1118 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1120 let mut filter_values = (1..=65)
1121 .map(|i| !matches!(i % 65, 0))
1122 .collect::<Vec<bool>>();
1123 data_values[1] = None;
1125 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1127 filter_values.extend_from_slice(&[false, true, true, true]);
1128 let a = Int32Array::from(data_values);
1129 let b = BooleanArray::from(filter_values);
1130 let c = filter(&a, &b).unwrap();
1131 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1132 assert_eq!(67, d.len());
1133 assert_eq!(3, d.null_count());
1134 assert_eq!(1, d.value(0));
1135 assert!(d.is_null(1));
1136 assert_eq!(64, d.value(63));
1137 assert!(d.is_null(64));
1138 assert_eq!(67, d.value(65));
1139 }
1140
1141 #[test]
1142 fn test_filter_string_array_simple() {
1143 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1144 let b = BooleanArray::from(vec![true, false, true, false]);
1145 let c = filter(&a, &b).unwrap();
1146 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1147 assert_eq!(2, d.len());
1148 assert_eq!("hello", d.value(0));
1149 assert_eq!("world", d.value(1));
1150 }
1151
1152 #[test]
1153 fn test_filter_primitive_array_with_null() {
1154 let a = Int32Array::from(vec![Some(5), None]);
1155 let b = BooleanArray::from(vec![false, true]);
1156 let c = filter(&a, &b).unwrap();
1157 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1158 assert_eq!(1, d.len());
1159 assert!(d.is_null(0));
1160 }
1161
1162 #[test]
1163 fn test_filter_string_array_with_null() {
1164 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1165 let b = BooleanArray::from(vec![true, false, false, true]);
1166 let c = filter(&a, &b).unwrap();
1167 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1168 assert_eq!(2, d.len());
1169 assert_eq!("hello", d.value(0));
1170 assert!(!d.is_null(0));
1171 assert!(d.is_null(1));
1172 }
1173
1174 #[test]
1175 fn test_filter_binary_array_with_null() {
1176 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1177 let a = BinaryArray::from(data);
1178 let b = BooleanArray::from(vec![true, false, false, true]);
1179 let c = filter(&a, &b).unwrap();
1180 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1181 assert_eq!(2, d.len());
1182 assert_eq!(b"hello", d.value(0));
1183 assert!(!d.is_null(0));
1184 assert!(d.is_null(1));
1185 }
1186
1187 fn _test_filter_byte_view<T>()
1188 where
1189 T: ByteViewType,
1190 str: AsRef<T::Native>,
1191 T::Native: PartialEq,
1192 {
1193 let array = {
1194 let mut builder = GenericByteViewBuilder::<T>::new();
1196 builder.append_value("hello");
1197 builder.append_value("world");
1198 builder.append_null();
1199 builder.append_value("large payload over 12 bytes");
1200 builder.append_value("lulu");
1201 builder.finish()
1202 };
1203
1204 {
1205 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1206 let actual = filter(&array, &predicate).unwrap();
1207
1208 assert_eq!(actual.len(), 3);
1209
1210 let expected = {
1211 let mut builder = GenericByteViewBuilder::<T>::new();
1213 builder.append_value("hello");
1214 builder.append_null();
1215 builder.append_value("large payload over 12 bytes");
1216 builder.finish()
1217 };
1218
1219 assert_eq!(actual.as_ref(), &expected);
1220 }
1221
1222 {
1223 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1224 let actual = filter(&array, &predicate).unwrap();
1225
1226 assert_eq!(actual.len(), 2);
1227
1228 let expected = {
1229 let mut builder = GenericByteViewBuilder::<T>::new();
1231 builder.append_value("hello");
1232 builder.append_value("lulu");
1233 builder.finish()
1234 };
1235
1236 assert_eq!(actual.as_ref(), &expected);
1237 }
1238 }
1239
1240 #[test]
1241 fn test_filter_string_view() {
1242 _test_filter_byte_view::<StringViewType>()
1243 }
1244
1245 #[test]
1246 fn test_filter_binary_view() {
1247 _test_filter_byte_view::<BinaryViewType>()
1248 }
1249
1250 #[test]
1251 fn test_filter_fixed_binary() {
1252 let v1 = [1_u8, 2];
1253 let v2 = [3_u8, 4];
1254 let v3 = [5_u8, 6];
1255 let v = vec![&v1, &v2, &v3];
1256 let a = FixedSizeBinaryArray::from(v);
1257 let b = BooleanArray::from(vec![true, false, true]);
1258 let c = filter(&a, &b).unwrap();
1259 let d = c
1260 .as_ref()
1261 .as_any()
1262 .downcast_ref::<FixedSizeBinaryArray>()
1263 .unwrap();
1264 assert_eq!(d.len(), 2);
1265 assert_eq!(d.value(0), &v1);
1266 assert_eq!(d.value(1), &v3);
1267 let c2 = FilterBuilder::new(&b)
1268 .optimize()
1269 .build()
1270 .filter(&a)
1271 .unwrap();
1272 let d2 = c2
1273 .as_ref()
1274 .as_any()
1275 .downcast_ref::<FixedSizeBinaryArray>()
1276 .unwrap();
1277 assert_eq!(d, d2);
1278
1279 let b = BooleanArray::from(vec![false, false, false]);
1280 let c = filter(&a, &b).unwrap();
1281 let d = c
1282 .as_ref()
1283 .as_any()
1284 .downcast_ref::<FixedSizeBinaryArray>()
1285 .unwrap();
1286 assert_eq!(d.len(), 0);
1287
1288 let b = BooleanArray::from(vec![true, true, true]);
1289 let c = filter(&a, &b).unwrap();
1290 let d = c
1291 .as_ref()
1292 .as_any()
1293 .downcast_ref::<FixedSizeBinaryArray>()
1294 .unwrap();
1295 assert_eq!(d.len(), 3);
1296 assert_eq!(d.value(0), &v1);
1297 assert_eq!(d.value(1), &v2);
1298 assert_eq!(d.value(2), &v3);
1299
1300 let b = BooleanArray::from(vec![false, false, true]);
1301 let c = filter(&a, &b).unwrap();
1302 let d = c
1303 .as_ref()
1304 .as_any()
1305 .downcast_ref::<FixedSizeBinaryArray>()
1306 .unwrap();
1307 assert_eq!(d.len(), 1);
1308 assert_eq!(d.value(0), &v3);
1309 let c2 = FilterBuilder::new(&b)
1310 .optimize()
1311 .build()
1312 .filter(&a)
1313 .unwrap();
1314 let d2 = c2
1315 .as_ref()
1316 .as_any()
1317 .downcast_ref::<FixedSizeBinaryArray>()
1318 .unwrap();
1319 assert_eq!(d, d2);
1320 }
1321
1322 #[test]
1323 fn test_filter_array_slice_with_null() {
1324 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1325 let b = BooleanArray::from(vec![true, false, false, true]);
1326 let c = filter(&a, &b).unwrap();
1330 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1331 assert_eq!(2, d.len());
1332 assert!(d.is_null(0));
1333 assert!(!d.is_null(1));
1334 assert_eq!(9, d.value(1));
1335 }
1336
1337 #[test]
1338 fn test_filter_run_end_encoding_array() {
1339 let run_ends = Int64Array::from(vec![2, 3, 8]);
1340 let values = Int64Array::from(vec![7, -2, 9]);
1341 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1342 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1343 let c = filter(&a, &b).unwrap();
1344 let actual: &RunArray<Int64Type> = as_run_array(&c);
1345 assert_eq!(4, actual.len());
1346
1347 let expected = RunArray::try_new(
1348 &Int64Array::from(vec![1, 2, 4]),
1349 &Int64Array::from(vec![7, -2, 9]),
1350 )
1351 .expect("Failed to make expected RunArray test is broken");
1352
1353 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1354 assert_eq!(actual.values(), expected.values())
1355 }
1356
1357 #[test]
1358 fn test_filter_run_end_encoding_array_remove_value() {
1359 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1360 let values = Int32Array::from(vec![7, -2, 9, -8]);
1361 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1362 let b = BooleanArray::from(vec![
1363 false, true, false, false, true, false, true, false, false, false,
1364 ]);
1365 let c = filter(&a, &b).unwrap();
1366 let actual: &RunArray<Int32Type> = as_run_array(&c);
1367 assert_eq!(3, actual.len());
1368
1369 let expected =
1370 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1371 .expect("Failed to make expected RunArray test is broken");
1372
1373 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1374 assert_eq!(actual.values(), expected.values())
1375 }
1376
1377 #[test]
1378 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1379 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1380 let values = Int16Array::from(vec![7, -2, 9, -8]);
1381 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1382 let b = BooleanArray::from(vec![
1383 false, false, false, false, false, false, true, false, false, false,
1384 ]);
1385 let c = filter(&a, &b).unwrap();
1386 let actual: &RunArray<Int16Type> = as_run_array(&c);
1387 assert_eq!(1, actual.len());
1388
1389 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1390 .expect("Failed to make expected RunArray test is broken");
1391
1392 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1393 assert_eq!(actual.values(), expected.values())
1394 }
1395
1396 #[test]
1397 fn test_filter_run_end_encoding_array_empty() {
1398 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1399 let values = Int64Array::from(vec![7, -2, 9, -8]);
1400 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1401 let b = BooleanArray::from(vec![
1402 false, false, false, false, false, false, false, false, false, false,
1403 ]);
1404 let c = filter(&a, &b).unwrap();
1405 let actual: &RunArray<Int64Type> = as_run_array(&c);
1406 assert_eq!(0, actual.len());
1407 }
1408
1409 #[test]
1410 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1411 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1412 let values = Int64Array::from(vec![7, -2, 9, -8]);
1413 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1414 let b = BooleanArray::from(vec![false, true, true]);
1415 let c = filter(&a, &b).unwrap();
1416 let actual: &RunArray<Int64Type> = as_run_array(&c);
1417 assert_eq!(2, actual.len());
1418
1419 let expected = RunArray::try_new(
1420 &Int64Array::from(vec![1, 2]),
1421 &Int64Array::from(vec![7, -2]),
1422 )
1423 .expect("Failed to make expected RunArray test is broken");
1424
1425 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1426 assert_eq!(actual.values(), expected.values())
1427 }
1428
1429 #[test]
1430 fn test_filter_dictionary_array() {
1431 let values = [Some("hello"), None, Some("world"), Some("!")];
1432 let a: Int8DictionaryArray = values.iter().copied().collect();
1433 let b = BooleanArray::from(vec![false, true, true, false]);
1434 let c = filter(&a, &b).unwrap();
1435 let d = c
1436 .as_ref()
1437 .as_any()
1438 .downcast_ref::<Int8DictionaryArray>()
1439 .unwrap();
1440 let value_array = d.values();
1441 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1442 assert_eq!(3, values.len());
1444 assert_eq!(2, d.len());
1446 assert!(d.is_null(0));
1447 assert_eq!("world", values.value(d.keys().value(1) as usize));
1448 }
1449
1450 #[test]
1451 fn test_filter_list_array() {
1452 let value_data = ArrayData::builder(DataType::Int32)
1453 .len(8)
1454 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1455 .build()
1456 .unwrap();
1457
1458 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1459
1460 let list_data_type =
1461 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1462 let list_data = ArrayData::builder(list_data_type)
1463 .len(4)
1464 .add_buffer(value_offsets)
1465 .add_child_data(value_data)
1466 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1467 .build()
1468 .unwrap();
1469
1470 let a = LargeListArray::from(list_data);
1472 let b = BooleanArray::from(vec![false, true, false, true]);
1473 let result = filter(&a, &b).unwrap();
1474
1475 let value_data = ArrayData::builder(DataType::Int32)
1477 .len(3)
1478 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1479 .build()
1480 .unwrap();
1481
1482 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1483
1484 let list_data_type =
1485 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1486 let expected = ArrayData::builder(list_data_type)
1487 .len(2)
1488 .add_buffer(value_offsets)
1489 .add_child_data(value_data)
1490 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1491 .build()
1492 .unwrap();
1493
1494 assert_eq!(&make_array(expected), &result);
1495 }
1496
1497 fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1498 let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1500 list_array.append_value([Some(1), Some(2)]);
1501 list_array.append_null();
1502 list_array.append_value([]);
1503 list_array.append_value([Some(3), Some(4)]);
1504
1505 let list_array = list_array.finish();
1506 let predicate = BooleanArray::from_iter([true, false, true, false]);
1507
1508 let filtered = filter(&list_array, &predicate)
1510 .unwrap()
1511 .as_list_view::<T>()
1512 .clone();
1513
1514 let mut expected =
1515 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1516 expected.append_value([Some(1), Some(2)]);
1517 expected.append_value([]);
1518 let expected = expected.finish();
1519
1520 assert_eq!(&filtered, &expected);
1521 }
1522
1523 fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1524 let mut list_array =
1526 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1527 list_array.append_value([Some(1), Some(2)]);
1528 list_array.append_null();
1529 list_array.append_value([]);
1530 list_array.append_value([Some(3), Some(4)]);
1531
1532 let list_array = list_array.finish();
1533
1534 let sliced = list_array.slice(1, 3);
1536 let predicate = BooleanArray::from_iter([false, false, true]);
1537
1538 let filtered = filter(&sliced, &predicate)
1540 .unwrap()
1541 .as_list_view::<T>()
1542 .clone();
1543
1544 let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1545 expected.append_value([Some(3), Some(4)]);
1546 let expected = expected.finish();
1547
1548 assert_eq!(&filtered, &expected);
1549 }
1550
1551 #[test]
1552 fn test_filter_list_view_array() {
1553 test_case_filter_list_view::<i32>();
1554 test_case_filter_list_view::<i64>();
1555
1556 test_case_filter_sliced_list_view::<i32>();
1557 test_case_filter_sliced_list_view::<i64>();
1558 }
1559
1560 #[test]
1561 fn test_slice_iterator_bits() {
1562 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1563 let filter = BooleanArray::from(filter_values);
1564 let filter_count = filter_count(&filter);
1565
1566 let iter = SlicesIterator::new(&filter);
1567 let chunks = iter.collect::<Vec<_>>();
1568
1569 assert_eq!(chunks, vec![(1, 2)]);
1570 assert_eq!(filter_count, 1);
1571 }
1572
1573 #[test]
1574 fn test_slice_iterator_bits1() {
1575 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1576 let filter = BooleanArray::from(filter_values);
1577 let filter_count = filter_count(&filter);
1578
1579 let iter = SlicesIterator::new(&filter);
1580 let chunks = iter.collect::<Vec<_>>();
1581
1582 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1583 assert_eq!(filter_count, 64 - 1);
1584 }
1585
1586 #[test]
1587 fn test_slice_iterator_chunk_and_bits() {
1588 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1589 let filter = BooleanArray::from(filter_values);
1590 let filter_count = filter_count(&filter);
1591
1592 let iter = SlicesIterator::new(&filter);
1593 let chunks = iter.collect::<Vec<_>>();
1594
1595 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1596 assert_eq!(filter_count, 61 + 61 + 5);
1597 }
1598
1599 #[test]
1600 fn test_null_mask() {
1601 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1602
1603 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1604 let out = filter(&a, &mask1).unwrap();
1605 assert_eq!(out.as_ref(), &a.slice(0, 2));
1606 }
1607
1608 #[test]
1609 fn test_filter_record_batch_no_columns() {
1610 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1611 let options = RecordBatchOptions::default().with_row_count(Some(100));
1612 let record_batch =
1613 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1614 let out = filter_record_batch(&record_batch, &pred).unwrap();
1615
1616 assert_eq!(out.num_rows(), 2);
1617 }
1618
1619 #[test]
1620 fn test_fast_path() {
1621 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1622
1623 let mask = BooleanArray::from(vec![true, true, true]);
1625 let out = filter(&a, &mask).unwrap();
1626 let b = out
1627 .as_any()
1628 .downcast_ref::<PrimitiveArray<Int64Type>>()
1629 .unwrap();
1630 assert_eq!(&a, b);
1631
1632 let mask = BooleanArray::from(vec![false, false, false]);
1634 let out = filter(&a, &mask).unwrap();
1635 assert_eq!(out.len(), 0);
1636 assert_eq!(out.data_type(), &DataType::Int64);
1637 }
1638
1639 #[test]
1640 fn test_slices() {
1641 let bools = std::iter::repeat_n(true, 10)
1643 .chain(std::iter::repeat_n(false, 30))
1644 .chain(std::iter::repeat_n(true, 20))
1645 .chain(std::iter::repeat_n(false, 17))
1646 .chain(std::iter::repeat_n(true, 4));
1647
1648 let bool_array: BooleanArray = bools.map(Some).collect();
1649
1650 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1651 let expected = vec![(0, 10), (40, 60), (77, 81)];
1652 assert_eq!(slices, expected);
1653
1654 let len = bool_array.len();
1656 let sliced_array = bool_array.slice(7, len - 10);
1657 let sliced_array = sliced_array
1658 .as_any()
1659 .downcast_ref::<BooleanArray>()
1660 .unwrap();
1661 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1662 let expected = vec![(0, 3), (33, 53), (70, 71)];
1663 assert_eq!(slices, expected);
1664 }
1665
1666 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1667 let mut rng = rng();
1668
1669 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1670 .take(mask_len)
1671 .collect();
1672
1673 let buffer = Buffer::from_iter(bools.iter().cloned());
1674
1675 let truncated_length = mask_len - offset - truncate;
1676
1677 let data = ArrayDataBuilder::new(DataType::Boolean)
1678 .len(truncated_length)
1679 .offset(offset)
1680 .add_buffer(buffer)
1681 .build()
1682 .unwrap();
1683
1684 let filter = BooleanArray::from(data);
1685
1686 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1687 .flat_map(|(start, end)| start..end)
1688 .collect();
1689
1690 let count = filter_count(&filter);
1691 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1692
1693 let expected_bits: Vec<_> = bools
1694 .iter()
1695 .skip(offset)
1696 .take(truncated_length)
1697 .enumerate()
1698 .flat_map(|(idx, v)| v.then(|| idx))
1699 .collect();
1700
1701 assert_eq!(slice_bits, expected_bits);
1702 assert_eq!(index_bits, expected_bits);
1703 }
1704
1705 #[test]
1706 #[cfg_attr(miri, ignore)]
1707 fn fuzz_test_slices_iterator() {
1708 let mut rng = rng();
1709
1710 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1711 for _ in 0..100 {
1712 let mask_len = rng.random_range(0..1024);
1713 let max_offset = 64.min(mask_len);
1714 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1715
1716 let max_truncate = 128.min(mask_len - offset);
1717 let truncate = uusize
1718 .sample(&mut rng)
1719 .checked_rem(max_truncate)
1720 .unwrap_or(0);
1721
1722 test_slices_fuzz(mask_len, offset, truncate);
1723 }
1724
1725 test_slices_fuzz(64, 0, 0);
1726 test_slices_fuzz(64, 8, 0);
1727 test_slices_fuzz(64, 8, 8);
1728 test_slices_fuzz(32, 8, 8);
1729 test_slices_fuzz(32, 5, 9);
1730 }
1731
1732 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1734 values
1735 .into_iter()
1736 .zip(predicate)
1737 .filter(|(_, x)| **x)
1738 .map(|(a, _)| a)
1739 .collect()
1740 }
1741
1742 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1744 where
1745 StandardUniform: Distribution<T>,
1746 {
1747 let mut rng = rng();
1748 (0..len)
1749 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1750 .collect()
1751 }
1752
1753 fn gen_strings(
1755 len: usize,
1756 valid_percent: f64,
1757 str_len_range: std::ops::Range<usize>,
1758 ) -> Vec<Option<String>> {
1759 let mut rng = rng();
1760 (0..len)
1761 .map(|_| {
1762 rng.random_bool(valid_percent).then(|| {
1763 let len = rng.random_range(str_len_range.clone());
1764 (0..len)
1765 .map(|_| char::from(rng.sample(Alphanumeric)))
1766 .collect()
1767 })
1768 })
1769 .collect()
1770 }
1771
1772 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1774 src.iter().map(|x| x.as_deref())
1775 }
1776
1777 #[test]
1778 #[cfg_attr(miri, ignore)]
1779 fn fuzz_filter() {
1780 let mut rng = rng();
1781
1782 for i in 0..100 {
1783 let filter_percent = match i {
1784 0..=4 => 1.,
1785 5..=10 => 0.,
1786 _ => rng.random_range(0.0..1.0),
1787 };
1788
1789 let valid_percent = rng.random_range(0.0..1.0);
1790
1791 let array_len = rng.random_range(32..256);
1792 let array_offset = rng.random_range(0..10);
1793
1794 let filter_offset = rng.random_range(0..10);
1796 let filter_truncate = rng.random_range(0..10);
1797 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1798 .take(array_len + filter_offset - filter_truncate)
1799 .collect();
1800
1801 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1802
1803 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1805 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1806 let bools = &bools[filter_offset..];
1807
1808 let values = gen_primitive(array_len + array_offset, valid_percent);
1810 let src = Int32Array::from_iter(values.iter().cloned());
1811
1812 let src = src.slice(array_offset, array_len);
1813 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1814 let values = &values[array_offset..];
1815
1816 let filtered = filter(src, predicate).unwrap();
1817 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1818 let actual: Vec<_> = array.iter().collect();
1819
1820 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1821
1822 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1824 let src = StringArray::from_iter(as_deref(&strings));
1825
1826 let src = src.slice(array_offset, array_len);
1827 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1828
1829 let filtered = filter(src, predicate).unwrap();
1830 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1831 let actual: Vec<_> = array.iter().collect();
1832
1833 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1834 assert_eq!(actual, expected_strings);
1835
1836 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1838
1839 let src = src.slice(array_offset, array_len);
1840 let src = src
1841 .as_any()
1842 .downcast_ref::<DictionaryArray<Int32Type>>()
1843 .unwrap();
1844
1845 let filtered = filter(src, predicate).unwrap();
1846
1847 let array = filtered
1848 .as_any()
1849 .downcast_ref::<DictionaryArray<Int32Type>>()
1850 .unwrap();
1851
1852 let values = array
1853 .values()
1854 .as_any()
1855 .downcast_ref::<StringArray>()
1856 .unwrap();
1857
1858 let actual: Vec<_> = array
1859 .keys()
1860 .iter()
1861 .map(|key| key.map(|key| values.value(key as usize)))
1862 .collect();
1863
1864 assert_eq!(actual, expected_strings);
1865 }
1866 }
1867
1868 #[test]
1869 fn test_filter_map() {
1870 let mut builder =
1871 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1872 builder.keys().append_value("key1");
1874 builder.values().append_value(1);
1875 builder.append(true).unwrap();
1876 builder.keys().append_value("key2");
1877 builder.keys().append_value("key3");
1878 builder.values().append_value(2);
1879 builder.values().append_value(3);
1880 builder.append(true).unwrap();
1881 builder.append(false).unwrap();
1882 builder.keys().append_value("key1");
1883 builder.values().append_value(1);
1884 builder.append(true).unwrap();
1885 let maparray = Arc::new(builder.finish()) as ArrayRef;
1886
1887 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1888 .into_iter()
1889 .collect::<BooleanArray>();
1890 let got = filter(&maparray, &indices).unwrap();
1891
1892 let mut builder =
1893 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1894 builder.keys().append_value("key1");
1895 builder.values().append_value(1);
1896 builder.append(true).unwrap();
1897 builder.keys().append_value("key1");
1898 builder.values().append_value(1);
1899 builder.append(true).unwrap();
1900 let expected = Arc::new(builder.finish()) as ArrayRef;
1901
1902 assert_eq!(&expected, &got);
1903 }
1904
1905 #[test]
1906 fn test_filter_fixed_size_list_arrays() {
1907 let value_data = ArrayData::builder(DataType::Int32)
1908 .len(9)
1909 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1910 .build()
1911 .unwrap();
1912 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1913 let list_data = ArrayData::builder(list_data_type)
1914 .len(3)
1915 .add_child_data(value_data)
1916 .build()
1917 .unwrap();
1918 let array = FixedSizeListArray::from(list_data);
1919
1920 let filter_array = BooleanArray::from(vec![true, false, false]);
1921
1922 let c = filter(&array, &filter_array).unwrap();
1923 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1924
1925 assert_eq!(filtered.len(), 1);
1926
1927 let list = filtered.value(0);
1928 assert_eq!(
1929 &[0, 1, 2],
1930 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1931 );
1932
1933 let filter_array = BooleanArray::from(vec![true, false, true]);
1934
1935 let c = filter(&array, &filter_array).unwrap();
1936 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1937
1938 assert_eq!(filtered.len(), 2);
1939
1940 let list = filtered.value(0);
1941 assert_eq!(
1942 &[0, 1, 2],
1943 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1944 );
1945 let list = filtered.value(1);
1946 assert_eq!(
1947 &[6, 7, 8],
1948 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1949 );
1950 }
1951
1952 #[test]
1953 fn test_filter_fixed_size_list_arrays_with_null() {
1954 let value_data = ArrayData::builder(DataType::Int32)
1955 .len(10)
1956 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1957 .build()
1958 .unwrap();
1959
1960 let mut null_bits: [u8; 1] = [0; 1];
1964 bit_util::set_bit(&mut null_bits, 0);
1965 bit_util::set_bit(&mut null_bits, 3);
1966 bit_util::set_bit(&mut null_bits, 4);
1967
1968 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1969 let list_data = ArrayData::builder(list_data_type)
1970 .len(5)
1971 .add_child_data(value_data)
1972 .null_bit_buffer(Some(Buffer::from(null_bits)))
1973 .build()
1974 .unwrap();
1975 let array = FixedSizeListArray::from(list_data);
1976
1977 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1978
1979 let c = filter(&array, &filter_array).unwrap();
1980 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1981
1982 assert_eq!(filtered.len(), 3);
1983
1984 let list = filtered.value(0);
1985 assert_eq!(
1986 &[0, 1],
1987 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1988 );
1989 assert!(filtered.is_null(1));
1990 let list = filtered.value(2);
1991 assert_eq!(
1992 &[6, 7],
1993 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1994 );
1995 }
1996
1997 fn test_filter_union_array(array: UnionArray) {
1998 let filter_array = BooleanArray::from(vec![true, false, false]);
1999 let c = filter(&array, &filter_array).unwrap();
2000 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2001
2002 let mut builder = UnionBuilder::new_dense();
2003 builder.append::<Int32Type>("A", 1).unwrap();
2004 let expected_array = builder.build().unwrap();
2005
2006 compare_union_arrays(filtered, &expected_array);
2007
2008 let filter_array = BooleanArray::from(vec![true, false, true]);
2009 let c = filter(&array, &filter_array).unwrap();
2010 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2011
2012 let mut builder = UnionBuilder::new_dense();
2013 builder.append::<Int32Type>("A", 1).unwrap();
2014 builder.append::<Int32Type>("A", 34).unwrap();
2015 let expected_array = builder.build().unwrap();
2016
2017 compare_union_arrays(filtered, &expected_array);
2018
2019 let filter_array = BooleanArray::from(vec![true, true, false]);
2020 let c = filter(&array, &filter_array).unwrap();
2021 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2022
2023 let mut builder = UnionBuilder::new_dense();
2024 builder.append::<Int32Type>("A", 1).unwrap();
2025 builder.append::<Float64Type>("B", 3.2).unwrap();
2026 let expected_array = builder.build().unwrap();
2027
2028 compare_union_arrays(filtered, &expected_array);
2029 }
2030
2031 #[test]
2032 fn test_filter_union_array_dense() {
2033 let mut builder = UnionBuilder::new_dense();
2034 builder.append::<Int32Type>("A", 1).unwrap();
2035 builder.append::<Float64Type>("B", 3.2).unwrap();
2036 builder.append::<Int32Type>("A", 34).unwrap();
2037 let array = builder.build().unwrap();
2038
2039 test_filter_union_array(array);
2040 }
2041
2042 #[test]
2043 fn test_filter_run_union_array_dense() {
2044 let mut builder = UnionBuilder::new_dense();
2045 builder.append::<Int32Type>("A", 1).unwrap();
2046 builder.append::<Int32Type>("A", 3).unwrap();
2047 builder.append::<Int32Type>("A", 34).unwrap();
2048 let array = builder.build().unwrap();
2049
2050 let filter_array = BooleanArray::from(vec![true, true, false]);
2051 let c = filter(&array, &filter_array).unwrap();
2052 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2053
2054 let mut builder = UnionBuilder::new_dense();
2055 builder.append::<Int32Type>("A", 1).unwrap();
2056 builder.append::<Int32Type>("A", 3).unwrap();
2057 let expected = builder.build().unwrap();
2058
2059 assert_eq!(filtered.to_data(), expected.to_data());
2060 }
2061
2062 #[test]
2063 fn test_filter_union_array_dense_with_nulls() {
2064 let mut builder = UnionBuilder::new_dense();
2065 builder.append::<Int32Type>("A", 1).unwrap();
2066 builder.append::<Float64Type>("B", 3.2).unwrap();
2067 builder.append_null::<Float64Type>("B").unwrap();
2068 builder.append::<Int32Type>("A", 34).unwrap();
2069 let array = builder.build().unwrap();
2070
2071 let filter_array = BooleanArray::from(vec![true, true, false, false]);
2072 let c = filter(&array, &filter_array).unwrap();
2073 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2074
2075 let mut builder = UnionBuilder::new_dense();
2076 builder.append::<Int32Type>("A", 1).unwrap();
2077 builder.append::<Float64Type>("B", 3.2).unwrap();
2078 let expected_array = builder.build().unwrap();
2079
2080 compare_union_arrays(filtered, &expected_array);
2081
2082 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2083 let c = filter(&array, &filter_array).unwrap();
2084 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2085
2086 let mut builder = UnionBuilder::new_dense();
2087 builder.append::<Int32Type>("A", 1).unwrap();
2088 builder.append_null::<Float64Type>("B").unwrap();
2089 let expected_array = builder.build().unwrap();
2090
2091 compare_union_arrays(filtered, &expected_array);
2092 }
2093
2094 #[test]
2095 fn test_filter_union_array_sparse() {
2096 let mut builder = UnionBuilder::new_sparse();
2097 builder.append::<Int32Type>("A", 1).unwrap();
2098 builder.append::<Float64Type>("B", 3.2).unwrap();
2099 builder.append::<Int32Type>("A", 34).unwrap();
2100 let array = builder.build().unwrap();
2101
2102 test_filter_union_array(array);
2103 }
2104
2105 #[test]
2106 fn test_filter_union_array_sparse_with_nulls() {
2107 let mut builder = UnionBuilder::new_sparse();
2108 builder.append::<Int32Type>("A", 1).unwrap();
2109 builder.append::<Float64Type>("B", 3.2).unwrap();
2110 builder.append_null::<Float64Type>("B").unwrap();
2111 builder.append::<Int32Type>("A", 34).unwrap();
2112 let array = builder.build().unwrap();
2113
2114 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2115 let c = filter(&array, &filter_array).unwrap();
2116 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2117
2118 let mut builder = UnionBuilder::new_sparse();
2119 builder.append::<Int32Type>("A", 1).unwrap();
2120 builder.append_null::<Float64Type>("B").unwrap();
2121 let expected_array = builder.build().unwrap();
2122
2123 compare_union_arrays(filtered, &expected_array);
2124 }
2125
2126 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2127 assert_eq!(union1.len(), union2.len());
2128
2129 for i in 0..union1.len() {
2130 let type_id = union1.type_id(i);
2131
2132 let slot1 = union1.value(i);
2133 let slot2 = union2.value(i);
2134
2135 assert_eq!(slot1.is_null(0), slot2.is_null(0));
2136
2137 if !slot1.is_null(0) && !slot2.is_null(0) {
2138 match type_id {
2139 0 => {
2140 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2141 assert_eq!(slot1.len(), 1);
2142 let value1 = slot1.value(0);
2143
2144 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2145 assert_eq!(slot2.len(), 1);
2146 let value2 = slot2.value(0);
2147 assert_eq!(value1, value2);
2148 }
2149 1 => {
2150 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2151 assert_eq!(slot1.len(), 1);
2152 let value1 = slot1.value(0);
2153
2154 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2155 assert_eq!(slot2.len(), 1);
2156 let value2 = slot2.value(0);
2157 assert_eq!(value1, value2);
2158 }
2159 _ => unreachable!(),
2160 }
2161 }
2162 }
2163 }
2164
2165 #[test]
2166 fn test_filter_struct() {
2167 let predicate = BooleanArray::from(vec![true, false, true, false]);
2168
2169 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2170 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2171
2172 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2173 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2174
2175 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2176 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2177
2178 let a_field = Field::new("a", DataType::Utf8, false);
2179 let b_field = Field::new("b", DataType::Int32, false);
2180
2181 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2182 let expected =
2183 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2184
2185 let result = filter(&array, &predicate).unwrap();
2186
2187 assert_eq!(result.to_data(), expected.to_data());
2188
2189 let array = StructArray::new(
2190 vec![a_field.clone()].into(),
2191 vec![a.clone()],
2192 Some(null_mask.clone()),
2193 );
2194 let expected = StructArray::new(
2195 vec![a_field.clone()].into(),
2196 vec![a_filtered.clone()],
2197 Some(null_mask_filtered.clone()),
2198 );
2199
2200 let result = filter(&array, &predicate).unwrap();
2201
2202 assert_eq!(result.to_data(), expected.to_data());
2203
2204 let array = StructArray::new(
2205 vec![a_field.clone(), b_field.clone()].into(),
2206 vec![a.clone(), b.clone()],
2207 None,
2208 );
2209 let expected = StructArray::new(
2210 vec![a_field.clone(), b_field.clone()].into(),
2211 vec![a_filtered.clone(), b_filtered.clone()],
2212 None,
2213 );
2214
2215 let result = filter(&array, &predicate).unwrap();
2216
2217 assert_eq!(result.to_data(), expected.to_data());
2218
2219 let array = StructArray::new(
2220 vec![a_field.clone(), b_field.clone()].into(),
2221 vec![a.clone(), b.clone()],
2222 Some(null_mask.clone()),
2223 );
2224
2225 let expected = StructArray::new(
2226 vec![a_field.clone(), b_field.clone()].into(),
2227 vec![a_filtered.clone(), b_filtered.clone()],
2228 Some(null_mask_filtered.clone()),
2229 );
2230
2231 let result = filter(&array, &predicate).unwrap();
2232
2233 assert_eq!(result.to_data(), expected.to_data());
2234 }
2235
2236 #[test]
2237 fn test_filter_empty_struct() {
2238 let fields = arrow_schema::Field::new(
2245 "a",
2246 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2247 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2248 arrow_schema::Field::new(
2249 "c",
2250 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2251 true,
2252 ),
2253 ])),
2254 true,
2255 );
2256
2257 let schema = Arc::new(Schema::new(vec![fields]));
2265
2266 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2267 let c = Arc::new(StructArray::new_empty_fields(
2268 3,
2269 Some(NullBuffer::from(vec![true, true, true])),
2270 ));
2271 let a = StructArray::new(
2272 vec![
2273 Field::new("b", DataType::Int64, true),
2274 Field::new("c", DataType::Struct(Fields::empty()), true),
2275 ]
2276 .into(),
2277 vec![b.clone(), c.clone()],
2278 Some(NullBuffer::from(vec![true, true, true])),
2279 );
2280 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2281 println!("{record_batch:?}");
2282
2283 let predicate = BooleanArray::from(vec![true, false, true]);
2285 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2286
2287 assert_eq!(filtered_batch.num_rows(), 2);
2289 }
2290
2291 #[test]
2292 #[should_panic]
2293 fn test_filter_bits_too_large() {
2294 let buffer = BooleanBuffer::from(vec![false; 8]);
2295 let predicate = BooleanArray::from(vec![true; 9]);
2296 let filter = FilterBuilder::new(&predicate).build();
2297 filter_bits(&buffer, &filter);
2298 }
2299
2300 #[test]
2301 #[should_panic]
2302 fn test_filter_native_too_large() {
2303 let values = vec![1; 8];
2304 let predicate = BooleanArray::from(vec![false; 9]);
2305 let filter = FilterBuilder::new(&predicate).build();
2306 filter_native(&values, &filter);
2307 }
2308}