arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::ArrayDataBuilder;
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        Self(filter.values().set_slices())
62    }
63}
64
65impl Iterator for SlicesIterator<'_> {
66    type Item = (usize, usize);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        self.0.next()
70    }
71}
72
73/// An iterator of `usize` whose index in [`BooleanArray`] is true
74///
75/// This provides the best performance on most predicates, apart from those which keep
76/// large runs and therefore favour [`SlicesIterator`]
77struct IndexIterator<'a> {
78    remaining: usize,
79    iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84        assert_eq!(filter.null_count(), 0);
85        let iter = filter.values().set_indices();
86        Self { remaining, iter }
87    }
88}
89
90impl Iterator for IndexIterator<'_> {
91    type Item = usize;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        if self.remaining != 0 {
95            // Fascinatingly swapping these two lines around results in a 50%
96            // performance regression for some benchmarks
97            let next = self.iter.next().expect("IndexIterator exhausted early");
98            self.remaining -= 1;
99            // Must panic if exhausted early as trusted length iterator
100            return Some(next);
101        }
102        None
103    }
104
105    fn size_hint(&self) -> (usize, Option<usize>) {
106        (self.remaining, Some(self.remaining))
107    }
108}
109
110/// Counts the number of set bits in `filter`
111fn filter_count(filter: &BooleanArray) -> usize {
112    filter.values().count_set_bits()
113}
114
115/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
116pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117    let nulls = filter.nulls().unwrap();
118    let mask = filter.values() & nulls.inner();
119    BooleanArray::new(mask, None)
120}
121
122/// Returns a filtered `values` [`Array`] where the corresponding elements of
123/// `predicate` are `true`.
124///
125/// # See also
126/// * [`FilterBuilder`] for more control over the filtering process.
127/// * [`filter_record_batch`] to filter a [`RecordBatch`]
128/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
129///   the results into a single array.
130///
131/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
132///
133/// # Example
134/// ```rust
135/// # use arrow_array::{Int32Array, BooleanArray};
136/// # use arrow_select::filter::filter;
137/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
138/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
139/// let c = filter(&array, &filter_array).unwrap();
140/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
141/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
142/// ```
143pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144    let mut filter_builder = FilterBuilder::new(predicate);
145
146    if multiple_arrays(values.data_type()) {
147        // Only optimize if filtering more than one array
148        // Otherwise, the overhead of optimization can be more than the benefit
149        filter_builder = filter_builder.optimize();
150    }
151
152    let predicate = filter_builder.build();
153
154    filter_array(values, &predicate)
155}
156
157fn multiple_arrays(data_type: &DataType) -> bool {
158    match data_type {
159        DataType::Struct(fields) => {
160            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161        }
162        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163        _ => false,
164    }
165}
166
167/// Returns a filtered [RecordBatch] where the corresponding elements of
168/// `predicate` are true.
169///
170/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
171pub fn filter_record_batch(
172    record_batch: &RecordBatch,
173    predicate: &BooleanArray,
174) -> Result<RecordBatch, ArrowError> {
175    let mut filter_builder = FilterBuilder::new(predicate);
176    if record_batch.num_columns() > 1 {
177        // Only optimize if filtering more than one column
178        // Otherwise, the overhead of optimization can be more than the benefit
179        filter_builder = filter_builder.optimize();
180    }
181    let filter = filter_builder.build();
182
183    let filtered_arrays = record_batch
184        .columns()
185        .iter()
186        .map(|a| filter_array(a, &filter))
187        .collect::<Result<Vec<_>, _>>()?;
188    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190}
191
192/// A builder to construct [`FilterPredicate`]
193#[derive(Debug)]
194pub struct FilterBuilder {
195    filter: BooleanArray,
196    count: usize,
197    strategy: IterationStrategy,
198}
199
200impl FilterBuilder {
201    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
202    pub fn new(filter: &BooleanArray) -> Self {
203        let filter = match filter.null_count() {
204            0 => filter.clone(),
205            _ => prep_null_mask_filter(filter),
206        };
207
208        let count = filter_count(&filter);
209        let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211        Self {
212            filter,
213            count,
214            strategy,
215        }
216    }
217
218    /// Compute an optimised representation of the provided `filter` mask that can be
219    /// applied to an array more quickly.
220    ///
221    /// Note: There is limited benefit to calling this to then filter a single array
222    /// Note: This will likely have a larger memory footprint than the original mask
223    pub fn optimize(mut self) -> Self {
224        match self.strategy {
225            IterationStrategy::SlicesIterator => {
226                let slices = SlicesIterator::new(&self.filter).collect();
227                self.strategy = IterationStrategy::Slices(slices)
228            }
229            IterationStrategy::IndexIterator => {
230                let indices = IndexIterator::new(&self.filter, self.count).collect();
231                self.strategy = IterationStrategy::Indices(indices)
232            }
233            _ => {}
234        }
235        self
236    }
237
238    /// Construct the final `FilterPredicate`
239    pub fn build(self) -> FilterPredicate {
240        FilterPredicate {
241            filter: self.filter,
242            count: self.count,
243            strategy: self.strategy,
244        }
245    }
246}
247
248/// The iteration strategy used to evaluate [`FilterPredicate`]
249#[derive(Debug)]
250enum IterationStrategy {
251    /// A lazily evaluated iterator of ranges
252    SlicesIterator,
253    /// A lazily evaluated iterator of indices
254    IndexIterator,
255    /// A precomputed list of indices
256    Indices(Vec<usize>),
257    /// A precomputed array of ranges
258    Slices(Vec<(usize, usize)>),
259    /// Select all rows
260    All,
261    /// Select no rows
262    None,
263}
264
265impl IterationStrategy {
266    /// The default [`IterationStrategy`] for a filter of length `filter_length`
267    /// and selecting `filter_count` rows
268    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269        if filter_length == 0 || filter_count == 0 {
270            return IterationStrategy::None;
271        }
272
273        if filter_count == filter_length {
274            return IterationStrategy::All;
275        }
276
277        // Compute the selectivity of the predicate by dividing the number of true
278        // bits in the predicate by the predicate's total length
279        //
280        // This can then be used as a heuristic for the optimal iteration strategy
281        let selectivity_frac = filter_count as f64 / filter_length as f64;
282        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283            return IterationStrategy::SlicesIterator;
284        }
285        IterationStrategy::IndexIterator
286    }
287}
288
289/// A filtering predicate that can be applied to an [`Array`]
290#[derive(Debug)]
291pub struct FilterPredicate {
292    filter: BooleanArray,
293    count: usize,
294    strategy: IterationStrategy,
295}
296
297impl FilterPredicate {
298    /// Selects rows from `values` based on this [`FilterPredicate`]
299    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300        filter_array(values, self)
301    }
302
303    /// Number of rows being selected based on this [`FilterPredicate`]
304    pub fn count(&self) -> usize {
305        self.count
306    }
307}
308
309fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310    if predicate.filter.len() > values.len() {
311        return Err(ArrowError::InvalidArgumentError(format!(
312            "Filter predicate of length {} is larger than target array of length {}",
313            predicate.filter.len(),
314            values.len()
315        )));
316    }
317
318    match predicate.strategy {
319        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321        // actually filter
322        _ => downcast_primitive_array! {
323            values => Ok(Arc::new(filter_primitive(values, predicate))),
324            DataType::Boolean => {
325                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326                Ok(Arc::new(filter_boolean(values, predicate)))
327            }
328            DataType::Utf8 => {
329                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330            }
331            DataType::LargeUtf8 => {
332                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333            }
334            DataType::Utf8View => {
335                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336            }
337            DataType::Binary => {
338                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339            }
340            DataType::LargeBinary => {
341                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342            }
343            DataType::BinaryView => {
344                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345            }
346            DataType::FixedSizeBinary(_) => {
347                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348            }
349            DataType::RunEndEncoded(_, _) => {
350                downcast_run_array!{
351                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353                }
354            }
355            DataType::Dictionary(_, _) => downcast_dictionary_array! {
356                values => Ok(Arc::new(filter_dict(values, predicate))),
357                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358            }
359            DataType::Struct(_) => {
360                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361            }
362            DataType::Union(_, UnionMode::Sparse) => {
363                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364            }
365            _ => {
366                let data = values.to_data();
367                // fallback to using MutableArrayData
368                let mut mutable = MutableArrayData::new(
369                    vec![&data],
370                    false,
371                    predicate.count,
372                );
373
374                match &predicate.strategy {
375                    IterationStrategy::Slices(slices) => {
376                        slices
377                            .iter()
378                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
379                    }
380                    _ => {
381                        let iter = SlicesIterator::new(&predicate.filter);
382                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
383                    }
384                }
385
386                let data = mutable.freeze();
387                Ok(make_array(data))
388            }
389        },
390    }
391}
392
393/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
394fn filter_run_end_array<R: RunEndIndexType>(
395    array: &RunArray<R>,
396    predicate: &FilterPredicate,
397) -> Result<RunArray<R>, ArrowError>
398where
399    R::Native: Into<i64> + From<bool>,
400    R::Native: AddAssign,
401{
402    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405    let mut start = 0u64;
406    let mut j = 0;
407    let mut count = R::default_value();
408    let filter_values = predicate.filter.values();
409    let run_ends = run_ends.inner();
410
411    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412        let mut keep = false;
413        let mut end = run_ends[i].into() as u64;
414        let difference = end.saturating_sub(filter_values.len() as u64);
415        end -= difference;
416
417        // Safety: we subtract the difference off `end` so we are always within bounds
418        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419            count += R::Native::from(pred);
420            keep |= pred
421        }
422        // this is to avoid branching
423        new_run_ends[j] = count;
424        j += keep as usize;
425
426        start = end;
427        keep
428    })
429    .into();
430
431    new_run_ends.truncate(j);
432
433    let values = array.values();
434    let values = filter(&values, &pred)?;
435
436    let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
437    RunArray::try_new(&run_ends, &values)
438}
439
440/// Computes a new null mask for `data` based on `predicate`
441///
442/// If the predicate selected no null-rows, returns `None`, otherwise returns
443/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
444/// in the filtered output, and `null_buffer` is the filtered null buffer
445///
446fn filter_null_mask(
447    nulls: Option<&NullBuffer>,
448    predicate: &FilterPredicate,
449) -> Option<(usize, Buffer)> {
450    let nulls = nulls?;
451    if nulls.null_count() == 0 {
452        return None;
453    }
454
455    let nulls = filter_bits(nulls.inner(), predicate);
456    // The filtered `nulls` has a length of `predicate.count` bits and
457    // therefore the null count is this minus the number of valid bits
458    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460    if null_count == 0 {
461        return None;
462    }
463
464    Some((null_count, nulls))
465}
466
467/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
468fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469    let src = buffer.values();
470    let offset = buffer.offset();
471
472    match &predicate.strategy {
473        IterationStrategy::IndexIterator => {
474            let bits = IndexIterator::new(&predicate.filter, predicate.count)
475                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477            // SAFETY: `IndexIterator` reports its size correctly
478            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479        }
480        IterationStrategy::Indices(indices) => {
481            let bits = indices
482                .iter()
483                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485            // SAFETY: `Vec::iter()` reports its size correctly
486            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487        }
488        IterationStrategy::SlicesIterator => {
489            let mut builder = BooleanBufferBuilder::new(predicate.count);
490            for (start, end) in SlicesIterator::new(&predicate.filter) {
491                builder.append_packed_range(start + offset..end + offset, src)
492            }
493            builder.into()
494        }
495        IterationStrategy::Slices(slices) => {
496            let mut builder = BooleanBufferBuilder::new(predicate.count);
497            for (start, end) in slices {
498                builder.append_packed_range(*start + offset..*end + offset, src)
499            }
500            builder.into()
501        }
502        IterationStrategy::All | IterationStrategy::None => unreachable!(),
503    }
504}
505
506/// `filter` implementation for boolean buffers
507fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508    let values = filter_bits(array.values(), predicate);
509
510    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511        .len(predicate.count)
512        .add_buffer(values);
513
514    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516    }
517
518    let data = unsafe { builder.build_unchecked() };
519    BooleanArray::from(data)
520}
521
522#[inline(never)]
523fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524    assert!(values.len() >= predicate.filter.len());
525
526    match &predicate.strategy {
527        IterationStrategy::SlicesIterator => {
528            let mut buffer = Vec::with_capacity(predicate.count);
529            for (start, end) in SlicesIterator::new(&predicate.filter) {
530                buffer.extend_from_slice(&values[start..end]);
531            }
532            buffer.into()
533        }
534        IterationStrategy::Slices(slices) => {
535            let mut buffer = Vec::with_capacity(predicate.count);
536            for (start, end) in slices {
537                buffer.extend_from_slice(&values[*start..*end]);
538            }
539            buffer.into()
540        }
541        IterationStrategy::IndexIterator => {
542            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544            // SAFETY: IndexIterator is trusted length
545            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546        }
547        IterationStrategy::Indices(indices) => {
548            let iter = indices.iter().map(|x| values[*x]);
549            iter.collect::<Vec<_>>().into()
550        }
551        IterationStrategy::All | IterationStrategy::None => unreachable!(),
552    }
553}
554
555/// `filter` implementation for primitive arrays
556fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557where
558    T: ArrowPrimitiveType,
559{
560    let values = array.values();
561    let buffer = filter_native(values, predicate);
562    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563        .len(predicate.count)
564        .add_buffer(buffer);
565
566    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568    }
569
570    let data = unsafe { builder.build_unchecked() };
571    PrimitiveArray::from(data)
572}
573
574/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
575/// used to build a new [`GenericByteArray`] by copying values from the source
576///
577/// TODO(raphael): Could this be used for the take kernel as well?
578struct FilterBytes<'a, OffsetSize> {
579    src_offsets: &'a [OffsetSize],
580    src_values: &'a [u8],
581    dst_offsets: Vec<OffsetSize>,
582    dst_values: Vec<u8>,
583    cur_offset: OffsetSize,
584}
585
586impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587where
588    OffsetSize: OffsetSizeTrait,
589{
590    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591    where
592        T: ByteArrayType<Offset = OffsetSize>,
593    {
594        let dst_values = Vec::new();
595        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596        let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598        dst_offsets.push(cur_offset);
599
600        Self {
601            src_offsets: array.value_offsets(),
602            src_values: array.value_data(),
603            dst_offsets,
604            dst_values,
605            cur_offset,
606        }
607    }
608
609    /// Returns the byte offset at `idx`
610    #[inline]
611    fn get_value_offset(&self, idx: usize) -> usize {
612        self.src_offsets[idx].as_usize()
613    }
614
615    /// Returns the start and end of the value at index `idx` along with its length
616    #[inline]
617    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618        // These can only fail if `array` contains invalid data
619        let start = self.get_value_offset(idx);
620        let end = self.get_value_offset(idx + 1);
621        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622        (start, end, len)
623    }
624
625    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626        self.dst_offsets.extend(iter.map(|idx| {
627            let start = self.src_offsets[idx].as_usize();
628            let end = self.src_offsets[idx + 1].as_usize();
629            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630            self.cur_offset += len;
631
632            self.cur_offset
633        }));
634    }
635
636    /// Extends the in-progress array by the indexes in the provided iterator
637    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638        self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640        for idx in iter {
641            let start = self.src_offsets[idx].as_usize();
642            let end = self.src_offsets[idx + 1].as_usize();
643            self.dst_values
644                .extend_from_slice(&self.src_values[start..end]);
645        }
646    }
647
648    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649        self.dst_offsets.reserve_exact(count);
650        for (start, end) in iter {
651            // These can only fail if `array` contains invalid data
652            for idx in start..end {
653                let (_, _, len) = self.get_value_range(idx);
654                self.cur_offset += len;
655                self.dst_offsets.push(self.cur_offset);
656            }
657        }
658    }
659
660    /// Extends the in-progress array by the ranges in the provided iterator
661    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662        self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664        for (start, end) in iter {
665            let value_start = self.get_value_offset(start);
666            let value_end = self.get_value_offset(end);
667            self.dst_values
668                .extend_from_slice(&self.src_values[value_start..value_end]);
669        }
670    }
671}
672
673/// `filter` implementation for byte arrays
674///
675/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
676/// data copied across. This allows handling the null mask separately from the data
677fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678where
679    T: ByteArrayType,
680{
681    let mut filter = FilterBytes::new(predicate.count, array);
682
683    match &predicate.strategy {
684        IterationStrategy::SlicesIterator => {
685            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686            filter.extend_slices(SlicesIterator::new(&predicate.filter))
687        }
688        IterationStrategy::Slices(slices) => {
689            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690            filter.extend_slices(slices.iter().cloned())
691        }
692        IterationStrategy::IndexIterator => {
693            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695        }
696        IterationStrategy::Indices(indices) => {
697            filter.extend_offsets_idx(indices.iter().cloned());
698            filter.extend_idx(indices.iter().cloned())
699        }
700        IterationStrategy::All | IterationStrategy::None => unreachable!(),
701    }
702
703    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704        .len(predicate.count)
705        .add_buffer(filter.dst_offsets.into())
706        .add_buffer(filter.dst_values.into());
707
708    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710    }
711
712    let data = unsafe { builder.build_unchecked() };
713    GenericByteArray::from(data)
714}
715
716/// `filter` implementation for byte view arrays.
717fn filter_byte_view<T: ByteViewType>(
718    array: &GenericByteViewArray<T>,
719    predicate: &FilterPredicate,
720) -> GenericByteViewArray<T> {
721    let new_view_buffer = filter_native(array.views(), predicate);
722
723    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724        .len(predicate.count)
725        .add_buffer(new_view_buffer)
726        .add_buffers(array.data_buffers().to_vec());
727
728    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730    }
731
732    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733}
734
735fn filter_fixed_size_binary(
736    array: &FixedSizeBinaryArray,
737    predicate: &FilterPredicate,
738) -> FixedSizeBinaryArray {
739    let values: &[u8] = array.values();
740    let value_length = array.value_length() as usize;
741    let calculate_offset_from_index = |index: usize| index * value_length;
742    let buffer = match &predicate.strategy {
743        IterationStrategy::SlicesIterator => {
744            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745            for (start, end) in SlicesIterator::new(&predicate.filter) {
746                buffer.extend_from_slice(
747                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748                );
749            }
750            buffer
751        }
752        IterationStrategy::Slices(slices) => {
753            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754            for (start, end) in slices {
755                buffer.extend_from_slice(
756                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757                );
758            }
759            buffer
760        }
761        IterationStrategy::IndexIterator => {
762            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764            });
765
766            let mut buffer = MutableBuffer::new(predicate.count * value_length);
767            iter.for_each(|item| buffer.extend_from_slice(item));
768            buffer
769        }
770        IterationStrategy::Indices(indices) => {
771            let iter = indices.iter().map(|x| {
772                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773            });
774
775            let mut buffer = MutableBuffer::new(predicate.count * value_length);
776            iter.for_each(|item| buffer.extend_from_slice(item));
777            buffer
778        }
779        IterationStrategy::All | IterationStrategy::None => unreachable!(),
780    };
781    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782        .len(predicate.count)
783        .add_buffer(buffer.into());
784
785    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787    }
788
789    let data = unsafe { builder.build_unchecked() };
790    FixedSizeBinaryArray::from(data)
791}
792
793/// `filter` implementation for dictionaries
794fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795where
796    T: ArrowDictionaryKeyType,
797    T::Native: num::Num,
798{
799    let builder = filter_primitive::<T>(array.keys(), predicate)
800        .into_data()
801        .into_builder()
802        .data_type(array.data_type().clone())
803        .child_data(vec![array.values().to_data()]);
804
805    // SAFETY:
806    // Keys were valid before, filtered subset is therefore still valid
807    DictionaryArray::from(unsafe { builder.build_unchecked() })
808}
809
810/// `filter` implementation for structs
811fn filter_struct(
812    array: &StructArray,
813    predicate: &FilterPredicate,
814) -> Result<StructArray, ArrowError> {
815    let columns = array
816        .columns()
817        .iter()
818        .map(|column| filter_array(column, predicate))
819        .collect::<Result<_, _>>()?;
820
821    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825    } else {
826        None
827    };
828
829    Ok(unsafe {
830        StructArray::new_unchecked_with_length(
831            array.fields().clone(),
832            columns,
833            nulls,
834            predicate.count(),
835        )
836    })
837}
838
839/// `filter` implementation for sparse unions
840fn filter_sparse_union(
841    array: &UnionArray,
842    predicate: &FilterPredicate,
843) -> Result<UnionArray, ArrowError> {
844    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845        unreachable!()
846    };
847
848    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
849
850    let children = fields
851        .iter()
852        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
853        .collect::<Result<_, _>>()?;
854
855    Ok(unsafe {
856        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
857    })
858}
859
860#[cfg(test)]
861mod tests {
862    use super::*;
863    use arrow_array::builder::*;
864    use arrow_array::cast::as_run_array;
865    use arrow_array::types::*;
866    use arrow_data::ArrayData;
867    use rand::distr::uniform::{UniformSampler, UniformUsize};
868    use rand::distr::{Alphanumeric, StandardUniform};
869    use rand::prelude::*;
870    use rand::rng;
871
872    macro_rules! def_temporal_test {
873        ($test:ident, $array_type: ident, $data: expr) => {
874            #[test]
875            fn $test() {
876                let a = $data;
877                let b = BooleanArray::from(vec![true, false, true, false]);
878                let c = filter(&a, &b).unwrap();
879                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880                assert_eq!(2, d.len());
881                assert_eq!(1, d.value(0));
882                assert_eq!(3, d.value(1));
883            }
884        };
885    }
886
887    def_temporal_test!(
888        test_filter_date32,
889        Date32Array,
890        Date32Array::from(vec![1, 2, 3, 4])
891    );
892    def_temporal_test!(
893        test_filter_date64,
894        Date64Array,
895        Date64Array::from(vec![1, 2, 3, 4])
896    );
897    def_temporal_test!(
898        test_filter_time32_second,
899        Time32SecondArray,
900        Time32SecondArray::from(vec![1, 2, 3, 4])
901    );
902    def_temporal_test!(
903        test_filter_time32_millisecond,
904        Time32MillisecondArray,
905        Time32MillisecondArray::from(vec![1, 2, 3, 4])
906    );
907    def_temporal_test!(
908        test_filter_time64_microsecond,
909        Time64MicrosecondArray,
910        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911    );
912    def_temporal_test!(
913        test_filter_time64_nanosecond,
914        Time64NanosecondArray,
915        Time64NanosecondArray::from(vec![1, 2, 3, 4])
916    );
917    def_temporal_test!(
918        test_filter_duration_second,
919        DurationSecondArray,
920        DurationSecondArray::from(vec![1, 2, 3, 4])
921    );
922    def_temporal_test!(
923        test_filter_duration_millisecond,
924        DurationMillisecondArray,
925        DurationMillisecondArray::from(vec![1, 2, 3, 4])
926    );
927    def_temporal_test!(
928        test_filter_duration_microsecond,
929        DurationMicrosecondArray,
930        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931    );
932    def_temporal_test!(
933        test_filter_duration_nanosecond,
934        DurationNanosecondArray,
935        DurationNanosecondArray::from(vec![1, 2, 3, 4])
936    );
937    def_temporal_test!(
938        test_filter_timestamp_second,
939        TimestampSecondArray,
940        TimestampSecondArray::from(vec![1, 2, 3, 4])
941    );
942    def_temporal_test!(
943        test_filter_timestamp_millisecond,
944        TimestampMillisecondArray,
945        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946    );
947    def_temporal_test!(
948        test_filter_timestamp_microsecond,
949        TimestampMicrosecondArray,
950        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951    );
952    def_temporal_test!(
953        test_filter_timestamp_nanosecond,
954        TimestampNanosecondArray,
955        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956    );
957
958    #[test]
959    fn test_filter_array_slice() {
960        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961        let b = BooleanArray::from(vec![true, false, false, true]);
962        // filtering with sliced filter array is not currently supported
963        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
964        // let b = b_slice.as_any().downcast_ref().unwrap();
965        let c = filter(&a, &b).unwrap();
966        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967        assert_eq!(2, d.len());
968        assert_eq!(6, d.value(0));
969        assert_eq!(9, d.value(1));
970    }
971
972    #[test]
973    fn test_filter_array_low_density() {
974        // this test exercises the all 0's branch of the filter algorithm
975        let mut data_values = (1..=65).collect::<Vec<i32>>();
976        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977        // set up two more values after the batch
978        data_values.extend_from_slice(&[66, 67]);
979        filter_values.extend_from_slice(&[false, true]);
980        let a = Int32Array::from(data_values);
981        let b = BooleanArray::from(filter_values);
982        let c = filter(&a, &b).unwrap();
983        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984        assert_eq!(2, d.len());
985        assert_eq!(65, d.value(0));
986        assert_eq!(67, d.value(1));
987    }
988
989    #[test]
990    fn test_filter_array_high_density() {
991        // this test exercises the all 1's branch of the filter algorithm
992        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993        let mut filter_values = (1..=65)
994            .map(|i| !matches!(i % 65, 0))
995            .collect::<Vec<bool>>();
996        // set second data value to null
997        data_values[1] = None;
998        // set up two more values after the batch
999        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000        filter_values.extend_from_slice(&[false, true, true, true]);
1001        let a = Int32Array::from(data_values);
1002        let b = BooleanArray::from(filter_values);
1003        let c = filter(&a, &b).unwrap();
1004        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005        assert_eq!(67, d.len());
1006        assert_eq!(3, d.null_count());
1007        assert_eq!(1, d.value(0));
1008        assert!(d.is_null(1));
1009        assert_eq!(64, d.value(63));
1010        assert!(d.is_null(64));
1011        assert_eq!(67, d.value(65));
1012    }
1013
1014    #[test]
1015    fn test_filter_string_array_simple() {
1016        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017        let b = BooleanArray::from(vec![true, false, true, false]);
1018        let c = filter(&a, &b).unwrap();
1019        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020        assert_eq!(2, d.len());
1021        assert_eq!("hello", d.value(0));
1022        assert_eq!("world", d.value(1));
1023    }
1024
1025    #[test]
1026    fn test_filter_primitive_array_with_null() {
1027        let a = Int32Array::from(vec![Some(5), None]);
1028        let b = BooleanArray::from(vec![false, true]);
1029        let c = filter(&a, &b).unwrap();
1030        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031        assert_eq!(1, d.len());
1032        assert!(d.is_null(0));
1033    }
1034
1035    #[test]
1036    fn test_filter_string_array_with_null() {
1037        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038        let b = BooleanArray::from(vec![true, false, false, true]);
1039        let c = filter(&a, &b).unwrap();
1040        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041        assert_eq!(2, d.len());
1042        assert_eq!("hello", d.value(0));
1043        assert!(!d.is_null(0));
1044        assert!(d.is_null(1));
1045    }
1046
1047    #[test]
1048    fn test_filter_binary_array_with_null() {
1049        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050        let a = BinaryArray::from(data);
1051        let b = BooleanArray::from(vec![true, false, false, true]);
1052        let c = filter(&a, &b).unwrap();
1053        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054        assert_eq!(2, d.len());
1055        assert_eq!(b"hello", d.value(0));
1056        assert!(!d.is_null(0));
1057        assert!(d.is_null(1));
1058    }
1059
1060    fn _test_filter_byte_view<T>()
1061    where
1062        T: ByteViewType,
1063        str: AsRef<T::Native>,
1064        T::Native: PartialEq,
1065    {
1066        let array = {
1067            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1068            let mut builder = GenericByteViewBuilder::<T>::new();
1069            builder.append_value("hello");
1070            builder.append_value("world");
1071            builder.append_null();
1072            builder.append_value("large payload over 12 bytes");
1073            builder.append_value("lulu");
1074            builder.finish()
1075        };
1076
1077        {
1078            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079            let actual = filter(&array, &predicate).unwrap();
1080
1081            assert_eq!(actual.len(), 3);
1082
1083            let expected = {
1084                // ["hello", null, "large payload over 12 bytes"]
1085                let mut builder = GenericByteViewBuilder::<T>::new();
1086                builder.append_value("hello");
1087                builder.append_null();
1088                builder.append_value("large payload over 12 bytes");
1089                builder.finish()
1090            };
1091
1092            assert_eq!(actual.as_ref(), &expected);
1093        }
1094
1095        {
1096            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097            let actual = filter(&array, &predicate).unwrap();
1098
1099            assert_eq!(actual.len(), 2);
1100
1101            let expected = {
1102                // ["hello", "lulu"]
1103                let mut builder = GenericByteViewBuilder::<T>::new();
1104                builder.append_value("hello");
1105                builder.append_value("lulu");
1106                builder.finish()
1107            };
1108
1109            assert_eq!(actual.as_ref(), &expected);
1110        }
1111    }
1112
1113    #[test]
1114    fn test_filter_string_view() {
1115        _test_filter_byte_view::<StringViewType>()
1116    }
1117
1118    #[test]
1119    fn test_filter_binary_view() {
1120        _test_filter_byte_view::<BinaryViewType>()
1121    }
1122
1123    #[test]
1124    fn test_filter_fixed_binary() {
1125        let v1 = [1_u8, 2];
1126        let v2 = [3_u8, 4];
1127        let v3 = [5_u8, 6];
1128        let v = vec![&v1, &v2, &v3];
1129        let a = FixedSizeBinaryArray::from(v);
1130        let b = BooleanArray::from(vec![true, false, true]);
1131        let c = filter(&a, &b).unwrap();
1132        let d = c
1133            .as_ref()
1134            .as_any()
1135            .downcast_ref::<FixedSizeBinaryArray>()
1136            .unwrap();
1137        assert_eq!(d.len(), 2);
1138        assert_eq!(d.value(0), &v1);
1139        assert_eq!(d.value(1), &v3);
1140        let c2 = FilterBuilder::new(&b)
1141            .optimize()
1142            .build()
1143            .filter(&a)
1144            .unwrap();
1145        let d2 = c2
1146            .as_ref()
1147            .as_any()
1148            .downcast_ref::<FixedSizeBinaryArray>()
1149            .unwrap();
1150        assert_eq!(d, d2);
1151
1152        let b = BooleanArray::from(vec![false, false, false]);
1153        let c = filter(&a, &b).unwrap();
1154        let d = c
1155            .as_ref()
1156            .as_any()
1157            .downcast_ref::<FixedSizeBinaryArray>()
1158            .unwrap();
1159        assert_eq!(d.len(), 0);
1160
1161        let b = BooleanArray::from(vec![true, true, true]);
1162        let c = filter(&a, &b).unwrap();
1163        let d = c
1164            .as_ref()
1165            .as_any()
1166            .downcast_ref::<FixedSizeBinaryArray>()
1167            .unwrap();
1168        assert_eq!(d.len(), 3);
1169        assert_eq!(d.value(0), &v1);
1170        assert_eq!(d.value(1), &v2);
1171        assert_eq!(d.value(2), &v3);
1172
1173        let b = BooleanArray::from(vec![false, false, true]);
1174        let c = filter(&a, &b).unwrap();
1175        let d = c
1176            .as_ref()
1177            .as_any()
1178            .downcast_ref::<FixedSizeBinaryArray>()
1179            .unwrap();
1180        assert_eq!(d.len(), 1);
1181        assert_eq!(d.value(0), &v3);
1182        let c2 = FilterBuilder::new(&b)
1183            .optimize()
1184            .build()
1185            .filter(&a)
1186            .unwrap();
1187        let d2 = c2
1188            .as_ref()
1189            .as_any()
1190            .downcast_ref::<FixedSizeBinaryArray>()
1191            .unwrap();
1192        assert_eq!(d, d2);
1193    }
1194
1195    #[test]
1196    fn test_filter_array_slice_with_null() {
1197        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198        let b = BooleanArray::from(vec![true, false, false, true]);
1199        // filtering with sliced filter array is not currently supported
1200        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1201        // let b = b_slice.as_any().downcast_ref().unwrap();
1202        let c = filter(&a, &b).unwrap();
1203        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204        assert_eq!(2, d.len());
1205        assert!(d.is_null(0));
1206        assert!(!d.is_null(1));
1207        assert_eq!(9, d.value(1));
1208    }
1209
1210    #[test]
1211    fn test_filter_run_end_encoding_array() {
1212        let run_ends = Int64Array::from(vec![2, 3, 8]);
1213        let values = Int64Array::from(vec![7, -2, 9]);
1214        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216        let c = filter(&a, &b).unwrap();
1217        let actual: &RunArray<Int64Type> = as_run_array(&c);
1218        assert_eq!(4, actual.len());
1219
1220        let expected = RunArray::try_new(
1221            &Int64Array::from(vec![1, 2, 4]),
1222            &Int64Array::from(vec![7, -2, 9]),
1223        )
1224        .expect("Failed to make expected RunArray test is broken");
1225
1226        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227        assert_eq!(actual.values(), expected.values())
1228    }
1229
1230    #[test]
1231    fn test_filter_run_end_encoding_array_remove_value() {
1232        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233        let values = Int32Array::from(vec![7, -2, 9, -8]);
1234        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235        let b = BooleanArray::from(vec![
1236            false, true, false, false, true, false, true, false, false, false,
1237        ]);
1238        let c = filter(&a, &b).unwrap();
1239        let actual: &RunArray<Int32Type> = as_run_array(&c);
1240        assert_eq!(3, actual.len());
1241
1242        let expected =
1243            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244                .expect("Failed to make expected RunArray test is broken");
1245
1246        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247        assert_eq!(actual.values(), expected.values())
1248    }
1249
1250    #[test]
1251    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253        let values = Int16Array::from(vec![7, -2, 9, -8]);
1254        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255        let b = BooleanArray::from(vec![
1256            false, false, false, false, false, false, true, false, false, false,
1257        ]);
1258        let c = filter(&a, &b).unwrap();
1259        let actual: &RunArray<Int16Type> = as_run_array(&c);
1260        assert_eq!(1, actual.len());
1261
1262        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263            .expect("Failed to make expected RunArray test is broken");
1264
1265        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266        assert_eq!(actual.values(), expected.values())
1267    }
1268
1269    #[test]
1270    fn test_filter_run_end_encoding_array_empty() {
1271        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272        let values = Int64Array::from(vec![7, -2, 9, -8]);
1273        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274        let b = BooleanArray::from(vec![
1275            false, false, false, false, false, false, false, false, false, false,
1276        ]);
1277        let c = filter(&a, &b).unwrap();
1278        let actual: &RunArray<Int64Type> = as_run_array(&c);
1279        assert_eq!(0, actual.len());
1280    }
1281
1282    #[test]
1283    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285        let values = Int64Array::from(vec![7, -2, 9, -8]);
1286        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287        let b = BooleanArray::from(vec![false, true, true]);
1288        let c = filter(&a, &b).unwrap();
1289        let actual: &RunArray<Int64Type> = as_run_array(&c);
1290        assert_eq!(2, actual.len());
1291
1292        let expected = RunArray::try_new(
1293            &Int64Array::from(vec![1, 2]),
1294            &Int64Array::from(vec![7, -2]),
1295        )
1296        .expect("Failed to make expected RunArray test is broken");
1297
1298        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299        assert_eq!(actual.values(), expected.values())
1300    }
1301
1302    #[test]
1303    fn test_filter_dictionary_array() {
1304        let values = [Some("hello"), None, Some("world"), Some("!")];
1305        let a: Int8DictionaryArray = values.iter().copied().collect();
1306        let b = BooleanArray::from(vec![false, true, true, false]);
1307        let c = filter(&a, &b).unwrap();
1308        let d = c
1309            .as_ref()
1310            .as_any()
1311            .downcast_ref::<Int8DictionaryArray>()
1312            .unwrap();
1313        let value_array = d.values();
1314        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315        // values are cloned in the filtered dictionary array
1316        assert_eq!(3, values.len());
1317        // but keys are filtered
1318        assert_eq!(2, d.len());
1319        assert!(d.is_null(0));
1320        assert_eq!("world", values.value(d.keys().value(1) as usize));
1321    }
1322
1323    #[test]
1324    fn test_filter_list_array() {
1325        let value_data = ArrayData::builder(DataType::Int32)
1326            .len(8)
1327            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328            .build()
1329            .unwrap();
1330
1331        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333        let list_data_type =
1334            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335        let list_data = ArrayData::builder(list_data_type)
1336            .len(4)
1337            .add_buffer(value_offsets)
1338            .add_child_data(value_data)
1339            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340            .build()
1341            .unwrap();
1342
1343        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1344        let a = LargeListArray::from(list_data);
1345        let b = BooleanArray::from(vec![false, true, false, true]);
1346        let result = filter(&a, &b).unwrap();
1347
1348        // expected: [[3, 4, 5], null]
1349        let value_data = ArrayData::builder(DataType::Int32)
1350            .len(3)
1351            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352            .build()
1353            .unwrap();
1354
1355        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357        let list_data_type =
1358            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359        let expected = ArrayData::builder(list_data_type)
1360            .len(2)
1361            .add_buffer(value_offsets)
1362            .add_child_data(value_data)
1363            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364            .build()
1365            .unwrap();
1366
1367        assert_eq!(&make_array(expected), &result);
1368    }
1369
1370    #[test]
1371    fn test_slice_iterator_bits() {
1372        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373        let filter = BooleanArray::from(filter_values);
1374        let filter_count = filter_count(&filter);
1375
1376        let iter = SlicesIterator::new(&filter);
1377        let chunks = iter.collect::<Vec<_>>();
1378
1379        assert_eq!(chunks, vec![(1, 2)]);
1380        assert_eq!(filter_count, 1);
1381    }
1382
1383    #[test]
1384    fn test_slice_iterator_bits1() {
1385        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386        let filter = BooleanArray::from(filter_values);
1387        let filter_count = filter_count(&filter);
1388
1389        let iter = SlicesIterator::new(&filter);
1390        let chunks = iter.collect::<Vec<_>>();
1391
1392        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393        assert_eq!(filter_count, 64 - 1);
1394    }
1395
1396    #[test]
1397    fn test_slice_iterator_chunk_and_bits() {
1398        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399        let filter = BooleanArray::from(filter_values);
1400        let filter_count = filter_count(&filter);
1401
1402        let iter = SlicesIterator::new(&filter);
1403        let chunks = iter.collect::<Vec<_>>();
1404
1405        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406        assert_eq!(filter_count, 61 + 61 + 5);
1407    }
1408
1409    #[test]
1410    fn test_null_mask() {
1411        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414        let out = filter(&a, &mask1).unwrap();
1415        assert_eq!(out.as_ref(), &a.slice(0, 2));
1416    }
1417
1418    #[test]
1419    fn test_filter_record_batch_no_columns() {
1420        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421        let options = RecordBatchOptions::default().with_row_count(Some(100));
1422        let record_batch =
1423            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424        let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426        assert_eq!(out.num_rows(), 2);
1427    }
1428
1429    #[test]
1430    fn test_fast_path() {
1431        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433        // all true
1434        let mask = BooleanArray::from(vec![true, true, true]);
1435        let out = filter(&a, &mask).unwrap();
1436        let b = out
1437            .as_any()
1438            .downcast_ref::<PrimitiveArray<Int64Type>>()
1439            .unwrap();
1440        assert_eq!(&a, b);
1441
1442        // all false
1443        let mask = BooleanArray::from(vec![false, false, false]);
1444        let out = filter(&a, &mask).unwrap();
1445        assert_eq!(out.len(), 0);
1446        assert_eq!(out.data_type(), &DataType::Int64);
1447    }
1448
1449    #[test]
1450    fn test_slices() {
1451        // takes up 2 u64s
1452        let bools = std::iter::repeat(true)
1453            .take(10)
1454            .chain(std::iter::repeat(false).take(30))
1455            .chain(std::iter::repeat(true).take(20))
1456            .chain(std::iter::repeat(false).take(17))
1457            .chain(std::iter::repeat(true).take(4));
1458
1459        let bool_array: BooleanArray = bools.map(Some).collect();
1460
1461        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1462        let expected = vec![(0, 10), (40, 60), (77, 81)];
1463        assert_eq!(slices, expected);
1464
1465        // slice with offset and truncated len
1466        let len = bool_array.len();
1467        let sliced_array = bool_array.slice(7, len - 10);
1468        let sliced_array = sliced_array
1469            .as_any()
1470            .downcast_ref::<BooleanArray>()
1471            .unwrap();
1472        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1473        let expected = vec![(0, 3), (33, 53), (70, 71)];
1474        assert_eq!(slices, expected);
1475    }
1476
1477    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1478        let mut rng = rng();
1479
1480        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1481            .take(mask_len)
1482            .collect();
1483
1484        let buffer = Buffer::from_iter(bools.iter().cloned());
1485
1486        let truncated_length = mask_len - offset - truncate;
1487
1488        let data = ArrayDataBuilder::new(DataType::Boolean)
1489            .len(truncated_length)
1490            .offset(offset)
1491            .add_buffer(buffer)
1492            .build()
1493            .unwrap();
1494
1495        let filter = BooleanArray::from(data);
1496
1497        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1498            .flat_map(|(start, end)| start..end)
1499            .collect();
1500
1501        let count = filter_count(&filter);
1502        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1503
1504        let expected_bits: Vec<_> = bools
1505            .iter()
1506            .skip(offset)
1507            .take(truncated_length)
1508            .enumerate()
1509            .flat_map(|(idx, v)| v.then(|| idx))
1510            .collect();
1511
1512        assert_eq!(slice_bits, expected_bits);
1513        assert_eq!(index_bits, expected_bits);
1514    }
1515
1516    #[test]
1517    #[cfg_attr(miri, ignore)]
1518    fn fuzz_test_slices_iterator() {
1519        let mut rng = rng();
1520
1521        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1522        for _ in 0..100 {
1523            let mask_len = rng.random_range(0..1024);
1524            let max_offset = 64.min(mask_len);
1525            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1526
1527            let max_truncate = 128.min(mask_len - offset);
1528            let truncate = uusize
1529                .sample(&mut rng)
1530                .checked_rem(max_truncate)
1531                .unwrap_or(0);
1532
1533            test_slices_fuzz(mask_len, offset, truncate);
1534        }
1535
1536        test_slices_fuzz(64, 0, 0);
1537        test_slices_fuzz(64, 8, 0);
1538        test_slices_fuzz(64, 8, 8);
1539        test_slices_fuzz(32, 8, 8);
1540        test_slices_fuzz(32, 5, 9);
1541    }
1542
1543    /// Filters `values` by `predicate` using standard rust iterators
1544    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1545        values
1546            .into_iter()
1547            .zip(predicate)
1548            .filter(|(_, x)| **x)
1549            .map(|(a, _)| a)
1550            .collect()
1551    }
1552
1553    /// Generates an array of length `len` with `valid_percent` non-null values
1554    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1555    where
1556        StandardUniform: Distribution<T>,
1557    {
1558        let mut rng = rng();
1559        (0..len)
1560            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1561            .collect()
1562    }
1563
1564    /// Generates an array of length `len` with `valid_percent` non-null values
1565    fn gen_strings(
1566        len: usize,
1567        valid_percent: f64,
1568        str_len_range: std::ops::Range<usize>,
1569    ) -> Vec<Option<String>> {
1570        let mut rng = rng();
1571        (0..len)
1572            .map(|_| {
1573                rng.random_bool(valid_percent).then(|| {
1574                    let len = rng.random_range(str_len_range.clone());
1575                    (0..len)
1576                        .map(|_| char::from(rng.sample(Alphanumeric)))
1577                        .collect()
1578                })
1579            })
1580            .collect()
1581    }
1582
1583    /// Returns an iterator that calls `Option::as_deref` on each item
1584    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1585        src.iter().map(|x| x.as_deref())
1586    }
1587
1588    #[test]
1589    #[cfg_attr(miri, ignore)]
1590    fn fuzz_filter() {
1591        let mut rng = rng();
1592
1593        for i in 0..100 {
1594            let filter_percent = match i {
1595                0..=4 => 1.,
1596                5..=10 => 0.,
1597                _ => rng.random_range(0.0..1.0),
1598            };
1599
1600            let valid_percent = rng.random_range(0.0..1.0);
1601
1602            let array_len = rng.random_range(32..256);
1603            let array_offset = rng.random_range(0..10);
1604
1605            // Construct a predicate
1606            let filter_offset = rng.random_range(0..10);
1607            let filter_truncate = rng.random_range(0..10);
1608            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1609                .take(array_len + filter_offset - filter_truncate)
1610                .collect();
1611
1612            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1613
1614            // Offset predicate
1615            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1616            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1617            let bools = &bools[filter_offset..];
1618
1619            // Test i32
1620            let values = gen_primitive(array_len + array_offset, valid_percent);
1621            let src = Int32Array::from_iter(values.iter().cloned());
1622
1623            let src = src.slice(array_offset, array_len);
1624            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1625            let values = &values[array_offset..];
1626
1627            let filtered = filter(src, predicate).unwrap();
1628            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1629            let actual: Vec<_> = array.iter().collect();
1630
1631            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1632
1633            // Test string
1634            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1635            let src = StringArray::from_iter(as_deref(&strings));
1636
1637            let src = src.slice(array_offset, array_len);
1638            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1639
1640            let filtered = filter(src, predicate).unwrap();
1641            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1642            let actual: Vec<_> = array.iter().collect();
1643
1644            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1645            assert_eq!(actual, expected_strings);
1646
1647            // Test string dictionary
1648            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1649
1650            let src = src.slice(array_offset, array_len);
1651            let src = src
1652                .as_any()
1653                .downcast_ref::<DictionaryArray<Int32Type>>()
1654                .unwrap();
1655
1656            let filtered = filter(src, predicate).unwrap();
1657
1658            let array = filtered
1659                .as_any()
1660                .downcast_ref::<DictionaryArray<Int32Type>>()
1661                .unwrap();
1662
1663            let values = array
1664                .values()
1665                .as_any()
1666                .downcast_ref::<StringArray>()
1667                .unwrap();
1668
1669            let actual: Vec<_> = array
1670                .keys()
1671                .iter()
1672                .map(|key| key.map(|key| values.value(key as usize)))
1673                .collect();
1674
1675            assert_eq!(actual, expected_strings);
1676        }
1677    }
1678
1679    #[test]
1680    fn test_filter_map() {
1681        let mut builder =
1682            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1683        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1684        builder.keys().append_value("key1");
1685        builder.values().append_value(1);
1686        builder.append(true).unwrap();
1687        builder.keys().append_value("key2");
1688        builder.keys().append_value("key3");
1689        builder.values().append_value(2);
1690        builder.values().append_value(3);
1691        builder.append(true).unwrap();
1692        builder.append(false).unwrap();
1693        builder.keys().append_value("key1");
1694        builder.values().append_value(1);
1695        builder.append(true).unwrap();
1696        let maparray = Arc::new(builder.finish()) as ArrayRef;
1697
1698        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1699            .into_iter()
1700            .collect::<BooleanArray>();
1701        let got = filter(&maparray, &indices).unwrap();
1702
1703        let mut builder =
1704            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1705        builder.keys().append_value("key1");
1706        builder.values().append_value(1);
1707        builder.append(true).unwrap();
1708        builder.keys().append_value("key1");
1709        builder.values().append_value(1);
1710        builder.append(true).unwrap();
1711        let expected = Arc::new(builder.finish()) as ArrayRef;
1712
1713        assert_eq!(&expected, &got);
1714    }
1715
1716    #[test]
1717    fn test_filter_fixed_size_list_arrays() {
1718        let value_data = ArrayData::builder(DataType::Int32)
1719            .len(9)
1720            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1721            .build()
1722            .unwrap();
1723        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1724        let list_data = ArrayData::builder(list_data_type)
1725            .len(3)
1726            .add_child_data(value_data)
1727            .build()
1728            .unwrap();
1729        let array = FixedSizeListArray::from(list_data);
1730
1731        let filter_array = BooleanArray::from(vec![true, false, false]);
1732
1733        let c = filter(&array, &filter_array).unwrap();
1734        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1735
1736        assert_eq!(filtered.len(), 1);
1737
1738        let list = filtered.value(0);
1739        assert_eq!(
1740            &[0, 1, 2],
1741            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1742        );
1743
1744        let filter_array = BooleanArray::from(vec![true, false, true]);
1745
1746        let c = filter(&array, &filter_array).unwrap();
1747        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1748
1749        assert_eq!(filtered.len(), 2);
1750
1751        let list = filtered.value(0);
1752        assert_eq!(
1753            &[0, 1, 2],
1754            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1755        );
1756        let list = filtered.value(1);
1757        assert_eq!(
1758            &[6, 7, 8],
1759            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1760        );
1761    }
1762
1763    #[test]
1764    fn test_filter_fixed_size_list_arrays_with_null() {
1765        let value_data = ArrayData::builder(DataType::Int32)
1766            .len(10)
1767            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1768            .build()
1769            .unwrap();
1770
1771        // Set null buts for the nested array:
1772        //  [[0, 1], null, null, [6, 7], [8, 9]]
1773        // 01011001 00000001
1774        let mut null_bits: [u8; 1] = [0; 1];
1775        bit_util::set_bit(&mut null_bits, 0);
1776        bit_util::set_bit(&mut null_bits, 3);
1777        bit_util::set_bit(&mut null_bits, 4);
1778
1779        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1780        let list_data = ArrayData::builder(list_data_type)
1781            .len(5)
1782            .add_child_data(value_data)
1783            .null_bit_buffer(Some(Buffer::from(null_bits)))
1784            .build()
1785            .unwrap();
1786        let array = FixedSizeListArray::from(list_data);
1787
1788        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1789
1790        let c = filter(&array, &filter_array).unwrap();
1791        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1792
1793        assert_eq!(filtered.len(), 3);
1794
1795        let list = filtered.value(0);
1796        assert_eq!(
1797            &[0, 1],
1798            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1799        );
1800        assert!(filtered.is_null(1));
1801        let list = filtered.value(2);
1802        assert_eq!(
1803            &[6, 7],
1804            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1805        );
1806    }
1807
1808    fn test_filter_union_array(array: UnionArray) {
1809        let filter_array = BooleanArray::from(vec![true, false, false]);
1810        let c = filter(&array, &filter_array).unwrap();
1811        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1812
1813        let mut builder = UnionBuilder::new_dense();
1814        builder.append::<Int32Type>("A", 1).unwrap();
1815        let expected_array = builder.build().unwrap();
1816
1817        compare_union_arrays(filtered, &expected_array);
1818
1819        let filter_array = BooleanArray::from(vec![true, false, true]);
1820        let c = filter(&array, &filter_array).unwrap();
1821        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1822
1823        let mut builder = UnionBuilder::new_dense();
1824        builder.append::<Int32Type>("A", 1).unwrap();
1825        builder.append::<Int32Type>("A", 34).unwrap();
1826        let expected_array = builder.build().unwrap();
1827
1828        compare_union_arrays(filtered, &expected_array);
1829
1830        let filter_array = BooleanArray::from(vec![true, true, false]);
1831        let c = filter(&array, &filter_array).unwrap();
1832        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1833
1834        let mut builder = UnionBuilder::new_dense();
1835        builder.append::<Int32Type>("A", 1).unwrap();
1836        builder.append::<Float64Type>("B", 3.2).unwrap();
1837        let expected_array = builder.build().unwrap();
1838
1839        compare_union_arrays(filtered, &expected_array);
1840    }
1841
1842    #[test]
1843    fn test_filter_union_array_dense() {
1844        let mut builder = UnionBuilder::new_dense();
1845        builder.append::<Int32Type>("A", 1).unwrap();
1846        builder.append::<Float64Type>("B", 3.2).unwrap();
1847        builder.append::<Int32Type>("A", 34).unwrap();
1848        let array = builder.build().unwrap();
1849
1850        test_filter_union_array(array);
1851    }
1852
1853    #[test]
1854    fn test_filter_run_union_array_dense() {
1855        let mut builder = UnionBuilder::new_dense();
1856        builder.append::<Int32Type>("A", 1).unwrap();
1857        builder.append::<Int32Type>("A", 3).unwrap();
1858        builder.append::<Int32Type>("A", 34).unwrap();
1859        let array = builder.build().unwrap();
1860
1861        let filter_array = BooleanArray::from(vec![true, true, false]);
1862        let c = filter(&array, &filter_array).unwrap();
1863        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1864
1865        let mut builder = UnionBuilder::new_dense();
1866        builder.append::<Int32Type>("A", 1).unwrap();
1867        builder.append::<Int32Type>("A", 3).unwrap();
1868        let expected = builder.build().unwrap();
1869
1870        assert_eq!(filtered.to_data(), expected.to_data());
1871    }
1872
1873    #[test]
1874    fn test_filter_union_array_dense_with_nulls() {
1875        let mut builder = UnionBuilder::new_dense();
1876        builder.append::<Int32Type>("A", 1).unwrap();
1877        builder.append::<Float64Type>("B", 3.2).unwrap();
1878        builder.append_null::<Float64Type>("B").unwrap();
1879        builder.append::<Int32Type>("A", 34).unwrap();
1880        let array = builder.build().unwrap();
1881
1882        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1883        let c = filter(&array, &filter_array).unwrap();
1884        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1885
1886        let mut builder = UnionBuilder::new_dense();
1887        builder.append::<Int32Type>("A", 1).unwrap();
1888        builder.append::<Float64Type>("B", 3.2).unwrap();
1889        let expected_array = builder.build().unwrap();
1890
1891        compare_union_arrays(filtered, &expected_array);
1892
1893        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1894        let c = filter(&array, &filter_array).unwrap();
1895        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1896
1897        let mut builder = UnionBuilder::new_dense();
1898        builder.append::<Int32Type>("A", 1).unwrap();
1899        builder.append_null::<Float64Type>("B").unwrap();
1900        let expected_array = builder.build().unwrap();
1901
1902        compare_union_arrays(filtered, &expected_array);
1903    }
1904
1905    #[test]
1906    fn test_filter_union_array_sparse() {
1907        let mut builder = UnionBuilder::new_sparse();
1908        builder.append::<Int32Type>("A", 1).unwrap();
1909        builder.append::<Float64Type>("B", 3.2).unwrap();
1910        builder.append::<Int32Type>("A", 34).unwrap();
1911        let array = builder.build().unwrap();
1912
1913        test_filter_union_array(array);
1914    }
1915
1916    #[test]
1917    fn test_filter_union_array_sparse_with_nulls() {
1918        let mut builder = UnionBuilder::new_sparse();
1919        builder.append::<Int32Type>("A", 1).unwrap();
1920        builder.append::<Float64Type>("B", 3.2).unwrap();
1921        builder.append_null::<Float64Type>("B").unwrap();
1922        builder.append::<Int32Type>("A", 34).unwrap();
1923        let array = builder.build().unwrap();
1924
1925        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1926        let c = filter(&array, &filter_array).unwrap();
1927        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1928
1929        let mut builder = UnionBuilder::new_sparse();
1930        builder.append::<Int32Type>("A", 1).unwrap();
1931        builder.append_null::<Float64Type>("B").unwrap();
1932        let expected_array = builder.build().unwrap();
1933
1934        compare_union_arrays(filtered, &expected_array);
1935    }
1936
1937    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1938        assert_eq!(union1.len(), union2.len());
1939
1940        for i in 0..union1.len() {
1941            let type_id = union1.type_id(i);
1942
1943            let slot1 = union1.value(i);
1944            let slot2 = union2.value(i);
1945
1946            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1947
1948            if !slot1.is_null(0) && !slot2.is_null(0) {
1949                match type_id {
1950                    0 => {
1951                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1952                        assert_eq!(slot1.len(), 1);
1953                        let value1 = slot1.value(0);
1954
1955                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1956                        assert_eq!(slot2.len(), 1);
1957                        let value2 = slot2.value(0);
1958                        assert_eq!(value1, value2);
1959                    }
1960                    1 => {
1961                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1962                        assert_eq!(slot1.len(), 1);
1963                        let value1 = slot1.value(0);
1964
1965                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1966                        assert_eq!(slot2.len(), 1);
1967                        let value2 = slot2.value(0);
1968                        assert_eq!(value1, value2);
1969                    }
1970                    _ => unreachable!(),
1971                }
1972            }
1973        }
1974    }
1975
1976    #[test]
1977    fn test_filter_struct() {
1978        let predicate = BooleanArray::from(vec![true, false, true, false]);
1979
1980        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1981        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1982
1983        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1984        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1985
1986        let null_mask = NullBuffer::from(vec![true, false, false, true]);
1987        let null_mask_filtered = NullBuffer::from(vec![true, false]);
1988
1989        let a_field = Field::new("a", DataType::Utf8, false);
1990        let b_field = Field::new("b", DataType::Int32, false);
1991
1992        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1993        let expected =
1994            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1995
1996        let result = filter(&array, &predicate).unwrap();
1997
1998        assert_eq!(result.to_data(), expected.to_data());
1999
2000        let array = StructArray::new(
2001            vec![a_field.clone()].into(),
2002            vec![a.clone()],
2003            Some(null_mask.clone()),
2004        );
2005        let expected = StructArray::new(
2006            vec![a_field.clone()].into(),
2007            vec![a_filtered.clone()],
2008            Some(null_mask_filtered.clone()),
2009        );
2010
2011        let result = filter(&array, &predicate).unwrap();
2012
2013        assert_eq!(result.to_data(), expected.to_data());
2014
2015        let array = StructArray::new(
2016            vec![a_field.clone(), b_field.clone()].into(),
2017            vec![a.clone(), b.clone()],
2018            None,
2019        );
2020        let expected = StructArray::new(
2021            vec![a_field.clone(), b_field.clone()].into(),
2022            vec![a_filtered.clone(), b_filtered.clone()],
2023            None,
2024        );
2025
2026        let result = filter(&array, &predicate).unwrap();
2027
2028        assert_eq!(result.to_data(), expected.to_data());
2029
2030        let array = StructArray::new(
2031            vec![a_field.clone(), b_field.clone()].into(),
2032            vec![a.clone(), b.clone()],
2033            Some(null_mask.clone()),
2034        );
2035
2036        let expected = StructArray::new(
2037            vec![a_field.clone(), b_field.clone()].into(),
2038            vec![a_filtered.clone(), b_filtered.clone()],
2039            Some(null_mask_filtered.clone()),
2040        );
2041
2042        let result = filter(&array, &predicate).unwrap();
2043
2044        assert_eq!(result.to_data(), expected.to_data());
2045    }
2046
2047    #[test]
2048    fn test_filter_empty_struct() {
2049        /*
2050            "a": {
2051                "b": int64,
2052                "c": {}
2053            },
2054        */
2055        let fields = arrow_schema::Field::new(
2056            "a",
2057            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2058                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2059                arrow_schema::Field::new(
2060                    "c",
2061                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2062                    true,
2063                ),
2064            ])),
2065            true,
2066        );
2067
2068        /* Test record
2069            {"a":{"c": {}}}
2070            {"a":{"c": {}}}
2071            {"a":{"c": {}}}
2072        */
2073
2074        // Create the record batch with the nested struct array
2075        let schema = Arc::new(Schema::new(vec![fields]));
2076
2077        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2078        let c = Arc::new(StructArray::new_empty_fields(
2079            3,
2080            Some(NullBuffer::from(vec![true, true, true])),
2081        ));
2082        let a = StructArray::new(
2083            vec![
2084                Field::new("b", DataType::Int64, true),
2085                Field::new("c", DataType::Struct(Fields::empty()), true),
2086            ]
2087            .into(),
2088            vec![b.clone(), c.clone()],
2089            Some(NullBuffer::from(vec![true, true, true])),
2090        );
2091        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2092        println!("{record_batch:?}");
2093
2094        // Apply the filter
2095        let predicate = BooleanArray::from(vec![true, false, true]);
2096        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2097
2098        // The filtered batch should have 2 rows (the 1st and 3rd)
2099        assert_eq!(filtered_batch.num_rows(), 2);
2100    }
2101}