arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        Self(filter.values().set_slices())
62    }
63}
64
65impl Iterator for SlicesIterator<'_> {
66    type Item = (usize, usize);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        self.0.next()
70    }
71}
72
73/// An iterator of `usize` whose index in [`BooleanArray`] is true
74///
75/// This provides the best performance on most predicates, apart from those which keep
76/// large runs and therefore favour [`SlicesIterator`]
77struct IndexIterator<'a> {
78    remaining: usize,
79    iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84        assert_eq!(filter.null_count(), 0);
85        let iter = filter.values().set_indices();
86        Self { remaining, iter }
87    }
88}
89
90impl Iterator for IndexIterator<'_> {
91    type Item = usize;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        if self.remaining != 0 {
95            // Fascinatingly swapping these two lines around results in a 50%
96            // performance regression for some benchmarks
97            let next = self.iter.next().expect("IndexIterator exhausted early");
98            self.remaining -= 1;
99            // Must panic if exhausted early as trusted length iterator
100            return Some(next);
101        }
102        None
103    }
104
105    fn size_hint(&self) -> (usize, Option<usize>) {
106        (self.remaining, Some(self.remaining))
107    }
108}
109
110/// Counts the number of set bits in `filter`
111fn filter_count(filter: &BooleanArray) -> usize {
112    filter.values().count_set_bits()
113}
114
115/// Function that can filter arbitrary arrays
116///
117/// Deprecated: Use [`FilterPredicate`] instead
118#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121/// Returns a prepared function optimized to filter multiple arrays.
122///
123/// Creating this function requires time, but using it is faster than [filter] when the
124/// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
125/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
126/// Therefore, it is considered undefined behavior to pass `filter` with null values.
127///
128/// Deprecated: Use [`FilterBuilder`] instead
129#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132    let iter = SlicesIterator::new(filter);
133    let filter_count = filter_count(filter);
134    let chunks = iter.collect::<Vec<_>>();
135
136    Ok(Box::new(move |array: &ArrayData| {
137        match filter_count {
138            // return all
139            len if len == array.len() => array.clone(),
140            0 => ArrayData::new_empty(array.data_type()),
141            _ => {
142                let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143                chunks
144                    .iter()
145                    .for_each(|(start, end)| mutable.extend(0, *start, *end));
146                mutable.freeze()
147            }
148        }
149    }))
150}
151
152/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
153pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154    let nulls = filter.nulls().unwrap();
155    let mask = filter.values() & nulls.inner();
156    BooleanArray::new(mask, None)
157}
158
159/// Returns a filtered `values` [`Array`] where the corresponding elements of
160/// `predicate` are `true`.
161///
162/// # See also
163/// * [`FilterBuilder`] for more control over the filtering process.
164/// * [`filter_record_batch`] to filter a [`RecordBatch`]
165/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
166///   the results into a single array.
167///
168/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
169///
170/// # Example
171/// ```rust
172/// # use arrow_array::{Int32Array, BooleanArray};
173/// # use arrow_select::filter::filter;
174/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
175/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
176/// let c = filter(&array, &filter_array).unwrap();
177/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
178/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
179/// ```
180pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
181    let mut filter_builder = FilterBuilder::new(predicate);
182
183    if multiple_arrays(values.data_type()) {
184        // Only optimize if filtering more than one array
185        // Otherwise, the overhead of optimization can be more than the benefit
186        filter_builder = filter_builder.optimize();
187    }
188
189    let predicate = filter_builder.build();
190
191    filter_array(values, &predicate)
192}
193
194fn multiple_arrays(data_type: &DataType) -> bool {
195    match data_type {
196        DataType::Struct(fields) => {
197            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
198        }
199        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
200        _ => false,
201    }
202}
203
204/// Returns a filtered [RecordBatch] where the corresponding elements of
205/// `predicate` are true.
206///
207/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
208pub fn filter_record_batch(
209    record_batch: &RecordBatch,
210    predicate: &BooleanArray,
211) -> Result<RecordBatch, ArrowError> {
212    let mut filter_builder = FilterBuilder::new(predicate);
213    if record_batch.num_columns() > 1 {
214        // Only optimize if filtering more than one column
215        // Otherwise, the overhead of optimization can be more than the benefit
216        filter_builder = filter_builder.optimize();
217    }
218    let filter = filter_builder.build();
219
220    let filtered_arrays = record_batch
221        .columns()
222        .iter()
223        .map(|a| filter_array(a, &filter))
224        .collect::<Result<Vec<_>, _>>()?;
225    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
226    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
227}
228
229/// A builder to construct [`FilterPredicate`]
230#[derive(Debug)]
231pub struct FilterBuilder {
232    filter: BooleanArray,
233    count: usize,
234    strategy: IterationStrategy,
235}
236
237impl FilterBuilder {
238    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
239    pub fn new(filter: &BooleanArray) -> Self {
240        let filter = match filter.null_count() {
241            0 => filter.clone(),
242            _ => prep_null_mask_filter(filter),
243        };
244
245        let count = filter_count(&filter);
246        let strategy = IterationStrategy::default_strategy(filter.len(), count);
247
248        Self {
249            filter,
250            count,
251            strategy,
252        }
253    }
254
255    /// Compute an optimised representation of the provided `filter` mask that can be
256    /// applied to an array more quickly.
257    ///
258    /// Note: There is limited benefit to calling this to then filter a single array
259    /// Note: This will likely have a larger memory footprint than the original mask
260    pub fn optimize(mut self) -> Self {
261        match self.strategy {
262            IterationStrategy::SlicesIterator => {
263                let slices = SlicesIterator::new(&self.filter).collect();
264                self.strategy = IterationStrategy::Slices(slices)
265            }
266            IterationStrategy::IndexIterator => {
267                let indices = IndexIterator::new(&self.filter, self.count).collect();
268                self.strategy = IterationStrategy::Indices(indices)
269            }
270            _ => {}
271        }
272        self
273    }
274
275    /// Construct the final `FilterPredicate`
276    pub fn build(self) -> FilterPredicate {
277        FilterPredicate {
278            filter: self.filter,
279            count: self.count,
280            strategy: self.strategy,
281        }
282    }
283}
284
285/// The iteration strategy used to evaluate [`FilterPredicate`]
286#[derive(Debug)]
287enum IterationStrategy {
288    /// A lazily evaluated iterator of ranges
289    SlicesIterator,
290    /// A lazily evaluated iterator of indices
291    IndexIterator,
292    /// A precomputed list of indices
293    Indices(Vec<usize>),
294    /// A precomputed array of ranges
295    Slices(Vec<(usize, usize)>),
296    /// Select all rows
297    All,
298    /// Select no rows
299    None,
300}
301
302impl IterationStrategy {
303    /// The default [`IterationStrategy`] for a filter of length `filter_length`
304    /// and selecting `filter_count` rows
305    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
306        if filter_length == 0 || filter_count == 0 {
307            return IterationStrategy::None;
308        }
309
310        if filter_count == filter_length {
311            return IterationStrategy::All;
312        }
313
314        // Compute the selectivity of the predicate by dividing the number of true
315        // bits in the predicate by the predicate's total length
316        //
317        // This can then be used as a heuristic for the optimal iteration strategy
318        let selectivity_frac = filter_count as f64 / filter_length as f64;
319        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
320            return IterationStrategy::SlicesIterator;
321        }
322        IterationStrategy::IndexIterator
323    }
324}
325
326/// A filtering predicate that can be applied to an [`Array`]
327#[derive(Debug)]
328pub struct FilterPredicate {
329    filter: BooleanArray,
330    count: usize,
331    strategy: IterationStrategy,
332}
333
334impl FilterPredicate {
335    /// Selects rows from `values` based on this [`FilterPredicate`]
336    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
337        filter_array(values, self)
338    }
339
340    /// Number of rows being selected based on this [`FilterPredicate`]
341    pub fn count(&self) -> usize {
342        self.count
343    }
344}
345
346fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
347    if predicate.filter.len() > values.len() {
348        return Err(ArrowError::InvalidArgumentError(format!(
349            "Filter predicate of length {} is larger than target array of length {}",
350            predicate.filter.len(),
351            values.len()
352        )));
353    }
354
355    match predicate.strategy {
356        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
357        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
358        // actually filter
359        _ => downcast_primitive_array! {
360            values => Ok(Arc::new(filter_primitive(values, predicate))),
361            DataType::Boolean => {
362                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
363                Ok(Arc::new(filter_boolean(values, predicate)))
364            }
365            DataType::Utf8 => {
366                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
367            }
368            DataType::LargeUtf8 => {
369                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
370            }
371            DataType::Utf8View => {
372                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
373            }
374            DataType::Binary => {
375                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
376            }
377            DataType::LargeBinary => {
378                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
379            }
380            DataType::BinaryView => {
381                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
382            }
383            DataType::FixedSizeBinary(_) => {
384                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
385            }
386            DataType::RunEndEncoded(_, _) => {
387                downcast_run_array!{
388                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
389                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
390                }
391            }
392            DataType::Dictionary(_, _) => downcast_dictionary_array! {
393                values => Ok(Arc::new(filter_dict(values, predicate))),
394                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
395            }
396            DataType::Struct(_) => {
397                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
398            }
399            DataType::Union(_, UnionMode::Sparse) => {
400                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
401            }
402            _ => {
403                let data = values.to_data();
404                // fallback to using MutableArrayData
405                let mut mutable = MutableArrayData::new(
406                    vec![&data],
407                    false,
408                    predicate.count,
409                );
410
411                match &predicate.strategy {
412                    IterationStrategy::Slices(slices) => {
413                        slices
414                            .iter()
415                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
416                    }
417                    _ => {
418                        let iter = SlicesIterator::new(&predicate.filter);
419                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
420                    }
421                }
422
423                let data = mutable.freeze();
424                Ok(make_array(data))
425            }
426        },
427    }
428}
429
430/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
431fn filter_run_end_array<R: RunEndIndexType>(
432    array: &RunArray<R>,
433    predicate: &FilterPredicate,
434) -> Result<RunArray<R>, ArrowError>
435where
436    R::Native: Into<i64> + From<bool>,
437    R::Native: AddAssign,
438{
439    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
440    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
441
442    let mut start = 0u64;
443    let mut j = 0;
444    let mut count = R::default_value();
445    let filter_values = predicate.filter.values();
446    let run_ends = run_ends.inner();
447
448    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
449        let mut keep = false;
450        let mut end = run_ends[i].into() as u64;
451        let difference = end.saturating_sub(filter_values.len() as u64);
452        end -= difference;
453
454        // Safety: we subtract the difference off `end` so we are always within bounds
455        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
456            count += R::Native::from(pred);
457            keep |= pred
458        }
459        // this is to avoid branching
460        new_run_ends[j] = count;
461        j += keep as usize;
462
463        start = end;
464        keep
465    })
466    .into();
467
468    new_run_ends.truncate(j);
469
470    let values = array.values();
471    let values = filter(&values, &pred)?;
472
473    let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
474    RunArray::try_new(&run_ends, &values)
475}
476
477/// Computes a new null mask for `data` based on `predicate`
478///
479/// If the predicate selected no null-rows, returns `None`, otherwise returns
480/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
481/// in the filtered output, and `null_buffer` is the filtered null buffer
482///
483fn filter_null_mask(
484    nulls: Option<&NullBuffer>,
485    predicate: &FilterPredicate,
486) -> Option<(usize, Buffer)> {
487    let nulls = nulls?;
488    if nulls.null_count() == 0 {
489        return None;
490    }
491
492    let nulls = filter_bits(nulls.inner(), predicate);
493    // The filtered `nulls` has a length of `predicate.count` bits and
494    // therefore the null count is this minus the number of valid bits
495    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
496
497    if null_count == 0 {
498        return None;
499    }
500
501    Some((null_count, nulls))
502}
503
504/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
505fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
506    let src = buffer.values();
507    let offset = buffer.offset();
508
509    match &predicate.strategy {
510        IterationStrategy::IndexIterator => {
511            let bits = IndexIterator::new(&predicate.filter, predicate.count)
512                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
513
514            // SAFETY: `IndexIterator` reports its size correctly
515            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
516        }
517        IterationStrategy::Indices(indices) => {
518            let bits = indices
519                .iter()
520                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
521
522            // SAFETY: `Vec::iter()` reports its size correctly
523            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
524        }
525        IterationStrategy::SlicesIterator => {
526            let mut builder = BooleanBufferBuilder::new(predicate.count);
527            for (start, end) in SlicesIterator::new(&predicate.filter) {
528                builder.append_packed_range(start + offset..end + offset, src)
529            }
530            builder.into()
531        }
532        IterationStrategy::Slices(slices) => {
533            let mut builder = BooleanBufferBuilder::new(predicate.count);
534            for (start, end) in slices {
535                builder.append_packed_range(*start + offset..*end + offset, src)
536            }
537            builder.into()
538        }
539        IterationStrategy::All | IterationStrategy::None => unreachable!(),
540    }
541}
542
543/// `filter` implementation for boolean buffers
544fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
545    let values = filter_bits(array.values(), predicate);
546
547    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
548        .len(predicate.count)
549        .add_buffer(values);
550
551    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
552        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
553    }
554
555    let data = unsafe { builder.build_unchecked() };
556    BooleanArray::from(data)
557}
558
559#[inline(never)]
560fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
561    assert!(values.len() >= predicate.filter.len());
562
563    match &predicate.strategy {
564        IterationStrategy::SlicesIterator => {
565            let mut buffer = Vec::with_capacity(predicate.count);
566            for (start, end) in SlicesIterator::new(&predicate.filter) {
567                buffer.extend_from_slice(&values[start..end]);
568            }
569            buffer.into()
570        }
571        IterationStrategy::Slices(slices) => {
572            let mut buffer = Vec::with_capacity(predicate.count);
573            for (start, end) in slices {
574                buffer.extend_from_slice(&values[*start..*end]);
575            }
576            buffer.into()
577        }
578        IterationStrategy::IndexIterator => {
579            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
580
581            // SAFETY: IndexIterator is trusted length
582            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
583        }
584        IterationStrategy::Indices(indices) => {
585            let iter = indices.iter().map(|x| values[*x]);
586            iter.collect::<Vec<_>>().into()
587        }
588        IterationStrategy::All | IterationStrategy::None => unreachable!(),
589    }
590}
591
592/// `filter` implementation for primitive arrays
593fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
594where
595    T: ArrowPrimitiveType,
596{
597    let values = array.values();
598    let buffer = filter_native(values, predicate);
599    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
600        .len(predicate.count)
601        .add_buffer(buffer);
602
603    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
604        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
605    }
606
607    let data = unsafe { builder.build_unchecked() };
608    PrimitiveArray::from(data)
609}
610
611/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
612/// used to build a new [`GenericByteArray`] by copying values from the source
613///
614/// TODO(raphael): Could this be used for the take kernel as well?
615struct FilterBytes<'a, OffsetSize> {
616    src_offsets: &'a [OffsetSize],
617    src_values: &'a [u8],
618    dst_offsets: Vec<OffsetSize>,
619    dst_values: Vec<u8>,
620    cur_offset: OffsetSize,
621}
622
623impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
624where
625    OffsetSize: OffsetSizeTrait,
626{
627    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
628    where
629        T: ByteArrayType<Offset = OffsetSize>,
630    {
631        let dst_values = Vec::new();
632        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
633        let cur_offset = OffsetSize::from_usize(0).unwrap();
634
635        dst_offsets.push(cur_offset);
636
637        Self {
638            src_offsets: array.value_offsets(),
639            src_values: array.value_data(),
640            dst_offsets,
641            dst_values,
642            cur_offset,
643        }
644    }
645
646    /// Returns the byte offset at `idx`
647    #[inline]
648    fn get_value_offset(&self, idx: usize) -> usize {
649        self.src_offsets[idx].as_usize()
650    }
651
652    /// Returns the start and end of the value at index `idx` along with its length
653    #[inline]
654    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
655        // These can only fail if `array` contains invalid data
656        let start = self.get_value_offset(idx);
657        let end = self.get_value_offset(idx + 1);
658        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
659        (start, end, len)
660    }
661
662    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
663        self.dst_offsets.extend(iter.map(|idx| {
664            let start = self.src_offsets[idx].as_usize();
665            let end = self.src_offsets[idx + 1].as_usize();
666            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
667            self.cur_offset += len;
668
669            self.cur_offset
670        }));
671    }
672
673    /// Extends the in-progress array by the indexes in the provided iterator
674    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
675        self.dst_values.reserve_exact(self.cur_offset.as_usize());
676
677        for idx in iter {
678            let start = self.src_offsets[idx].as_usize();
679            let end = self.src_offsets[idx + 1].as_usize();
680            self.dst_values
681                .extend_from_slice(&self.src_values[start..end]);
682        }
683    }
684
685    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
686        self.dst_offsets.reserve_exact(count);
687        for (start, end) in iter {
688            // These can only fail if `array` contains invalid data
689            for idx in start..end {
690                let (_, _, len) = self.get_value_range(idx);
691                self.cur_offset += len;
692                self.dst_offsets.push(self.cur_offset);
693            }
694        }
695    }
696
697    /// Extends the in-progress array by the ranges in the provided iterator
698    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
699        self.dst_values.reserve_exact(self.cur_offset.as_usize());
700
701        for (start, end) in iter {
702            let value_start = self.get_value_offset(start);
703            let value_end = self.get_value_offset(end);
704            self.dst_values
705                .extend_from_slice(&self.src_values[value_start..value_end]);
706        }
707    }
708}
709
710/// `filter` implementation for byte arrays
711///
712/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
713/// data copied across. This allows handling the null mask separately from the data
714fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
715where
716    T: ByteArrayType,
717{
718    let mut filter = FilterBytes::new(predicate.count, array);
719
720    match &predicate.strategy {
721        IterationStrategy::SlicesIterator => {
722            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
723            filter.extend_slices(SlicesIterator::new(&predicate.filter))
724        }
725        IterationStrategy::Slices(slices) => {
726            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
727            filter.extend_slices(slices.iter().cloned())
728        }
729        IterationStrategy::IndexIterator => {
730            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
731            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
732        }
733        IterationStrategy::Indices(indices) => {
734            filter.extend_offsets_idx(indices.iter().cloned());
735            filter.extend_idx(indices.iter().cloned())
736        }
737        IterationStrategy::All | IterationStrategy::None => unreachable!(),
738    }
739
740    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
741        .len(predicate.count)
742        .add_buffer(filter.dst_offsets.into())
743        .add_buffer(filter.dst_values.into());
744
745    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
746        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
747    }
748
749    let data = unsafe { builder.build_unchecked() };
750    GenericByteArray::from(data)
751}
752
753/// `filter` implementation for byte view arrays.
754fn filter_byte_view<T: ByteViewType>(
755    array: &GenericByteViewArray<T>,
756    predicate: &FilterPredicate,
757) -> GenericByteViewArray<T> {
758    let new_view_buffer = filter_native(array.views(), predicate);
759
760    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
761        .len(predicate.count)
762        .add_buffer(new_view_buffer)
763        .add_buffers(array.data_buffers().to_vec());
764
765    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
766        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
767    }
768
769    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
770}
771
772fn filter_fixed_size_binary(
773    array: &FixedSizeBinaryArray,
774    predicate: &FilterPredicate,
775) -> FixedSizeBinaryArray {
776    let values: &[u8] = array.values();
777    let value_length = array.value_length() as usize;
778    let calculate_offset_from_index = |index: usize| index * value_length;
779    let buffer = match &predicate.strategy {
780        IterationStrategy::SlicesIterator => {
781            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
782            for (start, end) in SlicesIterator::new(&predicate.filter) {
783                buffer.extend_from_slice(
784                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
785                );
786            }
787            buffer
788        }
789        IterationStrategy::Slices(slices) => {
790            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
791            for (start, end) in slices {
792                buffer.extend_from_slice(
793                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
794                );
795            }
796            buffer
797        }
798        IterationStrategy::IndexIterator => {
799            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
800                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
801            });
802
803            let mut buffer = MutableBuffer::new(predicate.count * value_length);
804            iter.for_each(|item| buffer.extend_from_slice(item));
805            buffer
806        }
807        IterationStrategy::Indices(indices) => {
808            let iter = indices.iter().map(|x| {
809                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
810            });
811
812            let mut buffer = MutableBuffer::new(predicate.count * value_length);
813            iter.for_each(|item| buffer.extend_from_slice(item));
814            buffer
815        }
816        IterationStrategy::All | IterationStrategy::None => unreachable!(),
817    };
818    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
819        .len(predicate.count)
820        .add_buffer(buffer.into());
821
822    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
823        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
824    }
825
826    let data = unsafe { builder.build_unchecked() };
827    FixedSizeBinaryArray::from(data)
828}
829
830/// `filter` implementation for dictionaries
831fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
832where
833    T: ArrowDictionaryKeyType,
834    T::Native: num::Num,
835{
836    let builder = filter_primitive::<T>(array.keys(), predicate)
837        .into_data()
838        .into_builder()
839        .data_type(array.data_type().clone())
840        .child_data(vec![array.values().to_data()]);
841
842    // SAFETY:
843    // Keys were valid before, filtered subset is therefore still valid
844    DictionaryArray::from(unsafe { builder.build_unchecked() })
845}
846
847/// `filter` implementation for structs
848fn filter_struct(
849    array: &StructArray,
850    predicate: &FilterPredicate,
851) -> Result<StructArray, ArrowError> {
852    let columns = array
853        .columns()
854        .iter()
855        .map(|column| filter_array(column, predicate))
856        .collect::<Result<_, _>>()?;
857
858    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
859        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
860
861        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
862    } else {
863        None
864    };
865
866    Ok(unsafe {
867        StructArray::new_unchecked_with_length(
868            array.fields().clone(),
869            columns,
870            nulls,
871            predicate.count(),
872        )
873    })
874}
875
876/// `filter` implementation for sparse unions
877fn filter_sparse_union(
878    array: &UnionArray,
879    predicate: &FilterPredicate,
880) -> Result<UnionArray, ArrowError> {
881    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
882        unreachable!()
883    };
884
885    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
886
887    let children = fields
888        .iter()
889        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
890        .collect::<Result<_, _>>()?;
891
892    Ok(unsafe {
893        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
894    })
895}
896
897#[cfg(test)]
898mod tests {
899    use arrow_array::builder::*;
900    use arrow_array::cast::as_run_array;
901    use arrow_array::types::*;
902    use rand::distr::uniform::{UniformSampler, UniformUsize};
903    use rand::distr::{Alphanumeric, StandardUniform};
904    use rand::prelude::*;
905    use rand::rng;
906
907    use super::*;
908
909    macro_rules! def_temporal_test {
910        ($test:ident, $array_type: ident, $data: expr) => {
911            #[test]
912            fn $test() {
913                let a = $data;
914                let b = BooleanArray::from(vec![true, false, true, false]);
915                let c = filter(&a, &b).unwrap();
916                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
917                assert_eq!(2, d.len());
918                assert_eq!(1, d.value(0));
919                assert_eq!(3, d.value(1));
920            }
921        };
922    }
923
924    def_temporal_test!(
925        test_filter_date32,
926        Date32Array,
927        Date32Array::from(vec![1, 2, 3, 4])
928    );
929    def_temporal_test!(
930        test_filter_date64,
931        Date64Array,
932        Date64Array::from(vec![1, 2, 3, 4])
933    );
934    def_temporal_test!(
935        test_filter_time32_second,
936        Time32SecondArray,
937        Time32SecondArray::from(vec![1, 2, 3, 4])
938    );
939    def_temporal_test!(
940        test_filter_time32_millisecond,
941        Time32MillisecondArray,
942        Time32MillisecondArray::from(vec![1, 2, 3, 4])
943    );
944    def_temporal_test!(
945        test_filter_time64_microsecond,
946        Time64MicrosecondArray,
947        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
948    );
949    def_temporal_test!(
950        test_filter_time64_nanosecond,
951        Time64NanosecondArray,
952        Time64NanosecondArray::from(vec![1, 2, 3, 4])
953    );
954    def_temporal_test!(
955        test_filter_duration_second,
956        DurationSecondArray,
957        DurationSecondArray::from(vec![1, 2, 3, 4])
958    );
959    def_temporal_test!(
960        test_filter_duration_millisecond,
961        DurationMillisecondArray,
962        DurationMillisecondArray::from(vec![1, 2, 3, 4])
963    );
964    def_temporal_test!(
965        test_filter_duration_microsecond,
966        DurationMicrosecondArray,
967        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
968    );
969    def_temporal_test!(
970        test_filter_duration_nanosecond,
971        DurationNanosecondArray,
972        DurationNanosecondArray::from(vec![1, 2, 3, 4])
973    );
974    def_temporal_test!(
975        test_filter_timestamp_second,
976        TimestampSecondArray,
977        TimestampSecondArray::from(vec![1, 2, 3, 4])
978    );
979    def_temporal_test!(
980        test_filter_timestamp_millisecond,
981        TimestampMillisecondArray,
982        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
983    );
984    def_temporal_test!(
985        test_filter_timestamp_microsecond,
986        TimestampMicrosecondArray,
987        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
988    );
989    def_temporal_test!(
990        test_filter_timestamp_nanosecond,
991        TimestampNanosecondArray,
992        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
993    );
994
995    #[test]
996    fn test_filter_array_slice() {
997        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
998        let b = BooleanArray::from(vec![true, false, false, true]);
999        // filtering with sliced filter array is not currently supported
1000        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1001        // let b = b_slice.as_any().downcast_ref().unwrap();
1002        let c = filter(&a, &b).unwrap();
1003        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1004        assert_eq!(2, d.len());
1005        assert_eq!(6, d.value(0));
1006        assert_eq!(9, d.value(1));
1007    }
1008
1009    #[test]
1010    fn test_filter_array_low_density() {
1011        // this test exercises the all 0's branch of the filter algorithm
1012        let mut data_values = (1..=65).collect::<Vec<i32>>();
1013        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1014        // set up two more values after the batch
1015        data_values.extend_from_slice(&[66, 67]);
1016        filter_values.extend_from_slice(&[false, true]);
1017        let a = Int32Array::from(data_values);
1018        let b = BooleanArray::from(filter_values);
1019        let c = filter(&a, &b).unwrap();
1020        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1021        assert_eq!(2, d.len());
1022        assert_eq!(65, d.value(0));
1023        assert_eq!(67, d.value(1));
1024    }
1025
1026    #[test]
1027    fn test_filter_array_high_density() {
1028        // this test exercises the all 1's branch of the filter algorithm
1029        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1030        let mut filter_values = (1..=65)
1031            .map(|i| !matches!(i % 65, 0))
1032            .collect::<Vec<bool>>();
1033        // set second data value to null
1034        data_values[1] = None;
1035        // set up two more values after the batch
1036        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1037        filter_values.extend_from_slice(&[false, true, true, true]);
1038        let a = Int32Array::from(data_values);
1039        let b = BooleanArray::from(filter_values);
1040        let c = filter(&a, &b).unwrap();
1041        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1042        assert_eq!(67, d.len());
1043        assert_eq!(3, d.null_count());
1044        assert_eq!(1, d.value(0));
1045        assert!(d.is_null(1));
1046        assert_eq!(64, d.value(63));
1047        assert!(d.is_null(64));
1048        assert_eq!(67, d.value(65));
1049    }
1050
1051    #[test]
1052    fn test_filter_string_array_simple() {
1053        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1054        let b = BooleanArray::from(vec![true, false, true, false]);
1055        let c = filter(&a, &b).unwrap();
1056        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1057        assert_eq!(2, d.len());
1058        assert_eq!("hello", d.value(0));
1059        assert_eq!("world", d.value(1));
1060    }
1061
1062    #[test]
1063    fn test_filter_primitive_array_with_null() {
1064        let a = Int32Array::from(vec![Some(5), None]);
1065        let b = BooleanArray::from(vec![false, true]);
1066        let c = filter(&a, &b).unwrap();
1067        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1068        assert_eq!(1, d.len());
1069        assert!(d.is_null(0));
1070    }
1071
1072    #[test]
1073    fn test_filter_string_array_with_null() {
1074        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1075        let b = BooleanArray::from(vec![true, false, false, true]);
1076        let c = filter(&a, &b).unwrap();
1077        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1078        assert_eq!(2, d.len());
1079        assert_eq!("hello", d.value(0));
1080        assert!(!d.is_null(0));
1081        assert!(d.is_null(1));
1082    }
1083
1084    #[test]
1085    fn test_filter_binary_array_with_null() {
1086        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1087        let a = BinaryArray::from(data);
1088        let b = BooleanArray::from(vec![true, false, false, true]);
1089        let c = filter(&a, &b).unwrap();
1090        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1091        assert_eq!(2, d.len());
1092        assert_eq!(b"hello", d.value(0));
1093        assert!(!d.is_null(0));
1094        assert!(d.is_null(1));
1095    }
1096
1097    fn _test_filter_byte_view<T>()
1098    where
1099        T: ByteViewType,
1100        str: AsRef<T::Native>,
1101        T::Native: PartialEq,
1102    {
1103        let array = {
1104            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1105            let mut builder = GenericByteViewBuilder::<T>::new();
1106            builder.append_value("hello");
1107            builder.append_value("world");
1108            builder.append_null();
1109            builder.append_value("large payload over 12 bytes");
1110            builder.append_value("lulu");
1111            builder.finish()
1112        };
1113
1114        {
1115            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1116            let actual = filter(&array, &predicate).unwrap();
1117
1118            assert_eq!(actual.len(), 3);
1119
1120            let expected = {
1121                // ["hello", null, "large payload over 12 bytes"]
1122                let mut builder = GenericByteViewBuilder::<T>::new();
1123                builder.append_value("hello");
1124                builder.append_null();
1125                builder.append_value("large payload over 12 bytes");
1126                builder.finish()
1127            };
1128
1129            assert_eq!(actual.as_ref(), &expected);
1130        }
1131
1132        {
1133            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1134            let actual = filter(&array, &predicate).unwrap();
1135
1136            assert_eq!(actual.len(), 2);
1137
1138            let expected = {
1139                // ["hello", "lulu"]
1140                let mut builder = GenericByteViewBuilder::<T>::new();
1141                builder.append_value("hello");
1142                builder.append_value("lulu");
1143                builder.finish()
1144            };
1145
1146            assert_eq!(actual.as_ref(), &expected);
1147        }
1148    }
1149
1150    #[test]
1151    fn test_filter_string_view() {
1152        _test_filter_byte_view::<StringViewType>()
1153    }
1154
1155    #[test]
1156    fn test_filter_binary_view() {
1157        _test_filter_byte_view::<BinaryViewType>()
1158    }
1159
1160    #[test]
1161    fn test_filter_fixed_binary() {
1162        let v1 = [1_u8, 2];
1163        let v2 = [3_u8, 4];
1164        let v3 = [5_u8, 6];
1165        let v = vec![&v1, &v2, &v3];
1166        let a = FixedSizeBinaryArray::from(v);
1167        let b = BooleanArray::from(vec![true, false, true]);
1168        let c = filter(&a, &b).unwrap();
1169        let d = c
1170            .as_ref()
1171            .as_any()
1172            .downcast_ref::<FixedSizeBinaryArray>()
1173            .unwrap();
1174        assert_eq!(d.len(), 2);
1175        assert_eq!(d.value(0), &v1);
1176        assert_eq!(d.value(1), &v3);
1177        let c2 = FilterBuilder::new(&b)
1178            .optimize()
1179            .build()
1180            .filter(&a)
1181            .unwrap();
1182        let d2 = c2
1183            .as_ref()
1184            .as_any()
1185            .downcast_ref::<FixedSizeBinaryArray>()
1186            .unwrap();
1187        assert_eq!(d, d2);
1188
1189        let b = BooleanArray::from(vec![false, false, false]);
1190        let c = filter(&a, &b).unwrap();
1191        let d = c
1192            .as_ref()
1193            .as_any()
1194            .downcast_ref::<FixedSizeBinaryArray>()
1195            .unwrap();
1196        assert_eq!(d.len(), 0);
1197
1198        let b = BooleanArray::from(vec![true, true, true]);
1199        let c = filter(&a, &b).unwrap();
1200        let d = c
1201            .as_ref()
1202            .as_any()
1203            .downcast_ref::<FixedSizeBinaryArray>()
1204            .unwrap();
1205        assert_eq!(d.len(), 3);
1206        assert_eq!(d.value(0), &v1);
1207        assert_eq!(d.value(1), &v2);
1208        assert_eq!(d.value(2), &v3);
1209
1210        let b = BooleanArray::from(vec![false, false, true]);
1211        let c = filter(&a, &b).unwrap();
1212        let d = c
1213            .as_ref()
1214            .as_any()
1215            .downcast_ref::<FixedSizeBinaryArray>()
1216            .unwrap();
1217        assert_eq!(d.len(), 1);
1218        assert_eq!(d.value(0), &v3);
1219        let c2 = FilterBuilder::new(&b)
1220            .optimize()
1221            .build()
1222            .filter(&a)
1223            .unwrap();
1224        let d2 = c2
1225            .as_ref()
1226            .as_any()
1227            .downcast_ref::<FixedSizeBinaryArray>()
1228            .unwrap();
1229        assert_eq!(d, d2);
1230    }
1231
1232    #[test]
1233    fn test_filter_array_slice_with_null() {
1234        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1235        let b = BooleanArray::from(vec![true, false, false, true]);
1236        // filtering with sliced filter array is not currently supported
1237        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1238        // let b = b_slice.as_any().downcast_ref().unwrap();
1239        let c = filter(&a, &b).unwrap();
1240        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1241        assert_eq!(2, d.len());
1242        assert!(d.is_null(0));
1243        assert!(!d.is_null(1));
1244        assert_eq!(9, d.value(1));
1245    }
1246
1247    #[test]
1248    fn test_filter_run_end_encoding_array() {
1249        let run_ends = Int64Array::from(vec![2, 3, 8]);
1250        let values = Int64Array::from(vec![7, -2, 9]);
1251        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1252        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1253        let c = filter(&a, &b).unwrap();
1254        let actual: &RunArray<Int64Type> = as_run_array(&c);
1255        assert_eq!(4, actual.len());
1256
1257        let expected = RunArray::try_new(
1258            &Int64Array::from(vec![1, 2, 4]),
1259            &Int64Array::from(vec![7, -2, 9]),
1260        )
1261        .expect("Failed to make expected RunArray test is broken");
1262
1263        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1264        assert_eq!(actual.values(), expected.values())
1265    }
1266
1267    #[test]
1268    fn test_filter_run_end_encoding_array_remove_value() {
1269        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1270        let values = Int32Array::from(vec![7, -2, 9, -8]);
1271        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1272        let b = BooleanArray::from(vec![
1273            false, true, false, false, true, false, true, false, false, false,
1274        ]);
1275        let c = filter(&a, &b).unwrap();
1276        let actual: &RunArray<Int32Type> = as_run_array(&c);
1277        assert_eq!(3, actual.len());
1278
1279        let expected =
1280            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1281                .expect("Failed to make expected RunArray test is broken");
1282
1283        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1284        assert_eq!(actual.values(), expected.values())
1285    }
1286
1287    #[test]
1288    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1289        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1290        let values = Int16Array::from(vec![7, -2, 9, -8]);
1291        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1292        let b = BooleanArray::from(vec![
1293            false, false, false, false, false, false, true, false, false, false,
1294        ]);
1295        let c = filter(&a, &b).unwrap();
1296        let actual: &RunArray<Int16Type> = as_run_array(&c);
1297        assert_eq!(1, actual.len());
1298
1299        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1300            .expect("Failed to make expected RunArray test is broken");
1301
1302        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1303        assert_eq!(actual.values(), expected.values())
1304    }
1305
1306    #[test]
1307    fn test_filter_run_end_encoding_array_empty() {
1308        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1309        let values = Int64Array::from(vec![7, -2, 9, -8]);
1310        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1311        let b = BooleanArray::from(vec![
1312            false, false, false, false, false, false, false, false, false, false,
1313        ]);
1314        let c = filter(&a, &b).unwrap();
1315        let actual: &RunArray<Int64Type> = as_run_array(&c);
1316        assert_eq!(0, actual.len());
1317    }
1318
1319    #[test]
1320    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1321        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1322        let values = Int64Array::from(vec![7, -2, 9, -8]);
1323        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1324        let b = BooleanArray::from(vec![false, true, true]);
1325        let c = filter(&a, &b).unwrap();
1326        let actual: &RunArray<Int64Type> = as_run_array(&c);
1327        assert_eq!(2, actual.len());
1328
1329        let expected = RunArray::try_new(
1330            &Int64Array::from(vec![1, 2]),
1331            &Int64Array::from(vec![7, -2]),
1332        )
1333        .expect("Failed to make expected RunArray test is broken");
1334
1335        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1336        assert_eq!(actual.values(), expected.values())
1337    }
1338
1339    #[test]
1340    fn test_filter_dictionary_array() {
1341        let values = [Some("hello"), None, Some("world"), Some("!")];
1342        let a: Int8DictionaryArray = values.iter().copied().collect();
1343        let b = BooleanArray::from(vec![false, true, true, false]);
1344        let c = filter(&a, &b).unwrap();
1345        let d = c
1346            .as_ref()
1347            .as_any()
1348            .downcast_ref::<Int8DictionaryArray>()
1349            .unwrap();
1350        let value_array = d.values();
1351        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1352        // values are cloned in the filtered dictionary array
1353        assert_eq!(3, values.len());
1354        // but keys are filtered
1355        assert_eq!(2, d.len());
1356        assert!(d.is_null(0));
1357        assert_eq!("world", values.value(d.keys().value(1) as usize));
1358    }
1359
1360    #[test]
1361    fn test_filter_list_array() {
1362        let value_data = ArrayData::builder(DataType::Int32)
1363            .len(8)
1364            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1365            .build()
1366            .unwrap();
1367
1368        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1369
1370        let list_data_type =
1371            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1372        let list_data = ArrayData::builder(list_data_type)
1373            .len(4)
1374            .add_buffer(value_offsets)
1375            .add_child_data(value_data)
1376            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1377            .build()
1378            .unwrap();
1379
1380        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1381        let a = LargeListArray::from(list_data);
1382        let b = BooleanArray::from(vec![false, true, false, true]);
1383        let result = filter(&a, &b).unwrap();
1384
1385        // expected: [[3, 4, 5], null]
1386        let value_data = ArrayData::builder(DataType::Int32)
1387            .len(3)
1388            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1389            .build()
1390            .unwrap();
1391
1392        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1393
1394        let list_data_type =
1395            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1396        let expected = ArrayData::builder(list_data_type)
1397            .len(2)
1398            .add_buffer(value_offsets)
1399            .add_child_data(value_data)
1400            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1401            .build()
1402            .unwrap();
1403
1404        assert_eq!(&make_array(expected), &result);
1405    }
1406
1407    #[test]
1408    fn test_slice_iterator_bits() {
1409        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1410        let filter = BooleanArray::from(filter_values);
1411        let filter_count = filter_count(&filter);
1412
1413        let iter = SlicesIterator::new(&filter);
1414        let chunks = iter.collect::<Vec<_>>();
1415
1416        assert_eq!(chunks, vec![(1, 2)]);
1417        assert_eq!(filter_count, 1);
1418    }
1419
1420    #[test]
1421    fn test_slice_iterator_bits1() {
1422        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1423        let filter = BooleanArray::from(filter_values);
1424        let filter_count = filter_count(&filter);
1425
1426        let iter = SlicesIterator::new(&filter);
1427        let chunks = iter.collect::<Vec<_>>();
1428
1429        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1430        assert_eq!(filter_count, 64 - 1);
1431    }
1432
1433    #[test]
1434    fn test_slice_iterator_chunk_and_bits() {
1435        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1436        let filter = BooleanArray::from(filter_values);
1437        let filter_count = filter_count(&filter);
1438
1439        let iter = SlicesIterator::new(&filter);
1440        let chunks = iter.collect::<Vec<_>>();
1441
1442        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1443        assert_eq!(filter_count, 61 + 61 + 5);
1444    }
1445
1446    #[test]
1447    fn test_null_mask() {
1448        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1449
1450        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1451        let out = filter(&a, &mask1).unwrap();
1452        assert_eq!(out.as_ref(), &a.slice(0, 2));
1453    }
1454
1455    #[test]
1456    fn test_filter_record_batch_no_columns() {
1457        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1458        let options = RecordBatchOptions::default().with_row_count(Some(100));
1459        let record_batch =
1460            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1461        let out = filter_record_batch(&record_batch, &pred).unwrap();
1462
1463        assert_eq!(out.num_rows(), 2);
1464    }
1465
1466    #[test]
1467    fn test_fast_path() {
1468        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1469
1470        // all true
1471        let mask = BooleanArray::from(vec![true, true, true]);
1472        let out = filter(&a, &mask).unwrap();
1473        let b = out
1474            .as_any()
1475            .downcast_ref::<PrimitiveArray<Int64Type>>()
1476            .unwrap();
1477        assert_eq!(&a, b);
1478
1479        // all false
1480        let mask = BooleanArray::from(vec![false, false, false]);
1481        let out = filter(&a, &mask).unwrap();
1482        assert_eq!(out.len(), 0);
1483        assert_eq!(out.data_type(), &DataType::Int64);
1484    }
1485
1486    #[test]
1487    fn test_slices() {
1488        // takes up 2 u64s
1489        let bools = std::iter::repeat(true)
1490            .take(10)
1491            .chain(std::iter::repeat(false).take(30))
1492            .chain(std::iter::repeat(true).take(20))
1493            .chain(std::iter::repeat(false).take(17))
1494            .chain(std::iter::repeat(true).take(4));
1495
1496        let bool_array: BooleanArray = bools.map(Some).collect();
1497
1498        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1499        let expected = vec![(0, 10), (40, 60), (77, 81)];
1500        assert_eq!(slices, expected);
1501
1502        // slice with offset and truncated len
1503        let len = bool_array.len();
1504        let sliced_array = bool_array.slice(7, len - 10);
1505        let sliced_array = sliced_array
1506            .as_any()
1507            .downcast_ref::<BooleanArray>()
1508            .unwrap();
1509        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1510        let expected = vec![(0, 3), (33, 53), (70, 71)];
1511        assert_eq!(slices, expected);
1512    }
1513
1514    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1515        let mut rng = rng();
1516
1517        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1518            .take(mask_len)
1519            .collect();
1520
1521        let buffer = Buffer::from_iter(bools.iter().cloned());
1522
1523        let truncated_length = mask_len - offset - truncate;
1524
1525        let data = ArrayDataBuilder::new(DataType::Boolean)
1526            .len(truncated_length)
1527            .offset(offset)
1528            .add_buffer(buffer)
1529            .build()
1530            .unwrap();
1531
1532        let filter = BooleanArray::from(data);
1533
1534        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1535            .flat_map(|(start, end)| start..end)
1536            .collect();
1537
1538        let count = filter_count(&filter);
1539        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1540
1541        let expected_bits: Vec<_> = bools
1542            .iter()
1543            .skip(offset)
1544            .take(truncated_length)
1545            .enumerate()
1546            .flat_map(|(idx, v)| v.then(|| idx))
1547            .collect();
1548
1549        assert_eq!(slice_bits, expected_bits);
1550        assert_eq!(index_bits, expected_bits);
1551    }
1552
1553    #[test]
1554    #[cfg_attr(miri, ignore)]
1555    fn fuzz_test_slices_iterator() {
1556        let mut rng = rng();
1557
1558        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1559        for _ in 0..100 {
1560            let mask_len = rng.random_range(0..1024);
1561            let max_offset = 64.min(mask_len);
1562            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1563
1564            let max_truncate = 128.min(mask_len - offset);
1565            let truncate = uusize
1566                .sample(&mut rng)
1567                .checked_rem(max_truncate)
1568                .unwrap_or(0);
1569
1570            test_slices_fuzz(mask_len, offset, truncate);
1571        }
1572
1573        test_slices_fuzz(64, 0, 0);
1574        test_slices_fuzz(64, 8, 0);
1575        test_slices_fuzz(64, 8, 8);
1576        test_slices_fuzz(32, 8, 8);
1577        test_slices_fuzz(32, 5, 9);
1578    }
1579
1580    /// Filters `values` by `predicate` using standard rust iterators
1581    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1582        values
1583            .into_iter()
1584            .zip(predicate)
1585            .filter(|(_, x)| **x)
1586            .map(|(a, _)| a)
1587            .collect()
1588    }
1589
1590    /// Generates an array of length `len` with `valid_percent` non-null values
1591    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1592    where
1593        StandardUniform: Distribution<T>,
1594    {
1595        let mut rng = rng();
1596        (0..len)
1597            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1598            .collect()
1599    }
1600
1601    /// Generates an array of length `len` with `valid_percent` non-null values
1602    fn gen_strings(
1603        len: usize,
1604        valid_percent: f64,
1605        str_len_range: std::ops::Range<usize>,
1606    ) -> Vec<Option<String>> {
1607        let mut rng = rng();
1608        (0..len)
1609            .map(|_| {
1610                rng.random_bool(valid_percent).then(|| {
1611                    let len = rng.random_range(str_len_range.clone());
1612                    (0..len)
1613                        .map(|_| char::from(rng.sample(Alphanumeric)))
1614                        .collect()
1615                })
1616            })
1617            .collect()
1618    }
1619
1620    /// Returns an iterator that calls `Option::as_deref` on each item
1621    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1622        src.iter().map(|x| x.as_deref())
1623    }
1624
1625    #[test]
1626    #[cfg_attr(miri, ignore)]
1627    fn fuzz_filter() {
1628        let mut rng = rng();
1629
1630        for i in 0..100 {
1631            let filter_percent = match i {
1632                0..=4 => 1.,
1633                5..=10 => 0.,
1634                _ => rng.random_range(0.0..1.0),
1635            };
1636
1637            let valid_percent = rng.random_range(0.0..1.0);
1638
1639            let array_len = rng.random_range(32..256);
1640            let array_offset = rng.random_range(0..10);
1641
1642            // Construct a predicate
1643            let filter_offset = rng.random_range(0..10);
1644            let filter_truncate = rng.random_range(0..10);
1645            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1646                .take(array_len + filter_offset - filter_truncate)
1647                .collect();
1648
1649            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1650
1651            // Offset predicate
1652            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1653            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1654            let bools = &bools[filter_offset..];
1655
1656            // Test i32
1657            let values = gen_primitive(array_len + array_offset, valid_percent);
1658            let src = Int32Array::from_iter(values.iter().cloned());
1659
1660            let src = src.slice(array_offset, array_len);
1661            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1662            let values = &values[array_offset..];
1663
1664            let filtered = filter(src, predicate).unwrap();
1665            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1666            let actual: Vec<_> = array.iter().collect();
1667
1668            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1669
1670            // Test string
1671            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1672            let src = StringArray::from_iter(as_deref(&strings));
1673
1674            let src = src.slice(array_offset, array_len);
1675            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1676
1677            let filtered = filter(src, predicate).unwrap();
1678            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1679            let actual: Vec<_> = array.iter().collect();
1680
1681            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1682            assert_eq!(actual, expected_strings);
1683
1684            // Test string dictionary
1685            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1686
1687            let src = src.slice(array_offset, array_len);
1688            let src = src
1689                .as_any()
1690                .downcast_ref::<DictionaryArray<Int32Type>>()
1691                .unwrap();
1692
1693            let filtered = filter(src, predicate).unwrap();
1694
1695            let array = filtered
1696                .as_any()
1697                .downcast_ref::<DictionaryArray<Int32Type>>()
1698                .unwrap();
1699
1700            let values = array
1701                .values()
1702                .as_any()
1703                .downcast_ref::<StringArray>()
1704                .unwrap();
1705
1706            let actual: Vec<_> = array
1707                .keys()
1708                .iter()
1709                .map(|key| key.map(|key| values.value(key as usize)))
1710                .collect();
1711
1712            assert_eq!(actual, expected_strings);
1713        }
1714    }
1715
1716    #[test]
1717    fn test_filter_map() {
1718        let mut builder =
1719            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1720        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1721        builder.keys().append_value("key1");
1722        builder.values().append_value(1);
1723        builder.append(true).unwrap();
1724        builder.keys().append_value("key2");
1725        builder.keys().append_value("key3");
1726        builder.values().append_value(2);
1727        builder.values().append_value(3);
1728        builder.append(true).unwrap();
1729        builder.append(false).unwrap();
1730        builder.keys().append_value("key1");
1731        builder.values().append_value(1);
1732        builder.append(true).unwrap();
1733        let maparray = Arc::new(builder.finish()) as ArrayRef;
1734
1735        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1736            .into_iter()
1737            .collect::<BooleanArray>();
1738        let got = filter(&maparray, &indices).unwrap();
1739
1740        let mut builder =
1741            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1742        builder.keys().append_value("key1");
1743        builder.values().append_value(1);
1744        builder.append(true).unwrap();
1745        builder.keys().append_value("key1");
1746        builder.values().append_value(1);
1747        builder.append(true).unwrap();
1748        let expected = Arc::new(builder.finish()) as ArrayRef;
1749
1750        assert_eq!(&expected, &got);
1751    }
1752
1753    #[test]
1754    fn test_filter_fixed_size_list_arrays() {
1755        let value_data = ArrayData::builder(DataType::Int32)
1756            .len(9)
1757            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1758            .build()
1759            .unwrap();
1760        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1761        let list_data = ArrayData::builder(list_data_type)
1762            .len(3)
1763            .add_child_data(value_data)
1764            .build()
1765            .unwrap();
1766        let array = FixedSizeListArray::from(list_data);
1767
1768        let filter_array = BooleanArray::from(vec![true, false, false]);
1769
1770        let c = filter(&array, &filter_array).unwrap();
1771        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1772
1773        assert_eq!(filtered.len(), 1);
1774
1775        let list = filtered.value(0);
1776        assert_eq!(
1777            &[0, 1, 2],
1778            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1779        );
1780
1781        let filter_array = BooleanArray::from(vec![true, false, true]);
1782
1783        let c = filter(&array, &filter_array).unwrap();
1784        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1785
1786        assert_eq!(filtered.len(), 2);
1787
1788        let list = filtered.value(0);
1789        assert_eq!(
1790            &[0, 1, 2],
1791            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1792        );
1793        let list = filtered.value(1);
1794        assert_eq!(
1795            &[6, 7, 8],
1796            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1797        );
1798    }
1799
1800    #[test]
1801    fn test_filter_fixed_size_list_arrays_with_null() {
1802        let value_data = ArrayData::builder(DataType::Int32)
1803            .len(10)
1804            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1805            .build()
1806            .unwrap();
1807
1808        // Set null buts for the nested array:
1809        //  [[0, 1], null, null, [6, 7], [8, 9]]
1810        // 01011001 00000001
1811        let mut null_bits: [u8; 1] = [0; 1];
1812        bit_util::set_bit(&mut null_bits, 0);
1813        bit_util::set_bit(&mut null_bits, 3);
1814        bit_util::set_bit(&mut null_bits, 4);
1815
1816        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1817        let list_data = ArrayData::builder(list_data_type)
1818            .len(5)
1819            .add_child_data(value_data)
1820            .null_bit_buffer(Some(Buffer::from(null_bits)))
1821            .build()
1822            .unwrap();
1823        let array = FixedSizeListArray::from(list_data);
1824
1825        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1826
1827        let c = filter(&array, &filter_array).unwrap();
1828        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1829
1830        assert_eq!(filtered.len(), 3);
1831
1832        let list = filtered.value(0);
1833        assert_eq!(
1834            &[0, 1],
1835            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1836        );
1837        assert!(filtered.is_null(1));
1838        let list = filtered.value(2);
1839        assert_eq!(
1840            &[6, 7],
1841            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1842        );
1843    }
1844
1845    fn test_filter_union_array(array: UnionArray) {
1846        let filter_array = BooleanArray::from(vec![true, false, false]);
1847        let c = filter(&array, &filter_array).unwrap();
1848        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1849
1850        let mut builder = UnionBuilder::new_dense();
1851        builder.append::<Int32Type>("A", 1).unwrap();
1852        let expected_array = builder.build().unwrap();
1853
1854        compare_union_arrays(filtered, &expected_array);
1855
1856        let filter_array = BooleanArray::from(vec![true, false, true]);
1857        let c = filter(&array, &filter_array).unwrap();
1858        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1859
1860        let mut builder = UnionBuilder::new_dense();
1861        builder.append::<Int32Type>("A", 1).unwrap();
1862        builder.append::<Int32Type>("A", 34).unwrap();
1863        let expected_array = builder.build().unwrap();
1864
1865        compare_union_arrays(filtered, &expected_array);
1866
1867        let filter_array = BooleanArray::from(vec![true, true, false]);
1868        let c = filter(&array, &filter_array).unwrap();
1869        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1870
1871        let mut builder = UnionBuilder::new_dense();
1872        builder.append::<Int32Type>("A", 1).unwrap();
1873        builder.append::<Float64Type>("B", 3.2).unwrap();
1874        let expected_array = builder.build().unwrap();
1875
1876        compare_union_arrays(filtered, &expected_array);
1877    }
1878
1879    #[test]
1880    fn test_filter_union_array_dense() {
1881        let mut builder = UnionBuilder::new_dense();
1882        builder.append::<Int32Type>("A", 1).unwrap();
1883        builder.append::<Float64Type>("B", 3.2).unwrap();
1884        builder.append::<Int32Type>("A", 34).unwrap();
1885        let array = builder.build().unwrap();
1886
1887        test_filter_union_array(array);
1888    }
1889
1890    #[test]
1891    fn test_filter_run_union_array_dense() {
1892        let mut builder = UnionBuilder::new_dense();
1893        builder.append::<Int32Type>("A", 1).unwrap();
1894        builder.append::<Int32Type>("A", 3).unwrap();
1895        builder.append::<Int32Type>("A", 34).unwrap();
1896        let array = builder.build().unwrap();
1897
1898        let filter_array = BooleanArray::from(vec![true, true, false]);
1899        let c = filter(&array, &filter_array).unwrap();
1900        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1901
1902        let mut builder = UnionBuilder::new_dense();
1903        builder.append::<Int32Type>("A", 1).unwrap();
1904        builder.append::<Int32Type>("A", 3).unwrap();
1905        let expected = builder.build().unwrap();
1906
1907        assert_eq!(filtered.to_data(), expected.to_data());
1908    }
1909
1910    #[test]
1911    fn test_filter_union_array_dense_with_nulls() {
1912        let mut builder = UnionBuilder::new_dense();
1913        builder.append::<Int32Type>("A", 1).unwrap();
1914        builder.append::<Float64Type>("B", 3.2).unwrap();
1915        builder.append_null::<Float64Type>("B").unwrap();
1916        builder.append::<Int32Type>("A", 34).unwrap();
1917        let array = builder.build().unwrap();
1918
1919        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1920        let c = filter(&array, &filter_array).unwrap();
1921        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1922
1923        let mut builder = UnionBuilder::new_dense();
1924        builder.append::<Int32Type>("A", 1).unwrap();
1925        builder.append::<Float64Type>("B", 3.2).unwrap();
1926        let expected_array = builder.build().unwrap();
1927
1928        compare_union_arrays(filtered, &expected_array);
1929
1930        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1931        let c = filter(&array, &filter_array).unwrap();
1932        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1933
1934        let mut builder = UnionBuilder::new_dense();
1935        builder.append::<Int32Type>("A", 1).unwrap();
1936        builder.append_null::<Float64Type>("B").unwrap();
1937        let expected_array = builder.build().unwrap();
1938
1939        compare_union_arrays(filtered, &expected_array);
1940    }
1941
1942    #[test]
1943    fn test_filter_union_array_sparse() {
1944        let mut builder = UnionBuilder::new_sparse();
1945        builder.append::<Int32Type>("A", 1).unwrap();
1946        builder.append::<Float64Type>("B", 3.2).unwrap();
1947        builder.append::<Int32Type>("A", 34).unwrap();
1948        let array = builder.build().unwrap();
1949
1950        test_filter_union_array(array);
1951    }
1952
1953    #[test]
1954    fn test_filter_union_array_sparse_with_nulls() {
1955        let mut builder = UnionBuilder::new_sparse();
1956        builder.append::<Int32Type>("A", 1).unwrap();
1957        builder.append::<Float64Type>("B", 3.2).unwrap();
1958        builder.append_null::<Float64Type>("B").unwrap();
1959        builder.append::<Int32Type>("A", 34).unwrap();
1960        let array = builder.build().unwrap();
1961
1962        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1963        let c = filter(&array, &filter_array).unwrap();
1964        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1965
1966        let mut builder = UnionBuilder::new_sparse();
1967        builder.append::<Int32Type>("A", 1).unwrap();
1968        builder.append_null::<Float64Type>("B").unwrap();
1969        let expected_array = builder.build().unwrap();
1970
1971        compare_union_arrays(filtered, &expected_array);
1972    }
1973
1974    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1975        assert_eq!(union1.len(), union2.len());
1976
1977        for i in 0..union1.len() {
1978            let type_id = union1.type_id(i);
1979
1980            let slot1 = union1.value(i);
1981            let slot2 = union2.value(i);
1982
1983            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1984
1985            if !slot1.is_null(0) && !slot2.is_null(0) {
1986                match type_id {
1987                    0 => {
1988                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1989                        assert_eq!(slot1.len(), 1);
1990                        let value1 = slot1.value(0);
1991
1992                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1993                        assert_eq!(slot2.len(), 1);
1994                        let value2 = slot2.value(0);
1995                        assert_eq!(value1, value2);
1996                    }
1997                    1 => {
1998                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1999                        assert_eq!(slot1.len(), 1);
2000                        let value1 = slot1.value(0);
2001
2002                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2003                        assert_eq!(slot2.len(), 1);
2004                        let value2 = slot2.value(0);
2005                        assert_eq!(value1, value2);
2006                    }
2007                    _ => unreachable!(),
2008                }
2009            }
2010        }
2011    }
2012
2013    #[test]
2014    fn test_filter_struct() {
2015        let predicate = BooleanArray::from(vec![true, false, true, false]);
2016
2017        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2018        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2019
2020        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2021        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2022
2023        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2024        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2025
2026        let a_field = Field::new("a", DataType::Utf8, false);
2027        let b_field = Field::new("b", DataType::Int32, false);
2028
2029        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2030        let expected =
2031            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2032
2033        let result = filter(&array, &predicate).unwrap();
2034
2035        assert_eq!(result.to_data(), expected.to_data());
2036
2037        let array = StructArray::new(
2038            vec![a_field.clone()].into(),
2039            vec![a.clone()],
2040            Some(null_mask.clone()),
2041        );
2042        let expected = StructArray::new(
2043            vec![a_field.clone()].into(),
2044            vec![a_filtered.clone()],
2045            Some(null_mask_filtered.clone()),
2046        );
2047
2048        let result = filter(&array, &predicate).unwrap();
2049
2050        assert_eq!(result.to_data(), expected.to_data());
2051
2052        let array = StructArray::new(
2053            vec![a_field.clone(), b_field.clone()].into(),
2054            vec![a.clone(), b.clone()],
2055            None,
2056        );
2057        let expected = StructArray::new(
2058            vec![a_field.clone(), b_field.clone()].into(),
2059            vec![a_filtered.clone(), b_filtered.clone()],
2060            None,
2061        );
2062
2063        let result = filter(&array, &predicate).unwrap();
2064
2065        assert_eq!(result.to_data(), expected.to_data());
2066
2067        let array = StructArray::new(
2068            vec![a_field.clone(), b_field.clone()].into(),
2069            vec![a.clone(), b.clone()],
2070            Some(null_mask.clone()),
2071        );
2072
2073        let expected = StructArray::new(
2074            vec![a_field.clone(), b_field.clone()].into(),
2075            vec![a_filtered.clone(), b_filtered.clone()],
2076            Some(null_mask_filtered.clone()),
2077        );
2078
2079        let result = filter(&array, &predicate).unwrap();
2080
2081        assert_eq!(result.to_data(), expected.to_data());
2082    }
2083
2084    #[test]
2085    fn test_filter_empty_struct() {
2086        /*
2087            "a": {
2088                "b": int64,
2089                "c": {}
2090            },
2091        */
2092        let fields = arrow_schema::Field::new(
2093            "a",
2094            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2095                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2096                arrow_schema::Field::new(
2097                    "c",
2098                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2099                    true,
2100                ),
2101            ])),
2102            true,
2103        );
2104
2105        /* Test record
2106            {"a":{"c": {}}}
2107            {"a":{"c": {}}}
2108            {"a":{"c": {}}}
2109        */
2110
2111        // Create the record batch with the nested struct array
2112        let schema = Arc::new(Schema::new(vec![fields]));
2113
2114        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2115        let c = Arc::new(StructArray::new_empty_fields(
2116            3,
2117            Some(NullBuffer::from(vec![true, true, true])),
2118        ));
2119        let a = StructArray::new(
2120            vec![
2121                Field::new("b", DataType::Int64, true),
2122                Field::new("c", DataType::Struct(Fields::empty()), true),
2123            ]
2124            .into(),
2125            vec![b.clone(), c.clone()],
2126            Some(NullBuffer::from(vec![true, true, true])),
2127        );
2128        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2129        println!("{record_batch:?}");
2130
2131        // Apply the filter
2132        let predicate = BooleanArray::from(vec![true, false, true]);
2133        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2134
2135        // The filtered batch should have 2 rows (the 1st and 3rd)
2136        assert_eq!(filtered_batch.num_rows(), 2);
2137    }
2138}