1use crate::cast::*;
19
20pub(crate) trait DecimalCast: Sized {
23 fn to_i32(self) -> Option<i32>;
24
25 fn to_i64(self) -> Option<i64>;
26
27 fn to_i128(self) -> Option<i128>;
28
29 fn to_i256(self) -> Option<i256>;
30
31 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
32
33 fn from_f64(n: f64) -> Option<Self>;
34}
35
36impl DecimalCast for i32 {
37 fn to_i32(self) -> Option<i32> {
38 Some(self)
39 }
40
41 fn to_i64(self) -> Option<i64> {
42 Some(self as i64)
43 }
44
45 fn to_i128(self) -> Option<i128> {
46 Some(self as i128)
47 }
48
49 fn to_i256(self) -> Option<i256> {
50 Some(i256::from_i128(self as i128))
51 }
52
53 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
54 n.to_i32()
55 }
56
57 fn from_f64(n: f64) -> Option<Self> {
58 n.to_i32()
59 }
60}
61
62impl DecimalCast for i64 {
63 fn to_i32(self) -> Option<i32> {
64 i32::try_from(self).ok()
65 }
66
67 fn to_i64(self) -> Option<i64> {
68 Some(self)
69 }
70
71 fn to_i128(self) -> Option<i128> {
72 Some(self as i128)
73 }
74
75 fn to_i256(self) -> Option<i256> {
76 Some(i256::from_i128(self as i128))
77 }
78
79 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
80 n.to_i64()
81 }
82
83 fn from_f64(n: f64) -> Option<Self> {
84 n.to_i64()
85 }
86}
87
88impl DecimalCast for i128 {
89 fn to_i32(self) -> Option<i32> {
90 i32::try_from(self).ok()
91 }
92
93 fn to_i64(self) -> Option<i64> {
94 i64::try_from(self).ok()
95 }
96
97 fn to_i128(self) -> Option<i128> {
98 Some(self)
99 }
100
101 fn to_i256(self) -> Option<i256> {
102 Some(i256::from_i128(self))
103 }
104
105 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
106 n.to_i128()
107 }
108
109 fn from_f64(n: f64) -> Option<Self> {
110 n.to_i128()
111 }
112}
113
114impl DecimalCast for i256 {
115 fn to_i32(self) -> Option<i32> {
116 self.to_i128().map(|x| i32::try_from(x).ok())?
117 }
118
119 fn to_i64(self) -> Option<i64> {
120 self.to_i128().map(|x| i64::try_from(x).ok())?
121 }
122
123 fn to_i128(self) -> Option<i128> {
124 self.to_i128()
125 }
126
127 fn to_i256(self) -> Option<i256> {
128 Some(self)
129 }
130
131 fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
132 n.to_i256()
133 }
134
135 fn from_f64(n: f64) -> Option<Self> {
136 i256::from_f64(n)
137 }
138}
139
140pub(crate) fn cast_decimal_to_decimal_error<I, O>(
141 output_precision: u8,
142 output_scale: i8,
143) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
144where
145 I: DecimalType,
146 O: DecimalType,
147 I::Native: DecimalCast + ArrowNativeTypeOp,
148 O::Native: DecimalCast + ArrowNativeTypeOp,
149{
150 move |x: I::Native| {
151 ArrowError::CastError(format!(
152 "Cannot cast to {}({}, {}). Overflowing on {:?}",
153 O::PREFIX,
154 output_precision,
155 output_scale,
156 x
157 ))
158 }
159}
160
161pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
162 array: &PrimitiveArray<I>,
163 input_precision: u8,
164 input_scale: i8,
165 output_precision: u8,
166 output_scale: i8,
167 cast_options: &CastOptions,
168) -> Result<PrimitiveArray<O>, ArrowError>
169where
170 I: DecimalType,
171 O: DecimalType,
172 I::Native: DecimalCast + ArrowNativeTypeOp,
173 O::Native: DecimalCast + ArrowNativeTypeOp,
174{
175 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
176 let delta_scale = input_scale - output_scale;
177 let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
188
189 let div = I::Native::from_decimal(10_i128)
190 .unwrap()
191 .pow_checked(delta_scale as u32)?;
192
193 let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
194 let half_neg = half.neg_wrapping();
195
196 let f = |x: I::Native| {
197 let d = x.div_wrapping(div);
199 let r = x.mod_wrapping(div);
200
201 let adjusted = match x >= I::Native::ZERO {
203 true if r >= half => d.add_wrapping(I::Native::ONE),
204 false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
205 _ => d,
206 };
207 O::Native::from_decimal(adjusted)
208 };
209
210 Ok(if is_infallible_cast {
211 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
213 let g = |x: I::Native| f(x).unwrap(); array.unary(g)
216 } else if cast_options.safe {
217 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
218 } else {
219 array.try_unary(|x| {
220 f(x).ok_or_else(|| error(x))
221 .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
222 })?
223 })
224}
225
226pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
227 array: &PrimitiveArray<I>,
228 input_precision: u8,
229 input_scale: i8,
230 output_precision: u8,
231 output_scale: i8,
232 cast_options: &CastOptions,
233) -> Result<PrimitiveArray<O>, ArrowError>
234where
235 I: DecimalType,
236 O: DecimalType,
237 I::Native: DecimalCast + ArrowNativeTypeOp,
238 O::Native: DecimalCast + ArrowNativeTypeOp,
239{
240 let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
241 let delta_scale = output_scale - input_scale;
242 let mul = O::Native::from_decimal(10_i128)
243 .unwrap()
244 .pow_checked(delta_scale as u32)?;
245
246 let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
253 let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
254
255 Ok(if is_infallible_cast {
256 validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
258 let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
260 array.unary(f)
261 } else if cast_options.safe {
262 array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
263 } else {
264 array.try_unary(|x| {
265 f(x).ok_or_else(|| error(x))
266 .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
267 })?
268 })
269}
270
271pub(crate) fn cast_decimal_to_decimal_same_type<T>(
273 array: &PrimitiveArray<T>,
274 input_precision: u8,
275 input_scale: i8,
276 output_precision: u8,
277 output_scale: i8,
278 cast_options: &CastOptions,
279) -> Result<ArrayRef, ArrowError>
280where
281 T: DecimalType,
282 T::Native: DecimalCast + ArrowNativeTypeOp,
283{
284 let array: PrimitiveArray<T> =
285 if input_scale == output_scale && input_precision <= output_precision {
286 array.clone()
287 } else if input_scale <= output_scale {
288 convert_to_bigger_or_equal_scale_decimal::<T, T>(
289 array,
290 input_precision,
291 input_scale,
292 output_precision,
293 output_scale,
294 cast_options,
295 )?
296 } else {
297 convert_to_smaller_scale_decimal::<T, T>(
299 array,
300 input_precision,
301 input_scale,
302 output_precision,
303 output_scale,
304 cast_options,
305 )?
306 };
307
308 Ok(Arc::new(array.with_precision_and_scale(
309 output_precision,
310 output_scale,
311 )?))
312}
313
314pub(crate) fn cast_decimal_to_decimal<I, O>(
316 array: &PrimitiveArray<I>,
317 input_precision: u8,
318 input_scale: i8,
319 output_precision: u8,
320 output_scale: i8,
321 cast_options: &CastOptions,
322) -> Result<ArrayRef, ArrowError>
323where
324 I: DecimalType,
325 O: DecimalType,
326 I::Native: DecimalCast + ArrowNativeTypeOp,
327 O::Native: DecimalCast + ArrowNativeTypeOp,
328{
329 let array: PrimitiveArray<O> = if input_scale > output_scale {
330 convert_to_smaller_scale_decimal::<I, O>(
331 array,
332 input_precision,
333 input_scale,
334 output_precision,
335 output_scale,
336 cast_options,
337 )?
338 } else {
339 convert_to_bigger_or_equal_scale_decimal::<I, O>(
340 array,
341 input_precision,
342 input_scale,
343 output_precision,
344 output_scale,
345 cast_options,
346 )?
347 };
348
349 Ok(Arc::new(array.with_precision_and_scale(
350 output_precision,
351 output_scale,
352 )?))
353}
354
355pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
358 value_str: &str,
359 scale: usize,
360) -> Result<T::Native, ArrowError>
361where
362 T::Native: DecimalCast + ArrowNativeTypeOp,
363{
364 let value_str = value_str.trim();
365 let parts: Vec<&str> = value_str.split('.').collect();
366 if parts.len() > 2 {
367 return Err(ArrowError::InvalidArgumentError(format!(
368 "Invalid decimal format: {value_str:?}"
369 )));
370 }
371
372 let (negative, first_part) = if parts[0].is_empty() {
373 (false, parts[0])
374 } else {
375 match parts[0].as_bytes()[0] {
376 b'-' => (true, &parts[0][1..]),
377 b'+' => (false, &parts[0][1..]),
378 _ => (false, parts[0]),
379 }
380 };
381
382 let integers = first_part;
383 let decimals = if parts.len() == 2 { parts[1] } else { "" };
384
385 if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
386 return Err(ArrowError::InvalidArgumentError(format!(
387 "Invalid decimal format: {value_str:?}"
388 )));
389 }
390
391 if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
392 return Err(ArrowError::InvalidArgumentError(format!(
393 "Invalid decimal format: {value_str:?}"
394 )));
395 }
396
397 let mut number_decimals = if decimals.len() > scale {
399 let decimal_number = i256::from_string(decimals).ok_or_else(|| {
400 ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
401 })?;
402
403 let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
404
405 let half = div.div_wrapping(i256::from_i128(2));
406 let half_neg = half.neg_wrapping();
407
408 let d = decimal_number.div_wrapping(div);
409 let r = decimal_number.mod_wrapping(div);
410
411 let adjusted = match decimal_number >= i256::ZERO {
413 true if r >= half => d.add_wrapping(i256::ONE),
414 false if r <= half_neg => d.sub_wrapping(i256::ONE),
415 _ => d,
416 };
417
418 let integers = if !integers.is_empty() {
419 i256::from_string(integers)
420 .ok_or_else(|| {
421 ArrowError::InvalidArgumentError(format!(
422 "Cannot parse decimal format: {value_str}"
423 ))
424 })
425 .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
426 } else {
427 i256::ZERO
428 };
429
430 format!("{}", integers.add_wrapping(adjusted))
431 } else {
432 let padding = if scale > decimals.len() { scale } else { 0 };
433
434 let decimals = format!("{decimals:0<padding$}");
435 format!("{integers}{decimals}")
436 };
437
438 if negative {
439 number_decimals.insert(0, '-');
440 }
441
442 let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
443 ArrowError::InvalidArgumentError(format!(
444 "Cannot convert {} to {}: Overflow",
445 value_str,
446 T::PREFIX
447 ))
448 })?;
449
450 T::Native::from_decimal(value).ok_or_else(|| {
451 ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
452 })
453}
454
455pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
456 from: &'a S,
457 precision: u8,
458 scale: i8,
459 cast_options: &CastOptions,
460) -> Result<PrimitiveArray<T>, ArrowError>
461where
462 T: DecimalType,
463 T::Native: DecimalCast + ArrowNativeTypeOp,
464 &'a S: StringArrayType<'a>,
465{
466 if cast_options.safe {
467 let iter = from.iter().map(|v| {
468 v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
469 .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
470 });
471 Ok(unsafe {
476 PrimitiveArray::<T>::from_trusted_len_iter(iter)
477 .with_precision_and_scale(precision, scale)?
478 })
479 } else {
480 let vec = from
481 .iter()
482 .map(|v| {
483 v.map(|v| {
484 parse_string_to_decimal_native::<T>(v, scale as usize)
485 .map_err(|_| {
486 ArrowError::CastError(format!(
487 "Cannot cast string '{}' to value of {:?} type",
488 v,
489 T::DATA_TYPE,
490 ))
491 })
492 .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
493 })
494 .transpose()
495 })
496 .collect::<Result<Vec<_>, _>>()?;
497 Ok(unsafe {
502 PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
503 .with_precision_and_scale(precision, scale)?
504 })
505 }
506}
507
508pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
509 from: &GenericStringArray<Offset>,
510 precision: u8,
511 scale: i8,
512 cast_options: &CastOptions,
513) -> Result<PrimitiveArray<T>, ArrowError>
514where
515 T: DecimalType,
516 T::Native: DecimalCast + ArrowNativeTypeOp,
517{
518 generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
519 from,
520 precision,
521 scale,
522 cast_options,
523 )
524}
525
526pub(crate) fn string_view_to_decimal_cast<T>(
527 from: &StringViewArray,
528 precision: u8,
529 scale: i8,
530 cast_options: &CastOptions,
531) -> Result<PrimitiveArray<T>, ArrowError>
532where
533 T: DecimalType,
534 T::Native: DecimalCast + ArrowNativeTypeOp,
535{
536 generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
537}
538
539pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
541 from: &dyn Array,
542 precision: u8,
543 scale: i8,
544 cast_options: &CastOptions,
545) -> Result<ArrayRef, ArrowError>
546where
547 T: DecimalType,
548 T::Native: DecimalCast + ArrowNativeTypeOp,
549{
550 if scale < 0 {
551 return Err(ArrowError::InvalidArgumentError(format!(
552 "Cannot cast string to decimal with negative scale {scale}"
553 )));
554 }
555
556 if scale > T::MAX_SCALE {
557 return Err(ArrowError::InvalidArgumentError(format!(
558 "Cannot cast string to decimal greater than maximum scale {}",
559 T::MAX_SCALE
560 )));
561 }
562
563 let result = match from.data_type() {
564 DataType::Utf8View => string_view_to_decimal_cast::<T>(
565 from.as_any().downcast_ref::<StringViewArray>().unwrap(),
566 precision,
567 scale,
568 cast_options,
569 )?,
570 DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
571 from.as_any()
572 .downcast_ref::<GenericStringArray<Offset>>()
573 .unwrap(),
574 precision,
575 scale,
576 cast_options,
577 )?,
578 other => {
579 return Err(ArrowError::ComputeError(format!(
580 "Cannot cast {other:?} to decimal",
581 )))
582 }
583 };
584
585 Ok(Arc::new(result))
586}
587
588pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
589 array: &PrimitiveArray<T>,
590 precision: u8,
591 scale: i8,
592 cast_options: &CastOptions,
593) -> Result<ArrayRef, ArrowError>
594where
595 <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
596 D: DecimalType + ArrowPrimitiveType,
597 <D as ArrowPrimitiveType>::Native: DecimalCast,
598{
599 let mul = 10_f64.powi(scale as i32);
600
601 if cast_options.safe {
602 array
603 .unary_opt::<_, D>(|v| {
604 D::Native::from_f64((mul * v.as_()).round())
605 .filter(|v| D::is_valid_decimal_precision(*v, precision))
606 })
607 .with_precision_and_scale(precision, scale)
608 .map(|a| Arc::new(a) as ArrayRef)
609 } else {
610 array
611 .try_unary::<_, D, _>(|v| {
612 D::Native::from_f64((mul * v.as_()).round())
613 .ok_or_else(|| {
614 ArrowError::CastError(format!(
615 "Cannot cast to {}({}, {}). Overflowing on {:?}",
616 D::PREFIX,
617 precision,
618 scale,
619 v
620 ))
621 })
622 .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v))
623 })?
624 .with_precision_and_scale(precision, scale)
625 .map(|a| Arc::new(a) as ArrayRef)
626 }
627}
628
629pub(crate) fn cast_decimal_to_integer<D, T>(
630 array: &dyn Array,
631 base: D::Native,
632 scale: i8,
633 cast_options: &CastOptions,
634) -> Result<ArrayRef, ArrowError>
635where
636 T: ArrowPrimitiveType,
637 <T as ArrowPrimitiveType>::Native: NumCast,
638 D: DecimalType + ArrowPrimitiveType,
639 <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
640{
641 let array = array.as_primitive::<D>();
642
643 let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
644 ArrowError::CastError(format!(
645 "Cannot cast to {:?}. The scale {} causes overflow.",
646 D::PREFIX,
647 scale,
648 ))
649 })?;
650
651 let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
652
653 if cast_options.safe {
654 for i in 0..array.len() {
655 if array.is_null(i) {
656 value_builder.append_null();
657 } else {
658 let v = array
659 .value(i)
660 .div_checked(div)
661 .ok()
662 .and_then(<T::Native as NumCast>::from::<D::Native>);
663
664 value_builder.append_option(v);
665 }
666 }
667 } else {
668 for i in 0..array.len() {
669 if array.is_null(i) {
670 value_builder.append_null();
671 } else {
672 let v = array.value(i).div_checked(div)?;
673
674 let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
675 ArrowError::CastError(format!(
676 "value of {:?} is out of range {}",
677 v,
678 T::DATA_TYPE
679 ))
680 })?;
681
682 value_builder.append_value(value);
683 }
684 }
685 }
686 Ok(Arc::new(value_builder.finish()))
687}
688
689pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
695 array: &dyn Array,
696 op: F,
697) -> Result<ArrayRef, ArrowError>
698where
699 F: Fn(D::Native) -> T::Native,
700{
701 let array = array.as_primitive::<D>();
702 let array = array.unary::<_, T>(op);
703 Ok(Arc::new(array))
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
712 assert_eq!(
713 parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
714 0_i128
715 );
716 assert_eq!(
717 parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
718 0_i128
719 );
720
721 assert_eq!(
722 parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
723 123_i128
724 );
725 assert_eq!(
726 parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
727 12300000_i128
728 );
729
730 assert_eq!(
731 parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
732 123_i128
733 );
734 assert_eq!(
735 parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
736 12345000_i128
737 );
738
739 assert_eq!(
740 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
741 123_i128
742 );
743 assert_eq!(
744 parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
745 12345679_i128
746 );
747 Ok(())
748 }
749}