1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 filter.values().into()
62 }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66 fn from(filter: &'a BooleanBuffer) -> Self {
67 Self(filter.set_slices())
68 }
69}
70
71impl Iterator for SlicesIterator<'_> {
72 type Item = (usize, usize);
73
74 fn next(&mut self) -> Option<Self::Item> {
75 self.0.next()
76 }
77}
78
79struct IndexIterator<'a> {
84 remaining: usize,
85 iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90 assert_eq!(filter.null_count(), 0);
91 let iter = filter.values().set_indices();
92 Self { remaining, iter }
93 }
94}
95
96impl Iterator for IndexIterator<'_> {
97 type Item = usize;
98
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.remaining != 0 {
101 let next = self.iter.next().expect("IndexIterator exhausted early");
104 self.remaining -= 1;
105 return Some(next);
107 }
108 None
109 }
110
111 fn size_hint(&self) -> (usize, Option<usize>) {
112 (self.remaining, Some(self.remaining))
113 }
114}
115
116fn filter_count(filter: &BooleanArray) -> usize {
118 filter.values().count_set_bits()
119}
120
121pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
149 let nulls = filter.nulls().unwrap();
150 let mask = filter.values() & nulls.inner();
151 BooleanArray::new(mask, None)
152}
153
154pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
183 let mut filter_builder = FilterBuilder::new(predicate);
184
185 if FilterBuilder::is_optimize_beneficial(values.data_type()) {
186 filter_builder = filter_builder.optimize();
189 }
190
191 let predicate = filter_builder.build();
192
193 filter_array(values, &predicate)
194}
195
196pub fn filter_record_batch(
207 record_batch: &RecordBatch,
208 predicate: &BooleanArray,
209) -> Result<RecordBatch, ArrowError> {
210 let mut filter_builder = FilterBuilder::new(predicate);
211 let num_cols = record_batch.num_columns();
212 if num_cols > 1
213 || (num_cols > 0
214 && FilterBuilder::is_optimize_beneficial(
215 record_batch.schema_ref().field(0).data_type(),
216 ))
217 {
218 filter_builder = filter_builder.optimize();
221 }
222 let filter = filter_builder.build();
223
224 filter.filter_record_batch(record_batch)
225}
226
227#[derive(Debug)]
229pub struct FilterBuilder {
230 filter: BooleanArray,
231 count: usize,
232 strategy: IterationStrategy,
233}
234
235impl FilterBuilder {
236 pub fn new(filter: &BooleanArray) -> Self {
238 let filter = match filter.null_count() {
239 0 => filter.clone(),
240 _ => prep_null_mask_filter(filter),
241 };
242
243 let count = filter_count(&filter);
244 let strategy = IterationStrategy::default_strategy(filter.len(), count);
245
246 Self {
247 filter,
248 count,
249 strategy,
250 }
251 }
252
253 pub fn optimize(mut self) -> Self {
264 match self.strategy {
265 IterationStrategy::SlicesIterator => {
266 let slices = SlicesIterator::new(&self.filter).collect();
267 self.strategy = IterationStrategy::Slices(slices)
268 }
269 IterationStrategy::IndexIterator => {
270 let indices = IndexIterator::new(&self.filter, self.count).collect();
271 self.strategy = IterationStrategy::Indices(indices)
272 }
273 _ => {}
274 }
275 self
276 }
277
278 pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
283 match data_type {
284 DataType::Struct(fields) => {
285 fields.len() > 1
286 || fields.len() == 1
287 && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
288 }
289 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
290 _ => false,
291 }
292 }
293
294 pub fn build(self) -> FilterPredicate {
296 FilterPredicate {
297 filter: self.filter,
298 count: self.count,
299 strategy: self.strategy,
300 }
301 }
302}
303
304#[derive(Debug)]
306enum IterationStrategy {
307 SlicesIterator,
309 IndexIterator,
311 Indices(Vec<usize>),
313 Slices(Vec<(usize, usize)>),
315 All,
317 None,
319}
320
321impl IterationStrategy {
322 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
325 if filter_length == 0 || filter_count == 0 {
326 return IterationStrategy::None;
327 }
328
329 if filter_count == filter_length {
330 return IterationStrategy::All;
331 }
332
333 let selectivity_frac = filter_count as f64 / filter_length as f64;
338 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
339 return IterationStrategy::SlicesIterator;
340 }
341 IterationStrategy::IndexIterator
342 }
343}
344
345#[derive(Debug)]
347pub struct FilterPredicate {
348 filter: BooleanArray,
349 count: usize,
350 strategy: IterationStrategy,
351}
352
353impl FilterPredicate {
354 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
356 filter_array(values, self)
357 }
358
359 pub fn filter_record_batch(
364 &self,
365 record_batch: &RecordBatch,
366 ) -> Result<RecordBatch, ArrowError> {
367 let filtered_arrays = record_batch
368 .columns()
369 .iter()
370 .map(|a| filter_array(a, self))
371 .collect::<Result<Vec<_>, _>>()?;
372
373 unsafe {
376 Ok(RecordBatch::new_unchecked(
377 record_batch.schema(),
378 filtered_arrays,
379 self.count,
380 ))
381 }
382 }
383
384 pub fn count(&self) -> usize {
386 self.count
387 }
388}
389
390fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
391 if predicate.filter.len() > values.len() {
392 return Err(ArrowError::InvalidArgumentError(format!(
393 "Filter predicate of length {} is larger than target array of length {}",
394 predicate.filter.len(),
395 values.len()
396 )));
397 }
398
399 match predicate.strategy {
400 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
401 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
402 _ => downcast_primitive_array! {
404 values => Ok(Arc::new(filter_primitive(values, predicate))),
405 DataType::Boolean => {
406 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
407 Ok(Arc::new(filter_boolean(values, predicate)))
408 }
409 DataType::Utf8 => {
410 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
411 }
412 DataType::LargeUtf8 => {
413 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
414 }
415 DataType::Utf8View => {
416 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
417 }
418 DataType::Binary => {
419 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
420 }
421 DataType::LargeBinary => {
422 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
423 }
424 DataType::BinaryView => {
425 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
426 }
427 DataType::FixedSizeBinary(_) => {
428 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
429 }
430 DataType::ListView(_) => {
431 Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
432 }
433 DataType::LargeListView(_) => {
434 Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
435 }
436 DataType::RunEndEncoded(_, _) => {
437 downcast_run_array!{
438 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
439 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
440 }
441 }
442 DataType::Dictionary(_, _) => downcast_dictionary_array! {
443 values => Ok(Arc::new(filter_dict(values, predicate))),
444 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
445 }
446 DataType::Struct(_) => {
447 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
448 }
449 DataType::Union(_, UnionMode::Sparse) => {
450 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
451 }
452 _ => {
453 let data = values.to_data();
454 let mut mutable = MutableArrayData::new(
456 vec![&data],
457 false,
458 predicate.count,
459 );
460
461 match &predicate.strategy {
462 IterationStrategy::Slices(slices) => {
463 slices
464 .iter()
465 .for_each(|(start, end)| mutable.extend(0, *start, *end));
466 }
467 _ => {
468 let iter = SlicesIterator::new(&predicate.filter);
469 iter.for_each(|(start, end)| mutable.extend(0, start, end));
470 }
471 }
472
473 let data = mutable.freeze();
474 Ok(make_array(data))
475 }
476 },
477 }
478}
479
480fn filter_run_end_array<R: RunEndIndexType>(
482 array: &RunArray<R>,
483 predicate: &FilterPredicate,
484) -> Result<RunArray<R>, ArrowError>
485where
486 R::Native: Into<i64> + From<bool>,
487 R::Native: AddAssign,
488{
489 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
490 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
491
492 let mut start = 0u64;
493 let mut j = 0;
494 let mut count = R::default_value();
495 let filter_values = predicate.filter.values();
496 let run_ends = run_ends.inner();
497
498 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
499 let mut keep = false;
500 let mut end = run_ends[i].into() as u64;
501 let difference = end.saturating_sub(filter_values.len() as u64);
502 end -= difference;
503
504 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
506 count += R::Native::from(pred);
507 keep |= pred
508 }
509 new_run_ends[j] = count;
511 j += keep as usize;
512
513 start = end;
514 keep
515 })
516 .into();
517
518 new_run_ends.truncate(j);
519
520 let values = array.values();
521 let values = filter(&values, &pred)?;
522
523 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
524 RunArray::try_new(&run_ends, &values)
525}
526
527fn filter_null_mask(
534 nulls: Option<&NullBuffer>,
535 predicate: &FilterPredicate,
536) -> Option<(usize, Buffer)> {
537 let nulls = nulls?;
538 if nulls.null_count() == 0 {
539 return None;
540 }
541
542 let nulls = filter_bits(nulls.inner(), predicate);
543 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
546
547 if null_count == 0 {
548 return None;
549 }
550
551 Some((null_count, nulls))
552}
553
554fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
556 let src = buffer.values();
557 let offset = buffer.offset();
558 assert!(buffer.len() >= predicate.filter.len());
559
560 match &predicate.strategy {
561 IterationStrategy::IndexIterator => {
562 let bits =
563 IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
565 bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
566 });
567
568 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
570 }
571 IterationStrategy::Indices(indices) => {
572 let bits = indices.iter().map(|src_idx| unsafe {
574 bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
575 });
576 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
578 }
579 IterationStrategy::SlicesIterator => {
580 let mut builder = BooleanBufferBuilder::new(predicate.count);
581 for (start, end) in SlicesIterator::new(&predicate.filter) {
582 builder.append_packed_range(start + offset..end + offset, src)
583 }
584 builder.into()
585 }
586 IterationStrategy::Slices(slices) => {
587 let mut builder = BooleanBufferBuilder::new(predicate.count);
588 for (start, end) in slices {
589 builder.append_packed_range(*start + offset..*end + offset, src)
590 }
591 builder.into()
592 }
593 IterationStrategy::All | IterationStrategy::None => unreachable!(),
594 }
595}
596
597fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
599 let values = filter_bits(array.values(), predicate);
600
601 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
602 .len(predicate.count)
603 .add_buffer(values);
604
605 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
606 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
607 }
608
609 let data = unsafe { builder.build_unchecked() };
610 BooleanArray::from(data)
611}
612
613#[inline(never)]
614fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
615 assert!(values.len() >= predicate.filter.len());
616
617 match &predicate.strategy {
618 IterationStrategy::SlicesIterator => {
619 let mut buffer = Vec::with_capacity(predicate.count);
620 for (start, end) in SlicesIterator::new(&predicate.filter) {
621 buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
623 }
624 buffer.into()
625 }
626 IterationStrategy::Slices(slices) => {
627 let mut buffer = Vec::with_capacity(predicate.count);
628 for (start, end) in slices {
629 buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
631 }
632 buffer.into()
633 }
634 IterationStrategy::IndexIterator => {
635 let iter = IndexIterator::new(&predicate.filter, predicate.count)
637 .map(|x| unsafe { *values.get_unchecked(x) });
638
639 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
641 }
642 IterationStrategy::Indices(indices) => {
643 let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
645 iter.collect::<Vec<_>>().into()
646 }
647 IterationStrategy::All | IterationStrategy::None => unreachable!(),
648 }
649}
650
651fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
653where
654 T: ArrowPrimitiveType,
655{
656 let values = array.values();
657 let buffer = filter_native(values, predicate);
658 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
659 .len(predicate.count)
660 .add_buffer(buffer);
661
662 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
663 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
664 }
665
666 let data = unsafe { builder.build_unchecked() };
667 PrimitiveArray::from(data)
668}
669
670struct FilterBytes<'a, OffsetSize> {
675 src_offsets: &'a [OffsetSize],
676 src_values: &'a [u8],
677 dst_offsets: Vec<OffsetSize>,
678 dst_values: Vec<u8>,
679 cur_offset: OffsetSize,
680}
681
682impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
683where
684 OffsetSize: OffsetSizeTrait,
685{
686 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
687 where
688 T: ByteArrayType<Offset = OffsetSize>,
689 {
690 let dst_values = Vec::new();
691 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
692 let cur_offset = OffsetSize::from_usize(0).unwrap();
693
694 dst_offsets.push(cur_offset);
695
696 Self {
697 src_offsets: array.value_offsets(),
698 src_values: array.value_data(),
699 dst_offsets,
700 dst_values,
701 cur_offset,
702 }
703 }
704
705 #[inline]
707 fn get_value_offset(&self, idx: usize) -> usize {
708 self.src_offsets[idx].as_usize()
709 }
710
711 #[inline]
713 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
714 let start = self.get_value_offset(idx);
716 let end = self.get_value_offset(idx + 1);
717 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
718 (start, end, len)
719 }
720
721 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
722 self.dst_offsets.extend(iter.map(|idx| {
723 let start = self.src_offsets[idx].as_usize();
724 let end = self.src_offsets[idx + 1].as_usize();
725 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
726 self.cur_offset += len;
727
728 self.cur_offset
729 }));
730 }
731
732 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
734 self.dst_values.reserve_exact(self.cur_offset.as_usize());
735
736 for idx in iter {
737 let start = self.src_offsets[idx].as_usize();
738 let end = self.src_offsets[idx + 1].as_usize();
739 self.dst_values
740 .extend_from_slice(&self.src_values[start..end]);
741 }
742 }
743
744 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
745 self.dst_offsets.reserve_exact(count);
746 for (start, end) in iter {
747 for idx in start..end {
749 let (_, _, len) = self.get_value_range(idx);
750 self.cur_offset += len;
751 self.dst_offsets.push(self.cur_offset);
752 }
753 }
754 }
755
756 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
758 self.dst_values.reserve_exact(self.cur_offset.as_usize());
759
760 for (start, end) in iter {
761 let value_start = self.get_value_offset(start);
762 let value_end = self.get_value_offset(end);
763 self.dst_values
764 .extend_from_slice(&self.src_values[value_start..value_end]);
765 }
766 }
767}
768
769fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
774where
775 T: ByteArrayType,
776{
777 let mut filter = FilterBytes::new(predicate.count, array);
778
779 match &predicate.strategy {
780 IterationStrategy::SlicesIterator => {
781 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
782 filter.extend_slices(SlicesIterator::new(&predicate.filter))
783 }
784 IterationStrategy::Slices(slices) => {
785 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
786 filter.extend_slices(slices.iter().cloned())
787 }
788 IterationStrategy::IndexIterator => {
789 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
790 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
791 }
792 IterationStrategy::Indices(indices) => {
793 filter.extend_offsets_idx(indices.iter().cloned());
794 filter.extend_idx(indices.iter().cloned())
795 }
796 IterationStrategy::All | IterationStrategy::None => unreachable!(),
797 }
798
799 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
800 .len(predicate.count)
801 .add_buffer(filter.dst_offsets.into())
802 .add_buffer(filter.dst_values.into());
803
804 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
805 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
806 }
807
808 let data = unsafe { builder.build_unchecked() };
809 GenericByteArray::from(data)
810}
811
812fn filter_byte_view<T: ByteViewType>(
814 array: &GenericByteViewArray<T>,
815 predicate: &FilterPredicate,
816) -> GenericByteViewArray<T> {
817 let new_view_buffer = filter_native(array.views(), predicate);
818
819 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
820 .len(predicate.count)
821 .add_buffer(new_view_buffer)
822 .add_buffers(array.data_buffers().to_vec());
823
824 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
825 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
826 }
827
828 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
829}
830
831fn filter_fixed_size_binary(
832 array: &FixedSizeBinaryArray,
833 predicate: &FilterPredicate,
834) -> FixedSizeBinaryArray {
835 let values: &[u8] = array.values();
836 let value_length = array.value_length() as usize;
837 let calculate_offset_from_index = |index: usize| index * value_length;
838 let buffer = match &predicate.strategy {
839 IterationStrategy::SlicesIterator => {
840 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
841 for (start, end) in SlicesIterator::new(&predicate.filter) {
842 buffer.extend_from_slice(
843 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
844 );
845 }
846 buffer
847 }
848 IterationStrategy::Slices(slices) => {
849 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
850 for (start, end) in slices {
851 buffer.extend_from_slice(
852 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
853 );
854 }
855 buffer
856 }
857 IterationStrategy::IndexIterator => {
858 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
859 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
860 });
861
862 let mut buffer = MutableBuffer::new(predicate.count * value_length);
863 iter.for_each(|item| buffer.extend_from_slice(item));
864 buffer
865 }
866 IterationStrategy::Indices(indices) => {
867 let iter = indices.iter().map(|x| {
868 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
869 });
870
871 let mut buffer = MutableBuffer::new(predicate.count * value_length);
872 iter.for_each(|item| buffer.extend_from_slice(item));
873 buffer
874 }
875 IterationStrategy::All | IterationStrategy::None => unreachable!(),
876 };
877 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
878 .len(predicate.count)
879 .add_buffer(buffer.into());
880
881 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
882 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
883 }
884
885 let data = unsafe { builder.build_unchecked() };
886 FixedSizeBinaryArray::from(data)
887}
888
889fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
891where
892 T: ArrowDictionaryKeyType,
893 T::Native: num_traits::Num,
894{
895 let builder = filter_primitive::<T>(array.keys(), predicate)
896 .into_data()
897 .into_builder()
898 .data_type(array.data_type().clone())
899 .child_data(vec![array.values().to_data()]);
900
901 DictionaryArray::from(unsafe { builder.build_unchecked() })
904}
905
906fn filter_struct(
908 array: &StructArray,
909 predicate: &FilterPredicate,
910) -> Result<StructArray, ArrowError> {
911 let columns = array
912 .columns()
913 .iter()
914 .map(|column| filter_array(column, predicate))
915 .collect::<Result<_, _>>()?;
916
917 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
918 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
919
920 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
921 } else {
922 None
923 };
924
925 Ok(unsafe {
926 StructArray::new_unchecked_with_length(
927 array.fields().clone(),
928 columns,
929 nulls,
930 predicate.count(),
931 )
932 })
933}
934
935fn filter_sparse_union(
937 array: &UnionArray,
938 predicate: &FilterPredicate,
939) -> Result<UnionArray, ArrowError> {
940 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
941 unreachable!()
942 };
943
944 let type_ids = filter_primitive(
945 &Int8Array::try_new(array.type_ids().clone(), None)?,
946 predicate,
947 );
948
949 let children = fields
950 .iter()
951 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
952 .collect::<Result<_, _>>()?;
953
954 Ok(unsafe {
955 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
956 })
957}
958
959fn filter_list_view<OffsetType: OffsetSizeTrait>(
961 array: &GenericListViewArray<OffsetType>,
962 predicate: &FilterPredicate,
963) -> GenericListViewArray<OffsetType> {
964 let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
965 let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
966
967 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
969 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
970
971 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
972 } else {
973 None
974 };
975
976 let list_data = ArrayDataBuilder::new(array.data_type().clone())
977 .nulls(nulls)
978 .buffers(vec![filtered_offsets, filtered_sizes])
979 .child_data(vec![array.values().to_data()])
980 .len(predicate.count);
981
982 let list_data = unsafe { list_data.build_unchecked() };
983
984 GenericListViewArray::from(list_data)
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use arrow_array::builder::*;
991 use arrow_array::cast::as_run_array;
992 use arrow_array::types::*;
993 use arrow_data::ArrayData;
994 use rand::distr::uniform::{UniformSampler, UniformUsize};
995 use rand::distr::{Alphanumeric, StandardUniform};
996 use rand::prelude::*;
997 use rand::rng;
998
999 macro_rules! def_temporal_test {
1000 ($test:ident, $array_type: ident, $data: expr) => {
1001 #[test]
1002 fn $test() {
1003 let a = $data;
1004 let b = BooleanArray::from(vec![true, false, true, false]);
1005 let c = filter(&a, &b).unwrap();
1006 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1007 assert_eq!(2, d.len());
1008 assert_eq!(1, d.value(0));
1009 assert_eq!(3, d.value(1));
1010 }
1011 };
1012 }
1013
1014 def_temporal_test!(
1015 test_filter_date32,
1016 Date32Array,
1017 Date32Array::from(vec![1, 2, 3, 4])
1018 );
1019 def_temporal_test!(
1020 test_filter_date64,
1021 Date64Array,
1022 Date64Array::from(vec![1, 2, 3, 4])
1023 );
1024 def_temporal_test!(
1025 test_filter_time32_second,
1026 Time32SecondArray,
1027 Time32SecondArray::from(vec![1, 2, 3, 4])
1028 );
1029 def_temporal_test!(
1030 test_filter_time32_millisecond,
1031 Time32MillisecondArray,
1032 Time32MillisecondArray::from(vec![1, 2, 3, 4])
1033 );
1034 def_temporal_test!(
1035 test_filter_time64_microsecond,
1036 Time64MicrosecondArray,
1037 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1038 );
1039 def_temporal_test!(
1040 test_filter_time64_nanosecond,
1041 Time64NanosecondArray,
1042 Time64NanosecondArray::from(vec![1, 2, 3, 4])
1043 );
1044 def_temporal_test!(
1045 test_filter_duration_second,
1046 DurationSecondArray,
1047 DurationSecondArray::from(vec![1, 2, 3, 4])
1048 );
1049 def_temporal_test!(
1050 test_filter_duration_millisecond,
1051 DurationMillisecondArray,
1052 DurationMillisecondArray::from(vec![1, 2, 3, 4])
1053 );
1054 def_temporal_test!(
1055 test_filter_duration_microsecond,
1056 DurationMicrosecondArray,
1057 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1058 );
1059 def_temporal_test!(
1060 test_filter_duration_nanosecond,
1061 DurationNanosecondArray,
1062 DurationNanosecondArray::from(vec![1, 2, 3, 4])
1063 );
1064 def_temporal_test!(
1065 test_filter_timestamp_second,
1066 TimestampSecondArray,
1067 TimestampSecondArray::from(vec![1, 2, 3, 4])
1068 );
1069 def_temporal_test!(
1070 test_filter_timestamp_millisecond,
1071 TimestampMillisecondArray,
1072 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1073 );
1074 def_temporal_test!(
1075 test_filter_timestamp_microsecond,
1076 TimestampMicrosecondArray,
1077 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1078 );
1079 def_temporal_test!(
1080 test_filter_timestamp_nanosecond,
1081 TimestampNanosecondArray,
1082 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1083 );
1084
1085 #[test]
1086 fn test_filter_array_slice() {
1087 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1088 let b = BooleanArray::from(vec![true, false, false, true]);
1089 let c = filter(&a, &b).unwrap();
1093 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1094 assert_eq!(2, d.len());
1095 assert_eq!(6, d.value(0));
1096 assert_eq!(9, d.value(1));
1097 }
1098
1099 #[test]
1100 fn test_filter_array_low_density() {
1101 let mut data_values = (1..=65).collect::<Vec<i32>>();
1103 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1104 data_values.extend_from_slice(&[66, 67]);
1106 filter_values.extend_from_slice(&[false, true]);
1107 let a = Int32Array::from(data_values);
1108 let b = BooleanArray::from(filter_values);
1109 let c = filter(&a, &b).unwrap();
1110 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1111 assert_eq!(2, d.len());
1112 assert_eq!(65, d.value(0));
1113 assert_eq!(67, d.value(1));
1114 }
1115
1116 #[test]
1117 fn test_filter_array_high_density() {
1118 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1120 let mut filter_values = (1..=65)
1121 .map(|i| !matches!(i % 65, 0))
1122 .collect::<Vec<bool>>();
1123 data_values[1] = None;
1125 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1127 filter_values.extend_from_slice(&[false, true, true, true]);
1128 let a = Int32Array::from(data_values);
1129 let b = BooleanArray::from(filter_values);
1130 let c = filter(&a, &b).unwrap();
1131 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1132 assert_eq!(67, d.len());
1133 assert_eq!(3, d.null_count());
1134 assert_eq!(1, d.value(0));
1135 assert!(d.is_null(1));
1136 assert_eq!(64, d.value(63));
1137 assert!(d.is_null(64));
1138 assert_eq!(67, d.value(65));
1139 }
1140
1141 #[test]
1142 fn test_filter_string_array_simple() {
1143 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1144 let b = BooleanArray::from(vec![true, false, true, false]);
1145 let c = filter(&a, &b).unwrap();
1146 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1147 assert_eq!(2, d.len());
1148 assert_eq!("hello", d.value(0));
1149 assert_eq!("world", d.value(1));
1150 }
1151
1152 #[test]
1153 fn test_filter_primitive_array_with_null() {
1154 let a = Int32Array::from(vec![Some(5), None]);
1155 let b = BooleanArray::from(vec![false, true]);
1156 let c = filter(&a, &b).unwrap();
1157 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1158 assert_eq!(1, d.len());
1159 assert!(d.is_null(0));
1160 }
1161
1162 #[test]
1163 fn test_filter_string_array_with_null() {
1164 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1165 let b = BooleanArray::from(vec![true, false, false, true]);
1166 let c = filter(&a, &b).unwrap();
1167 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1168 assert_eq!(2, d.len());
1169 assert_eq!("hello", d.value(0));
1170 assert!(!d.is_null(0));
1171 assert!(d.is_null(1));
1172 }
1173
1174 #[test]
1175 fn test_filter_binary_array_with_null() {
1176 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1177 let a = BinaryArray::from(data);
1178 let b = BooleanArray::from(vec![true, false, false, true]);
1179 let c = filter(&a, &b).unwrap();
1180 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1181 assert_eq!(2, d.len());
1182 assert_eq!(b"hello", d.value(0));
1183 assert!(!d.is_null(0));
1184 assert!(d.is_null(1));
1185 }
1186
1187 fn _test_filter_byte_view<T>()
1188 where
1189 T: ByteViewType,
1190 str: AsRef<T::Native>,
1191 T::Native: PartialEq,
1192 {
1193 let array = {
1194 let mut builder = GenericByteViewBuilder::<T>::new();
1196 builder.append_value("hello");
1197 builder.append_value("world");
1198 builder.append_null();
1199 builder.append_value("large payload over 12 bytes");
1200 builder.append_value("lulu");
1201 builder.finish()
1202 };
1203
1204 {
1205 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1206 let actual = filter(&array, &predicate).unwrap();
1207
1208 assert_eq!(actual.len(), 3);
1209
1210 let expected = {
1211 let mut builder = GenericByteViewBuilder::<T>::new();
1213 builder.append_value("hello");
1214 builder.append_null();
1215 builder.append_value("large payload over 12 bytes");
1216 builder.finish()
1217 };
1218
1219 assert_eq!(actual.as_ref(), &expected);
1220 }
1221
1222 {
1223 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1224 let actual = filter(&array, &predicate).unwrap();
1225
1226 assert_eq!(actual.len(), 2);
1227
1228 let expected = {
1229 let mut builder = GenericByteViewBuilder::<T>::new();
1231 builder.append_value("hello");
1232 builder.append_value("lulu");
1233 builder.finish()
1234 };
1235
1236 assert_eq!(actual.as_ref(), &expected);
1237 }
1238 }
1239
1240 #[test]
1241 fn test_filter_string_view() {
1242 _test_filter_byte_view::<StringViewType>()
1243 }
1244
1245 #[test]
1246 fn test_filter_binary_view() {
1247 _test_filter_byte_view::<BinaryViewType>()
1248 }
1249
1250 #[test]
1251 fn test_filter_fixed_binary() {
1252 let v1 = [1_u8, 2];
1253 let v2 = [3_u8, 4];
1254 let v3 = [5_u8, 6];
1255 let v = vec![&v1, &v2, &v3];
1256 let a = FixedSizeBinaryArray::from(v);
1257 let b = BooleanArray::from(vec![true, false, true]);
1258 let c = filter(&a, &b).unwrap();
1259 let d = c
1260 .as_ref()
1261 .as_any()
1262 .downcast_ref::<FixedSizeBinaryArray>()
1263 .unwrap();
1264 assert_eq!(d.len(), 2);
1265 assert_eq!(d.value(0), &v1);
1266 assert_eq!(d.value(1), &v3);
1267 let c2 = FilterBuilder::new(&b)
1268 .optimize()
1269 .build()
1270 .filter(&a)
1271 .unwrap();
1272 let d2 = c2
1273 .as_ref()
1274 .as_any()
1275 .downcast_ref::<FixedSizeBinaryArray>()
1276 .unwrap();
1277 assert_eq!(d, d2);
1278
1279 let b = BooleanArray::from(vec![false, false, false]);
1280 let c = filter(&a, &b).unwrap();
1281 let d = c
1282 .as_ref()
1283 .as_any()
1284 .downcast_ref::<FixedSizeBinaryArray>()
1285 .unwrap();
1286 assert_eq!(d.len(), 0);
1287
1288 let b = BooleanArray::from(vec![true, true, true]);
1289 let c = filter(&a, &b).unwrap();
1290 let d = c
1291 .as_ref()
1292 .as_any()
1293 .downcast_ref::<FixedSizeBinaryArray>()
1294 .unwrap();
1295 assert_eq!(d.len(), 3);
1296 assert_eq!(d.value(0), &v1);
1297 assert_eq!(d.value(1), &v2);
1298 assert_eq!(d.value(2), &v3);
1299
1300 let b = BooleanArray::from(vec![false, false, true]);
1301 let c = filter(&a, &b).unwrap();
1302 let d = c
1303 .as_ref()
1304 .as_any()
1305 .downcast_ref::<FixedSizeBinaryArray>()
1306 .unwrap();
1307 assert_eq!(d.len(), 1);
1308 assert_eq!(d.value(0), &v3);
1309 let c2 = FilterBuilder::new(&b)
1310 .optimize()
1311 .build()
1312 .filter(&a)
1313 .unwrap();
1314 let d2 = c2
1315 .as_ref()
1316 .as_any()
1317 .downcast_ref::<FixedSizeBinaryArray>()
1318 .unwrap();
1319 assert_eq!(d, d2);
1320 }
1321
1322 #[test]
1323 fn test_filter_array_slice_with_null() {
1324 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1325 let b = BooleanArray::from(vec![true, false, false, true]);
1326 let c = filter(&a, &b).unwrap();
1330 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1331 assert_eq!(2, d.len());
1332 assert!(d.is_null(0));
1333 assert!(!d.is_null(1));
1334 assert_eq!(9, d.value(1));
1335 }
1336
1337 #[test]
1338 fn test_filter_run_end_encoding_array() {
1339 let run_ends = Int64Array::from(vec![2, 3, 8]);
1340 let values = Int64Array::from(vec![7, -2, 9]);
1341 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1342 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1343 let c = filter(&a, &b).unwrap();
1344 let actual: &RunArray<Int64Type> = as_run_array(&c);
1345 assert_eq!(4, actual.len());
1346
1347 let expected = RunArray::try_new(
1348 &Int64Array::from(vec![1, 2, 4]),
1349 &Int64Array::from(vec![7, -2, 9]),
1350 )
1351 .expect("Failed to make expected RunArray test is broken");
1352
1353 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1354 assert_eq!(actual.values(), expected.values())
1355 }
1356
1357 #[test]
1358 #[should_panic = "assertion `left == right` failed\n left: [-2, 9]\n right: [7, -2]"]
1359 fn test_filter_run_end_encoding_array_sliced() {
1362 let run_ends = Int64Array::from(vec![2, 3, 8]);
1363 let values = Int64Array::from(vec![7, -2, 9]);
1364 let a = RunArray::try_new(&run_ends, &values).unwrap(); let a = a.slice(2, 3); let b = BooleanArray::from(vec![true, false, true]);
1367 let result = filter(&a, &b).unwrap();
1368
1369 let result = result.as_run::<Int64Type>();
1370 let result = result.downcast::<Int64Array>().unwrap();
1371
1372 let expected = vec![-2, 9];
1373 let actual = result.into_iter().flatten().collect::<Vec<_>>();
1374 assert_eq!(expected, actual);
1375 }
1376
1377 #[test]
1378 fn test_filter_run_end_encoding_array_remove_value() {
1379 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1380 let values = Int32Array::from(vec![7, -2, 9, -8]);
1381 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1382 let b = BooleanArray::from(vec![
1383 false, true, false, false, true, false, true, false, false, false,
1384 ]);
1385 let c = filter(&a, &b).unwrap();
1386 let actual: &RunArray<Int32Type> = as_run_array(&c);
1387 assert_eq!(3, actual.len());
1388
1389 let expected =
1390 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1391 .expect("Failed to make expected RunArray test is broken");
1392
1393 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1394 assert_eq!(actual.values(), expected.values())
1395 }
1396
1397 #[test]
1398 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1399 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1400 let values = Int16Array::from(vec![7, -2, 9, -8]);
1401 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1402 let b = BooleanArray::from(vec![
1403 false, false, false, false, false, false, true, false, false, false,
1404 ]);
1405 let c = filter(&a, &b).unwrap();
1406 let actual: &RunArray<Int16Type> = as_run_array(&c);
1407 assert_eq!(1, actual.len());
1408
1409 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1410 .expect("Failed to make expected RunArray test is broken");
1411
1412 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1413 assert_eq!(actual.values(), expected.values())
1414 }
1415
1416 #[test]
1417 fn test_filter_run_end_encoding_array_empty() {
1418 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1419 let values = Int64Array::from(vec![7, -2, 9, -8]);
1420 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1421 let b = BooleanArray::from(vec![
1422 false, false, false, false, false, false, false, false, false, false,
1423 ]);
1424 let c = filter(&a, &b).unwrap();
1425 let actual: &RunArray<Int64Type> = as_run_array(&c);
1426 assert_eq!(0, actual.len());
1427 }
1428
1429 #[test]
1430 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1431 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1432 let values = Int64Array::from(vec![7, -2, 9, -8]);
1433 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1434 let b = BooleanArray::from(vec![false, true, true]);
1435 let c = filter(&a, &b).unwrap();
1436 let actual: &RunArray<Int64Type> = as_run_array(&c);
1437 assert_eq!(2, actual.len());
1438
1439 let expected = RunArray::try_new(
1440 &Int64Array::from(vec![1, 2]),
1441 &Int64Array::from(vec![7, -2]),
1442 )
1443 .expect("Failed to make expected RunArray test is broken");
1444
1445 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1446 assert_eq!(actual.values(), expected.values())
1447 }
1448
1449 #[test]
1450 fn test_filter_dictionary_array() {
1451 let values = [Some("hello"), None, Some("world"), Some("!")];
1452 let a: Int8DictionaryArray = values.iter().copied().collect();
1453 let b = BooleanArray::from(vec![false, true, true, false]);
1454 let c = filter(&a, &b).unwrap();
1455 let d = c
1456 .as_ref()
1457 .as_any()
1458 .downcast_ref::<Int8DictionaryArray>()
1459 .unwrap();
1460 let value_array = d.values();
1461 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1462 assert_eq!(3, values.len());
1464 assert_eq!(2, d.len());
1466 assert!(d.is_null(0));
1467 assert_eq!("world", values.value(d.keys().value(1) as usize));
1468 }
1469
1470 #[test]
1471 fn test_filter_list_array() {
1472 let value_data = ArrayData::builder(DataType::Int32)
1473 .len(8)
1474 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1475 .build()
1476 .unwrap();
1477
1478 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1479
1480 let list_data_type =
1481 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1482 let list_data = ArrayData::builder(list_data_type)
1483 .len(4)
1484 .add_buffer(value_offsets)
1485 .add_child_data(value_data)
1486 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1487 .build()
1488 .unwrap();
1489
1490 let a = LargeListArray::from(list_data);
1492 let b = BooleanArray::from(vec![false, true, false, true]);
1493 let result = filter(&a, &b).unwrap();
1494
1495 let value_data = ArrayData::builder(DataType::Int32)
1497 .len(3)
1498 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1499 .build()
1500 .unwrap();
1501
1502 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1503
1504 let list_data_type =
1505 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1506 let expected = ArrayData::builder(list_data_type)
1507 .len(2)
1508 .add_buffer(value_offsets)
1509 .add_child_data(value_data)
1510 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1511 .build()
1512 .unwrap();
1513
1514 assert_eq!(&make_array(expected), &result);
1515 }
1516
1517 fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1518 let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1520 list_array.append_value([Some(1), Some(2)]);
1521 list_array.append_null();
1522 list_array.append_value([]);
1523 list_array.append_value([Some(3), Some(4)]);
1524
1525 let list_array = list_array.finish();
1526 let predicate = BooleanArray::from_iter([true, false, true, false]);
1527
1528 let filtered = filter(&list_array, &predicate)
1530 .unwrap()
1531 .as_list_view::<T>()
1532 .clone();
1533
1534 let mut expected =
1535 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1536 expected.append_value([Some(1), Some(2)]);
1537 expected.append_value([]);
1538 let expected = expected.finish();
1539
1540 assert_eq!(&filtered, &expected);
1541 }
1542
1543 fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1544 let mut list_array =
1546 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1547 list_array.append_value([Some(1), Some(2)]);
1548 list_array.append_null();
1549 list_array.append_value([]);
1550 list_array.append_value([Some(3), Some(4)]);
1551
1552 let list_array = list_array.finish();
1553
1554 let sliced = list_array.slice(1, 3);
1556 let predicate = BooleanArray::from_iter([false, false, true]);
1557
1558 let filtered = filter(&sliced, &predicate)
1560 .unwrap()
1561 .as_list_view::<T>()
1562 .clone();
1563
1564 let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1565 expected.append_value([Some(3), Some(4)]);
1566 let expected = expected.finish();
1567
1568 assert_eq!(&filtered, &expected);
1569 }
1570
1571 #[test]
1572 fn test_filter_list_view_array() {
1573 test_case_filter_list_view::<i32>();
1574 test_case_filter_list_view::<i64>();
1575
1576 test_case_filter_sliced_list_view::<i32>();
1577 test_case_filter_sliced_list_view::<i64>();
1578 }
1579
1580 #[test]
1581 fn test_slice_iterator_bits() {
1582 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1583 let filter = BooleanArray::from(filter_values);
1584 let filter_count = filter_count(&filter);
1585
1586 let iter = SlicesIterator::new(&filter);
1587 let chunks = iter.collect::<Vec<_>>();
1588
1589 assert_eq!(chunks, vec![(1, 2)]);
1590 assert_eq!(filter_count, 1);
1591 }
1592
1593 #[test]
1594 fn test_slice_iterator_bits1() {
1595 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1596 let filter = BooleanArray::from(filter_values);
1597 let filter_count = filter_count(&filter);
1598
1599 let iter = SlicesIterator::new(&filter);
1600 let chunks = iter.collect::<Vec<_>>();
1601
1602 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1603 assert_eq!(filter_count, 64 - 1);
1604 }
1605
1606 #[test]
1607 fn test_slice_iterator_chunk_and_bits() {
1608 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1609 let filter = BooleanArray::from(filter_values);
1610 let filter_count = filter_count(&filter);
1611
1612 let iter = SlicesIterator::new(&filter);
1613 let chunks = iter.collect::<Vec<_>>();
1614
1615 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1616 assert_eq!(filter_count, 61 + 61 + 5);
1617 }
1618
1619 #[test]
1620 fn test_null_mask() {
1621 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1622
1623 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1624 let out = filter(&a, &mask1).unwrap();
1625 assert_eq!(out.as_ref(), &a.slice(0, 2));
1626 }
1627
1628 #[test]
1629 fn test_filter_record_batch_no_columns() {
1630 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1631 let options = RecordBatchOptions::default().with_row_count(Some(100));
1632 let record_batch =
1633 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1634 let out = filter_record_batch(&record_batch, &pred).unwrap();
1635
1636 assert_eq!(out.num_rows(), 2);
1637 }
1638
1639 #[test]
1640 fn test_fast_path() {
1641 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1642
1643 let mask = BooleanArray::from(vec![true, true, true]);
1645 let out = filter(&a, &mask).unwrap();
1646 let b = out
1647 .as_any()
1648 .downcast_ref::<PrimitiveArray<Int64Type>>()
1649 .unwrap();
1650 assert_eq!(&a, b);
1651
1652 let mask = BooleanArray::from(vec![false, false, false]);
1654 let out = filter(&a, &mask).unwrap();
1655 assert_eq!(out.len(), 0);
1656 assert_eq!(out.data_type(), &DataType::Int64);
1657 }
1658
1659 #[test]
1660 fn test_slices() {
1661 let bools = std::iter::repeat_n(true, 10)
1663 .chain(std::iter::repeat_n(false, 30))
1664 .chain(std::iter::repeat_n(true, 20))
1665 .chain(std::iter::repeat_n(false, 17))
1666 .chain(std::iter::repeat_n(true, 4));
1667
1668 let bool_array: BooleanArray = bools.map(Some).collect();
1669
1670 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1671 let expected = vec![(0, 10), (40, 60), (77, 81)];
1672 assert_eq!(slices, expected);
1673
1674 let len = bool_array.len();
1676 let sliced_array = bool_array.slice(7, len - 10);
1677 let sliced_array = sliced_array
1678 .as_any()
1679 .downcast_ref::<BooleanArray>()
1680 .unwrap();
1681 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1682 let expected = vec![(0, 3), (33, 53), (70, 71)];
1683 assert_eq!(slices, expected);
1684 }
1685
1686 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1687 let mut rng = rng();
1688
1689 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1690 .take(mask_len)
1691 .collect();
1692
1693 let buffer = Buffer::from_iter(bools.iter().cloned());
1694
1695 let truncated_length = mask_len - offset - truncate;
1696
1697 let data = ArrayDataBuilder::new(DataType::Boolean)
1698 .len(truncated_length)
1699 .offset(offset)
1700 .add_buffer(buffer)
1701 .build()
1702 .unwrap();
1703
1704 let filter = BooleanArray::from(data);
1705
1706 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1707 .flat_map(|(start, end)| start..end)
1708 .collect();
1709
1710 let count = filter_count(&filter);
1711 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1712
1713 let expected_bits: Vec<_> = bools
1714 .iter()
1715 .skip(offset)
1716 .take(truncated_length)
1717 .enumerate()
1718 .flat_map(|(idx, v)| v.then(|| idx))
1719 .collect();
1720
1721 assert_eq!(slice_bits, expected_bits);
1722 assert_eq!(index_bits, expected_bits);
1723 }
1724
1725 #[test]
1726 #[cfg_attr(miri, ignore)]
1727 fn fuzz_test_slices_iterator() {
1728 let mut rng = rng();
1729
1730 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1731 for _ in 0..100 {
1732 let mask_len = rng.random_range(0..1024);
1733 let max_offset = 64.min(mask_len);
1734 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1735
1736 let max_truncate = 128.min(mask_len - offset);
1737 let truncate = uusize
1738 .sample(&mut rng)
1739 .checked_rem(max_truncate)
1740 .unwrap_or(0);
1741
1742 test_slices_fuzz(mask_len, offset, truncate);
1743 }
1744
1745 test_slices_fuzz(64, 0, 0);
1746 test_slices_fuzz(64, 8, 0);
1747 test_slices_fuzz(64, 8, 8);
1748 test_slices_fuzz(32, 8, 8);
1749 test_slices_fuzz(32, 5, 9);
1750 }
1751
1752 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1754 values
1755 .into_iter()
1756 .zip(predicate)
1757 .filter(|(_, x)| **x)
1758 .map(|(a, _)| a)
1759 .collect()
1760 }
1761
1762 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1764 where
1765 StandardUniform: Distribution<T>,
1766 {
1767 let mut rng = rng();
1768 (0..len)
1769 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1770 .collect()
1771 }
1772
1773 fn gen_strings(
1775 len: usize,
1776 valid_percent: f64,
1777 str_len_range: std::ops::Range<usize>,
1778 ) -> Vec<Option<String>> {
1779 let mut rng = rng();
1780 (0..len)
1781 .map(|_| {
1782 rng.random_bool(valid_percent).then(|| {
1783 let len = rng.random_range(str_len_range.clone());
1784 (0..len)
1785 .map(|_| char::from(rng.sample(Alphanumeric)))
1786 .collect()
1787 })
1788 })
1789 .collect()
1790 }
1791
1792 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1794 src.iter().map(|x| x.as_deref())
1795 }
1796
1797 #[test]
1798 #[cfg_attr(miri, ignore)]
1799 fn fuzz_filter() {
1800 let mut rng = rng();
1801
1802 for i in 0..100 {
1803 let filter_percent = match i {
1804 0..=4 => 1.,
1805 5..=10 => 0.,
1806 _ => rng.random_range(0.0..1.0),
1807 };
1808
1809 let valid_percent = rng.random_range(0.0..1.0);
1810
1811 let array_len = rng.random_range(32..256);
1812 let array_offset = rng.random_range(0..10);
1813
1814 let filter_offset = rng.random_range(0..10);
1816 let filter_truncate = rng.random_range(0..10);
1817 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1818 .take(array_len + filter_offset - filter_truncate)
1819 .collect();
1820
1821 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1822
1823 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1825 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1826 let bools = &bools[filter_offset..];
1827
1828 let values = gen_primitive(array_len + array_offset, valid_percent);
1830 let src = Int32Array::from_iter(values.iter().cloned());
1831
1832 let src = src.slice(array_offset, array_len);
1833 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1834 let values = &values[array_offset..];
1835
1836 let filtered = filter(src, predicate).unwrap();
1837 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1838 let actual: Vec<_> = array.iter().collect();
1839
1840 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1841
1842 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1844 let src = StringArray::from_iter(as_deref(&strings));
1845
1846 let src = src.slice(array_offset, array_len);
1847 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1848
1849 let filtered = filter(src, predicate).unwrap();
1850 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1851 let actual: Vec<_> = array.iter().collect();
1852
1853 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1854 assert_eq!(actual, expected_strings);
1855
1856 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1858
1859 let src = src.slice(array_offset, array_len);
1860 let src = src
1861 .as_any()
1862 .downcast_ref::<DictionaryArray<Int32Type>>()
1863 .unwrap();
1864
1865 let filtered = filter(src, predicate).unwrap();
1866
1867 let array = filtered
1868 .as_any()
1869 .downcast_ref::<DictionaryArray<Int32Type>>()
1870 .unwrap();
1871
1872 let values = array
1873 .values()
1874 .as_any()
1875 .downcast_ref::<StringArray>()
1876 .unwrap();
1877
1878 let actual: Vec<_> = array
1879 .keys()
1880 .iter()
1881 .map(|key| key.map(|key| values.value(key as usize)))
1882 .collect();
1883
1884 assert_eq!(actual, expected_strings);
1885 }
1886 }
1887
1888 #[test]
1889 fn test_filter_map() {
1890 let mut builder =
1891 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1892 builder.keys().append_value("key1");
1894 builder.values().append_value(1);
1895 builder.append(true).unwrap();
1896 builder.keys().append_value("key2");
1897 builder.keys().append_value("key3");
1898 builder.values().append_value(2);
1899 builder.values().append_value(3);
1900 builder.append(true).unwrap();
1901 builder.append(false).unwrap();
1902 builder.keys().append_value("key1");
1903 builder.values().append_value(1);
1904 builder.append(true).unwrap();
1905 let maparray = Arc::new(builder.finish()) as ArrayRef;
1906
1907 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1908 .into_iter()
1909 .collect::<BooleanArray>();
1910 let got = filter(&maparray, &indices).unwrap();
1911
1912 let mut builder =
1913 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1914 builder.keys().append_value("key1");
1915 builder.values().append_value(1);
1916 builder.append(true).unwrap();
1917 builder.keys().append_value("key1");
1918 builder.values().append_value(1);
1919 builder.append(true).unwrap();
1920 let expected = Arc::new(builder.finish()) as ArrayRef;
1921
1922 assert_eq!(&expected, &got);
1923 }
1924
1925 #[test]
1926 fn test_filter_fixed_size_list_arrays() {
1927 let value_data = ArrayData::builder(DataType::Int32)
1928 .len(9)
1929 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1930 .build()
1931 .unwrap();
1932 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1933 let list_data = ArrayData::builder(list_data_type)
1934 .len(3)
1935 .add_child_data(value_data)
1936 .build()
1937 .unwrap();
1938 let array = FixedSizeListArray::from(list_data);
1939
1940 let filter_array = BooleanArray::from(vec![true, false, false]);
1941
1942 let c = filter(&array, &filter_array).unwrap();
1943 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1944
1945 assert_eq!(filtered.len(), 1);
1946
1947 let list = filtered.value(0);
1948 assert_eq!(
1949 &[0, 1, 2],
1950 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1951 );
1952
1953 let filter_array = BooleanArray::from(vec![true, false, true]);
1954
1955 let c = filter(&array, &filter_array).unwrap();
1956 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1957
1958 assert_eq!(filtered.len(), 2);
1959
1960 let list = filtered.value(0);
1961 assert_eq!(
1962 &[0, 1, 2],
1963 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1964 );
1965 let list = filtered.value(1);
1966 assert_eq!(
1967 &[6, 7, 8],
1968 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1969 );
1970 }
1971
1972 #[test]
1973 fn test_filter_fixed_size_list_arrays_with_null() {
1974 let value_data = ArrayData::builder(DataType::Int32)
1975 .len(10)
1976 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1977 .build()
1978 .unwrap();
1979
1980 let mut null_bits: [u8; 1] = [0; 1];
1984 bit_util::set_bit(&mut null_bits, 0);
1985 bit_util::set_bit(&mut null_bits, 3);
1986 bit_util::set_bit(&mut null_bits, 4);
1987
1988 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1989 let list_data = ArrayData::builder(list_data_type)
1990 .len(5)
1991 .add_child_data(value_data)
1992 .null_bit_buffer(Some(Buffer::from(null_bits)))
1993 .build()
1994 .unwrap();
1995 let array = FixedSizeListArray::from(list_data);
1996
1997 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1998
1999 let c = filter(&array, &filter_array).unwrap();
2000 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
2001
2002 assert_eq!(filtered.len(), 3);
2003
2004 let list = filtered.value(0);
2005 assert_eq!(
2006 &[0, 1],
2007 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2008 );
2009 assert!(filtered.is_null(1));
2010 let list = filtered.value(2);
2011 assert_eq!(
2012 &[6, 7],
2013 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2014 );
2015 }
2016
2017 fn test_filter_union_array(array: UnionArray) {
2018 let filter_array = BooleanArray::from(vec![true, false, false]);
2019 let c = filter(&array, &filter_array).unwrap();
2020 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2021
2022 let mut builder = UnionBuilder::new_dense();
2023 builder.append::<Int32Type>("A", 1).unwrap();
2024 let expected_array = builder.build().unwrap();
2025
2026 compare_union_arrays(filtered, &expected_array);
2027
2028 let filter_array = BooleanArray::from(vec![true, false, true]);
2029 let c = filter(&array, &filter_array).unwrap();
2030 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2031
2032 let mut builder = UnionBuilder::new_dense();
2033 builder.append::<Int32Type>("A", 1).unwrap();
2034 builder.append::<Int32Type>("A", 34).unwrap();
2035 let expected_array = builder.build().unwrap();
2036
2037 compare_union_arrays(filtered, &expected_array);
2038
2039 let filter_array = BooleanArray::from(vec![true, true, false]);
2040 let c = filter(&array, &filter_array).unwrap();
2041 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2042
2043 let mut builder = UnionBuilder::new_dense();
2044 builder.append::<Int32Type>("A", 1).unwrap();
2045 builder.append::<Float64Type>("B", 3.2).unwrap();
2046 let expected_array = builder.build().unwrap();
2047
2048 compare_union_arrays(filtered, &expected_array);
2049 }
2050
2051 #[test]
2052 fn test_filter_union_array_dense() {
2053 let mut builder = UnionBuilder::new_dense();
2054 builder.append::<Int32Type>("A", 1).unwrap();
2055 builder.append::<Float64Type>("B", 3.2).unwrap();
2056 builder.append::<Int32Type>("A", 34).unwrap();
2057 let array = builder.build().unwrap();
2058
2059 test_filter_union_array(array);
2060 }
2061
2062 #[test]
2063 fn test_filter_run_union_array_dense() {
2064 let mut builder = UnionBuilder::new_dense();
2065 builder.append::<Int32Type>("A", 1).unwrap();
2066 builder.append::<Int32Type>("A", 3).unwrap();
2067 builder.append::<Int32Type>("A", 34).unwrap();
2068 let array = builder.build().unwrap();
2069
2070 let filter_array = BooleanArray::from(vec![true, true, false]);
2071 let c = filter(&array, &filter_array).unwrap();
2072 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2073
2074 let mut builder = UnionBuilder::new_dense();
2075 builder.append::<Int32Type>("A", 1).unwrap();
2076 builder.append::<Int32Type>("A", 3).unwrap();
2077 let expected = builder.build().unwrap();
2078
2079 assert_eq!(filtered.to_data(), expected.to_data());
2080 }
2081
2082 #[test]
2083 fn test_filter_union_array_dense_with_nulls() {
2084 let mut builder = UnionBuilder::new_dense();
2085 builder.append::<Int32Type>("A", 1).unwrap();
2086 builder.append::<Float64Type>("B", 3.2).unwrap();
2087 builder.append_null::<Float64Type>("B").unwrap();
2088 builder.append::<Int32Type>("A", 34).unwrap();
2089 let array = builder.build().unwrap();
2090
2091 let filter_array = BooleanArray::from(vec![true, true, false, false]);
2092 let c = filter(&array, &filter_array).unwrap();
2093 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2094
2095 let mut builder = UnionBuilder::new_dense();
2096 builder.append::<Int32Type>("A", 1).unwrap();
2097 builder.append::<Float64Type>("B", 3.2).unwrap();
2098 let expected_array = builder.build().unwrap();
2099
2100 compare_union_arrays(filtered, &expected_array);
2101
2102 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2103 let c = filter(&array, &filter_array).unwrap();
2104 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2105
2106 let mut builder = UnionBuilder::new_dense();
2107 builder.append::<Int32Type>("A", 1).unwrap();
2108 builder.append_null::<Float64Type>("B").unwrap();
2109 let expected_array = builder.build().unwrap();
2110
2111 compare_union_arrays(filtered, &expected_array);
2112 }
2113
2114 #[test]
2115 fn test_filter_union_array_sparse() {
2116 let mut builder = UnionBuilder::new_sparse();
2117 builder.append::<Int32Type>("A", 1).unwrap();
2118 builder.append::<Float64Type>("B", 3.2).unwrap();
2119 builder.append::<Int32Type>("A", 34).unwrap();
2120 let array = builder.build().unwrap();
2121
2122 test_filter_union_array(array);
2123 }
2124
2125 #[test]
2126 fn test_filter_union_array_sparse_with_nulls() {
2127 let mut builder = UnionBuilder::new_sparse();
2128 builder.append::<Int32Type>("A", 1).unwrap();
2129 builder.append::<Float64Type>("B", 3.2).unwrap();
2130 builder.append_null::<Float64Type>("B").unwrap();
2131 builder.append::<Int32Type>("A", 34).unwrap();
2132 let array = builder.build().unwrap();
2133
2134 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2135 let c = filter(&array, &filter_array).unwrap();
2136 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2137
2138 let mut builder = UnionBuilder::new_sparse();
2139 builder.append::<Int32Type>("A", 1).unwrap();
2140 builder.append_null::<Float64Type>("B").unwrap();
2141 let expected_array = builder.build().unwrap();
2142
2143 compare_union_arrays(filtered, &expected_array);
2144 }
2145
2146 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2147 assert_eq!(union1.len(), union2.len());
2148
2149 for i in 0..union1.len() {
2150 let type_id = union1.type_id(i);
2151
2152 let slot1 = union1.value(i);
2153 let slot2 = union2.value(i);
2154
2155 assert_eq!(slot1.is_null(0), slot2.is_null(0));
2156
2157 if !slot1.is_null(0) && !slot2.is_null(0) {
2158 match type_id {
2159 0 => {
2160 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2161 assert_eq!(slot1.len(), 1);
2162 let value1 = slot1.value(0);
2163
2164 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2165 assert_eq!(slot2.len(), 1);
2166 let value2 = slot2.value(0);
2167 assert_eq!(value1, value2);
2168 }
2169 1 => {
2170 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2171 assert_eq!(slot1.len(), 1);
2172 let value1 = slot1.value(0);
2173
2174 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2175 assert_eq!(slot2.len(), 1);
2176 let value2 = slot2.value(0);
2177 assert_eq!(value1, value2);
2178 }
2179 _ => unreachable!(),
2180 }
2181 }
2182 }
2183 }
2184
2185 #[test]
2186 fn test_filter_struct() {
2187 let predicate = BooleanArray::from(vec![true, false, true, false]);
2188
2189 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2190 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2191
2192 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2193 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2194
2195 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2196 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2197
2198 let a_field = Field::new("a", DataType::Utf8, false);
2199 let b_field = Field::new("b", DataType::Int32, false);
2200
2201 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2202 let expected =
2203 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2204
2205 let result = filter(&array, &predicate).unwrap();
2206
2207 assert_eq!(result.to_data(), expected.to_data());
2208
2209 let array = StructArray::new(
2210 vec![a_field.clone()].into(),
2211 vec![a.clone()],
2212 Some(null_mask.clone()),
2213 );
2214 let expected = StructArray::new(
2215 vec![a_field.clone()].into(),
2216 vec![a_filtered.clone()],
2217 Some(null_mask_filtered.clone()),
2218 );
2219
2220 let result = filter(&array, &predicate).unwrap();
2221
2222 assert_eq!(result.to_data(), expected.to_data());
2223
2224 let array = StructArray::new(
2225 vec![a_field.clone(), b_field.clone()].into(),
2226 vec![a.clone(), b.clone()],
2227 None,
2228 );
2229 let expected = StructArray::new(
2230 vec![a_field.clone(), b_field.clone()].into(),
2231 vec![a_filtered.clone(), b_filtered.clone()],
2232 None,
2233 );
2234
2235 let result = filter(&array, &predicate).unwrap();
2236
2237 assert_eq!(result.to_data(), expected.to_data());
2238
2239 let array = StructArray::new(
2240 vec![a_field.clone(), b_field.clone()].into(),
2241 vec![a.clone(), b.clone()],
2242 Some(null_mask.clone()),
2243 );
2244
2245 let expected = StructArray::new(
2246 vec![a_field.clone(), b_field.clone()].into(),
2247 vec![a_filtered.clone(), b_filtered.clone()],
2248 Some(null_mask_filtered.clone()),
2249 );
2250
2251 let result = filter(&array, &predicate).unwrap();
2252
2253 assert_eq!(result.to_data(), expected.to_data());
2254 }
2255
2256 #[test]
2257 fn test_filter_empty_struct() {
2258 let fields = arrow_schema::Field::new(
2265 "a",
2266 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2267 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2268 arrow_schema::Field::new(
2269 "c",
2270 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2271 true,
2272 ),
2273 ])),
2274 true,
2275 );
2276
2277 let schema = Arc::new(Schema::new(vec![fields]));
2285
2286 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2287 let c = Arc::new(StructArray::new_empty_fields(
2288 3,
2289 Some(NullBuffer::from(vec![true, true, true])),
2290 ));
2291 let a = StructArray::new(
2292 vec![
2293 Field::new("b", DataType::Int64, true),
2294 Field::new("c", DataType::Struct(Fields::empty()), true),
2295 ]
2296 .into(),
2297 vec![b.clone(), c.clone()],
2298 Some(NullBuffer::from(vec![true, true, true])),
2299 );
2300 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2301 println!("{record_batch:?}");
2302
2303 let predicate = BooleanArray::from(vec![true, false, true]);
2305 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2306
2307 assert_eq!(filtered_batch.num_rows(), 2);
2309 }
2310
2311 #[test]
2312 #[should_panic]
2313 fn test_filter_bits_too_large() {
2314 let buffer = BooleanBuffer::from(vec![false; 8]);
2315 let predicate = BooleanArray::from(vec![true; 9]);
2316 let filter = FilterBuilder::new(&predicate).build();
2317 filter_bits(&buffer, &filter);
2318 }
2319
2320 #[test]
2321 #[should_panic]
2322 fn test_filter_native_too_large() {
2323 let values = vec![1; 8];
2324 let predicate = BooleanArray::from(vec![false; 9]);
2325 let filter = FilterBuilder::new(&predicate).build();
2326 filter_native(&values, &filter);
2327 }
2328}