1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    pub fn new(filter: &'a BooleanArray) -> Self {
61        filter.values().into()
62    }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66    fn from(filter: &'a BooleanBuffer) -> Self {
67        Self(filter.set_slices())
68    }
69}
70
71impl Iterator for SlicesIterator<'_> {
72    type Item = (usize, usize);
73
74    fn next(&mut self) -> Option<Self::Item> {
75        self.0.next()
76    }
77}
78
79struct IndexIterator<'a> {
84    remaining: usize,
85    iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90        assert_eq!(filter.null_count(), 0);
91        let iter = filter.values().set_indices();
92        Self { remaining, iter }
93    }
94}
95
96impl Iterator for IndexIterator<'_> {
97    type Item = usize;
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if self.remaining != 0 {
101            let next = self.iter.next().expect("IndexIterator exhausted early");
104            self.remaining -= 1;
105            return Some(next);
107        }
108        None
109    }
110
111    fn size_hint(&self) -> (usize, Option<usize>) {
112        (self.remaining, Some(self.remaining))
113    }
114}
115
116fn filter_count(filter: &BooleanArray) -> usize {
118    filter.values().count_set_bits()
119}
120
121pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
123    let nulls = filter.nulls().unwrap();
124    let mask = filter.values() & nulls.inner();
125    BooleanArray::new(mask, None)
126}
127
128pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
156    let mut filter_builder = FilterBuilder::new(predicate);
157
158    if multiple_arrays(values.data_type()) {
159        filter_builder = filter_builder.optimize();
162    }
163
164    let predicate = filter_builder.build();
165
166    filter_array(values, &predicate)
167}
168
169fn multiple_arrays(data_type: &DataType) -> bool {
170    match data_type {
171        DataType::Struct(fields) => {
172            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
173        }
174        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
175        _ => false,
176    }
177}
178
179pub fn filter_record_batch(
190    record_batch: &RecordBatch,
191    predicate: &BooleanArray,
192) -> Result<RecordBatch, ArrowError> {
193    let mut filter_builder = FilterBuilder::new(predicate);
194    let num_cols = record_batch.num_columns();
195    if num_cols > 1
196        || (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type()))
197    {
198        filter_builder = filter_builder.optimize();
201    }
202    let filter = filter_builder.build();
203
204    filter.filter_record_batch(record_batch)
205}
206
207#[derive(Debug)]
209pub struct FilterBuilder {
210    filter: BooleanArray,
211    count: usize,
212    strategy: IterationStrategy,
213}
214
215impl FilterBuilder {
216    pub fn new(filter: &BooleanArray) -> Self {
218        let filter = match filter.null_count() {
219            0 => filter.clone(),
220            _ => prep_null_mask_filter(filter),
221        };
222
223        let count = filter_count(&filter);
224        let strategy = IterationStrategy::default_strategy(filter.len(), count);
225
226        Self {
227            filter,
228            count,
229            strategy,
230        }
231    }
232
233    pub fn optimize(mut self) -> Self {
239        match self.strategy {
240            IterationStrategy::SlicesIterator => {
241                let slices = SlicesIterator::new(&self.filter).collect();
242                self.strategy = IterationStrategy::Slices(slices)
243            }
244            IterationStrategy::IndexIterator => {
245                let indices = IndexIterator::new(&self.filter, self.count).collect();
246                self.strategy = IterationStrategy::Indices(indices)
247            }
248            _ => {}
249        }
250        self
251    }
252
253    pub fn build(self) -> FilterPredicate {
255        FilterPredicate {
256            filter: self.filter,
257            count: self.count,
258            strategy: self.strategy,
259        }
260    }
261}
262
263#[derive(Debug)]
265enum IterationStrategy {
266    SlicesIterator,
268    IndexIterator,
270    Indices(Vec<usize>),
272    Slices(Vec<(usize, usize)>),
274    All,
276    None,
278}
279
280impl IterationStrategy {
281    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
284        if filter_length == 0 || filter_count == 0 {
285            return IterationStrategy::None;
286        }
287
288        if filter_count == filter_length {
289            return IterationStrategy::All;
290        }
291
292        let selectivity_frac = filter_count as f64 / filter_length as f64;
297        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
298            return IterationStrategy::SlicesIterator;
299        }
300        IterationStrategy::IndexIterator
301    }
302}
303
304#[derive(Debug)]
306pub struct FilterPredicate {
307    filter: BooleanArray,
308    count: usize,
309    strategy: IterationStrategy,
310}
311
312impl FilterPredicate {
313    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
315        filter_array(values, self)
316    }
317
318    pub fn filter_record_batch(
323        &self,
324        record_batch: &RecordBatch,
325    ) -> Result<RecordBatch, ArrowError> {
326        let filtered_arrays = record_batch
327            .columns()
328            .iter()
329            .map(|a| filter_array(a, self))
330            .collect::<Result<Vec<_>, _>>()?;
331
332        unsafe {
335            Ok(RecordBatch::new_unchecked(
336                record_batch.schema(),
337                filtered_arrays,
338                self.count,
339            ))
340        }
341    }
342
343    pub fn count(&self) -> usize {
345        self.count
346    }
347}
348
349fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
350    if predicate.filter.len() > values.len() {
351        return Err(ArrowError::InvalidArgumentError(format!(
352            "Filter predicate of length {} is larger than target array of length {}",
353            predicate.filter.len(),
354            values.len()
355        )));
356    }
357
358    match predicate.strategy {
359        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
360        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
361        _ => downcast_primitive_array! {
363            values => Ok(Arc::new(filter_primitive(values, predicate))),
364            DataType::Boolean => {
365                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
366                Ok(Arc::new(filter_boolean(values, predicate)))
367            }
368            DataType::Utf8 => {
369                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
370            }
371            DataType::LargeUtf8 => {
372                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
373            }
374            DataType::Utf8View => {
375                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
376            }
377            DataType::Binary => {
378                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
379            }
380            DataType::LargeBinary => {
381                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
382            }
383            DataType::BinaryView => {
384                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
385            }
386            DataType::FixedSizeBinary(_) => {
387                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
388            }
389            DataType::ListView(_) => {
390                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
391            }
392            DataType::LargeListView(_) => {
393                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
394            }
395            DataType::RunEndEncoded(_, _) => {
396                downcast_run_array!{
397                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
398                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
399                }
400            }
401            DataType::Dictionary(_, _) => downcast_dictionary_array! {
402                values => Ok(Arc::new(filter_dict(values, predicate))),
403                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
404            }
405            DataType::Struct(_) => {
406                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
407            }
408            DataType::Union(_, UnionMode::Sparse) => {
409                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
410            }
411            _ => {
412                let data = values.to_data();
413                let mut mutable = MutableArrayData::new(
415                    vec![&data],
416                    false,
417                    predicate.count,
418                );
419
420                match &predicate.strategy {
421                    IterationStrategy::Slices(slices) => {
422                        slices
423                            .iter()
424                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
425                    }
426                    _ => {
427                        let iter = SlicesIterator::new(&predicate.filter);
428                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
429                    }
430                }
431
432                let data = mutable.freeze();
433                Ok(make_array(data))
434            }
435        },
436    }
437}
438
439fn filter_run_end_array<R: RunEndIndexType>(
441    array: &RunArray<R>,
442    predicate: &FilterPredicate,
443) -> Result<RunArray<R>, ArrowError>
444where
445    R::Native: Into<i64> + From<bool>,
446    R::Native: AddAssign,
447{
448    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
449    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
450
451    let mut start = 0u64;
452    let mut j = 0;
453    let mut count = R::default_value();
454    let filter_values = predicate.filter.values();
455    let run_ends = run_ends.inner();
456
457    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
458        let mut keep = false;
459        let mut end = run_ends[i].into() as u64;
460        let difference = end.saturating_sub(filter_values.len() as u64);
461        end -= difference;
462
463        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
465            count += R::Native::from(pred);
466            keep |= pred
467        }
468        new_run_ends[j] = count;
470        j += keep as usize;
471
472        start = end;
473        keep
474    })
475    .into();
476
477    new_run_ends.truncate(j);
478
479    let values = array.values();
480    let values = filter(&values, &pred)?;
481
482    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
483    RunArray::try_new(&run_ends, &values)
484}
485
486fn filter_null_mask(
493    nulls: Option<&NullBuffer>,
494    predicate: &FilterPredicate,
495) -> Option<(usize, Buffer)> {
496    let nulls = nulls?;
497    if nulls.null_count() == 0 {
498        return None;
499    }
500
501    let nulls = filter_bits(nulls.inner(), predicate);
502    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
505
506    if null_count == 0 {
507        return None;
508    }
509
510    Some((null_count, nulls))
511}
512
513fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
515    let src = buffer.values();
516    let offset = buffer.offset();
517
518    match &predicate.strategy {
519        IterationStrategy::IndexIterator => {
520            let bits = IndexIterator::new(&predicate.filter, predicate.count)
521                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
522
523            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
525        }
526        IterationStrategy::Indices(indices) => {
527            let bits = indices
528                .iter()
529                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
530
531            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
533        }
534        IterationStrategy::SlicesIterator => {
535            let mut builder = BooleanBufferBuilder::new(predicate.count);
536            for (start, end) in SlicesIterator::new(&predicate.filter) {
537                builder.append_packed_range(start + offset..end + offset, src)
538            }
539            builder.into()
540        }
541        IterationStrategy::Slices(slices) => {
542            let mut builder = BooleanBufferBuilder::new(predicate.count);
543            for (start, end) in slices {
544                builder.append_packed_range(*start + offset..*end + offset, src)
545            }
546            builder.into()
547        }
548        IterationStrategy::All | IterationStrategy::None => unreachable!(),
549    }
550}
551
552fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
554    let values = filter_bits(array.values(), predicate);
555
556    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
557        .len(predicate.count)
558        .add_buffer(values);
559
560    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
561        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
562    }
563
564    let data = unsafe { builder.build_unchecked() };
565    BooleanArray::from(data)
566}
567
568#[inline(never)]
569fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
570    assert!(values.len() >= predicate.filter.len());
571
572    match &predicate.strategy {
573        IterationStrategy::SlicesIterator => {
574            let mut buffer = Vec::with_capacity(predicate.count);
575            for (start, end) in SlicesIterator::new(&predicate.filter) {
576                buffer.extend_from_slice(&values[start..end]);
577            }
578            buffer.into()
579        }
580        IterationStrategy::Slices(slices) => {
581            let mut buffer = Vec::with_capacity(predicate.count);
582            for (start, end) in slices {
583                buffer.extend_from_slice(&values[*start..*end]);
584            }
585            buffer.into()
586        }
587        IterationStrategy::IndexIterator => {
588            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
589
590            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
592        }
593        IterationStrategy::Indices(indices) => {
594            let iter = indices.iter().map(|x| values[*x]);
595            iter.collect::<Vec<_>>().into()
596        }
597        IterationStrategy::All | IterationStrategy::None => unreachable!(),
598    }
599}
600
601fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
603where
604    T: ArrowPrimitiveType,
605{
606    let values = array.values();
607    let buffer = filter_native(values, predicate);
608    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
609        .len(predicate.count)
610        .add_buffer(buffer);
611
612    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
613        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
614    }
615
616    let data = unsafe { builder.build_unchecked() };
617    PrimitiveArray::from(data)
618}
619
620struct FilterBytes<'a, OffsetSize> {
625    src_offsets: &'a [OffsetSize],
626    src_values: &'a [u8],
627    dst_offsets: Vec<OffsetSize>,
628    dst_values: Vec<u8>,
629    cur_offset: OffsetSize,
630}
631
632impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
633where
634    OffsetSize: OffsetSizeTrait,
635{
636    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
637    where
638        T: ByteArrayType<Offset = OffsetSize>,
639    {
640        let dst_values = Vec::new();
641        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
642        let cur_offset = OffsetSize::from_usize(0).unwrap();
643
644        dst_offsets.push(cur_offset);
645
646        Self {
647            src_offsets: array.value_offsets(),
648            src_values: array.value_data(),
649            dst_offsets,
650            dst_values,
651            cur_offset,
652        }
653    }
654
655    #[inline]
657    fn get_value_offset(&self, idx: usize) -> usize {
658        self.src_offsets[idx].as_usize()
659    }
660
661    #[inline]
663    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
664        let start = self.get_value_offset(idx);
666        let end = self.get_value_offset(idx + 1);
667        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
668        (start, end, len)
669    }
670
671    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
672        self.dst_offsets.extend(iter.map(|idx| {
673            let start = self.src_offsets[idx].as_usize();
674            let end = self.src_offsets[idx + 1].as_usize();
675            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
676            self.cur_offset += len;
677
678            self.cur_offset
679        }));
680    }
681
682    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
684        self.dst_values.reserve_exact(self.cur_offset.as_usize());
685
686        for idx in iter {
687            let start = self.src_offsets[idx].as_usize();
688            let end = self.src_offsets[idx + 1].as_usize();
689            self.dst_values
690                .extend_from_slice(&self.src_values[start..end]);
691        }
692    }
693
694    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
695        self.dst_offsets.reserve_exact(count);
696        for (start, end) in iter {
697            for idx in start..end {
699                let (_, _, len) = self.get_value_range(idx);
700                self.cur_offset += len;
701                self.dst_offsets.push(self.cur_offset);
702            }
703        }
704    }
705
706    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
708        self.dst_values.reserve_exact(self.cur_offset.as_usize());
709
710        for (start, end) in iter {
711            let value_start = self.get_value_offset(start);
712            let value_end = self.get_value_offset(end);
713            self.dst_values
714                .extend_from_slice(&self.src_values[value_start..value_end]);
715        }
716    }
717}
718
719fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
724where
725    T: ByteArrayType,
726{
727    let mut filter = FilterBytes::new(predicate.count, array);
728
729    match &predicate.strategy {
730        IterationStrategy::SlicesIterator => {
731            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
732            filter.extend_slices(SlicesIterator::new(&predicate.filter))
733        }
734        IterationStrategy::Slices(slices) => {
735            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
736            filter.extend_slices(slices.iter().cloned())
737        }
738        IterationStrategy::IndexIterator => {
739            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
740            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
741        }
742        IterationStrategy::Indices(indices) => {
743            filter.extend_offsets_idx(indices.iter().cloned());
744            filter.extend_idx(indices.iter().cloned())
745        }
746        IterationStrategy::All | IterationStrategy::None => unreachable!(),
747    }
748
749    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
750        .len(predicate.count)
751        .add_buffer(filter.dst_offsets.into())
752        .add_buffer(filter.dst_values.into());
753
754    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
755        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
756    }
757
758    let data = unsafe { builder.build_unchecked() };
759    GenericByteArray::from(data)
760}
761
762fn filter_byte_view<T: ByteViewType>(
764    array: &GenericByteViewArray<T>,
765    predicate: &FilterPredicate,
766) -> GenericByteViewArray<T> {
767    let new_view_buffer = filter_native(array.views(), predicate);
768
769    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
770        .len(predicate.count)
771        .add_buffer(new_view_buffer)
772        .add_buffers(array.data_buffers().to_vec());
773
774    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
775        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
776    }
777
778    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
779}
780
781fn filter_fixed_size_binary(
782    array: &FixedSizeBinaryArray,
783    predicate: &FilterPredicate,
784) -> FixedSizeBinaryArray {
785    let values: &[u8] = array.values();
786    let value_length = array.value_length() as usize;
787    let calculate_offset_from_index = |index: usize| index * value_length;
788    let buffer = match &predicate.strategy {
789        IterationStrategy::SlicesIterator => {
790            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
791            for (start, end) in SlicesIterator::new(&predicate.filter) {
792                buffer.extend_from_slice(
793                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
794                );
795            }
796            buffer
797        }
798        IterationStrategy::Slices(slices) => {
799            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
800            for (start, end) in slices {
801                buffer.extend_from_slice(
802                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
803                );
804            }
805            buffer
806        }
807        IterationStrategy::IndexIterator => {
808            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
809                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
810            });
811
812            let mut buffer = MutableBuffer::new(predicate.count * value_length);
813            iter.for_each(|item| buffer.extend_from_slice(item));
814            buffer
815        }
816        IterationStrategy::Indices(indices) => {
817            let iter = indices.iter().map(|x| {
818                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
819            });
820
821            let mut buffer = MutableBuffer::new(predicate.count * value_length);
822            iter.for_each(|item| buffer.extend_from_slice(item));
823            buffer
824        }
825        IterationStrategy::All | IterationStrategy::None => unreachable!(),
826    };
827    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
828        .len(predicate.count)
829        .add_buffer(buffer.into());
830
831    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
832        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
833    }
834
835    let data = unsafe { builder.build_unchecked() };
836    FixedSizeBinaryArray::from(data)
837}
838
839fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
841where
842    T: ArrowDictionaryKeyType,
843    T::Native: num_traits::Num,
844{
845    let builder = filter_primitive::<T>(array.keys(), predicate)
846        .into_data()
847        .into_builder()
848        .data_type(array.data_type().clone())
849        .child_data(vec![array.values().to_data()]);
850
851    DictionaryArray::from(unsafe { builder.build_unchecked() })
854}
855
856fn filter_struct(
858    array: &StructArray,
859    predicate: &FilterPredicate,
860) -> Result<StructArray, ArrowError> {
861    let columns = array
862        .columns()
863        .iter()
864        .map(|column| filter_array(column, predicate))
865        .collect::<Result<_, _>>()?;
866
867    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
868        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
869
870        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
871    } else {
872        None
873    };
874
875    Ok(unsafe {
876        StructArray::new_unchecked_with_length(
877            array.fields().clone(),
878            columns,
879            nulls,
880            predicate.count(),
881        )
882    })
883}
884
885fn filter_sparse_union(
887    array: &UnionArray,
888    predicate: &FilterPredicate,
889) -> Result<UnionArray, ArrowError> {
890    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
891        unreachable!()
892    };
893
894    let type_ids = filter_primitive(
895        &Int8Array::try_new(array.type_ids().clone(), None)?,
896        predicate,
897    );
898
899    let children = fields
900        .iter()
901        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
902        .collect::<Result<_, _>>()?;
903
904    Ok(unsafe {
905        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
906    })
907}
908
909fn filter_list_view<OffsetType: OffsetSizeTrait>(
911    array: &GenericListViewArray<OffsetType>,
912    predicate: &FilterPredicate,
913) -> GenericListViewArray<OffsetType> {
914    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
915    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
916
917    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
919        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
920
921        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
922    } else {
923        None
924    };
925
926    let list_data = ArrayDataBuilder::new(array.data_type().clone())
927        .nulls(nulls)
928        .buffers(vec![filtered_offsets, filtered_sizes])
929        .child_data(vec![array.values().to_data()])
930        .len(predicate.count);
931
932    let list_data = unsafe { list_data.build_unchecked() };
933
934    GenericListViewArray::from(list_data)
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use arrow_array::builder::*;
941    use arrow_array::cast::as_run_array;
942    use arrow_array::types::*;
943    use arrow_data::ArrayData;
944    use rand::distr::uniform::{UniformSampler, UniformUsize};
945    use rand::distr::{Alphanumeric, StandardUniform};
946    use rand::prelude::*;
947    use rand::rng;
948
949    macro_rules! def_temporal_test {
950        ($test:ident, $array_type: ident, $data: expr) => {
951            #[test]
952            fn $test() {
953                let a = $data;
954                let b = BooleanArray::from(vec![true, false, true, false]);
955                let c = filter(&a, &b).unwrap();
956                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
957                assert_eq!(2, d.len());
958                assert_eq!(1, d.value(0));
959                assert_eq!(3, d.value(1));
960            }
961        };
962    }
963
964    def_temporal_test!(
965        test_filter_date32,
966        Date32Array,
967        Date32Array::from(vec![1, 2, 3, 4])
968    );
969    def_temporal_test!(
970        test_filter_date64,
971        Date64Array,
972        Date64Array::from(vec![1, 2, 3, 4])
973    );
974    def_temporal_test!(
975        test_filter_time32_second,
976        Time32SecondArray,
977        Time32SecondArray::from(vec![1, 2, 3, 4])
978    );
979    def_temporal_test!(
980        test_filter_time32_millisecond,
981        Time32MillisecondArray,
982        Time32MillisecondArray::from(vec![1, 2, 3, 4])
983    );
984    def_temporal_test!(
985        test_filter_time64_microsecond,
986        Time64MicrosecondArray,
987        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
988    );
989    def_temporal_test!(
990        test_filter_time64_nanosecond,
991        Time64NanosecondArray,
992        Time64NanosecondArray::from(vec![1, 2, 3, 4])
993    );
994    def_temporal_test!(
995        test_filter_duration_second,
996        DurationSecondArray,
997        DurationSecondArray::from(vec![1, 2, 3, 4])
998    );
999    def_temporal_test!(
1000        test_filter_duration_millisecond,
1001        DurationMillisecondArray,
1002        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1003    );
1004    def_temporal_test!(
1005        test_filter_duration_microsecond,
1006        DurationMicrosecondArray,
1007        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1008    );
1009    def_temporal_test!(
1010        test_filter_duration_nanosecond,
1011        DurationNanosecondArray,
1012        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1013    );
1014    def_temporal_test!(
1015        test_filter_timestamp_second,
1016        TimestampSecondArray,
1017        TimestampSecondArray::from(vec![1, 2, 3, 4])
1018    );
1019    def_temporal_test!(
1020        test_filter_timestamp_millisecond,
1021        TimestampMillisecondArray,
1022        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1023    );
1024    def_temporal_test!(
1025        test_filter_timestamp_microsecond,
1026        TimestampMicrosecondArray,
1027        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1028    );
1029    def_temporal_test!(
1030        test_filter_timestamp_nanosecond,
1031        TimestampNanosecondArray,
1032        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1033    );
1034
1035    #[test]
1036    fn test_filter_array_slice() {
1037        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1038        let b = BooleanArray::from(vec![true, false, false, true]);
1039        let c = filter(&a, &b).unwrap();
1043        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1044        assert_eq!(2, d.len());
1045        assert_eq!(6, d.value(0));
1046        assert_eq!(9, d.value(1));
1047    }
1048
1049    #[test]
1050    fn test_filter_array_low_density() {
1051        let mut data_values = (1..=65).collect::<Vec<i32>>();
1053        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1054        data_values.extend_from_slice(&[66, 67]);
1056        filter_values.extend_from_slice(&[false, true]);
1057        let a = Int32Array::from(data_values);
1058        let b = BooleanArray::from(filter_values);
1059        let c = filter(&a, &b).unwrap();
1060        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1061        assert_eq!(2, d.len());
1062        assert_eq!(65, d.value(0));
1063        assert_eq!(67, d.value(1));
1064    }
1065
1066    #[test]
1067    fn test_filter_array_high_density() {
1068        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1070        let mut filter_values = (1..=65)
1071            .map(|i| !matches!(i % 65, 0))
1072            .collect::<Vec<bool>>();
1073        data_values[1] = None;
1075        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1077        filter_values.extend_from_slice(&[false, true, true, true]);
1078        let a = Int32Array::from(data_values);
1079        let b = BooleanArray::from(filter_values);
1080        let c = filter(&a, &b).unwrap();
1081        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1082        assert_eq!(67, d.len());
1083        assert_eq!(3, d.null_count());
1084        assert_eq!(1, d.value(0));
1085        assert!(d.is_null(1));
1086        assert_eq!(64, d.value(63));
1087        assert!(d.is_null(64));
1088        assert_eq!(67, d.value(65));
1089    }
1090
1091    #[test]
1092    fn test_filter_string_array_simple() {
1093        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1094        let b = BooleanArray::from(vec![true, false, true, false]);
1095        let c = filter(&a, &b).unwrap();
1096        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1097        assert_eq!(2, d.len());
1098        assert_eq!("hello", d.value(0));
1099        assert_eq!("world", d.value(1));
1100    }
1101
1102    #[test]
1103    fn test_filter_primitive_array_with_null() {
1104        let a = Int32Array::from(vec![Some(5), None]);
1105        let b = BooleanArray::from(vec![false, true]);
1106        let c = filter(&a, &b).unwrap();
1107        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1108        assert_eq!(1, d.len());
1109        assert!(d.is_null(0));
1110    }
1111
1112    #[test]
1113    fn test_filter_string_array_with_null() {
1114        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1115        let b = BooleanArray::from(vec![true, false, false, true]);
1116        let c = filter(&a, &b).unwrap();
1117        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1118        assert_eq!(2, d.len());
1119        assert_eq!("hello", d.value(0));
1120        assert!(!d.is_null(0));
1121        assert!(d.is_null(1));
1122    }
1123
1124    #[test]
1125    fn test_filter_binary_array_with_null() {
1126        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1127        let a = BinaryArray::from(data);
1128        let b = BooleanArray::from(vec![true, false, false, true]);
1129        let c = filter(&a, &b).unwrap();
1130        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1131        assert_eq!(2, d.len());
1132        assert_eq!(b"hello", d.value(0));
1133        assert!(!d.is_null(0));
1134        assert!(d.is_null(1));
1135    }
1136
1137    fn _test_filter_byte_view<T>()
1138    where
1139        T: ByteViewType,
1140        str: AsRef<T::Native>,
1141        T::Native: PartialEq,
1142    {
1143        let array = {
1144            let mut builder = GenericByteViewBuilder::<T>::new();
1146            builder.append_value("hello");
1147            builder.append_value("world");
1148            builder.append_null();
1149            builder.append_value("large payload over 12 bytes");
1150            builder.append_value("lulu");
1151            builder.finish()
1152        };
1153
1154        {
1155            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1156            let actual = filter(&array, &predicate).unwrap();
1157
1158            assert_eq!(actual.len(), 3);
1159
1160            let expected = {
1161                let mut builder = GenericByteViewBuilder::<T>::new();
1163                builder.append_value("hello");
1164                builder.append_null();
1165                builder.append_value("large payload over 12 bytes");
1166                builder.finish()
1167            };
1168
1169            assert_eq!(actual.as_ref(), &expected);
1170        }
1171
1172        {
1173            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1174            let actual = filter(&array, &predicate).unwrap();
1175
1176            assert_eq!(actual.len(), 2);
1177
1178            let expected = {
1179                let mut builder = GenericByteViewBuilder::<T>::new();
1181                builder.append_value("hello");
1182                builder.append_value("lulu");
1183                builder.finish()
1184            };
1185
1186            assert_eq!(actual.as_ref(), &expected);
1187        }
1188    }
1189
1190    #[test]
1191    fn test_filter_string_view() {
1192        _test_filter_byte_view::<StringViewType>()
1193    }
1194
1195    #[test]
1196    fn test_filter_binary_view() {
1197        _test_filter_byte_view::<BinaryViewType>()
1198    }
1199
1200    #[test]
1201    fn test_filter_fixed_binary() {
1202        let v1 = [1_u8, 2];
1203        let v2 = [3_u8, 4];
1204        let v3 = [5_u8, 6];
1205        let v = vec![&v1, &v2, &v3];
1206        let a = FixedSizeBinaryArray::from(v);
1207        let b = BooleanArray::from(vec![true, false, true]);
1208        let c = filter(&a, &b).unwrap();
1209        let d = c
1210            .as_ref()
1211            .as_any()
1212            .downcast_ref::<FixedSizeBinaryArray>()
1213            .unwrap();
1214        assert_eq!(d.len(), 2);
1215        assert_eq!(d.value(0), &v1);
1216        assert_eq!(d.value(1), &v3);
1217        let c2 = FilterBuilder::new(&b)
1218            .optimize()
1219            .build()
1220            .filter(&a)
1221            .unwrap();
1222        let d2 = c2
1223            .as_ref()
1224            .as_any()
1225            .downcast_ref::<FixedSizeBinaryArray>()
1226            .unwrap();
1227        assert_eq!(d, d2);
1228
1229        let b = BooleanArray::from(vec![false, false, false]);
1230        let c = filter(&a, &b).unwrap();
1231        let d = c
1232            .as_ref()
1233            .as_any()
1234            .downcast_ref::<FixedSizeBinaryArray>()
1235            .unwrap();
1236        assert_eq!(d.len(), 0);
1237
1238        let b = BooleanArray::from(vec![true, true, true]);
1239        let c = filter(&a, &b).unwrap();
1240        let d = c
1241            .as_ref()
1242            .as_any()
1243            .downcast_ref::<FixedSizeBinaryArray>()
1244            .unwrap();
1245        assert_eq!(d.len(), 3);
1246        assert_eq!(d.value(0), &v1);
1247        assert_eq!(d.value(1), &v2);
1248        assert_eq!(d.value(2), &v3);
1249
1250        let b = BooleanArray::from(vec![false, false, true]);
1251        let c = filter(&a, &b).unwrap();
1252        let d = c
1253            .as_ref()
1254            .as_any()
1255            .downcast_ref::<FixedSizeBinaryArray>()
1256            .unwrap();
1257        assert_eq!(d.len(), 1);
1258        assert_eq!(d.value(0), &v3);
1259        let c2 = FilterBuilder::new(&b)
1260            .optimize()
1261            .build()
1262            .filter(&a)
1263            .unwrap();
1264        let d2 = c2
1265            .as_ref()
1266            .as_any()
1267            .downcast_ref::<FixedSizeBinaryArray>()
1268            .unwrap();
1269        assert_eq!(d, d2);
1270    }
1271
1272    #[test]
1273    fn test_filter_array_slice_with_null() {
1274        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1275        let b = BooleanArray::from(vec![true, false, false, true]);
1276        let c = filter(&a, &b).unwrap();
1280        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1281        assert_eq!(2, d.len());
1282        assert!(d.is_null(0));
1283        assert!(!d.is_null(1));
1284        assert_eq!(9, d.value(1));
1285    }
1286
1287    #[test]
1288    fn test_filter_run_end_encoding_array() {
1289        let run_ends = Int64Array::from(vec![2, 3, 8]);
1290        let values = Int64Array::from(vec![7, -2, 9]);
1291        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1292        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1293        let c = filter(&a, &b).unwrap();
1294        let actual: &RunArray<Int64Type> = as_run_array(&c);
1295        assert_eq!(4, actual.len());
1296
1297        let expected = RunArray::try_new(
1298            &Int64Array::from(vec![1, 2, 4]),
1299            &Int64Array::from(vec![7, -2, 9]),
1300        )
1301        .expect("Failed to make expected RunArray test is broken");
1302
1303        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1304        assert_eq!(actual.values(), expected.values())
1305    }
1306
1307    #[test]
1308    fn test_filter_run_end_encoding_array_remove_value() {
1309        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1310        let values = Int32Array::from(vec![7, -2, 9, -8]);
1311        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1312        let b = BooleanArray::from(vec![
1313            false, true, false, false, true, false, true, false, false, false,
1314        ]);
1315        let c = filter(&a, &b).unwrap();
1316        let actual: &RunArray<Int32Type> = as_run_array(&c);
1317        assert_eq!(3, actual.len());
1318
1319        let expected =
1320            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1321                .expect("Failed to make expected RunArray test is broken");
1322
1323        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1324        assert_eq!(actual.values(), expected.values())
1325    }
1326
1327    #[test]
1328    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1329        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1330        let values = Int16Array::from(vec![7, -2, 9, -8]);
1331        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1332        let b = BooleanArray::from(vec![
1333            false, false, false, false, false, false, true, false, false, false,
1334        ]);
1335        let c = filter(&a, &b).unwrap();
1336        let actual: &RunArray<Int16Type> = as_run_array(&c);
1337        assert_eq!(1, actual.len());
1338
1339        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1340            .expect("Failed to make expected RunArray test is broken");
1341
1342        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1343        assert_eq!(actual.values(), expected.values())
1344    }
1345
1346    #[test]
1347    fn test_filter_run_end_encoding_array_empty() {
1348        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1349        let values = Int64Array::from(vec![7, -2, 9, -8]);
1350        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1351        let b = BooleanArray::from(vec![
1352            false, false, false, false, false, false, false, false, false, false,
1353        ]);
1354        let c = filter(&a, &b).unwrap();
1355        let actual: &RunArray<Int64Type> = as_run_array(&c);
1356        assert_eq!(0, actual.len());
1357    }
1358
1359    #[test]
1360    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1361        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1362        let values = Int64Array::from(vec![7, -2, 9, -8]);
1363        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1364        let b = BooleanArray::from(vec![false, true, true]);
1365        let c = filter(&a, &b).unwrap();
1366        let actual: &RunArray<Int64Type> = as_run_array(&c);
1367        assert_eq!(2, actual.len());
1368
1369        let expected = RunArray::try_new(
1370            &Int64Array::from(vec![1, 2]),
1371            &Int64Array::from(vec![7, -2]),
1372        )
1373        .expect("Failed to make expected RunArray test is broken");
1374
1375        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1376        assert_eq!(actual.values(), expected.values())
1377    }
1378
1379    #[test]
1380    fn test_filter_dictionary_array() {
1381        let values = [Some("hello"), None, Some("world"), Some("!")];
1382        let a: Int8DictionaryArray = values.iter().copied().collect();
1383        let b = BooleanArray::from(vec![false, true, true, false]);
1384        let c = filter(&a, &b).unwrap();
1385        let d = c
1386            .as_ref()
1387            .as_any()
1388            .downcast_ref::<Int8DictionaryArray>()
1389            .unwrap();
1390        let value_array = d.values();
1391        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1392        assert_eq!(3, values.len());
1394        assert_eq!(2, d.len());
1396        assert!(d.is_null(0));
1397        assert_eq!("world", values.value(d.keys().value(1) as usize));
1398    }
1399
1400    #[test]
1401    fn test_filter_list_array() {
1402        let value_data = ArrayData::builder(DataType::Int32)
1403            .len(8)
1404            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1405            .build()
1406            .unwrap();
1407
1408        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1409
1410        let list_data_type =
1411            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1412        let list_data = ArrayData::builder(list_data_type)
1413            .len(4)
1414            .add_buffer(value_offsets)
1415            .add_child_data(value_data)
1416            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1417            .build()
1418            .unwrap();
1419
1420        let a = LargeListArray::from(list_data);
1422        let b = BooleanArray::from(vec![false, true, false, true]);
1423        let result = filter(&a, &b).unwrap();
1424
1425        let value_data = ArrayData::builder(DataType::Int32)
1427            .len(3)
1428            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1429            .build()
1430            .unwrap();
1431
1432        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1433
1434        let list_data_type =
1435            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1436        let expected = ArrayData::builder(list_data_type)
1437            .len(2)
1438            .add_buffer(value_offsets)
1439            .add_child_data(value_data)
1440            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1441            .build()
1442            .unwrap();
1443
1444        assert_eq!(&make_array(expected), &result);
1445    }
1446
1447    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1448        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1450        list_array.append_value([Some(1), Some(2)]);
1451        list_array.append_null();
1452        list_array.append_value([]);
1453        list_array.append_value([Some(3), Some(4)]);
1454
1455        let list_array = list_array.finish();
1456        let predicate = BooleanArray::from_iter([true, false, true, false]);
1457
1458        let filtered = filter(&list_array, &predicate)
1460            .unwrap()
1461            .as_list_view::<T>()
1462            .clone();
1463
1464        let mut expected =
1465            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1466        expected.append_value([Some(1), Some(2)]);
1467        expected.append_value([]);
1468        let expected = expected.finish();
1469
1470        assert_eq!(&filtered, &expected);
1471    }
1472
1473    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1474        let mut list_array =
1476            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1477        list_array.append_value([Some(1), Some(2)]);
1478        list_array.append_null();
1479        list_array.append_value([]);
1480        list_array.append_value([Some(3), Some(4)]);
1481
1482        let list_array = list_array.finish();
1483
1484        let sliced = list_array.slice(1, 3);
1486        let predicate = BooleanArray::from_iter([false, false, true]);
1487
1488        let filtered = filter(&sliced, &predicate)
1490            .unwrap()
1491            .as_list_view::<T>()
1492            .clone();
1493
1494        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1495        expected.append_value([Some(3), Some(4)]);
1496        let expected = expected.finish();
1497
1498        assert_eq!(&filtered, &expected);
1499    }
1500
1501    #[test]
1502    fn test_filter_list_view_array() {
1503        test_case_filter_list_view::<i32>();
1504        test_case_filter_list_view::<i64>();
1505
1506        test_case_filter_sliced_list_view::<i32>();
1507        test_case_filter_sliced_list_view::<i64>();
1508    }
1509
1510    #[test]
1511    fn test_slice_iterator_bits() {
1512        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1513        let filter = BooleanArray::from(filter_values);
1514        let filter_count = filter_count(&filter);
1515
1516        let iter = SlicesIterator::new(&filter);
1517        let chunks = iter.collect::<Vec<_>>();
1518
1519        assert_eq!(chunks, vec![(1, 2)]);
1520        assert_eq!(filter_count, 1);
1521    }
1522
1523    #[test]
1524    fn test_slice_iterator_bits1() {
1525        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1526        let filter = BooleanArray::from(filter_values);
1527        let filter_count = filter_count(&filter);
1528
1529        let iter = SlicesIterator::new(&filter);
1530        let chunks = iter.collect::<Vec<_>>();
1531
1532        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1533        assert_eq!(filter_count, 64 - 1);
1534    }
1535
1536    #[test]
1537    fn test_slice_iterator_chunk_and_bits() {
1538        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1539        let filter = BooleanArray::from(filter_values);
1540        let filter_count = filter_count(&filter);
1541
1542        let iter = SlicesIterator::new(&filter);
1543        let chunks = iter.collect::<Vec<_>>();
1544
1545        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1546        assert_eq!(filter_count, 61 + 61 + 5);
1547    }
1548
1549    #[test]
1550    fn test_null_mask() {
1551        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1552
1553        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1554        let out = filter(&a, &mask1).unwrap();
1555        assert_eq!(out.as_ref(), &a.slice(0, 2));
1556    }
1557
1558    #[test]
1559    fn test_filter_record_batch_no_columns() {
1560        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1561        let options = RecordBatchOptions::default().with_row_count(Some(100));
1562        let record_batch =
1563            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1564        let out = filter_record_batch(&record_batch, &pred).unwrap();
1565
1566        assert_eq!(out.num_rows(), 2);
1567    }
1568
1569    #[test]
1570    fn test_fast_path() {
1571        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1572
1573        let mask = BooleanArray::from(vec![true, true, true]);
1575        let out = filter(&a, &mask).unwrap();
1576        let b = out
1577            .as_any()
1578            .downcast_ref::<PrimitiveArray<Int64Type>>()
1579            .unwrap();
1580        assert_eq!(&a, b);
1581
1582        let mask = BooleanArray::from(vec![false, false, false]);
1584        let out = filter(&a, &mask).unwrap();
1585        assert_eq!(out.len(), 0);
1586        assert_eq!(out.data_type(), &DataType::Int64);
1587    }
1588
1589    #[test]
1590    fn test_slices() {
1591        let bools = std::iter::repeat_n(true, 10)
1593            .chain(std::iter::repeat_n(false, 30))
1594            .chain(std::iter::repeat_n(true, 20))
1595            .chain(std::iter::repeat_n(false, 17))
1596            .chain(std::iter::repeat_n(true, 4));
1597
1598        let bool_array: BooleanArray = bools.map(Some).collect();
1599
1600        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1601        let expected = vec![(0, 10), (40, 60), (77, 81)];
1602        assert_eq!(slices, expected);
1603
1604        let len = bool_array.len();
1606        let sliced_array = bool_array.slice(7, len - 10);
1607        let sliced_array = sliced_array
1608            .as_any()
1609            .downcast_ref::<BooleanArray>()
1610            .unwrap();
1611        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1612        let expected = vec![(0, 3), (33, 53), (70, 71)];
1613        assert_eq!(slices, expected);
1614    }
1615
1616    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1617        let mut rng = rng();
1618
1619        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1620            .take(mask_len)
1621            .collect();
1622
1623        let buffer = Buffer::from_iter(bools.iter().cloned());
1624
1625        let truncated_length = mask_len - offset - truncate;
1626
1627        let data = ArrayDataBuilder::new(DataType::Boolean)
1628            .len(truncated_length)
1629            .offset(offset)
1630            .add_buffer(buffer)
1631            .build()
1632            .unwrap();
1633
1634        let filter = BooleanArray::from(data);
1635
1636        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1637            .flat_map(|(start, end)| start..end)
1638            .collect();
1639
1640        let count = filter_count(&filter);
1641        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1642
1643        let expected_bits: Vec<_> = bools
1644            .iter()
1645            .skip(offset)
1646            .take(truncated_length)
1647            .enumerate()
1648            .flat_map(|(idx, v)| v.then(|| idx))
1649            .collect();
1650
1651        assert_eq!(slice_bits, expected_bits);
1652        assert_eq!(index_bits, expected_bits);
1653    }
1654
1655    #[test]
1656    #[cfg_attr(miri, ignore)]
1657    fn fuzz_test_slices_iterator() {
1658        let mut rng = rng();
1659
1660        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1661        for _ in 0..100 {
1662            let mask_len = rng.random_range(0..1024);
1663            let max_offset = 64.min(mask_len);
1664            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1665
1666            let max_truncate = 128.min(mask_len - offset);
1667            let truncate = uusize
1668                .sample(&mut rng)
1669                .checked_rem(max_truncate)
1670                .unwrap_or(0);
1671
1672            test_slices_fuzz(mask_len, offset, truncate);
1673        }
1674
1675        test_slices_fuzz(64, 0, 0);
1676        test_slices_fuzz(64, 8, 0);
1677        test_slices_fuzz(64, 8, 8);
1678        test_slices_fuzz(32, 8, 8);
1679        test_slices_fuzz(32, 5, 9);
1680    }
1681
1682    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1684        values
1685            .into_iter()
1686            .zip(predicate)
1687            .filter(|(_, x)| **x)
1688            .map(|(a, _)| a)
1689            .collect()
1690    }
1691
1692    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1694    where
1695        StandardUniform: Distribution<T>,
1696    {
1697        let mut rng = rng();
1698        (0..len)
1699            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1700            .collect()
1701    }
1702
1703    fn gen_strings(
1705        len: usize,
1706        valid_percent: f64,
1707        str_len_range: std::ops::Range<usize>,
1708    ) -> Vec<Option<String>> {
1709        let mut rng = rng();
1710        (0..len)
1711            .map(|_| {
1712                rng.random_bool(valid_percent).then(|| {
1713                    let len = rng.random_range(str_len_range.clone());
1714                    (0..len)
1715                        .map(|_| char::from(rng.sample(Alphanumeric)))
1716                        .collect()
1717                })
1718            })
1719            .collect()
1720    }
1721
1722    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1724        src.iter().map(|x| x.as_deref())
1725    }
1726
1727    #[test]
1728    #[cfg_attr(miri, ignore)]
1729    fn fuzz_filter() {
1730        let mut rng = rng();
1731
1732        for i in 0..100 {
1733            let filter_percent = match i {
1734                0..=4 => 1.,
1735                5..=10 => 0.,
1736                _ => rng.random_range(0.0..1.0),
1737            };
1738
1739            let valid_percent = rng.random_range(0.0..1.0);
1740
1741            let array_len = rng.random_range(32..256);
1742            let array_offset = rng.random_range(0..10);
1743
1744            let filter_offset = rng.random_range(0..10);
1746            let filter_truncate = rng.random_range(0..10);
1747            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1748                .take(array_len + filter_offset - filter_truncate)
1749                .collect();
1750
1751            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1752
1753            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1755            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1756            let bools = &bools[filter_offset..];
1757
1758            let values = gen_primitive(array_len + array_offset, valid_percent);
1760            let src = Int32Array::from_iter(values.iter().cloned());
1761
1762            let src = src.slice(array_offset, array_len);
1763            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1764            let values = &values[array_offset..];
1765
1766            let filtered = filter(src, predicate).unwrap();
1767            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1768            let actual: Vec<_> = array.iter().collect();
1769
1770            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1771
1772            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1774            let src = StringArray::from_iter(as_deref(&strings));
1775
1776            let src = src.slice(array_offset, array_len);
1777            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1778
1779            let filtered = filter(src, predicate).unwrap();
1780            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1781            let actual: Vec<_> = array.iter().collect();
1782
1783            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1784            assert_eq!(actual, expected_strings);
1785
1786            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1788
1789            let src = src.slice(array_offset, array_len);
1790            let src = src
1791                .as_any()
1792                .downcast_ref::<DictionaryArray<Int32Type>>()
1793                .unwrap();
1794
1795            let filtered = filter(src, predicate).unwrap();
1796
1797            let array = filtered
1798                .as_any()
1799                .downcast_ref::<DictionaryArray<Int32Type>>()
1800                .unwrap();
1801
1802            let values = array
1803                .values()
1804                .as_any()
1805                .downcast_ref::<StringArray>()
1806                .unwrap();
1807
1808            let actual: Vec<_> = array
1809                .keys()
1810                .iter()
1811                .map(|key| key.map(|key| values.value(key as usize)))
1812                .collect();
1813
1814            assert_eq!(actual, expected_strings);
1815        }
1816    }
1817
1818    #[test]
1819    fn test_filter_map() {
1820        let mut builder =
1821            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1822        builder.keys().append_value("key1");
1824        builder.values().append_value(1);
1825        builder.append(true).unwrap();
1826        builder.keys().append_value("key2");
1827        builder.keys().append_value("key3");
1828        builder.values().append_value(2);
1829        builder.values().append_value(3);
1830        builder.append(true).unwrap();
1831        builder.append(false).unwrap();
1832        builder.keys().append_value("key1");
1833        builder.values().append_value(1);
1834        builder.append(true).unwrap();
1835        let maparray = Arc::new(builder.finish()) as ArrayRef;
1836
1837        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1838            .into_iter()
1839            .collect::<BooleanArray>();
1840        let got = filter(&maparray, &indices).unwrap();
1841
1842        let mut builder =
1843            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1844        builder.keys().append_value("key1");
1845        builder.values().append_value(1);
1846        builder.append(true).unwrap();
1847        builder.keys().append_value("key1");
1848        builder.values().append_value(1);
1849        builder.append(true).unwrap();
1850        let expected = Arc::new(builder.finish()) as ArrayRef;
1851
1852        assert_eq!(&expected, &got);
1853    }
1854
1855    #[test]
1856    fn test_filter_fixed_size_list_arrays() {
1857        let value_data = ArrayData::builder(DataType::Int32)
1858            .len(9)
1859            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1860            .build()
1861            .unwrap();
1862        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1863        let list_data = ArrayData::builder(list_data_type)
1864            .len(3)
1865            .add_child_data(value_data)
1866            .build()
1867            .unwrap();
1868        let array = FixedSizeListArray::from(list_data);
1869
1870        let filter_array = BooleanArray::from(vec![true, false, false]);
1871
1872        let c = filter(&array, &filter_array).unwrap();
1873        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1874
1875        assert_eq!(filtered.len(), 1);
1876
1877        let list = filtered.value(0);
1878        assert_eq!(
1879            &[0, 1, 2],
1880            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1881        );
1882
1883        let filter_array = BooleanArray::from(vec![true, false, true]);
1884
1885        let c = filter(&array, &filter_array).unwrap();
1886        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1887
1888        assert_eq!(filtered.len(), 2);
1889
1890        let list = filtered.value(0);
1891        assert_eq!(
1892            &[0, 1, 2],
1893            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1894        );
1895        let list = filtered.value(1);
1896        assert_eq!(
1897            &[6, 7, 8],
1898            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1899        );
1900    }
1901
1902    #[test]
1903    fn test_filter_fixed_size_list_arrays_with_null() {
1904        let value_data = ArrayData::builder(DataType::Int32)
1905            .len(10)
1906            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1907            .build()
1908            .unwrap();
1909
1910        let mut null_bits: [u8; 1] = [0; 1];
1914        bit_util::set_bit(&mut null_bits, 0);
1915        bit_util::set_bit(&mut null_bits, 3);
1916        bit_util::set_bit(&mut null_bits, 4);
1917
1918        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1919        let list_data = ArrayData::builder(list_data_type)
1920            .len(5)
1921            .add_child_data(value_data)
1922            .null_bit_buffer(Some(Buffer::from(null_bits)))
1923            .build()
1924            .unwrap();
1925        let array = FixedSizeListArray::from(list_data);
1926
1927        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1928
1929        let c = filter(&array, &filter_array).unwrap();
1930        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1931
1932        assert_eq!(filtered.len(), 3);
1933
1934        let list = filtered.value(0);
1935        assert_eq!(
1936            &[0, 1],
1937            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1938        );
1939        assert!(filtered.is_null(1));
1940        let list = filtered.value(2);
1941        assert_eq!(
1942            &[6, 7],
1943            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1944        );
1945    }
1946
1947    fn test_filter_union_array(array: UnionArray) {
1948        let filter_array = BooleanArray::from(vec![true, false, false]);
1949        let c = filter(&array, &filter_array).unwrap();
1950        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1951
1952        let mut builder = UnionBuilder::new_dense();
1953        builder.append::<Int32Type>("A", 1).unwrap();
1954        let expected_array = builder.build().unwrap();
1955
1956        compare_union_arrays(filtered, &expected_array);
1957
1958        let filter_array = BooleanArray::from(vec![true, false, true]);
1959        let c = filter(&array, &filter_array).unwrap();
1960        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1961
1962        let mut builder = UnionBuilder::new_dense();
1963        builder.append::<Int32Type>("A", 1).unwrap();
1964        builder.append::<Int32Type>("A", 34).unwrap();
1965        let expected_array = builder.build().unwrap();
1966
1967        compare_union_arrays(filtered, &expected_array);
1968
1969        let filter_array = BooleanArray::from(vec![true, true, false]);
1970        let c = filter(&array, &filter_array).unwrap();
1971        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1972
1973        let mut builder = UnionBuilder::new_dense();
1974        builder.append::<Int32Type>("A", 1).unwrap();
1975        builder.append::<Float64Type>("B", 3.2).unwrap();
1976        let expected_array = builder.build().unwrap();
1977
1978        compare_union_arrays(filtered, &expected_array);
1979    }
1980
1981    #[test]
1982    fn test_filter_union_array_dense() {
1983        let mut builder = UnionBuilder::new_dense();
1984        builder.append::<Int32Type>("A", 1).unwrap();
1985        builder.append::<Float64Type>("B", 3.2).unwrap();
1986        builder.append::<Int32Type>("A", 34).unwrap();
1987        let array = builder.build().unwrap();
1988
1989        test_filter_union_array(array);
1990    }
1991
1992    #[test]
1993    fn test_filter_run_union_array_dense() {
1994        let mut builder = UnionBuilder::new_dense();
1995        builder.append::<Int32Type>("A", 1).unwrap();
1996        builder.append::<Int32Type>("A", 3).unwrap();
1997        builder.append::<Int32Type>("A", 34).unwrap();
1998        let array = builder.build().unwrap();
1999
2000        let filter_array = BooleanArray::from(vec![true, true, false]);
2001        let c = filter(&array, &filter_array).unwrap();
2002        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2003
2004        let mut builder = UnionBuilder::new_dense();
2005        builder.append::<Int32Type>("A", 1).unwrap();
2006        builder.append::<Int32Type>("A", 3).unwrap();
2007        let expected = builder.build().unwrap();
2008
2009        assert_eq!(filtered.to_data(), expected.to_data());
2010    }
2011
2012    #[test]
2013    fn test_filter_union_array_dense_with_nulls() {
2014        let mut builder = UnionBuilder::new_dense();
2015        builder.append::<Int32Type>("A", 1).unwrap();
2016        builder.append::<Float64Type>("B", 3.2).unwrap();
2017        builder.append_null::<Float64Type>("B").unwrap();
2018        builder.append::<Int32Type>("A", 34).unwrap();
2019        let array = builder.build().unwrap();
2020
2021        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2022        let c = filter(&array, &filter_array).unwrap();
2023        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2024
2025        let mut builder = UnionBuilder::new_dense();
2026        builder.append::<Int32Type>("A", 1).unwrap();
2027        builder.append::<Float64Type>("B", 3.2).unwrap();
2028        let expected_array = builder.build().unwrap();
2029
2030        compare_union_arrays(filtered, &expected_array);
2031
2032        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2033        let c = filter(&array, &filter_array).unwrap();
2034        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2035
2036        let mut builder = UnionBuilder::new_dense();
2037        builder.append::<Int32Type>("A", 1).unwrap();
2038        builder.append_null::<Float64Type>("B").unwrap();
2039        let expected_array = builder.build().unwrap();
2040
2041        compare_union_arrays(filtered, &expected_array);
2042    }
2043
2044    #[test]
2045    fn test_filter_union_array_sparse() {
2046        let mut builder = UnionBuilder::new_sparse();
2047        builder.append::<Int32Type>("A", 1).unwrap();
2048        builder.append::<Float64Type>("B", 3.2).unwrap();
2049        builder.append::<Int32Type>("A", 34).unwrap();
2050        let array = builder.build().unwrap();
2051
2052        test_filter_union_array(array);
2053    }
2054
2055    #[test]
2056    fn test_filter_union_array_sparse_with_nulls() {
2057        let mut builder = UnionBuilder::new_sparse();
2058        builder.append::<Int32Type>("A", 1).unwrap();
2059        builder.append::<Float64Type>("B", 3.2).unwrap();
2060        builder.append_null::<Float64Type>("B").unwrap();
2061        builder.append::<Int32Type>("A", 34).unwrap();
2062        let array = builder.build().unwrap();
2063
2064        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2065        let c = filter(&array, &filter_array).unwrap();
2066        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2067
2068        let mut builder = UnionBuilder::new_sparse();
2069        builder.append::<Int32Type>("A", 1).unwrap();
2070        builder.append_null::<Float64Type>("B").unwrap();
2071        let expected_array = builder.build().unwrap();
2072
2073        compare_union_arrays(filtered, &expected_array);
2074    }
2075
2076    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2077        assert_eq!(union1.len(), union2.len());
2078
2079        for i in 0..union1.len() {
2080            let type_id = union1.type_id(i);
2081
2082            let slot1 = union1.value(i);
2083            let slot2 = union2.value(i);
2084
2085            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2086
2087            if !slot1.is_null(0) && !slot2.is_null(0) {
2088                match type_id {
2089                    0 => {
2090                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2091                        assert_eq!(slot1.len(), 1);
2092                        let value1 = slot1.value(0);
2093
2094                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2095                        assert_eq!(slot2.len(), 1);
2096                        let value2 = slot2.value(0);
2097                        assert_eq!(value1, value2);
2098                    }
2099                    1 => {
2100                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2101                        assert_eq!(slot1.len(), 1);
2102                        let value1 = slot1.value(0);
2103
2104                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2105                        assert_eq!(slot2.len(), 1);
2106                        let value2 = slot2.value(0);
2107                        assert_eq!(value1, value2);
2108                    }
2109                    _ => unreachable!(),
2110                }
2111            }
2112        }
2113    }
2114
2115    #[test]
2116    fn test_filter_struct() {
2117        let predicate = BooleanArray::from(vec![true, false, true, false]);
2118
2119        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2120        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2121
2122        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2123        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2124
2125        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2126        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2127
2128        let a_field = Field::new("a", DataType::Utf8, false);
2129        let b_field = Field::new("b", DataType::Int32, false);
2130
2131        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2132        let expected =
2133            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2134
2135        let result = filter(&array, &predicate).unwrap();
2136
2137        assert_eq!(result.to_data(), expected.to_data());
2138
2139        let array = StructArray::new(
2140            vec![a_field.clone()].into(),
2141            vec![a.clone()],
2142            Some(null_mask.clone()),
2143        );
2144        let expected = StructArray::new(
2145            vec![a_field.clone()].into(),
2146            vec![a_filtered.clone()],
2147            Some(null_mask_filtered.clone()),
2148        );
2149
2150        let result = filter(&array, &predicate).unwrap();
2151
2152        assert_eq!(result.to_data(), expected.to_data());
2153
2154        let array = StructArray::new(
2155            vec![a_field.clone(), b_field.clone()].into(),
2156            vec![a.clone(), b.clone()],
2157            None,
2158        );
2159        let expected = StructArray::new(
2160            vec![a_field.clone(), b_field.clone()].into(),
2161            vec![a_filtered.clone(), b_filtered.clone()],
2162            None,
2163        );
2164
2165        let result = filter(&array, &predicate).unwrap();
2166
2167        assert_eq!(result.to_data(), expected.to_data());
2168
2169        let array = StructArray::new(
2170            vec![a_field.clone(), b_field.clone()].into(),
2171            vec![a.clone(), b.clone()],
2172            Some(null_mask.clone()),
2173        );
2174
2175        let expected = StructArray::new(
2176            vec![a_field.clone(), b_field.clone()].into(),
2177            vec![a_filtered.clone(), b_filtered.clone()],
2178            Some(null_mask_filtered.clone()),
2179        );
2180
2181        let result = filter(&array, &predicate).unwrap();
2182
2183        assert_eq!(result.to_data(), expected.to_data());
2184    }
2185
2186    #[test]
2187    fn test_filter_empty_struct() {
2188        let fields = arrow_schema::Field::new(
2195            "a",
2196            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2197                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2198                arrow_schema::Field::new(
2199                    "c",
2200                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2201                    true,
2202                ),
2203            ])),
2204            true,
2205        );
2206
2207        let schema = Arc::new(Schema::new(vec![fields]));
2215
2216        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2217        let c = Arc::new(StructArray::new_empty_fields(
2218            3,
2219            Some(NullBuffer::from(vec![true, true, true])),
2220        ));
2221        let a = StructArray::new(
2222            vec![
2223                Field::new("b", DataType::Int64, true),
2224                Field::new("c", DataType::Struct(Fields::empty()), true),
2225            ]
2226            .into(),
2227            vec![b.clone(), c.clone()],
2228            Some(NullBuffer::from(vec![true, true, true])),
2229        );
2230        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2231        println!("{record_batch:?}");
2232
2233        let predicate = BooleanArray::from(vec![true, false, true]);
2235        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2236
2237        assert_eq!(filtered_batch.num_rows(), 2);
2239    }
2240}