1use crate::cast::*;
19
20pub(crate) trait DecimalCast: Sized {
23 fn to_i128(self) -> Option<i128>;
24
25 fn to_i256(self) -> Option<i256>;
26
27 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
28
29 fn from_f64(n: f64) -> Option<Self>;
30}
31
32impl DecimalCast for i128 {
33 fn to_i128(self) -> Option<i128> {
34 Some(self)
35 }
36
37 fn to_i256(self) -> Option<i256> {
38 Some(i256::from_i128(self))
39 }
40
41 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
42 n.to_i128()
43 }
44
45 fn from_f64(n: f64) -> Option<Self> {
46 n.to_i128()
47 }
48}
49
50impl DecimalCast for i256 {
51 fn to_i128(self) -> Option<i128> {
52 self.to_i128()
53 }
54
55 fn to_i256(self) -> Option<i256> {
56 Some(self)
57 }
58
59 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60 n.to_i256()
61 }
62
63 fn from_f64(n: f64) -> Option<Self> {
64 i256::from_f64(n)
65 }
66}
67
68pub(crate) fn cast_decimal_to_decimal_error<I, O>(
69 output_precision: u8,
70 output_scale: i8,
71) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
72where
73 I: DecimalType,
74 O: DecimalType,
75 I::Native: DecimalCast + ArrowNativeTypeOp,
76 O::Native: DecimalCast + ArrowNativeTypeOp,
77{
78 move |x: I::Native| {
79 ArrowError::CastError(format!(
80 "Cannot cast to {}({}, {}). Overflowing on {:?}",
81 O::PREFIX,
82 output_precision,
83 output_scale,
84 x
85 ))
86 }
87}
88
89pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
90 array: &PrimitiveArray<I>,
91 input_precision: u8,
92 input_scale: i8,
93 output_precision: u8,
94 output_scale: i8,
95 cast_options: &CastOptions,
96) -> Result<PrimitiveArray<O>, ArrowError>
97where
98 I: DecimalType,
99 O: DecimalType,
100 I::Native: DecimalCast + ArrowNativeTypeOp,
101 O::Native: DecimalCast + ArrowNativeTypeOp,
102{
103 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
104 let delta_scale = input_scale - output_scale;
105 let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
116
117 let div = I::Native::from_decimal(10_i128)
118 .unwrap()
119 .pow_checked(delta_scale as u32)?;
120
121 let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
122 let half_neg = half.neg_wrapping();
123
124 let f = |x: I::Native| {
125 let d = x.div_wrapping(div);
127 let r = x.mod_wrapping(div);
128
129 let adjusted = match x >= I::Native::ZERO {
131 true if r >= half => d.add_wrapping(I::Native::ONE),
132 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
133 _ => d,
134 };
135 O::Native::from_decimal(adjusted)
136 };
137
138 Ok(if is_infallible_cast {
139 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
141 let g = |x: I::Native| f(x).unwrap(); array.unary(g)
144 } else if cast_options.safe {
145 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
146 } else {
147 array.try_unary(|x| {
148 f(x).ok_or_else(|| error(x))
149 .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
150 })?
151 })
152}
153
154pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
155 array: &PrimitiveArray<I>,
156 input_precision: u8,
157 input_scale: i8,
158 output_precision: u8,
159 output_scale: i8,
160 cast_options: &CastOptions,
161) -> Result<PrimitiveArray<O>, ArrowError>
162where
163 I: DecimalType,
164 O: DecimalType,
165 I::Native: DecimalCast + ArrowNativeTypeOp,
166 O::Native: DecimalCast + ArrowNativeTypeOp,
167{
168 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
169 let delta_scale = output_scale - input_scale;
170 let mul = O::Native::from_decimal(10_i128)
171 .unwrap()
172 .pow_checked(delta_scale as u32)?;
173
174 let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
181 let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
182
183 Ok(if is_infallible_cast {
184 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
186 let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
188 array.unary(f)
189 } else if cast_options.safe {
190 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
191 } else {
192 array.try_unary(|x| {
193 f(x).ok_or_else(|| error(x))
194 .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
195 })?
196 })
197}
198
199pub(crate) fn cast_decimal_to_decimal_same_type<T>(
201 array: &PrimitiveArray<T>,
202 input_precision: u8,
203 input_scale: i8,
204 output_precision: u8,
205 output_scale: i8,
206 cast_options: &CastOptions,
207) -> Result<ArrayRef, ArrowError>
208where
209 T: DecimalType,
210 T::Native: DecimalCast + ArrowNativeTypeOp,
211{
212 let array: PrimitiveArray<T> =
213 if input_scale == output_scale && input_precision <= output_precision {
214 array.clone()
215 } else if input_scale <= output_scale {
216 convert_to_bigger_or_equal_scale_decimal::<T, T>(
217 array,
218 input_precision,
219 input_scale,
220 output_precision,
221 output_scale,
222 cast_options,
223 )?
224 } else {
225 convert_to_smaller_scale_decimal::<T, T>(
227 array,
228 input_precision,
229 input_scale,
230 output_precision,
231 output_scale,
232 cast_options,
233 )?
234 };
235
236 Ok(Arc::new(array.with_precision_and_scale(
237 output_precision,
238 output_scale,
239 )?))
240}
241
242pub(crate) fn cast_decimal_to_decimal<I, O>(
244 array: &PrimitiveArray<I>,
245 input_precision: u8,
246 input_scale: i8,
247 output_precision: u8,
248 output_scale: i8,
249 cast_options: &CastOptions,
250) -> Result<ArrayRef, ArrowError>
251where
252 I: DecimalType,
253 O: DecimalType,
254 I::Native: DecimalCast + ArrowNativeTypeOp,
255 O::Native: DecimalCast + ArrowNativeTypeOp,
256{
257 let array: PrimitiveArray<O> = if input_scale > output_scale {
258 convert_to_smaller_scale_decimal::<I, O>(
259 array,
260 input_precision,
261 input_scale,
262 output_precision,
263 output_scale,
264 cast_options,
265 )?
266 } else {
267 convert_to_bigger_or_equal_scale_decimal::<I, O>(
268 array,
269 input_precision,
270 input_scale,
271 output_precision,
272 output_scale,
273 cast_options,
274 )?
275 };
276
277 Ok(Arc::new(array.with_precision_and_scale(
278 output_precision,
279 output_scale,
280 )?))
281}
282
283pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
286 value_str: &str,
287 scale: usize,
288) -> Result<T::Native, ArrowError>
289where
290 T::Native: DecimalCast + ArrowNativeTypeOp,
291{
292 let value_str = value_str.trim();
293 let parts: Vec<&str> = value_str.split('.').collect();
294 if parts.len() > 2 {
295 return Err(ArrowError::InvalidArgumentError(format!(
296 "Invalid decimal format: {value_str:?}"
297 )));
298 }
299
300 let (negative, first_part) = if parts[0].is_empty() {
301 (false, parts[0])
302 } else {
303 match parts[0].as_bytes()[0] {
304 b'-' => (true, &parts[0][1..]),
305 b'+' => (false, &parts[0][1..]),
306 _ => (false, parts[0]),
307 }
308 };
309
310 let integers = first_part;
311 let decimals = if parts.len() == 2 { parts[1] } else { "" };
312
313 if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
314 return Err(ArrowError::InvalidArgumentError(format!(
315 "Invalid decimal format: {value_str:?}"
316 )));
317 }
318
319 if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
320 return Err(ArrowError::InvalidArgumentError(format!(
321 "Invalid decimal format: {value_str:?}"
322 )));
323 }
324
325 let mut number_decimals = if decimals.len() > scale {
327 let decimal_number = i256::from_string(decimals).ok_or_else(|| {
328 ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
329 })?;
330
331 let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
332
333 let half = div.div_wrapping(i256::from_i128(2));
334 let half_neg = half.neg_wrapping();
335
336 let d = decimal_number.div_wrapping(div);
337 let r = decimal_number.mod_wrapping(div);
338
339 let adjusted = match decimal_number >= i256::ZERO {
341 true if r >= half => d.add_wrapping(i256::ONE),
342 false if r <= half_neg => d.sub_wrapping(i256::ONE),
343 _ => d,
344 };
345
346 let integers = if !integers.is_empty() {
347 i256::from_string(integers)
348 .ok_or_else(|| {
349 ArrowError::InvalidArgumentError(format!(
350 "Cannot parse decimal format: {value_str}"
351 ))
352 })
353 .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
354 } else {
355 i256::ZERO
356 };
357
358 format!("{}", integers.add_wrapping(adjusted))
359 } else {
360 let padding = if scale > decimals.len() { scale } else { 0 };
361
362 let decimals = format!("{decimals:0<padding$}");
363 format!("{integers}{decimals}")
364 };
365
366 if negative {
367 number_decimals.insert(0, '-');
368 }
369
370 let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
371 ArrowError::InvalidArgumentError(format!(
372 "Cannot convert {} to {}: Overflow",
373 value_str,
374 T::PREFIX
375 ))
376 })?;
377
378 T::Native::from_decimal(value).ok_or_else(|| {
379 ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
380 })
381}
382
383pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
384 from: &'a S,
385 precision: u8,
386 scale: i8,
387 cast_options: &CastOptions,
388) -> Result<PrimitiveArray<T>, ArrowError>
389where
390 T: DecimalType,
391 T::Native: DecimalCast + ArrowNativeTypeOp,
392 &'a S: StringArrayType<'a>,
393{
394 if cast_options.safe {
395 let iter = from.iter().map(|v| {
396 v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
397 .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
398 });
399 Ok(unsafe {
404 PrimitiveArray::<T>::from_trusted_len_iter(iter)
405 .with_precision_and_scale(precision, scale)?
406 })
407 } else {
408 let vec = from
409 .iter()
410 .map(|v| {
411 v.map(|v| {
412 parse_string_to_decimal_native::<T>(v, scale as usize)
413 .map_err(|_| {
414 ArrowError::CastError(format!(
415 "Cannot cast string '{}' to value of {:?} type",
416 v,
417 T::DATA_TYPE,
418 ))
419 })
420 .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
421 })
422 .transpose()
423 })
424 .collect::<Result<Vec<_>, _>>()?;
425 Ok(unsafe {
430 PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
431 .with_precision_and_scale(precision, scale)?
432 })
433 }
434}
435
436pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
437 from: &GenericStringArray<Offset>,
438 precision: u8,
439 scale: i8,
440 cast_options: &CastOptions,
441) -> Result<PrimitiveArray<T>, ArrowError>
442where
443 T: DecimalType,
444 T::Native: DecimalCast + ArrowNativeTypeOp,
445{
446 generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
447 from,
448 precision,
449 scale,
450 cast_options,
451 )
452}
453
454pub(crate) fn string_view_to_decimal_cast<T>(
455 from: &StringViewArray,
456 precision: u8,
457 scale: i8,
458 cast_options: &CastOptions,
459) -> Result<PrimitiveArray<T>, ArrowError>
460where
461 T: DecimalType,
462 T::Native: DecimalCast + ArrowNativeTypeOp,
463{
464 generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
465}
466
467pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
469 from: &dyn Array,
470 precision: u8,
471 scale: i8,
472 cast_options: &CastOptions,
473) -> Result<ArrayRef, ArrowError>
474where
475 T: DecimalType,
476 T::Native: DecimalCast + ArrowNativeTypeOp,
477{
478 if scale < 0 {
479 return Err(ArrowError::InvalidArgumentError(format!(
480 "Cannot cast string to decimal with negative scale {scale}"
481 )));
482 }
483
484 if scale > T::MAX_SCALE {
485 return Err(ArrowError::InvalidArgumentError(format!(
486 "Cannot cast string to decimal greater than maximum scale {}",
487 T::MAX_SCALE
488 )));
489 }
490
491 let result = match from.data_type() {
492 DataType::Utf8View => string_view_to_decimal_cast::<T>(
493 from.as_any().downcast_ref::<StringViewArray>().unwrap(),
494 precision,
495 scale,
496 cast_options,
497 )?,
498 DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
499 from.as_any()
500 .downcast_ref::<GenericStringArray<Offset>>()
501 .unwrap(),
502 precision,
503 scale,
504 cast_options,
505 )?,
506 other => {
507 return Err(ArrowError::ComputeError(format!(
508 "Cannot cast {:?} to decimal",
509 other
510 )))
511 }
512 };
513
514 Ok(Arc::new(result))
515}
516
517pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
518 array: &PrimitiveArray<T>,
519 precision: u8,
520 scale: i8,
521 cast_options: &CastOptions,
522) -> Result<ArrayRef, ArrowError>
523where
524 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
525 D: DecimalType + ArrowPrimitiveType,
526 <D as ArrowPrimitiveType>::Native: DecimalCast,
527{
528 let mul = 10_f64.powi(scale as i32);
529
530 if cast_options.safe {
531 array
532 .unary_opt::<_, D>(|v| {
533 D::Native::from_f64((mul * v.as_()).round())
534 .filter(|v| D::is_valid_decimal_precision(*v, precision))
535 })
536 .with_precision_and_scale(precision, scale)
537 .map(|a| Arc::new(a) as ArrayRef)
538 } else {
539 array
540 .try_unary::<_, D, _>(|v| {
541 D::Native::from_f64((mul * v.as_()).round())
542 .ok_or_else(|| {
543 ArrowError::CastError(format!(
544 "Cannot cast to {}({}, {}). Overflowing on {:?}",
545 D::PREFIX,
546 precision,
547 scale,
548 v
549 ))
550 })
551 .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v))
552 })?
553 .with_precision_and_scale(precision, scale)
554 .map(|a| Arc::new(a) as ArrayRef)
555 }
556}
557
558pub(crate) fn cast_decimal_to_integer<D, T>(
559 array: &dyn Array,
560 base: D::Native,
561 scale: i8,
562 cast_options: &CastOptions,
563) -> Result<ArrayRef, ArrowError>
564where
565 T: ArrowPrimitiveType,
566 <T as ArrowPrimitiveType>::Native: NumCast,
567 D: DecimalType + ArrowPrimitiveType,
568 <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
569{
570 let array = array.as_primitive::<D>();
571
572 let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
573 ArrowError::CastError(format!(
574 "Cannot cast to {:?}. The scale {} causes overflow.",
575 D::PREFIX,
576 scale,
577 ))
578 })?;
579
580 let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
581
582 if cast_options.safe {
583 for i in 0..array.len() {
584 if array.is_null(i) {
585 value_builder.append_null();
586 } else {
587 let v = array
588 .value(i)
589 .div_checked(div)
590 .ok()
591 .and_then(<T::Native as NumCast>::from::<D::Native>);
592
593 value_builder.append_option(v);
594 }
595 }
596 } else {
597 for i in 0..array.len() {
598 if array.is_null(i) {
599 value_builder.append_null();
600 } else {
601 let v = array.value(i).div_checked(div)?;
602
603 let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
604 ArrowError::CastError(format!(
605 "value of {:?} is out of range {}",
606 v,
607 T::DATA_TYPE
608 ))
609 })?;
610
611 value_builder.append_value(value);
612 }
613 }
614 }
615 Ok(Arc::new(value_builder.finish()))
616}
617
618pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
620 array: &dyn Array,
621 op: F,
622) -> Result<ArrayRef, ArrowError>
623where
624 F: Fn(D::Native) -> T::Native,
625{
626 let array = array.as_primitive::<D>();
627 let array = array.unary::<_, T>(op);
628 Ok(Arc::new(array))
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
637 assert_eq!(
638 parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
639 0_i128
640 );
641 assert_eq!(
642 parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
643 0_i128
644 );
645
646 assert_eq!(
647 parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
648 123_i128
649 );
650 assert_eq!(
651 parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
652 12300000_i128
653 );
654
655 assert_eq!(
656 parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
657 123_i128
658 );
659 assert_eq!(
660 parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
661 12345000_i128
662 );
663
664 assert_eq!(
665 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
666 123_i128
667 );
668 assert_eq!(
669 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
670 12345679_i128
671 );
672 Ok(())
673 }
674}