1use crate::cast::*;
19
20pub trait DecimalCast: Sized {
23 fn to_i32(self) -> Option<i32>;
25
26 fn to_i64(self) -> Option<i64>;
28
29 fn to_i128(self) -> Option<i128>;
31
32 fn to_i256(self) -> Option<i256>;
34
35 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
37
38 fn from_f64(n: f64) -> Option<Self>;
40}
41
42impl DecimalCast for i32 {
43 fn to_i32(self) -> Option<i32> {
44 Some(self)
45 }
46
47 fn to_i64(self) -> Option<i64> {
48 Some(self as i64)
49 }
50
51 fn to_i128(self) -> Option<i128> {
52 Some(self as i128)
53 }
54
55 fn to_i256(self) -> Option<i256> {
56 Some(i256::from_i128(self as i128))
57 }
58
59 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60 n.to_i32()
61 }
62
63 fn from_f64(n: f64) -> Option<Self> {
64 n.to_i32()
65 }
66}
67
68impl DecimalCast for i64 {
69 fn to_i32(self) -> Option<i32> {
70 i32::try_from(self).ok()
71 }
72
73 fn to_i64(self) -> Option<i64> {
74 Some(self)
75 }
76
77 fn to_i128(self) -> Option<i128> {
78 Some(self as i128)
79 }
80
81 fn to_i256(self) -> Option<i256> {
82 Some(i256::from_i128(self as i128))
83 }
84
85 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
86 n.to_i64()
87 }
88
89 fn from_f64(n: f64) -> Option<Self> {
90 num_traits::ToPrimitive::to_i64(&n)
93 }
94}
95
96impl DecimalCast for i128 {
97 fn to_i32(self) -> Option<i32> {
98 i32::try_from(self).ok()
99 }
100
101 fn to_i64(self) -> Option<i64> {
102 i64::try_from(self).ok()
103 }
104
105 fn to_i128(self) -> Option<i128> {
106 Some(self)
107 }
108
109 fn to_i256(self) -> Option<i256> {
110 Some(i256::from_i128(self))
111 }
112
113 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
114 n.to_i128()
115 }
116
117 fn from_f64(n: f64) -> Option<Self> {
118 n.to_i128()
119 }
120}
121
122impl DecimalCast for i256 {
123 fn to_i32(self) -> Option<i32> {
124 self.to_i128().map(|x| i32::try_from(x).ok())?
125 }
126
127 fn to_i64(self) -> Option<i64> {
128 self.to_i128().map(|x| i64::try_from(x).ok())?
129 }
130
131 fn to_i128(self) -> Option<i128> {
132 self.to_i128()
133 }
134
135 fn to_i256(self) -> Option<i256> {
136 Some(self)
137 }
138
139 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
140 n.to_i256()
141 }
142
143 fn from_f64(n: f64) -> Option<Self> {
144 i256::from_f64(n)
145 }
146}
147
148pub(crate) fn cast_decimal_to_decimal_error<I, O>(
149 output_precision: u8,
150 output_scale: i8,
151) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
152where
153 I: DecimalType,
154 O: DecimalType,
155 I::Native: DecimalCast + ArrowNativeTypeOp,
156 O::Native: DecimalCast + ArrowNativeTypeOp,
157{
158 move |x: I::Native| {
159 ArrowError::CastError(format!(
160 "Cannot cast to {}({}, {}). Overflowing on {:?}",
161 O::PREFIX,
162 output_precision,
163 output_scale,
164 x
165 ))
166 }
167}
168
169pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
170 array: &PrimitiveArray<I>,
171 input_precision: u8,
172 input_scale: i8,
173 output_precision: u8,
174 output_scale: i8,
175 cast_options: &CastOptions,
176) -> Result<PrimitiveArray<O>, ArrowError>
177where
178 I: DecimalType,
179 O: DecimalType,
180 I::Native: DecimalCast + ArrowNativeTypeOp,
181 O::Native: DecimalCast + ArrowNativeTypeOp,
182{
183 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
184 let delta_scale = input_scale - output_scale;
185 let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
196
197 let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else {
204 let zeros = vec![O::Native::ZERO; array.len()];
205 return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned()));
206 };
207
208 let div = max.add_wrapping(I::Native::ONE);
209 let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
210 let half_neg = half.neg_wrapping();
211
212 let f = |x: I::Native| {
213 let d = x.div_wrapping(div);
215 let r = x.mod_wrapping(div);
216
217 let adjusted = match x >= I::Native::ZERO {
219 true if r >= half => d.add_wrapping(I::Native::ONE),
220 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
221 _ => d,
222 };
223 O::Native::from_decimal(adjusted)
224 };
225
226 Ok(if is_infallible_cast {
227 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
229 let g = |x: I::Native| f(x).unwrap(); array.unary(g)
232 } else if cast_options.safe {
233 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
234 } else {
235 array.try_unary(|x| {
236 f(x).ok_or_else(|| error(x)).and_then(|v| {
237 O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
238 })
239 })?
240 })
241}
242
243pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
244 array: &PrimitiveArray<I>,
245 input_precision: u8,
246 input_scale: i8,
247 output_precision: u8,
248 output_scale: i8,
249 cast_options: &CastOptions,
250) -> Result<PrimitiveArray<O>, ArrowError>
251where
252 I: DecimalType,
253 O: DecimalType,
254 I::Native: DecimalCast + ArrowNativeTypeOp,
255 O::Native: DecimalCast + ArrowNativeTypeOp,
256{
257 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
258 let delta_scale = output_scale - input_scale;
259 let mul = O::Native::from_decimal(10_i128)
260 .unwrap()
261 .pow_checked(delta_scale as u32)?;
262
263 let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
270 let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
271
272 Ok(if is_infallible_cast {
273 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
275 let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
277 array.unary(f)
278 } else if cast_options.safe {
279 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
280 } else {
281 array.try_unary(|x| {
282 f(x).ok_or_else(|| error(x)).and_then(|v| {
283 O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
284 })
285 })?
286 })
287}
288
289pub(crate) fn cast_decimal_to_decimal_same_type<T>(
291 array: &PrimitiveArray<T>,
292 input_precision: u8,
293 input_scale: i8,
294 output_precision: u8,
295 output_scale: i8,
296 cast_options: &CastOptions,
297) -> Result<ArrayRef, ArrowError>
298where
299 T: DecimalType,
300 T::Native: DecimalCast + ArrowNativeTypeOp,
301{
302 let array: PrimitiveArray<T> =
303 if input_scale == output_scale && input_precision <= output_precision {
304 array.clone()
305 } else if input_scale <= output_scale {
306 convert_to_bigger_or_equal_scale_decimal::<T, T>(
307 array,
308 input_precision,
309 input_scale,
310 output_precision,
311 output_scale,
312 cast_options,
313 )?
314 } else {
315 convert_to_smaller_scale_decimal::<T, T>(
317 array,
318 input_precision,
319 input_scale,
320 output_precision,
321 output_scale,
322 cast_options,
323 )?
324 };
325
326 Ok(Arc::new(array.with_precision_and_scale(
327 output_precision,
328 output_scale,
329 )?))
330}
331
332pub(crate) fn cast_decimal_to_decimal<I, O>(
334 array: &PrimitiveArray<I>,
335 input_precision: u8,
336 input_scale: i8,
337 output_precision: u8,
338 output_scale: i8,
339 cast_options: &CastOptions,
340) -> Result<ArrayRef, ArrowError>
341where
342 I: DecimalType,
343 O: DecimalType,
344 I::Native: DecimalCast + ArrowNativeTypeOp,
345 O::Native: DecimalCast + ArrowNativeTypeOp,
346{
347 let array: PrimitiveArray<O> = if input_scale > output_scale {
348 convert_to_smaller_scale_decimal::<I, O>(
349 array,
350 input_precision,
351 input_scale,
352 output_precision,
353 output_scale,
354 cast_options,
355 )?
356 } else {
357 convert_to_bigger_or_equal_scale_decimal::<I, O>(
358 array,
359 input_precision,
360 input_scale,
361 output_precision,
362 output_scale,
363 cast_options,
364 )?
365 };
366
367 Ok(Arc::new(array.with_precision_and_scale(
368 output_precision,
369 output_scale,
370 )?))
371}
372
373pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
376 value_str: &str,
377 scale: usize,
378) -> Result<T::Native, ArrowError>
379where
380 T::Native: DecimalCast + ArrowNativeTypeOp,
381{
382 let value_str = value_str.trim();
383 let parts: Vec<&str> = value_str.split('.').collect();
384 if parts.len() > 2 {
385 return Err(ArrowError::InvalidArgumentError(format!(
386 "Invalid decimal format: {value_str:?}"
387 )));
388 }
389
390 let (negative, first_part) = if parts[0].is_empty() {
391 (false, parts[0])
392 } else {
393 match parts[0].as_bytes()[0] {
394 b'-' => (true, &parts[0][1..]),
395 b'+' => (false, &parts[0][1..]),
396 _ => (false, parts[0]),
397 }
398 };
399
400 let integers = first_part;
401 let decimals = if parts.len() == 2 { parts[1] } else { "" };
402
403 if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
404 return Err(ArrowError::InvalidArgumentError(format!(
405 "Invalid decimal format: {value_str:?}"
406 )));
407 }
408
409 if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
410 return Err(ArrowError::InvalidArgumentError(format!(
411 "Invalid decimal format: {value_str:?}"
412 )));
413 }
414
415 let mut number_decimals = if decimals.len() > scale {
417 let decimal_number = i256::from_string(decimals).ok_or_else(|| {
418 ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
419 })?;
420
421 let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
422
423 let half = div.div_wrapping(i256::from_i128(2));
424 let half_neg = half.neg_wrapping();
425
426 let d = decimal_number.div_wrapping(div);
427 let r = decimal_number.mod_wrapping(div);
428
429 let adjusted = match decimal_number >= i256::ZERO {
431 true if r >= half => d.add_wrapping(i256::ONE),
432 false if r <= half_neg => d.sub_wrapping(i256::ONE),
433 _ => d,
434 };
435
436 let integers = if !integers.is_empty() {
437 i256::from_string(integers)
438 .ok_or_else(|| {
439 ArrowError::InvalidArgumentError(format!(
440 "Cannot parse decimal format: {value_str}"
441 ))
442 })
443 .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
444 } else {
445 i256::ZERO
446 };
447
448 format!("{}", integers.add_wrapping(adjusted))
449 } else {
450 let padding = if scale > decimals.len() { scale } else { 0 };
451
452 let decimals = format!("{decimals:0<padding$}");
453 format!("{integers}{decimals}")
454 };
455
456 if negative {
457 number_decimals.insert(0, '-');
458 }
459
460 let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
461 ArrowError::InvalidArgumentError(format!(
462 "Cannot convert {} to {}: Overflow",
463 value_str,
464 T::PREFIX
465 ))
466 })?;
467
468 T::Native::from_decimal(value).ok_or_else(|| {
469 ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
470 })
471}
472
473pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
474 from: &'a S,
475 precision: u8,
476 scale: i8,
477 cast_options: &CastOptions,
478) -> Result<PrimitiveArray<T>, ArrowError>
479where
480 T: DecimalType,
481 T::Native: DecimalCast + ArrowNativeTypeOp,
482 &'a S: StringArrayType<'a>,
483{
484 if cast_options.safe {
485 let iter = from.iter().map(|v| {
486 v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
487 .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
488 });
489 Ok(unsafe {
494 PrimitiveArray::<T>::from_trusted_len_iter(iter)
495 .with_precision_and_scale(precision, scale)?
496 })
497 } else {
498 let vec = from
499 .iter()
500 .map(|v| {
501 v.map(|v| {
502 parse_string_to_decimal_native::<T>(v, scale as usize)
503 .map_err(|_| {
504 ArrowError::CastError(format!(
505 "Cannot cast string '{v}' to value of {} type",
506 T::DATA_TYPE,
507 ))
508 })
509 .and_then(|v| T::validate_decimal_precision(v, precision, scale).map(|_| v))
510 })
511 .transpose()
512 })
513 .collect::<Result<Vec<_>, _>>()?;
514 Ok(unsafe {
519 PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
520 .with_precision_and_scale(precision, scale)?
521 })
522 }
523}
524
525pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
526 from: &GenericStringArray<Offset>,
527 precision: u8,
528 scale: i8,
529 cast_options: &CastOptions,
530) -> Result<PrimitiveArray<T>, ArrowError>
531where
532 T: DecimalType,
533 T::Native: DecimalCast + ArrowNativeTypeOp,
534{
535 generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
536 from,
537 precision,
538 scale,
539 cast_options,
540 )
541}
542
543pub(crate) fn string_view_to_decimal_cast<T>(
544 from: &StringViewArray,
545 precision: u8,
546 scale: i8,
547 cast_options: &CastOptions,
548) -> Result<PrimitiveArray<T>, ArrowError>
549where
550 T: DecimalType,
551 T::Native: DecimalCast + ArrowNativeTypeOp,
552{
553 generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
554}
555
556pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
558 from: &dyn Array,
559 precision: u8,
560 scale: i8,
561 cast_options: &CastOptions,
562) -> Result<ArrayRef, ArrowError>
563where
564 T: DecimalType,
565 T::Native: DecimalCast + ArrowNativeTypeOp,
566{
567 if scale < 0 {
568 return Err(ArrowError::InvalidArgumentError(format!(
569 "Cannot cast string to decimal with negative scale {scale}"
570 )));
571 }
572
573 if scale > T::MAX_SCALE {
574 return Err(ArrowError::InvalidArgumentError(format!(
575 "Cannot cast string to decimal greater than maximum scale {}",
576 T::MAX_SCALE
577 )));
578 }
579
580 let result = match from.data_type() {
581 DataType::Utf8View => string_view_to_decimal_cast::<T>(
582 from.as_any().downcast_ref::<StringViewArray>().unwrap(),
583 precision,
584 scale,
585 cast_options,
586 )?,
587 DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
588 from.as_any()
589 .downcast_ref::<GenericStringArray<Offset>>()
590 .unwrap(),
591 precision,
592 scale,
593 cast_options,
594 )?,
595 other => {
596 return Err(ArrowError::ComputeError(format!(
597 "Cannot cast {other:?} to decimal",
598 )));
599 }
600 };
601
602 Ok(Arc::new(result))
603}
604
605pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
606 array: &PrimitiveArray<T>,
607 precision: u8,
608 scale: i8,
609 cast_options: &CastOptions,
610) -> Result<ArrayRef, ArrowError>
611where
612 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
613 D: DecimalType + ArrowPrimitiveType,
614 <D as ArrowPrimitiveType>::Native: DecimalCast,
615{
616 let mul = 10_f64.powi(scale as i32);
617
618 if cast_options.safe {
619 array
620 .unary_opt::<_, D>(|v| {
621 D::Native::from_f64((mul * v.as_()).round())
622 .filter(|v| D::is_valid_decimal_precision(*v, precision))
623 })
624 .with_precision_and_scale(precision, scale)
625 .map(|a| Arc::new(a) as ArrayRef)
626 } else {
627 array
628 .try_unary::<_, D, _>(|v| {
629 D::Native::from_f64((mul * v.as_()).round())
630 .ok_or_else(|| {
631 ArrowError::CastError(format!(
632 "Cannot cast to {}({}, {}). Overflowing on {:?}",
633 D::PREFIX,
634 precision,
635 scale,
636 v
637 ))
638 })
639 .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v))
640 })?
641 .with_precision_and_scale(precision, scale)
642 .map(|a| Arc::new(a) as ArrayRef)
643 }
644}
645
646pub(crate) fn cast_decimal_to_integer<D, T>(
647 array: &dyn Array,
648 base: D::Native,
649 scale: i8,
650 cast_options: &CastOptions,
651) -> Result<ArrayRef, ArrowError>
652where
653 T: ArrowPrimitiveType,
654 <T as ArrowPrimitiveType>::Native: NumCast,
655 D: DecimalType + ArrowPrimitiveType,
656 <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
657{
658 let array = array.as_primitive::<D>();
659
660 let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
661 ArrowError::CastError(format!(
662 "Cannot cast to {:?}. The scale {} causes overflow.",
663 D::PREFIX,
664 scale,
665 ))
666 })?;
667
668 let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
669
670 if cast_options.safe {
671 for i in 0..array.len() {
672 if array.is_null(i) {
673 value_builder.append_null();
674 } else {
675 let v = array
676 .value(i)
677 .div_checked(div)
678 .ok()
679 .and_then(<T::Native as NumCast>::from::<D::Native>);
680
681 value_builder.append_option(v);
682 }
683 }
684 } else {
685 for i in 0..array.len() {
686 if array.is_null(i) {
687 value_builder.append_null();
688 } else {
689 let v = array.value(i).div_checked(div)?;
690
691 let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
692 ArrowError::CastError(format!(
693 "value of {:?} is out of range {}",
694 v,
695 T::DATA_TYPE
696 ))
697 })?;
698
699 value_builder.append_value(value);
700 }
701 }
702 }
703 Ok(Arc::new(value_builder.finish()))
704}
705
706pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
712 array: &dyn Array,
713 op: F,
714) -> Result<ArrayRef, ArrowError>
715where
716 F: Fn(D::Native) -> T::Native,
717{
718 let array = array.as_primitive::<D>();
719 let array = array.unary::<_, T>(op);
720 Ok(Arc::new(array))
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726
727 #[test]
728 fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
729 assert_eq!(
730 parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
731 0_i128
732 );
733 assert_eq!(
734 parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
735 0_i128
736 );
737
738 assert_eq!(
739 parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
740 123_i128
741 );
742 assert_eq!(
743 parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
744 12300000_i128
745 );
746
747 assert_eq!(
748 parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
749 123_i128
750 );
751 assert_eq!(
752 parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
753 12345000_i128
754 );
755
756 assert_eq!(
757 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
758 123_i128
759 );
760 assert_eq!(
761 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
762 12345679_i128
763 );
764 Ok(())
765 }
766}