arrow_cast/cast/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cast::*;
19
20/// A utility trait that provides checked conversions between
21/// decimal types inspired by [`NumCast`]
22pub trait DecimalCast: Sized {
23    /// Convert the decimal to an i32
24    fn to_i32(self) -> Option<i32>;
25
26    /// Convert the decimal to an i64
27    fn to_i64(self) -> Option<i64>;
28
29    /// Convert the decimal to an i128
30    fn to_i128(self) -> Option<i128>;
31
32    /// Convert the decimal to an i256
33    fn to_i256(self) -> Option<i256>;
34
35    /// Convert a decimal from a decimal
36    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
37
38    /// Convert a decimal from a f64
39    fn from_f64(n: f64) -> Option<Self>;
40}
41
42impl DecimalCast for i32 {
43    fn to_i32(self) -> Option<i32> {
44        Some(self)
45    }
46
47    fn to_i64(self) -> Option<i64> {
48        Some(self as i64)
49    }
50
51    fn to_i128(self) -> Option<i128> {
52        Some(self as i128)
53    }
54
55    fn to_i256(self) -> Option<i256> {
56        Some(i256::from_i128(self as i128))
57    }
58
59    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60        n.to_i32()
61    }
62
63    fn from_f64(n: f64) -> Option<Self> {
64        n.to_i32()
65    }
66}
67
68impl DecimalCast for i64 {
69    fn to_i32(self) -> Option<i32> {
70        i32::try_from(self).ok()
71    }
72
73    fn to_i64(self) -> Option<i64> {
74        Some(self)
75    }
76
77    fn to_i128(self) -> Option<i128> {
78        Some(self as i128)
79    }
80
81    fn to_i256(self) -> Option<i256> {
82        Some(i256::from_i128(self as i128))
83    }
84
85    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
86        n.to_i64()
87    }
88
89    fn from_f64(n: f64) -> Option<Self> {
90        // Call implementation explicitly otherwise this resolves to `to_i64`
91        // in arrow-buffer that behaves differently.
92        num_traits::ToPrimitive::to_i64(&n)
93    }
94}
95
96impl DecimalCast for i128 {
97    fn to_i32(self) -> Option<i32> {
98        i32::try_from(self).ok()
99    }
100
101    fn to_i64(self) -> Option<i64> {
102        i64::try_from(self).ok()
103    }
104
105    fn to_i128(self) -> Option<i128> {
106        Some(self)
107    }
108
109    fn to_i256(self) -> Option<i256> {
110        Some(i256::from_i128(self))
111    }
112
113    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
114        n.to_i128()
115    }
116
117    fn from_f64(n: f64) -> Option<Self> {
118        n.to_i128()
119    }
120}
121
122impl DecimalCast for i256 {
123    fn to_i32(self) -> Option<i32> {
124        self.to_i128().map(|x| i32::try_from(x).ok())?
125    }
126
127    fn to_i64(self) -> Option<i64> {
128        self.to_i128().map(|x| i64::try_from(x).ok())?
129    }
130
131    fn to_i128(self) -> Option<i128> {
132        self.to_i128()
133    }
134
135    fn to_i256(self) -> Option<i256> {
136        Some(self)
137    }
138
139    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
140        n.to_i256()
141    }
142
143    fn from_f64(n: f64) -> Option<Self> {
144        i256::from_f64(n)
145    }
146}
147
148/// Construct closures to upscale decimals from `(input_precision, input_scale)` to
149/// `(output_precision, output_scale)`.
150///
151/// Returns `(f_fallible, f_infallible)` where:
152/// * `f_fallible` yields `None` when the requested cast would overflow
153/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None`
154///   and callers must fall back to `f_fallible`
155///
156/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale`
157/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`.
158/// In that case, the caller should treat this as an overflow for the output scale
159/// and handle it accordingly (e.g., return a cast error).
160#[allow(clippy::type_complexity)]
161fn make_upscaler<I: DecimalType, O: DecimalType>(
162    input_precision: u8,
163    input_scale: i8,
164    output_precision: u8,
165    output_scale: i8,
166) -> Option<(
167    impl Fn(I::Native) -> Option<O::Native>,
168    Option<impl Fn(I::Native) -> O::Native>,
169)>
170where
171    I::Native: DecimalCast + ArrowNativeTypeOp,
172    O::Native: DecimalCast + ArrowNativeTypeOp,
173{
174    let delta_scale = output_scale - input_scale;
175
176    // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...).
177    // Adding 1 yields exactly 10^k without computing a power at runtime.
178    // Using the precomputed table avoids pow(10, k) and its checked/overflow
179    // handling, which is faster and simpler for scaling by 10^delta_scale.
180    let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
181    let mul = max.add_wrapping(O::Native::ONE);
182    let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
183
184    // if the gain in precision (digits) is greater than the multiplication due to scaling
185    // every number will fit into the output type
186    // Example: If we are starting with any number of precision 5 [xxxxx],
187    // then an increase of scale by 3 will have the following effect on the representation:
188    // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
189    // needs to provide at least 8 digits precision
190    let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
191    let f_infallible = is_infallible_cast
192        .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul));
193    Some((f_fallible, f_infallible))
194}
195
196/// Construct closures to downscale decimals from `(input_precision, input_scale)` to
197/// `(output_precision, output_scale)`.
198///
199/// Returns `(f_fallible, f_infallible)` where:
200/// * `f_fallible` yields `None` when the requested cast would overflow
201/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None`
202///   and callers must fall back to `f_fallible`
203///
204/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale`
205/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`.
206/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the
207/// available precision). Callers should therefore produce zero values (preserving nulls) rather
208/// than returning an error.
209#[allow(clippy::type_complexity)]
210fn make_downscaler<I: DecimalType, O: DecimalType>(
211    input_precision: u8,
212    input_scale: i8,
213    output_precision: u8,
214    output_scale: i8,
215) -> Option<(
216    impl Fn(I::Native) -> Option<O::Native>,
217    Option<impl Fn(I::Native) -> O::Native>,
218)>
219where
220    I::Native: DecimalCast + ArrowNativeTypeOp,
221    O::Native: DecimalCast + ArrowNativeTypeOp,
222{
223    let delta_scale = input_scale - output_scale;
224
225    // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the
226    // scale change divides out more digits than the input has precision and the result of the cast
227    // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest
228    // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values
229    // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even
230    // smaller results, which also round to zero. In that case, just return an array of zeros.
231    let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
232
233    let div = max.add_wrapping(I::Native::ONE);
234    let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
235    let half_neg = half.neg_wrapping();
236
237    let f_fallible = move |x: I::Native| {
238        // div is >= 10 and so this cannot overflow
239        let d = x.div_wrapping(div);
240        let r = x.mod_wrapping(div);
241
242        // Round result
243        let adjusted = match x >= I::Native::ZERO {
244            true if r >= half => d.add_wrapping(I::Native::ONE),
245            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
246            _ => d,
247        };
248        O::Native::from_decimal(adjusted)
249    };
250
251    // if the reduction of the input number through scaling (dividing) is greater
252    // than a possible precision loss (plus potential increase via rounding)
253    // every input number will fit into the output type
254    // Example: If we are starting with any number of precision 5 [xxxxx],
255    // then and decrease the scale by 3 will have the following effect on the representation:
256    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
257    // The rounding may add a digit, so the cast to be infallible,
258    // the output type needs to have at least 3 digits of precision.
259    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
260    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
261    let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
262    let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap());
263    Some((f_fallible, f_infallible))
264}
265
266/// Apply the rescaler function to the value.
267/// If the rescaler is infallible, use the infallible function.
268/// Otherwise, use the fallible function and validate the precision.
269fn apply_rescaler<I: DecimalType, O: DecimalType>(
270    value: I::Native,
271    output_precision: u8,
272    f: impl Fn(I::Native) -> Option<O::Native>,
273    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
274) -> Option<O::Native>
275where
276    I::Native: DecimalCast,
277    O::Native: DecimalCast,
278{
279    if let Some(f_infallible) = f_infallible {
280        Some(f_infallible(value))
281    } else {
282        f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
283    }
284}
285
286/// Rescales a decimal value from `(input_precision, input_scale)` to
287/// `(output_precision, output_scale)` and returns the converted number when it fits
288/// within the output precision.
289///
290/// The function first validates that the requested precision and scale are supported for
291/// both the source and destination decimal types. It then either upscales (multiplying
292/// by an appropriate power of ten) or downscales (dividing with rounding) the input value.
293/// When the scaling factor exceeds the precision table of the destination type, the value
294/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any
295/// possible result would be zero at the requested scale).
296///
297/// This mirrors the column-oriented helpers of decimal casting but operates on a single value
298/// (row-level) instead of an entire array.
299///
300/// Returns `None` if the value cannot be represented with the requested precision.
301pub fn rescale_decimal<I: DecimalType, O: DecimalType>(
302    value: I::Native,
303    input_precision: u8,
304    input_scale: i8,
305    output_precision: u8,
306    output_scale: i8,
307) -> Option<O::Native>
308where
309    I::Native: DecimalCast + ArrowNativeTypeOp,
310    O::Native: DecimalCast + ArrowNativeTypeOp,
311{
312    validate_decimal_precision_and_scale::<I>(input_precision, input_scale).ok()?;
313    validate_decimal_precision_and_scale::<O>(output_precision, output_scale).ok()?;
314
315    if input_scale <= output_scale {
316        let (f, f_infallible) =
317            make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
318        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
319    } else {
320        let Some((f, f_infallible)) =
321            make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
322        else {
323            // Scale reduction exceeds supported precision; result mathematically rounds to zero
324            return Some(O::Native::ZERO);
325        };
326        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
327    }
328}
329
330fn cast_decimal_to_decimal_error<I, O>(
331    output_precision: u8,
332    output_scale: i8,
333) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
334where
335    I: DecimalType,
336    O: DecimalType,
337    I::Native: DecimalCast + ArrowNativeTypeOp,
338    O::Native: DecimalCast + ArrowNativeTypeOp,
339{
340    move |x: I::Native| {
341        ArrowError::CastError(format!(
342            "Cannot cast to {}({}, {}). Overflowing on {:?}",
343            O::PREFIX,
344            output_precision,
345            output_scale,
346            x
347        ))
348    }
349}
350
351fn apply_decimal_cast<I: DecimalType, O: DecimalType>(
352    array: &PrimitiveArray<I>,
353    output_precision: u8,
354    output_scale: i8,
355    f_fallible: impl Fn(I::Native) -> Option<O::Native>,
356    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
357    cast_options: &CastOptions,
358) -> Result<PrimitiveArray<O>, ArrowError>
359where
360    I::Native: DecimalCast + ArrowNativeTypeOp,
361    O::Native: DecimalCast + ArrowNativeTypeOp,
362{
363    let array = if let Some(f_infallible) = f_infallible {
364        array.unary(f_infallible)
365    } else if cast_options.safe {
366        array.unary_opt(|x| {
367            f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
368        })
369    } else {
370        let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
371        array.try_unary(|x| {
372            f_fallible(x).ok_or_else(|| error(x)).and_then(|v| {
373                O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
374            })
375        })?
376    };
377    Ok(array)
378}
379
380fn convert_to_smaller_scale_decimal<I, O>(
381    array: &PrimitiveArray<I>,
382    input_precision: u8,
383    input_scale: i8,
384    output_precision: u8,
385    output_scale: i8,
386    cast_options: &CastOptions,
387) -> Result<PrimitiveArray<O>, ArrowError>
388where
389    I: DecimalType,
390    O: DecimalType,
391    I::Native: DecimalCast + ArrowNativeTypeOp,
392    O::Native: DecimalCast + ArrowNativeTypeOp,
393{
394    if let Some((f_fallible, f_infallible)) =
395        make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
396    {
397        apply_decimal_cast(
398            array,
399            output_precision,
400            output_scale,
401            f_fallible,
402            f_infallible,
403            cast_options,
404        )
405    } else {
406        // Scale reduction exceeds supported precision; result mathematically rounds to zero
407        let zeros = vec![O::Native::ZERO; array.len()];
408        Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned()))
409    }
410}
411
412fn convert_to_bigger_or_equal_scale_decimal<I, O>(
413    array: &PrimitiveArray<I>,
414    input_precision: u8,
415    input_scale: i8,
416    output_precision: u8,
417    output_scale: i8,
418    cast_options: &CastOptions,
419) -> Result<PrimitiveArray<O>, ArrowError>
420where
421    I: DecimalType,
422    O: DecimalType,
423    I::Native: DecimalCast + ArrowNativeTypeOp,
424    O::Native: DecimalCast + ArrowNativeTypeOp,
425{
426    if let Some((f, f_infallible)) =
427        make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
428    {
429        apply_decimal_cast(
430            array,
431            output_precision,
432            output_scale,
433            f,
434            f_infallible,
435            cast_options,
436        )
437    } else {
438        // Scale increase exceeds supported precision; return overflow error
439        Err(ArrowError::CastError(format!(
440            "Cannot cast to {}({}, {}). Value overflows for output scale",
441            O::PREFIX,
442            output_precision,
443            output_scale
444        )))
445    }
446}
447
448// Only support one type of decimal cast operations
449pub(crate) fn cast_decimal_to_decimal_same_type<T>(
450    array: &PrimitiveArray<T>,
451    input_precision: u8,
452    input_scale: i8,
453    output_precision: u8,
454    output_scale: i8,
455    cast_options: &CastOptions,
456) -> Result<ArrayRef, ArrowError>
457where
458    T: DecimalType,
459    T::Native: DecimalCast + ArrowNativeTypeOp,
460{
461    let array: PrimitiveArray<T> =
462        if input_scale == output_scale && input_precision <= output_precision {
463            array.clone()
464        } else if input_scale <= output_scale {
465            convert_to_bigger_or_equal_scale_decimal::<T, T>(
466                array,
467                input_precision,
468                input_scale,
469                output_precision,
470                output_scale,
471                cast_options,
472            )?
473        } else {
474            // input_scale > output_scale
475            convert_to_smaller_scale_decimal::<T, T>(
476                array,
477                input_precision,
478                input_scale,
479                output_precision,
480                output_scale,
481                cast_options,
482            )?
483        };
484
485    Ok(Arc::new(array.with_precision_and_scale(
486        output_precision,
487        output_scale,
488    )?))
489}
490
491// Support two different types of decimal cast operations
492pub(crate) fn cast_decimal_to_decimal<I, O>(
493    array: &PrimitiveArray<I>,
494    input_precision: u8,
495    input_scale: i8,
496    output_precision: u8,
497    output_scale: i8,
498    cast_options: &CastOptions,
499) -> Result<ArrayRef, ArrowError>
500where
501    I: DecimalType,
502    O: DecimalType,
503    I::Native: DecimalCast + ArrowNativeTypeOp,
504    O::Native: DecimalCast + ArrowNativeTypeOp,
505{
506    let array: PrimitiveArray<O> = if input_scale > output_scale {
507        convert_to_smaller_scale_decimal::<I, O>(
508            array,
509            input_precision,
510            input_scale,
511            output_precision,
512            output_scale,
513            cast_options,
514        )?
515    } else {
516        convert_to_bigger_or_equal_scale_decimal::<I, O>(
517            array,
518            input_precision,
519            input_scale,
520            output_precision,
521            output_scale,
522            cast_options,
523        )?
524    };
525
526    Ok(Arc::new(array.with_precision_and_scale(
527        output_precision,
528        output_scale,
529    )?))
530}
531
532/// Parses given string to specified decimal native (i128/i256) based on given
533/// scale. Returns an `Err` if it cannot parse given string.
534pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
535    value_str: &str,
536    scale: usize,
537) -> Result<T::Native, ArrowError>
538where
539    T::Native: DecimalCast + ArrowNativeTypeOp,
540{
541    let value_str = value_str.trim();
542    let parts: Vec<&str> = value_str.split('.').collect();
543    if parts.len() > 2 {
544        return Err(ArrowError::InvalidArgumentError(format!(
545            "Invalid decimal format: {value_str:?}"
546        )));
547    }
548
549    let (negative, first_part) = if parts[0].is_empty() {
550        (false, parts[0])
551    } else {
552        match parts[0].as_bytes()[0] {
553            b'-' => (true, &parts[0][1..]),
554            b'+' => (false, &parts[0][1..]),
555            _ => (false, parts[0]),
556        }
557    };
558
559    let integers = first_part;
560    let decimals = if parts.len() == 2 { parts[1] } else { "" };
561
562    if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
563        return Err(ArrowError::InvalidArgumentError(format!(
564            "Invalid decimal format: {value_str:?}"
565        )));
566    }
567
568    if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
569        return Err(ArrowError::InvalidArgumentError(format!(
570            "Invalid decimal format: {value_str:?}"
571        )));
572    }
573
574    // Adjust decimal based on scale
575    let mut number_decimals = if decimals.len() > scale {
576        let decimal_number = i256::from_string(decimals).ok_or_else(|| {
577            ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
578        })?;
579
580        let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
581
582        let half = div.div_wrapping(i256::from_i128(2));
583        let half_neg = half.neg_wrapping();
584
585        let d = decimal_number.div_wrapping(div);
586        let r = decimal_number.mod_wrapping(div);
587
588        // Round result
589        let adjusted = match decimal_number >= i256::ZERO {
590            true if r >= half => d.add_wrapping(i256::ONE),
591            false if r <= half_neg => d.sub_wrapping(i256::ONE),
592            _ => d,
593        };
594
595        let integers = if !integers.is_empty() {
596            i256::from_string(integers)
597                .ok_or_else(|| {
598                    ArrowError::InvalidArgumentError(format!(
599                        "Cannot parse decimal format: {value_str}"
600                    ))
601                })
602                .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
603        } else {
604            i256::ZERO
605        };
606
607        format!("{}", integers.add_wrapping(adjusted))
608    } else {
609        let padding = if scale > decimals.len() { scale } else { 0 };
610
611        let decimals = format!("{decimals:0<padding$}");
612        format!("{integers}{decimals}")
613    };
614
615    if negative {
616        number_decimals.insert(0, '-');
617    }
618
619    let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
620        ArrowError::InvalidArgumentError(format!(
621            "Cannot convert {} to {}: Overflow",
622            value_str,
623            T::PREFIX
624        ))
625    })?;
626
627    T::Native::from_decimal(value).ok_or_else(|| {
628        ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
629    })
630}
631
632pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
633    from: &'a S,
634    precision: u8,
635    scale: i8,
636    cast_options: &CastOptions,
637) -> Result<PrimitiveArray<T>, ArrowError>
638where
639    T: DecimalType,
640    T::Native: DecimalCast + ArrowNativeTypeOp,
641    &'a S: StringArrayType<'a>,
642{
643    if cast_options.safe {
644        let iter = from.iter().map(|v| {
645            v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
646                .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
647        });
648        // Benefit:
649        //     20% performance improvement
650        // Soundness:
651        //     The iterator is trustedLen because it comes from an `StringArray`.
652        Ok(unsafe {
653            PrimitiveArray::<T>::from_trusted_len_iter(iter)
654                .with_precision_and_scale(precision, scale)?
655        })
656    } else {
657        let vec = from
658            .iter()
659            .map(|v| {
660                v.map(|v| {
661                    parse_string_to_decimal_native::<T>(v, scale as usize)
662                        .map_err(|_| {
663                            ArrowError::CastError(format!(
664                                "Cannot cast string '{v}' to value of {} type",
665                                T::DATA_TYPE,
666                            ))
667                        })
668                        .and_then(|v| T::validate_decimal_precision(v, precision, scale).map(|_| v))
669                })
670                .transpose()
671            })
672            .collect::<Result<Vec<_>, _>>()?;
673        // Benefit:
674        //     20% performance improvement
675        // Soundness:
676        //     The iterator is trustedLen because it comes from an `StringArray`.
677        Ok(unsafe {
678            PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
679                .with_precision_and_scale(precision, scale)?
680        })
681    }
682}
683
684pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
685    from: &GenericStringArray<Offset>,
686    precision: u8,
687    scale: i8,
688    cast_options: &CastOptions,
689) -> Result<PrimitiveArray<T>, ArrowError>
690where
691    T: DecimalType,
692    T::Native: DecimalCast + ArrowNativeTypeOp,
693{
694    generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
695        from,
696        precision,
697        scale,
698        cast_options,
699    )
700}
701
702pub(crate) fn string_view_to_decimal_cast<T>(
703    from: &StringViewArray,
704    precision: u8,
705    scale: i8,
706    cast_options: &CastOptions,
707) -> Result<PrimitiveArray<T>, ArrowError>
708where
709    T: DecimalType,
710    T::Native: DecimalCast + ArrowNativeTypeOp,
711{
712    generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
713}
714
715/// Cast Utf8 to decimal
716pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
717    from: &dyn Array,
718    precision: u8,
719    scale: i8,
720    cast_options: &CastOptions,
721) -> Result<ArrayRef, ArrowError>
722where
723    T: DecimalType,
724    T::Native: DecimalCast + ArrowNativeTypeOp,
725{
726    if scale < 0 {
727        return Err(ArrowError::InvalidArgumentError(format!(
728            "Cannot cast string to decimal with negative scale {scale}"
729        )));
730    }
731
732    if scale > T::MAX_SCALE {
733        return Err(ArrowError::InvalidArgumentError(format!(
734            "Cannot cast string to decimal greater than maximum scale {}",
735            T::MAX_SCALE
736        )));
737    }
738
739    let result = match from.data_type() {
740        DataType::Utf8View => string_view_to_decimal_cast::<T>(
741            from.as_any().downcast_ref::<StringViewArray>().unwrap(),
742            precision,
743            scale,
744            cast_options,
745        )?,
746        DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
747            from.as_any()
748                .downcast_ref::<GenericStringArray<Offset>>()
749                .unwrap(),
750            precision,
751            scale,
752            cast_options,
753        )?,
754        other => {
755            return Err(ArrowError::ComputeError(format!(
756                "Cannot cast {other:?} to decimal",
757            )));
758        }
759    };
760
761    Ok(Arc::new(result))
762}
763
764pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
765    array: &PrimitiveArray<T>,
766    precision: u8,
767    scale: i8,
768    cast_options: &CastOptions,
769) -> Result<ArrayRef, ArrowError>
770where
771    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
772    D: DecimalType + ArrowPrimitiveType,
773    <D as ArrowPrimitiveType>::Native: DecimalCast,
774{
775    let mul = 10_f64.powi(scale as i32);
776
777    if cast_options.safe {
778        array
779            .unary_opt::<_, D>(|v| {
780                D::Native::from_f64((mul * v.as_()).round())
781                    .filter(|v| D::is_valid_decimal_precision(*v, precision))
782            })
783            .with_precision_and_scale(precision, scale)
784            .map(|a| Arc::new(a) as ArrayRef)
785    } else {
786        array
787            .try_unary::<_, D, _>(|v| {
788                D::Native::from_f64((mul * v.as_()).round())
789                    .ok_or_else(|| {
790                        ArrowError::CastError(format!(
791                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
792                            D::PREFIX,
793                            precision,
794                            scale,
795                            v
796                        ))
797                    })
798                    .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v))
799            })?
800            .with_precision_and_scale(precision, scale)
801            .map(|a| Arc::new(a) as ArrayRef)
802    }
803}
804
805pub(crate) fn cast_decimal_to_integer<D, T>(
806    array: &dyn Array,
807    base: D::Native,
808    scale: i8,
809    cast_options: &CastOptions,
810) -> Result<ArrayRef, ArrowError>
811where
812    T: ArrowPrimitiveType,
813    <T as ArrowPrimitiveType>::Native: NumCast,
814    D: DecimalType + ArrowPrimitiveType,
815    <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive,
816{
817    let array = array.as_primitive::<D>();
818
819    let div: D::Native = base.pow_checked(scale as u32).map_err(|_| {
820        ArrowError::CastError(format!(
821            "Cannot cast to {:?}. The scale {} causes overflow.",
822            D::PREFIX,
823            scale,
824        ))
825    })?;
826
827    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
828
829    if cast_options.safe {
830        for i in 0..array.len() {
831            if array.is_null(i) {
832                value_builder.append_null();
833            } else {
834                let v = array
835                    .value(i)
836                    .div_checked(div)
837                    .ok()
838                    .and_then(<T::Native as NumCast>::from::<D::Native>);
839
840                value_builder.append_option(v);
841            }
842        }
843    } else {
844        for i in 0..array.len() {
845            if array.is_null(i) {
846                value_builder.append_null();
847            } else {
848                let v = array.value(i).div_checked(div)?;
849
850                let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
851                    ArrowError::CastError(format!(
852                        "value of {:?} is out of range {}",
853                        v,
854                        T::DATA_TYPE
855                    ))
856                })?;
857
858                value_builder.append_value(value);
859            }
860        }
861    }
862    Ok(Arc::new(value_builder.finish()))
863}
864
865/// Cast a decimal array to a floating point array.
866///
867/// Conversion is lossy and follows standard floating point semantics. Values
868/// that exceed the representable range become `INFINITY` or `-INFINITY` without
869/// returning an error.
870pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
871    array: &dyn Array,
872    op: F,
873) -> Result<ArrayRef, ArrowError>
874where
875    F: Fn(D::Native) -> T::Native,
876{
877    let array = array.as_primitive::<D>();
878    let array = array.unary::<_, T>(op);
879    Ok(Arc::new(array))
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885
886    #[test]
887    fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
888        assert_eq!(
889            parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
890            0_i128
891        );
892        assert_eq!(
893            parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
894            0_i128
895        );
896
897        assert_eq!(
898            parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
899            123_i128
900        );
901        assert_eq!(
902            parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
903            12300000_i128
904        );
905
906        assert_eq!(
907            parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
908            123_i128
909        );
910        assert_eq!(
911            parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
912            12345000_i128
913        );
914
915        assert_eq!(
916            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
917            123_i128
918        );
919        assert_eq!(
920            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
921            12345679_i128
922        );
923        Ok(())
924    }
925
926    #[test]
927    fn test_rescale_decimal_upscale_within_precision() {
928        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(
929            12_345_i128, // 123.45 with scale 2
930            5,
931            2,
932            8,
933            5,
934        );
935        assert_eq!(result, Some(12_345_000_i128));
936    }
937
938    #[test]
939    fn test_rescale_decimal_downscale_rounds_half_away_from_zero() {
940        let positive = rescale_decimal::<Decimal128Type, Decimal128Type>(
941            1_050_i128, // 1.050 with scale 3
942            5, 3, 5, 1,
943        );
944        assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1
945
946        let negative = rescale_decimal::<Decimal128Type, Decimal128Type>(
947            -1_050_i128, // -1.050 with scale 3
948            5,
949            3,
950            5,
951            1,
952        );
953        assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1
954    }
955
956    #[test]
957    fn test_rescale_decimal_downscale_large_delta_returns_zero() {
958        let result = rescale_decimal::<Decimal32Type, Decimal32Type>(12_345_i32, 9, 9, 9, 4);
959        assert_eq!(result, Some(0_i32));
960    }
961
962    #[test]
963    fn test_rescale_decimal_upscale_overflow_returns_none() {
964        let result = rescale_decimal::<Decimal32Type, Decimal32Type>(9_999_i32, 4, 0, 5, 2);
965        assert_eq!(result, None);
966    }
967
968    #[test]
969    fn test_rescale_decimal_invalid_input_precision_scale_returns_none() {
970        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(123_i128, 39, 39, 38, 38);
971        assert_eq!(result, None);
972    }
973
974    #[test]
975    fn test_rescale_decimal_invalid_output_precision_scale_returns_none() {
976        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(123_i128, 38, 38, 39, 39);
977        assert_eq!(result, None);
978    }
979}