arrow_cast/cast/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cast::*;
19
20/// A utility trait that provides checked conversions between
21/// decimal types inspired by [`NumCast`]
22pub(crate) trait DecimalCast: Sized {
23    fn to_i128(self) -> Option<i128>;
24
25    fn to_i256(self) -> Option<i256>;
26
27    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
28
29    fn from_f64(n: f64) -> Option<Self>;
30}
31
32impl DecimalCast for i128 {
33    fn to_i128(self) -> Option<i128> {
34        Some(self)
35    }
36
37    fn to_i256(self) -> Option<i256> {
38        Some(i256::from_i128(self))
39    }
40
41    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
42        n.to_i128()
43    }
44
45    fn from_f64(n: f64) -> Option<Self> {
46        n.to_i128()
47    }
48}
49
50impl DecimalCast for i256 {
51    fn to_i128(self) -> Option<i128> {
52        self.to_i128()
53    }
54
55    fn to_i256(self) -> Option<i256> {
56        Some(self)
57    }
58
59    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60        n.to_i256()
61    }
62
63    fn from_f64(n: f64) -> Option<Self> {
64        i256::from_f64(n)
65    }
66}
67
68pub(crate) fn cast_decimal_to_decimal_error<I, O>(
69    output_precision: u8,
70    output_scale: i8,
71) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
72where
73    I: DecimalType,
74    O: DecimalType,
75    I::Native: DecimalCast + ArrowNativeTypeOp,
76    O::Native: DecimalCast + ArrowNativeTypeOp,
77{
78    move |x: I::Native| {
79        ArrowError::CastError(format!(
80            "Cannot cast to {}({}, {}). Overflowing on {:?}",
81            O::PREFIX,
82            output_precision,
83            output_scale,
84            x
85        ))
86    }
87}
88
89pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
90    array: &PrimitiveArray<I>,
91    input_precision: u8,
92    input_scale: i8,
93    output_precision: u8,
94    output_scale: i8,
95    cast_options: &CastOptions,
96) -> Result<PrimitiveArray<O>, ArrowError>
97where
98    I: DecimalType,
99    O: DecimalType,
100    I::Native: DecimalCast + ArrowNativeTypeOp,
101    O::Native: DecimalCast + ArrowNativeTypeOp,
102{
103    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
104    let delta_scale = input_scale - output_scale;
105    // if the reduction of the input number through scaling (dividing) is greater
106    // than a possible precision loss (plus potential increase via rounding)
107    // every input number will fit into the output type
108    // Example: If we are starting with any number of precision 5 [xxxxx],
109    // then and decrease the scale by 3 will have the following effect on the representation:
110    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
111    // The rounding may add an additional digit, so the cast to be infallible,
112    // the output type needs to have at least 3 digits of precision.
113    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
114    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
115    let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
116
117    let div = I::Native::from_decimal(10_i128)
118        .unwrap()
119        .pow_checked(delta_scale as u32)?;
120
121    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
122    let half_neg = half.neg_wrapping();
123
124    let f = |x: I::Native| {
125        // div is >= 10 and so this cannot overflow
126        let d = x.div_wrapping(div);
127        let r = x.mod_wrapping(div);
128
129        // Round result
130        let adjusted = match x >= I::Native::ZERO {
131            true if r >= half => d.add_wrapping(I::Native::ONE),
132            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
133            _ => d,
134        };
135        O::Native::from_decimal(adjusted)
136    };
137
138    Ok(if is_infallible_cast {
139        // make sure we don't perform calculations that don't make sense w/o validation
140        validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
141        let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
142                                              // to fit into the target type
143        array.unary(g)
144    } else if cast_options.safe {
145        array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
146    } else {
147        array.try_unary(|x| {
148            f(x).ok_or_else(|| error(x))
149                .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
150        })?
151    })
152}
153
154pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
155    array: &PrimitiveArray<I>,
156    input_precision: u8,
157    input_scale: i8,
158    output_precision: u8,
159    output_scale: i8,
160    cast_options: &CastOptions,
161) -> Result<PrimitiveArray<O>, ArrowError>
162where
163    I: DecimalType,
164    O: DecimalType,
165    I::Native: DecimalCast + ArrowNativeTypeOp,
166    O::Native: DecimalCast + ArrowNativeTypeOp,
167{
168    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
169    let delta_scale = output_scale - input_scale;
170    let mul = O::Native::from_decimal(10_i128)
171        .unwrap()
172        .pow_checked(delta_scale as u32)?;
173
174    // if the gain in precision (digits) is greater than the multiplication due to scaling
175    // every number will fit into the output type
176    // Example: If we are starting with any number of precision 5 [xxxxx],
177    // then an increase of scale by 3 will have the following effect on the representation:
178    // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
179    // needs to provide at least 8 digits precision
180    let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
181    let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
182
183    Ok(if is_infallible_cast {
184        // make sure we don't perform calculations that don't make sense w/o validation
185        validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
186        // unwrapping is safe since the result is guaranteed to fit into the target type
187        let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
188        array.unary(f)
189    } else if cast_options.safe {
190        array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
191    } else {
192        array.try_unary(|x| {
193            f(x).ok_or_else(|| error(x))
194                .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
195        })?
196    })
197}
198
199// Only support one type of decimal cast operations
200pub(crate) fn cast_decimal_to_decimal_same_type<T>(
201    array: &PrimitiveArray<T>,
202    input_precision: u8,
203    input_scale: i8,
204    output_precision: u8,
205    output_scale: i8,
206    cast_options: &CastOptions,
207) -> Result<ArrayRef, ArrowError>
208where
209    T: DecimalType,
210    T::Native: DecimalCast + ArrowNativeTypeOp,
211{
212    let array: PrimitiveArray<T> =
213        if input_scale == output_scale && input_precision <= output_precision {
214            array.clone()
215        } else if input_scale <= output_scale {
216            convert_to_bigger_or_equal_scale_decimal::<T, T>(
217                array,
218                input_precision,
219                input_scale,
220                output_precision,
221                output_scale,
222                cast_options,
223            )?
224        } else {
225            // input_scale > output_scale
226            convert_to_smaller_scale_decimal::<T, T>(
227                array,
228                input_precision,
229                input_scale,
230                output_precision,
231                output_scale,
232                cast_options,
233            )?
234        };
235
236    Ok(Arc::new(array.with_precision_and_scale(
237        output_precision,
238        output_scale,
239    )?))
240}
241
242// Support two different types of decimal cast operations
243pub(crate) fn cast_decimal_to_decimal<I, O>(
244    array: &PrimitiveArray<I>,
245    input_precision: u8,
246    input_scale: i8,
247    output_precision: u8,
248    output_scale: i8,
249    cast_options: &CastOptions,
250) -> Result<ArrayRef, ArrowError>
251where
252    I: DecimalType,
253    O: DecimalType,
254    I::Native: DecimalCast + ArrowNativeTypeOp,
255    O::Native: DecimalCast + ArrowNativeTypeOp,
256{
257    let array: PrimitiveArray<O> = if input_scale > output_scale {
258        convert_to_smaller_scale_decimal::<I, O>(
259            array,
260            input_precision,
261            input_scale,
262            output_precision,
263            output_scale,
264            cast_options,
265        )?
266    } else {
267        convert_to_bigger_or_equal_scale_decimal::<I, O>(
268            array,
269            input_precision,
270            input_scale,
271            output_precision,
272            output_scale,
273            cast_options,
274        )?
275    };
276
277    Ok(Arc::new(array.with_precision_and_scale(
278        output_precision,
279        output_scale,
280    )?))
281}
282
283/// Parses given string to specified decimal native (i128/i256) based on given
284/// scale. Returns an `Err` if it cannot parse given string.
285pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
286    value_str: &str,
287    scale: usize,
288) -> Result<T::Native, ArrowError>
289where
290    T::Native: DecimalCast + ArrowNativeTypeOp,
291{
292    let value_str = value_str.trim();
293    let parts: Vec<&str> = value_str.split('.').collect();
294    if parts.len() > 2 {
295        return Err(ArrowError::InvalidArgumentError(format!(
296            "Invalid decimal format: {value_str:?}"
297        )));
298    }
299
300    let (negative, first_part) = if parts[0].is_empty() {
301        (false, parts[0])
302    } else {
303        match parts[0].as_bytes()[0] {
304            b'-' => (true, &parts[0][1..]),
305            b'+' => (false, &parts[0][1..]),
306            _ => (false, parts[0]),
307        }
308    };
309
310    let integers = first_part;
311    let decimals = if parts.len() == 2 { parts[1] } else { "" };
312
313    if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
314        return Err(ArrowError::InvalidArgumentError(format!(
315            "Invalid decimal format: {value_str:?}"
316        )));
317    }
318
319    if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
320        return Err(ArrowError::InvalidArgumentError(format!(
321            "Invalid decimal format: {value_str:?}"
322        )));
323    }
324
325    // Adjust decimal based on scale
326    let mut number_decimals = if decimals.len() > scale {
327        let decimal_number = i256::from_string(decimals).ok_or_else(|| {
328            ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
329        })?;
330
331        let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
332
333        let half = div.div_wrapping(i256::from_i128(2));
334        let half_neg = half.neg_wrapping();
335
336        let d = decimal_number.div_wrapping(div);
337        let r = decimal_number.mod_wrapping(div);
338
339        // Round result
340        let adjusted = match decimal_number >= i256::ZERO {
341            true if r >= half => d.add_wrapping(i256::ONE),
342            false if r <= half_neg => d.sub_wrapping(i256::ONE),
343            _ => d,
344        };
345
346        let integers = if !integers.is_empty() {
347            i256::from_string(integers)
348                .ok_or_else(|| {
349                    ArrowError::InvalidArgumentError(format!(
350                        "Cannot parse decimal format: {value_str}"
351                    ))
352                })
353                .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
354        } else {
355            i256::ZERO
356        };
357
358        format!("{}", integers.add_wrapping(adjusted))
359    } else {
360        let padding = if scale > decimals.len() { scale } else { 0 };
361
362        let decimals = format!("{decimals:0<padding$}");
363        format!("{integers}{decimals}")
364    };
365
366    if negative {
367        number_decimals.insert(0, '-');
368    }
369
370    let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
371        ArrowError::InvalidArgumentError(format!(
372            "Cannot convert {} to {}: Overflow",
373            value_str,
374            T::PREFIX
375        ))
376    })?;
377
378    T::Native::from_decimal(value).ok_or_else(|| {
379        ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
380    })
381}
382
383pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
384    from: &'a S,
385    precision: u8,
386    scale: i8,
387    cast_options: &CastOptions,
388) -> Result<PrimitiveArray<T>, ArrowError>
389where
390    T: DecimalType,
391    T::Native: DecimalCast + ArrowNativeTypeOp,
392    &'a S: StringArrayType<'a>,
393{
394    if cast_options.safe {
395        let iter = from.iter().map(|v| {
396            v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
397                .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
398        });
399        // Benefit:
400        //     20% performance improvement
401        // Soundness:
402        //     The iterator is trustedLen because it comes from an `StringArray`.
403        Ok(unsafe {
404            PrimitiveArray::<T>::from_trusted_len_iter(iter)
405                .with_precision_and_scale(precision, scale)?
406        })
407    } else {
408        let vec = from
409            .iter()
410            .map(|v| {
411                v.map(|v| {
412                    parse_string_to_decimal_native::<T>(v, scale as usize)
413                        .map_err(|_| {
414                            ArrowError::CastError(format!(
415                                "Cannot cast string '{}' to value of {:?} type",
416                                v,
417                                T::DATA_TYPE,
418                            ))
419                        })
420                        .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
421                })
422                .transpose()
423            })
424            .collect::<Result<Vec<_>, _>>()?;
425        // Benefit:
426        //     20% performance improvement
427        // Soundness:
428        //     The iterator is trustedLen because it comes from an `StringArray`.
429        Ok(unsafe {
430            PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
431                .with_precision_and_scale(precision, scale)?
432        })
433    }
434}
435
436pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
437    from: &GenericStringArray<Offset>,
438    precision: u8,
439    scale: i8,
440    cast_options: &CastOptions,
441) -> Result<PrimitiveArray<T>, ArrowError>
442where
443    T: DecimalType,
444    T::Native: DecimalCast + ArrowNativeTypeOp,
445{
446    generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
447        from,
448        precision,
449        scale,
450        cast_options,
451    )
452}
453
454pub(crate) fn string_view_to_decimal_cast<T>(
455    from: &StringViewArray,
456    precision: u8,
457    scale: i8,
458    cast_options: &CastOptions,
459) -> Result<PrimitiveArray<T>, ArrowError>
460where
461    T: DecimalType,
462    T::Native: DecimalCast + ArrowNativeTypeOp,
463{
464    generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
465}
466
467/// Cast Utf8 to decimal
468pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
469    from: &dyn Array,
470    precision: u8,
471    scale: i8,
472    cast_options: &CastOptions,
473) -> Result<ArrayRef, ArrowError>
474where
475    T: DecimalType,
476    T::Native: DecimalCast + ArrowNativeTypeOp,
477{
478    if scale < 0 {
479        return Err(ArrowError::InvalidArgumentError(format!(
480            "Cannot cast string to decimal with negative scale {scale}"
481        )));
482    }
483
484    if scale > T::MAX_SCALE {
485        return Err(ArrowError::InvalidArgumentError(format!(
486            "Cannot cast string to decimal greater than maximum scale {}",
487            T::MAX_SCALE
488        )));
489    }
490
491    let result = match from.data_type() {
492        DataType::Utf8View => string_view_to_decimal_cast::<T>(
493            from.as_any().downcast_ref::<StringViewArray>().unwrap(),
494            precision,
495            scale,
496            cast_options,
497        )?,
498        DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
499            from.as_any()
500                .downcast_ref::<GenericStringArray<Offset>>()
501                .unwrap(),
502            precision,
503            scale,
504            cast_options,
505        )?,
506        other => {
507            return Err(ArrowError::ComputeError(format!(
508                "Cannot cast {:?} to decimal",
509                other
510            )))
511        }
512    };
513
514    Ok(Arc::new(result))
515}
516
517pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
518    array: &PrimitiveArray<T>,
519    precision: u8,
520    scale: i8,
521    cast_options: &CastOptions,
522) -> Result<ArrayRef, ArrowError>
523where
524    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
525    D: DecimalType + ArrowPrimitiveType,
526    <D as ArrowPrimitiveType>::Native: DecimalCast,
527{
528    let mul = 10_f64.powi(scale as i32);
529
530    if cast_options.safe {
531        array
532            .unary_opt::<_, D>(|v| {
533                D::Native::from_f64((mul * v.as_()).round())
534                    .filter(|v| D::is_valid_decimal_precision(*v, precision))
535            })
536            .with_precision_and_scale(precision, scale)
537            .map(|a| Arc::new(a) as ArrayRef)
538    } else {
539        array
540            .try_unary::<_, D, _>(|v| {
541                D::Native::from_f64((mul * v.as_()).round())
542                    .ok_or_else(|| {
543                        ArrowError::CastError(format!(
544                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
545                            D::PREFIX,
546                            precision,
547                            scale,
548                            v
549                        ))
550                    })
551                    .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v))
552            })?
553            .with_precision_and_scale(precision, scale)
554            .map(|a| Arc::new(a) as ArrayRef)
555    }
556}
557
558pub(crate) fn cast_decimal_to_integer<D, T>(
559    array: &dyn Array,
560    base: D::Native,
561    scale: i8,
562    cast_options: &CastOptions,
563) -> Result<ArrayRef, ArrowError>
564where
565    T: ArrowPrimitiveType,
566    <T as ArrowPrimitiveType>::Native: NumCast,
567    D: DecimalType + ArrowPrimitiveType,
568    <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
569{
570    let array = array.as_primitive::<D>();
571
572    let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
573        ArrowError::CastError(format!(
574            "Cannot cast to {:?}. The scale {} causes overflow.",
575            D::PREFIX,
576            scale,
577        ))
578    })?;
579
580    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
581
582    if cast_options.safe {
583        for i in 0..array.len() {
584            if array.is_null(i) {
585                value_builder.append_null();
586            } else {
587                let v = array
588                    .value(i)
589                    .div_checked(div)
590                    .ok()
591                    .and_then(<T::Native as NumCast>::from::<D::Native>);
592
593                value_builder.append_option(v);
594            }
595        }
596    } else {
597        for i in 0..array.len() {
598            if array.is_null(i) {
599                value_builder.append_null();
600            } else {
601                let v = array.value(i).div_checked(div)?;
602
603                let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
604                    ArrowError::CastError(format!(
605                        "value of {:?} is out of range {}",
606                        v,
607                        T::DATA_TYPE
608                    ))
609                })?;
610
611                value_builder.append_value(value);
612            }
613        }
614    }
615    Ok(Arc::new(value_builder.finish()))
616}
617
618// Cast the decimal array to floating-point array
619pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
620    array: &dyn Array,
621    op: F,
622) -> Result<ArrayRef, ArrowError>
623where
624    F: Fn(D::Native) -> T::Native,
625{
626    let array = array.as_primitive::<D>();
627    let array = array.unary::<_, T>(op);
628    Ok(Arc::new(array))
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    #[test]
636    fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
637        assert_eq!(
638            parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
639            0_i128
640        );
641        assert_eq!(
642            parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
643            0_i128
644        );
645
646        assert_eq!(
647            parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
648            123_i128
649        );
650        assert_eq!(
651            parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
652            12300000_i128
653        );
654
655        assert_eq!(
656            parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
657            123_i128
658        );
659        assert_eq!(
660            parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
661            12345000_i128
662        );
663
664        assert_eq!(
665            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
666            123_i128
667        );
668        assert_eq!(
669            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
670            12345679_i128
671        );
672        Ok(())
673    }
674}