1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132 let iter = SlicesIterator::new(filter);
133 let filter_count = filter_count(filter);
134 let chunks = iter.collect::<Vec<_>>();
135
136 Ok(Box::new(move |array: &ArrayData| {
137 match filter_count {
138 len if len == array.len() => array.clone(),
140 0 => ArrayData::new_empty(array.data_type()),
141 _ => {
142 let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143 chunks
144 .iter()
145 .for_each(|(start, end)| mutable.extend(0, *start, *end));
146 mutable.freeze()
147 }
148 }
149 }))
150}
151
152pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154 let nulls = filter.nulls().unwrap();
155 let mask = filter.values() & nulls.inner();
156 BooleanArray::new(mask, None)
157}
158
159pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
181 let mut filter_builder = FilterBuilder::new(predicate);
182
183 if multiple_arrays(values.data_type()) {
184 filter_builder = filter_builder.optimize();
187 }
188
189 let predicate = filter_builder.build();
190
191 filter_array(values, &predicate)
192}
193
194fn multiple_arrays(data_type: &DataType) -> bool {
195 match data_type {
196 DataType::Struct(fields) => {
197 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
198 }
199 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
200 _ => false,
201 }
202}
203
204pub fn filter_record_batch(
209 record_batch: &RecordBatch,
210 predicate: &BooleanArray,
211) -> Result<RecordBatch, ArrowError> {
212 let mut filter_builder = FilterBuilder::new(predicate);
213 if record_batch.num_columns() > 1 {
214 filter_builder = filter_builder.optimize();
217 }
218 let filter = filter_builder.build();
219
220 let filtered_arrays = record_batch
221 .columns()
222 .iter()
223 .map(|a| filter_array(a, &filter))
224 .collect::<Result<Vec<_>, _>>()?;
225 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
226 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
227}
228
229#[derive(Debug)]
231pub struct FilterBuilder {
232 filter: BooleanArray,
233 count: usize,
234 strategy: IterationStrategy,
235}
236
237impl FilterBuilder {
238 pub fn new(filter: &BooleanArray) -> Self {
240 let filter = match filter.null_count() {
241 0 => filter.clone(),
242 _ => prep_null_mask_filter(filter),
243 };
244
245 let count = filter_count(&filter);
246 let strategy = IterationStrategy::default_strategy(filter.len(), count);
247
248 Self {
249 filter,
250 count,
251 strategy,
252 }
253 }
254
255 pub fn optimize(mut self) -> Self {
261 match self.strategy {
262 IterationStrategy::SlicesIterator => {
263 let slices = SlicesIterator::new(&self.filter).collect();
264 self.strategy = IterationStrategy::Slices(slices)
265 }
266 IterationStrategy::IndexIterator => {
267 let indices = IndexIterator::new(&self.filter, self.count).collect();
268 self.strategy = IterationStrategy::Indices(indices)
269 }
270 _ => {}
271 }
272 self
273 }
274
275 pub fn build(self) -> FilterPredicate {
277 FilterPredicate {
278 filter: self.filter,
279 count: self.count,
280 strategy: self.strategy,
281 }
282 }
283}
284
285#[derive(Debug)]
287enum IterationStrategy {
288 SlicesIterator,
290 IndexIterator,
292 Indices(Vec<usize>),
294 Slices(Vec<(usize, usize)>),
296 All,
298 None,
300}
301
302impl IterationStrategy {
303 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
306 if filter_length == 0 || filter_count == 0 {
307 return IterationStrategy::None;
308 }
309
310 if filter_count == filter_length {
311 return IterationStrategy::All;
312 }
313
314 let selectivity_frac = filter_count as f64 / filter_length as f64;
319 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
320 return IterationStrategy::SlicesIterator;
321 }
322 IterationStrategy::IndexIterator
323 }
324}
325
326#[derive(Debug)]
328pub struct FilterPredicate {
329 filter: BooleanArray,
330 count: usize,
331 strategy: IterationStrategy,
332}
333
334impl FilterPredicate {
335 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
337 filter_array(values, self)
338 }
339
340 pub fn count(&self) -> usize {
342 self.count
343 }
344}
345
346fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
347 if predicate.filter.len() > values.len() {
348 return Err(ArrowError::InvalidArgumentError(format!(
349 "Filter predicate of length {} is larger than target array of length {}",
350 predicate.filter.len(),
351 values.len()
352 )));
353 }
354
355 match predicate.strategy {
356 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
357 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
358 _ => downcast_primitive_array! {
360 values => Ok(Arc::new(filter_primitive(values, predicate))),
361 DataType::Boolean => {
362 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
363 Ok(Arc::new(filter_boolean(values, predicate)))
364 }
365 DataType::Utf8 => {
366 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
367 }
368 DataType::LargeUtf8 => {
369 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
370 }
371 DataType::Utf8View => {
372 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
373 }
374 DataType::Binary => {
375 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
376 }
377 DataType::LargeBinary => {
378 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
379 }
380 DataType::BinaryView => {
381 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
382 }
383 DataType::FixedSizeBinary(_) => {
384 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
385 }
386 DataType::RunEndEncoded(_, _) => {
387 downcast_run_array!{
388 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
389 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
390 }
391 }
392 DataType::Dictionary(_, _) => downcast_dictionary_array! {
393 values => Ok(Arc::new(filter_dict(values, predicate))),
394 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
395 }
396 DataType::Struct(_) => {
397 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
398 }
399 DataType::Union(_, UnionMode::Sparse) => {
400 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
401 }
402 _ => {
403 let data = values.to_data();
404 let mut mutable = MutableArrayData::new(
406 vec![&data],
407 false,
408 predicate.count,
409 );
410
411 match &predicate.strategy {
412 IterationStrategy::Slices(slices) => {
413 slices
414 .iter()
415 .for_each(|(start, end)| mutable.extend(0, *start, *end));
416 }
417 _ => {
418 let iter = SlicesIterator::new(&predicate.filter);
419 iter.for_each(|(start, end)| mutable.extend(0, start, end));
420 }
421 }
422
423 let data = mutable.freeze();
424 Ok(make_array(data))
425 }
426 },
427 }
428}
429
430fn filter_run_end_array<R: RunEndIndexType>(
432 array: &RunArray<R>,
433 predicate: &FilterPredicate,
434) -> Result<RunArray<R>, ArrowError>
435where
436 R::Native: Into<i64> + From<bool>,
437 R::Native: AddAssign,
438{
439 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
440 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
441
442 let mut start = 0u64;
443 let mut j = 0;
444 let mut count = R::default_value();
445 let filter_values = predicate.filter.values();
446 let run_ends = run_ends.inner();
447
448 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
449 let mut keep = false;
450 let mut end = run_ends[i].into() as u64;
451 let difference = end.saturating_sub(filter_values.len() as u64);
452 end -= difference;
453
454 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
456 count += R::Native::from(pred);
457 keep |= pred
458 }
459 new_run_ends[j] = count;
461 j += keep as usize;
462
463 start = end;
464 keep
465 })
466 .into();
467
468 new_run_ends.truncate(j);
469
470 let values = array.values();
471 let values = filter(&values, &pred)?;
472
473 let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
474 RunArray::try_new(&run_ends, &values)
475}
476
477fn filter_null_mask(
484 nulls: Option<&NullBuffer>,
485 predicate: &FilterPredicate,
486) -> Option<(usize, Buffer)> {
487 let nulls = nulls?;
488 if nulls.null_count() == 0 {
489 return None;
490 }
491
492 let nulls = filter_bits(nulls.inner(), predicate);
493 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
496
497 if null_count == 0 {
498 return None;
499 }
500
501 Some((null_count, nulls))
502}
503
504fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
506 let src = buffer.values();
507 let offset = buffer.offset();
508
509 match &predicate.strategy {
510 IterationStrategy::IndexIterator => {
511 let bits = IndexIterator::new(&predicate.filter, predicate.count)
512 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
513
514 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
516 }
517 IterationStrategy::Indices(indices) => {
518 let bits = indices
519 .iter()
520 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
521
522 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
524 }
525 IterationStrategy::SlicesIterator => {
526 let mut builder = BooleanBufferBuilder::new(predicate.count);
527 for (start, end) in SlicesIterator::new(&predicate.filter) {
528 builder.append_packed_range(start + offset..end + offset, src)
529 }
530 builder.into()
531 }
532 IterationStrategy::Slices(slices) => {
533 let mut builder = BooleanBufferBuilder::new(predicate.count);
534 for (start, end) in slices {
535 builder.append_packed_range(*start + offset..*end + offset, src)
536 }
537 builder.into()
538 }
539 IterationStrategy::All | IterationStrategy::None => unreachable!(),
540 }
541}
542
543fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
545 let values = filter_bits(array.values(), predicate);
546
547 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
548 .len(predicate.count)
549 .add_buffer(values);
550
551 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
552 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
553 }
554
555 let data = unsafe { builder.build_unchecked() };
556 BooleanArray::from(data)
557}
558
559#[inline(never)]
560fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
561 assert!(values.len() >= predicate.filter.len());
562
563 match &predicate.strategy {
564 IterationStrategy::SlicesIterator => {
565 let mut buffer = Vec::with_capacity(predicate.count);
566 for (start, end) in SlicesIterator::new(&predicate.filter) {
567 buffer.extend_from_slice(&values[start..end]);
568 }
569 buffer.into()
570 }
571 IterationStrategy::Slices(slices) => {
572 let mut buffer = Vec::with_capacity(predicate.count);
573 for (start, end) in slices {
574 buffer.extend_from_slice(&values[*start..*end]);
575 }
576 buffer.into()
577 }
578 IterationStrategy::IndexIterator => {
579 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
580
581 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
583 }
584 IterationStrategy::Indices(indices) => {
585 let iter = indices.iter().map(|x| values[*x]);
586 iter.collect::<Vec<_>>().into()
587 }
588 IterationStrategy::All | IterationStrategy::None => unreachable!(),
589 }
590}
591
592fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
594where
595 T: ArrowPrimitiveType,
596{
597 let values = array.values();
598 let buffer = filter_native(values, predicate);
599 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
600 .len(predicate.count)
601 .add_buffer(buffer);
602
603 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
604 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
605 }
606
607 let data = unsafe { builder.build_unchecked() };
608 PrimitiveArray::from(data)
609}
610
611struct FilterBytes<'a, OffsetSize> {
616 src_offsets: &'a [OffsetSize],
617 src_values: &'a [u8],
618 dst_offsets: Vec<OffsetSize>,
619 dst_values: Vec<u8>,
620 cur_offset: OffsetSize,
621}
622
623impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
624where
625 OffsetSize: OffsetSizeTrait,
626{
627 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
628 where
629 T: ByteArrayType<Offset = OffsetSize>,
630 {
631 let dst_values = Vec::new();
632 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
633 let cur_offset = OffsetSize::from_usize(0).unwrap();
634
635 dst_offsets.push(cur_offset);
636
637 Self {
638 src_offsets: array.value_offsets(),
639 src_values: array.value_data(),
640 dst_offsets,
641 dst_values,
642 cur_offset,
643 }
644 }
645
646 #[inline]
648 fn get_value_offset(&self, idx: usize) -> usize {
649 self.src_offsets[idx].as_usize()
650 }
651
652 #[inline]
654 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
655 let start = self.get_value_offset(idx);
657 let end = self.get_value_offset(idx + 1);
658 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
659 (start, end, len)
660 }
661
662 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
663 self.dst_offsets.extend(iter.map(|idx| {
664 let start = self.src_offsets[idx].as_usize();
665 let end = self.src_offsets[idx + 1].as_usize();
666 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
667 self.cur_offset += len;
668
669 self.cur_offset
670 }));
671 }
672
673 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
675 self.dst_values.reserve_exact(self.cur_offset.as_usize());
676
677 for idx in iter {
678 let start = self.src_offsets[idx].as_usize();
679 let end = self.src_offsets[idx + 1].as_usize();
680 self.dst_values
681 .extend_from_slice(&self.src_values[start..end]);
682 }
683 }
684
685 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
686 self.dst_offsets.reserve_exact(count);
687 for (start, end) in iter {
688 for idx in start..end {
690 let (_, _, len) = self.get_value_range(idx);
691 self.cur_offset += len;
692 self.dst_offsets.push(self.cur_offset);
693 }
694 }
695 }
696
697 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
699 self.dst_values.reserve_exact(self.cur_offset.as_usize());
700
701 for (start, end) in iter {
702 let value_start = self.get_value_offset(start);
703 let value_end = self.get_value_offset(end);
704 self.dst_values
705 .extend_from_slice(&self.src_values[value_start..value_end]);
706 }
707 }
708}
709
710fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
715where
716 T: ByteArrayType,
717{
718 let mut filter = FilterBytes::new(predicate.count, array);
719
720 match &predicate.strategy {
721 IterationStrategy::SlicesIterator => {
722 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
723 filter.extend_slices(SlicesIterator::new(&predicate.filter))
724 }
725 IterationStrategy::Slices(slices) => {
726 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
727 filter.extend_slices(slices.iter().cloned())
728 }
729 IterationStrategy::IndexIterator => {
730 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
731 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
732 }
733 IterationStrategy::Indices(indices) => {
734 filter.extend_offsets_idx(indices.iter().cloned());
735 filter.extend_idx(indices.iter().cloned())
736 }
737 IterationStrategy::All | IterationStrategy::None => unreachable!(),
738 }
739
740 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
741 .len(predicate.count)
742 .add_buffer(filter.dst_offsets.into())
743 .add_buffer(filter.dst_values.into());
744
745 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
746 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
747 }
748
749 let data = unsafe { builder.build_unchecked() };
750 GenericByteArray::from(data)
751}
752
753fn filter_byte_view<T: ByteViewType>(
755 array: &GenericByteViewArray<T>,
756 predicate: &FilterPredicate,
757) -> GenericByteViewArray<T> {
758 let new_view_buffer = filter_native(array.views(), predicate);
759
760 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
761 .len(predicate.count)
762 .add_buffer(new_view_buffer)
763 .add_buffers(array.data_buffers().to_vec());
764
765 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
766 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
767 }
768
769 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
770}
771
772fn filter_fixed_size_binary(
773 array: &FixedSizeBinaryArray,
774 predicate: &FilterPredicate,
775) -> FixedSizeBinaryArray {
776 let values: &[u8] = array.values();
777 let value_length = array.value_length() as usize;
778 let calculate_offset_from_index = |index: usize| index * value_length;
779 let buffer = match &predicate.strategy {
780 IterationStrategy::SlicesIterator => {
781 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
782 for (start, end) in SlicesIterator::new(&predicate.filter) {
783 buffer.extend_from_slice(
784 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
785 );
786 }
787 buffer
788 }
789 IterationStrategy::Slices(slices) => {
790 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
791 for (start, end) in slices {
792 buffer.extend_from_slice(
793 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
794 );
795 }
796 buffer
797 }
798 IterationStrategy::IndexIterator => {
799 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
800 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
801 });
802
803 let mut buffer = MutableBuffer::new(predicate.count * value_length);
804 iter.for_each(|item| buffer.extend_from_slice(item));
805 buffer
806 }
807 IterationStrategy::Indices(indices) => {
808 let iter = indices.iter().map(|x| {
809 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
810 });
811
812 let mut buffer = MutableBuffer::new(predicate.count * value_length);
813 iter.for_each(|item| buffer.extend_from_slice(item));
814 buffer
815 }
816 IterationStrategy::All | IterationStrategy::None => unreachable!(),
817 };
818 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
819 .len(predicate.count)
820 .add_buffer(buffer.into());
821
822 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
823 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
824 }
825
826 let data = unsafe { builder.build_unchecked() };
827 FixedSizeBinaryArray::from(data)
828}
829
830fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
832where
833 T: ArrowDictionaryKeyType,
834 T::Native: num::Num,
835{
836 let builder = filter_primitive::<T>(array.keys(), predicate)
837 .into_data()
838 .into_builder()
839 .data_type(array.data_type().clone())
840 .child_data(vec![array.values().to_data()]);
841
842 DictionaryArray::from(unsafe { builder.build_unchecked() })
845}
846
847fn filter_struct(
849 array: &StructArray,
850 predicate: &FilterPredicate,
851) -> Result<StructArray, ArrowError> {
852 let columns = array
853 .columns()
854 .iter()
855 .map(|column| filter_array(column, predicate))
856 .collect::<Result<_, _>>()?;
857
858 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
859 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
860
861 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
862 } else {
863 None
864 };
865
866 Ok(unsafe {
867 StructArray::new_unchecked_with_length(
868 array.fields().clone(),
869 columns,
870 nulls,
871 predicate.count(),
872 )
873 })
874}
875
876fn filter_sparse_union(
878 array: &UnionArray,
879 predicate: &FilterPredicate,
880) -> Result<UnionArray, ArrowError> {
881 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
882 unreachable!()
883 };
884
885 let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
886
887 let children = fields
888 .iter()
889 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
890 .collect::<Result<_, _>>()?;
891
892 Ok(unsafe {
893 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
894 })
895}
896
897#[cfg(test)]
898mod tests {
899 use arrow_array::builder::*;
900 use arrow_array::cast::as_run_array;
901 use arrow_array::types::*;
902 use rand::distr::uniform::{UniformSampler, UniformUsize};
903 use rand::distr::{Alphanumeric, StandardUniform};
904 use rand::prelude::*;
905 use rand::rng;
906
907 use super::*;
908
909 macro_rules! def_temporal_test {
910 ($test:ident, $array_type: ident, $data: expr) => {
911 #[test]
912 fn $test() {
913 let a = $data;
914 let b = BooleanArray::from(vec![true, false, true, false]);
915 let c = filter(&a, &b).unwrap();
916 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
917 assert_eq!(2, d.len());
918 assert_eq!(1, d.value(0));
919 assert_eq!(3, d.value(1));
920 }
921 };
922 }
923
924 def_temporal_test!(
925 test_filter_date32,
926 Date32Array,
927 Date32Array::from(vec![1, 2, 3, 4])
928 );
929 def_temporal_test!(
930 test_filter_date64,
931 Date64Array,
932 Date64Array::from(vec![1, 2, 3, 4])
933 );
934 def_temporal_test!(
935 test_filter_time32_second,
936 Time32SecondArray,
937 Time32SecondArray::from(vec![1, 2, 3, 4])
938 );
939 def_temporal_test!(
940 test_filter_time32_millisecond,
941 Time32MillisecondArray,
942 Time32MillisecondArray::from(vec![1, 2, 3, 4])
943 );
944 def_temporal_test!(
945 test_filter_time64_microsecond,
946 Time64MicrosecondArray,
947 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
948 );
949 def_temporal_test!(
950 test_filter_time64_nanosecond,
951 Time64NanosecondArray,
952 Time64NanosecondArray::from(vec![1, 2, 3, 4])
953 );
954 def_temporal_test!(
955 test_filter_duration_second,
956 DurationSecondArray,
957 DurationSecondArray::from(vec![1, 2, 3, 4])
958 );
959 def_temporal_test!(
960 test_filter_duration_millisecond,
961 DurationMillisecondArray,
962 DurationMillisecondArray::from(vec![1, 2, 3, 4])
963 );
964 def_temporal_test!(
965 test_filter_duration_microsecond,
966 DurationMicrosecondArray,
967 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
968 );
969 def_temporal_test!(
970 test_filter_duration_nanosecond,
971 DurationNanosecondArray,
972 DurationNanosecondArray::from(vec![1, 2, 3, 4])
973 );
974 def_temporal_test!(
975 test_filter_timestamp_second,
976 TimestampSecondArray,
977 TimestampSecondArray::from(vec![1, 2, 3, 4])
978 );
979 def_temporal_test!(
980 test_filter_timestamp_millisecond,
981 TimestampMillisecondArray,
982 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
983 );
984 def_temporal_test!(
985 test_filter_timestamp_microsecond,
986 TimestampMicrosecondArray,
987 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
988 );
989 def_temporal_test!(
990 test_filter_timestamp_nanosecond,
991 TimestampNanosecondArray,
992 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
993 );
994
995 #[test]
996 fn test_filter_array_slice() {
997 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
998 let b = BooleanArray::from(vec![true, false, false, true]);
999 let c = filter(&a, &b).unwrap();
1003 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1004 assert_eq!(2, d.len());
1005 assert_eq!(6, d.value(0));
1006 assert_eq!(9, d.value(1));
1007 }
1008
1009 #[test]
1010 fn test_filter_array_low_density() {
1011 let mut data_values = (1..=65).collect::<Vec<i32>>();
1013 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1014 data_values.extend_from_slice(&[66, 67]);
1016 filter_values.extend_from_slice(&[false, true]);
1017 let a = Int32Array::from(data_values);
1018 let b = BooleanArray::from(filter_values);
1019 let c = filter(&a, &b).unwrap();
1020 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1021 assert_eq!(2, d.len());
1022 assert_eq!(65, d.value(0));
1023 assert_eq!(67, d.value(1));
1024 }
1025
1026 #[test]
1027 fn test_filter_array_high_density() {
1028 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1030 let mut filter_values = (1..=65)
1031 .map(|i| !matches!(i % 65, 0))
1032 .collect::<Vec<bool>>();
1033 data_values[1] = None;
1035 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1037 filter_values.extend_from_slice(&[false, true, true, true]);
1038 let a = Int32Array::from(data_values);
1039 let b = BooleanArray::from(filter_values);
1040 let c = filter(&a, &b).unwrap();
1041 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1042 assert_eq!(67, d.len());
1043 assert_eq!(3, d.null_count());
1044 assert_eq!(1, d.value(0));
1045 assert!(d.is_null(1));
1046 assert_eq!(64, d.value(63));
1047 assert!(d.is_null(64));
1048 assert_eq!(67, d.value(65));
1049 }
1050
1051 #[test]
1052 fn test_filter_string_array_simple() {
1053 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1054 let b = BooleanArray::from(vec![true, false, true, false]);
1055 let c = filter(&a, &b).unwrap();
1056 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1057 assert_eq!(2, d.len());
1058 assert_eq!("hello", d.value(0));
1059 assert_eq!("world", d.value(1));
1060 }
1061
1062 #[test]
1063 fn test_filter_primitive_array_with_null() {
1064 let a = Int32Array::from(vec![Some(5), None]);
1065 let b = BooleanArray::from(vec![false, true]);
1066 let c = filter(&a, &b).unwrap();
1067 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1068 assert_eq!(1, d.len());
1069 assert!(d.is_null(0));
1070 }
1071
1072 #[test]
1073 fn test_filter_string_array_with_null() {
1074 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1075 let b = BooleanArray::from(vec![true, false, false, true]);
1076 let c = filter(&a, &b).unwrap();
1077 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1078 assert_eq!(2, d.len());
1079 assert_eq!("hello", d.value(0));
1080 assert!(!d.is_null(0));
1081 assert!(d.is_null(1));
1082 }
1083
1084 #[test]
1085 fn test_filter_binary_array_with_null() {
1086 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1087 let a = BinaryArray::from(data);
1088 let b = BooleanArray::from(vec![true, false, false, true]);
1089 let c = filter(&a, &b).unwrap();
1090 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1091 assert_eq!(2, d.len());
1092 assert_eq!(b"hello", d.value(0));
1093 assert!(!d.is_null(0));
1094 assert!(d.is_null(1));
1095 }
1096
1097 fn _test_filter_byte_view<T>()
1098 where
1099 T: ByteViewType,
1100 str: AsRef<T::Native>,
1101 T::Native: PartialEq,
1102 {
1103 let array = {
1104 let mut builder = GenericByteViewBuilder::<T>::new();
1106 builder.append_value("hello");
1107 builder.append_value("world");
1108 builder.append_null();
1109 builder.append_value("large payload over 12 bytes");
1110 builder.append_value("lulu");
1111 builder.finish()
1112 };
1113
1114 {
1115 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1116 let actual = filter(&array, &predicate).unwrap();
1117
1118 assert_eq!(actual.len(), 3);
1119
1120 let expected = {
1121 let mut builder = GenericByteViewBuilder::<T>::new();
1123 builder.append_value("hello");
1124 builder.append_null();
1125 builder.append_value("large payload over 12 bytes");
1126 builder.finish()
1127 };
1128
1129 assert_eq!(actual.as_ref(), &expected);
1130 }
1131
1132 {
1133 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1134 let actual = filter(&array, &predicate).unwrap();
1135
1136 assert_eq!(actual.len(), 2);
1137
1138 let expected = {
1139 let mut builder = GenericByteViewBuilder::<T>::new();
1141 builder.append_value("hello");
1142 builder.append_value("lulu");
1143 builder.finish()
1144 };
1145
1146 assert_eq!(actual.as_ref(), &expected);
1147 }
1148 }
1149
1150 #[test]
1151 fn test_filter_string_view() {
1152 _test_filter_byte_view::<StringViewType>()
1153 }
1154
1155 #[test]
1156 fn test_filter_binary_view() {
1157 _test_filter_byte_view::<BinaryViewType>()
1158 }
1159
1160 #[test]
1161 fn test_filter_fixed_binary() {
1162 let v1 = [1_u8, 2];
1163 let v2 = [3_u8, 4];
1164 let v3 = [5_u8, 6];
1165 let v = vec![&v1, &v2, &v3];
1166 let a = FixedSizeBinaryArray::from(v);
1167 let b = BooleanArray::from(vec![true, false, true]);
1168 let c = filter(&a, &b).unwrap();
1169 let d = c
1170 .as_ref()
1171 .as_any()
1172 .downcast_ref::<FixedSizeBinaryArray>()
1173 .unwrap();
1174 assert_eq!(d.len(), 2);
1175 assert_eq!(d.value(0), &v1);
1176 assert_eq!(d.value(1), &v3);
1177 let c2 = FilterBuilder::new(&b)
1178 .optimize()
1179 .build()
1180 .filter(&a)
1181 .unwrap();
1182 let d2 = c2
1183 .as_ref()
1184 .as_any()
1185 .downcast_ref::<FixedSizeBinaryArray>()
1186 .unwrap();
1187 assert_eq!(d, d2);
1188
1189 let b = BooleanArray::from(vec![false, false, false]);
1190 let c = filter(&a, &b).unwrap();
1191 let d = c
1192 .as_ref()
1193 .as_any()
1194 .downcast_ref::<FixedSizeBinaryArray>()
1195 .unwrap();
1196 assert_eq!(d.len(), 0);
1197
1198 let b = BooleanArray::from(vec![true, true, true]);
1199 let c = filter(&a, &b).unwrap();
1200 let d = c
1201 .as_ref()
1202 .as_any()
1203 .downcast_ref::<FixedSizeBinaryArray>()
1204 .unwrap();
1205 assert_eq!(d.len(), 3);
1206 assert_eq!(d.value(0), &v1);
1207 assert_eq!(d.value(1), &v2);
1208 assert_eq!(d.value(2), &v3);
1209
1210 let b = BooleanArray::from(vec![false, false, true]);
1211 let c = filter(&a, &b).unwrap();
1212 let d = c
1213 .as_ref()
1214 .as_any()
1215 .downcast_ref::<FixedSizeBinaryArray>()
1216 .unwrap();
1217 assert_eq!(d.len(), 1);
1218 assert_eq!(d.value(0), &v3);
1219 let c2 = FilterBuilder::new(&b)
1220 .optimize()
1221 .build()
1222 .filter(&a)
1223 .unwrap();
1224 let d2 = c2
1225 .as_ref()
1226 .as_any()
1227 .downcast_ref::<FixedSizeBinaryArray>()
1228 .unwrap();
1229 assert_eq!(d, d2);
1230 }
1231
1232 #[test]
1233 fn test_filter_array_slice_with_null() {
1234 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1235 let b = BooleanArray::from(vec![true, false, false, true]);
1236 let c = filter(&a, &b).unwrap();
1240 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1241 assert_eq!(2, d.len());
1242 assert!(d.is_null(0));
1243 assert!(!d.is_null(1));
1244 assert_eq!(9, d.value(1));
1245 }
1246
1247 #[test]
1248 fn test_filter_run_end_encoding_array() {
1249 let run_ends = Int64Array::from(vec![2, 3, 8]);
1250 let values = Int64Array::from(vec![7, -2, 9]);
1251 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1252 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1253 let c = filter(&a, &b).unwrap();
1254 let actual: &RunArray<Int64Type> = as_run_array(&c);
1255 assert_eq!(4, actual.len());
1256
1257 let expected = RunArray::try_new(
1258 &Int64Array::from(vec![1, 2, 4]),
1259 &Int64Array::from(vec![7, -2, 9]),
1260 )
1261 .expect("Failed to make expected RunArray test is broken");
1262
1263 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1264 assert_eq!(actual.values(), expected.values())
1265 }
1266
1267 #[test]
1268 fn test_filter_run_end_encoding_array_remove_value() {
1269 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1270 let values = Int32Array::from(vec![7, -2, 9, -8]);
1271 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1272 let b = BooleanArray::from(vec![
1273 false, true, false, false, true, false, true, false, false, false,
1274 ]);
1275 let c = filter(&a, &b).unwrap();
1276 let actual: &RunArray<Int32Type> = as_run_array(&c);
1277 assert_eq!(3, actual.len());
1278
1279 let expected =
1280 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1281 .expect("Failed to make expected RunArray test is broken");
1282
1283 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1284 assert_eq!(actual.values(), expected.values())
1285 }
1286
1287 #[test]
1288 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1289 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1290 let values = Int16Array::from(vec![7, -2, 9, -8]);
1291 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1292 let b = BooleanArray::from(vec![
1293 false, false, false, false, false, false, true, false, false, false,
1294 ]);
1295 let c = filter(&a, &b).unwrap();
1296 let actual: &RunArray<Int16Type> = as_run_array(&c);
1297 assert_eq!(1, actual.len());
1298
1299 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1300 .expect("Failed to make expected RunArray test is broken");
1301
1302 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1303 assert_eq!(actual.values(), expected.values())
1304 }
1305
1306 #[test]
1307 fn test_filter_run_end_encoding_array_empty() {
1308 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1309 let values = Int64Array::from(vec![7, -2, 9, -8]);
1310 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1311 let b = BooleanArray::from(vec![
1312 false, false, false, false, false, false, false, false, false, false,
1313 ]);
1314 let c = filter(&a, &b).unwrap();
1315 let actual: &RunArray<Int64Type> = as_run_array(&c);
1316 assert_eq!(0, actual.len());
1317 }
1318
1319 #[test]
1320 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1321 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1322 let values = Int64Array::from(vec![7, -2, 9, -8]);
1323 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1324 let b = BooleanArray::from(vec![false, true, true]);
1325 let c = filter(&a, &b).unwrap();
1326 let actual: &RunArray<Int64Type> = as_run_array(&c);
1327 assert_eq!(2, actual.len());
1328
1329 let expected = RunArray::try_new(
1330 &Int64Array::from(vec![1, 2]),
1331 &Int64Array::from(vec![7, -2]),
1332 )
1333 .expect("Failed to make expected RunArray test is broken");
1334
1335 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1336 assert_eq!(actual.values(), expected.values())
1337 }
1338
1339 #[test]
1340 fn test_filter_dictionary_array() {
1341 let values = [Some("hello"), None, Some("world"), Some("!")];
1342 let a: Int8DictionaryArray = values.iter().copied().collect();
1343 let b = BooleanArray::from(vec![false, true, true, false]);
1344 let c = filter(&a, &b).unwrap();
1345 let d = c
1346 .as_ref()
1347 .as_any()
1348 .downcast_ref::<Int8DictionaryArray>()
1349 .unwrap();
1350 let value_array = d.values();
1351 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1352 assert_eq!(3, values.len());
1354 assert_eq!(2, d.len());
1356 assert!(d.is_null(0));
1357 assert_eq!("world", values.value(d.keys().value(1) as usize));
1358 }
1359
1360 #[test]
1361 fn test_filter_list_array() {
1362 let value_data = ArrayData::builder(DataType::Int32)
1363 .len(8)
1364 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1365 .build()
1366 .unwrap();
1367
1368 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1369
1370 let list_data_type =
1371 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1372 let list_data = ArrayData::builder(list_data_type)
1373 .len(4)
1374 .add_buffer(value_offsets)
1375 .add_child_data(value_data)
1376 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1377 .build()
1378 .unwrap();
1379
1380 let a = LargeListArray::from(list_data);
1382 let b = BooleanArray::from(vec![false, true, false, true]);
1383 let result = filter(&a, &b).unwrap();
1384
1385 let value_data = ArrayData::builder(DataType::Int32)
1387 .len(3)
1388 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1389 .build()
1390 .unwrap();
1391
1392 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1393
1394 let list_data_type =
1395 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1396 let expected = ArrayData::builder(list_data_type)
1397 .len(2)
1398 .add_buffer(value_offsets)
1399 .add_child_data(value_data)
1400 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1401 .build()
1402 .unwrap();
1403
1404 assert_eq!(&make_array(expected), &result);
1405 }
1406
1407 #[test]
1408 fn test_slice_iterator_bits() {
1409 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1410 let filter = BooleanArray::from(filter_values);
1411 let filter_count = filter_count(&filter);
1412
1413 let iter = SlicesIterator::new(&filter);
1414 let chunks = iter.collect::<Vec<_>>();
1415
1416 assert_eq!(chunks, vec![(1, 2)]);
1417 assert_eq!(filter_count, 1);
1418 }
1419
1420 #[test]
1421 fn test_slice_iterator_bits1() {
1422 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1423 let filter = BooleanArray::from(filter_values);
1424 let filter_count = filter_count(&filter);
1425
1426 let iter = SlicesIterator::new(&filter);
1427 let chunks = iter.collect::<Vec<_>>();
1428
1429 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1430 assert_eq!(filter_count, 64 - 1);
1431 }
1432
1433 #[test]
1434 fn test_slice_iterator_chunk_and_bits() {
1435 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1436 let filter = BooleanArray::from(filter_values);
1437 let filter_count = filter_count(&filter);
1438
1439 let iter = SlicesIterator::new(&filter);
1440 let chunks = iter.collect::<Vec<_>>();
1441
1442 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1443 assert_eq!(filter_count, 61 + 61 + 5);
1444 }
1445
1446 #[test]
1447 fn test_null_mask() {
1448 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1449
1450 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1451 let out = filter(&a, &mask1).unwrap();
1452 assert_eq!(out.as_ref(), &a.slice(0, 2));
1453 }
1454
1455 #[test]
1456 fn test_filter_record_batch_no_columns() {
1457 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1458 let options = RecordBatchOptions::default().with_row_count(Some(100));
1459 let record_batch =
1460 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1461 let out = filter_record_batch(&record_batch, &pred).unwrap();
1462
1463 assert_eq!(out.num_rows(), 2);
1464 }
1465
1466 #[test]
1467 fn test_fast_path() {
1468 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1469
1470 let mask = BooleanArray::from(vec![true, true, true]);
1472 let out = filter(&a, &mask).unwrap();
1473 let b = out
1474 .as_any()
1475 .downcast_ref::<PrimitiveArray<Int64Type>>()
1476 .unwrap();
1477 assert_eq!(&a, b);
1478
1479 let mask = BooleanArray::from(vec![false, false, false]);
1481 let out = filter(&a, &mask).unwrap();
1482 assert_eq!(out.len(), 0);
1483 assert_eq!(out.data_type(), &DataType::Int64);
1484 }
1485
1486 #[test]
1487 fn test_slices() {
1488 let bools = std::iter::repeat(true)
1490 .take(10)
1491 .chain(std::iter::repeat(false).take(30))
1492 .chain(std::iter::repeat(true).take(20))
1493 .chain(std::iter::repeat(false).take(17))
1494 .chain(std::iter::repeat(true).take(4));
1495
1496 let bool_array: BooleanArray = bools.map(Some).collect();
1497
1498 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1499 let expected = vec![(0, 10), (40, 60), (77, 81)];
1500 assert_eq!(slices, expected);
1501
1502 let len = bool_array.len();
1504 let sliced_array = bool_array.slice(7, len - 10);
1505 let sliced_array = sliced_array
1506 .as_any()
1507 .downcast_ref::<BooleanArray>()
1508 .unwrap();
1509 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1510 let expected = vec![(0, 3), (33, 53), (70, 71)];
1511 assert_eq!(slices, expected);
1512 }
1513
1514 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1515 let mut rng = rng();
1516
1517 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1518 .take(mask_len)
1519 .collect();
1520
1521 let buffer = Buffer::from_iter(bools.iter().cloned());
1522
1523 let truncated_length = mask_len - offset - truncate;
1524
1525 let data = ArrayDataBuilder::new(DataType::Boolean)
1526 .len(truncated_length)
1527 .offset(offset)
1528 .add_buffer(buffer)
1529 .build()
1530 .unwrap();
1531
1532 let filter = BooleanArray::from(data);
1533
1534 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1535 .flat_map(|(start, end)| start..end)
1536 .collect();
1537
1538 let count = filter_count(&filter);
1539 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1540
1541 let expected_bits: Vec<_> = bools
1542 .iter()
1543 .skip(offset)
1544 .take(truncated_length)
1545 .enumerate()
1546 .flat_map(|(idx, v)| v.then(|| idx))
1547 .collect();
1548
1549 assert_eq!(slice_bits, expected_bits);
1550 assert_eq!(index_bits, expected_bits);
1551 }
1552
1553 #[test]
1554 #[cfg_attr(miri, ignore)]
1555 fn fuzz_test_slices_iterator() {
1556 let mut rng = rng();
1557
1558 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1559 for _ in 0..100 {
1560 let mask_len = rng.random_range(0..1024);
1561 let max_offset = 64.min(mask_len);
1562 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1563
1564 let max_truncate = 128.min(mask_len - offset);
1565 let truncate = uusize
1566 .sample(&mut rng)
1567 .checked_rem(max_truncate)
1568 .unwrap_or(0);
1569
1570 test_slices_fuzz(mask_len, offset, truncate);
1571 }
1572
1573 test_slices_fuzz(64, 0, 0);
1574 test_slices_fuzz(64, 8, 0);
1575 test_slices_fuzz(64, 8, 8);
1576 test_slices_fuzz(32, 8, 8);
1577 test_slices_fuzz(32, 5, 9);
1578 }
1579
1580 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1582 values
1583 .into_iter()
1584 .zip(predicate)
1585 .filter(|(_, x)| **x)
1586 .map(|(a, _)| a)
1587 .collect()
1588 }
1589
1590 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1592 where
1593 StandardUniform: Distribution<T>,
1594 {
1595 let mut rng = rng();
1596 (0..len)
1597 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1598 .collect()
1599 }
1600
1601 fn gen_strings(
1603 len: usize,
1604 valid_percent: f64,
1605 str_len_range: std::ops::Range<usize>,
1606 ) -> Vec<Option<String>> {
1607 let mut rng = rng();
1608 (0..len)
1609 .map(|_| {
1610 rng.random_bool(valid_percent).then(|| {
1611 let len = rng.random_range(str_len_range.clone());
1612 (0..len)
1613 .map(|_| char::from(rng.sample(Alphanumeric)))
1614 .collect()
1615 })
1616 })
1617 .collect()
1618 }
1619
1620 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1622 src.iter().map(|x| x.as_deref())
1623 }
1624
1625 #[test]
1626 #[cfg_attr(miri, ignore)]
1627 fn fuzz_filter() {
1628 let mut rng = rng();
1629
1630 for i in 0..100 {
1631 let filter_percent = match i {
1632 0..=4 => 1.,
1633 5..=10 => 0.,
1634 _ => rng.random_range(0.0..1.0),
1635 };
1636
1637 let valid_percent = rng.random_range(0.0..1.0);
1638
1639 let array_len = rng.random_range(32..256);
1640 let array_offset = rng.random_range(0..10);
1641
1642 let filter_offset = rng.random_range(0..10);
1644 let filter_truncate = rng.random_range(0..10);
1645 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1646 .take(array_len + filter_offset - filter_truncate)
1647 .collect();
1648
1649 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1650
1651 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1653 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1654 let bools = &bools[filter_offset..];
1655
1656 let values = gen_primitive(array_len + array_offset, valid_percent);
1658 let src = Int32Array::from_iter(values.iter().cloned());
1659
1660 let src = src.slice(array_offset, array_len);
1661 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1662 let values = &values[array_offset..];
1663
1664 let filtered = filter(src, predicate).unwrap();
1665 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1666 let actual: Vec<_> = array.iter().collect();
1667
1668 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1669
1670 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1672 let src = StringArray::from_iter(as_deref(&strings));
1673
1674 let src = src.slice(array_offset, array_len);
1675 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1676
1677 let filtered = filter(src, predicate).unwrap();
1678 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1679 let actual: Vec<_> = array.iter().collect();
1680
1681 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1682 assert_eq!(actual, expected_strings);
1683
1684 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1686
1687 let src = src.slice(array_offset, array_len);
1688 let src = src
1689 .as_any()
1690 .downcast_ref::<DictionaryArray<Int32Type>>()
1691 .unwrap();
1692
1693 let filtered = filter(src, predicate).unwrap();
1694
1695 let array = filtered
1696 .as_any()
1697 .downcast_ref::<DictionaryArray<Int32Type>>()
1698 .unwrap();
1699
1700 let values = array
1701 .values()
1702 .as_any()
1703 .downcast_ref::<StringArray>()
1704 .unwrap();
1705
1706 let actual: Vec<_> = array
1707 .keys()
1708 .iter()
1709 .map(|key| key.map(|key| values.value(key as usize)))
1710 .collect();
1711
1712 assert_eq!(actual, expected_strings);
1713 }
1714 }
1715
1716 #[test]
1717 fn test_filter_map() {
1718 let mut builder =
1719 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1720 builder.keys().append_value("key1");
1722 builder.values().append_value(1);
1723 builder.append(true).unwrap();
1724 builder.keys().append_value("key2");
1725 builder.keys().append_value("key3");
1726 builder.values().append_value(2);
1727 builder.values().append_value(3);
1728 builder.append(true).unwrap();
1729 builder.append(false).unwrap();
1730 builder.keys().append_value("key1");
1731 builder.values().append_value(1);
1732 builder.append(true).unwrap();
1733 let maparray = Arc::new(builder.finish()) as ArrayRef;
1734
1735 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1736 .into_iter()
1737 .collect::<BooleanArray>();
1738 let got = filter(&maparray, &indices).unwrap();
1739
1740 let mut builder =
1741 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1742 builder.keys().append_value("key1");
1743 builder.values().append_value(1);
1744 builder.append(true).unwrap();
1745 builder.keys().append_value("key1");
1746 builder.values().append_value(1);
1747 builder.append(true).unwrap();
1748 let expected = Arc::new(builder.finish()) as ArrayRef;
1749
1750 assert_eq!(&expected, &got);
1751 }
1752
1753 #[test]
1754 fn test_filter_fixed_size_list_arrays() {
1755 let value_data = ArrayData::builder(DataType::Int32)
1756 .len(9)
1757 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1758 .build()
1759 .unwrap();
1760 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1761 let list_data = ArrayData::builder(list_data_type)
1762 .len(3)
1763 .add_child_data(value_data)
1764 .build()
1765 .unwrap();
1766 let array = FixedSizeListArray::from(list_data);
1767
1768 let filter_array = BooleanArray::from(vec![true, false, false]);
1769
1770 let c = filter(&array, &filter_array).unwrap();
1771 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1772
1773 assert_eq!(filtered.len(), 1);
1774
1775 let list = filtered.value(0);
1776 assert_eq!(
1777 &[0, 1, 2],
1778 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1779 );
1780
1781 let filter_array = BooleanArray::from(vec![true, false, true]);
1782
1783 let c = filter(&array, &filter_array).unwrap();
1784 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1785
1786 assert_eq!(filtered.len(), 2);
1787
1788 let list = filtered.value(0);
1789 assert_eq!(
1790 &[0, 1, 2],
1791 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1792 );
1793 let list = filtered.value(1);
1794 assert_eq!(
1795 &[6, 7, 8],
1796 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1797 );
1798 }
1799
1800 #[test]
1801 fn test_filter_fixed_size_list_arrays_with_null() {
1802 let value_data = ArrayData::builder(DataType::Int32)
1803 .len(10)
1804 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1805 .build()
1806 .unwrap();
1807
1808 let mut null_bits: [u8; 1] = [0; 1];
1812 bit_util::set_bit(&mut null_bits, 0);
1813 bit_util::set_bit(&mut null_bits, 3);
1814 bit_util::set_bit(&mut null_bits, 4);
1815
1816 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1817 let list_data = ArrayData::builder(list_data_type)
1818 .len(5)
1819 .add_child_data(value_data)
1820 .null_bit_buffer(Some(Buffer::from(null_bits)))
1821 .build()
1822 .unwrap();
1823 let array = FixedSizeListArray::from(list_data);
1824
1825 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1826
1827 let c = filter(&array, &filter_array).unwrap();
1828 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1829
1830 assert_eq!(filtered.len(), 3);
1831
1832 let list = filtered.value(0);
1833 assert_eq!(
1834 &[0, 1],
1835 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1836 );
1837 assert!(filtered.is_null(1));
1838 let list = filtered.value(2);
1839 assert_eq!(
1840 &[6, 7],
1841 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1842 );
1843 }
1844
1845 fn test_filter_union_array(array: UnionArray) {
1846 let filter_array = BooleanArray::from(vec![true, false, false]);
1847 let c = filter(&array, &filter_array).unwrap();
1848 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1849
1850 let mut builder = UnionBuilder::new_dense();
1851 builder.append::<Int32Type>("A", 1).unwrap();
1852 let expected_array = builder.build().unwrap();
1853
1854 compare_union_arrays(filtered, &expected_array);
1855
1856 let filter_array = BooleanArray::from(vec![true, false, true]);
1857 let c = filter(&array, &filter_array).unwrap();
1858 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1859
1860 let mut builder = UnionBuilder::new_dense();
1861 builder.append::<Int32Type>("A", 1).unwrap();
1862 builder.append::<Int32Type>("A", 34).unwrap();
1863 let expected_array = builder.build().unwrap();
1864
1865 compare_union_arrays(filtered, &expected_array);
1866
1867 let filter_array = BooleanArray::from(vec![true, true, false]);
1868 let c = filter(&array, &filter_array).unwrap();
1869 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1870
1871 let mut builder = UnionBuilder::new_dense();
1872 builder.append::<Int32Type>("A", 1).unwrap();
1873 builder.append::<Float64Type>("B", 3.2).unwrap();
1874 let expected_array = builder.build().unwrap();
1875
1876 compare_union_arrays(filtered, &expected_array);
1877 }
1878
1879 #[test]
1880 fn test_filter_union_array_dense() {
1881 let mut builder = UnionBuilder::new_dense();
1882 builder.append::<Int32Type>("A", 1).unwrap();
1883 builder.append::<Float64Type>("B", 3.2).unwrap();
1884 builder.append::<Int32Type>("A", 34).unwrap();
1885 let array = builder.build().unwrap();
1886
1887 test_filter_union_array(array);
1888 }
1889
1890 #[test]
1891 fn test_filter_run_union_array_dense() {
1892 let mut builder = UnionBuilder::new_dense();
1893 builder.append::<Int32Type>("A", 1).unwrap();
1894 builder.append::<Int32Type>("A", 3).unwrap();
1895 builder.append::<Int32Type>("A", 34).unwrap();
1896 let array = builder.build().unwrap();
1897
1898 let filter_array = BooleanArray::from(vec![true, true, false]);
1899 let c = filter(&array, &filter_array).unwrap();
1900 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1901
1902 let mut builder = UnionBuilder::new_dense();
1903 builder.append::<Int32Type>("A", 1).unwrap();
1904 builder.append::<Int32Type>("A", 3).unwrap();
1905 let expected = builder.build().unwrap();
1906
1907 assert_eq!(filtered.to_data(), expected.to_data());
1908 }
1909
1910 #[test]
1911 fn test_filter_union_array_dense_with_nulls() {
1912 let mut builder = UnionBuilder::new_dense();
1913 builder.append::<Int32Type>("A", 1).unwrap();
1914 builder.append::<Float64Type>("B", 3.2).unwrap();
1915 builder.append_null::<Float64Type>("B").unwrap();
1916 builder.append::<Int32Type>("A", 34).unwrap();
1917 let array = builder.build().unwrap();
1918
1919 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1920 let c = filter(&array, &filter_array).unwrap();
1921 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1922
1923 let mut builder = UnionBuilder::new_dense();
1924 builder.append::<Int32Type>("A", 1).unwrap();
1925 builder.append::<Float64Type>("B", 3.2).unwrap();
1926 let expected_array = builder.build().unwrap();
1927
1928 compare_union_arrays(filtered, &expected_array);
1929
1930 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1931 let c = filter(&array, &filter_array).unwrap();
1932 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1933
1934 let mut builder = UnionBuilder::new_dense();
1935 builder.append::<Int32Type>("A", 1).unwrap();
1936 builder.append_null::<Float64Type>("B").unwrap();
1937 let expected_array = builder.build().unwrap();
1938
1939 compare_union_arrays(filtered, &expected_array);
1940 }
1941
1942 #[test]
1943 fn test_filter_union_array_sparse() {
1944 let mut builder = UnionBuilder::new_sparse();
1945 builder.append::<Int32Type>("A", 1).unwrap();
1946 builder.append::<Float64Type>("B", 3.2).unwrap();
1947 builder.append::<Int32Type>("A", 34).unwrap();
1948 let array = builder.build().unwrap();
1949
1950 test_filter_union_array(array);
1951 }
1952
1953 #[test]
1954 fn test_filter_union_array_sparse_with_nulls() {
1955 let mut builder = UnionBuilder::new_sparse();
1956 builder.append::<Int32Type>("A", 1).unwrap();
1957 builder.append::<Float64Type>("B", 3.2).unwrap();
1958 builder.append_null::<Float64Type>("B").unwrap();
1959 builder.append::<Int32Type>("A", 34).unwrap();
1960 let array = builder.build().unwrap();
1961
1962 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1963 let c = filter(&array, &filter_array).unwrap();
1964 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1965
1966 let mut builder = UnionBuilder::new_sparse();
1967 builder.append::<Int32Type>("A", 1).unwrap();
1968 builder.append_null::<Float64Type>("B").unwrap();
1969 let expected_array = builder.build().unwrap();
1970
1971 compare_union_arrays(filtered, &expected_array);
1972 }
1973
1974 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1975 assert_eq!(union1.len(), union2.len());
1976
1977 for i in 0..union1.len() {
1978 let type_id = union1.type_id(i);
1979
1980 let slot1 = union1.value(i);
1981 let slot2 = union2.value(i);
1982
1983 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1984
1985 if !slot1.is_null(0) && !slot2.is_null(0) {
1986 match type_id {
1987 0 => {
1988 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1989 assert_eq!(slot1.len(), 1);
1990 let value1 = slot1.value(0);
1991
1992 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1993 assert_eq!(slot2.len(), 1);
1994 let value2 = slot2.value(0);
1995 assert_eq!(value1, value2);
1996 }
1997 1 => {
1998 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1999 assert_eq!(slot1.len(), 1);
2000 let value1 = slot1.value(0);
2001
2002 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2003 assert_eq!(slot2.len(), 1);
2004 let value2 = slot2.value(0);
2005 assert_eq!(value1, value2);
2006 }
2007 _ => unreachable!(),
2008 }
2009 }
2010 }
2011 }
2012
2013 #[test]
2014 fn test_filter_struct() {
2015 let predicate = BooleanArray::from(vec![true, false, true, false]);
2016
2017 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2018 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2019
2020 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2021 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2022
2023 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2024 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2025
2026 let a_field = Field::new("a", DataType::Utf8, false);
2027 let b_field = Field::new("b", DataType::Int32, false);
2028
2029 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2030 let expected =
2031 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2032
2033 let result = filter(&array, &predicate).unwrap();
2034
2035 assert_eq!(result.to_data(), expected.to_data());
2036
2037 let array = StructArray::new(
2038 vec![a_field.clone()].into(),
2039 vec![a.clone()],
2040 Some(null_mask.clone()),
2041 );
2042 let expected = StructArray::new(
2043 vec![a_field.clone()].into(),
2044 vec![a_filtered.clone()],
2045 Some(null_mask_filtered.clone()),
2046 );
2047
2048 let result = filter(&array, &predicate).unwrap();
2049
2050 assert_eq!(result.to_data(), expected.to_data());
2051
2052 let array = StructArray::new(
2053 vec![a_field.clone(), b_field.clone()].into(),
2054 vec![a.clone(), b.clone()],
2055 None,
2056 );
2057 let expected = StructArray::new(
2058 vec![a_field.clone(), b_field.clone()].into(),
2059 vec![a_filtered.clone(), b_filtered.clone()],
2060 None,
2061 );
2062
2063 let result = filter(&array, &predicate).unwrap();
2064
2065 assert_eq!(result.to_data(), expected.to_data());
2066
2067 let array = StructArray::new(
2068 vec![a_field.clone(), b_field.clone()].into(),
2069 vec![a.clone(), b.clone()],
2070 Some(null_mask.clone()),
2071 );
2072
2073 let expected = StructArray::new(
2074 vec![a_field.clone(), b_field.clone()].into(),
2075 vec![a_filtered.clone(), b_filtered.clone()],
2076 Some(null_mask_filtered.clone()),
2077 );
2078
2079 let result = filter(&array, &predicate).unwrap();
2080
2081 assert_eq!(result.to_data(), expected.to_data());
2082 }
2083
2084 #[test]
2085 fn test_filter_empty_struct() {
2086 let fields = arrow_schema::Field::new(
2093 "a",
2094 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2095 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2096 arrow_schema::Field::new(
2097 "c",
2098 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2099 true,
2100 ),
2101 ])),
2102 true,
2103 );
2104
2105 let schema = Arc::new(Schema::new(vec![fields]));
2113
2114 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2115 let c = Arc::new(StructArray::new_empty_fields(
2116 3,
2117 Some(NullBuffer::from(vec![true, true, true])),
2118 ));
2119 let a = StructArray::new(
2120 vec![
2121 Field::new("b", DataType::Int64, true),
2122 Field::new("c", DataType::Struct(Fields::empty()), true),
2123 ]
2124 .into(),
2125 vec![b.clone(), c.clone()],
2126 Some(NullBuffer::from(vec![true, true, true])),
2127 );
2128 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2129 println!("{record_batch:?}");
2130
2131 let predicate = BooleanArray::from(vec![true, false, true]);
2133 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2134
2135 assert_eq!(filtered_batch.num_rows(), 2);
2137 }
2138}