arrow_arith/
arity.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Kernels for operating on [`PrimitiveArray`]s
19
20use arrow_array::builder::BufferBuilder;
21use arrow_array::*;
22use arrow_buffer::ArrowNativeType;
23use arrow_buffer::MutableBuffer;
24use arrow_buffer::buffer::NullBuffer;
25use arrow_data::ArrayData;
26use arrow_schema::ArrowError;
27
28/// See [`PrimitiveArray::unary`]
29pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
30where
31    I: ArrowPrimitiveType,
32    O: ArrowPrimitiveType,
33    F: Fn(I::Native) -> O::Native,
34{
35    array.unary(op)
36}
37
38/// See [`PrimitiveArray::unary_mut`]
39pub fn unary_mut<I, F>(
40    array: PrimitiveArray<I>,
41    op: F,
42) -> Result<PrimitiveArray<I>, PrimitiveArray<I>>
43where
44    I: ArrowPrimitiveType,
45    F: Fn(I::Native) -> I::Native,
46{
47    array.unary_mut(op)
48}
49
50/// See [`PrimitiveArray::try_unary`]
51pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>, ArrowError>
52where
53    I: ArrowPrimitiveType,
54    O: ArrowPrimitiveType,
55    F: Fn(I::Native) -> Result<O::Native, ArrowError>,
56{
57    array.try_unary(op)
58}
59
60/// See [`PrimitiveArray::try_unary_mut`]
61pub fn try_unary_mut<I, F>(
62    array: PrimitiveArray<I>,
63    op: F,
64) -> Result<Result<PrimitiveArray<I>, ArrowError>, PrimitiveArray<I>>
65where
66    I: ArrowPrimitiveType,
67    F: Fn(I::Native) -> Result<I::Native, ArrowError>,
68{
69    array.try_unary_mut(op)
70}
71
72/// Allies a binary infallable function to two [`PrimitiveArray`]s,
73/// producing a new [`PrimitiveArray`]
74///
75/// # Details
76///
77/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
78/// the results in a [`PrimitiveArray`].
79///
80/// If any index is null in either `a` or `b`, the
81/// corresponding index in the result will also be null
82///
83/// Like [`unary`], the `op` is evaluated for every element in the two arrays,
84/// including those elements which are NULL. This is beneficial as the cost of
85/// the operation is low compared to the cost of branching, and especially when
86/// the operation can be vectorised, however, requires `op` to be infallible for
87/// all possible values of its inputs
88///
89/// # Errors
90///
91/// * if the arrays have different lengths.
92///
93/// # Example
94/// ```
95/// # use arrow_arith::arity::binary;
96/// # use arrow_array::{Float32Array, Int32Array};
97/// # use arrow_array::types::Int32Type;
98/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8), Some(7.2)]);
99/// let b = Int32Array::from(vec![1, 2, 4, 9]);
100/// // compute int(a) + b for each element
101/// let c = binary(&a, &b, |a, b| a as i32 + b).unwrap();
102/// assert_eq!(c, Int32Array::from(vec![Some(6), None, Some(10), Some(16)]));
103/// ```
104pub fn binary<A, B, F, O>(
105    a: &PrimitiveArray<A>,
106    b: &PrimitiveArray<B>,
107    op: F,
108) -> Result<PrimitiveArray<O>, ArrowError>
109where
110    A: ArrowPrimitiveType,
111    B: ArrowPrimitiveType,
112    O: ArrowPrimitiveType,
113    F: Fn(A::Native, B::Native) -> O::Native,
114{
115    if a.len() != b.len() {
116        return Err(ArrowError::ComputeError(
117            "Cannot perform binary operation on arrays of different length".to_string(),
118        ));
119    }
120
121    if a.is_empty() {
122        return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
123    }
124
125    let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref());
126
127    let values = a
128        .values()
129        .into_iter()
130        .zip(b.values())
131        .map(|(l, r)| op(*l, *r));
132
133    let buffer: Vec<_> = values.collect();
134    Ok(PrimitiveArray::new(buffer.into(), nulls))
135}
136
137/// Applies a binary and infallible function to values in two arrays, replacing
138/// the values in the first array in place.
139///
140/// # Details
141///
142/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in
143/// `0..len`, modifying the [`PrimitiveArray`] `a` in place, if possible.
144///
145/// If any index is null in either `a` or `b`, the corresponding index in the
146/// result will also be null.
147///
148/// # Buffer Reuse
149///
150/// If the underlying buffers in `a` are not shared with other arrays,  mutates
151/// the underlying buffer in place, without allocating.
152///
153/// If the underlying buffer in `a` are shared, returns Err(self)
154///
155/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This
156/// is beneficial when the cost of the operation is low compared to the cost of branching, and
157/// especially when the operation can be vectorised, however, requires `op` to be infallible
158/// for all possible values of its inputs
159///
160/// # Errors
161///
162/// * If the arrays have different lengths
163/// * If the array is not mutable (see "Buffer Reuse")
164///
165/// # See Also
166///
167/// * Documentation on [`PrimitiveArray::unary_mut`] for operating on [`ArrayRef`].
168///
169/// # Example
170/// ```
171/// # use arrow_arith::arity::binary_mut;
172/// # use arrow_array::{Float32Array, Int32Array};
173/// # use arrow_array::types::Int32Type;
174/// // compute a + b for each element
175/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]);
176/// let b = Int32Array::from(vec![Some(1), None, Some(2)]);
177/// // compute a + b, updating the value in a in place if possible
178/// let a = binary_mut(a, &b, |a, b| a + b as f32).unwrap().unwrap();
179/// // a is updated in place
180/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)]));
181/// ```
182///
183/// # Example with shared buffers
184/// ```
185/// # use arrow_arith::arity::binary_mut;
186/// # use arrow_array::Float32Array;
187/// # use arrow_array::types::Int32Type;
188/// let a = Float32Array::from(vec![Some(5.1f32), None, Some(6.8)]);
189/// let b = Float32Array::from(vec![Some(1.0f32), None, Some(2.0)]);
190/// // a_clone shares the buffer with a
191/// let a_cloned = a.clone();
192/// // try to update a in place, but it is shared. Returns Err(a)
193/// let a = binary_mut(a, &b, |a, b| a + b).unwrap_err();
194/// assert_eq!(a_cloned, a);
195/// // drop shared reference
196/// drop(a_cloned);
197/// // now a is not shared, so we can update it in place
198/// let a = binary_mut(a, &b, |a, b| a + b).unwrap().unwrap();
199/// assert_eq!(a, Float32Array::from(vec![Some(6.1), None, Some(8.8)]));
200/// ```
201pub fn binary_mut<T, U, F>(
202    a: PrimitiveArray<T>,
203    b: &PrimitiveArray<U>,
204    op: F,
205) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
206where
207    T: ArrowPrimitiveType,
208    U: ArrowPrimitiveType,
209    F: Fn(T::Native, U::Native) -> T::Native,
210{
211    if a.len() != b.len() {
212        return Ok(Err(ArrowError::ComputeError(
213            "Cannot perform binary operation on arrays of different length".to_string(),
214        )));
215    }
216
217    if a.is_empty() {
218        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
219            &T::DATA_TYPE,
220        ))));
221    }
222
223    let mut builder = a.into_builder()?;
224
225    builder
226        .values_slice_mut()
227        .iter_mut()
228        .zip(b.values())
229        .for_each(|(l, r)| *l = op(*l, *r));
230
231    let array = builder.finish();
232
233    // The builder has the null buffer from `a`, it is not changed.
234    let nulls = NullBuffer::union(array.logical_nulls().as_ref(), b.logical_nulls().as_ref());
235
236    let array_builder = array.into_data().into_builder().nulls(nulls);
237
238    let array_data = unsafe { array_builder.build_unchecked() };
239    Ok(Ok(PrimitiveArray::<T>::from(array_data)))
240}
241
242/// Applies the provided fallible binary operation across `a` and `b`.
243///
244/// This will return any error encountered, or collect the results into
245/// a [`PrimitiveArray`]. If any index is null in either `a`
246/// or `b`, the corresponding index in the result will also be null
247///
248/// Like [`try_unary`] the function is only evaluated for non-null indices
249///
250/// # Error
251///
252/// Return an error if the arrays have different lengths or
253/// the operation is under erroneous
254pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
255    a: A,
256    b: B,
257    op: F,
258) -> Result<PrimitiveArray<O>, ArrowError>
259where
260    O: ArrowPrimitiveType,
261    F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
262{
263    if a.len() != b.len() {
264        return Err(ArrowError::ComputeError(
265            "Cannot perform a binary operation on arrays of different length".to_string(),
266        ));
267    }
268    if a.is_empty() {
269        return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
270    }
271    let len = a.len();
272
273    if a.null_count() == 0 && b.null_count() == 0 {
274        try_binary_no_nulls(len, a, b, op)
275    } else {
276        let nulls =
277            NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap();
278
279        let mut buffer = BufferBuilder::<O::Native>::new(len);
280        buffer.append_n_zeroed(len);
281        let slice = buffer.as_slice_mut();
282
283        nulls.try_for_each_valid_idx(|idx| {
284            unsafe {
285                *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))?
286            };
287            Ok::<_, ArrowError>(())
288        })?;
289
290        let values = buffer.finish().into();
291        Ok(PrimitiveArray::new(values, Some(nulls)))
292    }
293}
294
295/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable
296/// [`PrimitiveArray`] `a` with the results.
297///
298/// Returns any error encountered, or collects the results into a [`PrimitiveArray`] as return
299/// value. If any index is null in either `a` or `b`, the corresponding index in the result will
300/// also be null.
301///
302/// Like [`try_unary`] the function is only evaluated for non-null indices.
303///
304/// See [`binary_mut`] for errors and buffer reuse information.
305pub fn try_binary_mut<T, F>(
306    a: PrimitiveArray<T>,
307    b: &PrimitiveArray<T>,
308    op: F,
309) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
310where
311    T: ArrowPrimitiveType,
312    F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
313{
314    if a.len() != b.len() {
315        return Ok(Err(ArrowError::ComputeError(
316            "Cannot perform binary operation on arrays of different length".to_string(),
317        )));
318    }
319    let len = a.len();
320
321    if a.is_empty() {
322        return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty(
323            &T::DATA_TYPE,
324        ))));
325    }
326
327    if a.null_count() == 0 && b.null_count() == 0 {
328        try_binary_no_nulls_mut(len, a, b, op)
329    } else {
330        let nulls =
331            create_union_null_buffer(a.logical_nulls().as_ref(), b.logical_nulls().as_ref())
332                .unwrap();
333
334        let mut builder = a.into_builder()?;
335
336        let slice = builder.values_slice_mut();
337
338        let r = nulls.try_for_each_valid_idx(|idx| {
339            unsafe {
340                *slice.get_unchecked_mut(idx) =
341                    op(*slice.get_unchecked(idx), b.value_unchecked(idx))?
342            };
343            Ok::<_, ArrowError>(())
344        });
345        if let Err(err) = r {
346            return Ok(Err(err));
347        }
348        let array_builder = builder.finish().into_data().into_builder();
349        let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() };
350        Ok(Ok(PrimitiveArray::<T>::from(array_data)))
351    }
352}
353
354/// Computes the union of the nulls in two optional [`NullBuffer`] which
355/// is not shared with the input buffers.
356///
357/// The union of the nulls is the same as `NullBuffer::union(lhs, rhs)` but
358/// it does not increase the reference count of the null buffer.
359fn create_union_null_buffer(
360    lhs: Option<&NullBuffer>,
361    rhs: Option<&NullBuffer>,
362) -> Option<NullBuffer> {
363    match (lhs, rhs) {
364        (Some(lhs), Some(rhs)) => Some(NullBuffer::new(lhs.inner() & rhs.inner())),
365        (Some(n), None) | (None, Some(n)) => Some(NullBuffer::new(n.inner() & n.inner())),
366        (None, None) => None,
367    }
368}
369
370/// This intentional inline(never) attribute helps LLVM optimize the loop.
371#[inline(never)]
372fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
373    len: usize,
374    a: A,
375    b: B,
376    op: F,
377) -> Result<PrimitiveArray<O>, ArrowError>
378where
379    O: ArrowPrimitiveType,
380    F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
381{
382    let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
383    for idx in 0..len {
384        unsafe {
385            buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
386        };
387    }
388    Ok(PrimitiveArray::new(buffer.into(), None))
389}
390
391/// This intentional inline(never) attribute helps LLVM optimize the loop.
392#[inline(never)]
393fn try_binary_no_nulls_mut<T, F>(
394    len: usize,
395    a: PrimitiveArray<T>,
396    b: &PrimitiveArray<T>,
397    op: F,
398) -> Result<Result<PrimitiveArray<T>, ArrowError>, PrimitiveArray<T>>
399where
400    T: ArrowPrimitiveType,
401    F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
402{
403    let mut builder = a.into_builder()?;
404    let slice = builder.values_slice_mut();
405
406    for idx in 0..len {
407        unsafe {
408            match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) {
409                Ok(value) => *slice.get_unchecked_mut(idx) = value,
410                Err(err) => return Ok(Err(err)),
411            };
412        };
413    }
414    Ok(Ok(builder.finish()))
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use arrow_array::types::*;
421    use std::sync::Arc;
422
423    #[test]
424    #[allow(deprecated)]
425    fn test_unary_f64_slice() {
426        let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
427        let input_slice = input.slice(1, 4);
428        let result = unary(&input_slice, |n| n.round());
429        assert_eq!(
430            result,
431            Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
432        );
433    }
434
435    #[test]
436    fn test_binary_mut() {
437        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
438        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
439        let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap();
440
441        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
442        assert_eq!(c, expected);
443    }
444
445    #[test]
446    fn test_binary_mut_null_buffer() {
447        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
448
449        let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
450
451        let r1 = binary_mut(a, &b, |a, b| a + b).unwrap();
452
453        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
454        let b = Int32Array::new(
455            vec![10, 11, 12, 13, 14].into(),
456            Some(vec![true, true, true, true, true].into()),
457        );
458
459        // unwrap here means that no copying occured
460        let r2 = binary_mut(a, &b, |a, b| a + b).unwrap();
461        assert_eq!(r1.unwrap(), r2.unwrap());
462    }
463
464    #[test]
465    fn test_try_binary_mut() {
466        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
467        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
468        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
469
470        let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]);
471        assert_eq!(c, expected);
472
473        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
474        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
475        let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap();
476        let expected = Int32Array::from(vec![16, 16, 12, 12, 6]);
477        assert_eq!(c, expected);
478
479        let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
480        let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]);
481        let _ = try_binary_mut(a, &b, |l, r| {
482            if l == 1 {
483                Err(ArrowError::InvalidArgumentError(
484                    "got error".parse().unwrap(),
485                ))
486            } else {
487                Ok(l + r)
488            }
489        })
490        .unwrap()
491        .expect_err("should got error");
492    }
493
494    #[test]
495    fn test_try_binary_mut_null_buffer() {
496        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
497
498        let b = Int32Array::from(vec![Some(10), Some(11), Some(12), Some(13), Some(14)]);
499
500        let r1 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
501
502        let a = Int32Array::from(vec![Some(3), Some(4), Some(5), Some(6), None]);
503        let b = Int32Array::new(
504            vec![10, 11, 12, 13, 14].into(),
505            Some(vec![true, true, true, true, true].into()),
506        );
507
508        // unwrap here means that no copying occured
509        let r2 = try_binary_mut(a, &b, |a, b| Ok(a + b)).unwrap();
510        assert_eq!(r1.unwrap(), r2.unwrap());
511    }
512
513    #[test]
514    fn test_unary_dict_mut() {
515        let values = Int32Array::from(vec![Some(10), Some(20), None]);
516        let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
517        let dictionary = DictionaryArray::new(keys, Arc::new(values));
518
519        let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap();
520        let typed = updated.downcast_dict::<Int32Array>().unwrap();
521        assert_eq!(typed.value(0), 11);
522        assert_eq!(typed.value(1), 11);
523        assert_eq!(typed.value(2), 21);
524
525        let values = updated.values();
526        assert!(values.is_null(2));
527    }
528}