arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::ArrayDataBuilder;
32use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33use arrow_data::transform::MutableArrayData;
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        Self(filter.values().set_slices())
62    }
63}
64
65impl Iterator for SlicesIterator<'_> {
66    type Item = (usize, usize);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        self.0.next()
70    }
71}
72
73/// An iterator of `usize` whose index in [`BooleanArray`] is true
74///
75/// This provides the best performance on most predicates, apart from those which keep
76/// large runs and therefore favour [`SlicesIterator`]
77struct IndexIterator<'a> {
78    remaining: usize,
79    iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84        assert_eq!(filter.null_count(), 0);
85        let iter = filter.values().set_indices();
86        Self { remaining, iter }
87    }
88}
89
90impl Iterator for IndexIterator<'_> {
91    type Item = usize;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        if self.remaining != 0 {
95            // Fascinatingly swapping these two lines around results in a 50%
96            // performance regression for some benchmarks
97            let next = self.iter.next().expect("IndexIterator exhausted early");
98            self.remaining -= 1;
99            // Must panic if exhausted early as trusted length iterator
100            return Some(next);
101        }
102        None
103    }
104
105    fn size_hint(&self) -> (usize, Option<usize>) {
106        (self.remaining, Some(self.remaining))
107    }
108}
109
110/// Counts the number of set bits in `filter`
111fn filter_count(filter: &BooleanArray) -> usize {
112    filter.values().count_set_bits()
113}
114
115/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
116pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117    let nulls = filter.nulls().unwrap();
118    let mask = filter.values() & nulls.inner();
119    BooleanArray::new(mask, None)
120}
121
122/// Returns a filtered `values` [`Array`] where the corresponding elements of
123/// `predicate` are `true`.
124///
125/// # See also
126/// * [`FilterBuilder`] for more control over the filtering process.
127/// * [`filter_record_batch`] to filter a [`RecordBatch`]
128/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
129///   the results into a single array.
130///
131/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
132///
133/// # Example
134/// ```rust
135/// # use arrow_array::{Int32Array, BooleanArray};
136/// # use arrow_select::filter::filter;
137/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
138/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
139/// let c = filter(&array, &filter_array).unwrap();
140/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
141/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
142/// ```
143pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144    let mut filter_builder = FilterBuilder::new(predicate);
145
146    if multiple_arrays(values.data_type()) {
147        // Only optimize if filtering more than one array
148        // Otherwise, the overhead of optimization can be more than the benefit
149        filter_builder = filter_builder.optimize();
150    }
151
152    let predicate = filter_builder.build();
153
154    filter_array(values, &predicate)
155}
156
157fn multiple_arrays(data_type: &DataType) -> bool {
158    match data_type {
159        DataType::Struct(fields) => {
160            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161        }
162        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163        _ => false,
164    }
165}
166
167/// Returns a filtered [RecordBatch] where the corresponding elements of
168/// `predicate` are true.
169///
170/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
171pub fn filter_record_batch(
172    record_batch: &RecordBatch,
173    predicate: &BooleanArray,
174) -> Result<RecordBatch, ArrowError> {
175    let mut filter_builder = FilterBuilder::new(predicate);
176    if record_batch.num_columns() > 1 {
177        // Only optimize if filtering more than one column
178        // Otherwise, the overhead of optimization can be more than the benefit
179        filter_builder = filter_builder.optimize();
180    }
181    let filter = filter_builder.build();
182
183    let filtered_arrays = record_batch
184        .columns()
185        .iter()
186        .map(|a| filter_array(a, &filter))
187        .collect::<Result<Vec<_>, _>>()?;
188    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190}
191
192/// A builder to construct [`FilterPredicate`]
193#[derive(Debug)]
194pub struct FilterBuilder {
195    filter: BooleanArray,
196    count: usize,
197    strategy: IterationStrategy,
198}
199
200impl FilterBuilder {
201    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
202    pub fn new(filter: &BooleanArray) -> Self {
203        let filter = match filter.null_count() {
204            0 => filter.clone(),
205            _ => prep_null_mask_filter(filter),
206        };
207
208        let count = filter_count(&filter);
209        let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211        Self {
212            filter,
213            count,
214            strategy,
215        }
216    }
217
218    /// Compute an optimised representation of the provided `filter` mask that can be
219    /// applied to an array more quickly.
220    ///
221    /// Note: There is limited benefit to calling this to then filter a single array
222    /// Note: This will likely have a larger memory footprint than the original mask
223    pub fn optimize(mut self) -> Self {
224        match self.strategy {
225            IterationStrategy::SlicesIterator => {
226                let slices = SlicesIterator::new(&self.filter).collect();
227                self.strategy = IterationStrategy::Slices(slices)
228            }
229            IterationStrategy::IndexIterator => {
230                let indices = IndexIterator::new(&self.filter, self.count).collect();
231                self.strategy = IterationStrategy::Indices(indices)
232            }
233            _ => {}
234        }
235        self
236    }
237
238    /// Construct the final `FilterPredicate`
239    pub fn build(self) -> FilterPredicate {
240        FilterPredicate {
241            filter: self.filter,
242            count: self.count,
243            strategy: self.strategy,
244        }
245    }
246}
247
248/// The iteration strategy used to evaluate [`FilterPredicate`]
249#[derive(Debug)]
250enum IterationStrategy {
251    /// A lazily evaluated iterator of ranges
252    SlicesIterator,
253    /// A lazily evaluated iterator of indices
254    IndexIterator,
255    /// A precomputed list of indices
256    Indices(Vec<usize>),
257    /// A precomputed array of ranges
258    Slices(Vec<(usize, usize)>),
259    /// Select all rows
260    All,
261    /// Select no rows
262    None,
263}
264
265impl IterationStrategy {
266    /// The default [`IterationStrategy`] for a filter of length `filter_length`
267    /// and selecting `filter_count` rows
268    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269        if filter_length == 0 || filter_count == 0 {
270            return IterationStrategy::None;
271        }
272
273        if filter_count == filter_length {
274            return IterationStrategy::All;
275        }
276
277        // Compute the selectivity of the predicate by dividing the number of true
278        // bits in the predicate by the predicate's total length
279        //
280        // This can then be used as a heuristic for the optimal iteration strategy
281        let selectivity_frac = filter_count as f64 / filter_length as f64;
282        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283            return IterationStrategy::SlicesIterator;
284        }
285        IterationStrategy::IndexIterator
286    }
287}
288
289/// A filtering predicate that can be applied to an [`Array`]
290#[derive(Debug)]
291pub struct FilterPredicate {
292    filter: BooleanArray,
293    count: usize,
294    strategy: IterationStrategy,
295}
296
297impl FilterPredicate {
298    /// Selects rows from `values` based on this [`FilterPredicate`]
299    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300        filter_array(values, self)
301    }
302
303    /// Number of rows being selected based on this [`FilterPredicate`]
304    pub fn count(&self) -> usize {
305        self.count
306    }
307}
308
309fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310    if predicate.filter.len() > values.len() {
311        return Err(ArrowError::InvalidArgumentError(format!(
312            "Filter predicate of length {} is larger than target array of length {}",
313            predicate.filter.len(),
314            values.len()
315        )));
316    }
317
318    match predicate.strategy {
319        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321        // actually filter
322        _ => downcast_primitive_array! {
323            values => Ok(Arc::new(filter_primitive(values, predicate))),
324            DataType::Boolean => {
325                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326                Ok(Arc::new(filter_boolean(values, predicate)))
327            }
328            DataType::Utf8 => {
329                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330            }
331            DataType::LargeUtf8 => {
332                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333            }
334            DataType::Utf8View => {
335                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336            }
337            DataType::Binary => {
338                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339            }
340            DataType::LargeBinary => {
341                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342            }
343            DataType::BinaryView => {
344                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345            }
346            DataType::FixedSizeBinary(_) => {
347                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348            }
349            DataType::RunEndEncoded(_, _) => {
350                downcast_run_array!{
351                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353                }
354            }
355            DataType::Dictionary(_, _) => downcast_dictionary_array! {
356                values => Ok(Arc::new(filter_dict(values, predicate))),
357                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358            }
359            DataType::Struct(_) => {
360                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361            }
362            DataType::Union(_, UnionMode::Sparse) => {
363                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364            }
365            _ => {
366                let data = values.to_data();
367                // fallback to using MutableArrayData
368                let mut mutable = MutableArrayData::new(
369                    vec![&data],
370                    false,
371                    predicate.count,
372                );
373
374                match &predicate.strategy {
375                    IterationStrategy::Slices(slices) => {
376                        slices
377                            .iter()
378                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
379                    }
380                    _ => {
381                        let iter = SlicesIterator::new(&predicate.filter);
382                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
383                    }
384                }
385
386                let data = mutable.freeze();
387                Ok(make_array(data))
388            }
389        },
390    }
391}
392
393/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
394fn filter_run_end_array<R: RunEndIndexType>(
395    array: &RunArray<R>,
396    predicate: &FilterPredicate,
397) -> Result<RunArray<R>, ArrowError>
398where
399    R::Native: Into<i64> + From<bool>,
400    R::Native: AddAssign,
401{
402    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405    let mut start = 0u64;
406    let mut j = 0;
407    let mut count = R::default_value();
408    let filter_values = predicate.filter.values();
409    let run_ends = run_ends.inner();
410
411    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412        let mut keep = false;
413        let mut end = run_ends[i].into() as u64;
414        let difference = end.saturating_sub(filter_values.len() as u64);
415        end -= difference;
416
417        // Safety: we subtract the difference off `end` so we are always within bounds
418        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419            count += R::Native::from(pred);
420            keep |= pred
421        }
422        // this is to avoid branching
423        new_run_ends[j] = count;
424        j += keep as usize;
425
426        start = end;
427        keep
428    })
429    .into();
430
431    new_run_ends.truncate(j);
432
433    let values = array.values();
434    let values = filter(&values, &pred)?;
435
436    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
437    RunArray::try_new(&run_ends, &values)
438}
439
440/// Computes a new null mask for `data` based on `predicate`
441///
442/// If the predicate selected no null-rows, returns `None`, otherwise returns
443/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
444/// in the filtered output, and `null_buffer` is the filtered null buffer
445///
446fn filter_null_mask(
447    nulls: Option<&NullBuffer>,
448    predicate: &FilterPredicate,
449) -> Option<(usize, Buffer)> {
450    let nulls = nulls?;
451    if nulls.null_count() == 0 {
452        return None;
453    }
454
455    let nulls = filter_bits(nulls.inner(), predicate);
456    // The filtered `nulls` has a length of `predicate.count` bits and
457    // therefore the null count is this minus the number of valid bits
458    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460    if null_count == 0 {
461        return None;
462    }
463
464    Some((null_count, nulls))
465}
466
467/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
468fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469    let src = buffer.values();
470    let offset = buffer.offset();
471
472    match &predicate.strategy {
473        IterationStrategy::IndexIterator => {
474            let bits = IndexIterator::new(&predicate.filter, predicate.count)
475                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477            // SAFETY: `IndexIterator` reports its size correctly
478            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479        }
480        IterationStrategy::Indices(indices) => {
481            let bits = indices
482                .iter()
483                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485            // SAFETY: `Vec::iter()` reports its size correctly
486            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487        }
488        IterationStrategy::SlicesIterator => {
489            let mut builder = BooleanBufferBuilder::new(predicate.count);
490            for (start, end) in SlicesIterator::new(&predicate.filter) {
491                builder.append_packed_range(start + offset..end + offset, src)
492            }
493            builder.into()
494        }
495        IterationStrategy::Slices(slices) => {
496            let mut builder = BooleanBufferBuilder::new(predicate.count);
497            for (start, end) in slices {
498                builder.append_packed_range(*start + offset..*end + offset, src)
499            }
500            builder.into()
501        }
502        IterationStrategy::All | IterationStrategy::None => unreachable!(),
503    }
504}
505
506/// `filter` implementation for boolean buffers
507fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508    let values = filter_bits(array.values(), predicate);
509
510    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511        .len(predicate.count)
512        .add_buffer(values);
513
514    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516    }
517
518    let data = unsafe { builder.build_unchecked() };
519    BooleanArray::from(data)
520}
521
522#[inline(never)]
523fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524    assert!(values.len() >= predicate.filter.len());
525
526    match &predicate.strategy {
527        IterationStrategy::SlicesIterator => {
528            let mut buffer = Vec::with_capacity(predicate.count);
529            for (start, end) in SlicesIterator::new(&predicate.filter) {
530                buffer.extend_from_slice(&values[start..end]);
531            }
532            buffer.into()
533        }
534        IterationStrategy::Slices(slices) => {
535            let mut buffer = Vec::with_capacity(predicate.count);
536            for (start, end) in slices {
537                buffer.extend_from_slice(&values[*start..*end]);
538            }
539            buffer.into()
540        }
541        IterationStrategy::IndexIterator => {
542            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544            // SAFETY: IndexIterator is trusted length
545            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546        }
547        IterationStrategy::Indices(indices) => {
548            let iter = indices.iter().map(|x| values[*x]);
549            iter.collect::<Vec<_>>().into()
550        }
551        IterationStrategy::All | IterationStrategy::None => unreachable!(),
552    }
553}
554
555/// `filter` implementation for primitive arrays
556fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557where
558    T: ArrowPrimitiveType,
559{
560    let values = array.values();
561    let buffer = filter_native(values, predicate);
562    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563        .len(predicate.count)
564        .add_buffer(buffer);
565
566    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568    }
569
570    let data = unsafe { builder.build_unchecked() };
571    PrimitiveArray::from(data)
572}
573
574/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
575/// used to build a new [`GenericByteArray`] by copying values from the source
576///
577/// TODO(raphael): Could this be used for the take kernel as well?
578struct FilterBytes<'a, OffsetSize> {
579    src_offsets: &'a [OffsetSize],
580    src_values: &'a [u8],
581    dst_offsets: Vec<OffsetSize>,
582    dst_values: Vec<u8>,
583    cur_offset: OffsetSize,
584}
585
586impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587where
588    OffsetSize: OffsetSizeTrait,
589{
590    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591    where
592        T: ByteArrayType<Offset = OffsetSize>,
593    {
594        let dst_values = Vec::new();
595        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596        let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598        dst_offsets.push(cur_offset);
599
600        Self {
601            src_offsets: array.value_offsets(),
602            src_values: array.value_data(),
603            dst_offsets,
604            dst_values,
605            cur_offset,
606        }
607    }
608
609    /// Returns the byte offset at `idx`
610    #[inline]
611    fn get_value_offset(&self, idx: usize) -> usize {
612        self.src_offsets[idx].as_usize()
613    }
614
615    /// Returns the start and end of the value at index `idx` along with its length
616    #[inline]
617    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618        // These can only fail if `array` contains invalid data
619        let start = self.get_value_offset(idx);
620        let end = self.get_value_offset(idx + 1);
621        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622        (start, end, len)
623    }
624
625    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626        self.dst_offsets.extend(iter.map(|idx| {
627            let start = self.src_offsets[idx].as_usize();
628            let end = self.src_offsets[idx + 1].as_usize();
629            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630            self.cur_offset += len;
631
632            self.cur_offset
633        }));
634    }
635
636    /// Extends the in-progress array by the indexes in the provided iterator
637    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638        self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640        for idx in iter {
641            let start = self.src_offsets[idx].as_usize();
642            let end = self.src_offsets[idx + 1].as_usize();
643            self.dst_values
644                .extend_from_slice(&self.src_values[start..end]);
645        }
646    }
647
648    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649        self.dst_offsets.reserve_exact(count);
650        for (start, end) in iter {
651            // These can only fail if `array` contains invalid data
652            for idx in start..end {
653                let (_, _, len) = self.get_value_range(idx);
654                self.cur_offset += len;
655                self.dst_offsets.push(self.cur_offset);
656            }
657        }
658    }
659
660    /// Extends the in-progress array by the ranges in the provided iterator
661    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662        self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664        for (start, end) in iter {
665            let value_start = self.get_value_offset(start);
666            let value_end = self.get_value_offset(end);
667            self.dst_values
668                .extend_from_slice(&self.src_values[value_start..value_end]);
669        }
670    }
671}
672
673/// `filter` implementation for byte arrays
674///
675/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
676/// data copied across. This allows handling the null mask separately from the data
677fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678where
679    T: ByteArrayType,
680{
681    let mut filter = FilterBytes::new(predicate.count, array);
682
683    match &predicate.strategy {
684        IterationStrategy::SlicesIterator => {
685            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686            filter.extend_slices(SlicesIterator::new(&predicate.filter))
687        }
688        IterationStrategy::Slices(slices) => {
689            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690            filter.extend_slices(slices.iter().cloned())
691        }
692        IterationStrategy::IndexIterator => {
693            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695        }
696        IterationStrategy::Indices(indices) => {
697            filter.extend_offsets_idx(indices.iter().cloned());
698            filter.extend_idx(indices.iter().cloned())
699        }
700        IterationStrategy::All | IterationStrategy::None => unreachable!(),
701    }
702
703    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704        .len(predicate.count)
705        .add_buffer(filter.dst_offsets.into())
706        .add_buffer(filter.dst_values.into());
707
708    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710    }
711
712    let data = unsafe { builder.build_unchecked() };
713    GenericByteArray::from(data)
714}
715
716/// `filter` implementation for byte view arrays.
717fn filter_byte_view<T: ByteViewType>(
718    array: &GenericByteViewArray<T>,
719    predicate: &FilterPredicate,
720) -> GenericByteViewArray<T> {
721    let new_view_buffer = filter_native(array.views(), predicate);
722
723    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724        .len(predicate.count)
725        .add_buffer(new_view_buffer)
726        .add_buffers(array.data_buffers().to_vec());
727
728    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730    }
731
732    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733}
734
735fn filter_fixed_size_binary(
736    array: &FixedSizeBinaryArray,
737    predicate: &FilterPredicate,
738) -> FixedSizeBinaryArray {
739    let values: &[u8] = array.values();
740    let value_length = array.value_length() as usize;
741    let calculate_offset_from_index = |index: usize| index * value_length;
742    let buffer = match &predicate.strategy {
743        IterationStrategy::SlicesIterator => {
744            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745            for (start, end) in SlicesIterator::new(&predicate.filter) {
746                buffer.extend_from_slice(
747                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748                );
749            }
750            buffer
751        }
752        IterationStrategy::Slices(slices) => {
753            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754            for (start, end) in slices {
755                buffer.extend_from_slice(
756                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757                );
758            }
759            buffer
760        }
761        IterationStrategy::IndexIterator => {
762            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764            });
765
766            let mut buffer = MutableBuffer::new(predicate.count * value_length);
767            iter.for_each(|item| buffer.extend_from_slice(item));
768            buffer
769        }
770        IterationStrategy::Indices(indices) => {
771            let iter = indices.iter().map(|x| {
772                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773            });
774
775            let mut buffer = MutableBuffer::new(predicate.count * value_length);
776            iter.for_each(|item| buffer.extend_from_slice(item));
777            buffer
778        }
779        IterationStrategy::All | IterationStrategy::None => unreachable!(),
780    };
781    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782        .len(predicate.count)
783        .add_buffer(buffer.into());
784
785    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787    }
788
789    let data = unsafe { builder.build_unchecked() };
790    FixedSizeBinaryArray::from(data)
791}
792
793/// `filter` implementation for dictionaries
794fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795where
796    T: ArrowDictionaryKeyType,
797    T::Native: num_traits::Num,
798{
799    let builder = filter_primitive::<T>(array.keys(), predicate)
800        .into_data()
801        .into_builder()
802        .data_type(array.data_type().clone())
803        .child_data(vec![array.values().to_data()]);
804
805    // SAFETY:
806    // Keys were valid before, filtered subset is therefore still valid
807    DictionaryArray::from(unsafe { builder.build_unchecked() })
808}
809
810/// `filter` implementation for structs
811fn filter_struct(
812    array: &StructArray,
813    predicate: &FilterPredicate,
814) -> Result<StructArray, ArrowError> {
815    let columns = array
816        .columns()
817        .iter()
818        .map(|column| filter_array(column, predicate))
819        .collect::<Result<_, _>>()?;
820
821    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825    } else {
826        None
827    };
828
829    Ok(unsafe {
830        StructArray::new_unchecked_with_length(
831            array.fields().clone(),
832            columns,
833            nulls,
834            predicate.count(),
835        )
836    })
837}
838
839/// `filter` implementation for sparse unions
840fn filter_sparse_union(
841    array: &UnionArray,
842    predicate: &FilterPredicate,
843) -> Result<UnionArray, ArrowError> {
844    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845        unreachable!()
846    };
847
848    let type_ids = filter_primitive(
849        &Int8Array::try_new(array.type_ids().clone(), None)?,
850        predicate,
851    );
852
853    let children = fields
854        .iter()
855        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
856        .collect::<Result<_, _>>()?;
857
858    Ok(unsafe {
859        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
860    })
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866    use arrow_array::builder::*;
867    use arrow_array::cast::as_run_array;
868    use arrow_array::types::*;
869    use arrow_data::ArrayData;
870    use rand::distr::uniform::{UniformSampler, UniformUsize};
871    use rand::distr::{Alphanumeric, StandardUniform};
872    use rand::prelude::*;
873    use rand::rng;
874
875    macro_rules! def_temporal_test {
876        ($test:ident, $array_type: ident, $data: expr) => {
877            #[test]
878            fn $test() {
879                let a = $data;
880                let b = BooleanArray::from(vec![true, false, true, false]);
881                let c = filter(&a, &b).unwrap();
882                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
883                assert_eq!(2, d.len());
884                assert_eq!(1, d.value(0));
885                assert_eq!(3, d.value(1));
886            }
887        };
888    }
889
890    def_temporal_test!(
891        test_filter_date32,
892        Date32Array,
893        Date32Array::from(vec![1, 2, 3, 4])
894    );
895    def_temporal_test!(
896        test_filter_date64,
897        Date64Array,
898        Date64Array::from(vec![1, 2, 3, 4])
899    );
900    def_temporal_test!(
901        test_filter_time32_second,
902        Time32SecondArray,
903        Time32SecondArray::from(vec![1, 2, 3, 4])
904    );
905    def_temporal_test!(
906        test_filter_time32_millisecond,
907        Time32MillisecondArray,
908        Time32MillisecondArray::from(vec![1, 2, 3, 4])
909    );
910    def_temporal_test!(
911        test_filter_time64_microsecond,
912        Time64MicrosecondArray,
913        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
914    );
915    def_temporal_test!(
916        test_filter_time64_nanosecond,
917        Time64NanosecondArray,
918        Time64NanosecondArray::from(vec![1, 2, 3, 4])
919    );
920    def_temporal_test!(
921        test_filter_duration_second,
922        DurationSecondArray,
923        DurationSecondArray::from(vec![1, 2, 3, 4])
924    );
925    def_temporal_test!(
926        test_filter_duration_millisecond,
927        DurationMillisecondArray,
928        DurationMillisecondArray::from(vec![1, 2, 3, 4])
929    );
930    def_temporal_test!(
931        test_filter_duration_microsecond,
932        DurationMicrosecondArray,
933        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
934    );
935    def_temporal_test!(
936        test_filter_duration_nanosecond,
937        DurationNanosecondArray,
938        DurationNanosecondArray::from(vec![1, 2, 3, 4])
939    );
940    def_temporal_test!(
941        test_filter_timestamp_second,
942        TimestampSecondArray,
943        TimestampSecondArray::from(vec![1, 2, 3, 4])
944    );
945    def_temporal_test!(
946        test_filter_timestamp_millisecond,
947        TimestampMillisecondArray,
948        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
949    );
950    def_temporal_test!(
951        test_filter_timestamp_microsecond,
952        TimestampMicrosecondArray,
953        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
954    );
955    def_temporal_test!(
956        test_filter_timestamp_nanosecond,
957        TimestampNanosecondArray,
958        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
959    );
960
961    #[test]
962    fn test_filter_array_slice() {
963        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
964        let b = BooleanArray::from(vec![true, false, false, true]);
965        // filtering with sliced filter array is not currently supported
966        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
967        // let b = b_slice.as_any().downcast_ref().unwrap();
968        let c = filter(&a, &b).unwrap();
969        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
970        assert_eq!(2, d.len());
971        assert_eq!(6, d.value(0));
972        assert_eq!(9, d.value(1));
973    }
974
975    #[test]
976    fn test_filter_array_low_density() {
977        // this test exercises the all 0's branch of the filter algorithm
978        let mut data_values = (1..=65).collect::<Vec<i32>>();
979        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
980        // set up two more values after the batch
981        data_values.extend_from_slice(&[66, 67]);
982        filter_values.extend_from_slice(&[false, true]);
983        let a = Int32Array::from(data_values);
984        let b = BooleanArray::from(filter_values);
985        let c = filter(&a, &b).unwrap();
986        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
987        assert_eq!(2, d.len());
988        assert_eq!(65, d.value(0));
989        assert_eq!(67, d.value(1));
990    }
991
992    #[test]
993    fn test_filter_array_high_density() {
994        // this test exercises the all 1's branch of the filter algorithm
995        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
996        let mut filter_values = (1..=65)
997            .map(|i| !matches!(i % 65, 0))
998            .collect::<Vec<bool>>();
999        // set second data value to null
1000        data_values[1] = None;
1001        // set up two more values after the batch
1002        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1003        filter_values.extend_from_slice(&[false, true, true, true]);
1004        let a = Int32Array::from(data_values);
1005        let b = BooleanArray::from(filter_values);
1006        let c = filter(&a, &b).unwrap();
1007        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1008        assert_eq!(67, d.len());
1009        assert_eq!(3, d.null_count());
1010        assert_eq!(1, d.value(0));
1011        assert!(d.is_null(1));
1012        assert_eq!(64, d.value(63));
1013        assert!(d.is_null(64));
1014        assert_eq!(67, d.value(65));
1015    }
1016
1017    #[test]
1018    fn test_filter_string_array_simple() {
1019        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1020        let b = BooleanArray::from(vec![true, false, true, false]);
1021        let c = filter(&a, &b).unwrap();
1022        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1023        assert_eq!(2, d.len());
1024        assert_eq!("hello", d.value(0));
1025        assert_eq!("world", d.value(1));
1026    }
1027
1028    #[test]
1029    fn test_filter_primitive_array_with_null() {
1030        let a = Int32Array::from(vec![Some(5), None]);
1031        let b = BooleanArray::from(vec![false, true]);
1032        let c = filter(&a, &b).unwrap();
1033        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1034        assert_eq!(1, d.len());
1035        assert!(d.is_null(0));
1036    }
1037
1038    #[test]
1039    fn test_filter_string_array_with_null() {
1040        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1041        let b = BooleanArray::from(vec![true, false, false, true]);
1042        let c = filter(&a, &b).unwrap();
1043        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1044        assert_eq!(2, d.len());
1045        assert_eq!("hello", d.value(0));
1046        assert!(!d.is_null(0));
1047        assert!(d.is_null(1));
1048    }
1049
1050    #[test]
1051    fn test_filter_binary_array_with_null() {
1052        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1053        let a = BinaryArray::from(data);
1054        let b = BooleanArray::from(vec![true, false, false, true]);
1055        let c = filter(&a, &b).unwrap();
1056        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1057        assert_eq!(2, d.len());
1058        assert_eq!(b"hello", d.value(0));
1059        assert!(!d.is_null(0));
1060        assert!(d.is_null(1));
1061    }
1062
1063    fn _test_filter_byte_view<T>()
1064    where
1065        T: ByteViewType,
1066        str: AsRef<T::Native>,
1067        T::Native: PartialEq,
1068    {
1069        let array = {
1070            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1071            let mut builder = GenericByteViewBuilder::<T>::new();
1072            builder.append_value("hello");
1073            builder.append_value("world");
1074            builder.append_null();
1075            builder.append_value("large payload over 12 bytes");
1076            builder.append_value("lulu");
1077            builder.finish()
1078        };
1079
1080        {
1081            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1082            let actual = filter(&array, &predicate).unwrap();
1083
1084            assert_eq!(actual.len(), 3);
1085
1086            let expected = {
1087                // ["hello", null, "large payload over 12 bytes"]
1088                let mut builder = GenericByteViewBuilder::<T>::new();
1089                builder.append_value("hello");
1090                builder.append_null();
1091                builder.append_value("large payload over 12 bytes");
1092                builder.finish()
1093            };
1094
1095            assert_eq!(actual.as_ref(), &expected);
1096        }
1097
1098        {
1099            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1100            let actual = filter(&array, &predicate).unwrap();
1101
1102            assert_eq!(actual.len(), 2);
1103
1104            let expected = {
1105                // ["hello", "lulu"]
1106                let mut builder = GenericByteViewBuilder::<T>::new();
1107                builder.append_value("hello");
1108                builder.append_value("lulu");
1109                builder.finish()
1110            };
1111
1112            assert_eq!(actual.as_ref(), &expected);
1113        }
1114    }
1115
1116    #[test]
1117    fn test_filter_string_view() {
1118        _test_filter_byte_view::<StringViewType>()
1119    }
1120
1121    #[test]
1122    fn test_filter_binary_view() {
1123        _test_filter_byte_view::<BinaryViewType>()
1124    }
1125
1126    #[test]
1127    fn test_filter_fixed_binary() {
1128        let v1 = [1_u8, 2];
1129        let v2 = [3_u8, 4];
1130        let v3 = [5_u8, 6];
1131        let v = vec![&v1, &v2, &v3];
1132        let a = FixedSizeBinaryArray::from(v);
1133        let b = BooleanArray::from(vec![true, false, true]);
1134        let c = filter(&a, &b).unwrap();
1135        let d = c
1136            .as_ref()
1137            .as_any()
1138            .downcast_ref::<FixedSizeBinaryArray>()
1139            .unwrap();
1140        assert_eq!(d.len(), 2);
1141        assert_eq!(d.value(0), &v1);
1142        assert_eq!(d.value(1), &v3);
1143        let c2 = FilterBuilder::new(&b)
1144            .optimize()
1145            .build()
1146            .filter(&a)
1147            .unwrap();
1148        let d2 = c2
1149            .as_ref()
1150            .as_any()
1151            .downcast_ref::<FixedSizeBinaryArray>()
1152            .unwrap();
1153        assert_eq!(d, d2);
1154
1155        let b = BooleanArray::from(vec![false, false, false]);
1156        let c = filter(&a, &b).unwrap();
1157        let d = c
1158            .as_ref()
1159            .as_any()
1160            .downcast_ref::<FixedSizeBinaryArray>()
1161            .unwrap();
1162        assert_eq!(d.len(), 0);
1163
1164        let b = BooleanArray::from(vec![true, true, true]);
1165        let c = filter(&a, &b).unwrap();
1166        let d = c
1167            .as_ref()
1168            .as_any()
1169            .downcast_ref::<FixedSizeBinaryArray>()
1170            .unwrap();
1171        assert_eq!(d.len(), 3);
1172        assert_eq!(d.value(0), &v1);
1173        assert_eq!(d.value(1), &v2);
1174        assert_eq!(d.value(2), &v3);
1175
1176        let b = BooleanArray::from(vec![false, false, true]);
1177        let c = filter(&a, &b).unwrap();
1178        let d = c
1179            .as_ref()
1180            .as_any()
1181            .downcast_ref::<FixedSizeBinaryArray>()
1182            .unwrap();
1183        assert_eq!(d.len(), 1);
1184        assert_eq!(d.value(0), &v3);
1185        let c2 = FilterBuilder::new(&b)
1186            .optimize()
1187            .build()
1188            .filter(&a)
1189            .unwrap();
1190        let d2 = c2
1191            .as_ref()
1192            .as_any()
1193            .downcast_ref::<FixedSizeBinaryArray>()
1194            .unwrap();
1195        assert_eq!(d, d2);
1196    }
1197
1198    #[test]
1199    fn test_filter_array_slice_with_null() {
1200        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1201        let b = BooleanArray::from(vec![true, false, false, true]);
1202        // filtering with sliced filter array is not currently supported
1203        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1204        // let b = b_slice.as_any().downcast_ref().unwrap();
1205        let c = filter(&a, &b).unwrap();
1206        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1207        assert_eq!(2, d.len());
1208        assert!(d.is_null(0));
1209        assert!(!d.is_null(1));
1210        assert_eq!(9, d.value(1));
1211    }
1212
1213    #[test]
1214    fn test_filter_run_end_encoding_array() {
1215        let run_ends = Int64Array::from(vec![2, 3, 8]);
1216        let values = Int64Array::from(vec![7, -2, 9]);
1217        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1218        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1219        let c = filter(&a, &b).unwrap();
1220        let actual: &RunArray<Int64Type> = as_run_array(&c);
1221        assert_eq!(4, actual.len());
1222
1223        let expected = RunArray::try_new(
1224            &Int64Array::from(vec![1, 2, 4]),
1225            &Int64Array::from(vec![7, -2, 9]),
1226        )
1227        .expect("Failed to make expected RunArray test is broken");
1228
1229        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1230        assert_eq!(actual.values(), expected.values())
1231    }
1232
1233    #[test]
1234    fn test_filter_run_end_encoding_array_remove_value() {
1235        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1236        let values = Int32Array::from(vec![7, -2, 9, -8]);
1237        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1238        let b = BooleanArray::from(vec![
1239            false, true, false, false, true, false, true, false, false, false,
1240        ]);
1241        let c = filter(&a, &b).unwrap();
1242        let actual: &RunArray<Int32Type> = as_run_array(&c);
1243        assert_eq!(3, actual.len());
1244
1245        let expected =
1246            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1247                .expect("Failed to make expected RunArray test is broken");
1248
1249        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1250        assert_eq!(actual.values(), expected.values())
1251    }
1252
1253    #[test]
1254    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1255        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1256        let values = Int16Array::from(vec![7, -2, 9, -8]);
1257        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1258        let b = BooleanArray::from(vec![
1259            false, false, false, false, false, false, true, false, false, false,
1260        ]);
1261        let c = filter(&a, &b).unwrap();
1262        let actual: &RunArray<Int16Type> = as_run_array(&c);
1263        assert_eq!(1, actual.len());
1264
1265        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1266            .expect("Failed to make expected RunArray test is broken");
1267
1268        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1269        assert_eq!(actual.values(), expected.values())
1270    }
1271
1272    #[test]
1273    fn test_filter_run_end_encoding_array_empty() {
1274        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1275        let values = Int64Array::from(vec![7, -2, 9, -8]);
1276        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1277        let b = BooleanArray::from(vec![
1278            false, false, false, false, false, false, false, false, false, false,
1279        ]);
1280        let c = filter(&a, &b).unwrap();
1281        let actual: &RunArray<Int64Type> = as_run_array(&c);
1282        assert_eq!(0, actual.len());
1283    }
1284
1285    #[test]
1286    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1287        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1288        let values = Int64Array::from(vec![7, -2, 9, -8]);
1289        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1290        let b = BooleanArray::from(vec![false, true, true]);
1291        let c = filter(&a, &b).unwrap();
1292        let actual: &RunArray<Int64Type> = as_run_array(&c);
1293        assert_eq!(2, actual.len());
1294
1295        let expected = RunArray::try_new(
1296            &Int64Array::from(vec![1, 2]),
1297            &Int64Array::from(vec![7, -2]),
1298        )
1299        .expect("Failed to make expected RunArray test is broken");
1300
1301        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1302        assert_eq!(actual.values(), expected.values())
1303    }
1304
1305    #[test]
1306    fn test_filter_dictionary_array() {
1307        let values = [Some("hello"), None, Some("world"), Some("!")];
1308        let a: Int8DictionaryArray = values.iter().copied().collect();
1309        let b = BooleanArray::from(vec![false, true, true, false]);
1310        let c = filter(&a, &b).unwrap();
1311        let d = c
1312            .as_ref()
1313            .as_any()
1314            .downcast_ref::<Int8DictionaryArray>()
1315            .unwrap();
1316        let value_array = d.values();
1317        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1318        // values are cloned in the filtered dictionary array
1319        assert_eq!(3, values.len());
1320        // but keys are filtered
1321        assert_eq!(2, d.len());
1322        assert!(d.is_null(0));
1323        assert_eq!("world", values.value(d.keys().value(1) as usize));
1324    }
1325
1326    #[test]
1327    fn test_filter_list_array() {
1328        let value_data = ArrayData::builder(DataType::Int32)
1329            .len(8)
1330            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1331            .build()
1332            .unwrap();
1333
1334        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1335
1336        let list_data_type =
1337            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1338        let list_data = ArrayData::builder(list_data_type)
1339            .len(4)
1340            .add_buffer(value_offsets)
1341            .add_child_data(value_data)
1342            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1343            .build()
1344            .unwrap();
1345
1346        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1347        let a = LargeListArray::from(list_data);
1348        let b = BooleanArray::from(vec![false, true, false, true]);
1349        let result = filter(&a, &b).unwrap();
1350
1351        // expected: [[3, 4, 5], null]
1352        let value_data = ArrayData::builder(DataType::Int32)
1353            .len(3)
1354            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1355            .build()
1356            .unwrap();
1357
1358        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1359
1360        let list_data_type =
1361            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1362        let expected = ArrayData::builder(list_data_type)
1363            .len(2)
1364            .add_buffer(value_offsets)
1365            .add_child_data(value_data)
1366            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1367            .build()
1368            .unwrap();
1369
1370        assert_eq!(&make_array(expected), &result);
1371    }
1372
1373    #[test]
1374    fn test_slice_iterator_bits() {
1375        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1376        let filter = BooleanArray::from(filter_values);
1377        let filter_count = filter_count(&filter);
1378
1379        let iter = SlicesIterator::new(&filter);
1380        let chunks = iter.collect::<Vec<_>>();
1381
1382        assert_eq!(chunks, vec![(1, 2)]);
1383        assert_eq!(filter_count, 1);
1384    }
1385
1386    #[test]
1387    fn test_slice_iterator_bits1() {
1388        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1389        let filter = BooleanArray::from(filter_values);
1390        let filter_count = filter_count(&filter);
1391
1392        let iter = SlicesIterator::new(&filter);
1393        let chunks = iter.collect::<Vec<_>>();
1394
1395        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1396        assert_eq!(filter_count, 64 - 1);
1397    }
1398
1399    #[test]
1400    fn test_slice_iterator_chunk_and_bits() {
1401        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1402        let filter = BooleanArray::from(filter_values);
1403        let filter_count = filter_count(&filter);
1404
1405        let iter = SlicesIterator::new(&filter);
1406        let chunks = iter.collect::<Vec<_>>();
1407
1408        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1409        assert_eq!(filter_count, 61 + 61 + 5);
1410    }
1411
1412    #[test]
1413    fn test_null_mask() {
1414        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1415
1416        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1417        let out = filter(&a, &mask1).unwrap();
1418        assert_eq!(out.as_ref(), &a.slice(0, 2));
1419    }
1420
1421    #[test]
1422    fn test_filter_record_batch_no_columns() {
1423        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1424        let options = RecordBatchOptions::default().with_row_count(Some(100));
1425        let record_batch =
1426            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1427        let out = filter_record_batch(&record_batch, &pred).unwrap();
1428
1429        assert_eq!(out.num_rows(), 2);
1430    }
1431
1432    #[test]
1433    fn test_fast_path() {
1434        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1435
1436        // all true
1437        let mask = BooleanArray::from(vec![true, true, true]);
1438        let out = filter(&a, &mask).unwrap();
1439        let b = out
1440            .as_any()
1441            .downcast_ref::<PrimitiveArray<Int64Type>>()
1442            .unwrap();
1443        assert_eq!(&a, b);
1444
1445        // all false
1446        let mask = BooleanArray::from(vec![false, false, false]);
1447        let out = filter(&a, &mask).unwrap();
1448        assert_eq!(out.len(), 0);
1449        assert_eq!(out.data_type(), &DataType::Int64);
1450    }
1451
1452    #[test]
1453    fn test_slices() {
1454        // takes up 2 u64s
1455        let bools = std::iter::repeat_n(true, 10)
1456            .chain(std::iter::repeat_n(false, 30))
1457            .chain(std::iter::repeat_n(true, 20))
1458            .chain(std::iter::repeat_n(false, 17))
1459            .chain(std::iter::repeat_n(true, 4));
1460
1461        let bool_array: BooleanArray = bools.map(Some).collect();
1462
1463        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1464        let expected = vec![(0, 10), (40, 60), (77, 81)];
1465        assert_eq!(slices, expected);
1466
1467        // slice with offset and truncated len
1468        let len = bool_array.len();
1469        let sliced_array = bool_array.slice(7, len - 10);
1470        let sliced_array = sliced_array
1471            .as_any()
1472            .downcast_ref::<BooleanArray>()
1473            .unwrap();
1474        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1475        let expected = vec![(0, 3), (33, 53), (70, 71)];
1476        assert_eq!(slices, expected);
1477    }
1478
1479    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1480        let mut rng = rng();
1481
1482        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1483            .take(mask_len)
1484            .collect();
1485
1486        let buffer = Buffer::from_iter(bools.iter().cloned());
1487
1488        let truncated_length = mask_len - offset - truncate;
1489
1490        let data = ArrayDataBuilder::new(DataType::Boolean)
1491            .len(truncated_length)
1492            .offset(offset)
1493            .add_buffer(buffer)
1494            .build()
1495            .unwrap();
1496
1497        let filter = BooleanArray::from(data);
1498
1499        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1500            .flat_map(|(start, end)| start..end)
1501            .collect();
1502
1503        let count = filter_count(&filter);
1504        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1505
1506        let expected_bits: Vec<_> = bools
1507            .iter()
1508            .skip(offset)
1509            .take(truncated_length)
1510            .enumerate()
1511            .flat_map(|(idx, v)| v.then(|| idx))
1512            .collect();
1513
1514        assert_eq!(slice_bits, expected_bits);
1515        assert_eq!(index_bits, expected_bits);
1516    }
1517
1518    #[test]
1519    #[cfg_attr(miri, ignore)]
1520    fn fuzz_test_slices_iterator() {
1521        let mut rng = rng();
1522
1523        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1524        for _ in 0..100 {
1525            let mask_len = rng.random_range(0..1024);
1526            let max_offset = 64.min(mask_len);
1527            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1528
1529            let max_truncate = 128.min(mask_len - offset);
1530            let truncate = uusize
1531                .sample(&mut rng)
1532                .checked_rem(max_truncate)
1533                .unwrap_or(0);
1534
1535            test_slices_fuzz(mask_len, offset, truncate);
1536        }
1537
1538        test_slices_fuzz(64, 0, 0);
1539        test_slices_fuzz(64, 8, 0);
1540        test_slices_fuzz(64, 8, 8);
1541        test_slices_fuzz(32, 8, 8);
1542        test_slices_fuzz(32, 5, 9);
1543    }
1544
1545    /// Filters `values` by `predicate` using standard rust iterators
1546    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1547        values
1548            .into_iter()
1549            .zip(predicate)
1550            .filter(|(_, x)| **x)
1551            .map(|(a, _)| a)
1552            .collect()
1553    }
1554
1555    /// Generates an array of length `len` with `valid_percent` non-null values
1556    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1557    where
1558        StandardUniform: Distribution<T>,
1559    {
1560        let mut rng = rng();
1561        (0..len)
1562            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1563            .collect()
1564    }
1565
1566    /// Generates an array of length `len` with `valid_percent` non-null values
1567    fn gen_strings(
1568        len: usize,
1569        valid_percent: f64,
1570        str_len_range: std::ops::Range<usize>,
1571    ) -> Vec<Option<String>> {
1572        let mut rng = rng();
1573        (0..len)
1574            .map(|_| {
1575                rng.random_bool(valid_percent).then(|| {
1576                    let len = rng.random_range(str_len_range.clone());
1577                    (0..len)
1578                        .map(|_| char::from(rng.sample(Alphanumeric)))
1579                        .collect()
1580                })
1581            })
1582            .collect()
1583    }
1584
1585    /// Returns an iterator that calls `Option::as_deref` on each item
1586    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1587        src.iter().map(|x| x.as_deref())
1588    }
1589
1590    #[test]
1591    #[cfg_attr(miri, ignore)]
1592    fn fuzz_filter() {
1593        let mut rng = rng();
1594
1595        for i in 0..100 {
1596            let filter_percent = match i {
1597                0..=4 => 1.,
1598                5..=10 => 0.,
1599                _ => rng.random_range(0.0..1.0),
1600            };
1601
1602            let valid_percent = rng.random_range(0.0..1.0);
1603
1604            let array_len = rng.random_range(32..256);
1605            let array_offset = rng.random_range(0..10);
1606
1607            // Construct a predicate
1608            let filter_offset = rng.random_range(0..10);
1609            let filter_truncate = rng.random_range(0..10);
1610            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1611                .take(array_len + filter_offset - filter_truncate)
1612                .collect();
1613
1614            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1615
1616            // Offset predicate
1617            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1618            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1619            let bools = &bools[filter_offset..];
1620
1621            // Test i32
1622            let values = gen_primitive(array_len + array_offset, valid_percent);
1623            let src = Int32Array::from_iter(values.iter().cloned());
1624
1625            let src = src.slice(array_offset, array_len);
1626            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1627            let values = &values[array_offset..];
1628
1629            let filtered = filter(src, predicate).unwrap();
1630            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1631            let actual: Vec<_> = array.iter().collect();
1632
1633            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1634
1635            // Test string
1636            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1637            let src = StringArray::from_iter(as_deref(&strings));
1638
1639            let src = src.slice(array_offset, array_len);
1640            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1641
1642            let filtered = filter(src, predicate).unwrap();
1643            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1644            let actual: Vec<_> = array.iter().collect();
1645
1646            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1647            assert_eq!(actual, expected_strings);
1648
1649            // Test string dictionary
1650            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1651
1652            let src = src.slice(array_offset, array_len);
1653            let src = src
1654                .as_any()
1655                .downcast_ref::<DictionaryArray<Int32Type>>()
1656                .unwrap();
1657
1658            let filtered = filter(src, predicate).unwrap();
1659
1660            let array = filtered
1661                .as_any()
1662                .downcast_ref::<DictionaryArray<Int32Type>>()
1663                .unwrap();
1664
1665            let values = array
1666                .values()
1667                .as_any()
1668                .downcast_ref::<StringArray>()
1669                .unwrap();
1670
1671            let actual: Vec<_> = array
1672                .keys()
1673                .iter()
1674                .map(|key| key.map(|key| values.value(key as usize)))
1675                .collect();
1676
1677            assert_eq!(actual, expected_strings);
1678        }
1679    }
1680
1681    #[test]
1682    fn test_filter_map() {
1683        let mut builder =
1684            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1685        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1686        builder.keys().append_value("key1");
1687        builder.values().append_value(1);
1688        builder.append(true).unwrap();
1689        builder.keys().append_value("key2");
1690        builder.keys().append_value("key3");
1691        builder.values().append_value(2);
1692        builder.values().append_value(3);
1693        builder.append(true).unwrap();
1694        builder.append(false).unwrap();
1695        builder.keys().append_value("key1");
1696        builder.values().append_value(1);
1697        builder.append(true).unwrap();
1698        let maparray = Arc::new(builder.finish()) as ArrayRef;
1699
1700        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1701            .into_iter()
1702            .collect::<BooleanArray>();
1703        let got = filter(&maparray, &indices).unwrap();
1704
1705        let mut builder =
1706            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1707        builder.keys().append_value("key1");
1708        builder.values().append_value(1);
1709        builder.append(true).unwrap();
1710        builder.keys().append_value("key1");
1711        builder.values().append_value(1);
1712        builder.append(true).unwrap();
1713        let expected = Arc::new(builder.finish()) as ArrayRef;
1714
1715        assert_eq!(&expected, &got);
1716    }
1717
1718    #[test]
1719    fn test_filter_fixed_size_list_arrays() {
1720        let value_data = ArrayData::builder(DataType::Int32)
1721            .len(9)
1722            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1723            .build()
1724            .unwrap();
1725        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1726        let list_data = ArrayData::builder(list_data_type)
1727            .len(3)
1728            .add_child_data(value_data)
1729            .build()
1730            .unwrap();
1731        let array = FixedSizeListArray::from(list_data);
1732
1733        let filter_array = BooleanArray::from(vec![true, false, false]);
1734
1735        let c = filter(&array, &filter_array).unwrap();
1736        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1737
1738        assert_eq!(filtered.len(), 1);
1739
1740        let list = filtered.value(0);
1741        assert_eq!(
1742            &[0, 1, 2],
1743            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1744        );
1745
1746        let filter_array = BooleanArray::from(vec![true, false, true]);
1747
1748        let c = filter(&array, &filter_array).unwrap();
1749        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1750
1751        assert_eq!(filtered.len(), 2);
1752
1753        let list = filtered.value(0);
1754        assert_eq!(
1755            &[0, 1, 2],
1756            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1757        );
1758        let list = filtered.value(1);
1759        assert_eq!(
1760            &[6, 7, 8],
1761            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1762        );
1763    }
1764
1765    #[test]
1766    fn test_filter_fixed_size_list_arrays_with_null() {
1767        let value_data = ArrayData::builder(DataType::Int32)
1768            .len(10)
1769            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1770            .build()
1771            .unwrap();
1772
1773        // Set null buts for the nested array:
1774        //  [[0, 1], null, null, [6, 7], [8, 9]]
1775        // 01011001 00000001
1776        let mut null_bits: [u8; 1] = [0; 1];
1777        bit_util::set_bit(&mut null_bits, 0);
1778        bit_util::set_bit(&mut null_bits, 3);
1779        bit_util::set_bit(&mut null_bits, 4);
1780
1781        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1782        let list_data = ArrayData::builder(list_data_type)
1783            .len(5)
1784            .add_child_data(value_data)
1785            .null_bit_buffer(Some(Buffer::from(null_bits)))
1786            .build()
1787            .unwrap();
1788        let array = FixedSizeListArray::from(list_data);
1789
1790        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1791
1792        let c = filter(&array, &filter_array).unwrap();
1793        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1794
1795        assert_eq!(filtered.len(), 3);
1796
1797        let list = filtered.value(0);
1798        assert_eq!(
1799            &[0, 1],
1800            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801        );
1802        assert!(filtered.is_null(1));
1803        let list = filtered.value(2);
1804        assert_eq!(
1805            &[6, 7],
1806            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1807        );
1808    }
1809
1810    fn test_filter_union_array(array: UnionArray) {
1811        let filter_array = BooleanArray::from(vec![true, false, false]);
1812        let c = filter(&array, &filter_array).unwrap();
1813        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1814
1815        let mut builder = UnionBuilder::new_dense();
1816        builder.append::<Int32Type>("A", 1).unwrap();
1817        let expected_array = builder.build().unwrap();
1818
1819        compare_union_arrays(filtered, &expected_array);
1820
1821        let filter_array = BooleanArray::from(vec![true, false, true]);
1822        let c = filter(&array, &filter_array).unwrap();
1823        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1824
1825        let mut builder = UnionBuilder::new_dense();
1826        builder.append::<Int32Type>("A", 1).unwrap();
1827        builder.append::<Int32Type>("A", 34).unwrap();
1828        let expected_array = builder.build().unwrap();
1829
1830        compare_union_arrays(filtered, &expected_array);
1831
1832        let filter_array = BooleanArray::from(vec![true, true, false]);
1833        let c = filter(&array, &filter_array).unwrap();
1834        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1835
1836        let mut builder = UnionBuilder::new_dense();
1837        builder.append::<Int32Type>("A", 1).unwrap();
1838        builder.append::<Float64Type>("B", 3.2).unwrap();
1839        let expected_array = builder.build().unwrap();
1840
1841        compare_union_arrays(filtered, &expected_array);
1842    }
1843
1844    #[test]
1845    fn test_filter_union_array_dense() {
1846        let mut builder = UnionBuilder::new_dense();
1847        builder.append::<Int32Type>("A", 1).unwrap();
1848        builder.append::<Float64Type>("B", 3.2).unwrap();
1849        builder.append::<Int32Type>("A", 34).unwrap();
1850        let array = builder.build().unwrap();
1851
1852        test_filter_union_array(array);
1853    }
1854
1855    #[test]
1856    fn test_filter_run_union_array_dense() {
1857        let mut builder = UnionBuilder::new_dense();
1858        builder.append::<Int32Type>("A", 1).unwrap();
1859        builder.append::<Int32Type>("A", 3).unwrap();
1860        builder.append::<Int32Type>("A", 34).unwrap();
1861        let array = builder.build().unwrap();
1862
1863        let filter_array = BooleanArray::from(vec![true, true, false]);
1864        let c = filter(&array, &filter_array).unwrap();
1865        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1866
1867        let mut builder = UnionBuilder::new_dense();
1868        builder.append::<Int32Type>("A", 1).unwrap();
1869        builder.append::<Int32Type>("A", 3).unwrap();
1870        let expected = builder.build().unwrap();
1871
1872        assert_eq!(filtered.to_data(), expected.to_data());
1873    }
1874
1875    #[test]
1876    fn test_filter_union_array_dense_with_nulls() {
1877        let mut builder = UnionBuilder::new_dense();
1878        builder.append::<Int32Type>("A", 1).unwrap();
1879        builder.append::<Float64Type>("B", 3.2).unwrap();
1880        builder.append_null::<Float64Type>("B").unwrap();
1881        builder.append::<Int32Type>("A", 34).unwrap();
1882        let array = builder.build().unwrap();
1883
1884        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1885        let c = filter(&array, &filter_array).unwrap();
1886        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1887
1888        let mut builder = UnionBuilder::new_dense();
1889        builder.append::<Int32Type>("A", 1).unwrap();
1890        builder.append::<Float64Type>("B", 3.2).unwrap();
1891        let expected_array = builder.build().unwrap();
1892
1893        compare_union_arrays(filtered, &expected_array);
1894
1895        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1896        let c = filter(&array, &filter_array).unwrap();
1897        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1898
1899        let mut builder = UnionBuilder::new_dense();
1900        builder.append::<Int32Type>("A", 1).unwrap();
1901        builder.append_null::<Float64Type>("B").unwrap();
1902        let expected_array = builder.build().unwrap();
1903
1904        compare_union_arrays(filtered, &expected_array);
1905    }
1906
1907    #[test]
1908    fn test_filter_union_array_sparse() {
1909        let mut builder = UnionBuilder::new_sparse();
1910        builder.append::<Int32Type>("A", 1).unwrap();
1911        builder.append::<Float64Type>("B", 3.2).unwrap();
1912        builder.append::<Int32Type>("A", 34).unwrap();
1913        let array = builder.build().unwrap();
1914
1915        test_filter_union_array(array);
1916    }
1917
1918    #[test]
1919    fn test_filter_union_array_sparse_with_nulls() {
1920        let mut builder = UnionBuilder::new_sparse();
1921        builder.append::<Int32Type>("A", 1).unwrap();
1922        builder.append::<Float64Type>("B", 3.2).unwrap();
1923        builder.append_null::<Float64Type>("B").unwrap();
1924        builder.append::<Int32Type>("A", 34).unwrap();
1925        let array = builder.build().unwrap();
1926
1927        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1928        let c = filter(&array, &filter_array).unwrap();
1929        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1930
1931        let mut builder = UnionBuilder::new_sparse();
1932        builder.append::<Int32Type>("A", 1).unwrap();
1933        builder.append_null::<Float64Type>("B").unwrap();
1934        let expected_array = builder.build().unwrap();
1935
1936        compare_union_arrays(filtered, &expected_array);
1937    }
1938
1939    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1940        assert_eq!(union1.len(), union2.len());
1941
1942        for i in 0..union1.len() {
1943            let type_id = union1.type_id(i);
1944
1945            let slot1 = union1.value(i);
1946            let slot2 = union2.value(i);
1947
1948            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1949
1950            if !slot1.is_null(0) && !slot2.is_null(0) {
1951                match type_id {
1952                    0 => {
1953                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1954                        assert_eq!(slot1.len(), 1);
1955                        let value1 = slot1.value(0);
1956
1957                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1958                        assert_eq!(slot2.len(), 1);
1959                        let value2 = slot2.value(0);
1960                        assert_eq!(value1, value2);
1961                    }
1962                    1 => {
1963                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1964                        assert_eq!(slot1.len(), 1);
1965                        let value1 = slot1.value(0);
1966
1967                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1968                        assert_eq!(slot2.len(), 1);
1969                        let value2 = slot2.value(0);
1970                        assert_eq!(value1, value2);
1971                    }
1972                    _ => unreachable!(),
1973                }
1974            }
1975        }
1976    }
1977
1978    #[test]
1979    fn test_filter_struct() {
1980        let predicate = BooleanArray::from(vec![true, false, true, false]);
1981
1982        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1983        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1984
1985        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1986        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1987
1988        let null_mask = NullBuffer::from(vec![true, false, false, true]);
1989        let null_mask_filtered = NullBuffer::from(vec![true, false]);
1990
1991        let a_field = Field::new("a", DataType::Utf8, false);
1992        let b_field = Field::new("b", DataType::Int32, false);
1993
1994        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1995        let expected =
1996            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1997
1998        let result = filter(&array, &predicate).unwrap();
1999
2000        assert_eq!(result.to_data(), expected.to_data());
2001
2002        let array = StructArray::new(
2003            vec![a_field.clone()].into(),
2004            vec![a.clone()],
2005            Some(null_mask.clone()),
2006        );
2007        let expected = StructArray::new(
2008            vec![a_field.clone()].into(),
2009            vec![a_filtered.clone()],
2010            Some(null_mask_filtered.clone()),
2011        );
2012
2013        let result = filter(&array, &predicate).unwrap();
2014
2015        assert_eq!(result.to_data(), expected.to_data());
2016
2017        let array = StructArray::new(
2018            vec![a_field.clone(), b_field.clone()].into(),
2019            vec![a.clone(), b.clone()],
2020            None,
2021        );
2022        let expected = StructArray::new(
2023            vec![a_field.clone(), b_field.clone()].into(),
2024            vec![a_filtered.clone(), b_filtered.clone()],
2025            None,
2026        );
2027
2028        let result = filter(&array, &predicate).unwrap();
2029
2030        assert_eq!(result.to_data(), expected.to_data());
2031
2032        let array = StructArray::new(
2033            vec![a_field.clone(), b_field.clone()].into(),
2034            vec![a.clone(), b.clone()],
2035            Some(null_mask.clone()),
2036        );
2037
2038        let expected = StructArray::new(
2039            vec![a_field.clone(), b_field.clone()].into(),
2040            vec![a_filtered.clone(), b_filtered.clone()],
2041            Some(null_mask_filtered.clone()),
2042        );
2043
2044        let result = filter(&array, &predicate).unwrap();
2045
2046        assert_eq!(result.to_data(), expected.to_data());
2047    }
2048
2049    #[test]
2050    fn test_filter_empty_struct() {
2051        /*
2052            "a": {
2053                "b": int64,
2054                "c": {}
2055            },
2056        */
2057        let fields = arrow_schema::Field::new(
2058            "a",
2059            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2060                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2061                arrow_schema::Field::new(
2062                    "c",
2063                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2064                    true,
2065                ),
2066            ])),
2067            true,
2068        );
2069
2070        /* Test record
2071            {"a":{"c": {}}}
2072            {"a":{"c": {}}}
2073            {"a":{"c": {}}}
2074        */
2075
2076        // Create the record batch with the nested struct array
2077        let schema = Arc::new(Schema::new(vec![fields]));
2078
2079        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2080        let c = Arc::new(StructArray::new_empty_fields(
2081            3,
2082            Some(NullBuffer::from(vec![true, true, true])),
2083        ));
2084        let a = StructArray::new(
2085            vec![
2086                Field::new("b", DataType::Int64, true),
2087                Field::new("c", DataType::Struct(Fields::empty()), true),
2088            ]
2089            .into(),
2090            vec![b.clone(), c.clone()],
2091            Some(NullBuffer::from(vec![true, true, true])),
2092        );
2093        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2094        println!("{record_batch:?}");
2095
2096        // Apply the filter
2097        let predicate = BooleanArray::from(vec![true, false, true]);
2098        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2099
2100        // The filtered batch should have 2 rows (the 1st and 3rd)
2101        assert_eq!(filtered_batch.num_rows(), 2);
2102    }
2103}