1use crate::arity::*;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::ArrowNativeType;
29use arrow_buffer::i256;
30use arrow_schema::*;
31use std::cmp::min;
32use std::sync::Arc;
33
34fn get_fixed_point_info(
37 left: (u8, i8),
38 right: (u8, i8),
39 required_scale: i8,
40) -> Result<(u8, i8, i256), ArrowError> {
41 let product_scale = left.1 + right.1;
42 let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44 if required_scale > product_scale {
45 return Err(ArrowError::ComputeError(format!(
46 "Required scale {required_scale} is greater than product scale {product_scale}",
47 )));
48 }
49
50 let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
51
52 Ok((precision, product_scale, divisor))
53}
54
55pub fn multiply_fixed_point_dyn(
70 left: &dyn Array,
71 right: &dyn Array,
72 required_scale: i8,
73) -> Result<ArrayRef, ArrowError> {
74 match (left.data_type(), right.data_type()) {
75 (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
76 let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
77 let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
78
79 multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
80 }
81 (_, _) => Err(ArrowError::CastError(format!(
82 "Unsupported data type {}, {}",
83 left.data_type(),
84 right.data_type()
85 ))),
86 }
87}
88
89pub fn multiply_fixed_point_checked(
102 left: &PrimitiveArray<Decimal128Type>,
103 right: &PrimitiveArray<Decimal128Type>,
104 required_scale: i8,
105) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
106 let (precision, product_scale, divisor) = get_fixed_point_info(
107 (left.precision(), left.scale()),
108 (right.precision(), right.scale()),
109 required_scale,
110 )?;
111
112 if required_scale == product_scale {
113 return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
114 .with_precision_and_scale(precision, required_scale);
115 }
116
117 try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
118 let a = i256::from_i128(a);
119 let b = i256::from_i128(b);
120
121 let mut mul = a.wrapping_mul(b);
122 mul = divide_and_round::<Decimal256Type>(mul, divisor);
123 mul.to_i128().ok_or_else(|| {
124 ArrowError::ArithmeticOverflow(format!("Overflow happened on: {a:?} * {b:?}"))
125 })
126 })
127 .and_then(|a| a.with_precision_and_scale(precision, required_scale))
128}
129
130pub fn multiply_fixed_point(
146 left: &PrimitiveArray<Decimal128Type>,
147 right: &PrimitiveArray<Decimal128Type>,
148 required_scale: i8,
149) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
150 let (precision, product_scale, divisor) = get_fixed_point_info(
151 (left.precision(), left.scale()),
152 (right.precision(), right.scale()),
153 required_scale,
154 )?;
155
156 if required_scale == product_scale {
157 return binary(left, right, |a, b| a.mul_wrapping(b))?
158 .with_precision_and_scale(precision, required_scale);
159 }
160
161 binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
162 let a = i256::from_i128(a);
163 let b = i256::from_i128(b);
164
165 let mut mul = a.wrapping_mul(b);
166 mul = divide_and_round::<Decimal256Type>(mul, divisor);
167 mul.as_i128()
168 })
169 .and_then(|a| a.with_precision_and_scale(precision, required_scale))
170}
171
172fn divide_and_round<I: DecimalType>(input: I::Native, div: I::Native) -> I::Native {
174 let d = input.div_wrapping(div);
175 let r = input.mod_wrapping(div);
176
177 let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
178 let half_neg = half.neg_wrapping();
179
180 match input >= I::Native::ZERO {
182 true if r >= half => d.add_wrapping(I::Native::ONE),
183 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
184 _ => d,
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::numeric::mul;
192
193 #[test]
194 fn test_decimal_multiply_allow_precision_loss() {
195 let a = Decimal128Array::from(vec![123456789000000000000000000])
198 .with_precision_and_scale(38, 18)
199 .unwrap();
200
201 let b = Decimal128Array::from(vec![10000000000000000000])
203 .with_precision_and_scale(38, 18)
204 .unwrap();
205
206 let err = mul(&a, &b).unwrap_err();
207 assert!(
208 err.to_string().contains(
209 "Overflow happened on: 123456789000000000000000000 * 10000000000000000000"
210 )
211 );
212
213 let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
215 let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
217 .with_precision_and_scale(38, 28)
218 .unwrap();
219
220 assert_eq!(&expected, &result);
221 assert_eq!(
222 result.value_as_string(0),
223 "1234567890.0000000000000000000000000000"
224 );
225
226 let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
229 .with_precision_and_scale(38, 18)
230 .unwrap();
231
232 let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
234 .with_precision_and_scale(38, 18)
235 .unwrap();
236
237 let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
238 let expected = Decimal128Array::from(vec![
244 15555555556,
245 13854595272345679012071330528765432099,
246 15555555556,
247 ])
248 .with_precision_and_scale(38, 28)
249 .unwrap();
250
251 assert_eq!(&expected, &result);
252
253 assert_eq!(
255 result.value_as_string(1),
256 "1385459527.2345679012071330528765432099"
257 );
258 assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
259 assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
260
261 let a = Decimal128Array::from(vec![1230])
262 .with_precision_and_scale(4, 2)
263 .unwrap();
264
265 let b = Decimal128Array::from(vec![1000])
266 .with_precision_and_scale(4, 2)
267 .unwrap();
268
269 let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
271 assert_eq!(result.precision(), 9);
272 assert_eq!(result.scale(), 4);
273
274 let expected = mul(&a, &b).unwrap();
275 assert_eq!(expected.as_ref(), &result);
276
277 let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
279 assert!(
280 result
281 .to_string()
282 .contains("Required scale 5 is greater than product scale 4")
283 );
284 }
285
286 #[test]
287 fn test_decimal_multiply_allow_precision_loss_overflow() {
288 let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
290 .with_precision_and_scale(38, 18)
291 .unwrap();
292
293 let b = Decimal128Array::from(vec![9999999999910000000000000000000])
295 .with_precision_and_scale(38, 18)
296 .unwrap();
297
298 let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
299 assert!(err.to_string().contains(
300 "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
301 ));
302
303 let result = multiply_fixed_point(&a, &b, 28).unwrap();
304 let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
305 .with_precision_and_scale(38, 28)
306 .unwrap();
307
308 assert_eq!(&expected, &result);
309 }
310
311 #[test]
312 fn test_decimal_multiply_fixed_point() {
313 let a = Decimal128Array::from(vec![123456789000000000000000000])
315 .with_precision_and_scale(38, 18)
316 .unwrap();
317
318 let b = Decimal128Array::from(vec![10000000000000000000])
320 .with_precision_and_scale(38, 18)
321 .unwrap();
322
323 let err = mul(&a, &b).unwrap_err();
325 assert_eq!(
326 err.to_string(),
327 "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"
328 );
329
330 let result = multiply_fixed_point(&a, &b, 28).unwrap();
332 let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
334 .with_precision_and_scale(38, 28)
335 .unwrap();
336
337 assert_eq!(&expected, &result);
338 assert_eq!(
339 result.value_as_string(0),
340 "1234567890.0000000000000000000000000000"
341 );
342 }
343}