1use crate::arity::*;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::i256;
29use arrow_buffer::ArrowNativeType;
30use arrow_schema::*;
31use std::cmp::min;
32use std::sync::Arc;
33
34fn get_fixed_point_info(
37 left: (u8, i8),
38 right: (u8, i8),
39 required_scale: i8,
40) -> Result<(u8, i8, i256), ArrowError> {
41 let product_scale = left.1 + right.1;
42 let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44 if required_scale > product_scale {
45 return Err(ArrowError::ComputeError(format!(
46 "Required scale {} is greater than product scale {}",
47 required_scale, product_scale
48 )));
49 }
50
51 let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
52
53 Ok((precision, product_scale, divisor))
54}
55
56pub fn multiply_fixed_point_dyn(
71 left: &dyn Array,
72 right: &dyn Array,
73 required_scale: i8,
74) -> Result<ArrayRef, ArrowError> {
75 match (left.data_type(), right.data_type()) {
76 (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
77 let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
78 let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
79
80 multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
81 }
82 (_, _) => Err(ArrowError::CastError(format!(
83 "Unsupported data type {}, {}",
84 left.data_type(),
85 right.data_type()
86 ))),
87 }
88}
89
90pub fn multiply_fixed_point_checked(
103 left: &PrimitiveArray<Decimal128Type>,
104 right: &PrimitiveArray<Decimal128Type>,
105 required_scale: i8,
106) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
107 let (precision, product_scale, divisor) = get_fixed_point_info(
108 (left.precision(), left.scale()),
109 (right.precision(), right.scale()),
110 required_scale,
111 )?;
112
113 if required_scale == product_scale {
114 return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
115 .with_precision_and_scale(precision, required_scale);
116 }
117
118 try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
119 let a = i256::from_i128(a);
120 let b = i256::from_i128(b);
121
122 let mut mul = a.wrapping_mul(b);
123 mul = divide_and_round::<Decimal256Type>(mul, divisor);
124 mul.to_i128().ok_or_else(|| {
125 ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b))
126 })
127 })
128 .and_then(|a| a.with_precision_and_scale(precision, required_scale))
129}
130
131pub fn multiply_fixed_point(
147 left: &PrimitiveArray<Decimal128Type>,
148 right: &PrimitiveArray<Decimal128Type>,
149 required_scale: i8,
150) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
151 let (precision, product_scale, divisor) = get_fixed_point_info(
152 (left.precision(), left.scale()),
153 (right.precision(), right.scale()),
154 required_scale,
155 )?;
156
157 if required_scale == product_scale {
158 return binary(left, right, |a, b| a.mul_wrapping(b))?
159 .with_precision_and_scale(precision, required_scale);
160 }
161
162 binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
163 let a = i256::from_i128(a);
164 let b = i256::from_i128(b);
165
166 let mut mul = a.wrapping_mul(b);
167 mul = divide_and_round::<Decimal256Type>(mul, divisor);
168 mul.as_i128()
169 })
170 .and_then(|a| a.with_precision_and_scale(precision, required_scale))
171}
172
173fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
175where
176 I: DecimalType,
177 I::Native: ArrowNativeTypeOp,
178{
179 let d = input.div_wrapping(div);
180 let r = input.mod_wrapping(div);
181
182 let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
183 let half_neg = half.neg_wrapping();
184
185 match input >= I::Native::ZERO {
187 true if r >= half => d.add_wrapping(I::Native::ONE),
188 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
189 _ => d,
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::numeric::mul;
197
198 #[test]
199 fn test_decimal_multiply_allow_precision_loss() {
200 let a = Decimal128Array::from(vec![123456789000000000000000000])
203 .with_precision_and_scale(38, 18)
204 .unwrap();
205
206 let b = Decimal128Array::from(vec![10000000000000000000])
208 .with_precision_and_scale(38, 18)
209 .unwrap();
210
211 let err = mul(&a, &b).unwrap_err();
212 assert!(err
213 .to_string()
214 .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000"));
215
216 let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
218 let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
220 .with_precision_and_scale(38, 28)
221 .unwrap();
222
223 assert_eq!(&expected, &result);
224 assert_eq!(
225 result.value_as_string(0),
226 "1234567890.0000000000000000000000000000"
227 );
228
229 let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
232 .with_precision_and_scale(38, 18)
233 .unwrap();
234
235 let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
237 .with_precision_and_scale(38, 18)
238 .unwrap();
239
240 let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
241 let expected = Decimal128Array::from(vec![
247 15555555556,
248 13854595272345679012071330528765432099,
249 15555555556,
250 ])
251 .with_precision_and_scale(38, 28)
252 .unwrap();
253
254 assert_eq!(&expected, &result);
255
256 assert_eq!(
258 result.value_as_string(1),
259 "1385459527.2345679012071330528765432099"
260 );
261 assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
262 assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
263
264 let a = Decimal128Array::from(vec![1230])
265 .with_precision_and_scale(4, 2)
266 .unwrap();
267
268 let b = Decimal128Array::from(vec![1000])
269 .with_precision_and_scale(4, 2)
270 .unwrap();
271
272 let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
274 assert_eq!(result.precision(), 9);
275 assert_eq!(result.scale(), 4);
276
277 let expected = mul(&a, &b).unwrap();
278 assert_eq!(expected.as_ref(), &result);
279
280 let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
282 assert!(result
283 .to_string()
284 .contains("Required scale 5 is greater than product scale 4"));
285 }
286
287 #[test]
288 fn test_decimal_multiply_allow_precision_loss_overflow() {
289 let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
291 .with_precision_and_scale(38, 18)
292 .unwrap();
293
294 let b = Decimal128Array::from(vec![9999999999910000000000000000000])
296 .with_precision_and_scale(38, 18)
297 .unwrap();
298
299 let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
300 assert!(err.to_string().contains(
301 "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
302 ));
303
304 let result = multiply_fixed_point(&a, &b, 28).unwrap();
305 let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
306 .with_precision_and_scale(38, 28)
307 .unwrap();
308
309 assert_eq!(&expected, &result);
310 }
311
312 #[test]
313 fn test_decimal_multiply_fixed_point() {
314 let a = Decimal128Array::from(vec![123456789000000000000000000])
316 .with_precision_and_scale(38, 18)
317 .unwrap();
318
319 let b = Decimal128Array::from(vec![10000000000000000000])
321 .with_precision_and_scale(38, 18)
322 .unwrap();
323
324 let err = mul(&a, &b).unwrap_err();
326 assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000");
327
328 let result = multiply_fixed_point(&a, &b, 28).unwrap();
330 let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
332 .with_precision_and_scale(38, 28)
333 .unwrap();
334
335 assert_eq!(&expected, &result);
336 assert_eq!(
337 result.value_as_string(0),
338 "1234567890.0000000000000000000000000000"
339 );
340 }
341}