1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117 let nulls = filter.nulls().unwrap();
118 let mask = filter.values() & nulls.inner();
119 BooleanArray::new(mask, None)
120}
121
122pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144 let mut filter_builder = FilterBuilder::new(predicate);
145
146 if multiple_arrays(values.data_type()) {
147 filter_builder = filter_builder.optimize();
150 }
151
152 let predicate = filter_builder.build();
153
154 filter_array(values, &predicate)
155}
156
157fn multiple_arrays(data_type: &DataType) -> bool {
158 match data_type {
159 DataType::Struct(fields) => {
160 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161 }
162 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163 _ => false,
164 }
165}
166
167pub fn filter_record_batch(
172 record_batch: &RecordBatch,
173 predicate: &BooleanArray,
174) -> Result<RecordBatch, ArrowError> {
175 let mut filter_builder = FilterBuilder::new(predicate);
176 if record_batch.num_columns() > 1 {
177 filter_builder = filter_builder.optimize();
180 }
181 let filter = filter_builder.build();
182
183 let filtered_arrays = record_batch
184 .columns()
185 .iter()
186 .map(|a| filter_array(a, &filter))
187 .collect::<Result<Vec<_>, _>>()?;
188 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190}
191
192#[derive(Debug)]
194pub struct FilterBuilder {
195 filter: BooleanArray,
196 count: usize,
197 strategy: IterationStrategy,
198}
199
200impl FilterBuilder {
201 pub fn new(filter: &BooleanArray) -> Self {
203 let filter = match filter.null_count() {
204 0 => filter.clone(),
205 _ => prep_null_mask_filter(filter),
206 };
207
208 let count = filter_count(&filter);
209 let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211 Self {
212 filter,
213 count,
214 strategy,
215 }
216 }
217
218 pub fn optimize(mut self) -> Self {
224 match self.strategy {
225 IterationStrategy::SlicesIterator => {
226 let slices = SlicesIterator::new(&self.filter).collect();
227 self.strategy = IterationStrategy::Slices(slices)
228 }
229 IterationStrategy::IndexIterator => {
230 let indices = IndexIterator::new(&self.filter, self.count).collect();
231 self.strategy = IterationStrategy::Indices(indices)
232 }
233 _ => {}
234 }
235 self
236 }
237
238 pub fn build(self) -> FilterPredicate {
240 FilterPredicate {
241 filter: self.filter,
242 count: self.count,
243 strategy: self.strategy,
244 }
245 }
246}
247
248#[derive(Debug)]
250enum IterationStrategy {
251 SlicesIterator,
253 IndexIterator,
255 Indices(Vec<usize>),
257 Slices(Vec<(usize, usize)>),
259 All,
261 None,
263}
264
265impl IterationStrategy {
266 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269 if filter_length == 0 || filter_count == 0 {
270 return IterationStrategy::None;
271 }
272
273 if filter_count == filter_length {
274 return IterationStrategy::All;
275 }
276
277 let selectivity_frac = filter_count as f64 / filter_length as f64;
282 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283 return IterationStrategy::SlicesIterator;
284 }
285 IterationStrategy::IndexIterator
286 }
287}
288
289#[derive(Debug)]
291pub struct FilterPredicate {
292 filter: BooleanArray,
293 count: usize,
294 strategy: IterationStrategy,
295}
296
297impl FilterPredicate {
298 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300 filter_array(values, self)
301 }
302
303 pub fn count(&self) -> usize {
305 self.count
306 }
307}
308
309fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310 if predicate.filter.len() > values.len() {
311 return Err(ArrowError::InvalidArgumentError(format!(
312 "Filter predicate of length {} is larger than target array of length {}",
313 predicate.filter.len(),
314 values.len()
315 )));
316 }
317
318 match predicate.strategy {
319 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321 _ => downcast_primitive_array! {
323 values => Ok(Arc::new(filter_primitive(values, predicate))),
324 DataType::Boolean => {
325 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326 Ok(Arc::new(filter_boolean(values, predicate)))
327 }
328 DataType::Utf8 => {
329 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330 }
331 DataType::LargeUtf8 => {
332 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333 }
334 DataType::Utf8View => {
335 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336 }
337 DataType::Binary => {
338 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339 }
340 DataType::LargeBinary => {
341 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342 }
343 DataType::BinaryView => {
344 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345 }
346 DataType::FixedSizeBinary(_) => {
347 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348 }
349 DataType::RunEndEncoded(_, _) => {
350 downcast_run_array!{
351 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353 }
354 }
355 DataType::Dictionary(_, _) => downcast_dictionary_array! {
356 values => Ok(Arc::new(filter_dict(values, predicate))),
357 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358 }
359 DataType::Struct(_) => {
360 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361 }
362 DataType::Union(_, UnionMode::Sparse) => {
363 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364 }
365 _ => {
366 let data = values.to_data();
367 let mut mutable = MutableArrayData::new(
369 vec![&data],
370 false,
371 predicate.count,
372 );
373
374 match &predicate.strategy {
375 IterationStrategy::Slices(slices) => {
376 slices
377 .iter()
378 .for_each(|(start, end)| mutable.extend(0, *start, *end));
379 }
380 _ => {
381 let iter = SlicesIterator::new(&predicate.filter);
382 iter.for_each(|(start, end)| mutable.extend(0, start, end));
383 }
384 }
385
386 let data = mutable.freeze();
387 Ok(make_array(data))
388 }
389 },
390 }
391}
392
393fn filter_run_end_array<R: RunEndIndexType>(
395 array: &RunArray<R>,
396 predicate: &FilterPredicate,
397) -> Result<RunArray<R>, ArrowError>
398where
399 R::Native: Into<i64> + From<bool>,
400 R::Native: AddAssign,
401{
402 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405 let mut start = 0u64;
406 let mut j = 0;
407 let mut count = R::default_value();
408 let filter_values = predicate.filter.values();
409 let run_ends = run_ends.inner();
410
411 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412 let mut keep = false;
413 let mut end = run_ends[i].into() as u64;
414 let difference = end.saturating_sub(filter_values.len() as u64);
415 end -= difference;
416
417 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419 count += R::Native::from(pred);
420 keep |= pred
421 }
422 new_run_ends[j] = count;
424 j += keep as usize;
425
426 start = end;
427 keep
428 })
429 .into();
430
431 new_run_ends.truncate(j);
432
433 let values = array.values();
434 let values = filter(&values, &pred)?;
435
436 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
437 RunArray::try_new(&run_ends, &values)
438}
439
440fn filter_null_mask(
447 nulls: Option<&NullBuffer>,
448 predicate: &FilterPredicate,
449) -> Option<(usize, Buffer)> {
450 let nulls = nulls?;
451 if nulls.null_count() == 0 {
452 return None;
453 }
454
455 let nulls = filter_bits(nulls.inner(), predicate);
456 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460 if null_count == 0 {
461 return None;
462 }
463
464 Some((null_count, nulls))
465}
466
467fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469 let src = buffer.values();
470 let offset = buffer.offset();
471
472 match &predicate.strategy {
473 IterationStrategy::IndexIterator => {
474 let bits = IndexIterator::new(&predicate.filter, predicate.count)
475 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479 }
480 IterationStrategy::Indices(indices) => {
481 let bits = indices
482 .iter()
483 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487 }
488 IterationStrategy::SlicesIterator => {
489 let mut builder = BooleanBufferBuilder::new(predicate.count);
490 for (start, end) in SlicesIterator::new(&predicate.filter) {
491 builder.append_packed_range(start + offset..end + offset, src)
492 }
493 builder.into()
494 }
495 IterationStrategy::Slices(slices) => {
496 let mut builder = BooleanBufferBuilder::new(predicate.count);
497 for (start, end) in slices {
498 builder.append_packed_range(*start + offset..*end + offset, src)
499 }
500 builder.into()
501 }
502 IterationStrategy::All | IterationStrategy::None => unreachable!(),
503 }
504}
505
506fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508 let values = filter_bits(array.values(), predicate);
509
510 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511 .len(predicate.count)
512 .add_buffer(values);
513
514 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516 }
517
518 let data = unsafe { builder.build_unchecked() };
519 BooleanArray::from(data)
520}
521
522#[inline(never)]
523fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524 assert!(values.len() >= predicate.filter.len());
525
526 match &predicate.strategy {
527 IterationStrategy::SlicesIterator => {
528 let mut buffer = Vec::with_capacity(predicate.count);
529 for (start, end) in SlicesIterator::new(&predicate.filter) {
530 buffer.extend_from_slice(&values[start..end]);
531 }
532 buffer.into()
533 }
534 IterationStrategy::Slices(slices) => {
535 let mut buffer = Vec::with_capacity(predicate.count);
536 for (start, end) in slices {
537 buffer.extend_from_slice(&values[*start..*end]);
538 }
539 buffer.into()
540 }
541 IterationStrategy::IndexIterator => {
542 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546 }
547 IterationStrategy::Indices(indices) => {
548 let iter = indices.iter().map(|x| values[*x]);
549 iter.collect::<Vec<_>>().into()
550 }
551 IterationStrategy::All | IterationStrategy::None => unreachable!(),
552 }
553}
554
555fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557where
558 T: ArrowPrimitiveType,
559{
560 let values = array.values();
561 let buffer = filter_native(values, predicate);
562 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563 .len(predicate.count)
564 .add_buffer(buffer);
565
566 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568 }
569
570 let data = unsafe { builder.build_unchecked() };
571 PrimitiveArray::from(data)
572}
573
574struct FilterBytes<'a, OffsetSize> {
579 src_offsets: &'a [OffsetSize],
580 src_values: &'a [u8],
581 dst_offsets: Vec<OffsetSize>,
582 dst_values: Vec<u8>,
583 cur_offset: OffsetSize,
584}
585
586impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587where
588 OffsetSize: OffsetSizeTrait,
589{
590 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591 where
592 T: ByteArrayType<Offset = OffsetSize>,
593 {
594 let dst_values = Vec::new();
595 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596 let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598 dst_offsets.push(cur_offset);
599
600 Self {
601 src_offsets: array.value_offsets(),
602 src_values: array.value_data(),
603 dst_offsets,
604 dst_values,
605 cur_offset,
606 }
607 }
608
609 #[inline]
611 fn get_value_offset(&self, idx: usize) -> usize {
612 self.src_offsets[idx].as_usize()
613 }
614
615 #[inline]
617 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618 let start = self.get_value_offset(idx);
620 let end = self.get_value_offset(idx + 1);
621 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622 (start, end, len)
623 }
624
625 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626 self.dst_offsets.extend(iter.map(|idx| {
627 let start = self.src_offsets[idx].as_usize();
628 let end = self.src_offsets[idx + 1].as_usize();
629 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630 self.cur_offset += len;
631
632 self.cur_offset
633 }));
634 }
635
636 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638 self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640 for idx in iter {
641 let start = self.src_offsets[idx].as_usize();
642 let end = self.src_offsets[idx + 1].as_usize();
643 self.dst_values
644 .extend_from_slice(&self.src_values[start..end]);
645 }
646 }
647
648 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649 self.dst_offsets.reserve_exact(count);
650 for (start, end) in iter {
651 for idx in start..end {
653 let (_, _, len) = self.get_value_range(idx);
654 self.cur_offset += len;
655 self.dst_offsets.push(self.cur_offset);
656 }
657 }
658 }
659
660 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662 self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664 for (start, end) in iter {
665 let value_start = self.get_value_offset(start);
666 let value_end = self.get_value_offset(end);
667 self.dst_values
668 .extend_from_slice(&self.src_values[value_start..value_end]);
669 }
670 }
671}
672
673fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678where
679 T: ByteArrayType,
680{
681 let mut filter = FilterBytes::new(predicate.count, array);
682
683 match &predicate.strategy {
684 IterationStrategy::SlicesIterator => {
685 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686 filter.extend_slices(SlicesIterator::new(&predicate.filter))
687 }
688 IterationStrategy::Slices(slices) => {
689 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690 filter.extend_slices(slices.iter().cloned())
691 }
692 IterationStrategy::IndexIterator => {
693 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695 }
696 IterationStrategy::Indices(indices) => {
697 filter.extend_offsets_idx(indices.iter().cloned());
698 filter.extend_idx(indices.iter().cloned())
699 }
700 IterationStrategy::All | IterationStrategy::None => unreachable!(),
701 }
702
703 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704 .len(predicate.count)
705 .add_buffer(filter.dst_offsets.into())
706 .add_buffer(filter.dst_values.into());
707
708 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710 }
711
712 let data = unsafe { builder.build_unchecked() };
713 GenericByteArray::from(data)
714}
715
716fn filter_byte_view<T: ByteViewType>(
718 array: &GenericByteViewArray<T>,
719 predicate: &FilterPredicate,
720) -> GenericByteViewArray<T> {
721 let new_view_buffer = filter_native(array.views(), predicate);
722
723 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724 .len(predicate.count)
725 .add_buffer(new_view_buffer)
726 .add_buffers(array.data_buffers().to_vec());
727
728 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730 }
731
732 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733}
734
735fn filter_fixed_size_binary(
736 array: &FixedSizeBinaryArray,
737 predicate: &FilterPredicate,
738) -> FixedSizeBinaryArray {
739 let values: &[u8] = array.values();
740 let value_length = array.value_length() as usize;
741 let calculate_offset_from_index = |index: usize| index * value_length;
742 let buffer = match &predicate.strategy {
743 IterationStrategy::SlicesIterator => {
744 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745 for (start, end) in SlicesIterator::new(&predicate.filter) {
746 buffer.extend_from_slice(
747 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748 );
749 }
750 buffer
751 }
752 IterationStrategy::Slices(slices) => {
753 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754 for (start, end) in slices {
755 buffer.extend_from_slice(
756 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757 );
758 }
759 buffer
760 }
761 IterationStrategy::IndexIterator => {
762 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764 });
765
766 let mut buffer = MutableBuffer::new(predicate.count * value_length);
767 iter.for_each(|item| buffer.extend_from_slice(item));
768 buffer
769 }
770 IterationStrategy::Indices(indices) => {
771 let iter = indices.iter().map(|x| {
772 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773 });
774
775 let mut buffer = MutableBuffer::new(predicate.count * value_length);
776 iter.for_each(|item| buffer.extend_from_slice(item));
777 buffer
778 }
779 IterationStrategy::All | IterationStrategy::None => unreachable!(),
780 };
781 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782 .len(predicate.count)
783 .add_buffer(buffer.into());
784
785 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787 }
788
789 let data = unsafe { builder.build_unchecked() };
790 FixedSizeBinaryArray::from(data)
791}
792
793fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795where
796 T: ArrowDictionaryKeyType,
797 T::Native: num_traits::Num,
798{
799 let builder = filter_primitive::<T>(array.keys(), predicate)
800 .into_data()
801 .into_builder()
802 .data_type(array.data_type().clone())
803 .child_data(vec![array.values().to_data()]);
804
805 DictionaryArray::from(unsafe { builder.build_unchecked() })
808}
809
810fn filter_struct(
812 array: &StructArray,
813 predicate: &FilterPredicate,
814) -> Result<StructArray, ArrowError> {
815 let columns = array
816 .columns()
817 .iter()
818 .map(|column| filter_array(column, predicate))
819 .collect::<Result<_, _>>()?;
820
821 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825 } else {
826 None
827 };
828
829 Ok(unsafe {
830 StructArray::new_unchecked_with_length(
831 array.fields().clone(),
832 columns,
833 nulls,
834 predicate.count(),
835 )
836 })
837}
838
839fn filter_sparse_union(
841 array: &UnionArray,
842 predicate: &FilterPredicate,
843) -> Result<UnionArray, ArrowError> {
844 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845 unreachable!()
846 };
847
848 let type_ids = filter_primitive(
849 &Int8Array::try_new(array.type_ids().clone(), None)?,
850 predicate,
851 );
852
853 let children = fields
854 .iter()
855 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
856 .collect::<Result<_, _>>()?;
857
858 Ok(unsafe {
859 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
860 })
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866 use arrow_array::builder::*;
867 use arrow_array::cast::as_run_array;
868 use arrow_array::types::*;
869 use arrow_data::ArrayData;
870 use rand::distr::uniform::{UniformSampler, UniformUsize};
871 use rand::distr::{Alphanumeric, StandardUniform};
872 use rand::prelude::*;
873 use rand::rng;
874
875 macro_rules! def_temporal_test {
876 ($test:ident, $array_type: ident, $data: expr) => {
877 #[test]
878 fn $test() {
879 let a = $data;
880 let b = BooleanArray::from(vec![true, false, true, false]);
881 let c = filter(&a, &b).unwrap();
882 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
883 assert_eq!(2, d.len());
884 assert_eq!(1, d.value(0));
885 assert_eq!(3, d.value(1));
886 }
887 };
888 }
889
890 def_temporal_test!(
891 test_filter_date32,
892 Date32Array,
893 Date32Array::from(vec![1, 2, 3, 4])
894 );
895 def_temporal_test!(
896 test_filter_date64,
897 Date64Array,
898 Date64Array::from(vec![1, 2, 3, 4])
899 );
900 def_temporal_test!(
901 test_filter_time32_second,
902 Time32SecondArray,
903 Time32SecondArray::from(vec![1, 2, 3, 4])
904 );
905 def_temporal_test!(
906 test_filter_time32_millisecond,
907 Time32MillisecondArray,
908 Time32MillisecondArray::from(vec![1, 2, 3, 4])
909 );
910 def_temporal_test!(
911 test_filter_time64_microsecond,
912 Time64MicrosecondArray,
913 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
914 );
915 def_temporal_test!(
916 test_filter_time64_nanosecond,
917 Time64NanosecondArray,
918 Time64NanosecondArray::from(vec![1, 2, 3, 4])
919 );
920 def_temporal_test!(
921 test_filter_duration_second,
922 DurationSecondArray,
923 DurationSecondArray::from(vec![1, 2, 3, 4])
924 );
925 def_temporal_test!(
926 test_filter_duration_millisecond,
927 DurationMillisecondArray,
928 DurationMillisecondArray::from(vec![1, 2, 3, 4])
929 );
930 def_temporal_test!(
931 test_filter_duration_microsecond,
932 DurationMicrosecondArray,
933 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
934 );
935 def_temporal_test!(
936 test_filter_duration_nanosecond,
937 DurationNanosecondArray,
938 DurationNanosecondArray::from(vec![1, 2, 3, 4])
939 );
940 def_temporal_test!(
941 test_filter_timestamp_second,
942 TimestampSecondArray,
943 TimestampSecondArray::from(vec![1, 2, 3, 4])
944 );
945 def_temporal_test!(
946 test_filter_timestamp_millisecond,
947 TimestampMillisecondArray,
948 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
949 );
950 def_temporal_test!(
951 test_filter_timestamp_microsecond,
952 TimestampMicrosecondArray,
953 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
954 );
955 def_temporal_test!(
956 test_filter_timestamp_nanosecond,
957 TimestampNanosecondArray,
958 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
959 );
960
961 #[test]
962 fn test_filter_array_slice() {
963 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
964 let b = BooleanArray::from(vec![true, false, false, true]);
965 let c = filter(&a, &b).unwrap();
969 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
970 assert_eq!(2, d.len());
971 assert_eq!(6, d.value(0));
972 assert_eq!(9, d.value(1));
973 }
974
975 #[test]
976 fn test_filter_array_low_density() {
977 let mut data_values = (1..=65).collect::<Vec<i32>>();
979 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
980 data_values.extend_from_slice(&[66, 67]);
982 filter_values.extend_from_slice(&[false, true]);
983 let a = Int32Array::from(data_values);
984 let b = BooleanArray::from(filter_values);
985 let c = filter(&a, &b).unwrap();
986 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
987 assert_eq!(2, d.len());
988 assert_eq!(65, d.value(0));
989 assert_eq!(67, d.value(1));
990 }
991
992 #[test]
993 fn test_filter_array_high_density() {
994 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
996 let mut filter_values = (1..=65)
997 .map(|i| !matches!(i % 65, 0))
998 .collect::<Vec<bool>>();
999 data_values[1] = None;
1001 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1003 filter_values.extend_from_slice(&[false, true, true, true]);
1004 let a = Int32Array::from(data_values);
1005 let b = BooleanArray::from(filter_values);
1006 let c = filter(&a, &b).unwrap();
1007 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1008 assert_eq!(67, d.len());
1009 assert_eq!(3, d.null_count());
1010 assert_eq!(1, d.value(0));
1011 assert!(d.is_null(1));
1012 assert_eq!(64, d.value(63));
1013 assert!(d.is_null(64));
1014 assert_eq!(67, d.value(65));
1015 }
1016
1017 #[test]
1018 fn test_filter_string_array_simple() {
1019 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1020 let b = BooleanArray::from(vec![true, false, true, false]);
1021 let c = filter(&a, &b).unwrap();
1022 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1023 assert_eq!(2, d.len());
1024 assert_eq!("hello", d.value(0));
1025 assert_eq!("world", d.value(1));
1026 }
1027
1028 #[test]
1029 fn test_filter_primitive_array_with_null() {
1030 let a = Int32Array::from(vec![Some(5), None]);
1031 let b = BooleanArray::from(vec![false, true]);
1032 let c = filter(&a, &b).unwrap();
1033 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1034 assert_eq!(1, d.len());
1035 assert!(d.is_null(0));
1036 }
1037
1038 #[test]
1039 fn test_filter_string_array_with_null() {
1040 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1041 let b = BooleanArray::from(vec![true, false, false, true]);
1042 let c = filter(&a, &b).unwrap();
1043 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1044 assert_eq!(2, d.len());
1045 assert_eq!("hello", d.value(0));
1046 assert!(!d.is_null(0));
1047 assert!(d.is_null(1));
1048 }
1049
1050 #[test]
1051 fn test_filter_binary_array_with_null() {
1052 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1053 let a = BinaryArray::from(data);
1054 let b = BooleanArray::from(vec![true, false, false, true]);
1055 let c = filter(&a, &b).unwrap();
1056 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1057 assert_eq!(2, d.len());
1058 assert_eq!(b"hello", d.value(0));
1059 assert!(!d.is_null(0));
1060 assert!(d.is_null(1));
1061 }
1062
1063 fn _test_filter_byte_view<T>()
1064 where
1065 T: ByteViewType,
1066 str: AsRef<T::Native>,
1067 T::Native: PartialEq,
1068 {
1069 let array = {
1070 let mut builder = GenericByteViewBuilder::<T>::new();
1072 builder.append_value("hello");
1073 builder.append_value("world");
1074 builder.append_null();
1075 builder.append_value("large payload over 12 bytes");
1076 builder.append_value("lulu");
1077 builder.finish()
1078 };
1079
1080 {
1081 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1082 let actual = filter(&array, &predicate).unwrap();
1083
1084 assert_eq!(actual.len(), 3);
1085
1086 let expected = {
1087 let mut builder = GenericByteViewBuilder::<T>::new();
1089 builder.append_value("hello");
1090 builder.append_null();
1091 builder.append_value("large payload over 12 bytes");
1092 builder.finish()
1093 };
1094
1095 assert_eq!(actual.as_ref(), &expected);
1096 }
1097
1098 {
1099 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1100 let actual = filter(&array, &predicate).unwrap();
1101
1102 assert_eq!(actual.len(), 2);
1103
1104 let expected = {
1105 let mut builder = GenericByteViewBuilder::<T>::new();
1107 builder.append_value("hello");
1108 builder.append_value("lulu");
1109 builder.finish()
1110 };
1111
1112 assert_eq!(actual.as_ref(), &expected);
1113 }
1114 }
1115
1116 #[test]
1117 fn test_filter_string_view() {
1118 _test_filter_byte_view::<StringViewType>()
1119 }
1120
1121 #[test]
1122 fn test_filter_binary_view() {
1123 _test_filter_byte_view::<BinaryViewType>()
1124 }
1125
1126 #[test]
1127 fn test_filter_fixed_binary() {
1128 let v1 = [1_u8, 2];
1129 let v2 = [3_u8, 4];
1130 let v3 = [5_u8, 6];
1131 let v = vec![&v1, &v2, &v3];
1132 let a = FixedSizeBinaryArray::from(v);
1133 let b = BooleanArray::from(vec![true, false, true]);
1134 let c = filter(&a, &b).unwrap();
1135 let d = c
1136 .as_ref()
1137 .as_any()
1138 .downcast_ref::<FixedSizeBinaryArray>()
1139 .unwrap();
1140 assert_eq!(d.len(), 2);
1141 assert_eq!(d.value(0), &v1);
1142 assert_eq!(d.value(1), &v3);
1143 let c2 = FilterBuilder::new(&b)
1144 .optimize()
1145 .build()
1146 .filter(&a)
1147 .unwrap();
1148 let d2 = c2
1149 .as_ref()
1150 .as_any()
1151 .downcast_ref::<FixedSizeBinaryArray>()
1152 .unwrap();
1153 assert_eq!(d, d2);
1154
1155 let b = BooleanArray::from(vec![false, false, false]);
1156 let c = filter(&a, &b).unwrap();
1157 let d = c
1158 .as_ref()
1159 .as_any()
1160 .downcast_ref::<FixedSizeBinaryArray>()
1161 .unwrap();
1162 assert_eq!(d.len(), 0);
1163
1164 let b = BooleanArray::from(vec![true, true, true]);
1165 let c = filter(&a, &b).unwrap();
1166 let d = c
1167 .as_ref()
1168 .as_any()
1169 .downcast_ref::<FixedSizeBinaryArray>()
1170 .unwrap();
1171 assert_eq!(d.len(), 3);
1172 assert_eq!(d.value(0), &v1);
1173 assert_eq!(d.value(1), &v2);
1174 assert_eq!(d.value(2), &v3);
1175
1176 let b = BooleanArray::from(vec![false, false, true]);
1177 let c = filter(&a, &b).unwrap();
1178 let d = c
1179 .as_ref()
1180 .as_any()
1181 .downcast_ref::<FixedSizeBinaryArray>()
1182 .unwrap();
1183 assert_eq!(d.len(), 1);
1184 assert_eq!(d.value(0), &v3);
1185 let c2 = FilterBuilder::new(&b)
1186 .optimize()
1187 .build()
1188 .filter(&a)
1189 .unwrap();
1190 let d2 = c2
1191 .as_ref()
1192 .as_any()
1193 .downcast_ref::<FixedSizeBinaryArray>()
1194 .unwrap();
1195 assert_eq!(d, d2);
1196 }
1197
1198 #[test]
1199 fn test_filter_array_slice_with_null() {
1200 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1201 let b = BooleanArray::from(vec![true, false, false, true]);
1202 let c = filter(&a, &b).unwrap();
1206 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1207 assert_eq!(2, d.len());
1208 assert!(d.is_null(0));
1209 assert!(!d.is_null(1));
1210 assert_eq!(9, d.value(1));
1211 }
1212
1213 #[test]
1214 fn test_filter_run_end_encoding_array() {
1215 let run_ends = Int64Array::from(vec![2, 3, 8]);
1216 let values = Int64Array::from(vec![7, -2, 9]);
1217 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1218 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1219 let c = filter(&a, &b).unwrap();
1220 let actual: &RunArray<Int64Type> = as_run_array(&c);
1221 assert_eq!(4, actual.len());
1222
1223 let expected = RunArray::try_new(
1224 &Int64Array::from(vec![1, 2, 4]),
1225 &Int64Array::from(vec![7, -2, 9]),
1226 )
1227 .expect("Failed to make expected RunArray test is broken");
1228
1229 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1230 assert_eq!(actual.values(), expected.values())
1231 }
1232
1233 #[test]
1234 fn test_filter_run_end_encoding_array_remove_value() {
1235 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1236 let values = Int32Array::from(vec![7, -2, 9, -8]);
1237 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1238 let b = BooleanArray::from(vec![
1239 false, true, false, false, true, false, true, false, false, false,
1240 ]);
1241 let c = filter(&a, &b).unwrap();
1242 let actual: &RunArray<Int32Type> = as_run_array(&c);
1243 assert_eq!(3, actual.len());
1244
1245 let expected =
1246 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1247 .expect("Failed to make expected RunArray test is broken");
1248
1249 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1250 assert_eq!(actual.values(), expected.values())
1251 }
1252
1253 #[test]
1254 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1255 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1256 let values = Int16Array::from(vec![7, -2, 9, -8]);
1257 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1258 let b = BooleanArray::from(vec![
1259 false, false, false, false, false, false, true, false, false, false,
1260 ]);
1261 let c = filter(&a, &b).unwrap();
1262 let actual: &RunArray<Int16Type> = as_run_array(&c);
1263 assert_eq!(1, actual.len());
1264
1265 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1266 .expect("Failed to make expected RunArray test is broken");
1267
1268 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1269 assert_eq!(actual.values(), expected.values())
1270 }
1271
1272 #[test]
1273 fn test_filter_run_end_encoding_array_empty() {
1274 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1275 let values = Int64Array::from(vec![7, -2, 9, -8]);
1276 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1277 let b = BooleanArray::from(vec![
1278 false, false, false, false, false, false, false, false, false, false,
1279 ]);
1280 let c = filter(&a, &b).unwrap();
1281 let actual: &RunArray<Int64Type> = as_run_array(&c);
1282 assert_eq!(0, actual.len());
1283 }
1284
1285 #[test]
1286 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1287 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1288 let values = Int64Array::from(vec![7, -2, 9, -8]);
1289 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1290 let b = BooleanArray::from(vec![false, true, true]);
1291 let c = filter(&a, &b).unwrap();
1292 let actual: &RunArray<Int64Type> = as_run_array(&c);
1293 assert_eq!(2, actual.len());
1294
1295 let expected = RunArray::try_new(
1296 &Int64Array::from(vec![1, 2]),
1297 &Int64Array::from(vec![7, -2]),
1298 )
1299 .expect("Failed to make expected RunArray test is broken");
1300
1301 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1302 assert_eq!(actual.values(), expected.values())
1303 }
1304
1305 #[test]
1306 fn test_filter_dictionary_array() {
1307 let values = [Some("hello"), None, Some("world"), Some("!")];
1308 let a: Int8DictionaryArray = values.iter().copied().collect();
1309 let b = BooleanArray::from(vec![false, true, true, false]);
1310 let c = filter(&a, &b).unwrap();
1311 let d = c
1312 .as_ref()
1313 .as_any()
1314 .downcast_ref::<Int8DictionaryArray>()
1315 .unwrap();
1316 let value_array = d.values();
1317 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1318 assert_eq!(3, values.len());
1320 assert_eq!(2, d.len());
1322 assert!(d.is_null(0));
1323 assert_eq!("world", values.value(d.keys().value(1) as usize));
1324 }
1325
1326 #[test]
1327 fn test_filter_list_array() {
1328 let value_data = ArrayData::builder(DataType::Int32)
1329 .len(8)
1330 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1331 .build()
1332 .unwrap();
1333
1334 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1335
1336 let list_data_type =
1337 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1338 let list_data = ArrayData::builder(list_data_type)
1339 .len(4)
1340 .add_buffer(value_offsets)
1341 .add_child_data(value_data)
1342 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1343 .build()
1344 .unwrap();
1345
1346 let a = LargeListArray::from(list_data);
1348 let b = BooleanArray::from(vec![false, true, false, true]);
1349 let result = filter(&a, &b).unwrap();
1350
1351 let value_data = ArrayData::builder(DataType::Int32)
1353 .len(3)
1354 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1355 .build()
1356 .unwrap();
1357
1358 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1359
1360 let list_data_type =
1361 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1362 let expected = ArrayData::builder(list_data_type)
1363 .len(2)
1364 .add_buffer(value_offsets)
1365 .add_child_data(value_data)
1366 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1367 .build()
1368 .unwrap();
1369
1370 assert_eq!(&make_array(expected), &result);
1371 }
1372
1373 #[test]
1374 fn test_slice_iterator_bits() {
1375 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1376 let filter = BooleanArray::from(filter_values);
1377 let filter_count = filter_count(&filter);
1378
1379 let iter = SlicesIterator::new(&filter);
1380 let chunks = iter.collect::<Vec<_>>();
1381
1382 assert_eq!(chunks, vec![(1, 2)]);
1383 assert_eq!(filter_count, 1);
1384 }
1385
1386 #[test]
1387 fn test_slice_iterator_bits1() {
1388 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1389 let filter = BooleanArray::from(filter_values);
1390 let filter_count = filter_count(&filter);
1391
1392 let iter = SlicesIterator::new(&filter);
1393 let chunks = iter.collect::<Vec<_>>();
1394
1395 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1396 assert_eq!(filter_count, 64 - 1);
1397 }
1398
1399 #[test]
1400 fn test_slice_iterator_chunk_and_bits() {
1401 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1402 let filter = BooleanArray::from(filter_values);
1403 let filter_count = filter_count(&filter);
1404
1405 let iter = SlicesIterator::new(&filter);
1406 let chunks = iter.collect::<Vec<_>>();
1407
1408 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1409 assert_eq!(filter_count, 61 + 61 + 5);
1410 }
1411
1412 #[test]
1413 fn test_null_mask() {
1414 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1415
1416 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1417 let out = filter(&a, &mask1).unwrap();
1418 assert_eq!(out.as_ref(), &a.slice(0, 2));
1419 }
1420
1421 #[test]
1422 fn test_filter_record_batch_no_columns() {
1423 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1424 let options = RecordBatchOptions::default().with_row_count(Some(100));
1425 let record_batch =
1426 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1427 let out = filter_record_batch(&record_batch, &pred).unwrap();
1428
1429 assert_eq!(out.num_rows(), 2);
1430 }
1431
1432 #[test]
1433 fn test_fast_path() {
1434 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1435
1436 let mask = BooleanArray::from(vec![true, true, true]);
1438 let out = filter(&a, &mask).unwrap();
1439 let b = out
1440 .as_any()
1441 .downcast_ref::<PrimitiveArray<Int64Type>>()
1442 .unwrap();
1443 assert_eq!(&a, b);
1444
1445 let mask = BooleanArray::from(vec![false, false, false]);
1447 let out = filter(&a, &mask).unwrap();
1448 assert_eq!(out.len(), 0);
1449 assert_eq!(out.data_type(), &DataType::Int64);
1450 }
1451
1452 #[test]
1453 fn test_slices() {
1454 let bools = std::iter::repeat_n(true, 10)
1456 .chain(std::iter::repeat_n(false, 30))
1457 .chain(std::iter::repeat_n(true, 20))
1458 .chain(std::iter::repeat_n(false, 17))
1459 .chain(std::iter::repeat_n(true, 4));
1460
1461 let bool_array: BooleanArray = bools.map(Some).collect();
1462
1463 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1464 let expected = vec![(0, 10), (40, 60), (77, 81)];
1465 assert_eq!(slices, expected);
1466
1467 let len = bool_array.len();
1469 let sliced_array = bool_array.slice(7, len - 10);
1470 let sliced_array = sliced_array
1471 .as_any()
1472 .downcast_ref::<BooleanArray>()
1473 .unwrap();
1474 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1475 let expected = vec![(0, 3), (33, 53), (70, 71)];
1476 assert_eq!(slices, expected);
1477 }
1478
1479 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1480 let mut rng = rng();
1481
1482 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1483 .take(mask_len)
1484 .collect();
1485
1486 let buffer = Buffer::from_iter(bools.iter().cloned());
1487
1488 let truncated_length = mask_len - offset - truncate;
1489
1490 let data = ArrayDataBuilder::new(DataType::Boolean)
1491 .len(truncated_length)
1492 .offset(offset)
1493 .add_buffer(buffer)
1494 .build()
1495 .unwrap();
1496
1497 let filter = BooleanArray::from(data);
1498
1499 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1500 .flat_map(|(start, end)| start..end)
1501 .collect();
1502
1503 let count = filter_count(&filter);
1504 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1505
1506 let expected_bits: Vec<_> = bools
1507 .iter()
1508 .skip(offset)
1509 .take(truncated_length)
1510 .enumerate()
1511 .flat_map(|(idx, v)| v.then(|| idx))
1512 .collect();
1513
1514 assert_eq!(slice_bits, expected_bits);
1515 assert_eq!(index_bits, expected_bits);
1516 }
1517
1518 #[test]
1519 #[cfg_attr(miri, ignore)]
1520 fn fuzz_test_slices_iterator() {
1521 let mut rng = rng();
1522
1523 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1524 for _ in 0..100 {
1525 let mask_len = rng.random_range(0..1024);
1526 let max_offset = 64.min(mask_len);
1527 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1528
1529 let max_truncate = 128.min(mask_len - offset);
1530 let truncate = uusize
1531 .sample(&mut rng)
1532 .checked_rem(max_truncate)
1533 .unwrap_or(0);
1534
1535 test_slices_fuzz(mask_len, offset, truncate);
1536 }
1537
1538 test_slices_fuzz(64, 0, 0);
1539 test_slices_fuzz(64, 8, 0);
1540 test_slices_fuzz(64, 8, 8);
1541 test_slices_fuzz(32, 8, 8);
1542 test_slices_fuzz(32, 5, 9);
1543 }
1544
1545 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1547 values
1548 .into_iter()
1549 .zip(predicate)
1550 .filter(|(_, x)| **x)
1551 .map(|(a, _)| a)
1552 .collect()
1553 }
1554
1555 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1557 where
1558 StandardUniform: Distribution<T>,
1559 {
1560 let mut rng = rng();
1561 (0..len)
1562 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1563 .collect()
1564 }
1565
1566 fn gen_strings(
1568 len: usize,
1569 valid_percent: f64,
1570 str_len_range: std::ops::Range<usize>,
1571 ) -> Vec<Option<String>> {
1572 let mut rng = rng();
1573 (0..len)
1574 .map(|_| {
1575 rng.random_bool(valid_percent).then(|| {
1576 let len = rng.random_range(str_len_range.clone());
1577 (0..len)
1578 .map(|_| char::from(rng.sample(Alphanumeric)))
1579 .collect()
1580 })
1581 })
1582 .collect()
1583 }
1584
1585 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1587 src.iter().map(|x| x.as_deref())
1588 }
1589
1590 #[test]
1591 #[cfg_attr(miri, ignore)]
1592 fn fuzz_filter() {
1593 let mut rng = rng();
1594
1595 for i in 0..100 {
1596 let filter_percent = match i {
1597 0..=4 => 1.,
1598 5..=10 => 0.,
1599 _ => rng.random_range(0.0..1.0),
1600 };
1601
1602 let valid_percent = rng.random_range(0.0..1.0);
1603
1604 let array_len = rng.random_range(32..256);
1605 let array_offset = rng.random_range(0..10);
1606
1607 let filter_offset = rng.random_range(0..10);
1609 let filter_truncate = rng.random_range(0..10);
1610 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1611 .take(array_len + filter_offset - filter_truncate)
1612 .collect();
1613
1614 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1615
1616 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1618 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1619 let bools = &bools[filter_offset..];
1620
1621 let values = gen_primitive(array_len + array_offset, valid_percent);
1623 let src = Int32Array::from_iter(values.iter().cloned());
1624
1625 let src = src.slice(array_offset, array_len);
1626 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1627 let values = &values[array_offset..];
1628
1629 let filtered = filter(src, predicate).unwrap();
1630 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1631 let actual: Vec<_> = array.iter().collect();
1632
1633 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1634
1635 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1637 let src = StringArray::from_iter(as_deref(&strings));
1638
1639 let src = src.slice(array_offset, array_len);
1640 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1641
1642 let filtered = filter(src, predicate).unwrap();
1643 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1644 let actual: Vec<_> = array.iter().collect();
1645
1646 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1647 assert_eq!(actual, expected_strings);
1648
1649 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1651
1652 let src = src.slice(array_offset, array_len);
1653 let src = src
1654 .as_any()
1655 .downcast_ref::<DictionaryArray<Int32Type>>()
1656 .unwrap();
1657
1658 let filtered = filter(src, predicate).unwrap();
1659
1660 let array = filtered
1661 .as_any()
1662 .downcast_ref::<DictionaryArray<Int32Type>>()
1663 .unwrap();
1664
1665 let values = array
1666 .values()
1667 .as_any()
1668 .downcast_ref::<StringArray>()
1669 .unwrap();
1670
1671 let actual: Vec<_> = array
1672 .keys()
1673 .iter()
1674 .map(|key| key.map(|key| values.value(key as usize)))
1675 .collect();
1676
1677 assert_eq!(actual, expected_strings);
1678 }
1679 }
1680
1681 #[test]
1682 fn test_filter_map() {
1683 let mut builder =
1684 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1685 builder.keys().append_value("key1");
1687 builder.values().append_value(1);
1688 builder.append(true).unwrap();
1689 builder.keys().append_value("key2");
1690 builder.keys().append_value("key3");
1691 builder.values().append_value(2);
1692 builder.values().append_value(3);
1693 builder.append(true).unwrap();
1694 builder.append(false).unwrap();
1695 builder.keys().append_value("key1");
1696 builder.values().append_value(1);
1697 builder.append(true).unwrap();
1698 let maparray = Arc::new(builder.finish()) as ArrayRef;
1699
1700 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1701 .into_iter()
1702 .collect::<BooleanArray>();
1703 let got = filter(&maparray, &indices).unwrap();
1704
1705 let mut builder =
1706 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1707 builder.keys().append_value("key1");
1708 builder.values().append_value(1);
1709 builder.append(true).unwrap();
1710 builder.keys().append_value("key1");
1711 builder.values().append_value(1);
1712 builder.append(true).unwrap();
1713 let expected = Arc::new(builder.finish()) as ArrayRef;
1714
1715 assert_eq!(&expected, &got);
1716 }
1717
1718 #[test]
1719 fn test_filter_fixed_size_list_arrays() {
1720 let value_data = ArrayData::builder(DataType::Int32)
1721 .len(9)
1722 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1723 .build()
1724 .unwrap();
1725 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1726 let list_data = ArrayData::builder(list_data_type)
1727 .len(3)
1728 .add_child_data(value_data)
1729 .build()
1730 .unwrap();
1731 let array = FixedSizeListArray::from(list_data);
1732
1733 let filter_array = BooleanArray::from(vec![true, false, false]);
1734
1735 let c = filter(&array, &filter_array).unwrap();
1736 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1737
1738 assert_eq!(filtered.len(), 1);
1739
1740 let list = filtered.value(0);
1741 assert_eq!(
1742 &[0, 1, 2],
1743 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1744 );
1745
1746 let filter_array = BooleanArray::from(vec![true, false, true]);
1747
1748 let c = filter(&array, &filter_array).unwrap();
1749 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1750
1751 assert_eq!(filtered.len(), 2);
1752
1753 let list = filtered.value(0);
1754 assert_eq!(
1755 &[0, 1, 2],
1756 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1757 );
1758 let list = filtered.value(1);
1759 assert_eq!(
1760 &[6, 7, 8],
1761 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1762 );
1763 }
1764
1765 #[test]
1766 fn test_filter_fixed_size_list_arrays_with_null() {
1767 let value_data = ArrayData::builder(DataType::Int32)
1768 .len(10)
1769 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1770 .build()
1771 .unwrap();
1772
1773 let mut null_bits: [u8; 1] = [0; 1];
1777 bit_util::set_bit(&mut null_bits, 0);
1778 bit_util::set_bit(&mut null_bits, 3);
1779 bit_util::set_bit(&mut null_bits, 4);
1780
1781 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1782 let list_data = ArrayData::builder(list_data_type)
1783 .len(5)
1784 .add_child_data(value_data)
1785 .null_bit_buffer(Some(Buffer::from(null_bits)))
1786 .build()
1787 .unwrap();
1788 let array = FixedSizeListArray::from(list_data);
1789
1790 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1791
1792 let c = filter(&array, &filter_array).unwrap();
1793 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1794
1795 assert_eq!(filtered.len(), 3);
1796
1797 let list = filtered.value(0);
1798 assert_eq!(
1799 &[0, 1],
1800 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801 );
1802 assert!(filtered.is_null(1));
1803 let list = filtered.value(2);
1804 assert_eq!(
1805 &[6, 7],
1806 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1807 );
1808 }
1809
1810 fn test_filter_union_array(array: UnionArray) {
1811 let filter_array = BooleanArray::from(vec![true, false, false]);
1812 let c = filter(&array, &filter_array).unwrap();
1813 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1814
1815 let mut builder = UnionBuilder::new_dense();
1816 builder.append::<Int32Type>("A", 1).unwrap();
1817 let expected_array = builder.build().unwrap();
1818
1819 compare_union_arrays(filtered, &expected_array);
1820
1821 let filter_array = BooleanArray::from(vec![true, false, true]);
1822 let c = filter(&array, &filter_array).unwrap();
1823 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1824
1825 let mut builder = UnionBuilder::new_dense();
1826 builder.append::<Int32Type>("A", 1).unwrap();
1827 builder.append::<Int32Type>("A", 34).unwrap();
1828 let expected_array = builder.build().unwrap();
1829
1830 compare_union_arrays(filtered, &expected_array);
1831
1832 let filter_array = BooleanArray::from(vec![true, true, false]);
1833 let c = filter(&array, &filter_array).unwrap();
1834 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1835
1836 let mut builder = UnionBuilder::new_dense();
1837 builder.append::<Int32Type>("A", 1).unwrap();
1838 builder.append::<Float64Type>("B", 3.2).unwrap();
1839 let expected_array = builder.build().unwrap();
1840
1841 compare_union_arrays(filtered, &expected_array);
1842 }
1843
1844 #[test]
1845 fn test_filter_union_array_dense() {
1846 let mut builder = UnionBuilder::new_dense();
1847 builder.append::<Int32Type>("A", 1).unwrap();
1848 builder.append::<Float64Type>("B", 3.2).unwrap();
1849 builder.append::<Int32Type>("A", 34).unwrap();
1850 let array = builder.build().unwrap();
1851
1852 test_filter_union_array(array);
1853 }
1854
1855 #[test]
1856 fn test_filter_run_union_array_dense() {
1857 let mut builder = UnionBuilder::new_dense();
1858 builder.append::<Int32Type>("A", 1).unwrap();
1859 builder.append::<Int32Type>("A", 3).unwrap();
1860 builder.append::<Int32Type>("A", 34).unwrap();
1861 let array = builder.build().unwrap();
1862
1863 let filter_array = BooleanArray::from(vec![true, true, false]);
1864 let c = filter(&array, &filter_array).unwrap();
1865 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1866
1867 let mut builder = UnionBuilder::new_dense();
1868 builder.append::<Int32Type>("A", 1).unwrap();
1869 builder.append::<Int32Type>("A", 3).unwrap();
1870 let expected = builder.build().unwrap();
1871
1872 assert_eq!(filtered.to_data(), expected.to_data());
1873 }
1874
1875 #[test]
1876 fn test_filter_union_array_dense_with_nulls() {
1877 let mut builder = UnionBuilder::new_dense();
1878 builder.append::<Int32Type>("A", 1).unwrap();
1879 builder.append::<Float64Type>("B", 3.2).unwrap();
1880 builder.append_null::<Float64Type>("B").unwrap();
1881 builder.append::<Int32Type>("A", 34).unwrap();
1882 let array = builder.build().unwrap();
1883
1884 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1885 let c = filter(&array, &filter_array).unwrap();
1886 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1887
1888 let mut builder = UnionBuilder::new_dense();
1889 builder.append::<Int32Type>("A", 1).unwrap();
1890 builder.append::<Float64Type>("B", 3.2).unwrap();
1891 let expected_array = builder.build().unwrap();
1892
1893 compare_union_arrays(filtered, &expected_array);
1894
1895 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1896 let c = filter(&array, &filter_array).unwrap();
1897 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1898
1899 let mut builder = UnionBuilder::new_dense();
1900 builder.append::<Int32Type>("A", 1).unwrap();
1901 builder.append_null::<Float64Type>("B").unwrap();
1902 let expected_array = builder.build().unwrap();
1903
1904 compare_union_arrays(filtered, &expected_array);
1905 }
1906
1907 #[test]
1908 fn test_filter_union_array_sparse() {
1909 let mut builder = UnionBuilder::new_sparse();
1910 builder.append::<Int32Type>("A", 1).unwrap();
1911 builder.append::<Float64Type>("B", 3.2).unwrap();
1912 builder.append::<Int32Type>("A", 34).unwrap();
1913 let array = builder.build().unwrap();
1914
1915 test_filter_union_array(array);
1916 }
1917
1918 #[test]
1919 fn test_filter_union_array_sparse_with_nulls() {
1920 let mut builder = UnionBuilder::new_sparse();
1921 builder.append::<Int32Type>("A", 1).unwrap();
1922 builder.append::<Float64Type>("B", 3.2).unwrap();
1923 builder.append_null::<Float64Type>("B").unwrap();
1924 builder.append::<Int32Type>("A", 34).unwrap();
1925 let array = builder.build().unwrap();
1926
1927 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1928 let c = filter(&array, &filter_array).unwrap();
1929 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1930
1931 let mut builder = UnionBuilder::new_sparse();
1932 builder.append::<Int32Type>("A", 1).unwrap();
1933 builder.append_null::<Float64Type>("B").unwrap();
1934 let expected_array = builder.build().unwrap();
1935
1936 compare_union_arrays(filtered, &expected_array);
1937 }
1938
1939 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1940 assert_eq!(union1.len(), union2.len());
1941
1942 for i in 0..union1.len() {
1943 let type_id = union1.type_id(i);
1944
1945 let slot1 = union1.value(i);
1946 let slot2 = union2.value(i);
1947
1948 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1949
1950 if !slot1.is_null(0) && !slot2.is_null(0) {
1951 match type_id {
1952 0 => {
1953 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1954 assert_eq!(slot1.len(), 1);
1955 let value1 = slot1.value(0);
1956
1957 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1958 assert_eq!(slot2.len(), 1);
1959 let value2 = slot2.value(0);
1960 assert_eq!(value1, value2);
1961 }
1962 1 => {
1963 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1964 assert_eq!(slot1.len(), 1);
1965 let value1 = slot1.value(0);
1966
1967 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1968 assert_eq!(slot2.len(), 1);
1969 let value2 = slot2.value(0);
1970 assert_eq!(value1, value2);
1971 }
1972 _ => unreachable!(),
1973 }
1974 }
1975 }
1976 }
1977
1978 #[test]
1979 fn test_filter_struct() {
1980 let predicate = BooleanArray::from(vec![true, false, true, false]);
1981
1982 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1983 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1984
1985 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1986 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1987
1988 let null_mask = NullBuffer::from(vec![true, false, false, true]);
1989 let null_mask_filtered = NullBuffer::from(vec![true, false]);
1990
1991 let a_field = Field::new("a", DataType::Utf8, false);
1992 let b_field = Field::new("b", DataType::Int32, false);
1993
1994 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1995 let expected =
1996 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1997
1998 let result = filter(&array, &predicate).unwrap();
1999
2000 assert_eq!(result.to_data(), expected.to_data());
2001
2002 let array = StructArray::new(
2003 vec![a_field.clone()].into(),
2004 vec![a.clone()],
2005 Some(null_mask.clone()),
2006 );
2007 let expected = StructArray::new(
2008 vec![a_field.clone()].into(),
2009 vec![a_filtered.clone()],
2010 Some(null_mask_filtered.clone()),
2011 );
2012
2013 let result = filter(&array, &predicate).unwrap();
2014
2015 assert_eq!(result.to_data(), expected.to_data());
2016
2017 let array = StructArray::new(
2018 vec![a_field.clone(), b_field.clone()].into(),
2019 vec![a.clone(), b.clone()],
2020 None,
2021 );
2022 let expected = StructArray::new(
2023 vec![a_field.clone(), b_field.clone()].into(),
2024 vec![a_filtered.clone(), b_filtered.clone()],
2025 None,
2026 );
2027
2028 let result = filter(&array, &predicate).unwrap();
2029
2030 assert_eq!(result.to_data(), expected.to_data());
2031
2032 let array = StructArray::new(
2033 vec![a_field.clone(), b_field.clone()].into(),
2034 vec![a.clone(), b.clone()],
2035 Some(null_mask.clone()),
2036 );
2037
2038 let expected = StructArray::new(
2039 vec![a_field.clone(), b_field.clone()].into(),
2040 vec![a_filtered.clone(), b_filtered.clone()],
2041 Some(null_mask_filtered.clone()),
2042 );
2043
2044 let result = filter(&array, &predicate).unwrap();
2045
2046 assert_eq!(result.to_data(), expected.to_data());
2047 }
2048
2049 #[test]
2050 fn test_filter_empty_struct() {
2051 let fields = arrow_schema::Field::new(
2058 "a",
2059 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2060 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2061 arrow_schema::Field::new(
2062 "c",
2063 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2064 true,
2065 ),
2066 ])),
2067 true,
2068 );
2069
2070 let schema = Arc::new(Schema::new(vec![fields]));
2078
2079 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2080 let c = Arc::new(StructArray::new_empty_fields(
2081 3,
2082 Some(NullBuffer::from(vec![true, true, true])),
2083 ));
2084 let a = StructArray::new(
2085 vec![
2086 Field::new("b", DataType::Int64, true),
2087 Field::new("c", DataType::Struct(Fields::empty()), true),
2088 ]
2089 .into(),
2090 vec![b.clone(), c.clone()],
2091 Some(NullBuffer::from(vec![true, true, true])),
2092 );
2093 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2094 println!("{record_batch:?}");
2095
2096 let predicate = BooleanArray::from(vec![true, false, true]);
2098 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2099
2100 assert_eq!(filtered_batch.num_rows(), 2);
2102 }
2103}