arrow_arith/
arithmetic.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines basic arithmetic kernels for `PrimitiveArrays`.
19//!
20//! These kernels can leverage SIMD if available on your system.  Currently no runtime
21//! detection is provided, you should enable the specific SIMD intrinsics using
22//! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
23//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24
25use crate::arity::*;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::i256;
29use arrow_buffer::ArrowNativeType;
30use arrow_schema::*;
31use std::cmp::min;
32use std::sync::Arc;
33
34/// Returns the precision and scale of the result of a multiplication of two decimal types,
35/// and the divisor for fixed point multiplication.
36fn get_fixed_point_info(
37    left: (u8, i8),
38    right: (u8, i8),
39    required_scale: i8,
40) -> Result<(u8, i8, i256), ArrowError> {
41    let product_scale = left.1 + right.1;
42    let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44    if required_scale > product_scale {
45        return Err(ArrowError::ComputeError(format!(
46            "Required scale {} is greater than product scale {}",
47            required_scale, product_scale
48        )));
49    }
50
51    let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
52
53    Ok((precision, product_scale, divisor))
54}
55
56/// Perform `left * right` operation on two decimal arrays. If either left or right value is
57/// null then the result is also null.
58///
59/// This performs decimal multiplication which allows precision loss if an exact representation
60/// is not possible for the result, according to the required scale. In the case, the result
61/// will be rounded to the required scale.
62///
63/// If the required scale is greater than the product scale, an error is returned.
64///
65/// This doesn't detect overflow. Once overflowing, the result will wrap around.
66///
67/// It is implemented for compatibility with precision loss `multiply` function provided by
68/// other data processing engines. For multiplication with precision loss detection, use
69/// `multiply_dyn` or `multiply_dyn_checked` instead.
70pub fn multiply_fixed_point_dyn(
71    left: &dyn Array,
72    right: &dyn Array,
73    required_scale: i8,
74) -> Result<ArrayRef, ArrowError> {
75    match (left.data_type(), right.data_type()) {
76        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
77            let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
78            let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
79
80            multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
81        }
82        (_, _) => Err(ArrowError::CastError(format!(
83            "Unsupported data type {}, {}",
84            left.data_type(),
85            right.data_type()
86        ))),
87    }
88}
89
90/// Perform `left * right` operation on two decimal arrays. If either left or right value is
91/// null then the result is also null.
92///
93/// This performs decimal multiplication which allows precision loss if an exact representation
94/// is not possible for the result, according to the required scale. In the case, the result
95/// will be rounded to the required scale.
96///
97/// If the required scale is greater than the product scale, an error is returned.
98///
99/// It is implemented for compatibility with precision loss `multiply` function provided by
100/// other data processing engines. For multiplication with precision loss detection, use
101/// `multiply` or `multiply_checked` instead.
102pub fn multiply_fixed_point_checked(
103    left: &PrimitiveArray<Decimal128Type>,
104    right: &PrimitiveArray<Decimal128Type>,
105    required_scale: i8,
106) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
107    let (precision, product_scale, divisor) = get_fixed_point_info(
108        (left.precision(), left.scale()),
109        (right.precision(), right.scale()),
110        required_scale,
111    )?;
112
113    if required_scale == product_scale {
114        return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
115            .with_precision_and_scale(precision, required_scale);
116    }
117
118    try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
119        let a = i256::from_i128(a);
120        let b = i256::from_i128(b);
121
122        let mut mul = a.wrapping_mul(b);
123        mul = divide_and_round::<Decimal256Type>(mul, divisor);
124        mul.to_i128().ok_or_else(|| {
125            ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b))
126        })
127    })
128    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
129}
130
131/// Perform `left * right` operation on two decimal arrays. If either left or right value is
132/// null then the result is also null.
133///
134/// This performs decimal multiplication which allows precision loss if an exact representation
135/// is not possible for the result, according to the required scale. In the case, the result
136/// will be rounded to the required scale.
137///
138/// If the required scale is greater than the product scale, an error is returned.
139///
140/// This doesn't detect overflow. Once overflowing, the result will wrap around.
141/// For an overflow-checking variant, use `multiply_fixed_point_checked` instead.
142///
143/// It is implemented for compatibility with precision loss `multiply` function provided by
144/// other data processing engines. For multiplication with precision loss detection, use
145/// `multiply` or `multiply_checked` instead.
146pub fn multiply_fixed_point(
147    left: &PrimitiveArray<Decimal128Type>,
148    right: &PrimitiveArray<Decimal128Type>,
149    required_scale: i8,
150) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
151    let (precision, product_scale, divisor) = get_fixed_point_info(
152        (left.precision(), left.scale()),
153        (right.precision(), right.scale()),
154        required_scale,
155    )?;
156
157    if required_scale == product_scale {
158        return binary(left, right, |a, b| a.mul_wrapping(b))?
159            .with_precision_and_scale(precision, required_scale);
160    }
161
162    binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
163        let a = i256::from_i128(a);
164        let b = i256::from_i128(b);
165
166        let mut mul = a.wrapping_mul(b);
167        mul = divide_and_round::<Decimal256Type>(mul, divisor);
168        mul.as_i128()
169    })
170    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
171}
172
173/// Divide a decimal native value by given divisor and round the result.
174fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
175where
176    I: DecimalType,
177    I::Native: ArrowNativeTypeOp,
178{
179    let d = input.div_wrapping(div);
180    let r = input.mod_wrapping(div);
181
182    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
183    let half_neg = half.neg_wrapping();
184
185    // Round result
186    match input >= I::Native::ZERO {
187        true if r >= half => d.add_wrapping(I::Native::ONE),
188        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
189        _ => d,
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::numeric::mul;
197
198    #[test]
199    fn test_decimal_multiply_allow_precision_loss() {
200        // Overflow happening as i128 cannot hold multiplying result.
201        // [123456789]
202        let a = Decimal128Array::from(vec![123456789000000000000000000])
203            .with_precision_and_scale(38, 18)
204            .unwrap();
205
206        // [10]
207        let b = Decimal128Array::from(vec![10000000000000000000])
208            .with_precision_and_scale(38, 18)
209            .unwrap();
210
211        let err = mul(&a, &b).unwrap_err();
212        assert!(err
213            .to_string()
214            .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000"));
215
216        // Allow precision loss.
217        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
218        // [1234567890]
219        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
220            .with_precision_and_scale(38, 28)
221            .unwrap();
222
223        assert_eq!(&expected, &result);
224        assert_eq!(
225            result.value_as_string(0),
226            "1234567890.0000000000000000000000000000"
227        );
228
229        // Rounding case
230        // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555]
231        let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
232            .with_precision_and_scale(38, 18)
233            .unwrap();
234
235        // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001]
236        let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
237            .with_precision_and_scale(38, 18)
238            .unwrap();
239
240        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
241        // [
242        //    0.0000000000000000015555555556,
243        //    1385459527.2345679012071330528765432099,
244        //    0.0000000000000000015555555556
245        // ]
246        let expected = Decimal128Array::from(vec![
247            15555555556,
248            13854595272345679012071330528765432099,
249            15555555556,
250        ])
251        .with_precision_and_scale(38, 28)
252        .unwrap();
253
254        assert_eq!(&expected, &result);
255
256        // Rounded the value "1385459527.234567901207133052876543209876543210".
257        assert_eq!(
258            result.value_as_string(1),
259            "1385459527.2345679012071330528765432099"
260        );
261        assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
262        assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
263
264        let a = Decimal128Array::from(vec![1230])
265            .with_precision_and_scale(4, 2)
266            .unwrap();
267
268        let b = Decimal128Array::from(vec![1000])
269            .with_precision_and_scale(4, 2)
270            .unwrap();
271
272        // Required scale is same as the product of the input scales. Behavior is same as multiply.
273        let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
274        assert_eq!(result.precision(), 9);
275        assert_eq!(result.scale(), 4);
276
277        let expected = mul(&a, &b).unwrap();
278        assert_eq!(expected.as_ref(), &result);
279
280        // Required scale cannot be larger than the product of the input scales.
281        let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
282        assert!(result
283            .to_string()
284            .contains("Required scale 5 is greater than product scale 4"));
285    }
286
287    #[test]
288    fn test_decimal_multiply_allow_precision_loss_overflow() {
289        // [99999999999123456789]
290        let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
291            .with_precision_and_scale(38, 18)
292            .unwrap();
293
294        // [9999999999910]
295        let b = Decimal128Array::from(vec![9999999999910000000000000000000])
296            .with_precision_and_scale(38, 18)
297            .unwrap();
298
299        let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
300        assert!(err.to_string().contains(
301            "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
302        ));
303
304        let result = multiply_fixed_point(&a, &b, 28).unwrap();
305        let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
306            .with_precision_and_scale(38, 28)
307            .unwrap();
308
309        assert_eq!(&expected, &result);
310    }
311
312    #[test]
313    fn test_decimal_multiply_fixed_point() {
314        // [123456789]
315        let a = Decimal128Array::from(vec![123456789000000000000000000])
316            .with_precision_and_scale(38, 18)
317            .unwrap();
318
319        // [10]
320        let b = Decimal128Array::from(vec![10000000000000000000])
321            .with_precision_and_scale(38, 18)
322            .unwrap();
323
324        // `multiply` overflows on this case.
325        let err = mul(&a, &b).unwrap_err();
326        assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000");
327
328        // Avoid overflow by reducing the scale.
329        let result = multiply_fixed_point(&a, &b, 28).unwrap();
330        // [1234567890]
331        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
332            .with_precision_and_scale(38, 28)
333            .unwrap();
334
335        assert_eq!(&expected, &result);
336        assert_eq!(
337            result.value_as_string(0),
338            "1234567890.0000000000000000000000000000"
339        );
340    }
341}