arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        filter.values().into()
62    }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66    fn from(filter: &'a BooleanBuffer) -> Self {
67        Self(filter.set_slices())
68    }
69}
70
71impl Iterator for SlicesIterator<'_> {
72    type Item = (usize, usize);
73
74    fn next(&mut self) -> Option<Self::Item> {
75        self.0.next()
76    }
77}
78
79/// An iterator of `usize` whose index in [`BooleanArray`] is true
80///
81/// This provides the best performance on most predicates, apart from those which keep
82/// large runs and therefore favour [`SlicesIterator`]
83struct IndexIterator<'a> {
84    remaining: usize,
85    iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90        assert_eq!(filter.null_count(), 0);
91        let iter = filter.values().set_indices();
92        Self { remaining, iter }
93    }
94}
95
96impl Iterator for IndexIterator<'_> {
97    type Item = usize;
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if self.remaining != 0 {
101            // Fascinatingly swapping these two lines around results in a 50%
102            // performance regression for some benchmarks
103            let next = self.iter.next().expect("IndexIterator exhausted early");
104            self.remaining -= 1;
105            // Must panic if exhausted early as trusted length iterator
106            return Some(next);
107        }
108        None
109    }
110
111    fn size_hint(&self) -> (usize, Option<usize>) {
112        (self.remaining, Some(self.remaining))
113    }
114}
115
116/// Counts the number of set bits in `filter`
117fn filter_count(filter: &BooleanArray) -> usize {
118    filter.values().count_set_bits()
119}
120
121/// Convert all null values in `BooleanArray` to `false`
122///
123/// This is useful for filter-like operations which select only `true`
124/// values, but not `false` or `NULL` values
125///
126/// Internally this is implemented as a bitwise `AND` operation with null bits
127/// and the boolean bits.
128///
129/// # Example
130/// ```
131/// # use arrow_array::{Array, BooleanArray};
132/// # use arrow_select::filter::prep_null_mask_filter;
133/// let filter = BooleanArray::from(vec![
134///   Some(true),
135///   Some(false),
136///   None
137/// ]);
138/// // convert Boolean array to a filter mask
139/// let null_mask = prep_null_mask_filter(&filter);
140/// // there are no nulls in the output mask
141/// assert!(null_mask.nulls().is_none());
142/// assert_eq!(null_mask, BooleanArray::from(vec![
143///  true,
144///  false,
145///  false, // Null is converted to false
146/// ]));
147/// ```
148pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
149    let nulls = filter.nulls().unwrap();
150    let mask = filter.values() & nulls.inner();
151    BooleanArray::new(mask, None)
152}
153
154/// Returns a filtered `values` [`Array`] where the corresponding elements of
155/// `predicate` are `true`.
156///
157/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
158/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
159/// calling [FilterPredicate::filter_record_batch].
160///
161/// In contrast to this function, it is then the responsibility of the caller
162/// to use [FilterBuilder::optimize] if appropriate.
163///
164/// # See also
165/// * [`FilterBuilder`] for more control over the filtering process.
166/// * [`filter_record_batch`] to filter a [`RecordBatch`]
167/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
168///   the results into a single array.
169///
170/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
171///
172/// # Example
173/// ```rust
174/// # use arrow_array::{Int32Array, BooleanArray};
175/// # use arrow_select::filter::filter;
176/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
177/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
178/// let c = filter(&array, &filter_array).unwrap();
179/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
180/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
181/// ```
182pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
183    let mut filter_builder = FilterBuilder::new(predicate);
184
185    if FilterBuilder::is_optimize_beneficial(values.data_type()) {
186        // Only optimize if filtering more than one array
187        // Otherwise, the overhead of optimization can be more than the benefit
188        filter_builder = filter_builder.optimize();
189    }
190
191    let predicate = filter_builder.build();
192
193    filter_array(values, &predicate)
194}
195
196/// Returns a filtered [RecordBatch] where the corresponding elements of
197/// `predicate` are true.
198///
199/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
200///
201/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
202/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
203/// calling [FilterPredicate::filter_record_batch].
204/// In contrast to this function, it is then the responsibility of the caller
205/// to use [FilterBuilder::optimize] if appropriate.
206pub fn filter_record_batch(
207    record_batch: &RecordBatch,
208    predicate: &BooleanArray,
209) -> Result<RecordBatch, ArrowError> {
210    let mut filter_builder = FilterBuilder::new(predicate);
211    let num_cols = record_batch.num_columns();
212    if num_cols > 1
213        || (num_cols > 0
214            && FilterBuilder::is_optimize_beneficial(
215                record_batch.schema_ref().field(0).data_type(),
216            ))
217    {
218        // Only optimize if filtering more than one column or if the column contains multiple internal arrays
219        // Otherwise, the overhead of optimization can be more than the benefit
220        filter_builder = filter_builder.optimize();
221    }
222    let filter = filter_builder.build();
223
224    filter.filter_record_batch(record_batch)
225}
226
227/// A builder to construct [`FilterPredicate`]
228#[derive(Debug)]
229pub struct FilterBuilder {
230    filter: BooleanArray,
231    count: usize,
232    strategy: IterationStrategy,
233}
234
235impl FilterBuilder {
236    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
237    pub fn new(filter: &BooleanArray) -> Self {
238        let filter = match filter.null_count() {
239            0 => filter.clone(),
240            _ => prep_null_mask_filter(filter),
241        };
242
243        let count = filter_count(&filter);
244        let strategy = IterationStrategy::default_strategy(filter.len(), count);
245
246        Self {
247            filter,
248            count,
249            strategy,
250        }
251    }
252
253    /// Compute an optimized representation of the provided `filter` mask that can be
254    /// applied to an array more quickly.
255    ///
256    /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a
257    /// [`StructArray`] with multiple fields), optimizing the filter can provide
258    /// significant performance benefits.
259    ///
260    /// However, optimization takes time and can have a larger memory footprint
261    /// than the original mask, so it is often faster to filter a single array,
262    /// without filter optimization.
263    pub fn optimize(mut self) -> Self {
264        match self.strategy {
265            IterationStrategy::SlicesIterator => {
266                let slices = SlicesIterator::new(&self.filter).collect();
267                self.strategy = IterationStrategy::Slices(slices)
268            }
269            IterationStrategy::IndexIterator => {
270                let indices = IndexIterator::new(&self.filter, self.count).collect();
271                self.strategy = IterationStrategy::Indices(indices)
272            }
273            _ => {}
274        }
275        self
276    }
277
278    /// Determines if calling [FilterBuilder::optimize] is beneficial for the
279    /// given type even when filtering just a single array.
280    ///
281    /// See [`FilterBuilder::optimize`] for more details.
282    pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
283        match data_type {
284            DataType::Struct(fields) => {
285                fields.len() > 1
286                    || fields.len() == 1
287                        && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
288            }
289            DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
290            _ => false,
291        }
292    }
293
294    /// Construct the final `FilterPredicate`
295    pub fn build(self) -> FilterPredicate {
296        FilterPredicate {
297            filter: self.filter,
298            count: self.count,
299            strategy: self.strategy,
300        }
301    }
302}
303
304/// The iteration strategy used to evaluate [`FilterPredicate`]
305#[derive(Debug)]
306enum IterationStrategy {
307    /// A lazily evaluated iterator of ranges
308    SlicesIterator,
309    /// A lazily evaluated iterator of indices
310    IndexIterator,
311    /// A precomputed list of indices
312    Indices(Vec<usize>),
313    /// A precomputed array of ranges
314    Slices(Vec<(usize, usize)>),
315    /// Select all rows
316    All,
317    /// Select no rows
318    None,
319}
320
321impl IterationStrategy {
322    /// The default [`IterationStrategy`] for a filter of length `filter_length`
323    /// and selecting `filter_count` rows
324    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
325        if filter_length == 0 || filter_count == 0 {
326            return IterationStrategy::None;
327        }
328
329        if filter_count == filter_length {
330            return IterationStrategy::All;
331        }
332
333        // Compute the selectivity of the predicate by dividing the number of true
334        // bits in the predicate by the predicate's total length
335        //
336        // This can then be used as a heuristic for the optimal iteration strategy
337        let selectivity_frac = filter_count as f64 / filter_length as f64;
338        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
339            return IterationStrategy::SlicesIterator;
340        }
341        IterationStrategy::IndexIterator
342    }
343}
344
345/// A filtering predicate that can be applied to an [`Array`]
346#[derive(Debug)]
347pub struct FilterPredicate {
348    filter: BooleanArray,
349    count: usize,
350    strategy: IterationStrategy,
351}
352
353impl FilterPredicate {
354    /// Selects rows from `values` based on this [`FilterPredicate`]
355    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
356        filter_array(values, self)
357    }
358
359    /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
360    /// [`FilterPredicate`].
361    ///
362    /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
363    pub fn filter_record_batch(
364        &self,
365        record_batch: &RecordBatch,
366    ) -> Result<RecordBatch, ArrowError> {
367        let filtered_arrays = record_batch
368            .columns()
369            .iter()
370            .map(|a| filter_array(a, self))
371            .collect::<Result<Vec<_>, _>>()?;
372
373        // SAFETY: we know that the set of filtered arrays will match the schema of the original
374        // record batch
375        unsafe {
376            Ok(RecordBatch::new_unchecked(
377                record_batch.schema(),
378                filtered_arrays,
379                self.count,
380            ))
381        }
382    }
383
384    /// Number of rows being selected based on this [`FilterPredicate`]
385    pub fn count(&self) -> usize {
386        self.count
387    }
388}
389
390fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
391    if predicate.filter.len() > values.len() {
392        return Err(ArrowError::InvalidArgumentError(format!(
393            "Filter predicate of length {} is larger than target array of length {}",
394            predicate.filter.len(),
395            values.len()
396        )));
397    }
398
399    match predicate.strategy {
400        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
401        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
402        // actually filter
403        _ => downcast_primitive_array! {
404            values => Ok(Arc::new(filter_primitive(values, predicate))),
405            DataType::Boolean => {
406                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
407                Ok(Arc::new(filter_boolean(values, predicate)))
408            }
409            DataType::Utf8 => {
410                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
411            }
412            DataType::LargeUtf8 => {
413                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
414            }
415            DataType::Utf8View => {
416                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
417            }
418            DataType::Binary => {
419                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
420            }
421            DataType::LargeBinary => {
422                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
423            }
424            DataType::BinaryView => {
425                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
426            }
427            DataType::FixedSizeBinary(_) => {
428                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
429            }
430            DataType::ListView(_) => {
431                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
432            }
433            DataType::LargeListView(_) => {
434                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
435            }
436            DataType::RunEndEncoded(_, _) => {
437                downcast_run_array!{
438                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
439                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
440                }
441            }
442            DataType::Dictionary(_, _) => downcast_dictionary_array! {
443                values => Ok(Arc::new(filter_dict(values, predicate))),
444                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
445            }
446            DataType::Struct(_) => {
447                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
448            }
449            DataType::Union(_, UnionMode::Sparse) => {
450                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
451            }
452            _ => {
453                let data = values.to_data();
454                // fallback to using MutableArrayData
455                let mut mutable = MutableArrayData::new(
456                    vec![&data],
457                    false,
458                    predicate.count,
459                );
460
461                match &predicate.strategy {
462                    IterationStrategy::Slices(slices) => {
463                        slices
464                            .iter()
465                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
466                    }
467                    _ => {
468                        let iter = SlicesIterator::new(&predicate.filter);
469                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
470                    }
471                }
472
473                let data = mutable.freeze();
474                Ok(make_array(data))
475            }
476        },
477    }
478}
479
480/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
481fn filter_run_end_array<R: RunEndIndexType>(
482    array: &RunArray<R>,
483    predicate: &FilterPredicate,
484) -> Result<RunArray<R>, ArrowError>
485where
486    R::Native: Into<i64> + From<bool>,
487    R::Native: AddAssign,
488{
489    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
490    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
491
492    let mut start = 0u64;
493    let mut j = 0;
494    let mut count = R::default_value();
495    let filter_values = predicate.filter.values();
496    let run_ends = run_ends.inner();
497
498    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
499        let mut keep = false;
500        let mut end = run_ends[i].into() as u64;
501        let difference = end.saturating_sub(filter_values.len() as u64);
502        end -= difference;
503
504        // Safety: we subtract the difference off `end` so we are always within bounds
505        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
506            count += R::Native::from(pred);
507            keep |= pred
508        }
509        // this is to avoid branching
510        new_run_ends[j] = count;
511        j += keep as usize;
512
513        start = end;
514        keep
515    })
516    .into();
517
518    new_run_ends.truncate(j);
519
520    let values = array.values();
521    let values = filter(&values, &pred)?;
522
523    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
524    RunArray::try_new(&run_ends, &values)
525}
526
527/// Computes a new null mask for `data` based on `predicate`
528///
529/// If the predicate selected no null-rows, returns `None`, otherwise returns
530/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
531/// in the filtered output, and `null_buffer` is the filtered null buffer
532///
533fn filter_null_mask(
534    nulls: Option<&NullBuffer>,
535    predicate: &FilterPredicate,
536) -> Option<(usize, Buffer)> {
537    let nulls = nulls?;
538    if nulls.null_count() == 0 {
539        return None;
540    }
541
542    let nulls = filter_bits(nulls.inner(), predicate);
543    // The filtered `nulls` has a length of `predicate.count` bits and
544    // therefore the null count is this minus the number of valid bits
545    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
546
547    if null_count == 0 {
548        return None;
549    }
550
551    Some((null_count, nulls))
552}
553
554/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
555fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
556    let src = buffer.values();
557    let offset = buffer.offset();
558    assert!(buffer.len() >= predicate.filter.len());
559
560    match &predicate.strategy {
561        IterationStrategy::IndexIterator => {
562            let bits =
563                // SAFETY: IndexIterator uses the filter predicate to derive indices
564                IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
565                    bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
566                });
567
568            // SAFETY: `IndexIterator` reports its size correctly
569            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
570        }
571        IterationStrategy::Indices(indices) => {
572            // SAFETY: indices were derived from the filter predicate
573            let bits = indices.iter().map(|src_idx| unsafe {
574                bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
575            });
576            // SAFETY: `Vec::iter()` reports its size correctly
577            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
578        }
579        IterationStrategy::SlicesIterator => {
580            let mut builder = BooleanBufferBuilder::new(predicate.count);
581            for (start, end) in SlicesIterator::new(&predicate.filter) {
582                builder.append_packed_range(start + offset..end + offset, src)
583            }
584            builder.into()
585        }
586        IterationStrategy::Slices(slices) => {
587            let mut builder = BooleanBufferBuilder::new(predicate.count);
588            for (start, end) in slices {
589                builder.append_packed_range(*start + offset..*end + offset, src)
590            }
591            builder.into()
592        }
593        IterationStrategy::All | IterationStrategy::None => unreachable!(),
594    }
595}
596
597/// `filter` implementation for boolean buffers
598fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
599    let values = filter_bits(array.values(), predicate);
600
601    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
602        .len(predicate.count)
603        .add_buffer(values);
604
605    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
606        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
607    }
608
609    let data = unsafe { builder.build_unchecked() };
610    BooleanArray::from(data)
611}
612
613#[inline(never)]
614fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
615    assert!(values.len() >= predicate.filter.len());
616
617    match &predicate.strategy {
618        IterationStrategy::SlicesIterator => {
619            let mut buffer = Vec::with_capacity(predicate.count);
620            for (start, end) in SlicesIterator::new(&predicate.filter) {
621                // SAFETY: indices were derived from the filter predicate
622                buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
623            }
624            buffer.into()
625        }
626        IterationStrategy::Slices(slices) => {
627            let mut buffer = Vec::with_capacity(predicate.count);
628            for (start, end) in slices {
629                // SAFETY: indices were derived from the filter predicate
630                buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
631            }
632            buffer.into()
633        }
634        IterationStrategy::IndexIterator => {
635            // SAFETY: indices were derived from the filter predicate
636            let iter = IndexIterator::new(&predicate.filter, predicate.count)
637                .map(|x| unsafe { *values.get_unchecked(x) });
638
639            // SAFETY: IndexIterator is trusted length
640            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
641        }
642        IterationStrategy::Indices(indices) => {
643            // SAFETY: indices were derived from the filter predicate
644            let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
645            iter.collect::<Vec<_>>().into()
646        }
647        IterationStrategy::All | IterationStrategy::None => unreachable!(),
648    }
649}
650
651/// `filter` implementation for primitive arrays
652fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
653where
654    T: ArrowPrimitiveType,
655{
656    let values = array.values();
657    let buffer = filter_native(values, predicate);
658    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
659        .len(predicate.count)
660        .add_buffer(buffer);
661
662    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
663        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
664    }
665
666    let data = unsafe { builder.build_unchecked() };
667    PrimitiveArray::from(data)
668}
669
670/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
671/// used to build a new [`GenericByteArray`] by copying values from the source
672///
673/// TODO(raphael): Could this be used for the take kernel as well?
674struct FilterBytes<'a, OffsetSize> {
675    src_offsets: &'a [OffsetSize],
676    src_values: &'a [u8],
677    dst_offsets: Vec<OffsetSize>,
678    dst_values: Vec<u8>,
679    cur_offset: OffsetSize,
680}
681
682impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
683where
684    OffsetSize: OffsetSizeTrait,
685{
686    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
687    where
688        T: ByteArrayType<Offset = OffsetSize>,
689    {
690        let dst_values = Vec::new();
691        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
692        let cur_offset = OffsetSize::from_usize(0).unwrap();
693
694        dst_offsets.push(cur_offset);
695
696        Self {
697            src_offsets: array.value_offsets(),
698            src_values: array.value_data(),
699            dst_offsets,
700            dst_values,
701            cur_offset,
702        }
703    }
704
705    /// Returns the byte offset at `idx`
706    #[inline]
707    fn get_value_offset(&self, idx: usize) -> usize {
708        self.src_offsets[idx].as_usize()
709    }
710
711    /// Returns the start and end of the value at index `idx` along with its length
712    #[inline]
713    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
714        // These can only fail if `array` contains invalid data
715        let start = self.get_value_offset(idx);
716        let end = self.get_value_offset(idx + 1);
717        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
718        (start, end, len)
719    }
720
721    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
722        self.dst_offsets.extend(iter.map(|idx| {
723            let start = self.src_offsets[idx].as_usize();
724            let end = self.src_offsets[idx + 1].as_usize();
725            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
726            self.cur_offset += len;
727
728            self.cur_offset
729        }));
730    }
731
732    /// Extends the in-progress array by the indexes in the provided iterator
733    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
734        self.dst_values.reserve_exact(self.cur_offset.as_usize());
735
736        for idx in iter {
737            let start = self.src_offsets[idx].as_usize();
738            let end = self.src_offsets[idx + 1].as_usize();
739            self.dst_values
740                .extend_from_slice(&self.src_values[start..end]);
741        }
742    }
743
744    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
745        self.dst_offsets.reserve_exact(count);
746        for (start, end) in iter {
747            // These can only fail if `array` contains invalid data
748            for idx in start..end {
749                let (_, _, len) = self.get_value_range(idx);
750                self.cur_offset += len;
751                self.dst_offsets.push(self.cur_offset);
752            }
753        }
754    }
755
756    /// Extends the in-progress array by the ranges in the provided iterator
757    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
758        self.dst_values.reserve_exact(self.cur_offset.as_usize());
759
760        for (start, end) in iter {
761            let value_start = self.get_value_offset(start);
762            let value_end = self.get_value_offset(end);
763            self.dst_values
764                .extend_from_slice(&self.src_values[value_start..value_end]);
765        }
766    }
767}
768
769/// `filter` implementation for byte arrays
770///
771/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
772/// data copied across. This allows handling the null mask separately from the data
773fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
774where
775    T: ByteArrayType,
776{
777    let mut filter = FilterBytes::new(predicate.count, array);
778
779    match &predicate.strategy {
780        IterationStrategy::SlicesIterator => {
781            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
782            filter.extend_slices(SlicesIterator::new(&predicate.filter))
783        }
784        IterationStrategy::Slices(slices) => {
785            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
786            filter.extend_slices(slices.iter().cloned())
787        }
788        IterationStrategy::IndexIterator => {
789            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
790            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
791        }
792        IterationStrategy::Indices(indices) => {
793            filter.extend_offsets_idx(indices.iter().cloned());
794            filter.extend_idx(indices.iter().cloned())
795        }
796        IterationStrategy::All | IterationStrategy::None => unreachable!(),
797    }
798
799    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
800        .len(predicate.count)
801        .add_buffer(filter.dst_offsets.into())
802        .add_buffer(filter.dst_values.into());
803
804    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
805        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
806    }
807
808    let data = unsafe { builder.build_unchecked() };
809    GenericByteArray::from(data)
810}
811
812/// `filter` implementation for byte view arrays.
813fn filter_byte_view<T: ByteViewType>(
814    array: &GenericByteViewArray<T>,
815    predicate: &FilterPredicate,
816) -> GenericByteViewArray<T> {
817    let new_view_buffer = filter_native(array.views(), predicate);
818
819    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
820        .len(predicate.count)
821        .add_buffer(new_view_buffer)
822        .add_buffers(array.data_buffers().to_vec());
823
824    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
825        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
826    }
827
828    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
829}
830
831fn filter_fixed_size_binary(
832    array: &FixedSizeBinaryArray,
833    predicate: &FilterPredicate,
834) -> FixedSizeBinaryArray {
835    let values: &[u8] = array.values();
836    let value_length = array.value_length() as usize;
837    let calculate_offset_from_index = |index: usize| index * value_length;
838    let buffer = match &predicate.strategy {
839        IterationStrategy::SlicesIterator => {
840            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
841            for (start, end) in SlicesIterator::new(&predicate.filter) {
842                buffer.extend_from_slice(
843                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
844                );
845            }
846            buffer
847        }
848        IterationStrategy::Slices(slices) => {
849            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
850            for (start, end) in slices {
851                buffer.extend_from_slice(
852                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
853                );
854            }
855            buffer
856        }
857        IterationStrategy::IndexIterator => {
858            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
859                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
860            });
861
862            let mut buffer = MutableBuffer::new(predicate.count * value_length);
863            iter.for_each(|item| buffer.extend_from_slice(item));
864            buffer
865        }
866        IterationStrategy::Indices(indices) => {
867            let iter = indices.iter().map(|x| {
868                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
869            });
870
871            let mut buffer = MutableBuffer::new(predicate.count * value_length);
872            iter.for_each(|item| buffer.extend_from_slice(item));
873            buffer
874        }
875        IterationStrategy::All | IterationStrategy::None => unreachable!(),
876    };
877    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
878        .len(predicate.count)
879        .add_buffer(buffer.into());
880
881    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
882        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
883    }
884
885    let data = unsafe { builder.build_unchecked() };
886    FixedSizeBinaryArray::from(data)
887}
888
889/// `filter` implementation for dictionaries
890fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
891where
892    T: ArrowDictionaryKeyType,
893    T::Native: num_traits::Num,
894{
895    let builder = filter_primitive::<T>(array.keys(), predicate)
896        .into_data()
897        .into_builder()
898        .data_type(array.data_type().clone())
899        .child_data(vec![array.values().to_data()]);
900
901    // SAFETY:
902    // Keys were valid before, filtered subset is therefore still valid
903    DictionaryArray::from(unsafe { builder.build_unchecked() })
904}
905
906/// `filter` implementation for structs
907fn filter_struct(
908    array: &StructArray,
909    predicate: &FilterPredicate,
910) -> Result<StructArray, ArrowError> {
911    let columns = array
912        .columns()
913        .iter()
914        .map(|column| filter_array(column, predicate))
915        .collect::<Result<_, _>>()?;
916
917    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
918        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
919
920        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
921    } else {
922        None
923    };
924
925    Ok(unsafe {
926        StructArray::new_unchecked_with_length(
927            array.fields().clone(),
928            columns,
929            nulls,
930            predicate.count(),
931        )
932    })
933}
934
935/// `filter` implementation for sparse unions
936fn filter_sparse_union(
937    array: &UnionArray,
938    predicate: &FilterPredicate,
939) -> Result<UnionArray, ArrowError> {
940    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
941        unreachable!()
942    };
943
944    let type_ids = filter_primitive(
945        &Int8Array::try_new(array.type_ids().clone(), None)?,
946        predicate,
947    );
948
949    let children = fields
950        .iter()
951        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
952        .collect::<Result<_, _>>()?;
953
954    Ok(unsafe {
955        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
956    })
957}
958
959/// `filter` implementation for list views
960fn filter_list_view<OffsetType: OffsetSizeTrait>(
961    array: &GenericListViewArray<OffsetType>,
962    predicate: &FilterPredicate,
963) -> GenericListViewArray<OffsetType> {
964    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
965    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
966
967    // Filter the nulls
968    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
969        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
970
971        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
972    } else {
973        None
974    };
975
976    let list_data = ArrayDataBuilder::new(array.data_type().clone())
977        .nulls(nulls)
978        .buffers(vec![filtered_offsets, filtered_sizes])
979        .child_data(vec![array.values().to_data()])
980        .len(predicate.count);
981
982    let list_data = unsafe { list_data.build_unchecked() };
983
984    GenericListViewArray::from(list_data)
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use arrow_array::builder::*;
991    use arrow_array::cast::as_run_array;
992    use arrow_array::types::*;
993    use arrow_data::ArrayData;
994    use rand::distr::uniform::{UniformSampler, UniformUsize};
995    use rand::distr::{Alphanumeric, StandardUniform};
996    use rand::prelude::*;
997    use rand::rng;
998
999    macro_rules! def_temporal_test {
1000        ($test:ident, $array_type: ident, $data: expr) => {
1001            #[test]
1002            fn $test() {
1003                let a = $data;
1004                let b = BooleanArray::from(vec![true, false, true, false]);
1005                let c = filter(&a, &b).unwrap();
1006                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1007                assert_eq!(2, d.len());
1008                assert_eq!(1, d.value(0));
1009                assert_eq!(3, d.value(1));
1010            }
1011        };
1012    }
1013
1014    def_temporal_test!(
1015        test_filter_date32,
1016        Date32Array,
1017        Date32Array::from(vec![1, 2, 3, 4])
1018    );
1019    def_temporal_test!(
1020        test_filter_date64,
1021        Date64Array,
1022        Date64Array::from(vec![1, 2, 3, 4])
1023    );
1024    def_temporal_test!(
1025        test_filter_time32_second,
1026        Time32SecondArray,
1027        Time32SecondArray::from(vec![1, 2, 3, 4])
1028    );
1029    def_temporal_test!(
1030        test_filter_time32_millisecond,
1031        Time32MillisecondArray,
1032        Time32MillisecondArray::from(vec![1, 2, 3, 4])
1033    );
1034    def_temporal_test!(
1035        test_filter_time64_microsecond,
1036        Time64MicrosecondArray,
1037        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1038    );
1039    def_temporal_test!(
1040        test_filter_time64_nanosecond,
1041        Time64NanosecondArray,
1042        Time64NanosecondArray::from(vec![1, 2, 3, 4])
1043    );
1044    def_temporal_test!(
1045        test_filter_duration_second,
1046        DurationSecondArray,
1047        DurationSecondArray::from(vec![1, 2, 3, 4])
1048    );
1049    def_temporal_test!(
1050        test_filter_duration_millisecond,
1051        DurationMillisecondArray,
1052        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1053    );
1054    def_temporal_test!(
1055        test_filter_duration_microsecond,
1056        DurationMicrosecondArray,
1057        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1058    );
1059    def_temporal_test!(
1060        test_filter_duration_nanosecond,
1061        DurationNanosecondArray,
1062        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1063    );
1064    def_temporal_test!(
1065        test_filter_timestamp_second,
1066        TimestampSecondArray,
1067        TimestampSecondArray::from(vec![1, 2, 3, 4])
1068    );
1069    def_temporal_test!(
1070        test_filter_timestamp_millisecond,
1071        TimestampMillisecondArray,
1072        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1073    );
1074    def_temporal_test!(
1075        test_filter_timestamp_microsecond,
1076        TimestampMicrosecondArray,
1077        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1078    );
1079    def_temporal_test!(
1080        test_filter_timestamp_nanosecond,
1081        TimestampNanosecondArray,
1082        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1083    );
1084
1085    #[test]
1086    fn test_filter_array_slice() {
1087        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1088        let b = BooleanArray::from(vec![true, false, false, true]);
1089        // filtering with sliced filter array is not currently supported
1090        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1091        // let b = b_slice.as_any().downcast_ref().unwrap();
1092        let c = filter(&a, &b).unwrap();
1093        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1094        assert_eq!(2, d.len());
1095        assert_eq!(6, d.value(0));
1096        assert_eq!(9, d.value(1));
1097    }
1098
1099    #[test]
1100    fn test_filter_array_low_density() {
1101        // this test exercises the all 0's branch of the filter algorithm
1102        let mut data_values = (1..=65).collect::<Vec<i32>>();
1103        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1104        // set up two more values after the batch
1105        data_values.extend_from_slice(&[66, 67]);
1106        filter_values.extend_from_slice(&[false, true]);
1107        let a = Int32Array::from(data_values);
1108        let b = BooleanArray::from(filter_values);
1109        let c = filter(&a, &b).unwrap();
1110        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1111        assert_eq!(2, d.len());
1112        assert_eq!(65, d.value(0));
1113        assert_eq!(67, d.value(1));
1114    }
1115
1116    #[test]
1117    fn test_filter_array_high_density() {
1118        // this test exercises the all 1's branch of the filter algorithm
1119        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1120        let mut filter_values = (1..=65)
1121            .map(|i| !matches!(i % 65, 0))
1122            .collect::<Vec<bool>>();
1123        // set second data value to null
1124        data_values[1] = None;
1125        // set up two more values after the batch
1126        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1127        filter_values.extend_from_slice(&[false, true, true, true]);
1128        let a = Int32Array::from(data_values);
1129        let b = BooleanArray::from(filter_values);
1130        let c = filter(&a, &b).unwrap();
1131        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1132        assert_eq!(67, d.len());
1133        assert_eq!(3, d.null_count());
1134        assert_eq!(1, d.value(0));
1135        assert!(d.is_null(1));
1136        assert_eq!(64, d.value(63));
1137        assert!(d.is_null(64));
1138        assert_eq!(67, d.value(65));
1139    }
1140
1141    #[test]
1142    fn test_filter_string_array_simple() {
1143        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1144        let b = BooleanArray::from(vec![true, false, true, false]);
1145        let c = filter(&a, &b).unwrap();
1146        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1147        assert_eq!(2, d.len());
1148        assert_eq!("hello", d.value(0));
1149        assert_eq!("world", d.value(1));
1150    }
1151
1152    #[test]
1153    fn test_filter_primitive_array_with_null() {
1154        let a = Int32Array::from(vec![Some(5), None]);
1155        let b = BooleanArray::from(vec![false, true]);
1156        let c = filter(&a, &b).unwrap();
1157        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1158        assert_eq!(1, d.len());
1159        assert!(d.is_null(0));
1160    }
1161
1162    #[test]
1163    fn test_filter_string_array_with_null() {
1164        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1165        let b = BooleanArray::from(vec![true, false, false, true]);
1166        let c = filter(&a, &b).unwrap();
1167        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1168        assert_eq!(2, d.len());
1169        assert_eq!("hello", d.value(0));
1170        assert!(!d.is_null(0));
1171        assert!(d.is_null(1));
1172    }
1173
1174    #[test]
1175    fn test_filter_binary_array_with_null() {
1176        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1177        let a = BinaryArray::from(data);
1178        let b = BooleanArray::from(vec![true, false, false, true]);
1179        let c = filter(&a, &b).unwrap();
1180        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1181        assert_eq!(2, d.len());
1182        assert_eq!(b"hello", d.value(0));
1183        assert!(!d.is_null(0));
1184        assert!(d.is_null(1));
1185    }
1186
1187    fn _test_filter_byte_view<T>()
1188    where
1189        T: ByteViewType,
1190        str: AsRef<T::Native>,
1191        T::Native: PartialEq,
1192    {
1193        let array = {
1194            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1195            let mut builder = GenericByteViewBuilder::<T>::new();
1196            builder.append_value("hello");
1197            builder.append_value("world");
1198            builder.append_null();
1199            builder.append_value("large payload over 12 bytes");
1200            builder.append_value("lulu");
1201            builder.finish()
1202        };
1203
1204        {
1205            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1206            let actual = filter(&array, &predicate).unwrap();
1207
1208            assert_eq!(actual.len(), 3);
1209
1210            let expected = {
1211                // ["hello", null, "large payload over 12 bytes"]
1212                let mut builder = GenericByteViewBuilder::<T>::new();
1213                builder.append_value("hello");
1214                builder.append_null();
1215                builder.append_value("large payload over 12 bytes");
1216                builder.finish()
1217            };
1218
1219            assert_eq!(actual.as_ref(), &expected);
1220        }
1221
1222        {
1223            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1224            let actual = filter(&array, &predicate).unwrap();
1225
1226            assert_eq!(actual.len(), 2);
1227
1228            let expected = {
1229                // ["hello", "lulu"]
1230                let mut builder = GenericByteViewBuilder::<T>::new();
1231                builder.append_value("hello");
1232                builder.append_value("lulu");
1233                builder.finish()
1234            };
1235
1236            assert_eq!(actual.as_ref(), &expected);
1237        }
1238    }
1239
1240    #[test]
1241    fn test_filter_string_view() {
1242        _test_filter_byte_view::<StringViewType>()
1243    }
1244
1245    #[test]
1246    fn test_filter_binary_view() {
1247        _test_filter_byte_view::<BinaryViewType>()
1248    }
1249
1250    #[test]
1251    fn test_filter_fixed_binary() {
1252        let v1 = [1_u8, 2];
1253        let v2 = [3_u8, 4];
1254        let v3 = [5_u8, 6];
1255        let v = vec![&v1, &v2, &v3];
1256        let a = FixedSizeBinaryArray::from(v);
1257        let b = BooleanArray::from(vec![true, false, true]);
1258        let c = filter(&a, &b).unwrap();
1259        let d = c
1260            .as_ref()
1261            .as_any()
1262            .downcast_ref::<FixedSizeBinaryArray>()
1263            .unwrap();
1264        assert_eq!(d.len(), 2);
1265        assert_eq!(d.value(0), &v1);
1266        assert_eq!(d.value(1), &v3);
1267        let c2 = FilterBuilder::new(&b)
1268            .optimize()
1269            .build()
1270            .filter(&a)
1271            .unwrap();
1272        let d2 = c2
1273            .as_ref()
1274            .as_any()
1275            .downcast_ref::<FixedSizeBinaryArray>()
1276            .unwrap();
1277        assert_eq!(d, d2);
1278
1279        let b = BooleanArray::from(vec![false, false, false]);
1280        let c = filter(&a, &b).unwrap();
1281        let d = c
1282            .as_ref()
1283            .as_any()
1284            .downcast_ref::<FixedSizeBinaryArray>()
1285            .unwrap();
1286        assert_eq!(d.len(), 0);
1287
1288        let b = BooleanArray::from(vec![true, true, true]);
1289        let c = filter(&a, &b).unwrap();
1290        let d = c
1291            .as_ref()
1292            .as_any()
1293            .downcast_ref::<FixedSizeBinaryArray>()
1294            .unwrap();
1295        assert_eq!(d.len(), 3);
1296        assert_eq!(d.value(0), &v1);
1297        assert_eq!(d.value(1), &v2);
1298        assert_eq!(d.value(2), &v3);
1299
1300        let b = BooleanArray::from(vec![false, false, true]);
1301        let c = filter(&a, &b).unwrap();
1302        let d = c
1303            .as_ref()
1304            .as_any()
1305            .downcast_ref::<FixedSizeBinaryArray>()
1306            .unwrap();
1307        assert_eq!(d.len(), 1);
1308        assert_eq!(d.value(0), &v3);
1309        let c2 = FilterBuilder::new(&b)
1310            .optimize()
1311            .build()
1312            .filter(&a)
1313            .unwrap();
1314        let d2 = c2
1315            .as_ref()
1316            .as_any()
1317            .downcast_ref::<FixedSizeBinaryArray>()
1318            .unwrap();
1319        assert_eq!(d, d2);
1320    }
1321
1322    #[test]
1323    fn test_filter_array_slice_with_null() {
1324        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1325        let b = BooleanArray::from(vec![true, false, false, true]);
1326        // filtering with sliced filter array is not currently supported
1327        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1328        // let b = b_slice.as_any().downcast_ref().unwrap();
1329        let c = filter(&a, &b).unwrap();
1330        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1331        assert_eq!(2, d.len());
1332        assert!(d.is_null(0));
1333        assert!(!d.is_null(1));
1334        assert_eq!(9, d.value(1));
1335    }
1336
1337    #[test]
1338    fn test_filter_run_end_encoding_array() {
1339        let run_ends = Int64Array::from(vec![2, 3, 8]);
1340        let values = Int64Array::from(vec![7, -2, 9]);
1341        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1342        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1343        let c = filter(&a, &b).unwrap();
1344        let actual: &RunArray<Int64Type> = as_run_array(&c);
1345        assert_eq!(4, actual.len());
1346
1347        let expected = RunArray::try_new(
1348            &Int64Array::from(vec![1, 2, 4]),
1349            &Int64Array::from(vec![7, -2, 9]),
1350        )
1351        .expect("Failed to make expected RunArray test is broken");
1352
1353        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1354        assert_eq!(actual.values(), expected.values())
1355    }
1356
1357    #[test]
1358    fn test_filter_run_end_encoding_array_remove_value() {
1359        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1360        let values = Int32Array::from(vec![7, -2, 9, -8]);
1361        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1362        let b = BooleanArray::from(vec![
1363            false, true, false, false, true, false, true, false, false, false,
1364        ]);
1365        let c = filter(&a, &b).unwrap();
1366        let actual: &RunArray<Int32Type> = as_run_array(&c);
1367        assert_eq!(3, actual.len());
1368
1369        let expected =
1370            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1371                .expect("Failed to make expected RunArray test is broken");
1372
1373        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1374        assert_eq!(actual.values(), expected.values())
1375    }
1376
1377    #[test]
1378    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1379        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1380        let values = Int16Array::from(vec![7, -2, 9, -8]);
1381        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1382        let b = BooleanArray::from(vec![
1383            false, false, false, false, false, false, true, false, false, false,
1384        ]);
1385        let c = filter(&a, &b).unwrap();
1386        let actual: &RunArray<Int16Type> = as_run_array(&c);
1387        assert_eq!(1, actual.len());
1388
1389        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1390            .expect("Failed to make expected RunArray test is broken");
1391
1392        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1393        assert_eq!(actual.values(), expected.values())
1394    }
1395
1396    #[test]
1397    fn test_filter_run_end_encoding_array_empty() {
1398        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1399        let values = Int64Array::from(vec![7, -2, 9, -8]);
1400        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1401        let b = BooleanArray::from(vec![
1402            false, false, false, false, false, false, false, false, false, false,
1403        ]);
1404        let c = filter(&a, &b).unwrap();
1405        let actual: &RunArray<Int64Type> = as_run_array(&c);
1406        assert_eq!(0, actual.len());
1407    }
1408
1409    #[test]
1410    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1411        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1412        let values = Int64Array::from(vec![7, -2, 9, -8]);
1413        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1414        let b = BooleanArray::from(vec![false, true, true]);
1415        let c = filter(&a, &b).unwrap();
1416        let actual: &RunArray<Int64Type> = as_run_array(&c);
1417        assert_eq!(2, actual.len());
1418
1419        let expected = RunArray::try_new(
1420            &Int64Array::from(vec![1, 2]),
1421            &Int64Array::from(vec![7, -2]),
1422        )
1423        .expect("Failed to make expected RunArray test is broken");
1424
1425        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1426        assert_eq!(actual.values(), expected.values())
1427    }
1428
1429    #[test]
1430    fn test_filter_dictionary_array() {
1431        let values = [Some("hello"), None, Some("world"), Some("!")];
1432        let a: Int8DictionaryArray = values.iter().copied().collect();
1433        let b = BooleanArray::from(vec![false, true, true, false]);
1434        let c = filter(&a, &b).unwrap();
1435        let d = c
1436            .as_ref()
1437            .as_any()
1438            .downcast_ref::<Int8DictionaryArray>()
1439            .unwrap();
1440        let value_array = d.values();
1441        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1442        // values are cloned in the filtered dictionary array
1443        assert_eq!(3, values.len());
1444        // but keys are filtered
1445        assert_eq!(2, d.len());
1446        assert!(d.is_null(0));
1447        assert_eq!("world", values.value(d.keys().value(1) as usize));
1448    }
1449
1450    #[test]
1451    fn test_filter_list_array() {
1452        let value_data = ArrayData::builder(DataType::Int32)
1453            .len(8)
1454            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1455            .build()
1456            .unwrap();
1457
1458        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1459
1460        let list_data_type =
1461            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1462        let list_data = ArrayData::builder(list_data_type)
1463            .len(4)
1464            .add_buffer(value_offsets)
1465            .add_child_data(value_data)
1466            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1467            .build()
1468            .unwrap();
1469
1470        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1471        let a = LargeListArray::from(list_data);
1472        let b = BooleanArray::from(vec![false, true, false, true]);
1473        let result = filter(&a, &b).unwrap();
1474
1475        // expected: [[3, 4, 5], null]
1476        let value_data = ArrayData::builder(DataType::Int32)
1477            .len(3)
1478            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1479            .build()
1480            .unwrap();
1481
1482        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1483
1484        let list_data_type =
1485            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1486        let expected = ArrayData::builder(list_data_type)
1487            .len(2)
1488            .add_buffer(value_offsets)
1489            .add_child_data(value_data)
1490            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1491            .build()
1492            .unwrap();
1493
1494        assert_eq!(&make_array(expected), &result);
1495    }
1496
1497    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1498        // [[1, 2], null, [], [3,4]]
1499        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1500        list_array.append_value([Some(1), Some(2)]);
1501        list_array.append_null();
1502        list_array.append_value([]);
1503        list_array.append_value([Some(3), Some(4)]);
1504
1505        let list_array = list_array.finish();
1506        let predicate = BooleanArray::from_iter([true, false, true, false]);
1507
1508        // Filter result: [[1, 2], []]
1509        let filtered = filter(&list_array, &predicate)
1510            .unwrap()
1511            .as_list_view::<T>()
1512            .clone();
1513
1514        let mut expected =
1515            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1516        expected.append_value([Some(1), Some(2)]);
1517        expected.append_value([]);
1518        let expected = expected.finish();
1519
1520        assert_eq!(&filtered, &expected);
1521    }
1522
1523    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1524        // [[1, 2], null, [], [3,4]]
1525        let mut list_array =
1526            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1527        list_array.append_value([Some(1), Some(2)]);
1528        list_array.append_null();
1529        list_array.append_value([]);
1530        list_array.append_value([Some(3), Some(4)]);
1531
1532        let list_array = list_array.finish();
1533
1534        // Sliced: [null, [], [3, 4]]
1535        let sliced = list_array.slice(1, 3);
1536        let predicate = BooleanArray::from_iter([false, false, true]);
1537
1538        // Filter result: [[1, 2], []]
1539        let filtered = filter(&sliced, &predicate)
1540            .unwrap()
1541            .as_list_view::<T>()
1542            .clone();
1543
1544        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1545        expected.append_value([Some(3), Some(4)]);
1546        let expected = expected.finish();
1547
1548        assert_eq!(&filtered, &expected);
1549    }
1550
1551    #[test]
1552    fn test_filter_list_view_array() {
1553        test_case_filter_list_view::<i32>();
1554        test_case_filter_list_view::<i64>();
1555
1556        test_case_filter_sliced_list_view::<i32>();
1557        test_case_filter_sliced_list_view::<i64>();
1558    }
1559
1560    #[test]
1561    fn test_slice_iterator_bits() {
1562        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1563        let filter = BooleanArray::from(filter_values);
1564        let filter_count = filter_count(&filter);
1565
1566        let iter = SlicesIterator::new(&filter);
1567        let chunks = iter.collect::<Vec<_>>();
1568
1569        assert_eq!(chunks, vec![(1, 2)]);
1570        assert_eq!(filter_count, 1);
1571    }
1572
1573    #[test]
1574    fn test_slice_iterator_bits1() {
1575        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1576        let filter = BooleanArray::from(filter_values);
1577        let filter_count = filter_count(&filter);
1578
1579        let iter = SlicesIterator::new(&filter);
1580        let chunks = iter.collect::<Vec<_>>();
1581
1582        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1583        assert_eq!(filter_count, 64 - 1);
1584    }
1585
1586    #[test]
1587    fn test_slice_iterator_chunk_and_bits() {
1588        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1589        let filter = BooleanArray::from(filter_values);
1590        let filter_count = filter_count(&filter);
1591
1592        let iter = SlicesIterator::new(&filter);
1593        let chunks = iter.collect::<Vec<_>>();
1594
1595        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1596        assert_eq!(filter_count, 61 + 61 + 5);
1597    }
1598
1599    #[test]
1600    fn test_null_mask() {
1601        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1602
1603        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1604        let out = filter(&a, &mask1).unwrap();
1605        assert_eq!(out.as_ref(), &a.slice(0, 2));
1606    }
1607
1608    #[test]
1609    fn test_filter_record_batch_no_columns() {
1610        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1611        let options = RecordBatchOptions::default().with_row_count(Some(100));
1612        let record_batch =
1613            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1614        let out = filter_record_batch(&record_batch, &pred).unwrap();
1615
1616        assert_eq!(out.num_rows(), 2);
1617    }
1618
1619    #[test]
1620    fn test_fast_path() {
1621        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1622
1623        // all true
1624        let mask = BooleanArray::from(vec![true, true, true]);
1625        let out = filter(&a, &mask).unwrap();
1626        let b = out
1627            .as_any()
1628            .downcast_ref::<PrimitiveArray<Int64Type>>()
1629            .unwrap();
1630        assert_eq!(&a, b);
1631
1632        // all false
1633        let mask = BooleanArray::from(vec![false, false, false]);
1634        let out = filter(&a, &mask).unwrap();
1635        assert_eq!(out.len(), 0);
1636        assert_eq!(out.data_type(), &DataType::Int64);
1637    }
1638
1639    #[test]
1640    fn test_slices() {
1641        // takes up 2 u64s
1642        let bools = std::iter::repeat_n(true, 10)
1643            .chain(std::iter::repeat_n(false, 30))
1644            .chain(std::iter::repeat_n(true, 20))
1645            .chain(std::iter::repeat_n(false, 17))
1646            .chain(std::iter::repeat_n(true, 4));
1647
1648        let bool_array: BooleanArray = bools.map(Some).collect();
1649
1650        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1651        let expected = vec![(0, 10), (40, 60), (77, 81)];
1652        assert_eq!(slices, expected);
1653
1654        // slice with offset and truncated len
1655        let len = bool_array.len();
1656        let sliced_array = bool_array.slice(7, len - 10);
1657        let sliced_array = sliced_array
1658            .as_any()
1659            .downcast_ref::<BooleanArray>()
1660            .unwrap();
1661        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1662        let expected = vec![(0, 3), (33, 53), (70, 71)];
1663        assert_eq!(slices, expected);
1664    }
1665
1666    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1667        let mut rng = rng();
1668
1669        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1670            .take(mask_len)
1671            .collect();
1672
1673        let buffer = Buffer::from_iter(bools.iter().cloned());
1674
1675        let truncated_length = mask_len - offset - truncate;
1676
1677        let data = ArrayDataBuilder::new(DataType::Boolean)
1678            .len(truncated_length)
1679            .offset(offset)
1680            .add_buffer(buffer)
1681            .build()
1682            .unwrap();
1683
1684        let filter = BooleanArray::from(data);
1685
1686        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1687            .flat_map(|(start, end)| start..end)
1688            .collect();
1689
1690        let count = filter_count(&filter);
1691        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1692
1693        let expected_bits: Vec<_> = bools
1694            .iter()
1695            .skip(offset)
1696            .take(truncated_length)
1697            .enumerate()
1698            .flat_map(|(idx, v)| v.then(|| idx))
1699            .collect();
1700
1701        assert_eq!(slice_bits, expected_bits);
1702        assert_eq!(index_bits, expected_bits);
1703    }
1704
1705    #[test]
1706    #[cfg_attr(miri, ignore)]
1707    fn fuzz_test_slices_iterator() {
1708        let mut rng = rng();
1709
1710        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1711        for _ in 0..100 {
1712            let mask_len = rng.random_range(0..1024);
1713            let max_offset = 64.min(mask_len);
1714            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1715
1716            let max_truncate = 128.min(mask_len - offset);
1717            let truncate = uusize
1718                .sample(&mut rng)
1719                .checked_rem(max_truncate)
1720                .unwrap_or(0);
1721
1722            test_slices_fuzz(mask_len, offset, truncate);
1723        }
1724
1725        test_slices_fuzz(64, 0, 0);
1726        test_slices_fuzz(64, 8, 0);
1727        test_slices_fuzz(64, 8, 8);
1728        test_slices_fuzz(32, 8, 8);
1729        test_slices_fuzz(32, 5, 9);
1730    }
1731
1732    /// Filters `values` by `predicate` using standard rust iterators
1733    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1734        values
1735            .into_iter()
1736            .zip(predicate)
1737            .filter(|(_, x)| **x)
1738            .map(|(a, _)| a)
1739            .collect()
1740    }
1741
1742    /// Generates an array of length `len` with `valid_percent` non-null values
1743    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1744    where
1745        StandardUniform: Distribution<T>,
1746    {
1747        let mut rng = rng();
1748        (0..len)
1749            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1750            .collect()
1751    }
1752
1753    /// Generates an array of length `len` with `valid_percent` non-null values
1754    fn gen_strings(
1755        len: usize,
1756        valid_percent: f64,
1757        str_len_range: std::ops::Range<usize>,
1758    ) -> Vec<Option<String>> {
1759        let mut rng = rng();
1760        (0..len)
1761            .map(|_| {
1762                rng.random_bool(valid_percent).then(|| {
1763                    let len = rng.random_range(str_len_range.clone());
1764                    (0..len)
1765                        .map(|_| char::from(rng.sample(Alphanumeric)))
1766                        .collect()
1767                })
1768            })
1769            .collect()
1770    }
1771
1772    /// Returns an iterator that calls `Option::as_deref` on each item
1773    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1774        src.iter().map(|x| x.as_deref())
1775    }
1776
1777    #[test]
1778    #[cfg_attr(miri, ignore)]
1779    fn fuzz_filter() {
1780        let mut rng = rng();
1781
1782        for i in 0..100 {
1783            let filter_percent = match i {
1784                0..=4 => 1.,
1785                5..=10 => 0.,
1786                _ => rng.random_range(0.0..1.0),
1787            };
1788
1789            let valid_percent = rng.random_range(0.0..1.0);
1790
1791            let array_len = rng.random_range(32..256);
1792            let array_offset = rng.random_range(0..10);
1793
1794            // Construct a predicate
1795            let filter_offset = rng.random_range(0..10);
1796            let filter_truncate = rng.random_range(0..10);
1797            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1798                .take(array_len + filter_offset - filter_truncate)
1799                .collect();
1800
1801            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1802
1803            // Offset predicate
1804            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1805            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1806            let bools = &bools[filter_offset..];
1807
1808            // Test i32
1809            let values = gen_primitive(array_len + array_offset, valid_percent);
1810            let src = Int32Array::from_iter(values.iter().cloned());
1811
1812            let src = src.slice(array_offset, array_len);
1813            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1814            let values = &values[array_offset..];
1815
1816            let filtered = filter(src, predicate).unwrap();
1817            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1818            let actual: Vec<_> = array.iter().collect();
1819
1820            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1821
1822            // Test string
1823            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1824            let src = StringArray::from_iter(as_deref(&strings));
1825
1826            let src = src.slice(array_offset, array_len);
1827            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1828
1829            let filtered = filter(src, predicate).unwrap();
1830            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1831            let actual: Vec<_> = array.iter().collect();
1832
1833            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1834            assert_eq!(actual, expected_strings);
1835
1836            // Test string dictionary
1837            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1838
1839            let src = src.slice(array_offset, array_len);
1840            let src = src
1841                .as_any()
1842                .downcast_ref::<DictionaryArray<Int32Type>>()
1843                .unwrap();
1844
1845            let filtered = filter(src, predicate).unwrap();
1846
1847            let array = filtered
1848                .as_any()
1849                .downcast_ref::<DictionaryArray<Int32Type>>()
1850                .unwrap();
1851
1852            let values = array
1853                .values()
1854                .as_any()
1855                .downcast_ref::<StringArray>()
1856                .unwrap();
1857
1858            let actual: Vec<_> = array
1859                .keys()
1860                .iter()
1861                .map(|key| key.map(|key| values.value(key as usize)))
1862                .collect();
1863
1864            assert_eq!(actual, expected_strings);
1865        }
1866    }
1867
1868    #[test]
1869    fn test_filter_map() {
1870        let mut builder =
1871            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1872        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1873        builder.keys().append_value("key1");
1874        builder.values().append_value(1);
1875        builder.append(true).unwrap();
1876        builder.keys().append_value("key2");
1877        builder.keys().append_value("key3");
1878        builder.values().append_value(2);
1879        builder.values().append_value(3);
1880        builder.append(true).unwrap();
1881        builder.append(false).unwrap();
1882        builder.keys().append_value("key1");
1883        builder.values().append_value(1);
1884        builder.append(true).unwrap();
1885        let maparray = Arc::new(builder.finish()) as ArrayRef;
1886
1887        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1888            .into_iter()
1889            .collect::<BooleanArray>();
1890        let got = filter(&maparray, &indices).unwrap();
1891
1892        let mut builder =
1893            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1894        builder.keys().append_value("key1");
1895        builder.values().append_value(1);
1896        builder.append(true).unwrap();
1897        builder.keys().append_value("key1");
1898        builder.values().append_value(1);
1899        builder.append(true).unwrap();
1900        let expected = Arc::new(builder.finish()) as ArrayRef;
1901
1902        assert_eq!(&expected, &got);
1903    }
1904
1905    #[test]
1906    fn test_filter_fixed_size_list_arrays() {
1907        let value_data = ArrayData::builder(DataType::Int32)
1908            .len(9)
1909            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1910            .build()
1911            .unwrap();
1912        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1913        let list_data = ArrayData::builder(list_data_type)
1914            .len(3)
1915            .add_child_data(value_data)
1916            .build()
1917            .unwrap();
1918        let array = FixedSizeListArray::from(list_data);
1919
1920        let filter_array = BooleanArray::from(vec![true, false, false]);
1921
1922        let c = filter(&array, &filter_array).unwrap();
1923        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1924
1925        assert_eq!(filtered.len(), 1);
1926
1927        let list = filtered.value(0);
1928        assert_eq!(
1929            &[0, 1, 2],
1930            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1931        );
1932
1933        let filter_array = BooleanArray::from(vec![true, false, true]);
1934
1935        let c = filter(&array, &filter_array).unwrap();
1936        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1937
1938        assert_eq!(filtered.len(), 2);
1939
1940        let list = filtered.value(0);
1941        assert_eq!(
1942            &[0, 1, 2],
1943            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1944        );
1945        let list = filtered.value(1);
1946        assert_eq!(
1947            &[6, 7, 8],
1948            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1949        );
1950    }
1951
1952    #[test]
1953    fn test_filter_fixed_size_list_arrays_with_null() {
1954        let value_data = ArrayData::builder(DataType::Int32)
1955            .len(10)
1956            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1957            .build()
1958            .unwrap();
1959
1960        // Set null buts for the nested array:
1961        //  [[0, 1], null, null, [6, 7], [8, 9]]
1962        // 01011001 00000001
1963        let mut null_bits: [u8; 1] = [0; 1];
1964        bit_util::set_bit(&mut null_bits, 0);
1965        bit_util::set_bit(&mut null_bits, 3);
1966        bit_util::set_bit(&mut null_bits, 4);
1967
1968        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1969        let list_data = ArrayData::builder(list_data_type)
1970            .len(5)
1971            .add_child_data(value_data)
1972            .null_bit_buffer(Some(Buffer::from(null_bits)))
1973            .build()
1974            .unwrap();
1975        let array = FixedSizeListArray::from(list_data);
1976
1977        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1978
1979        let c = filter(&array, &filter_array).unwrap();
1980        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1981
1982        assert_eq!(filtered.len(), 3);
1983
1984        let list = filtered.value(0);
1985        assert_eq!(
1986            &[0, 1],
1987            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1988        );
1989        assert!(filtered.is_null(1));
1990        let list = filtered.value(2);
1991        assert_eq!(
1992            &[6, 7],
1993            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1994        );
1995    }
1996
1997    fn test_filter_union_array(array: UnionArray) {
1998        let filter_array = BooleanArray::from(vec![true, false, false]);
1999        let c = filter(&array, &filter_array).unwrap();
2000        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2001
2002        let mut builder = UnionBuilder::new_dense();
2003        builder.append::<Int32Type>("A", 1).unwrap();
2004        let expected_array = builder.build().unwrap();
2005
2006        compare_union_arrays(filtered, &expected_array);
2007
2008        let filter_array = BooleanArray::from(vec![true, false, true]);
2009        let c = filter(&array, &filter_array).unwrap();
2010        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2011
2012        let mut builder = UnionBuilder::new_dense();
2013        builder.append::<Int32Type>("A", 1).unwrap();
2014        builder.append::<Int32Type>("A", 34).unwrap();
2015        let expected_array = builder.build().unwrap();
2016
2017        compare_union_arrays(filtered, &expected_array);
2018
2019        let filter_array = BooleanArray::from(vec![true, true, false]);
2020        let c = filter(&array, &filter_array).unwrap();
2021        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2022
2023        let mut builder = UnionBuilder::new_dense();
2024        builder.append::<Int32Type>("A", 1).unwrap();
2025        builder.append::<Float64Type>("B", 3.2).unwrap();
2026        let expected_array = builder.build().unwrap();
2027
2028        compare_union_arrays(filtered, &expected_array);
2029    }
2030
2031    #[test]
2032    fn test_filter_union_array_dense() {
2033        let mut builder = UnionBuilder::new_dense();
2034        builder.append::<Int32Type>("A", 1).unwrap();
2035        builder.append::<Float64Type>("B", 3.2).unwrap();
2036        builder.append::<Int32Type>("A", 34).unwrap();
2037        let array = builder.build().unwrap();
2038
2039        test_filter_union_array(array);
2040    }
2041
2042    #[test]
2043    fn test_filter_run_union_array_dense() {
2044        let mut builder = UnionBuilder::new_dense();
2045        builder.append::<Int32Type>("A", 1).unwrap();
2046        builder.append::<Int32Type>("A", 3).unwrap();
2047        builder.append::<Int32Type>("A", 34).unwrap();
2048        let array = builder.build().unwrap();
2049
2050        let filter_array = BooleanArray::from(vec![true, true, false]);
2051        let c = filter(&array, &filter_array).unwrap();
2052        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2053
2054        let mut builder = UnionBuilder::new_dense();
2055        builder.append::<Int32Type>("A", 1).unwrap();
2056        builder.append::<Int32Type>("A", 3).unwrap();
2057        let expected = builder.build().unwrap();
2058
2059        assert_eq!(filtered.to_data(), expected.to_data());
2060    }
2061
2062    #[test]
2063    fn test_filter_union_array_dense_with_nulls() {
2064        let mut builder = UnionBuilder::new_dense();
2065        builder.append::<Int32Type>("A", 1).unwrap();
2066        builder.append::<Float64Type>("B", 3.2).unwrap();
2067        builder.append_null::<Float64Type>("B").unwrap();
2068        builder.append::<Int32Type>("A", 34).unwrap();
2069        let array = builder.build().unwrap();
2070
2071        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2072        let c = filter(&array, &filter_array).unwrap();
2073        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2074
2075        let mut builder = UnionBuilder::new_dense();
2076        builder.append::<Int32Type>("A", 1).unwrap();
2077        builder.append::<Float64Type>("B", 3.2).unwrap();
2078        let expected_array = builder.build().unwrap();
2079
2080        compare_union_arrays(filtered, &expected_array);
2081
2082        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2083        let c = filter(&array, &filter_array).unwrap();
2084        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2085
2086        let mut builder = UnionBuilder::new_dense();
2087        builder.append::<Int32Type>("A", 1).unwrap();
2088        builder.append_null::<Float64Type>("B").unwrap();
2089        let expected_array = builder.build().unwrap();
2090
2091        compare_union_arrays(filtered, &expected_array);
2092    }
2093
2094    #[test]
2095    fn test_filter_union_array_sparse() {
2096        let mut builder = UnionBuilder::new_sparse();
2097        builder.append::<Int32Type>("A", 1).unwrap();
2098        builder.append::<Float64Type>("B", 3.2).unwrap();
2099        builder.append::<Int32Type>("A", 34).unwrap();
2100        let array = builder.build().unwrap();
2101
2102        test_filter_union_array(array);
2103    }
2104
2105    #[test]
2106    fn test_filter_union_array_sparse_with_nulls() {
2107        let mut builder = UnionBuilder::new_sparse();
2108        builder.append::<Int32Type>("A", 1).unwrap();
2109        builder.append::<Float64Type>("B", 3.2).unwrap();
2110        builder.append_null::<Float64Type>("B").unwrap();
2111        builder.append::<Int32Type>("A", 34).unwrap();
2112        let array = builder.build().unwrap();
2113
2114        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2115        let c = filter(&array, &filter_array).unwrap();
2116        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2117
2118        let mut builder = UnionBuilder::new_sparse();
2119        builder.append::<Int32Type>("A", 1).unwrap();
2120        builder.append_null::<Float64Type>("B").unwrap();
2121        let expected_array = builder.build().unwrap();
2122
2123        compare_union_arrays(filtered, &expected_array);
2124    }
2125
2126    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2127        assert_eq!(union1.len(), union2.len());
2128
2129        for i in 0..union1.len() {
2130            let type_id = union1.type_id(i);
2131
2132            let slot1 = union1.value(i);
2133            let slot2 = union2.value(i);
2134
2135            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2136
2137            if !slot1.is_null(0) && !slot2.is_null(0) {
2138                match type_id {
2139                    0 => {
2140                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2141                        assert_eq!(slot1.len(), 1);
2142                        let value1 = slot1.value(0);
2143
2144                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2145                        assert_eq!(slot2.len(), 1);
2146                        let value2 = slot2.value(0);
2147                        assert_eq!(value1, value2);
2148                    }
2149                    1 => {
2150                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2151                        assert_eq!(slot1.len(), 1);
2152                        let value1 = slot1.value(0);
2153
2154                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2155                        assert_eq!(slot2.len(), 1);
2156                        let value2 = slot2.value(0);
2157                        assert_eq!(value1, value2);
2158                    }
2159                    _ => unreachable!(),
2160                }
2161            }
2162        }
2163    }
2164
2165    #[test]
2166    fn test_filter_struct() {
2167        let predicate = BooleanArray::from(vec![true, false, true, false]);
2168
2169        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2170        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2171
2172        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2173        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2174
2175        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2176        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2177
2178        let a_field = Field::new("a", DataType::Utf8, false);
2179        let b_field = Field::new("b", DataType::Int32, false);
2180
2181        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2182        let expected =
2183            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2184
2185        let result = filter(&array, &predicate).unwrap();
2186
2187        assert_eq!(result.to_data(), expected.to_data());
2188
2189        let array = StructArray::new(
2190            vec![a_field.clone()].into(),
2191            vec![a.clone()],
2192            Some(null_mask.clone()),
2193        );
2194        let expected = StructArray::new(
2195            vec![a_field.clone()].into(),
2196            vec![a_filtered.clone()],
2197            Some(null_mask_filtered.clone()),
2198        );
2199
2200        let result = filter(&array, &predicate).unwrap();
2201
2202        assert_eq!(result.to_data(), expected.to_data());
2203
2204        let array = StructArray::new(
2205            vec![a_field.clone(), b_field.clone()].into(),
2206            vec![a.clone(), b.clone()],
2207            None,
2208        );
2209        let expected = StructArray::new(
2210            vec![a_field.clone(), b_field.clone()].into(),
2211            vec![a_filtered.clone(), b_filtered.clone()],
2212            None,
2213        );
2214
2215        let result = filter(&array, &predicate).unwrap();
2216
2217        assert_eq!(result.to_data(), expected.to_data());
2218
2219        let array = StructArray::new(
2220            vec![a_field.clone(), b_field.clone()].into(),
2221            vec![a.clone(), b.clone()],
2222            Some(null_mask.clone()),
2223        );
2224
2225        let expected = StructArray::new(
2226            vec![a_field.clone(), b_field.clone()].into(),
2227            vec![a_filtered.clone(), b_filtered.clone()],
2228            Some(null_mask_filtered.clone()),
2229        );
2230
2231        let result = filter(&array, &predicate).unwrap();
2232
2233        assert_eq!(result.to_data(), expected.to_data());
2234    }
2235
2236    #[test]
2237    fn test_filter_empty_struct() {
2238        /*
2239            "a": {
2240                "b": int64,
2241                "c": {}
2242            },
2243        */
2244        let fields = arrow_schema::Field::new(
2245            "a",
2246            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2247                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2248                arrow_schema::Field::new(
2249                    "c",
2250                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2251                    true,
2252                ),
2253            ])),
2254            true,
2255        );
2256
2257        /* Test record
2258            {"a":{"c": {}}}
2259            {"a":{"c": {}}}
2260            {"a":{"c": {}}}
2261        */
2262
2263        // Create the record batch with the nested struct array
2264        let schema = Arc::new(Schema::new(vec![fields]));
2265
2266        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2267        let c = Arc::new(StructArray::new_empty_fields(
2268            3,
2269            Some(NullBuffer::from(vec![true, true, true])),
2270        ));
2271        let a = StructArray::new(
2272            vec![
2273                Field::new("b", DataType::Int64, true),
2274                Field::new("c", DataType::Struct(Fields::empty()), true),
2275            ]
2276            .into(),
2277            vec![b.clone(), c.clone()],
2278            Some(NullBuffer::from(vec![true, true, true])),
2279        );
2280        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2281        println!("{record_batch:?}");
2282
2283        // Apply the filter
2284        let predicate = BooleanArray::from(vec![true, false, true]);
2285        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2286
2287        // The filtered batch should have 2 rows (the 1st and 3rd)
2288        assert_eq!(filtered_batch.num_rows(), 2);
2289    }
2290
2291    #[test]
2292    #[should_panic]
2293    fn test_filter_bits_too_large() {
2294        let buffer = BooleanBuffer::from(vec![false; 8]);
2295        let predicate = BooleanArray::from(vec![true; 9]);
2296        let filter = FilterBuilder::new(&predicate).build();
2297        filter_bits(&buffer, &filter);
2298    }
2299
2300    #[test]
2301    #[should_panic]
2302    fn test_filter_native_too_large() {
2303        let values = vec![1; 8];
2304        let predicate = BooleanArray::from(vec![false; 9]);
2305        let filter = FilterBuilder::new(&predicate).build();
2306        filter_native(&values, &filter);
2307    }
2308}