Skip to main content

arrow_select/coalesce/
primitive.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::coalesce::InProgressArray;
19use crate::filter::{
20    FilterIndices, FilterPredicate, FilterSelection, FilterSlices, filter_null_mask,
21};
22use arrow_array::cast::AsArray;
23use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
24use arrow_buffer::{BooleanBuffer, NullBuffer, NullBufferBuilder, ScalarBuffer};
25use arrow_schema::{ArrowError, DataType};
26use std::fmt::Debug;
27use std::sync::Arc;
28
29/// InProgressArray for [`PrimitiveArray`]
30#[derive(Debug)]
31pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {
32    /// Data type of the array
33    data_type: DataType,
34    /// The current source, if any
35    source: Option<ArrayRef>,
36    /// the target batch size (and thus size for views allocation)
37    batch_size: usize,
38    /// In progress nulls
39    nulls: NullBufferBuilder,
40    /// The currently in progress array
41    current: Vec<T::Native>,
42}
43
44impl<T: ArrowPrimitiveType> InProgressPrimitiveArray<T> {
45    /// Create a new `InProgressPrimitiveArray`
46    pub(crate) fn new(batch_size: usize, data_type: DataType) -> Self {
47        Self {
48            data_type,
49            batch_size,
50            source: None,
51            nulls: NullBufferBuilder::new(batch_size),
52            current: vec![],
53        }
54    }
55
56    /// Allocate space for output values if necessary.
57    ///
58    /// This is done on write (when we know it is necessary) rather than
59    /// eagerly to avoid allocations that are not used.
60    fn ensure_capacity(&mut self) {
61        if self.current.capacity() == 0 {
62            self.current.reserve(self.batch_size);
63        }
64    }
65
66    fn append_values_by_indices(
67        current: &mut Vec<T::Native>,
68        values: &[T::Native],
69        indices: FilterIndices<'_>,
70        selected_count: usize,
71    ) {
72        let current_len = current.len();
73        let mut written = 0;
74
75        unsafe {
76            let mut out = current
77                .spare_capacity_mut()
78                .as_mut_ptr()
79                .cast::<T::Native>();
80
81            indices.for_each(|idx| {
82                // SAFETY: indices are derived from the filter predicate for this source.
83                out.write(*values.get_unchecked(idx));
84                out = out.add(1);
85                written += 1;
86            });
87
88            current.set_len(current_len + written);
89        }
90
91        debug_assert_eq!(written, selected_count);
92    }
93
94    fn append_values_by_slices(
95        current: &mut Vec<T::Native>,
96        values: &[T::Native],
97        slices: FilterSlices<'_>,
98        selected_count: usize,
99    ) {
100        let current_len = current.len();
101        let mut written = 0;
102
103        unsafe {
104            let mut out = current
105                .spare_capacity_mut()
106                .as_mut_ptr()
107                .cast::<T::Native>();
108
109            slices.for_each(|(start, end)| {
110                let len = end - start;
111                // SAFETY: slices are derived from the filter predicate for this source.
112                std::ptr::copy_nonoverlapping(values.as_ptr().add(start), out, len);
113                out = out.add(len);
114                written += len;
115            });
116
117            current.set_len(current_len + written);
118        }
119
120        debug_assert_eq!(written, selected_count);
121    }
122}
123
124#[inline]
125fn primitive_source<T: ArrowPrimitiveType>(
126    source: &Option<ArrayRef>,
127) -> Result<&PrimitiveArray<T>, ArrowError> {
128    Ok(source
129        .as_ref()
130        .ok_or_else(|| {
131            ArrowError::InvalidArgumentError(
132                "Internal Error: InProgressPrimitiveArray: source not set".to_string(),
133            )
134        })?
135        .as_primitive::<T>())
136}
137
138fn append_filtered_nulls(
139    nulls: &mut NullBufferBuilder,
140    source_nulls: Option<&NullBuffer>,
141    filter: &FilterPredicate,
142) {
143    if let Some((null_count, filtered_nulls)) = filter_null_mask(source_nulls, filter) {
144        let filtered_nulls = unsafe {
145            NullBuffer::new_unchecked(
146                BooleanBuffer::new(filtered_nulls, 0, filter.count()),
147                null_count,
148            )
149        };
150        nulls.append_buffer(&filtered_nulls);
151    } else {
152        nulls.append_n_non_nulls(filter.count());
153    }
154}
155
156impl<T: ArrowPrimitiveType + Debug> InProgressArray for InProgressPrimitiveArray<T> {
157    fn set_source(&mut self, source: Option<ArrayRef>) {
158        self.source = source;
159    }
160
161    fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError> {
162        self.ensure_capacity();
163
164        let s = primitive_source::<T>(&self.source)?;
165
166        // add nulls if necessary
167        if let Some(nulls) = s.nulls().as_ref() {
168            let nulls = nulls.slice(offset, len);
169            self.nulls.append_buffer(&nulls);
170        } else {
171            self.nulls.append_n_non_nulls(len);
172        };
173
174        // Copy the values
175        let values = s.values();
176        // SAFETY: copy_rows is called with ranges derived from the source array.
177        self.current
178            .extend_from_slice(unsafe { values.get_unchecked(offset..offset + len) });
179
180        Ok(())
181    }
182
183    fn copy_rows_by_filter(&mut self, filter: &FilterPredicate) -> Result<(), ArrowError> {
184        match filter.selection() {
185            FilterSelection::Indices(indices) => {
186                self.ensure_capacity();
187                let s = primitive_source::<T>(&self.source)?;
188
189                append_filtered_nulls(&mut self.nulls, s.nulls(), filter);
190                self.current.reserve(filter.count());
191                Self::append_values_by_indices(
192                    &mut self.current,
193                    s.values(),
194                    indices,
195                    filter.count(),
196                );
197                Ok(())
198            }
199            FilterSelection::Slices(slices) => {
200                self.ensure_capacity();
201                let s = primitive_source::<T>(&self.source)?;
202
203                append_filtered_nulls(&mut self.nulls, s.nulls(), filter);
204                self.current.reserve(filter.count());
205                Self::append_values_by_slices(
206                    &mut self.current,
207                    s.values(),
208                    slices,
209                    filter.count(),
210                );
211                Ok(())
212            }
213            // Other selection shapes reuse the generic copy_rows path.
214            selection => self.copy_rows_by_selection(selection),
215        }
216    }
217
218    fn finish(&mut self) -> Result<ArrayRef, ArrowError> {
219        // take and reset the current values and nulls
220        let values = std::mem::take(&mut self.current);
221        let nulls = self.nulls.finish();
222        self.nulls = NullBufferBuilder::new(self.batch_size);
223
224        let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), nulls)?
225            // preserve timezone / precision+scale if applicable
226            .with_data_type(self.data_type.clone());
227        Ok(Arc::new(array))
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::filter::FilterBuilder;
235    use arrow_array::types::Int32Type;
236    use arrow_array::{BooleanArray, Int32Array};
237
238    #[test]
239    fn test_copy_rows_by_filter_index_iterator() {
240        let source =
241            Int32Array::from_iter((0..21).map(|idx| if idx % 5 == 0 { None } else { Some(idx) }));
242        let filter = BooleanArray::from_iter(
243            (0..21).map(|idx| Some(matches!(idx, 0 | 1 | 2 | 3 | 5 | 8 | 13))),
244        );
245        let predicate = FilterBuilder::new(&filter).build();
246        let FilterSelection::Indices(indices) = predicate.selection() else {
247            panic!("expected index iterator selection");
248        };
249        let mut selected_indices = Vec::new();
250        indices.for_each(|idx| selected_indices.push(idx));
251        assert_eq!(selected_indices, vec![0, 1, 2, 3, 5, 8, 13]);
252
253        let mut in_progress = InProgressPrimitiveArray::<Int32Type>::new(7, DataType::Int32);
254        in_progress.set_source(Some(Arc::new(source)));
255        in_progress.copy_rows_by_filter(&predicate).unwrap();
256
257        let result = in_progress.finish().unwrap();
258        let result = result.as_primitive::<Int32Type>();
259        let expected = Int32Array::from(vec![
260            None,
261            Some(1),
262            Some(2),
263            Some(3),
264            None,
265            Some(8),
266            Some(13),
267        ]);
268        assert_eq!(result, &expected);
269    }
270
271    #[test]
272    fn test_copy_rows_by_filter_slice_iterator() {
273        let source =
274            Int32Array::from_iter((0..16).map(|idx| if idx % 5 == 0 { None } else { Some(idx) }));
275        let filter = BooleanArray::from_iter((0..16).map(|idx| Some(!matches!(idx, 3 | 9))));
276        let predicate = FilterBuilder::new(&filter).build();
277        let FilterSelection::Slices(slices) = predicate.selection() else {
278            panic!("expected slice iterator selection");
279        };
280        let mut selected_slices = Vec::new();
281        slices.for_each(|slice| selected_slices.push(slice));
282        assert_eq!(selected_slices, vec![(0, 3), (4, 9), (10, 16)]);
283
284        let mut in_progress = InProgressPrimitiveArray::<Int32Type>::new(14, DataType::Int32);
285        in_progress.set_source(Some(Arc::new(source)));
286        in_progress.copy_rows_by_filter(&predicate).unwrap();
287
288        let result = in_progress.finish().unwrap();
289        let result = result.as_primitive::<Int32Type>();
290        let expected = Int32Array::from(vec![
291            None,
292            Some(1),
293            Some(2),
294            Some(4),
295            None,
296            Some(6),
297            Some(7),
298            Some(8),
299            None,
300            Some(11),
301            Some(12),
302            Some(13),
303            Some(14),
304            None,
305        ]);
306        assert_eq!(result, &expected);
307    }
308}