arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        filter.values().into()
62    }
63}
64
65impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66    fn from(filter: &'a BooleanBuffer) -> Self {
67        Self(filter.set_slices())
68    }
69}
70
71impl Iterator for SlicesIterator<'_> {
72    type Item = (usize, usize);
73
74    fn next(&mut self) -> Option<Self::Item> {
75        self.0.next()
76    }
77}
78
79/// An iterator of `usize` whose index in [`BooleanArray`] is true
80///
81/// This provides the best performance on most predicates, apart from those which keep
82/// large runs and therefore favour [`SlicesIterator`]
83struct IndexIterator<'a> {
84    remaining: usize,
85    iter: BitIndexIterator<'a>,
86}
87
88impl<'a> IndexIterator<'a> {
89    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90        assert_eq!(filter.null_count(), 0);
91        let iter = filter.values().set_indices();
92        Self { remaining, iter }
93    }
94}
95
96impl Iterator for IndexIterator<'_> {
97    type Item = usize;
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if self.remaining != 0 {
101            // Fascinatingly swapping these two lines around results in a 50%
102            // performance regression for some benchmarks
103            let next = self.iter.next().expect("IndexIterator exhausted early");
104            self.remaining -= 1;
105            // Must panic if exhausted early as trusted length iterator
106            return Some(next);
107        }
108        None
109    }
110
111    fn size_hint(&self) -> (usize, Option<usize>) {
112        (self.remaining, Some(self.remaining))
113    }
114}
115
116/// Counts the number of set bits in `filter`
117fn filter_count(filter: &BooleanArray) -> usize {
118    filter.values().count_set_bits()
119}
120
121/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
122pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
123    let nulls = filter.nulls().unwrap();
124    let mask = filter.values() & nulls.inner();
125    BooleanArray::new(mask, None)
126}
127
128/// Returns a filtered `values` [`Array`] where the corresponding elements of
129/// `predicate` are `true`.
130///
131/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
132/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
133/// calling [FilterPredicate::filter_record_batch].
134///
135/// In contrast to this function, it is then the responsibility of the caller
136/// to use [FilterBuilder::optimize] if appropriate.
137///
138/// # See also
139/// * [`FilterBuilder`] for more control over the filtering process.
140/// * [`filter_record_batch`] to filter a [`RecordBatch`]
141/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
142///   the results into a single array.
143///
144/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
145///
146/// # Example
147/// ```rust
148/// # use arrow_array::{Int32Array, BooleanArray};
149/// # use arrow_select::filter::filter;
150/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
151/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
152/// let c = filter(&array, &filter_array).unwrap();
153/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
154/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
155/// ```
156pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
157    let mut filter_builder = FilterBuilder::new(predicate);
158
159    if FilterBuilder::is_optimize_beneficial(values.data_type()) {
160        // Only optimize if filtering more than one array
161        // Otherwise, the overhead of optimization can be more than the benefit
162        filter_builder = filter_builder.optimize();
163    }
164
165    let predicate = filter_builder.build();
166
167    filter_array(values, &predicate)
168}
169
170/// Returns a filtered [RecordBatch] where the corresponding elements of
171/// `predicate` are true.
172///
173/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
174///
175/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
176/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
177/// calling [FilterPredicate::filter_record_batch].
178/// In contrast to this function, it is then the responsibility of the caller
179/// to use [FilterBuilder::optimize] if appropriate.
180pub fn filter_record_batch(
181    record_batch: &RecordBatch,
182    predicate: &BooleanArray,
183) -> Result<RecordBatch, ArrowError> {
184    let mut filter_builder = FilterBuilder::new(predicate);
185    let num_cols = record_batch.num_columns();
186    if num_cols > 1
187        || (num_cols > 0
188            && FilterBuilder::is_optimize_beneficial(
189                record_batch.schema_ref().field(0).data_type(),
190            ))
191    {
192        // Only optimize if filtering more than one column or if the column contains multiple internal arrays
193        // Otherwise, the overhead of optimization can be more than the benefit
194        filter_builder = filter_builder.optimize();
195    }
196    let filter = filter_builder.build();
197
198    filter.filter_record_batch(record_batch)
199}
200
201/// A builder to construct [`FilterPredicate`]
202#[derive(Debug)]
203pub struct FilterBuilder {
204    filter: BooleanArray,
205    count: usize,
206    strategy: IterationStrategy,
207}
208
209impl FilterBuilder {
210    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
211    pub fn new(filter: &BooleanArray) -> Self {
212        let filter = match filter.null_count() {
213            0 => filter.clone(),
214            _ => prep_null_mask_filter(filter),
215        };
216
217        let count = filter_count(&filter);
218        let strategy = IterationStrategy::default_strategy(filter.len(), count);
219
220        Self {
221            filter,
222            count,
223            strategy,
224        }
225    }
226
227    /// Compute an optimized representation of the provided `filter` mask that can be
228    /// applied to an array more quickly.
229    ///
230    /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a
231    /// [`StructArray`] with multiple fields), optimizing the filter can provide
232    /// significant performance benefits.
233    ///
234    /// However, optimization takes time and can have a larger memory footprint
235    /// than the original mask, so it is often faster to filter a single array,
236    /// without filter optimization.
237    pub fn optimize(mut self) -> Self {
238        match self.strategy {
239            IterationStrategy::SlicesIterator => {
240                let slices = SlicesIterator::new(&self.filter).collect();
241                self.strategy = IterationStrategy::Slices(slices)
242            }
243            IterationStrategy::IndexIterator => {
244                let indices = IndexIterator::new(&self.filter, self.count).collect();
245                self.strategy = IterationStrategy::Indices(indices)
246            }
247            _ => {}
248        }
249        self
250    }
251
252    /// Determines if calling [FilterBuilder::optimize] is beneficial for the
253    /// given type even when filtering just a single array.
254    ///
255    /// See [`FilterBuilder::optimize`] for more details.
256    pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
257        match data_type {
258            DataType::Struct(fields) => {
259                fields.len() > 1
260                    || fields.len() == 1
261                        && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
262            }
263            DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
264            _ => false,
265        }
266    }
267
268    /// Construct the final `FilterPredicate`
269    pub fn build(self) -> FilterPredicate {
270        FilterPredicate {
271            filter: self.filter,
272            count: self.count,
273            strategy: self.strategy,
274        }
275    }
276}
277
278/// The iteration strategy used to evaluate [`FilterPredicate`]
279#[derive(Debug)]
280enum IterationStrategy {
281    /// A lazily evaluated iterator of ranges
282    SlicesIterator,
283    /// A lazily evaluated iterator of indices
284    IndexIterator,
285    /// A precomputed list of indices
286    Indices(Vec<usize>),
287    /// A precomputed array of ranges
288    Slices(Vec<(usize, usize)>),
289    /// Select all rows
290    All,
291    /// Select no rows
292    None,
293}
294
295impl IterationStrategy {
296    /// The default [`IterationStrategy`] for a filter of length `filter_length`
297    /// and selecting `filter_count` rows
298    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
299        if filter_length == 0 || filter_count == 0 {
300            return IterationStrategy::None;
301        }
302
303        if filter_count == filter_length {
304            return IterationStrategy::All;
305        }
306
307        // Compute the selectivity of the predicate by dividing the number of true
308        // bits in the predicate by the predicate's total length
309        //
310        // This can then be used as a heuristic for the optimal iteration strategy
311        let selectivity_frac = filter_count as f64 / filter_length as f64;
312        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
313            return IterationStrategy::SlicesIterator;
314        }
315        IterationStrategy::IndexIterator
316    }
317}
318
319/// A filtering predicate that can be applied to an [`Array`]
320#[derive(Debug)]
321pub struct FilterPredicate {
322    filter: BooleanArray,
323    count: usize,
324    strategy: IterationStrategy,
325}
326
327impl FilterPredicate {
328    /// Selects rows from `values` based on this [`FilterPredicate`]
329    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
330        filter_array(values, self)
331    }
332
333    /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
334    /// [`FilterPredicate`].
335    ///
336    /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
337    pub fn filter_record_batch(
338        &self,
339        record_batch: &RecordBatch,
340    ) -> Result<RecordBatch, ArrowError> {
341        let filtered_arrays = record_batch
342            .columns()
343            .iter()
344            .map(|a| filter_array(a, self))
345            .collect::<Result<Vec<_>, _>>()?;
346
347        // SAFETY: we know that the set of filtered arrays will match the schema of the original
348        // record batch
349        unsafe {
350            Ok(RecordBatch::new_unchecked(
351                record_batch.schema(),
352                filtered_arrays,
353                self.count,
354            ))
355        }
356    }
357
358    /// Number of rows being selected based on this [`FilterPredicate`]
359    pub fn count(&self) -> usize {
360        self.count
361    }
362}
363
364fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
365    if predicate.filter.len() > values.len() {
366        return Err(ArrowError::InvalidArgumentError(format!(
367            "Filter predicate of length {} is larger than target array of length {}",
368            predicate.filter.len(),
369            values.len()
370        )));
371    }
372
373    match predicate.strategy {
374        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
375        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
376        // actually filter
377        _ => downcast_primitive_array! {
378            values => Ok(Arc::new(filter_primitive(values, predicate))),
379            DataType::Boolean => {
380                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
381                Ok(Arc::new(filter_boolean(values, predicate)))
382            }
383            DataType::Utf8 => {
384                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
385            }
386            DataType::LargeUtf8 => {
387                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
388            }
389            DataType::Utf8View => {
390                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
391            }
392            DataType::Binary => {
393                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
394            }
395            DataType::LargeBinary => {
396                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
397            }
398            DataType::BinaryView => {
399                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
400            }
401            DataType::FixedSizeBinary(_) => {
402                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
403            }
404            DataType::ListView(_) => {
405                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
406            }
407            DataType::LargeListView(_) => {
408                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
409            }
410            DataType::RunEndEncoded(_, _) => {
411                downcast_run_array!{
412                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
413                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
414                }
415            }
416            DataType::Dictionary(_, _) => downcast_dictionary_array! {
417                values => Ok(Arc::new(filter_dict(values, predicate))),
418                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
419            }
420            DataType::Struct(_) => {
421                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
422            }
423            DataType::Union(_, UnionMode::Sparse) => {
424                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
425            }
426            _ => {
427                let data = values.to_data();
428                // fallback to using MutableArrayData
429                let mut mutable = MutableArrayData::new(
430                    vec![&data],
431                    false,
432                    predicate.count,
433                );
434
435                match &predicate.strategy {
436                    IterationStrategy::Slices(slices) => {
437                        slices
438                            .iter()
439                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
440                    }
441                    _ => {
442                        let iter = SlicesIterator::new(&predicate.filter);
443                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
444                    }
445                }
446
447                let data = mutable.freeze();
448                Ok(make_array(data))
449            }
450        },
451    }
452}
453
454/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
455fn filter_run_end_array<R: RunEndIndexType>(
456    array: &RunArray<R>,
457    predicate: &FilterPredicate,
458) -> Result<RunArray<R>, ArrowError>
459where
460    R::Native: Into<i64> + From<bool>,
461    R::Native: AddAssign,
462{
463    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
464    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
465
466    let mut start = 0u64;
467    let mut j = 0;
468    let mut count = R::default_value();
469    let filter_values = predicate.filter.values();
470    let run_ends = run_ends.inner();
471
472    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
473        let mut keep = false;
474        let mut end = run_ends[i].into() as u64;
475        let difference = end.saturating_sub(filter_values.len() as u64);
476        end -= difference;
477
478        // Safety: we subtract the difference off `end` so we are always within bounds
479        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
480            count += R::Native::from(pred);
481            keep |= pred
482        }
483        // this is to avoid branching
484        new_run_ends[j] = count;
485        j += keep as usize;
486
487        start = end;
488        keep
489    })
490    .into();
491
492    new_run_ends.truncate(j);
493
494    let values = array.values();
495    let values = filter(&values, &pred)?;
496
497    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
498    RunArray::try_new(&run_ends, &values)
499}
500
501/// Computes a new null mask for `data` based on `predicate`
502///
503/// If the predicate selected no null-rows, returns `None`, otherwise returns
504/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
505/// in the filtered output, and `null_buffer` is the filtered null buffer
506///
507fn filter_null_mask(
508    nulls: Option<&NullBuffer>,
509    predicate: &FilterPredicate,
510) -> Option<(usize, Buffer)> {
511    let nulls = nulls?;
512    if nulls.null_count() == 0 {
513        return None;
514    }
515
516    let nulls = filter_bits(nulls.inner(), predicate);
517    // The filtered `nulls` has a length of `predicate.count` bits and
518    // therefore the null count is this minus the number of valid bits
519    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
520
521    if null_count == 0 {
522        return None;
523    }
524
525    Some((null_count, nulls))
526}
527
528/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
529fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
530    let src = buffer.values();
531    let offset = buffer.offset();
532    assert!(buffer.len() >= predicate.filter.len());
533
534    match &predicate.strategy {
535        IterationStrategy::IndexIterator => {
536            let bits =
537                // SAFETY: IndexIterator uses the filter predicate to derive indices
538                IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
539                    bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
540                });
541
542            // SAFETY: `IndexIterator` reports its size correctly
543            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
544        }
545        IterationStrategy::Indices(indices) => {
546            // SAFETY: indices were derived from the filter predicate
547            let bits = indices.iter().map(|src_idx| unsafe {
548                bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
549            });
550            // SAFETY: `Vec::iter()` reports its size correctly
551            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
552        }
553        IterationStrategy::SlicesIterator => {
554            let mut builder = BooleanBufferBuilder::new(predicate.count);
555            for (start, end) in SlicesIterator::new(&predicate.filter) {
556                builder.append_packed_range(start + offset..end + offset, src)
557            }
558            builder.into()
559        }
560        IterationStrategy::Slices(slices) => {
561            let mut builder = BooleanBufferBuilder::new(predicate.count);
562            for (start, end) in slices {
563                builder.append_packed_range(*start + offset..*end + offset, src)
564            }
565            builder.into()
566        }
567        IterationStrategy::All | IterationStrategy::None => unreachable!(),
568    }
569}
570
571/// `filter` implementation for boolean buffers
572fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
573    let values = filter_bits(array.values(), predicate);
574
575    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
576        .len(predicate.count)
577        .add_buffer(values);
578
579    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
580        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
581    }
582
583    let data = unsafe { builder.build_unchecked() };
584    BooleanArray::from(data)
585}
586
587#[inline(never)]
588fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
589    assert!(values.len() >= predicate.filter.len());
590
591    match &predicate.strategy {
592        IterationStrategy::SlicesIterator => {
593            let mut buffer = Vec::with_capacity(predicate.count);
594            for (start, end) in SlicesIterator::new(&predicate.filter) {
595                // SAFETY: indices were derived from the filter predicate
596                buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
597            }
598            buffer.into()
599        }
600        IterationStrategy::Slices(slices) => {
601            let mut buffer = Vec::with_capacity(predicate.count);
602            for (start, end) in slices {
603                // SAFETY: indices were derived from the filter predicate
604                buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
605            }
606            buffer.into()
607        }
608        IterationStrategy::IndexIterator => {
609            // SAFETY: indices were derived from the filter predicate
610            let iter = IndexIterator::new(&predicate.filter, predicate.count)
611                .map(|x| unsafe { *values.get_unchecked(x) });
612
613            // SAFETY: IndexIterator is trusted length
614            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
615        }
616        IterationStrategy::Indices(indices) => {
617            // SAFETY: indices were derived from the filter predicate
618            let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
619            iter.collect::<Vec<_>>().into()
620        }
621        IterationStrategy::All | IterationStrategy::None => unreachable!(),
622    }
623}
624
625/// `filter` implementation for primitive arrays
626fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
627where
628    T: ArrowPrimitiveType,
629{
630    let values = array.values();
631    let buffer = filter_native(values, predicate);
632    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
633        .len(predicate.count)
634        .add_buffer(buffer);
635
636    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
637        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
638    }
639
640    let data = unsafe { builder.build_unchecked() };
641    PrimitiveArray::from(data)
642}
643
644/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
645/// used to build a new [`GenericByteArray`] by copying values from the source
646///
647/// TODO(raphael): Could this be used for the take kernel as well?
648struct FilterBytes<'a, OffsetSize> {
649    src_offsets: &'a [OffsetSize],
650    src_values: &'a [u8],
651    dst_offsets: Vec<OffsetSize>,
652    dst_values: Vec<u8>,
653    cur_offset: OffsetSize,
654}
655
656impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
657where
658    OffsetSize: OffsetSizeTrait,
659{
660    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
661    where
662        T: ByteArrayType<Offset = OffsetSize>,
663    {
664        let dst_values = Vec::new();
665        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
666        let cur_offset = OffsetSize::from_usize(0).unwrap();
667
668        dst_offsets.push(cur_offset);
669
670        Self {
671            src_offsets: array.value_offsets(),
672            src_values: array.value_data(),
673            dst_offsets,
674            dst_values,
675            cur_offset,
676        }
677    }
678
679    /// Returns the byte offset at `idx`
680    #[inline]
681    fn get_value_offset(&self, idx: usize) -> usize {
682        self.src_offsets[idx].as_usize()
683    }
684
685    /// Returns the start and end of the value at index `idx` along with its length
686    #[inline]
687    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
688        // These can only fail if `array` contains invalid data
689        let start = self.get_value_offset(idx);
690        let end = self.get_value_offset(idx + 1);
691        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
692        (start, end, len)
693    }
694
695    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
696        self.dst_offsets.extend(iter.map(|idx| {
697            let start = self.src_offsets[idx].as_usize();
698            let end = self.src_offsets[idx + 1].as_usize();
699            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
700            self.cur_offset += len;
701
702            self.cur_offset
703        }));
704    }
705
706    /// Extends the in-progress array by the indexes in the provided iterator
707    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
708        self.dst_values.reserve_exact(self.cur_offset.as_usize());
709
710        for idx in iter {
711            let start = self.src_offsets[idx].as_usize();
712            let end = self.src_offsets[idx + 1].as_usize();
713            self.dst_values
714                .extend_from_slice(&self.src_values[start..end]);
715        }
716    }
717
718    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
719        self.dst_offsets.reserve_exact(count);
720        for (start, end) in iter {
721            // These can only fail if `array` contains invalid data
722            for idx in start..end {
723                let (_, _, len) = self.get_value_range(idx);
724                self.cur_offset += len;
725                self.dst_offsets.push(self.cur_offset);
726            }
727        }
728    }
729
730    /// Extends the in-progress array by the ranges in the provided iterator
731    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
732        self.dst_values.reserve_exact(self.cur_offset.as_usize());
733
734        for (start, end) in iter {
735            let value_start = self.get_value_offset(start);
736            let value_end = self.get_value_offset(end);
737            self.dst_values
738                .extend_from_slice(&self.src_values[value_start..value_end]);
739        }
740    }
741}
742
743/// `filter` implementation for byte arrays
744///
745/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
746/// data copied across. This allows handling the null mask separately from the data
747fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
748where
749    T: ByteArrayType,
750{
751    let mut filter = FilterBytes::new(predicate.count, array);
752
753    match &predicate.strategy {
754        IterationStrategy::SlicesIterator => {
755            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
756            filter.extend_slices(SlicesIterator::new(&predicate.filter))
757        }
758        IterationStrategy::Slices(slices) => {
759            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
760            filter.extend_slices(slices.iter().cloned())
761        }
762        IterationStrategy::IndexIterator => {
763            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
764            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
765        }
766        IterationStrategy::Indices(indices) => {
767            filter.extend_offsets_idx(indices.iter().cloned());
768            filter.extend_idx(indices.iter().cloned())
769        }
770        IterationStrategy::All | IterationStrategy::None => unreachable!(),
771    }
772
773    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
774        .len(predicate.count)
775        .add_buffer(filter.dst_offsets.into())
776        .add_buffer(filter.dst_values.into());
777
778    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
779        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
780    }
781
782    let data = unsafe { builder.build_unchecked() };
783    GenericByteArray::from(data)
784}
785
786/// `filter` implementation for byte view arrays.
787fn filter_byte_view<T: ByteViewType>(
788    array: &GenericByteViewArray<T>,
789    predicate: &FilterPredicate,
790) -> GenericByteViewArray<T> {
791    let new_view_buffer = filter_native(array.views(), predicate);
792
793    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
794        .len(predicate.count)
795        .add_buffer(new_view_buffer)
796        .add_buffers(array.data_buffers().to_vec());
797
798    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
799        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
800    }
801
802    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
803}
804
805fn filter_fixed_size_binary(
806    array: &FixedSizeBinaryArray,
807    predicate: &FilterPredicate,
808) -> FixedSizeBinaryArray {
809    let values: &[u8] = array.values();
810    let value_length = array.value_length() as usize;
811    let calculate_offset_from_index = |index: usize| index * value_length;
812    let buffer = match &predicate.strategy {
813        IterationStrategy::SlicesIterator => {
814            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
815            for (start, end) in SlicesIterator::new(&predicate.filter) {
816                buffer.extend_from_slice(
817                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
818                );
819            }
820            buffer
821        }
822        IterationStrategy::Slices(slices) => {
823            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
824            for (start, end) in slices {
825                buffer.extend_from_slice(
826                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
827                );
828            }
829            buffer
830        }
831        IterationStrategy::IndexIterator => {
832            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
833                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
834            });
835
836            let mut buffer = MutableBuffer::new(predicate.count * value_length);
837            iter.for_each(|item| buffer.extend_from_slice(item));
838            buffer
839        }
840        IterationStrategy::Indices(indices) => {
841            let iter = indices.iter().map(|x| {
842                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
843            });
844
845            let mut buffer = MutableBuffer::new(predicate.count * value_length);
846            iter.for_each(|item| buffer.extend_from_slice(item));
847            buffer
848        }
849        IterationStrategy::All | IterationStrategy::None => unreachable!(),
850    };
851    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
852        .len(predicate.count)
853        .add_buffer(buffer.into());
854
855    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
856        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
857    }
858
859    let data = unsafe { builder.build_unchecked() };
860    FixedSizeBinaryArray::from(data)
861}
862
863/// `filter` implementation for dictionaries
864fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
865where
866    T: ArrowDictionaryKeyType,
867    T::Native: num_traits::Num,
868{
869    let builder = filter_primitive::<T>(array.keys(), predicate)
870        .into_data()
871        .into_builder()
872        .data_type(array.data_type().clone())
873        .child_data(vec![array.values().to_data()]);
874
875    // SAFETY:
876    // Keys were valid before, filtered subset is therefore still valid
877    DictionaryArray::from(unsafe { builder.build_unchecked() })
878}
879
880/// `filter` implementation for structs
881fn filter_struct(
882    array: &StructArray,
883    predicate: &FilterPredicate,
884) -> Result<StructArray, ArrowError> {
885    let columns = array
886        .columns()
887        .iter()
888        .map(|column| filter_array(column, predicate))
889        .collect::<Result<_, _>>()?;
890
891    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
892        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
893
894        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
895    } else {
896        None
897    };
898
899    Ok(unsafe {
900        StructArray::new_unchecked_with_length(
901            array.fields().clone(),
902            columns,
903            nulls,
904            predicate.count(),
905        )
906    })
907}
908
909/// `filter` implementation for sparse unions
910fn filter_sparse_union(
911    array: &UnionArray,
912    predicate: &FilterPredicate,
913) -> Result<UnionArray, ArrowError> {
914    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
915        unreachable!()
916    };
917
918    let type_ids = filter_primitive(
919        &Int8Array::try_new(array.type_ids().clone(), None)?,
920        predicate,
921    );
922
923    let children = fields
924        .iter()
925        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
926        .collect::<Result<_, _>>()?;
927
928    Ok(unsafe {
929        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
930    })
931}
932
933/// `filter` implementation for list views
934fn filter_list_view<OffsetType: OffsetSizeTrait>(
935    array: &GenericListViewArray<OffsetType>,
936    predicate: &FilterPredicate,
937) -> GenericListViewArray<OffsetType> {
938    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
939    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
940
941    // Filter the nulls
942    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
943        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
944
945        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
946    } else {
947        None
948    };
949
950    let list_data = ArrayDataBuilder::new(array.data_type().clone())
951        .nulls(nulls)
952        .buffers(vec![filtered_offsets, filtered_sizes])
953        .child_data(vec![array.values().to_data()])
954        .len(predicate.count);
955
956    let list_data = unsafe { list_data.build_unchecked() };
957
958    GenericListViewArray::from(list_data)
959}
960
961#[cfg(test)]
962mod tests {
963    use super::*;
964    use arrow_array::builder::*;
965    use arrow_array::cast::as_run_array;
966    use arrow_array::types::*;
967    use arrow_data::ArrayData;
968    use rand::distr::uniform::{UniformSampler, UniformUsize};
969    use rand::distr::{Alphanumeric, StandardUniform};
970    use rand::prelude::*;
971    use rand::rng;
972
973    macro_rules! def_temporal_test {
974        ($test:ident, $array_type: ident, $data: expr) => {
975            #[test]
976            fn $test() {
977                let a = $data;
978                let b = BooleanArray::from(vec![true, false, true, false]);
979                let c = filter(&a, &b).unwrap();
980                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
981                assert_eq!(2, d.len());
982                assert_eq!(1, d.value(0));
983                assert_eq!(3, d.value(1));
984            }
985        };
986    }
987
988    def_temporal_test!(
989        test_filter_date32,
990        Date32Array,
991        Date32Array::from(vec![1, 2, 3, 4])
992    );
993    def_temporal_test!(
994        test_filter_date64,
995        Date64Array,
996        Date64Array::from(vec![1, 2, 3, 4])
997    );
998    def_temporal_test!(
999        test_filter_time32_second,
1000        Time32SecondArray,
1001        Time32SecondArray::from(vec![1, 2, 3, 4])
1002    );
1003    def_temporal_test!(
1004        test_filter_time32_millisecond,
1005        Time32MillisecondArray,
1006        Time32MillisecondArray::from(vec![1, 2, 3, 4])
1007    );
1008    def_temporal_test!(
1009        test_filter_time64_microsecond,
1010        Time64MicrosecondArray,
1011        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1012    );
1013    def_temporal_test!(
1014        test_filter_time64_nanosecond,
1015        Time64NanosecondArray,
1016        Time64NanosecondArray::from(vec![1, 2, 3, 4])
1017    );
1018    def_temporal_test!(
1019        test_filter_duration_second,
1020        DurationSecondArray,
1021        DurationSecondArray::from(vec![1, 2, 3, 4])
1022    );
1023    def_temporal_test!(
1024        test_filter_duration_millisecond,
1025        DurationMillisecondArray,
1026        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1027    );
1028    def_temporal_test!(
1029        test_filter_duration_microsecond,
1030        DurationMicrosecondArray,
1031        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1032    );
1033    def_temporal_test!(
1034        test_filter_duration_nanosecond,
1035        DurationNanosecondArray,
1036        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1037    );
1038    def_temporal_test!(
1039        test_filter_timestamp_second,
1040        TimestampSecondArray,
1041        TimestampSecondArray::from(vec![1, 2, 3, 4])
1042    );
1043    def_temporal_test!(
1044        test_filter_timestamp_millisecond,
1045        TimestampMillisecondArray,
1046        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1047    );
1048    def_temporal_test!(
1049        test_filter_timestamp_microsecond,
1050        TimestampMicrosecondArray,
1051        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1052    );
1053    def_temporal_test!(
1054        test_filter_timestamp_nanosecond,
1055        TimestampNanosecondArray,
1056        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1057    );
1058
1059    #[test]
1060    fn test_filter_array_slice() {
1061        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1062        let b = BooleanArray::from(vec![true, false, false, true]);
1063        // filtering with sliced filter array is not currently supported
1064        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1065        // let b = b_slice.as_any().downcast_ref().unwrap();
1066        let c = filter(&a, &b).unwrap();
1067        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1068        assert_eq!(2, d.len());
1069        assert_eq!(6, d.value(0));
1070        assert_eq!(9, d.value(1));
1071    }
1072
1073    #[test]
1074    fn test_filter_array_low_density() {
1075        // this test exercises the all 0's branch of the filter algorithm
1076        let mut data_values = (1..=65).collect::<Vec<i32>>();
1077        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1078        // set up two more values after the batch
1079        data_values.extend_from_slice(&[66, 67]);
1080        filter_values.extend_from_slice(&[false, true]);
1081        let a = Int32Array::from(data_values);
1082        let b = BooleanArray::from(filter_values);
1083        let c = filter(&a, &b).unwrap();
1084        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1085        assert_eq!(2, d.len());
1086        assert_eq!(65, d.value(0));
1087        assert_eq!(67, d.value(1));
1088    }
1089
1090    #[test]
1091    fn test_filter_array_high_density() {
1092        // this test exercises the all 1's branch of the filter algorithm
1093        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1094        let mut filter_values = (1..=65)
1095            .map(|i| !matches!(i % 65, 0))
1096            .collect::<Vec<bool>>();
1097        // set second data value to null
1098        data_values[1] = None;
1099        // set up two more values after the batch
1100        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1101        filter_values.extend_from_slice(&[false, true, true, true]);
1102        let a = Int32Array::from(data_values);
1103        let b = BooleanArray::from(filter_values);
1104        let c = filter(&a, &b).unwrap();
1105        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1106        assert_eq!(67, d.len());
1107        assert_eq!(3, d.null_count());
1108        assert_eq!(1, d.value(0));
1109        assert!(d.is_null(1));
1110        assert_eq!(64, d.value(63));
1111        assert!(d.is_null(64));
1112        assert_eq!(67, d.value(65));
1113    }
1114
1115    #[test]
1116    fn test_filter_string_array_simple() {
1117        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1118        let b = BooleanArray::from(vec![true, false, true, false]);
1119        let c = filter(&a, &b).unwrap();
1120        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1121        assert_eq!(2, d.len());
1122        assert_eq!("hello", d.value(0));
1123        assert_eq!("world", d.value(1));
1124    }
1125
1126    #[test]
1127    fn test_filter_primitive_array_with_null() {
1128        let a = Int32Array::from(vec![Some(5), None]);
1129        let b = BooleanArray::from(vec![false, true]);
1130        let c = filter(&a, &b).unwrap();
1131        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1132        assert_eq!(1, d.len());
1133        assert!(d.is_null(0));
1134    }
1135
1136    #[test]
1137    fn test_filter_string_array_with_null() {
1138        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1139        let b = BooleanArray::from(vec![true, false, false, true]);
1140        let c = filter(&a, &b).unwrap();
1141        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1142        assert_eq!(2, d.len());
1143        assert_eq!("hello", d.value(0));
1144        assert!(!d.is_null(0));
1145        assert!(d.is_null(1));
1146    }
1147
1148    #[test]
1149    fn test_filter_binary_array_with_null() {
1150        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1151        let a = BinaryArray::from(data);
1152        let b = BooleanArray::from(vec![true, false, false, true]);
1153        let c = filter(&a, &b).unwrap();
1154        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1155        assert_eq!(2, d.len());
1156        assert_eq!(b"hello", d.value(0));
1157        assert!(!d.is_null(0));
1158        assert!(d.is_null(1));
1159    }
1160
1161    fn _test_filter_byte_view<T>()
1162    where
1163        T: ByteViewType,
1164        str: AsRef<T::Native>,
1165        T::Native: PartialEq,
1166    {
1167        let array = {
1168            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1169            let mut builder = GenericByteViewBuilder::<T>::new();
1170            builder.append_value("hello");
1171            builder.append_value("world");
1172            builder.append_null();
1173            builder.append_value("large payload over 12 bytes");
1174            builder.append_value("lulu");
1175            builder.finish()
1176        };
1177
1178        {
1179            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1180            let actual = filter(&array, &predicate).unwrap();
1181
1182            assert_eq!(actual.len(), 3);
1183
1184            let expected = {
1185                // ["hello", null, "large payload over 12 bytes"]
1186                let mut builder = GenericByteViewBuilder::<T>::new();
1187                builder.append_value("hello");
1188                builder.append_null();
1189                builder.append_value("large payload over 12 bytes");
1190                builder.finish()
1191            };
1192
1193            assert_eq!(actual.as_ref(), &expected);
1194        }
1195
1196        {
1197            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1198            let actual = filter(&array, &predicate).unwrap();
1199
1200            assert_eq!(actual.len(), 2);
1201
1202            let expected = {
1203                // ["hello", "lulu"]
1204                let mut builder = GenericByteViewBuilder::<T>::new();
1205                builder.append_value("hello");
1206                builder.append_value("lulu");
1207                builder.finish()
1208            };
1209
1210            assert_eq!(actual.as_ref(), &expected);
1211        }
1212    }
1213
1214    #[test]
1215    fn test_filter_string_view() {
1216        _test_filter_byte_view::<StringViewType>()
1217    }
1218
1219    #[test]
1220    fn test_filter_binary_view() {
1221        _test_filter_byte_view::<BinaryViewType>()
1222    }
1223
1224    #[test]
1225    fn test_filter_fixed_binary() {
1226        let v1 = [1_u8, 2];
1227        let v2 = [3_u8, 4];
1228        let v3 = [5_u8, 6];
1229        let v = vec![&v1, &v2, &v3];
1230        let a = FixedSizeBinaryArray::from(v);
1231        let b = BooleanArray::from(vec![true, false, true]);
1232        let c = filter(&a, &b).unwrap();
1233        let d = c
1234            .as_ref()
1235            .as_any()
1236            .downcast_ref::<FixedSizeBinaryArray>()
1237            .unwrap();
1238        assert_eq!(d.len(), 2);
1239        assert_eq!(d.value(0), &v1);
1240        assert_eq!(d.value(1), &v3);
1241        let c2 = FilterBuilder::new(&b)
1242            .optimize()
1243            .build()
1244            .filter(&a)
1245            .unwrap();
1246        let d2 = c2
1247            .as_ref()
1248            .as_any()
1249            .downcast_ref::<FixedSizeBinaryArray>()
1250            .unwrap();
1251        assert_eq!(d, d2);
1252
1253        let b = BooleanArray::from(vec![false, false, false]);
1254        let c = filter(&a, &b).unwrap();
1255        let d = c
1256            .as_ref()
1257            .as_any()
1258            .downcast_ref::<FixedSizeBinaryArray>()
1259            .unwrap();
1260        assert_eq!(d.len(), 0);
1261
1262        let b = BooleanArray::from(vec![true, true, true]);
1263        let c = filter(&a, &b).unwrap();
1264        let d = c
1265            .as_ref()
1266            .as_any()
1267            .downcast_ref::<FixedSizeBinaryArray>()
1268            .unwrap();
1269        assert_eq!(d.len(), 3);
1270        assert_eq!(d.value(0), &v1);
1271        assert_eq!(d.value(1), &v2);
1272        assert_eq!(d.value(2), &v3);
1273
1274        let b = BooleanArray::from(vec![false, false, true]);
1275        let c = filter(&a, &b).unwrap();
1276        let d = c
1277            .as_ref()
1278            .as_any()
1279            .downcast_ref::<FixedSizeBinaryArray>()
1280            .unwrap();
1281        assert_eq!(d.len(), 1);
1282        assert_eq!(d.value(0), &v3);
1283        let c2 = FilterBuilder::new(&b)
1284            .optimize()
1285            .build()
1286            .filter(&a)
1287            .unwrap();
1288        let d2 = c2
1289            .as_ref()
1290            .as_any()
1291            .downcast_ref::<FixedSizeBinaryArray>()
1292            .unwrap();
1293        assert_eq!(d, d2);
1294    }
1295
1296    #[test]
1297    fn test_filter_array_slice_with_null() {
1298        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1299        let b = BooleanArray::from(vec![true, false, false, true]);
1300        // filtering with sliced filter array is not currently supported
1301        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1302        // let b = b_slice.as_any().downcast_ref().unwrap();
1303        let c = filter(&a, &b).unwrap();
1304        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1305        assert_eq!(2, d.len());
1306        assert!(d.is_null(0));
1307        assert!(!d.is_null(1));
1308        assert_eq!(9, d.value(1));
1309    }
1310
1311    #[test]
1312    fn test_filter_run_end_encoding_array() {
1313        let run_ends = Int64Array::from(vec![2, 3, 8]);
1314        let values = Int64Array::from(vec![7, -2, 9]);
1315        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1316        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1317        let c = filter(&a, &b).unwrap();
1318        let actual: &RunArray<Int64Type> = as_run_array(&c);
1319        assert_eq!(4, actual.len());
1320
1321        let expected = RunArray::try_new(
1322            &Int64Array::from(vec![1, 2, 4]),
1323            &Int64Array::from(vec![7, -2, 9]),
1324        )
1325        .expect("Failed to make expected RunArray test is broken");
1326
1327        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1328        assert_eq!(actual.values(), expected.values())
1329    }
1330
1331    #[test]
1332    fn test_filter_run_end_encoding_array_remove_value() {
1333        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1334        let values = Int32Array::from(vec![7, -2, 9, -8]);
1335        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1336        let b = BooleanArray::from(vec![
1337            false, true, false, false, true, false, true, false, false, false,
1338        ]);
1339        let c = filter(&a, &b).unwrap();
1340        let actual: &RunArray<Int32Type> = as_run_array(&c);
1341        assert_eq!(3, actual.len());
1342
1343        let expected =
1344            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1345                .expect("Failed to make expected RunArray test is broken");
1346
1347        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1348        assert_eq!(actual.values(), expected.values())
1349    }
1350
1351    #[test]
1352    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1353        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1354        let values = Int16Array::from(vec![7, -2, 9, -8]);
1355        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1356        let b = BooleanArray::from(vec![
1357            false, false, false, false, false, false, true, false, false, false,
1358        ]);
1359        let c = filter(&a, &b).unwrap();
1360        let actual: &RunArray<Int16Type> = as_run_array(&c);
1361        assert_eq!(1, actual.len());
1362
1363        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1364            .expect("Failed to make expected RunArray test is broken");
1365
1366        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1367        assert_eq!(actual.values(), expected.values())
1368    }
1369
1370    #[test]
1371    fn test_filter_run_end_encoding_array_empty() {
1372        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1373        let values = Int64Array::from(vec![7, -2, 9, -8]);
1374        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1375        let b = BooleanArray::from(vec![
1376            false, false, false, false, false, false, false, false, false, false,
1377        ]);
1378        let c = filter(&a, &b).unwrap();
1379        let actual: &RunArray<Int64Type> = as_run_array(&c);
1380        assert_eq!(0, actual.len());
1381    }
1382
1383    #[test]
1384    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1385        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1386        let values = Int64Array::from(vec![7, -2, 9, -8]);
1387        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1388        let b = BooleanArray::from(vec![false, true, true]);
1389        let c = filter(&a, &b).unwrap();
1390        let actual: &RunArray<Int64Type> = as_run_array(&c);
1391        assert_eq!(2, actual.len());
1392
1393        let expected = RunArray::try_new(
1394            &Int64Array::from(vec![1, 2]),
1395            &Int64Array::from(vec![7, -2]),
1396        )
1397        .expect("Failed to make expected RunArray test is broken");
1398
1399        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1400        assert_eq!(actual.values(), expected.values())
1401    }
1402
1403    #[test]
1404    fn test_filter_dictionary_array() {
1405        let values = [Some("hello"), None, Some("world"), Some("!")];
1406        let a: Int8DictionaryArray = values.iter().copied().collect();
1407        let b = BooleanArray::from(vec![false, true, true, false]);
1408        let c = filter(&a, &b).unwrap();
1409        let d = c
1410            .as_ref()
1411            .as_any()
1412            .downcast_ref::<Int8DictionaryArray>()
1413            .unwrap();
1414        let value_array = d.values();
1415        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1416        // values are cloned in the filtered dictionary array
1417        assert_eq!(3, values.len());
1418        // but keys are filtered
1419        assert_eq!(2, d.len());
1420        assert!(d.is_null(0));
1421        assert_eq!("world", values.value(d.keys().value(1) as usize));
1422    }
1423
1424    #[test]
1425    fn test_filter_list_array() {
1426        let value_data = ArrayData::builder(DataType::Int32)
1427            .len(8)
1428            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1429            .build()
1430            .unwrap();
1431
1432        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1433
1434        let list_data_type =
1435            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1436        let list_data = ArrayData::builder(list_data_type)
1437            .len(4)
1438            .add_buffer(value_offsets)
1439            .add_child_data(value_data)
1440            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1441            .build()
1442            .unwrap();
1443
1444        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1445        let a = LargeListArray::from(list_data);
1446        let b = BooleanArray::from(vec![false, true, false, true]);
1447        let result = filter(&a, &b).unwrap();
1448
1449        // expected: [[3, 4, 5], null]
1450        let value_data = ArrayData::builder(DataType::Int32)
1451            .len(3)
1452            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1453            .build()
1454            .unwrap();
1455
1456        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1457
1458        let list_data_type =
1459            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1460        let expected = ArrayData::builder(list_data_type)
1461            .len(2)
1462            .add_buffer(value_offsets)
1463            .add_child_data(value_data)
1464            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1465            .build()
1466            .unwrap();
1467
1468        assert_eq!(&make_array(expected), &result);
1469    }
1470
1471    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1472        // [[1, 2], null, [], [3,4]]
1473        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1474        list_array.append_value([Some(1), Some(2)]);
1475        list_array.append_null();
1476        list_array.append_value([]);
1477        list_array.append_value([Some(3), Some(4)]);
1478
1479        let list_array = list_array.finish();
1480        let predicate = BooleanArray::from_iter([true, false, true, false]);
1481
1482        // Filter result: [[1, 2], []]
1483        let filtered = filter(&list_array, &predicate)
1484            .unwrap()
1485            .as_list_view::<T>()
1486            .clone();
1487
1488        let mut expected =
1489            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1490        expected.append_value([Some(1), Some(2)]);
1491        expected.append_value([]);
1492        let expected = expected.finish();
1493
1494        assert_eq!(&filtered, &expected);
1495    }
1496
1497    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1498        // [[1, 2], null, [], [3,4]]
1499        let mut list_array =
1500            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1501        list_array.append_value([Some(1), Some(2)]);
1502        list_array.append_null();
1503        list_array.append_value([]);
1504        list_array.append_value([Some(3), Some(4)]);
1505
1506        let list_array = list_array.finish();
1507
1508        // Sliced: [null, [], [3, 4]]
1509        let sliced = list_array.slice(1, 3);
1510        let predicate = BooleanArray::from_iter([false, false, true]);
1511
1512        // Filter result: [[1, 2], []]
1513        let filtered = filter(&sliced, &predicate)
1514            .unwrap()
1515            .as_list_view::<T>()
1516            .clone();
1517
1518        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1519        expected.append_value([Some(3), Some(4)]);
1520        let expected = expected.finish();
1521
1522        assert_eq!(&filtered, &expected);
1523    }
1524
1525    #[test]
1526    fn test_filter_list_view_array() {
1527        test_case_filter_list_view::<i32>();
1528        test_case_filter_list_view::<i64>();
1529
1530        test_case_filter_sliced_list_view::<i32>();
1531        test_case_filter_sliced_list_view::<i64>();
1532    }
1533
1534    #[test]
1535    fn test_slice_iterator_bits() {
1536        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1537        let filter = BooleanArray::from(filter_values);
1538        let filter_count = filter_count(&filter);
1539
1540        let iter = SlicesIterator::new(&filter);
1541        let chunks = iter.collect::<Vec<_>>();
1542
1543        assert_eq!(chunks, vec![(1, 2)]);
1544        assert_eq!(filter_count, 1);
1545    }
1546
1547    #[test]
1548    fn test_slice_iterator_bits1() {
1549        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1550        let filter = BooleanArray::from(filter_values);
1551        let filter_count = filter_count(&filter);
1552
1553        let iter = SlicesIterator::new(&filter);
1554        let chunks = iter.collect::<Vec<_>>();
1555
1556        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1557        assert_eq!(filter_count, 64 - 1);
1558    }
1559
1560    #[test]
1561    fn test_slice_iterator_chunk_and_bits() {
1562        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1563        let filter = BooleanArray::from(filter_values);
1564        let filter_count = filter_count(&filter);
1565
1566        let iter = SlicesIterator::new(&filter);
1567        let chunks = iter.collect::<Vec<_>>();
1568
1569        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1570        assert_eq!(filter_count, 61 + 61 + 5);
1571    }
1572
1573    #[test]
1574    fn test_null_mask() {
1575        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1576
1577        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1578        let out = filter(&a, &mask1).unwrap();
1579        assert_eq!(out.as_ref(), &a.slice(0, 2));
1580    }
1581
1582    #[test]
1583    fn test_filter_record_batch_no_columns() {
1584        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1585        let options = RecordBatchOptions::default().with_row_count(Some(100));
1586        let record_batch =
1587            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1588        let out = filter_record_batch(&record_batch, &pred).unwrap();
1589
1590        assert_eq!(out.num_rows(), 2);
1591    }
1592
1593    #[test]
1594    fn test_fast_path() {
1595        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1596
1597        // all true
1598        let mask = BooleanArray::from(vec![true, true, true]);
1599        let out = filter(&a, &mask).unwrap();
1600        let b = out
1601            .as_any()
1602            .downcast_ref::<PrimitiveArray<Int64Type>>()
1603            .unwrap();
1604        assert_eq!(&a, b);
1605
1606        // all false
1607        let mask = BooleanArray::from(vec![false, false, false]);
1608        let out = filter(&a, &mask).unwrap();
1609        assert_eq!(out.len(), 0);
1610        assert_eq!(out.data_type(), &DataType::Int64);
1611    }
1612
1613    #[test]
1614    fn test_slices() {
1615        // takes up 2 u64s
1616        let bools = std::iter::repeat_n(true, 10)
1617            .chain(std::iter::repeat_n(false, 30))
1618            .chain(std::iter::repeat_n(true, 20))
1619            .chain(std::iter::repeat_n(false, 17))
1620            .chain(std::iter::repeat_n(true, 4));
1621
1622        let bool_array: BooleanArray = bools.map(Some).collect();
1623
1624        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1625        let expected = vec![(0, 10), (40, 60), (77, 81)];
1626        assert_eq!(slices, expected);
1627
1628        // slice with offset and truncated len
1629        let len = bool_array.len();
1630        let sliced_array = bool_array.slice(7, len - 10);
1631        let sliced_array = sliced_array
1632            .as_any()
1633            .downcast_ref::<BooleanArray>()
1634            .unwrap();
1635        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1636        let expected = vec![(0, 3), (33, 53), (70, 71)];
1637        assert_eq!(slices, expected);
1638    }
1639
1640    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1641        let mut rng = rng();
1642
1643        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1644            .take(mask_len)
1645            .collect();
1646
1647        let buffer = Buffer::from_iter(bools.iter().cloned());
1648
1649        let truncated_length = mask_len - offset - truncate;
1650
1651        let data = ArrayDataBuilder::new(DataType::Boolean)
1652            .len(truncated_length)
1653            .offset(offset)
1654            .add_buffer(buffer)
1655            .build()
1656            .unwrap();
1657
1658        let filter = BooleanArray::from(data);
1659
1660        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1661            .flat_map(|(start, end)| start..end)
1662            .collect();
1663
1664        let count = filter_count(&filter);
1665        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1666
1667        let expected_bits: Vec<_> = bools
1668            .iter()
1669            .skip(offset)
1670            .take(truncated_length)
1671            .enumerate()
1672            .flat_map(|(idx, v)| v.then(|| idx))
1673            .collect();
1674
1675        assert_eq!(slice_bits, expected_bits);
1676        assert_eq!(index_bits, expected_bits);
1677    }
1678
1679    #[test]
1680    #[cfg_attr(miri, ignore)]
1681    fn fuzz_test_slices_iterator() {
1682        let mut rng = rng();
1683
1684        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1685        for _ in 0..100 {
1686            let mask_len = rng.random_range(0..1024);
1687            let max_offset = 64.min(mask_len);
1688            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1689
1690            let max_truncate = 128.min(mask_len - offset);
1691            let truncate = uusize
1692                .sample(&mut rng)
1693                .checked_rem(max_truncate)
1694                .unwrap_or(0);
1695
1696            test_slices_fuzz(mask_len, offset, truncate);
1697        }
1698
1699        test_slices_fuzz(64, 0, 0);
1700        test_slices_fuzz(64, 8, 0);
1701        test_slices_fuzz(64, 8, 8);
1702        test_slices_fuzz(32, 8, 8);
1703        test_slices_fuzz(32, 5, 9);
1704    }
1705
1706    /// Filters `values` by `predicate` using standard rust iterators
1707    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1708        values
1709            .into_iter()
1710            .zip(predicate)
1711            .filter(|(_, x)| **x)
1712            .map(|(a, _)| a)
1713            .collect()
1714    }
1715
1716    /// Generates an array of length `len` with `valid_percent` non-null values
1717    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1718    where
1719        StandardUniform: Distribution<T>,
1720    {
1721        let mut rng = rng();
1722        (0..len)
1723            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1724            .collect()
1725    }
1726
1727    /// Generates an array of length `len` with `valid_percent` non-null values
1728    fn gen_strings(
1729        len: usize,
1730        valid_percent: f64,
1731        str_len_range: std::ops::Range<usize>,
1732    ) -> Vec<Option<String>> {
1733        let mut rng = rng();
1734        (0..len)
1735            .map(|_| {
1736                rng.random_bool(valid_percent).then(|| {
1737                    let len = rng.random_range(str_len_range.clone());
1738                    (0..len)
1739                        .map(|_| char::from(rng.sample(Alphanumeric)))
1740                        .collect()
1741                })
1742            })
1743            .collect()
1744    }
1745
1746    /// Returns an iterator that calls `Option::as_deref` on each item
1747    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1748        src.iter().map(|x| x.as_deref())
1749    }
1750
1751    #[test]
1752    #[cfg_attr(miri, ignore)]
1753    fn fuzz_filter() {
1754        let mut rng = rng();
1755
1756        for i in 0..100 {
1757            let filter_percent = match i {
1758                0..=4 => 1.,
1759                5..=10 => 0.,
1760                _ => rng.random_range(0.0..1.0),
1761            };
1762
1763            let valid_percent = rng.random_range(0.0..1.0);
1764
1765            let array_len = rng.random_range(32..256);
1766            let array_offset = rng.random_range(0..10);
1767
1768            // Construct a predicate
1769            let filter_offset = rng.random_range(0..10);
1770            let filter_truncate = rng.random_range(0..10);
1771            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1772                .take(array_len + filter_offset - filter_truncate)
1773                .collect();
1774
1775            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1776
1777            // Offset predicate
1778            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1779            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1780            let bools = &bools[filter_offset..];
1781
1782            // Test i32
1783            let values = gen_primitive(array_len + array_offset, valid_percent);
1784            let src = Int32Array::from_iter(values.iter().cloned());
1785
1786            let src = src.slice(array_offset, array_len);
1787            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1788            let values = &values[array_offset..];
1789
1790            let filtered = filter(src, predicate).unwrap();
1791            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1792            let actual: Vec<_> = array.iter().collect();
1793
1794            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1795
1796            // Test string
1797            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1798            let src = StringArray::from_iter(as_deref(&strings));
1799
1800            let src = src.slice(array_offset, array_len);
1801            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1802
1803            let filtered = filter(src, predicate).unwrap();
1804            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1805            let actual: Vec<_> = array.iter().collect();
1806
1807            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1808            assert_eq!(actual, expected_strings);
1809
1810            // Test string dictionary
1811            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1812
1813            let src = src.slice(array_offset, array_len);
1814            let src = src
1815                .as_any()
1816                .downcast_ref::<DictionaryArray<Int32Type>>()
1817                .unwrap();
1818
1819            let filtered = filter(src, predicate).unwrap();
1820
1821            let array = filtered
1822                .as_any()
1823                .downcast_ref::<DictionaryArray<Int32Type>>()
1824                .unwrap();
1825
1826            let values = array
1827                .values()
1828                .as_any()
1829                .downcast_ref::<StringArray>()
1830                .unwrap();
1831
1832            let actual: Vec<_> = array
1833                .keys()
1834                .iter()
1835                .map(|key| key.map(|key| values.value(key as usize)))
1836                .collect();
1837
1838            assert_eq!(actual, expected_strings);
1839        }
1840    }
1841
1842    #[test]
1843    fn test_filter_map() {
1844        let mut builder =
1845            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1846        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1847        builder.keys().append_value("key1");
1848        builder.values().append_value(1);
1849        builder.append(true).unwrap();
1850        builder.keys().append_value("key2");
1851        builder.keys().append_value("key3");
1852        builder.values().append_value(2);
1853        builder.values().append_value(3);
1854        builder.append(true).unwrap();
1855        builder.append(false).unwrap();
1856        builder.keys().append_value("key1");
1857        builder.values().append_value(1);
1858        builder.append(true).unwrap();
1859        let maparray = Arc::new(builder.finish()) as ArrayRef;
1860
1861        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1862            .into_iter()
1863            .collect::<BooleanArray>();
1864        let got = filter(&maparray, &indices).unwrap();
1865
1866        let mut builder =
1867            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1868        builder.keys().append_value("key1");
1869        builder.values().append_value(1);
1870        builder.append(true).unwrap();
1871        builder.keys().append_value("key1");
1872        builder.values().append_value(1);
1873        builder.append(true).unwrap();
1874        let expected = Arc::new(builder.finish()) as ArrayRef;
1875
1876        assert_eq!(&expected, &got);
1877    }
1878
1879    #[test]
1880    fn test_filter_fixed_size_list_arrays() {
1881        let value_data = ArrayData::builder(DataType::Int32)
1882            .len(9)
1883            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1884            .build()
1885            .unwrap();
1886        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1887        let list_data = ArrayData::builder(list_data_type)
1888            .len(3)
1889            .add_child_data(value_data)
1890            .build()
1891            .unwrap();
1892        let array = FixedSizeListArray::from(list_data);
1893
1894        let filter_array = BooleanArray::from(vec![true, false, false]);
1895
1896        let c = filter(&array, &filter_array).unwrap();
1897        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1898
1899        assert_eq!(filtered.len(), 1);
1900
1901        let list = filtered.value(0);
1902        assert_eq!(
1903            &[0, 1, 2],
1904            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1905        );
1906
1907        let filter_array = BooleanArray::from(vec![true, false, true]);
1908
1909        let c = filter(&array, &filter_array).unwrap();
1910        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1911
1912        assert_eq!(filtered.len(), 2);
1913
1914        let list = filtered.value(0);
1915        assert_eq!(
1916            &[0, 1, 2],
1917            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1918        );
1919        let list = filtered.value(1);
1920        assert_eq!(
1921            &[6, 7, 8],
1922            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1923        );
1924    }
1925
1926    #[test]
1927    fn test_filter_fixed_size_list_arrays_with_null() {
1928        let value_data = ArrayData::builder(DataType::Int32)
1929            .len(10)
1930            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1931            .build()
1932            .unwrap();
1933
1934        // Set null buts for the nested array:
1935        //  [[0, 1], null, null, [6, 7], [8, 9]]
1936        // 01011001 00000001
1937        let mut null_bits: [u8; 1] = [0; 1];
1938        bit_util::set_bit(&mut null_bits, 0);
1939        bit_util::set_bit(&mut null_bits, 3);
1940        bit_util::set_bit(&mut null_bits, 4);
1941
1942        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1943        let list_data = ArrayData::builder(list_data_type)
1944            .len(5)
1945            .add_child_data(value_data)
1946            .null_bit_buffer(Some(Buffer::from(null_bits)))
1947            .build()
1948            .unwrap();
1949        let array = FixedSizeListArray::from(list_data);
1950
1951        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1952
1953        let c = filter(&array, &filter_array).unwrap();
1954        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1955
1956        assert_eq!(filtered.len(), 3);
1957
1958        let list = filtered.value(0);
1959        assert_eq!(
1960            &[0, 1],
1961            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1962        );
1963        assert!(filtered.is_null(1));
1964        let list = filtered.value(2);
1965        assert_eq!(
1966            &[6, 7],
1967            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1968        );
1969    }
1970
1971    fn test_filter_union_array(array: UnionArray) {
1972        let filter_array = BooleanArray::from(vec![true, false, false]);
1973        let c = filter(&array, &filter_array).unwrap();
1974        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1975
1976        let mut builder = UnionBuilder::new_dense();
1977        builder.append::<Int32Type>("A", 1).unwrap();
1978        let expected_array = builder.build().unwrap();
1979
1980        compare_union_arrays(filtered, &expected_array);
1981
1982        let filter_array = BooleanArray::from(vec![true, false, true]);
1983        let c = filter(&array, &filter_array).unwrap();
1984        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1985
1986        let mut builder = UnionBuilder::new_dense();
1987        builder.append::<Int32Type>("A", 1).unwrap();
1988        builder.append::<Int32Type>("A", 34).unwrap();
1989        let expected_array = builder.build().unwrap();
1990
1991        compare_union_arrays(filtered, &expected_array);
1992
1993        let filter_array = BooleanArray::from(vec![true, true, false]);
1994        let c = filter(&array, &filter_array).unwrap();
1995        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1996
1997        let mut builder = UnionBuilder::new_dense();
1998        builder.append::<Int32Type>("A", 1).unwrap();
1999        builder.append::<Float64Type>("B", 3.2).unwrap();
2000        let expected_array = builder.build().unwrap();
2001
2002        compare_union_arrays(filtered, &expected_array);
2003    }
2004
2005    #[test]
2006    fn test_filter_union_array_dense() {
2007        let mut builder = UnionBuilder::new_dense();
2008        builder.append::<Int32Type>("A", 1).unwrap();
2009        builder.append::<Float64Type>("B", 3.2).unwrap();
2010        builder.append::<Int32Type>("A", 34).unwrap();
2011        let array = builder.build().unwrap();
2012
2013        test_filter_union_array(array);
2014    }
2015
2016    #[test]
2017    fn test_filter_run_union_array_dense() {
2018        let mut builder = UnionBuilder::new_dense();
2019        builder.append::<Int32Type>("A", 1).unwrap();
2020        builder.append::<Int32Type>("A", 3).unwrap();
2021        builder.append::<Int32Type>("A", 34).unwrap();
2022        let array = builder.build().unwrap();
2023
2024        let filter_array = BooleanArray::from(vec![true, true, false]);
2025        let c = filter(&array, &filter_array).unwrap();
2026        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2027
2028        let mut builder = UnionBuilder::new_dense();
2029        builder.append::<Int32Type>("A", 1).unwrap();
2030        builder.append::<Int32Type>("A", 3).unwrap();
2031        let expected = builder.build().unwrap();
2032
2033        assert_eq!(filtered.to_data(), expected.to_data());
2034    }
2035
2036    #[test]
2037    fn test_filter_union_array_dense_with_nulls() {
2038        let mut builder = UnionBuilder::new_dense();
2039        builder.append::<Int32Type>("A", 1).unwrap();
2040        builder.append::<Float64Type>("B", 3.2).unwrap();
2041        builder.append_null::<Float64Type>("B").unwrap();
2042        builder.append::<Int32Type>("A", 34).unwrap();
2043        let array = builder.build().unwrap();
2044
2045        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2046        let c = filter(&array, &filter_array).unwrap();
2047        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2048
2049        let mut builder = UnionBuilder::new_dense();
2050        builder.append::<Int32Type>("A", 1).unwrap();
2051        builder.append::<Float64Type>("B", 3.2).unwrap();
2052        let expected_array = builder.build().unwrap();
2053
2054        compare_union_arrays(filtered, &expected_array);
2055
2056        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2057        let c = filter(&array, &filter_array).unwrap();
2058        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2059
2060        let mut builder = UnionBuilder::new_dense();
2061        builder.append::<Int32Type>("A", 1).unwrap();
2062        builder.append_null::<Float64Type>("B").unwrap();
2063        let expected_array = builder.build().unwrap();
2064
2065        compare_union_arrays(filtered, &expected_array);
2066    }
2067
2068    #[test]
2069    fn test_filter_union_array_sparse() {
2070        let mut builder = UnionBuilder::new_sparse();
2071        builder.append::<Int32Type>("A", 1).unwrap();
2072        builder.append::<Float64Type>("B", 3.2).unwrap();
2073        builder.append::<Int32Type>("A", 34).unwrap();
2074        let array = builder.build().unwrap();
2075
2076        test_filter_union_array(array);
2077    }
2078
2079    #[test]
2080    fn test_filter_union_array_sparse_with_nulls() {
2081        let mut builder = UnionBuilder::new_sparse();
2082        builder.append::<Int32Type>("A", 1).unwrap();
2083        builder.append::<Float64Type>("B", 3.2).unwrap();
2084        builder.append_null::<Float64Type>("B").unwrap();
2085        builder.append::<Int32Type>("A", 34).unwrap();
2086        let array = builder.build().unwrap();
2087
2088        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2089        let c = filter(&array, &filter_array).unwrap();
2090        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2091
2092        let mut builder = UnionBuilder::new_sparse();
2093        builder.append::<Int32Type>("A", 1).unwrap();
2094        builder.append_null::<Float64Type>("B").unwrap();
2095        let expected_array = builder.build().unwrap();
2096
2097        compare_union_arrays(filtered, &expected_array);
2098    }
2099
2100    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2101        assert_eq!(union1.len(), union2.len());
2102
2103        for i in 0..union1.len() {
2104            let type_id = union1.type_id(i);
2105
2106            let slot1 = union1.value(i);
2107            let slot2 = union2.value(i);
2108
2109            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2110
2111            if !slot1.is_null(0) && !slot2.is_null(0) {
2112                match type_id {
2113                    0 => {
2114                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2115                        assert_eq!(slot1.len(), 1);
2116                        let value1 = slot1.value(0);
2117
2118                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2119                        assert_eq!(slot2.len(), 1);
2120                        let value2 = slot2.value(0);
2121                        assert_eq!(value1, value2);
2122                    }
2123                    1 => {
2124                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2125                        assert_eq!(slot1.len(), 1);
2126                        let value1 = slot1.value(0);
2127
2128                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2129                        assert_eq!(slot2.len(), 1);
2130                        let value2 = slot2.value(0);
2131                        assert_eq!(value1, value2);
2132                    }
2133                    _ => unreachable!(),
2134                }
2135            }
2136        }
2137    }
2138
2139    #[test]
2140    fn test_filter_struct() {
2141        let predicate = BooleanArray::from(vec![true, false, true, false]);
2142
2143        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2144        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2145
2146        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2147        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2148
2149        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2150        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2151
2152        let a_field = Field::new("a", DataType::Utf8, false);
2153        let b_field = Field::new("b", DataType::Int32, false);
2154
2155        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2156        let expected =
2157            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2158
2159        let result = filter(&array, &predicate).unwrap();
2160
2161        assert_eq!(result.to_data(), expected.to_data());
2162
2163        let array = StructArray::new(
2164            vec![a_field.clone()].into(),
2165            vec![a.clone()],
2166            Some(null_mask.clone()),
2167        );
2168        let expected = StructArray::new(
2169            vec![a_field.clone()].into(),
2170            vec![a_filtered.clone()],
2171            Some(null_mask_filtered.clone()),
2172        );
2173
2174        let result = filter(&array, &predicate).unwrap();
2175
2176        assert_eq!(result.to_data(), expected.to_data());
2177
2178        let array = StructArray::new(
2179            vec![a_field.clone(), b_field.clone()].into(),
2180            vec![a.clone(), b.clone()],
2181            None,
2182        );
2183        let expected = StructArray::new(
2184            vec![a_field.clone(), b_field.clone()].into(),
2185            vec![a_filtered.clone(), b_filtered.clone()],
2186            None,
2187        );
2188
2189        let result = filter(&array, &predicate).unwrap();
2190
2191        assert_eq!(result.to_data(), expected.to_data());
2192
2193        let array = StructArray::new(
2194            vec![a_field.clone(), b_field.clone()].into(),
2195            vec![a.clone(), b.clone()],
2196            Some(null_mask.clone()),
2197        );
2198
2199        let expected = StructArray::new(
2200            vec![a_field.clone(), b_field.clone()].into(),
2201            vec![a_filtered.clone(), b_filtered.clone()],
2202            Some(null_mask_filtered.clone()),
2203        );
2204
2205        let result = filter(&array, &predicate).unwrap();
2206
2207        assert_eq!(result.to_data(), expected.to_data());
2208    }
2209
2210    #[test]
2211    fn test_filter_empty_struct() {
2212        /*
2213            "a": {
2214                "b": int64,
2215                "c": {}
2216            },
2217        */
2218        let fields = arrow_schema::Field::new(
2219            "a",
2220            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2221                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2222                arrow_schema::Field::new(
2223                    "c",
2224                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2225                    true,
2226                ),
2227            ])),
2228            true,
2229        );
2230
2231        /* Test record
2232            {"a":{"c": {}}}
2233            {"a":{"c": {}}}
2234            {"a":{"c": {}}}
2235        */
2236
2237        // Create the record batch with the nested struct array
2238        let schema = Arc::new(Schema::new(vec![fields]));
2239
2240        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2241        let c = Arc::new(StructArray::new_empty_fields(
2242            3,
2243            Some(NullBuffer::from(vec![true, true, true])),
2244        ));
2245        let a = StructArray::new(
2246            vec![
2247                Field::new("b", DataType::Int64, true),
2248                Field::new("c", DataType::Struct(Fields::empty()), true),
2249            ]
2250            .into(),
2251            vec![b.clone(), c.clone()],
2252            Some(NullBuffer::from(vec![true, true, true])),
2253        );
2254        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2255        println!("{record_batch:?}");
2256
2257        // Apply the filter
2258        let predicate = BooleanArray::from(vec![true, false, true]);
2259        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2260
2261        // The filtered batch should have 2 rows (the 1st and 3rd)
2262        assert_eq!(filtered_batch.num_rows(), 2);
2263    }
2264
2265    #[test]
2266    #[should_panic]
2267    fn test_filter_bits_too_large() {
2268        let buffer = BooleanBuffer::from(vec![false; 8]);
2269        let predicate = BooleanArray::from(vec![true; 9]);
2270        let filter = FilterBuilder::new(&predicate).build();
2271        filter_bits(&buffer, &filter);
2272    }
2273
2274    #[test]
2275    #[should_panic]
2276    fn test_filter_native_too_large() {
2277        let values = vec![1; 8];
2278        let predicate = BooleanArray::from(vec![false; 9]);
2279        let filter = FilterBuilder::new(&predicate).build();
2280        filter_native(&values, &filter);
2281    }
2282}