Skip to main content

arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        filter.values().into()
62    }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66    fn from(filter: &'a BooleanBuffer) -> Self {
67        Self(filter.set_slices())
68    }
69}
70
71impl Iterator for SlicesIterator<'_> {
72    type Item = (usize, usize);
73
74    fn next(&mut self) -> Option<Self::Item> {
75        self.0.next()
76    }
77}
78
79/// An iterator of `usize` whose index in [`BooleanArray`] is true
80///
81/// This provides the best performance on most predicates, apart from those which keep
82/// large runs and therefore favour [`SlicesIterator`]
83struct IndexIterator<'a> {
84    remaining: usize,
85    iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90        assert_eq!(filter.null_count(), 0);
91        let iter = filter.values().set_indices();
92        Self { remaining, iter }
93    }
94
95    /// Collect this iterator as a [`Vec`]
96    /// This is more efficient than the standard `collect` as we can
97    /// pre-allocate the entire uninitialized buffer and then fill it (roughly 1.6x faster)
98    pub fn collect(mut self) -> Vec<usize> {
99        let len = self.remaining;
100        let mut result = Vec::with_capacity(len);
101        let ptr: *mut usize = result.as_mut_ptr();
102        for i in 0..len {
103            // SAFETY: we have allocated enough space in `result` and remaining
104            // correctly tracks the number of elements
105            let next = self.iter.next();
106            debug_assert!(next.is_some(), "IndexIterator exhausted early");
107            unsafe {
108                *ptr.add(i) = next.unwrap_unchecked();
109            }
110        }
111        // SAFETY: we have initialized `len` elements
112        unsafe {
113            result.set_len(len);
114        }
115        result
116    }
117}
118
119impl Iterator for IndexIterator<'_> {
120    type Item = usize;
121
122    fn next(&mut self) -> Option<Self::Item> {
123        if self.remaining != 0 {
124            // Fascinatingly swapping these two lines around results in a 50%
125            // performance regression for some benchmarks
126            let next = self.iter.next().expect("IndexIterator exhausted early");
127            self.remaining -= 1;
128            // Must panic if exhausted early as trusted length iterator
129            return Some(next);
130        }
131        None
132    }
133
134    fn size_hint(&self) -> (usize, Option<usize>) {
135        (self.remaining, Some(self.remaining))
136    }
137}
138
139/// Counts the number of set bits in `filter`
140fn filter_count(filter: &BooleanArray) -> usize {
141    filter.values().count_set_bits()
142}
143
144/// Convert all null values in `BooleanArray` to `false`
145///
146/// This is useful for filter-like operations which select only `true`
147/// values, but not `false` or `NULL` values
148///
149/// Internally this is implemented as a bitwise `AND` operation with null bits
150/// and the boolean bits.
151///
152/// # Example
153/// ```
154/// # use arrow_array::{Array, BooleanArray};
155/// # use arrow_select::filter::prep_null_mask_filter;
156/// let filter = BooleanArray::from(vec![
157///   Some(true),
158///   Some(false),
159///   None
160/// ]);
161/// // convert Boolean array to a filter mask
162/// let null_mask = prep_null_mask_filter(&filter);
163/// // there are no nulls in the output mask
164/// assert!(null_mask.nulls().is_none());
165/// assert_eq!(null_mask, BooleanArray::from(vec![
166///  true,
167///  false,
168///  false, // Null is converted to false
169/// ]));
170/// ```
171pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
172    let nulls = filter.nulls().unwrap();
173    let mask = filter.values() & nulls.inner();
174    BooleanArray::new(mask, None)
175}
176
177/// Returns a filtered `values` [`Array`] where the corresponding elements of
178/// `predicate` are `true`.
179///
180/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
181/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
182/// calling [FilterPredicate::filter_record_batch].
183///
184/// In contrast to this function, it is then the responsibility of the caller
185/// to use [FilterBuilder::optimize] if appropriate.
186///
187/// # See also
188/// * [`FilterBuilder`] for more control over the filtering process.
189/// * [`filter_record_batch`] to filter a [`RecordBatch`]
190/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
191///   the results into a single array.
192///
193/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
194///
195/// # Example
196/// ```rust
197/// # use arrow_array::{Int32Array, BooleanArray};
198/// # use arrow_select::filter::filter;
199/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
200/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
201/// let c = filter(&array, &filter_array).unwrap();
202/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
203/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
204/// ```
205pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
206    let mut filter_builder = FilterBuilder::new(predicate);
207
208    if FilterBuilder::is_optimize_beneficial(values.data_type()) {
209        // Only optimize if filtering more than one array
210        // Otherwise, the overhead of optimization can be more than the benefit
211        filter_builder = filter_builder.optimize();
212    }
213
214    let predicate = filter_builder.build();
215
216    filter_array(values, &predicate)
217}
218
219/// Returns a filtered [RecordBatch] where the corresponding elements of
220/// `predicate` are true.
221///
222/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
223///
224/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
225/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
226/// calling [FilterPredicate::filter_record_batch].
227/// In contrast to this function, it is then the responsibility of the caller
228/// to use [FilterBuilder::optimize] if appropriate.
229pub fn filter_record_batch(
230    record_batch: &RecordBatch,
231    predicate: &BooleanArray,
232) -> Result<RecordBatch, ArrowError> {
233    let mut filter_builder = FilterBuilder::new(predicate);
234    let num_cols = record_batch.num_columns();
235    if num_cols > 1
236        || (num_cols > 0
237            && FilterBuilder::is_optimize_beneficial(
238                record_batch.schema_ref().field(0).data_type(),
239            ))
240    {
241        // Only optimize if filtering more than one column or if the column contains multiple internal arrays
242        // Otherwise, the overhead of optimization can be more than the benefit
243        filter_builder = filter_builder.optimize();
244    }
245    let filter = filter_builder.build();
246
247    filter.filter_record_batch(record_batch)
248}
249
250/// A builder to construct [`FilterPredicate`]
251#[derive(Debug)]
252pub struct FilterBuilder {
253    filter: BooleanArray,
254    count: usize,
255    strategy: IterationStrategy,
256}
257
258impl FilterBuilder {
259    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
260    pub fn new(filter: &BooleanArray) -> Self {
261        let filter = match filter.null_count() {
262            0 => filter.clone(),
263            _ => prep_null_mask_filter(filter),
264        };
265
266        let count = filter_count(&filter);
267        let strategy = IterationStrategy::default_strategy(filter.len(), count);
268
269        Self {
270            filter,
271            count,
272            strategy,
273        }
274    }
275
276    /// Compute an optimized representation of the provided `filter` mask that can be
277    /// applied to an array more quickly.
278    ///
279    /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a
280    /// [`StructArray`] with multiple fields), optimizing the filter can provide
281    /// significant performance benefits.
282    ///
283    /// However, optimization takes time and can have a larger memory footprint
284    /// than the original mask, so it is often faster to filter a single array,
285    /// without filter optimization.
286    pub fn optimize(mut self) -> Self {
287        match self.strategy {
288            IterationStrategy::SlicesIterator => {
289                let slices = SlicesIterator::new(&self.filter).collect();
290                self.strategy = IterationStrategy::Slices(slices)
291            }
292            IterationStrategy::IndexIterator => {
293                let indices = IndexIterator::new(&self.filter, self.count).collect();
294                self.strategy = IterationStrategy::Indices(indices)
295            }
296            _ => {}
297        }
298        self
299    }
300
301    /// Determines if calling [FilterBuilder::optimize] is beneficial for the
302    /// given type even when filtering just a single array.
303    ///
304    /// See [`FilterBuilder::optimize`] for more details.
305    pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
306        match data_type {
307            DataType::Struct(fields) => {
308                fields.len() > 1
309                    || fields.len() == 1
310                        && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
311            }
312            DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
313            _ => false,
314        }
315    }
316
317    /// Construct the final `FilterPredicate`
318    pub fn build(self) -> FilterPredicate {
319        FilterPredicate {
320            filter: self.filter,
321            count: self.count,
322            strategy: self.strategy,
323        }
324    }
325}
326
327/// The iteration strategy used to evaluate [`FilterPredicate`]
328#[derive(Debug)]
329enum IterationStrategy {
330    /// A lazily evaluated iterator of ranges
331    SlicesIterator,
332    /// A lazily evaluated iterator of indices
333    IndexIterator,
334    /// A precomputed list of indices
335    Indices(Vec<usize>),
336    /// A precomputed array of ranges
337    Slices(Vec<(usize, usize)>),
338    /// Select all rows
339    All,
340    /// Select no rows
341    None,
342}
343
344impl IterationStrategy {
345    /// The default [`IterationStrategy`] for a filter of length `filter_length`
346    /// and selecting `filter_count` rows
347    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
348        if filter_length == 0 || filter_count == 0 {
349            return IterationStrategy::None;
350        }
351
352        if filter_count == filter_length {
353            return IterationStrategy::All;
354        }
355
356        // Compute the selectivity of the predicate by dividing the number of true
357        // bits in the predicate by the predicate's total length
358        //
359        // This can then be used as a heuristic for the optimal iteration strategy
360        let selectivity_frac = filter_count as f64 / filter_length as f64;
361        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
362            return IterationStrategy::SlicesIterator;
363        }
364        IterationStrategy::IndexIterator
365    }
366}
367
368/// A filtering predicate that can be applied to an [`Array`]
369#[derive(Debug)]
370pub struct FilterPredicate {
371    filter: BooleanArray,
372    count: usize,
373    strategy: IterationStrategy,
374}
375
376impl FilterPredicate {
377    /// Selects rows from `values` based on this [`FilterPredicate`]
378    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
379        filter_array(values, self)
380    }
381
382    /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
383    /// [`FilterPredicate`].
384    ///
385    /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
386    pub fn filter_record_batch(
387        &self,
388        record_batch: &RecordBatch,
389    ) -> Result<RecordBatch, ArrowError> {
390        let filtered_arrays = record_batch
391            .columns()
392            .iter()
393            .map(|a| filter_array(a, self))
394            .collect::<Result<Vec<_>, _>>()?;
395
396        // SAFETY: we know that the set of filtered arrays will match the schema of the original
397        // record batch
398        unsafe {
399            Ok(RecordBatch::new_unchecked(
400                record_batch.schema(),
401                filtered_arrays,
402                self.count,
403            ))
404        }
405    }
406
407    /// Number of rows being selected based on this [`FilterPredicate`]
408    pub fn count(&self) -> usize {
409        self.count
410    }
411}
412
413fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
414    if predicate.filter.len() > values.len() {
415        return Err(ArrowError::InvalidArgumentError(format!(
416            "Filter predicate of length {} is larger than target array of length {}",
417            predicate.filter.len(),
418            values.len()
419        )));
420    }
421
422    match predicate.strategy {
423        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
424        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
425        // actually filter
426        _ => downcast_primitive_array! {
427            values => Ok(Arc::new(filter_primitive(values, predicate))),
428            DataType::Boolean => {
429                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
430                Ok(Arc::new(filter_boolean(values, predicate)))
431            }
432            DataType::Utf8 => {
433                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
434            }
435            DataType::LargeUtf8 => {
436                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
437            }
438            DataType::Utf8View => {
439                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
440            }
441            DataType::Binary => {
442                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
443            }
444            DataType::LargeBinary => {
445                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
446            }
447            DataType::BinaryView => {
448                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
449            }
450            DataType::FixedSizeBinary(_) => {
451                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
452            }
453            DataType::ListView(_) => {
454                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
455            }
456            DataType::LargeListView(_) => {
457                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
458            }
459            DataType::RunEndEncoded(_, _) => {
460                downcast_run_array!{
461                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
462                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
463                }
464            }
465            DataType::Dictionary(_, _) => downcast_dictionary_array! {
466                values => Ok(Arc::new(filter_dict(values, predicate))),
467                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
468            }
469            DataType::Struct(_) => {
470                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
471            }
472            DataType::Union(_, UnionMode::Sparse) => {
473                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
474            }
475            _ => {
476                let data = values.to_data();
477                // fallback to using MutableArrayData
478                let mut mutable = MutableArrayData::new(
479                    vec![&data],
480                    false,
481                    predicate.count,
482                );
483
484                match &predicate.strategy {
485                    IterationStrategy::Slices(slices) => {
486                        slices
487                            .iter()
488                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
489                    }
490                    _ => {
491                        let iter = SlicesIterator::new(&predicate.filter);
492                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
493                    }
494                }
495
496                let data = mutable.freeze();
497                Ok(make_array(data))
498            }
499        },
500    }
501}
502
503/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
504fn filter_run_end_array<R: RunEndIndexType>(
505    array: &RunArray<R>,
506    predicate: &FilterPredicate,
507) -> Result<RunArray<R>, ArrowError>
508where
509    R::Native: Into<i64> + From<bool>,
510    R::Native: AddAssign,
511{
512    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
513    let start_physical = run_ends.get_start_physical_index();
514    let end_physical = run_ends.get_end_physical_index();
515    let physical_len = end_physical - start_physical + 1;
516
517    let mut new_run_ends = vec![R::default_value(); physical_len];
518    let offset = run_ends.offset() as u64;
519
520    let mut start = 0u64;
521    let mut j = 0;
522    let mut count = R::default_value();
523    let filter_values = predicate.filter.values();
524    let run_ends = run_ends.inner();
525
526    let pred: BooleanArray = BooleanBuffer::collect_bool(physical_len, |i| {
527        let mut keep = false;
528        let mut end = (run_ends[i + start_physical].into() as u64).saturating_sub(offset);
529        let difference = end.saturating_sub(filter_values.len() as u64);
530        end -= difference;
531
532        // Safety: we subtract the difference off `end` so we are always within bounds
533        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
534            count += R::Native::from(pred);
535            keep |= pred
536        }
537        // this is to avoid branching
538        new_run_ends[j] = count;
539        j += keep as usize;
540
541        start = end;
542        keep
543    })
544    .into();
545
546    new_run_ends.truncate(j);
547
548    let values = array.values_slice();
549    let values = filter(values.as_ref(), &pred)?;
550
551    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
552    RunArray::try_new(&run_ends, &values)
553}
554
555/// Computes a new null mask for `data` based on `predicate`
556///
557/// If the predicate selected no null-rows, returns `None`, otherwise returns
558/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
559/// in the filtered output, and `null_buffer` is the filtered null buffer
560///
561fn filter_null_mask(
562    nulls: Option<&NullBuffer>,
563    predicate: &FilterPredicate,
564) -> Option<(usize, Buffer)> {
565    let nulls = nulls?;
566    if nulls.null_count() == 0 {
567        return None;
568    }
569
570    let nulls = filter_bits(nulls.inner(), predicate);
571    // The filtered `nulls` has a length of `predicate.count` bits and
572    // therefore the null count is this minus the number of valid bits
573    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
574
575    if null_count == 0 {
576        return None;
577    }
578
579    Some((null_count, nulls))
580}
581
582/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
583fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
584    let src = buffer.values();
585    let offset = buffer.offset();
586    assert!(buffer.len() >= predicate.filter.len());
587
588    match &predicate.strategy {
589        IterationStrategy::IndexIterator => {
590            let bits =
591                // SAFETY: IndexIterator uses the filter predicate to derive indices
592                IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
593                    bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
594                });
595
596            // SAFETY: `IndexIterator` reports its size correctly
597            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
598        }
599        IterationStrategy::Indices(indices) => {
600            // SAFETY: indices were derived from the filter predicate
601            let bits = indices.iter().map(|src_idx| unsafe {
602                bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
603            });
604            // SAFETY: `Vec::iter()` reports its size correctly
605            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
606        }
607        IterationStrategy::SlicesIterator => {
608            let mut builder = BooleanBufferBuilder::new(predicate.count);
609            for (start, end) in SlicesIterator::new(&predicate.filter) {
610                builder.append_packed_range(start + offset..end + offset, src)
611            }
612            builder.into()
613        }
614        IterationStrategy::Slices(slices) => {
615            let mut builder = BooleanBufferBuilder::new(predicate.count);
616            for (start, end) in slices {
617                builder.append_packed_range(*start + offset..*end + offset, src)
618            }
619            builder.into()
620        }
621        IterationStrategy::All | IterationStrategy::None => unreachable!(),
622    }
623}
624
625/// `filter` implementation for boolean buffers
626fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
627    let values = filter_bits(array.values(), predicate);
628
629    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
630        .len(predicate.count)
631        .add_buffer(values);
632
633    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
634        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
635    }
636
637    let data = unsafe { builder.build_unchecked() };
638    BooleanArray::from(data)
639}
640
641#[inline(never)]
642fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
643    assert!(values.len() >= predicate.filter.len());
644
645    match &predicate.strategy {
646        IterationStrategy::SlicesIterator => {
647            let mut buffer = Vec::with_capacity(predicate.count);
648            for (start, end) in SlicesIterator::new(&predicate.filter) {
649                // SAFETY: indices were derived from the filter predicate
650                buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
651            }
652            buffer.into()
653        }
654        IterationStrategy::Slices(slices) => {
655            let mut buffer = Vec::with_capacity(predicate.count);
656            for (start, end) in slices {
657                // SAFETY: indices were derived from the filter predicate
658                buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
659            }
660            buffer.into()
661        }
662        IterationStrategy::IndexIterator => {
663            // SAFETY: indices were derived from the filter predicate
664            let iter = IndexIterator::new(&predicate.filter, predicate.count)
665                .map(|x| unsafe { *values.get_unchecked(x) });
666
667            // SAFETY: IndexIterator is trusted length
668            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
669        }
670        IterationStrategy::Indices(indices) => {
671            // SAFETY: indices were derived from the filter predicate
672            let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
673            iter.collect::<Vec<_>>().into()
674        }
675        IterationStrategy::All | IterationStrategy::None => unreachable!(),
676    }
677}
678
679/// `filter` implementation for primitive arrays
680fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
681where
682    T: ArrowPrimitiveType,
683{
684    let values = array.values();
685    let buffer = filter_native(values, predicate);
686    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
687        .len(predicate.count)
688        .add_buffer(buffer);
689
690    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
691        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
692    }
693
694    let data = unsafe { builder.build_unchecked() };
695    PrimitiveArray::from(data)
696}
697
698/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
699/// used to build a new [`GenericByteArray`] by copying values from the source
700///
701/// TODO(raphael): Could this be used for the take kernel as well?
702struct FilterBytes<'a, OffsetSize> {
703    src_offsets: &'a [OffsetSize],
704    src_values: &'a [u8],
705    dst_offsets: Vec<OffsetSize>,
706    dst_values: Vec<u8>,
707    cur_offset: OffsetSize,
708}
709
710impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
711where
712    OffsetSize: OffsetSizeTrait,
713{
714    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
715    where
716        T: ByteArrayType<Offset = OffsetSize>,
717    {
718        let dst_values = Vec::new();
719        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
720        let cur_offset = OffsetSize::from_usize(0).unwrap();
721
722        dst_offsets.push(cur_offset);
723
724        Self {
725            src_offsets: array.value_offsets(),
726            src_values: array.value_data(),
727            dst_offsets,
728            dst_values,
729            cur_offset,
730        }
731    }
732
733    /// Returns the byte offset at `idx`
734    #[inline]
735    fn get_value_offset(&self, idx: usize) -> usize {
736        self.src_offsets[idx].as_usize()
737    }
738
739    /// Returns the start and end of the value at index `idx` along with its length
740    #[inline]
741    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
742        // These can only fail if `array` contains invalid data
743        let start = self.get_value_offset(idx);
744        let end = self.get_value_offset(idx + 1);
745        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
746        (start, end, len)
747    }
748
749    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
750        self.dst_offsets.extend(iter.map(|idx| {
751            let start = self.src_offsets[idx].as_usize();
752            let end = self.src_offsets[idx + 1].as_usize();
753            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
754            self.cur_offset += len;
755
756            self.cur_offset
757        }));
758    }
759
760    /// Extends the in-progress array by the indexes in the provided iterator
761    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
762        self.dst_values.reserve_exact(self.cur_offset.as_usize());
763
764        for idx in iter {
765            let start = self.src_offsets[idx].as_usize();
766            let end = self.src_offsets[idx + 1].as_usize();
767            self.dst_values
768                .extend_from_slice(&self.src_values[start..end]);
769        }
770    }
771
772    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
773        self.dst_offsets.reserve_exact(count);
774        for (start, end) in iter {
775            // These can only fail if `array` contains invalid data
776            for idx in start..end {
777                let (_, _, len) = self.get_value_range(idx);
778                self.cur_offset += len;
779                self.dst_offsets.push(self.cur_offset);
780            }
781        }
782    }
783
784    /// Extends the in-progress array by the ranges in the provided iterator
785    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
786        self.dst_values.reserve_exact(self.cur_offset.as_usize());
787
788        for (start, end) in iter {
789            let value_start = self.get_value_offset(start);
790            let value_end = self.get_value_offset(end);
791            self.dst_values
792                .extend_from_slice(&self.src_values[value_start..value_end]);
793        }
794    }
795}
796
797/// `filter` implementation for byte arrays
798///
799/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
800/// data copied across. This allows handling the null mask separately from the data
801fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
802where
803    T: ByteArrayType,
804{
805    let mut filter = FilterBytes::new(predicate.count, array);
806
807    match &predicate.strategy {
808        IterationStrategy::SlicesIterator => {
809            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
810            filter.extend_slices(SlicesIterator::new(&predicate.filter))
811        }
812        IterationStrategy::Slices(slices) => {
813            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
814            filter.extend_slices(slices.iter().cloned())
815        }
816        IterationStrategy::IndexIterator => {
817            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
818            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
819        }
820        IterationStrategy::Indices(indices) => {
821            filter.extend_offsets_idx(indices.iter().cloned());
822            filter.extend_idx(indices.iter().cloned())
823        }
824        IterationStrategy::All | IterationStrategy::None => unreachable!(),
825    }
826
827    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
828        .len(predicate.count)
829        .add_buffer(filter.dst_offsets.into())
830        .add_buffer(filter.dst_values.into());
831
832    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
833        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
834    }
835
836    let data = unsafe { builder.build_unchecked() };
837    GenericByteArray::from(data)
838}
839
840/// `filter` implementation for byte view arrays.
841fn filter_byte_view<T: ByteViewType>(
842    array: &GenericByteViewArray<T>,
843    predicate: &FilterPredicate,
844) -> GenericByteViewArray<T> {
845    let new_view_buffer = filter_native(array.views(), predicate);
846
847    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
848        .len(predicate.count)
849        .add_buffer(new_view_buffer)
850        .add_buffers(array.data_buffers().to_vec());
851
852    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
853        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
854    }
855
856    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
857}
858
859fn filter_fixed_size_binary(
860    array: &FixedSizeBinaryArray,
861    predicate: &FilterPredicate,
862) -> FixedSizeBinaryArray {
863    let values: &[u8] = array.values();
864    let value_length = array.value_length() as usize;
865    let calculate_offset_from_index = |index: usize| index * value_length;
866    let buffer = match &predicate.strategy {
867        IterationStrategy::SlicesIterator => {
868            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
869            for (start, end) in SlicesIterator::new(&predicate.filter) {
870                buffer.extend_from_slice(
871                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
872                );
873            }
874            buffer
875        }
876        IterationStrategy::Slices(slices) => {
877            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
878            for (start, end) in slices {
879                buffer.extend_from_slice(
880                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
881                );
882            }
883            buffer
884        }
885        IterationStrategy::IndexIterator => {
886            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
887                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
888            });
889
890            let mut buffer = MutableBuffer::new(predicate.count * value_length);
891            iter.for_each(|item| buffer.extend_from_slice(item));
892            buffer
893        }
894        IterationStrategy::Indices(indices) => {
895            let iter = indices.iter().map(|x| {
896                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
897            });
898
899            let mut buffer = MutableBuffer::new(predicate.count * value_length);
900            iter.for_each(|item| buffer.extend_from_slice(item));
901            buffer
902        }
903        IterationStrategy::All | IterationStrategy::None => unreachable!(),
904    };
905    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
906        .len(predicate.count)
907        .add_buffer(buffer.into());
908
909    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
910        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
911    }
912
913    let data = unsafe { builder.build_unchecked() };
914    FixedSizeBinaryArray::from(data)
915}
916
917/// `filter` implementation for dictionaries
918fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
919where
920    T: ArrowDictionaryKeyType,
921    T::Native: num_traits::Num,
922{
923    let builder = filter_primitive::<T>(array.keys(), predicate)
924        .into_data()
925        .into_builder()
926        .data_type(array.data_type().clone())
927        .child_data(vec![array.values().to_data()]);
928
929    // SAFETY:
930    // Keys were valid before, filtered subset is therefore still valid
931    DictionaryArray::from(unsafe { builder.build_unchecked() })
932}
933
934/// `filter` implementation for structs
935fn filter_struct(
936    array: &StructArray,
937    predicate: &FilterPredicate,
938) -> Result<StructArray, ArrowError> {
939    let columns = array
940        .columns()
941        .iter()
942        .map(|column| filter_array(column, predicate))
943        .collect::<Result<_, _>>()?;
944
945    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
946        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
947
948        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
949    } else {
950        None
951    };
952
953    Ok(unsafe {
954        StructArray::new_unchecked_with_length(
955            array.fields().clone(),
956            columns,
957            nulls,
958            predicate.count(),
959        )
960    })
961}
962
963/// `filter` implementation for sparse unions
964fn filter_sparse_union(
965    array: &UnionArray,
966    predicate: &FilterPredicate,
967) -> Result<UnionArray, ArrowError> {
968    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
969        unreachable!()
970    };
971
972    let type_ids = filter_primitive(
973        &Int8Array::try_new(array.type_ids().clone(), None)?,
974        predicate,
975    );
976
977    let children = fields
978        .iter()
979        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
980        .collect::<Result<_, _>>()?;
981
982    Ok(unsafe {
983        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
984    })
985}
986
987/// `filter` implementation for list views
988fn filter_list_view<OffsetType: OffsetSizeTrait>(
989    array: &GenericListViewArray<OffsetType>,
990    predicate: &FilterPredicate,
991) -> GenericListViewArray<OffsetType> {
992    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
993    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
994
995    // Filter the nulls
996    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
997        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
998
999        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
1000    } else {
1001        None
1002    };
1003
1004    let list_data = ArrayDataBuilder::new(array.data_type().clone())
1005        .nulls(nulls)
1006        .buffers(vec![filtered_offsets, filtered_sizes])
1007        .child_data(vec![array.values().to_data()])
1008        .len(predicate.count);
1009
1010    let list_data = unsafe { list_data.build_unchecked() };
1011
1012    GenericListViewArray::from(list_data)
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018    use arrow_array::builder::*;
1019    use arrow_array::cast::as_run_array;
1020    use arrow_array::types::*;
1021    use arrow_data::ArrayData;
1022    use rand::distr::uniform::{UniformSampler, UniformUsize};
1023    use rand::distr::{Alphanumeric, StandardUniform};
1024    use rand::prelude::*;
1025    use rand::rng;
1026
1027    macro_rules! def_temporal_test {
1028        ($test:ident, $array_type: ident, $data: expr) => {
1029            #[test]
1030            fn $test() {
1031                let a = $data;
1032                let b = BooleanArray::from(vec![true, false, true, false]);
1033                let c = filter(&a, &b).unwrap();
1034                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1035                assert_eq!(2, d.len());
1036                assert_eq!(1, d.value(0));
1037                assert_eq!(3, d.value(1));
1038            }
1039        };
1040    }
1041
1042    def_temporal_test!(
1043        test_filter_date32,
1044        Date32Array,
1045        Date32Array::from(vec![1, 2, 3, 4])
1046    );
1047    def_temporal_test!(
1048        test_filter_date64,
1049        Date64Array,
1050        Date64Array::from(vec![1, 2, 3, 4])
1051    );
1052    def_temporal_test!(
1053        test_filter_time32_second,
1054        Time32SecondArray,
1055        Time32SecondArray::from(vec![1, 2, 3, 4])
1056    );
1057    def_temporal_test!(
1058        test_filter_time32_millisecond,
1059        Time32MillisecondArray,
1060        Time32MillisecondArray::from(vec![1, 2, 3, 4])
1061    );
1062    def_temporal_test!(
1063        test_filter_time64_microsecond,
1064        Time64MicrosecondArray,
1065        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1066    );
1067    def_temporal_test!(
1068        test_filter_time64_nanosecond,
1069        Time64NanosecondArray,
1070        Time64NanosecondArray::from(vec![1, 2, 3, 4])
1071    );
1072    def_temporal_test!(
1073        test_filter_duration_second,
1074        DurationSecondArray,
1075        DurationSecondArray::from(vec![1, 2, 3, 4])
1076    );
1077    def_temporal_test!(
1078        test_filter_duration_millisecond,
1079        DurationMillisecondArray,
1080        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1081    );
1082    def_temporal_test!(
1083        test_filter_duration_microsecond,
1084        DurationMicrosecondArray,
1085        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1086    );
1087    def_temporal_test!(
1088        test_filter_duration_nanosecond,
1089        DurationNanosecondArray,
1090        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1091    );
1092    def_temporal_test!(
1093        test_filter_timestamp_second,
1094        TimestampSecondArray,
1095        TimestampSecondArray::from(vec![1, 2, 3, 4])
1096    );
1097    def_temporal_test!(
1098        test_filter_timestamp_millisecond,
1099        TimestampMillisecondArray,
1100        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1101    );
1102    def_temporal_test!(
1103        test_filter_timestamp_microsecond,
1104        TimestampMicrosecondArray,
1105        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1106    );
1107    def_temporal_test!(
1108        test_filter_timestamp_nanosecond,
1109        TimestampNanosecondArray,
1110        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1111    );
1112
1113    #[test]
1114    fn test_filter_array_slice() {
1115        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1116        let b = BooleanArray::from(vec![true, false, false, true]);
1117        // filtering with sliced filter array is not currently supported
1118        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1119        // let b = b_slice.as_any().downcast_ref().unwrap();
1120        let c = filter(&a, &b).unwrap();
1121        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1122        assert_eq!(2, d.len());
1123        assert_eq!(6, d.value(0));
1124        assert_eq!(9, d.value(1));
1125    }
1126
1127    #[test]
1128    fn test_filter_array_low_density() {
1129        // this test exercises the all 0's branch of the filter algorithm
1130        let mut data_values = (1..=65).collect::<Vec<i32>>();
1131        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1132        // set up two more values after the batch
1133        data_values.extend_from_slice(&[66, 67]);
1134        filter_values.extend_from_slice(&[false, true]);
1135        let a = Int32Array::from(data_values);
1136        let b = BooleanArray::from(filter_values);
1137        let c = filter(&a, &b).unwrap();
1138        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1139        assert_eq!(2, d.len());
1140        assert_eq!(65, d.value(0));
1141        assert_eq!(67, d.value(1));
1142    }
1143
1144    #[test]
1145    fn test_filter_array_high_density() {
1146        // this test exercises the all 1's branch of the filter algorithm
1147        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1148        let mut filter_values = (1..=65)
1149            .map(|i| !matches!(i % 65, 0))
1150            .collect::<Vec<bool>>();
1151        // set second data value to null
1152        data_values[1] = None;
1153        // set up two more values after the batch
1154        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1155        filter_values.extend_from_slice(&[false, true, true, true]);
1156        let a = Int32Array::from(data_values);
1157        let b = BooleanArray::from(filter_values);
1158        let c = filter(&a, &b).unwrap();
1159        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1160        assert_eq!(67, d.len());
1161        assert_eq!(3, d.null_count());
1162        assert_eq!(1, d.value(0));
1163        assert!(d.is_null(1));
1164        assert_eq!(64, d.value(63));
1165        assert!(d.is_null(64));
1166        assert_eq!(67, d.value(65));
1167    }
1168
1169    #[test]
1170    fn test_filter_string_array_simple() {
1171        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1172        let b = BooleanArray::from(vec![true, false, true, false]);
1173        let c = filter(&a, &b).unwrap();
1174        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1175        assert_eq!(2, d.len());
1176        assert_eq!("hello", d.value(0));
1177        assert_eq!("world", d.value(1));
1178    }
1179
1180    #[test]
1181    fn test_filter_primitive_array_with_null() {
1182        let a = Int32Array::from(vec![Some(5), None]);
1183        let b = BooleanArray::from(vec![false, true]);
1184        let c = filter(&a, &b).unwrap();
1185        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1186        assert_eq!(1, d.len());
1187        assert!(d.is_null(0));
1188    }
1189
1190    #[test]
1191    fn test_filter_string_array_with_null() {
1192        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1193        let b = BooleanArray::from(vec![true, false, false, true]);
1194        let c = filter(&a, &b).unwrap();
1195        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1196        assert_eq!(2, d.len());
1197        assert_eq!("hello", d.value(0));
1198        assert!(!d.is_null(0));
1199        assert!(d.is_null(1));
1200    }
1201
1202    #[test]
1203    fn test_filter_binary_array_with_null() {
1204        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1205        let a = BinaryArray::from(data);
1206        let b = BooleanArray::from(vec![true, false, false, true]);
1207        let c = filter(&a, &b).unwrap();
1208        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1209        assert_eq!(2, d.len());
1210        assert_eq!(b"hello", d.value(0));
1211        assert!(!d.is_null(0));
1212        assert!(d.is_null(1));
1213    }
1214
1215    fn _test_filter_byte_view<T>()
1216    where
1217        T: ByteViewType,
1218        str: AsRef<T::Native>,
1219        T::Native: PartialEq,
1220    {
1221        let array = {
1222            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1223            let mut builder = GenericByteViewBuilder::<T>::new();
1224            builder.append_value("hello");
1225            builder.append_value("world");
1226            builder.append_null();
1227            builder.append_value("large payload over 12 bytes");
1228            builder.append_value("lulu");
1229            builder.finish()
1230        };
1231
1232        {
1233            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1234            let actual = filter(&array, &predicate).unwrap();
1235
1236            assert_eq!(actual.len(), 3);
1237
1238            let expected = {
1239                // ["hello", null, "large payload over 12 bytes"]
1240                let mut builder = GenericByteViewBuilder::<T>::new();
1241                builder.append_value("hello");
1242                builder.append_null();
1243                builder.append_value("large payload over 12 bytes");
1244                builder.finish()
1245            };
1246
1247            assert_eq!(actual.as_ref(), &expected);
1248        }
1249
1250        {
1251            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1252            let actual = filter(&array, &predicate).unwrap();
1253
1254            assert_eq!(actual.len(), 2);
1255
1256            let expected = {
1257                // ["hello", "lulu"]
1258                let mut builder = GenericByteViewBuilder::<T>::new();
1259                builder.append_value("hello");
1260                builder.append_value("lulu");
1261                builder.finish()
1262            };
1263
1264            assert_eq!(actual.as_ref(), &expected);
1265        }
1266    }
1267
1268    #[test]
1269    fn test_filter_string_view() {
1270        _test_filter_byte_view::<StringViewType>()
1271    }
1272
1273    #[test]
1274    fn test_filter_binary_view() {
1275        _test_filter_byte_view::<BinaryViewType>()
1276    }
1277
1278    #[test]
1279    fn test_filter_fixed_binary() {
1280        let v1 = [1_u8, 2];
1281        let v2 = [3_u8, 4];
1282        let v3 = [5_u8, 6];
1283        let v = vec![&v1, &v2, &v3];
1284        let a = FixedSizeBinaryArray::from(v);
1285        let b = BooleanArray::from(vec![true, false, true]);
1286        let c = filter(&a, &b).unwrap();
1287        let d = c
1288            .as_ref()
1289            .as_any()
1290            .downcast_ref::<FixedSizeBinaryArray>()
1291            .unwrap();
1292        assert_eq!(d.len(), 2);
1293        assert_eq!(d.value(0), &v1);
1294        assert_eq!(d.value(1), &v3);
1295        let c2 = FilterBuilder::new(&b)
1296            .optimize()
1297            .build()
1298            .filter(&a)
1299            .unwrap();
1300        let d2 = c2
1301            .as_ref()
1302            .as_any()
1303            .downcast_ref::<FixedSizeBinaryArray>()
1304            .unwrap();
1305        assert_eq!(d, d2);
1306
1307        let b = BooleanArray::from(vec![false, false, false]);
1308        let c = filter(&a, &b).unwrap();
1309        let d = c
1310            .as_ref()
1311            .as_any()
1312            .downcast_ref::<FixedSizeBinaryArray>()
1313            .unwrap();
1314        assert_eq!(d.len(), 0);
1315
1316        let b = BooleanArray::from(vec![true, true, true]);
1317        let c = filter(&a, &b).unwrap();
1318        let d = c
1319            .as_ref()
1320            .as_any()
1321            .downcast_ref::<FixedSizeBinaryArray>()
1322            .unwrap();
1323        assert_eq!(d.len(), 3);
1324        assert_eq!(d.value(0), &v1);
1325        assert_eq!(d.value(1), &v2);
1326        assert_eq!(d.value(2), &v3);
1327
1328        let b = BooleanArray::from(vec![false, false, true]);
1329        let c = filter(&a, &b).unwrap();
1330        let d = c
1331            .as_ref()
1332            .as_any()
1333            .downcast_ref::<FixedSizeBinaryArray>()
1334            .unwrap();
1335        assert_eq!(d.len(), 1);
1336        assert_eq!(d.value(0), &v3);
1337        let c2 = FilterBuilder::new(&b)
1338            .optimize()
1339            .build()
1340            .filter(&a)
1341            .unwrap();
1342        let d2 = c2
1343            .as_ref()
1344            .as_any()
1345            .downcast_ref::<FixedSizeBinaryArray>()
1346            .unwrap();
1347        assert_eq!(d, d2);
1348    }
1349
1350    #[test]
1351    fn test_filter_array_slice_with_null() {
1352        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1353        let b = BooleanArray::from(vec![true, false, false, true]);
1354        // filtering with sliced filter array is not currently supported
1355        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1356        // let b = b_slice.as_any().downcast_ref().unwrap();
1357        let c = filter(&a, &b).unwrap();
1358        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1359        assert_eq!(2, d.len());
1360        assert!(d.is_null(0));
1361        assert!(!d.is_null(1));
1362        assert_eq!(9, d.value(1));
1363    }
1364
1365    #[test]
1366    fn test_filter_run_end_encoding_array() {
1367        let run_ends = Int64Array::from(vec![2, 3, 8]);
1368        let values = Int64Array::from(vec![7, -2, 9]);
1369        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1370        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1371        let c = filter(&a, &b).unwrap();
1372        let actual: &RunArray<Int64Type> = as_run_array(&c);
1373        assert_eq!(4, actual.len());
1374
1375        let expected = RunArray::try_new(
1376            &Int64Array::from(vec![1, 2, 4]),
1377            &Int64Array::from(vec![7, -2, 9]),
1378        )
1379        .expect("Failed to make expected RunArray test is broken");
1380
1381        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1382        assert_eq!(actual.values(), expected.values())
1383    }
1384
1385    #[test]
1386    fn test_filter_run_end_encoding_array_sliced() {
1387        let run_ends = Int64Array::from(vec![2, 3, 8]);
1388        let values = Int64Array::from(vec![7, -2, 9]);
1389        let a = RunArray::try_new(&run_ends, &values).unwrap(); // [7, 7, -2, 9, 9, 9, 9, 9]
1390        let a = a.slice(2, 3); // [-2, 9, 9]
1391        let b = BooleanArray::from(vec![true, false, true]);
1392        let result = filter(&a, &b).unwrap();
1393
1394        let result = result.as_run::<Int64Type>();
1395        let result = result.downcast::<Int64Array>().unwrap();
1396
1397        let expected = vec![-2, 9];
1398        let actual = result.into_iter().flatten().collect::<Vec<_>>();
1399        assert_eq!(expected, actual);
1400    }
1401
1402    #[test]
1403    fn test_filter_run_end_encoding_array_remove_value() {
1404        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1405        let values = Int32Array::from(vec![7, -2, 9, -8]);
1406        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1407        let b = BooleanArray::from(vec![
1408            false, true, false, false, true, false, true, false, false, false,
1409        ]);
1410        let c = filter(&a, &b).unwrap();
1411        let actual: &RunArray<Int32Type> = as_run_array(&c);
1412        assert_eq!(3, actual.len());
1413
1414        let expected =
1415            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1416                .expect("Failed to make expected RunArray test is broken");
1417
1418        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1419        assert_eq!(actual.values(), expected.values())
1420    }
1421
1422    #[test]
1423    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1424        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1425        let values = Int16Array::from(vec![7, -2, 9, -8]);
1426        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1427        let b = BooleanArray::from(vec![
1428            false, false, false, false, false, false, true, false, false, false,
1429        ]);
1430        let c = filter(&a, &b).unwrap();
1431        let actual: &RunArray<Int16Type> = as_run_array(&c);
1432        assert_eq!(1, actual.len());
1433
1434        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1435            .expect("Failed to make expected RunArray test is broken");
1436
1437        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1438        assert_eq!(actual.values(), expected.values())
1439    }
1440
1441    #[test]
1442    fn test_filter_run_end_encoding_array_empty() {
1443        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1444        let values = Int64Array::from(vec![7, -2, 9, -8]);
1445        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1446        let b = BooleanArray::from(vec![
1447            false, false, false, false, false, false, false, false, false, false,
1448        ]);
1449        let c = filter(&a, &b).unwrap();
1450        let actual: &RunArray<Int64Type> = as_run_array(&c);
1451        assert_eq!(0, actual.len());
1452    }
1453
1454    #[test]
1455    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1456        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1457        let values = Int64Array::from(vec![7, -2, 9, -8]);
1458        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1459        let b = BooleanArray::from(vec![false, true, true]);
1460        let c = filter(&a, &b).unwrap();
1461        let actual: &RunArray<Int64Type> = as_run_array(&c);
1462        assert_eq!(2, actual.len());
1463
1464        let expected = RunArray::try_new(
1465            &Int64Array::from(vec![1, 2]),
1466            &Int64Array::from(vec![7, -2]),
1467        )
1468        .expect("Failed to make expected RunArray test is broken");
1469
1470        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1471        assert_eq!(actual.values(), expected.values())
1472    }
1473
1474    #[test]
1475    fn test_filter_dictionary_array() {
1476        let values = [Some("hello"), None, Some("world"), Some("!")];
1477        let a: Int8DictionaryArray = values.iter().copied().collect();
1478        let b = BooleanArray::from(vec![false, true, true, false]);
1479        let c = filter(&a, &b).unwrap();
1480        let d = c
1481            .as_ref()
1482            .as_any()
1483            .downcast_ref::<Int8DictionaryArray>()
1484            .unwrap();
1485        let value_array = d.values();
1486        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1487        // values are cloned in the filtered dictionary array
1488        assert_eq!(3, values.len());
1489        // but keys are filtered
1490        assert_eq!(2, d.len());
1491        assert!(d.is_null(0));
1492        assert_eq!("world", values.value(d.keys().value(1) as usize));
1493    }
1494
1495    #[test]
1496    fn test_filter_list_array() {
1497        let value_data = ArrayData::builder(DataType::Int32)
1498            .len(8)
1499            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1500            .build()
1501            .unwrap();
1502
1503        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1504
1505        let list_data_type =
1506            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1507        let list_data = ArrayData::builder(list_data_type)
1508            .len(4)
1509            .add_buffer(value_offsets)
1510            .add_child_data(value_data)
1511            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1512            .build()
1513            .unwrap();
1514
1515        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1516        let a = LargeListArray::from(list_data);
1517        let b = BooleanArray::from(vec![false, true, false, true]);
1518        let result = filter(&a, &b).unwrap();
1519
1520        // expected: [[3, 4, 5], null]
1521        let value_data = ArrayData::builder(DataType::Int32)
1522            .len(3)
1523            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1524            .build()
1525            .unwrap();
1526
1527        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1528
1529        let list_data_type =
1530            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1531        let expected = ArrayData::builder(list_data_type)
1532            .len(2)
1533            .add_buffer(value_offsets)
1534            .add_child_data(value_data)
1535            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1536            .build()
1537            .unwrap();
1538
1539        assert_eq!(&make_array(expected), &result);
1540    }
1541
1542    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1543        // [[1, 2], null, [], [3,4]]
1544        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1545        list_array.append_value([Some(1), Some(2)]);
1546        list_array.append_null();
1547        list_array.append_value([]);
1548        list_array.append_value([Some(3), Some(4)]);
1549
1550        let list_array = list_array.finish();
1551        let predicate = BooleanArray::from_iter([true, false, true, false]);
1552
1553        // Filter result: [[1, 2], []]
1554        let filtered = filter(&list_array, &predicate)
1555            .unwrap()
1556            .as_list_view::<T>()
1557            .clone();
1558
1559        let mut expected =
1560            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1561        expected.append_value([Some(1), Some(2)]);
1562        expected.append_value([]);
1563        let expected = expected.finish();
1564
1565        assert_eq!(&filtered, &expected);
1566    }
1567
1568    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1569        // [[1, 2], null, [], [3,4]]
1570        let mut list_array =
1571            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1572        list_array.append_value([Some(1), Some(2)]);
1573        list_array.append_null();
1574        list_array.append_value([]);
1575        list_array.append_value([Some(3), Some(4)]);
1576
1577        let list_array = list_array.finish();
1578
1579        // Sliced: [null, [], [3, 4]]
1580        let sliced = list_array.slice(1, 3);
1581        let predicate = BooleanArray::from_iter([false, false, true]);
1582
1583        // Filter result: [[1, 2], []]
1584        let filtered = filter(&sliced, &predicate)
1585            .unwrap()
1586            .as_list_view::<T>()
1587            .clone();
1588
1589        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1590        expected.append_value([Some(3), Some(4)]);
1591        let expected = expected.finish();
1592
1593        assert_eq!(&filtered, &expected);
1594    }
1595
1596    #[test]
1597    fn test_filter_list_view_array() {
1598        test_case_filter_list_view::<i32>();
1599        test_case_filter_list_view::<i64>();
1600
1601        test_case_filter_sliced_list_view::<i32>();
1602        test_case_filter_sliced_list_view::<i64>();
1603    }
1604
1605    #[test]
1606    fn test_slice_iterator_bits() {
1607        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1608        let filter = BooleanArray::from(filter_values);
1609        let filter_count = filter_count(&filter);
1610
1611        let iter = SlicesIterator::new(&filter);
1612        let chunks = iter.collect::<Vec<_>>();
1613
1614        assert_eq!(chunks, vec![(1, 2)]);
1615        assert_eq!(filter_count, 1);
1616    }
1617
1618    #[test]
1619    fn test_slice_iterator_bits1() {
1620        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1621        let filter = BooleanArray::from(filter_values);
1622        let filter_count = filter_count(&filter);
1623
1624        let iter = SlicesIterator::new(&filter);
1625        let chunks = iter.collect::<Vec<_>>();
1626
1627        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1628        assert_eq!(filter_count, 64 - 1);
1629    }
1630
1631    #[test]
1632    fn test_slice_iterator_chunk_and_bits() {
1633        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1634        let filter = BooleanArray::from(filter_values);
1635        let filter_count = filter_count(&filter);
1636
1637        let iter = SlicesIterator::new(&filter);
1638        let chunks = iter.collect::<Vec<_>>();
1639
1640        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1641        assert_eq!(filter_count, 61 + 61 + 5);
1642    }
1643
1644    #[test]
1645    fn test_null_mask() {
1646        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1647
1648        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1649        let out = filter(&a, &mask1).unwrap();
1650        assert_eq!(out.as_ref(), &a.slice(0, 2));
1651    }
1652
1653    #[test]
1654    fn test_filter_record_batch_no_columns() {
1655        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1656        let options = RecordBatchOptions::default().with_row_count(Some(100));
1657        let record_batch =
1658            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1659        let out = filter_record_batch(&record_batch, &pred).unwrap();
1660
1661        assert_eq!(out.num_rows(), 2);
1662    }
1663
1664    #[test]
1665    fn test_fast_path() {
1666        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1667
1668        // all true
1669        let mask = BooleanArray::from(vec![true, true, true]);
1670        let out = filter(&a, &mask).unwrap();
1671        let b = out
1672            .as_any()
1673            .downcast_ref::<PrimitiveArray<Int64Type>>()
1674            .unwrap();
1675        assert_eq!(&a, b);
1676
1677        // all false
1678        let mask = BooleanArray::from(vec![false, false, false]);
1679        let out = filter(&a, &mask).unwrap();
1680        assert_eq!(out.len(), 0);
1681        assert_eq!(out.data_type(), &DataType::Int64);
1682    }
1683
1684    #[test]
1685    fn test_slices() {
1686        // takes up 2 u64s
1687        let bools = std::iter::repeat_n(true, 10)
1688            .chain(std::iter::repeat_n(false, 30))
1689            .chain(std::iter::repeat_n(true, 20))
1690            .chain(std::iter::repeat_n(false, 17))
1691            .chain(std::iter::repeat_n(true, 4));
1692
1693        let bool_array: BooleanArray = bools.map(Some).collect();
1694
1695        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1696        let expected = vec![(0, 10), (40, 60), (77, 81)];
1697        assert_eq!(slices, expected);
1698
1699        // slice with offset and truncated len
1700        let len = bool_array.len();
1701        let sliced_array = bool_array.slice(7, len - 10);
1702        let sliced_array = sliced_array
1703            .as_any()
1704            .downcast_ref::<BooleanArray>()
1705            .unwrap();
1706        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1707        let expected = vec![(0, 3), (33, 53), (70, 71)];
1708        assert_eq!(slices, expected);
1709    }
1710
1711    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1712        let mut rng = rng();
1713
1714        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1715            .take(mask_len)
1716            .collect();
1717
1718        let buffer = Buffer::from_iter(bools.iter().cloned());
1719
1720        let truncated_length = mask_len - offset - truncate;
1721
1722        let data = ArrayDataBuilder::new(DataType::Boolean)
1723            .len(truncated_length)
1724            .offset(offset)
1725            .add_buffer(buffer)
1726            .build()
1727            .unwrap();
1728
1729        let filter = BooleanArray::from(data);
1730
1731        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1732            .flat_map(|(start, end)| start..end)
1733            .collect();
1734
1735        let count = filter_count(&filter);
1736        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1737
1738        let expected_bits: Vec<_> = bools
1739            .iter()
1740            .skip(offset)
1741            .take(truncated_length)
1742            .enumerate()
1743            .flat_map(|(idx, v)| v.then(|| idx))
1744            .collect();
1745
1746        assert_eq!(slice_bits, expected_bits);
1747        assert_eq!(index_bits, expected_bits);
1748    }
1749
1750    #[test]
1751    #[cfg_attr(miri, ignore)]
1752    fn fuzz_test_slices_iterator() {
1753        let mut rng = rng();
1754
1755        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1756        for _ in 0..100 {
1757            let mask_len = rng.random_range(0..1024);
1758            let max_offset = 64.min(mask_len);
1759            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1760
1761            let max_truncate = 128.min(mask_len - offset);
1762            let truncate = uusize
1763                .sample(&mut rng)
1764                .checked_rem(max_truncate)
1765                .unwrap_or(0);
1766
1767            test_slices_fuzz(mask_len, offset, truncate);
1768        }
1769
1770        test_slices_fuzz(64, 0, 0);
1771        test_slices_fuzz(64, 8, 0);
1772        test_slices_fuzz(64, 8, 8);
1773        test_slices_fuzz(32, 8, 8);
1774        test_slices_fuzz(32, 5, 9);
1775    }
1776
1777    /// Filters `values` by `predicate` using standard rust iterators
1778    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1779        values
1780            .into_iter()
1781            .zip(predicate)
1782            .filter(|(_, x)| **x)
1783            .map(|(a, _)| a)
1784            .collect()
1785    }
1786
1787    /// Generates an array of length `len` with `valid_percent` non-null values
1788    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1789    where
1790        StandardUniform: Distribution<T>,
1791    {
1792        let mut rng = rng();
1793        (0..len)
1794            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1795            .collect()
1796    }
1797
1798    /// Generates an array of length `len` with `valid_percent` non-null values
1799    fn gen_strings(
1800        len: usize,
1801        valid_percent: f64,
1802        str_len_range: std::ops::Range<usize>,
1803    ) -> Vec<Option<String>> {
1804        let mut rng = rng();
1805        (0..len)
1806            .map(|_| {
1807                rng.random_bool(valid_percent).then(|| {
1808                    let len = rng.random_range(str_len_range.clone());
1809                    (0..len)
1810                        .map(|_| char::from(rng.sample(Alphanumeric)))
1811                        .collect()
1812                })
1813            })
1814            .collect()
1815    }
1816
1817    /// Returns an iterator that calls `Option::as_deref` on each item
1818    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1819        src.iter().map(|x| x.as_deref())
1820    }
1821
1822    #[test]
1823    #[cfg_attr(miri, ignore)]
1824    fn fuzz_filter() {
1825        let mut rng = rng();
1826
1827        for i in 0..100 {
1828            let filter_percent = match i {
1829                0..=4 => 1.,
1830                5..=10 => 0.,
1831                _ => rng.random_range(0.0..1.0),
1832            };
1833
1834            let valid_percent = rng.random_range(0.0..1.0);
1835
1836            let array_len = rng.random_range(32..256);
1837            let array_offset = rng.random_range(0..10);
1838
1839            // Construct a predicate
1840            let filter_offset = rng.random_range(0..10);
1841            let filter_truncate = rng.random_range(0..10);
1842            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1843                .take(array_len + filter_offset - filter_truncate)
1844                .collect();
1845
1846            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1847
1848            // Offset predicate
1849            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1850            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1851            let bools = &bools[filter_offset..];
1852
1853            // Test i32
1854            let values = gen_primitive(array_len + array_offset, valid_percent);
1855            let src = Int32Array::from_iter(values.iter().cloned());
1856
1857            let src = src.slice(array_offset, array_len);
1858            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1859            let values = &values[array_offset..];
1860
1861            let filtered = filter(src, predicate).unwrap();
1862            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1863            let actual: Vec<_> = array.iter().collect();
1864
1865            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1866
1867            // Test string
1868            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1869            let src = StringArray::from_iter(as_deref(&strings));
1870
1871            let src = src.slice(array_offset, array_len);
1872            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1873
1874            let filtered = filter(src, predicate).unwrap();
1875            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1876            let actual: Vec<_> = array.iter().collect();
1877
1878            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1879            assert_eq!(actual, expected_strings);
1880
1881            // Test string dictionary
1882            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1883
1884            let src = src.slice(array_offset, array_len);
1885            let src = src
1886                .as_any()
1887                .downcast_ref::<DictionaryArray<Int32Type>>()
1888                .unwrap();
1889
1890            let filtered = filter(src, predicate).unwrap();
1891
1892            let array = filtered
1893                .as_any()
1894                .downcast_ref::<DictionaryArray<Int32Type>>()
1895                .unwrap();
1896
1897            let values = array
1898                .values()
1899                .as_any()
1900                .downcast_ref::<StringArray>()
1901                .unwrap();
1902
1903            let actual: Vec<_> = array
1904                .keys()
1905                .iter()
1906                .map(|key| key.map(|key| values.value(key as usize)))
1907                .collect();
1908
1909            assert_eq!(actual, expected_strings);
1910        }
1911    }
1912
1913    #[test]
1914    fn test_filter_map() {
1915        let mut builder =
1916            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1917        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1918        builder.keys().append_value("key1");
1919        builder.values().append_value(1);
1920        builder.append(true).unwrap();
1921        builder.keys().append_value("key2");
1922        builder.keys().append_value("key3");
1923        builder.values().append_value(2);
1924        builder.values().append_value(3);
1925        builder.append(true).unwrap();
1926        builder.append(false).unwrap();
1927        builder.keys().append_value("key1");
1928        builder.values().append_value(1);
1929        builder.append(true).unwrap();
1930        let maparray = Arc::new(builder.finish()) as ArrayRef;
1931
1932        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1933            .into_iter()
1934            .collect::<BooleanArray>();
1935        let got = filter(&maparray, &indices).unwrap();
1936
1937        let mut builder =
1938            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1939        builder.keys().append_value("key1");
1940        builder.values().append_value(1);
1941        builder.append(true).unwrap();
1942        builder.keys().append_value("key1");
1943        builder.values().append_value(1);
1944        builder.append(true).unwrap();
1945        let expected = Arc::new(builder.finish()) as ArrayRef;
1946
1947        assert_eq!(&expected, &got);
1948    }
1949
1950    #[test]
1951    fn test_filter_fixed_size_list_arrays() {
1952        let value_data = ArrayData::builder(DataType::Int32)
1953            .len(9)
1954            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1955            .build()
1956            .unwrap();
1957        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1958        let list_data = ArrayData::builder(list_data_type)
1959            .len(3)
1960            .add_child_data(value_data)
1961            .build()
1962            .unwrap();
1963        let array = FixedSizeListArray::from(list_data);
1964
1965        let filter_array = BooleanArray::from(vec![true, false, false]);
1966
1967        let c = filter(&array, &filter_array).unwrap();
1968        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1969
1970        assert_eq!(filtered.len(), 1);
1971
1972        let list = filtered.value(0);
1973        assert_eq!(
1974            &[0, 1, 2],
1975            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1976        );
1977
1978        let filter_array = BooleanArray::from(vec![true, false, true]);
1979
1980        let c = filter(&array, &filter_array).unwrap();
1981        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1982
1983        assert_eq!(filtered.len(), 2);
1984
1985        let list = filtered.value(0);
1986        assert_eq!(
1987            &[0, 1, 2],
1988            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1989        );
1990        let list = filtered.value(1);
1991        assert_eq!(
1992            &[6, 7, 8],
1993            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1994        );
1995    }
1996
1997    #[test]
1998    fn test_filter_fixed_size_list_arrays_with_null() {
1999        let value_data = ArrayData::builder(DataType::Int32)
2000            .len(10)
2001            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
2002            .build()
2003            .unwrap();
2004
2005        // Set null buts for the nested array:
2006        //  [[0, 1], null, null, [6, 7], [8, 9]]
2007        // 01011001 00000001
2008        let mut null_bits: [u8; 1] = [0; 1];
2009        bit_util::set_bit(&mut null_bits, 0);
2010        bit_util::set_bit(&mut null_bits, 3);
2011        bit_util::set_bit(&mut null_bits, 4);
2012
2013        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
2014        let list_data = ArrayData::builder(list_data_type)
2015            .len(5)
2016            .add_child_data(value_data)
2017            .null_bit_buffer(Some(Buffer::from(null_bits)))
2018            .build()
2019            .unwrap();
2020        let array = FixedSizeListArray::from(list_data);
2021
2022        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
2023
2024        let c = filter(&array, &filter_array).unwrap();
2025        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
2026
2027        assert_eq!(filtered.len(), 3);
2028
2029        let list = filtered.value(0);
2030        assert_eq!(
2031            &[0, 1],
2032            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2033        );
2034        assert!(filtered.is_null(1));
2035        let list = filtered.value(2);
2036        assert_eq!(
2037            &[6, 7],
2038            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
2039        );
2040    }
2041
2042    fn test_filter_union_array(array: UnionArray) {
2043        let filter_array = BooleanArray::from(vec![true, false, false]);
2044        let c = filter(&array, &filter_array).unwrap();
2045        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2046
2047        let mut builder = UnionBuilder::new_dense();
2048        builder.append::<Int32Type>("A", 1).unwrap();
2049        let expected_array = builder.build().unwrap();
2050
2051        compare_union_arrays(filtered, &expected_array);
2052
2053        let filter_array = BooleanArray::from(vec![true, false, true]);
2054        let c = filter(&array, &filter_array).unwrap();
2055        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2056
2057        let mut builder = UnionBuilder::new_dense();
2058        builder.append::<Int32Type>("A", 1).unwrap();
2059        builder.append::<Int32Type>("A", 34).unwrap();
2060        let expected_array = builder.build().unwrap();
2061
2062        compare_union_arrays(filtered, &expected_array);
2063
2064        let filter_array = BooleanArray::from(vec![true, true, false]);
2065        let c = filter(&array, &filter_array).unwrap();
2066        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2067
2068        let mut builder = UnionBuilder::new_dense();
2069        builder.append::<Int32Type>("A", 1).unwrap();
2070        builder.append::<Float64Type>("B", 3.2).unwrap();
2071        let expected_array = builder.build().unwrap();
2072
2073        compare_union_arrays(filtered, &expected_array);
2074    }
2075
2076    #[test]
2077    fn test_filter_union_array_dense() {
2078        let mut builder = UnionBuilder::new_dense();
2079        builder.append::<Int32Type>("A", 1).unwrap();
2080        builder.append::<Float64Type>("B", 3.2).unwrap();
2081        builder.append::<Int32Type>("A", 34).unwrap();
2082        let array = builder.build().unwrap();
2083
2084        test_filter_union_array(array);
2085    }
2086
2087    #[test]
2088    fn test_filter_run_union_array_dense() {
2089        let mut builder = UnionBuilder::new_dense();
2090        builder.append::<Int32Type>("A", 1).unwrap();
2091        builder.append::<Int32Type>("A", 3).unwrap();
2092        builder.append::<Int32Type>("A", 34).unwrap();
2093        let array = builder.build().unwrap();
2094
2095        let filter_array = BooleanArray::from(vec![true, true, false]);
2096        let c = filter(&array, &filter_array).unwrap();
2097        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2098
2099        let mut builder = UnionBuilder::new_dense();
2100        builder.append::<Int32Type>("A", 1).unwrap();
2101        builder.append::<Int32Type>("A", 3).unwrap();
2102        let expected = builder.build().unwrap();
2103
2104        assert_eq!(filtered.to_data(), expected.to_data());
2105    }
2106
2107    #[test]
2108    fn test_filter_union_array_dense_with_nulls() {
2109        let mut builder = UnionBuilder::new_dense();
2110        builder.append::<Int32Type>("A", 1).unwrap();
2111        builder.append::<Float64Type>("B", 3.2).unwrap();
2112        builder.append_null::<Float64Type>("B").unwrap();
2113        builder.append::<Int32Type>("A", 34).unwrap();
2114        let array = builder.build().unwrap();
2115
2116        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2117        let c = filter(&array, &filter_array).unwrap();
2118        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2119
2120        let mut builder = UnionBuilder::new_dense();
2121        builder.append::<Int32Type>("A", 1).unwrap();
2122        builder.append::<Float64Type>("B", 3.2).unwrap();
2123        let expected_array = builder.build().unwrap();
2124
2125        compare_union_arrays(filtered, &expected_array);
2126
2127        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2128        let c = filter(&array, &filter_array).unwrap();
2129        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2130
2131        let mut builder = UnionBuilder::new_dense();
2132        builder.append::<Int32Type>("A", 1).unwrap();
2133        builder.append_null::<Float64Type>("B").unwrap();
2134        let expected_array = builder.build().unwrap();
2135
2136        compare_union_arrays(filtered, &expected_array);
2137    }
2138
2139    #[test]
2140    fn test_filter_union_array_sparse() {
2141        let mut builder = UnionBuilder::new_sparse();
2142        builder.append::<Int32Type>("A", 1).unwrap();
2143        builder.append::<Float64Type>("B", 3.2).unwrap();
2144        builder.append::<Int32Type>("A", 34).unwrap();
2145        let array = builder.build().unwrap();
2146
2147        test_filter_union_array(array);
2148    }
2149
2150    #[test]
2151    fn test_filter_union_array_sparse_with_nulls() {
2152        let mut builder = UnionBuilder::new_sparse();
2153        builder.append::<Int32Type>("A", 1).unwrap();
2154        builder.append::<Float64Type>("B", 3.2).unwrap();
2155        builder.append_null::<Float64Type>("B").unwrap();
2156        builder.append::<Int32Type>("A", 34).unwrap();
2157        let array = builder.build().unwrap();
2158
2159        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2160        let c = filter(&array, &filter_array).unwrap();
2161        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2162
2163        let mut builder = UnionBuilder::new_sparse();
2164        builder.append::<Int32Type>("A", 1).unwrap();
2165        builder.append_null::<Float64Type>("B").unwrap();
2166        let expected_array = builder.build().unwrap();
2167
2168        compare_union_arrays(filtered, &expected_array);
2169    }
2170
2171    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2172        assert_eq!(union1.len(), union2.len());
2173
2174        for i in 0..union1.len() {
2175            let type_id = union1.type_id(i);
2176
2177            let slot1 = union1.value(i);
2178            let slot2 = union2.value(i);
2179
2180            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2181
2182            if !slot1.is_null(0) && !slot2.is_null(0) {
2183                match type_id {
2184                    0 => {
2185                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2186                        assert_eq!(slot1.len(), 1);
2187                        let value1 = slot1.value(0);
2188
2189                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2190                        assert_eq!(slot2.len(), 1);
2191                        let value2 = slot2.value(0);
2192                        assert_eq!(value1, value2);
2193                    }
2194                    1 => {
2195                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2196                        assert_eq!(slot1.len(), 1);
2197                        let value1 = slot1.value(0);
2198
2199                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2200                        assert_eq!(slot2.len(), 1);
2201                        let value2 = slot2.value(0);
2202                        assert_eq!(value1, value2);
2203                    }
2204                    _ => unreachable!(),
2205                }
2206            }
2207        }
2208    }
2209
2210    #[test]
2211    fn test_filter_struct() {
2212        let predicate = BooleanArray::from(vec![true, false, true, false]);
2213
2214        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2215        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2216
2217        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2218        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2219
2220        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2221        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2222
2223        let a_field = Field::new("a", DataType::Utf8, false);
2224        let b_field = Field::new("b", DataType::Int32, false);
2225
2226        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2227        let expected =
2228            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2229
2230        let result = filter(&array, &predicate).unwrap();
2231
2232        assert_eq!(result.to_data(), expected.to_data());
2233
2234        let array = StructArray::new(
2235            vec![a_field.clone()].into(),
2236            vec![a.clone()],
2237            Some(null_mask.clone()),
2238        );
2239        let expected = StructArray::new(
2240            vec![a_field.clone()].into(),
2241            vec![a_filtered.clone()],
2242            Some(null_mask_filtered.clone()),
2243        );
2244
2245        let result = filter(&array, &predicate).unwrap();
2246
2247        assert_eq!(result.to_data(), expected.to_data());
2248
2249        let array = StructArray::new(
2250            vec![a_field.clone(), b_field.clone()].into(),
2251            vec![a.clone(), b.clone()],
2252            None,
2253        );
2254        let expected = StructArray::new(
2255            vec![a_field.clone(), b_field.clone()].into(),
2256            vec![a_filtered.clone(), b_filtered.clone()],
2257            None,
2258        );
2259
2260        let result = filter(&array, &predicate).unwrap();
2261
2262        assert_eq!(result.to_data(), expected.to_data());
2263
2264        let array = StructArray::new(
2265            vec![a_field.clone(), b_field.clone()].into(),
2266            vec![a.clone(), b.clone()],
2267            Some(null_mask.clone()),
2268        );
2269
2270        let expected = StructArray::new(
2271            vec![a_field.clone(), b_field.clone()].into(),
2272            vec![a_filtered.clone(), b_filtered.clone()],
2273            Some(null_mask_filtered.clone()),
2274        );
2275
2276        let result = filter(&array, &predicate).unwrap();
2277
2278        assert_eq!(result.to_data(), expected.to_data());
2279    }
2280
2281    #[test]
2282    fn test_filter_empty_struct() {
2283        /*
2284            "a": {
2285                "b": int64,
2286                "c": {}
2287            },
2288        */
2289        let fields = arrow_schema::Field::new(
2290            "a",
2291            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2292                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2293                arrow_schema::Field::new(
2294                    "c",
2295                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2296                    true,
2297                ),
2298            ])),
2299            true,
2300        );
2301
2302        /* Test record
2303            {"a":{"c": {}}}
2304            {"a":{"c": {}}}
2305            {"a":{"c": {}}}
2306        */
2307
2308        // Create the record batch with the nested struct array
2309        let schema = Arc::new(Schema::new(vec![fields]));
2310
2311        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2312        let c = Arc::new(StructArray::new_empty_fields(
2313            3,
2314            Some(NullBuffer::from(vec![true, true, true])),
2315        ));
2316        let a = StructArray::new(
2317            vec![
2318                Field::new("b", DataType::Int64, true),
2319                Field::new("c", DataType::Struct(Fields::empty()), true),
2320            ]
2321            .into(),
2322            vec![b.clone(), c.clone()],
2323            Some(NullBuffer::from(vec![true, true, true])),
2324        );
2325        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2326        println!("{record_batch:?}");
2327
2328        // Apply the filter
2329        let predicate = BooleanArray::from(vec![true, false, true]);
2330        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2331
2332        // The filtered batch should have 2 rows (the 1st and 3rd)
2333        assert_eq!(filtered_batch.num_rows(), 2);
2334    }
2335
2336    #[test]
2337    #[should_panic]
2338    fn test_filter_bits_too_large() {
2339        let buffer = BooleanBuffer::from(vec![false; 8]);
2340        let predicate = BooleanArray::from(vec![true; 9]);
2341        let filter = FilterBuilder::new(&predicate).build();
2342        filter_bits(&buffer, &filter);
2343    }
2344
2345    #[test]
2346    #[should_panic]
2347    fn test_filter_native_too_large() {
2348        let values = vec![1; 8];
2349        let predicate = BooleanArray::from(vec![false; 9]);
2350        let filter = FilterBuilder::new(&predicate).build();
2351        filter_native(&values, &filter);
2352    }
2353}