arrow_cast/cast/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cast::*;
19
20/// A utility trait that provides checked conversions between
21/// decimal types inspired by [`NumCast`]
22pub(crate) trait DecimalCast: Sized {
23    fn to_i128(self) -> Option<i128>;
24
25    fn to_i256(self) -> Option<i256>;
26
27    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
28
29    fn from_f64(n: f64) -> Option<Self>;
30}
31
32impl DecimalCast for i128 {
33    fn to_i128(self) -> Option<i128> {
34        Some(self)
35    }
36
37    fn to_i256(self) -> Option<i256> {
38        Some(i256::from_i128(self))
39    }
40
41    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
42        n.to_i128()
43    }
44
45    fn from_f64(n: f64) -> Option<Self> {
46        n.to_i128()
47    }
48}
49
50impl DecimalCast for i256 {
51    fn to_i128(self) -> Option<i128> {
52        self.to_i128()
53    }
54
55    fn to_i256(self) -> Option<i256> {
56        Some(self)
57    }
58
59    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60        n.to_i256()
61    }
62
63    fn from_f64(n: f64) -> Option<Self> {
64        i256::from_f64(n)
65    }
66}
67
68pub(crate) fn cast_decimal_to_decimal_error<I, O>(
69    output_precision: u8,
70    output_scale: i8,
71) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
72where
73    I: DecimalType,
74    O: DecimalType,
75    I::Native: DecimalCast + ArrowNativeTypeOp,
76    O::Native: DecimalCast + ArrowNativeTypeOp,
77{
78    move |x: I::Native| {
79        ArrowError::CastError(format!(
80            "Cannot cast to {}({}, {}). Overflowing on {:?}",
81            O::PREFIX,
82            output_precision,
83            output_scale,
84            x
85        ))
86    }
87}
88
89pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
90    array: &PrimitiveArray<I>,
91    input_precision: u8,
92    input_scale: i8,
93    output_precision: u8,
94    output_scale: i8,
95    cast_options: &CastOptions,
96) -> Result<PrimitiveArray<O>, ArrowError>
97where
98    I: DecimalType,
99    O: DecimalType,
100    I::Native: DecimalCast + ArrowNativeTypeOp,
101    O::Native: DecimalCast + ArrowNativeTypeOp,
102{
103    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
104    let delta_scale = input_scale - output_scale;
105    // if the reduction of the input number through scaling (dividing) is greater
106    // than a possible precision loss (plus potential increase via rounding)
107    // every input number will fit into the output type
108    // Example: If we are starting with any number of precision 5 [xxxxx],
109    // then and decrease the scale by 3 will have the following effect on the representation:
110    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
111    // The rounding may add an additional digit, so the cast to be infallible,
112    // the output type needs to have at least 3 digits of precision.
113    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
114    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
115    let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
116
117    let div = I::Native::from_decimal(10_i128)
118        .unwrap()
119        .pow_checked(delta_scale as u32)?;
120
121    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
122    let half_neg = half.neg_wrapping();
123
124    let f = |x: I::Native| {
125        // div is >= 10 and so this cannot overflow
126        let d = x.div_wrapping(div);
127        let r = x.mod_wrapping(div);
128
129        // Round result
130        let adjusted = match x >= I::Native::ZERO {
131            true if r >= half => d.add_wrapping(I::Native::ONE),
132            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
133            _ => d,
134        };
135        O::Native::from_decimal(adjusted)
136    };
137
138    Ok(if is_infallible_cast {
139        // make sure we don't perform calculations that don't make sense w/o validation
140        validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
141        let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
142                                              // to fit into the target type
143        array.unary(g)
144    } else if cast_options.safe {
145        array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
146    } else {
147        array.try_unary(|x| {
148            f(x).ok_or_else(|| error(x))
149                .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
150        })?
151    })
152}
153
154pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
155    array: &PrimitiveArray<I>,
156    input_precision: u8,
157    input_scale: i8,
158    output_precision: u8,
159    output_scale: i8,
160    cast_options: &CastOptions,
161) -> Result<PrimitiveArray<O>, ArrowError>
162where
163    I: DecimalType,
164    O: DecimalType,
165    I::Native: DecimalCast + ArrowNativeTypeOp,
166    O::Native: DecimalCast + ArrowNativeTypeOp,
167{
168    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
169    let delta_scale = output_scale - input_scale;
170    let mul = O::Native::from_decimal(10_i128)
171        .unwrap()
172        .pow_checked(delta_scale as u32)?;
173
174    // if the gain in precision (digits) is greater than the multiplication due to scaling
175    // every number will fit into the output type
176    // Example: If we are starting with any number of precision 5 [xxxxx],
177    // then an increase of scale by 3 will have the following effect on the representation:
178    // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
179    // needs to provide at least 8 digits precision
180    let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
181    let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
182
183    Ok(if is_infallible_cast {
184        // make sure we don't perform calculations that don't make sense w/o validation
185        validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
186        // unwrapping is safe since the result is guaranteed to fit into the target type
187        let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
188        array.unary(f)
189    } else if cast_options.safe {
190        array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
191    } else {
192        array.try_unary(|x| {
193            f(x).ok_or_else(|| error(x))
194                .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
195        })?
196    })
197}
198
199// Only support one type of decimal cast operations
200pub(crate) fn cast_decimal_to_decimal_same_type<T>(
201    array: &PrimitiveArray<T>,
202    input_precision: u8,
203    input_scale: i8,
204    output_precision: u8,
205    output_scale: i8,
206    cast_options: &CastOptions,
207) -> Result<ArrayRef, ArrowError>
208where
209    T: DecimalType,
210    T::Native: DecimalCast + ArrowNativeTypeOp,
211{
212    let array: PrimitiveArray<T> =
213        if input_scale == output_scale && input_precision <= output_precision {
214            array.clone()
215        } else if input_scale <= output_scale {
216            convert_to_bigger_or_equal_scale_decimal::<T, T>(
217                array,
218                input_precision,
219                input_scale,
220                output_precision,
221                output_scale,
222                cast_options,
223            )?
224        } else {
225            // input_scale > output_scale
226            convert_to_smaller_scale_decimal::<T, T>(
227                array,
228                input_precision,
229                input_scale,
230                output_precision,
231                output_scale,
232                cast_options,
233            )?
234        };
235
236    Ok(Arc::new(array.with_precision_and_scale(
237        output_precision,
238        output_scale,
239    )?))
240}
241
242// Support two different types of decimal cast operations
243pub(crate) fn cast_decimal_to_decimal<I, O>(
244    array: &PrimitiveArray<I>,
245    input_precision: u8,
246    input_scale: i8,
247    output_precision: u8,
248    output_scale: i8,
249    cast_options: &CastOptions,
250) -> Result<ArrayRef, ArrowError>
251where
252    I: DecimalType,
253    O: DecimalType,
254    I::Native: DecimalCast + ArrowNativeTypeOp,
255    O::Native: DecimalCast + ArrowNativeTypeOp,
256{
257    let array: PrimitiveArray<O> = if input_scale > output_scale {
258        convert_to_smaller_scale_decimal::<I, O>(
259            array,
260            input_precision,
261            input_scale,
262            output_precision,
263            output_scale,
264            cast_options,
265        )?
266    } else {
267        convert_to_bigger_or_equal_scale_decimal::<I, O>(
268            array,
269            input_precision,
270            input_scale,
271            output_precision,
272            output_scale,
273            cast_options,
274        )?
275    };
276
277    Ok(Arc::new(array.with_precision_and_scale(
278        output_precision,
279        output_scale,
280    )?))
281}
282
283/// Parses given string to specified decimal native (i128/i256) based on given
284/// scale. Returns an `Err` if it cannot parse given string.
285pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
286    value_str: &str,
287    scale: usize,
288) -> Result<T::Native, ArrowError>
289where
290    T::Native: DecimalCast + ArrowNativeTypeOp,
291{
292    let value_str = value_str.trim();
293    let parts: Vec<&str> = value_str.split('.').collect();
294    if parts.len() > 2 {
295        return Err(ArrowError::InvalidArgumentError(format!(
296            "Invalid decimal format: {value_str:?}"
297        )));
298    }
299
300    let (negative, first_part) = if parts[0].is_empty() {
301        (false, parts[0])
302    } else {
303        match parts[0].as_bytes()[0] {
304            b'-' => (true, &parts[0][1..]),
305            b'+' => (false, &parts[0][1..]),
306            _ => (false, parts[0]),
307        }
308    };
309
310    let integers = first_part;
311    let decimals = if parts.len() == 2 { parts[1] } else { "" };
312
313    if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
314        return Err(ArrowError::InvalidArgumentError(format!(
315            "Invalid decimal format: {value_str:?}"
316        )));
317    }
318
319    if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
320        return Err(ArrowError::InvalidArgumentError(format!(
321            "Invalid decimal format: {value_str:?}"
322        )));
323    }
324
325    // Adjust decimal based on scale
326    let mut number_decimals = if decimals.len() > scale {
327        let decimal_number = i256::from_string(decimals).ok_or_else(|| {
328            ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
329        })?;
330
331        let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
332
333        let half = div.div_wrapping(i256::from_i128(2));
334        let half_neg = half.neg_wrapping();
335
336        let d = decimal_number.div_wrapping(div);
337        let r = decimal_number.mod_wrapping(div);
338
339        // Round result
340        let adjusted = match decimal_number >= i256::ZERO {
341            true if r >= half => d.add_wrapping(i256::ONE),
342            false if r <= half_neg => d.sub_wrapping(i256::ONE),
343            _ => d,
344        };
345
346        let integers = if !integers.is_empty() {
347            i256::from_string(integers)
348                .ok_or_else(|| {
349                    ArrowError::InvalidArgumentError(format!(
350                        "Cannot parse decimal format: {value_str}"
351                    ))
352                })
353                .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
354        } else {
355            i256::ZERO
356        };
357
358        format!("{}", integers.add_wrapping(adjusted))
359    } else {
360        let padding = if scale > decimals.len() { scale } else { 0 };
361
362        let decimals = format!("{decimals:0<padding$}");
363        format!("{integers}{decimals}")
364    };
365
366    if negative {
367        number_decimals.insert(0, '-');
368    }
369
370    let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
371        ArrowError::InvalidArgumentError(format!(
372            "Cannot convert {} to {}: Overflow",
373            value_str,
374            T::PREFIX
375        ))
376    })?;
377
378    T::Native::from_decimal(value).ok_or_else(|| {
379        ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
380    })
381}
382
383pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
384    from: &'a S,
385    precision: u8,
386    scale: i8,
387    cast_options: &CastOptions,
388) -> Result<PrimitiveArray<T>, ArrowError>
389where
390    T: DecimalType,
391    T::Native: DecimalCast + ArrowNativeTypeOp,
392    &'a S: StringArrayType<'a>,
393{
394    if cast_options.safe {
395        let iter = from.iter().map(|v| {
396            v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
397                .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
398        });
399        // Benefit:
400        //     20% performance improvement
401        // Soundness:
402        //     The iterator is trustedLen because it comes from an `StringArray`.
403        Ok(unsafe {
404            PrimitiveArray::<T>::from_trusted_len_iter(iter)
405                .with_precision_and_scale(precision, scale)?
406        })
407    } else {
408        let vec = from
409            .iter()
410            .map(|v| {
411                v.map(|v| {
412                    parse_string_to_decimal_native::<T>(v, scale as usize)
413                        .map_err(|_| {
414                            ArrowError::CastError(format!(
415                                "Cannot cast string '{}' to value of {:?} type",
416                                v,
417                                T::DATA_TYPE,
418                            ))
419                        })
420                        .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
421                })
422                .transpose()
423            })
424            .collect::<Result<Vec<_>, _>>()?;
425        // Benefit:
426        //     20% performance improvement
427        // Soundness:
428        //     The iterator is trustedLen because it comes from an `StringArray`.
429        Ok(unsafe {
430            PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
431                .with_precision_and_scale(precision, scale)?
432        })
433    }
434}
435
436pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
437    from: &GenericStringArray<Offset>,
438    precision: u8,
439    scale: i8,
440    cast_options: &CastOptions,
441) -> Result<PrimitiveArray<T>, ArrowError>
442where
443    T: DecimalType,
444    T::Native: DecimalCast + ArrowNativeTypeOp,
445{
446    generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
447        from,
448        precision,
449        scale,
450        cast_options,
451    )
452}
453
454pub(crate) fn string_view_to_decimal_cast<T>(
455    from: &StringViewArray,
456    precision: u8,
457    scale: i8,
458    cast_options: &CastOptions,
459) -> Result<PrimitiveArray<T>, ArrowError>
460where
461    T: DecimalType,
462    T::Native: DecimalCast + ArrowNativeTypeOp,
463{
464    generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
465}
466
467/// Cast Utf8 to decimal
468pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
469    from: &dyn Array,
470    precision: u8,
471    scale: i8,
472    cast_options: &CastOptions,
473) -> Result<ArrayRef, ArrowError>
474where
475    T: DecimalType,
476    T::Native: DecimalCast + ArrowNativeTypeOp,
477{
478    if scale < 0 {
479        return Err(ArrowError::InvalidArgumentError(format!(
480            "Cannot cast string to decimal with negative scale {scale}"
481        )));
482    }
483
484    if scale > T::MAX_SCALE {
485        return Err(ArrowError::InvalidArgumentError(format!(
486            "Cannot cast string to decimal greater than maximum scale {}",
487            T::MAX_SCALE
488        )));
489    }
490
491    let result = match from.data_type() {
492        DataType::Utf8View => string_view_to_decimal_cast::<T>(
493            from.as_any().downcast_ref::<StringViewArray>().unwrap(),
494            precision,
495            scale,
496            cast_options,
497        )?,
498        DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
499            from.as_any()
500                .downcast_ref::<GenericStringArray<Offset>>()
501                .unwrap(),
502            precision,
503            scale,
504            cast_options,
505        )?,
506        other => {
507            return Err(ArrowError::ComputeError(format!(
508                "Cannot cast {other:?} to decimal",
509            )))
510        }
511    };
512
513    Ok(Arc::new(result))
514}
515
516pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
517    array: &PrimitiveArray<T>,
518    precision: u8,
519    scale: i8,
520    cast_options: &CastOptions,
521) -> Result<ArrayRef, ArrowError>
522where
523    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
524    D: DecimalType + ArrowPrimitiveType,
525    <D as ArrowPrimitiveType>::Native: DecimalCast,
526{
527    let mul = 10_f64.powi(scale as i32);
528
529    if cast_options.safe {
530        array
531            .unary_opt::<_, D>(|v| {
532                D::Native::from_f64((mul * v.as_()).round())
533                    .filter(|v| D::is_valid_decimal_precision(*v, precision))
534            })
535            .with_precision_and_scale(precision, scale)
536            .map(|a| Arc::new(a) as ArrayRef)
537    } else {
538        array
539            .try_unary::<_, D, _>(|v| {
540                D::Native::from_f64((mul * v.as_()).round())
541                    .ok_or_else(|| {
542                        ArrowError::CastError(format!(
543                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
544                            D::PREFIX,
545                            precision,
546                            scale,
547                            v
548                        ))
549                    })
550                    .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v))
551            })?
552            .with_precision_and_scale(precision, scale)
553            .map(|a| Arc::new(a) as ArrayRef)
554    }
555}
556
557pub(crate) fn cast_decimal_to_integer<D, T>(
558    array: &dyn Array,
559    base: D::Native,
560    scale: i8,
561    cast_options: &CastOptions,
562) -> Result<ArrayRef, ArrowError>
563where
564    T: ArrowPrimitiveType,
565    <T as ArrowPrimitiveType>::Native: NumCast,
566    D: DecimalType + ArrowPrimitiveType,
567    <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
568{
569    let array = array.as_primitive::<D>();
570
571    let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
572        ArrowError::CastError(format!(
573            "Cannot cast to {:?}. The scale {} causes overflow.",
574            D::PREFIX,
575            scale,
576        ))
577    })?;
578
579    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
580
581    if cast_options.safe {
582        for i in 0..array.len() {
583            if array.is_null(i) {
584                value_builder.append_null();
585            } else {
586                let v = array
587                    .value(i)
588                    .div_checked(div)
589                    .ok()
590                    .and_then(<T::Native as NumCast>::from::<D::Native>);
591
592                value_builder.append_option(v);
593            }
594        }
595    } else {
596        for i in 0..array.len() {
597            if array.is_null(i) {
598                value_builder.append_null();
599            } else {
600                let v = array.value(i).div_checked(div)?;
601
602                let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
603                    ArrowError::CastError(format!(
604                        "value of {:?} is out of range {}",
605                        v,
606                        T::DATA_TYPE
607                    ))
608                })?;
609
610                value_builder.append_value(value);
611            }
612        }
613    }
614    Ok(Arc::new(value_builder.finish()))
615}
616
617// Cast the decimal array to floating-point array
618pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
619    array: &dyn Array,
620    op: F,
621) -> Result<ArrayRef, ArrowError>
622where
623    F: Fn(D::Native) -> T::Native,
624{
625    let array = array.as_primitive::<D>();
626    let array = array.unary::<_, T>(op);
627    Ok(Arc::new(array))
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
636        assert_eq!(
637            parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
638            0_i128
639        );
640        assert_eq!(
641            parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
642            0_i128
643        );
644
645        assert_eq!(
646            parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
647            123_i128
648        );
649        assert_eq!(
650            parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
651            12300000_i128
652        );
653
654        assert_eq!(
655            parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
656            123_i128
657        );
658        assert_eq!(
659            parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
660            12345000_i128
661        );
662
663        assert_eq!(
664            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
665            123_i128
666        );
667        assert_eq!(
668            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
669            12345679_i128
670        );
671        Ok(())
672    }
673}