Skip to main content

arrow_cast/cast/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::cast::*;
19
20/// A utility trait that provides checked conversions between
21/// decimal types inspired by [`NumCast`]
22pub trait DecimalCast: Sized {
23    /// Convert the decimal to an i32
24    fn to_i32(self) -> Option<i32>;
25
26    /// Convert the decimal to an i64
27    fn to_i64(self) -> Option<i64>;
28
29    /// Convert the decimal to an i128
30    fn to_i128(self) -> Option<i128>;
31
32    /// Convert the decimal to an i256
33    fn to_i256(self) -> Option<i256>;
34
35    /// Convert a decimal from a decimal
36    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>;
37
38    /// Convert a decimal from a f64
39    fn from_f64(n: f64) -> Option<Self>;
40}
41
42impl DecimalCast for i32 {
43    fn to_i32(self) -> Option<i32> {
44        Some(self)
45    }
46
47    fn to_i64(self) -> Option<i64> {
48        Some(self as i64)
49    }
50
51    fn to_i128(self) -> Option<i128> {
52        Some(self as i128)
53    }
54
55    fn to_i256(self) -> Option<i256> {
56        Some(i256::from_i128(self as i128))
57    }
58
59    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
60        n.to_i32()
61    }
62
63    fn from_f64(n: f64) -> Option<Self> {
64        n.to_i32()
65    }
66}
67
68impl DecimalCast for i64 {
69    fn to_i32(self) -> Option<i32> {
70        i32::try_from(self).ok()
71    }
72
73    fn to_i64(self) -> Option<i64> {
74        Some(self)
75    }
76
77    fn to_i128(self) -> Option<i128> {
78        Some(self as i128)
79    }
80
81    fn to_i256(self) -> Option<i256> {
82        Some(i256::from_i128(self as i128))
83    }
84
85    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
86        n.to_i64()
87    }
88
89    fn from_f64(n: f64) -> Option<Self> {
90        // Call implementation explicitly otherwise this resolves to `to_i64`
91        // in arrow-buffer that behaves differently.
92        num_traits::ToPrimitive::to_i64(&n)
93    }
94}
95
96impl DecimalCast for i128 {
97    fn to_i32(self) -> Option<i32> {
98        i32::try_from(self).ok()
99    }
100
101    fn to_i64(self) -> Option<i64> {
102        i64::try_from(self).ok()
103    }
104
105    fn to_i128(self) -> Option<i128> {
106        Some(self)
107    }
108
109    fn to_i256(self) -> Option<i256> {
110        Some(i256::from_i128(self))
111    }
112
113    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
114        n.to_i128()
115    }
116
117    fn from_f64(n: f64) -> Option<Self> {
118        n.to_i128()
119    }
120}
121
122impl DecimalCast for i256 {
123    fn to_i32(self) -> Option<i32> {
124        self.to_i128().map(|x| i32::try_from(x).ok())?
125    }
126
127    fn to_i64(self) -> Option<i64> {
128        self.to_i128().map(|x| i64::try_from(x).ok())?
129    }
130
131    fn to_i128(self) -> Option<i128> {
132        self.to_i128()
133    }
134
135    fn to_i256(self) -> Option<i256> {
136        Some(self)
137    }
138
139    fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> {
140        n.to_i256()
141    }
142
143    fn from_f64(n: f64) -> Option<Self> {
144        i256::from_f64(n)
145    }
146}
147
148/// Construct closures to upscale decimals from `(input_precision, input_scale)` to
149/// `(output_precision, output_scale)`.
150///
151/// Returns `(f_fallible, f_infallible)` where:
152/// * `f_fallible` yields `None` when the requested cast would overflow
153/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None`
154///   and callers must fall back to `f_fallible`
155///
156/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale`
157/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`.
158/// In that case, the caller should treat this as an overflow for the output scale
159/// and handle it accordingly (e.g., return a cast error).
160#[allow(clippy::type_complexity)]
161fn make_upscaler<I: DecimalType, O: DecimalType>(
162    input_precision: u8,
163    input_scale: i8,
164    output_precision: u8,
165    output_scale: i8,
166) -> Option<(
167    impl Fn(I::Native) -> Option<O::Native>,
168    Option<impl Fn(I::Native) -> O::Native>,
169)>
170where
171    I::Native: DecimalCast + ArrowNativeTypeOp,
172    O::Native: DecimalCast + ArrowNativeTypeOp,
173{
174    let delta_scale = output_scale - input_scale;
175
176    // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...).
177    // Adding 1 yields exactly 10^k without computing a power at runtime.
178    // Using the precomputed table avoids pow(10, k) and its checked/overflow
179    // handling, which is faster and simpler for scaling by 10^delta_scale.
180    let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
181    let mul = max.add_wrapping(O::Native::ONE);
182    let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
183
184    // if the gain in precision (digits) is greater than the multiplication due to scaling
185    // every number will fit into the output type
186    // Example: If we are starting with any number of precision 5 [xxxxx],
187    // then an increase of scale by 3 will have the following effect on the representation:
188    // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
189    // needs to provide at least 8 digits precision
190    let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
191    let f_infallible = is_infallible_cast
192        .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul));
193    Some((f_fallible, f_infallible))
194}
195
196/// Construct closures to downscale decimals from `(input_precision, input_scale)` to
197/// `(output_precision, output_scale)`.
198///
199/// Returns `(f_fallible, f_infallible)` where:
200/// * `f_fallible` yields `None` when the requested cast would overflow
201/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None`
202///   and callers must fall back to `f_fallible`
203///
204/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale`
205/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`.
206/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the
207/// available precision). Callers should therefore produce zero values (preserving nulls) rather
208/// than returning an error.
209#[allow(clippy::type_complexity)]
210fn make_downscaler<I: DecimalType, O: DecimalType>(
211    input_precision: u8,
212    input_scale: i8,
213    output_precision: u8,
214    output_scale: i8,
215) -> Option<(
216    impl Fn(I::Native) -> Option<O::Native>,
217    Option<impl Fn(I::Native) -> O::Native>,
218)>
219where
220    I::Native: DecimalCast + ArrowNativeTypeOp,
221    O::Native: DecimalCast + ArrowNativeTypeOp,
222{
223    let delta_scale = input_scale - output_scale;
224
225    // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the
226    // scale change divides out more digits than the input has precision and the result of the cast
227    // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest
228    // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values
229    // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even
230    // smaller results, which also round to zero. In that case, just return an array of zeros.
231    let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
232
233    let div = max.add_wrapping(I::Native::ONE);
234    let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
235    let half_neg = half.neg_wrapping();
236
237    let f_fallible = move |x: I::Native| {
238        // div is >= 10 and so this cannot overflow
239        let d = x.div_wrapping(div);
240        let r = x.mod_wrapping(div);
241
242        // Round result
243        let adjusted = match x >= I::Native::ZERO {
244            true if r >= half => d.add_wrapping(I::Native::ONE),
245            false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
246            _ => d,
247        };
248        O::Native::from_decimal(adjusted)
249    };
250
251    // if the reduction of the input number through scaling (dividing) is greater
252    // than a possible precision loss (plus potential increase via rounding)
253    // every input number will fit into the output type
254    // Example: If we are starting with any number of precision 5 [xxxxx],
255    // then and decrease the scale by 3 will have the following effect on the representation:
256    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
257    // The rounding may add a digit, so the cast to be infallible,
258    // the output type needs to have at least 3 digits of precision.
259    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
260    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
261    let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
262    let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap());
263    Some((f_fallible, f_infallible))
264}
265
266/// Apply the rescaler function to the value.
267/// If the rescaler is infallible, use the infallible function.
268/// Otherwise, use the fallible function and validate the precision.
269fn apply_rescaler<I: DecimalType, O: DecimalType>(
270    value: I::Native,
271    output_precision: u8,
272    f: impl Fn(I::Native) -> Option<O::Native>,
273    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
274) -> Option<O::Native>
275where
276    I::Native: DecimalCast,
277    O::Native: DecimalCast,
278{
279    if let Some(f_infallible) = f_infallible {
280        Some(f_infallible(value))
281    } else {
282        f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
283    }
284}
285
286/// Rescales a decimal value from `(input_precision, input_scale)` to
287/// `(output_precision, output_scale)` and returns the converted number when it fits
288/// within the output precision.
289///
290/// The function first validates that the requested precision and scale are supported for
291/// both the source and destination decimal types. It then either upscales (multiplying
292/// by an appropriate power of ten) or downscales (dividing with rounding) the input value.
293/// When the scaling factor exceeds the precision table of the destination type, the value
294/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any
295/// possible result would be zero at the requested scale).
296///
297/// This mirrors the column-oriented helpers of decimal casting but operates on a single value
298/// (row-level) instead of an entire array.
299///
300/// Returns `None` if the value cannot be represented with the requested precision.
301pub fn rescale_decimal<I: DecimalType, O: DecimalType>(
302    value: I::Native,
303    input_precision: u8,
304    input_scale: i8,
305    output_precision: u8,
306    output_scale: i8,
307) -> Option<O::Native>
308where
309    I::Native: DecimalCast + ArrowNativeTypeOp,
310    O::Native: DecimalCast + ArrowNativeTypeOp,
311{
312    validate_decimal_precision_and_scale::<I>(input_precision, input_scale).ok()?;
313    validate_decimal_precision_and_scale::<O>(output_precision, output_scale).ok()?;
314
315    if input_scale <= output_scale {
316        let (f, f_infallible) =
317            make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
318        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
319    } else {
320        let Some((f, f_infallible)) =
321            make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
322        else {
323            // Scale reduction exceeds supported precision; result mathematically rounds to zero
324            return Some(O::Native::ZERO);
325        };
326        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
327    }
328}
329
330fn cast_decimal_to_decimal_error<I, O>(
331    output_precision: u8,
332    output_scale: i8,
333) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
334where
335    I: DecimalType,
336    O: DecimalType,
337    I::Native: DecimalCast + ArrowNativeTypeOp,
338    O::Native: DecimalCast + ArrowNativeTypeOp,
339{
340    move |x: I::Native| {
341        ArrowError::CastError(format!(
342            "Cannot cast to {}({}, {}). Overflowing on {:?}",
343            O::PREFIX,
344            output_precision,
345            output_scale,
346            x
347        ))
348    }
349}
350
351fn apply_decimal_cast<I: DecimalType, O: DecimalType>(
352    array: &PrimitiveArray<I>,
353    output_precision: u8,
354    output_scale: i8,
355    f_fallible: impl Fn(I::Native) -> Option<O::Native>,
356    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
357    cast_options: &CastOptions,
358) -> Result<PrimitiveArray<O>, ArrowError>
359where
360    I::Native: DecimalCast + ArrowNativeTypeOp,
361    O::Native: DecimalCast + ArrowNativeTypeOp,
362{
363    let array = if let Some(f_infallible) = f_infallible {
364        array.unary(f_infallible)
365    } else if cast_options.safe {
366        array.unary_opt(|x| {
367            f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
368        })
369    } else {
370        let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
371        array.try_unary(|x| {
372            f_fallible(x).ok_or_else(|| error(x)).and_then(|v| {
373                O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
374            })
375        })?
376    };
377    Ok(array)
378}
379
380fn convert_to_smaller_scale_decimal<I, O>(
381    array: &PrimitiveArray<I>,
382    input_precision: u8,
383    input_scale: i8,
384    output_precision: u8,
385    output_scale: i8,
386    cast_options: &CastOptions,
387) -> Result<PrimitiveArray<O>, ArrowError>
388where
389    I: DecimalType,
390    O: DecimalType,
391    I::Native: DecimalCast + ArrowNativeTypeOp,
392    O::Native: DecimalCast + ArrowNativeTypeOp,
393{
394    if let Some((f_fallible, f_infallible)) =
395        make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
396    {
397        apply_decimal_cast(
398            array,
399            output_precision,
400            output_scale,
401            f_fallible,
402            f_infallible,
403            cast_options,
404        )
405    } else {
406        // Scale reduction exceeds supported precision; result mathematically rounds to zero
407        let zeros = vec![O::Native::ZERO; array.len()];
408        Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned()))
409    }
410}
411
412fn convert_to_bigger_or_equal_scale_decimal<I, O>(
413    array: &PrimitiveArray<I>,
414    input_precision: u8,
415    input_scale: i8,
416    output_precision: u8,
417    output_scale: i8,
418    cast_options: &CastOptions,
419) -> Result<PrimitiveArray<O>, ArrowError>
420where
421    I: DecimalType,
422    O: DecimalType,
423    I::Native: DecimalCast + ArrowNativeTypeOp,
424    O::Native: DecimalCast + ArrowNativeTypeOp,
425{
426    if let Some((f, f_infallible)) =
427        make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
428    {
429        apply_decimal_cast(
430            array,
431            output_precision,
432            output_scale,
433            f,
434            f_infallible,
435            cast_options,
436        )
437    } else {
438        // Scale increase exceeds supported precision; return overflow error
439        Err(ArrowError::CastError(format!(
440            "Cannot cast to {}({}, {}). Value overflows for output scale",
441            O::PREFIX,
442            output_precision,
443            output_scale
444        )))
445    }
446}
447
448// Only support one type of decimal cast operations
449pub(crate) fn cast_decimal_to_decimal_same_type<T>(
450    array: &PrimitiveArray<T>,
451    input_precision: u8,
452    input_scale: i8,
453    output_precision: u8,
454    output_scale: i8,
455    cast_options: &CastOptions,
456) -> Result<ArrayRef, ArrowError>
457where
458    T: DecimalType,
459    T::Native: DecimalCast + ArrowNativeTypeOp,
460{
461    let array: PrimitiveArray<T> =
462        if input_scale == output_scale && input_precision <= output_precision {
463            array.clone()
464        } else if input_scale <= output_scale {
465            convert_to_bigger_or_equal_scale_decimal::<T, T>(
466                array,
467                input_precision,
468                input_scale,
469                output_precision,
470                output_scale,
471                cast_options,
472            )?
473        } else {
474            // input_scale > output_scale
475            convert_to_smaller_scale_decimal::<T, T>(
476                array,
477                input_precision,
478                input_scale,
479                output_precision,
480                output_scale,
481                cast_options,
482            )?
483        };
484
485    Ok(Arc::new(array.with_precision_and_scale(
486        output_precision,
487        output_scale,
488    )?))
489}
490
491// Support two different types of decimal cast operations
492pub(crate) fn cast_decimal_to_decimal<I, O>(
493    array: &PrimitiveArray<I>,
494    input_precision: u8,
495    input_scale: i8,
496    output_precision: u8,
497    output_scale: i8,
498    cast_options: &CastOptions,
499) -> Result<ArrayRef, ArrowError>
500where
501    I: DecimalType,
502    O: DecimalType,
503    I::Native: DecimalCast + ArrowNativeTypeOp,
504    O::Native: DecimalCast + ArrowNativeTypeOp,
505{
506    let array: PrimitiveArray<O> = if input_scale > output_scale {
507        convert_to_smaller_scale_decimal::<I, O>(
508            array,
509            input_precision,
510            input_scale,
511            output_precision,
512            output_scale,
513            cast_options,
514        )?
515    } else {
516        convert_to_bigger_or_equal_scale_decimal::<I, O>(
517            array,
518            input_precision,
519            input_scale,
520            output_precision,
521            output_scale,
522            cast_options,
523        )?
524    };
525
526    Ok(Arc::new(array.with_precision_and_scale(
527        output_precision,
528        output_scale,
529    )?))
530}
531
532/// Parses given string to specified decimal native (i128/i256) based on given
533/// scale. Returns an `Err` if it cannot parse given string.
534pub fn parse_string_to_decimal_native<T: DecimalType>(
535    value_str: &str,
536    scale: usize,
537) -> Result<T::Native, ArrowError>
538where
539    T::Native: DecimalCast + ArrowNativeTypeOp,
540{
541    let value_str = value_str.trim();
542    let parts: Vec<&str> = value_str.split('.').collect();
543    if parts.len() > 2 {
544        return Err(ArrowError::InvalidArgumentError(format!(
545            "Invalid decimal format: {value_str:?}"
546        )));
547    }
548
549    let (negative, first_part) = if parts[0].is_empty() {
550        (false, parts[0])
551    } else {
552        match parts[0].as_bytes()[0] {
553            b'-' => (true, &parts[0][1..]),
554            b'+' => (false, &parts[0][1..]),
555            _ => (false, parts[0]),
556        }
557    };
558
559    let integers = first_part;
560    let decimals = if parts.len() == 2 { parts[1] } else { "" };
561
562    if integers.is_empty() && decimals.is_empty() {
563        return Err(ArrowError::InvalidArgumentError(format!(
564            "Invalid decimal format: {value_str:?}"
565        )));
566    }
567
568    if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() {
569        return Err(ArrowError::InvalidArgumentError(format!(
570            "Invalid decimal format: {value_str:?}"
571        )));
572    }
573
574    if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() {
575        return Err(ArrowError::InvalidArgumentError(format!(
576            "Invalid decimal format: {value_str:?}"
577        )));
578    }
579
580    // Adjust decimal based on scale
581    let mut number_decimals = if decimals.len() > scale {
582        let decimal_number = i256::from_string(decimals).ok_or_else(|| {
583            ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}"))
584        })?;
585
586        let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?;
587
588        let half = div.div_wrapping(i256::from_i128(2));
589        let half_neg = half.neg_wrapping();
590
591        let d = decimal_number.div_wrapping(div);
592        let r = decimal_number.mod_wrapping(div);
593
594        // Round result
595        let adjusted = match decimal_number >= i256::ZERO {
596            true if r >= half => d.add_wrapping(i256::ONE),
597            false if r <= half_neg => d.sub_wrapping(i256::ONE),
598            _ => d,
599        };
600
601        let integers = if !integers.is_empty() {
602            i256::from_string(integers)
603                .ok_or_else(|| {
604                    ArrowError::InvalidArgumentError(format!(
605                        "Cannot parse decimal format: {value_str}"
606                    ))
607                })
608                .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))?
609        } else {
610            i256::ZERO
611        };
612
613        format!("{}", integers.add_wrapping(adjusted))
614    } else {
615        let padding = if scale > decimals.len() { scale } else { 0 };
616
617        let decimals = format!("{decimals:0<padding$}");
618        format!("{integers}{decimals}")
619    };
620
621    if negative {
622        number_decimals.insert(0, '-');
623    }
624
625    let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| {
626        ArrowError::InvalidArgumentError(format!(
627            "Cannot convert {} to {}: Overflow",
628            value_str,
629            T::PREFIX
630        ))
631    })?;
632
633    T::Native::from_decimal(value).ok_or_else(|| {
634        ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX))
635    })
636}
637
638pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
639    from: &'a S,
640    precision: u8,
641    scale: i8,
642    cast_options: &CastOptions,
643) -> Result<PrimitiveArray<T>, ArrowError>
644where
645    T: DecimalType,
646    T::Native: DecimalCast + ArrowNativeTypeOp,
647    &'a S: StringArrayType<'a>,
648{
649    if cast_options.safe {
650        let iter = from.iter().map(|v| {
651            v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
652                .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
653        });
654        // Benefit:
655        //     20% performance improvement
656        // Soundness:
657        //     The iterator is trustedLen because it comes from an `StringArray`.
658        Ok(unsafe {
659            PrimitiveArray::<T>::from_trusted_len_iter(iter)
660                .with_precision_and_scale(precision, scale)?
661        })
662    } else {
663        let vec = from
664            .iter()
665            .map(|v| {
666                v.map(|v| {
667                    parse_string_to_decimal_native::<T>(v, scale as usize)
668                        .map_err(|_| {
669                            ArrowError::CastError(format!(
670                                "Cannot cast string '{v}' to value of {} type",
671                                T::DATA_TYPE,
672                            ))
673                        })
674                        .and_then(|v| T::validate_decimal_precision(v, precision, scale).map(|_| v))
675                })
676                .transpose()
677            })
678            .collect::<Result<Vec<_>, _>>()?;
679        // Benefit:
680        //     20% performance improvement
681        // Soundness:
682        //     The iterator is trustedLen because it comes from an `StringArray`.
683        Ok(unsafe {
684            PrimitiveArray::<T>::from_trusted_len_iter(vec.iter())
685                .with_precision_and_scale(precision, scale)?
686        })
687    }
688}
689
690pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
691    from: &GenericStringArray<Offset>,
692    precision: u8,
693    scale: i8,
694    cast_options: &CastOptions,
695) -> Result<PrimitiveArray<T>, ArrowError>
696where
697    T: DecimalType,
698    T::Native: DecimalCast + ArrowNativeTypeOp,
699{
700    generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
701        from,
702        precision,
703        scale,
704        cast_options,
705    )
706}
707
708pub(crate) fn string_view_to_decimal_cast<T>(
709    from: &StringViewArray,
710    precision: u8,
711    scale: i8,
712    cast_options: &CastOptions,
713) -> Result<PrimitiveArray<T>, ArrowError>
714where
715    T: DecimalType,
716    T::Native: DecimalCast + ArrowNativeTypeOp,
717{
718    generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
719}
720
721/// Cast Utf8 to decimal
722pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
723    from: &dyn Array,
724    precision: u8,
725    scale: i8,
726    cast_options: &CastOptions,
727) -> Result<ArrayRef, ArrowError>
728where
729    T: DecimalType,
730    T::Native: DecimalCast + ArrowNativeTypeOp,
731{
732    if scale < 0 {
733        return Err(ArrowError::InvalidArgumentError(format!(
734            "Cannot cast string to decimal with negative scale {scale}"
735        )));
736    }
737
738    if scale > T::MAX_SCALE {
739        return Err(ArrowError::InvalidArgumentError(format!(
740            "Cannot cast string to decimal greater than maximum scale {}",
741            T::MAX_SCALE
742        )));
743    }
744
745    let result = match from.data_type() {
746        DataType::Utf8View => string_view_to_decimal_cast::<T>(
747            from.as_any().downcast_ref::<StringViewArray>().unwrap(),
748            precision,
749            scale,
750            cast_options,
751        )?,
752        DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
753            from.as_any()
754                .downcast_ref::<GenericStringArray<Offset>>()
755                .unwrap(),
756            precision,
757            scale,
758            cast_options,
759        )?,
760        other => {
761            return Err(ArrowError::ComputeError(format!(
762                "Cannot cast {other:?} to decimal",
763            )));
764        }
765    };
766
767    Ok(Arc::new(result))
768}
769
770pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>(
771    array: &PrimitiveArray<T>,
772    precision: u8,
773    scale: i8,
774    cast_options: &CastOptions,
775) -> Result<ArrayRef, ArrowError>
776where
777    <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
778    D: DecimalType + ArrowPrimitiveType,
779    <D as ArrowPrimitiveType>::Native: DecimalCast,
780{
781    let mul = 10_f64.powi(scale as i32);
782
783    if cast_options.safe {
784        array
785            .unary_opt::<_, D>(|v| {
786                single_float_to_decimal::<D>(v.as_(), mul)
787                    .filter(|v| D::is_valid_decimal_precision(*v, precision))
788            })
789            .with_precision_and_scale(precision, scale)
790            .map(|a| Arc::new(a) as ArrayRef)
791    } else {
792        array
793            .try_unary::<_, D, _>(|v| {
794                single_float_to_decimal::<D>(v.as_(), mul)
795                    .ok_or_else(|| {
796                        ArrowError::CastError(format!(
797                            "Cannot cast to {}({}, {}). Overflowing on {:?}",
798                            D::PREFIX,
799                            precision,
800                            scale,
801                            v
802                        ))
803                    })
804                    .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v))
805            })?
806            .with_precision_and_scale(precision, scale)
807            .map(|a| Arc::new(a) as ArrayRef)
808    }
809}
810
811/// Cast a single floating point value to a decimal native with the given multiple.
812/// Returns `None` if the value cannot be represented with the requested precision.
813#[inline(always)]
814pub fn single_float_to_decimal<D>(input: f64, mul: f64) -> Option<D::Native>
815where
816    D: DecimalType + ArrowPrimitiveType,
817    <D as ArrowPrimitiveType>::Native: DecimalCast,
818{
819    D::Native::from_f64((mul * input).round())
820}
821
822pub(crate) fn cast_decimal_to_integer<D, T>(
823    array: &dyn Array,
824    base: D::Native,
825    scale: i8,
826    cast_options: &CastOptions,
827) -> Result<ArrayRef, ArrowError>
828where
829    T: ArrowPrimitiveType,
830    <T as ArrowPrimitiveType>::Native: NumCast,
831    D: DecimalType + ArrowPrimitiveType,
832    <D as ArrowPrimitiveType>::Native: ToPrimitive,
833{
834    let array = array.as_primitive::<D>();
835
836    let div: D::Native = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| {
837        ArrowError::CastError(format!(
838            "Cannot cast to {:?}. The scale {} causes overflow.",
839            D::PREFIX,
840            scale,
841        ))
842    })?;
843
844    let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len());
845
846    if scale < 0 {
847        match cast_options.safe {
848            true => {
849                for i in 0..array.len() {
850                    if array.is_null(i) {
851                        value_builder.append_null();
852                    } else {
853                        let v = array
854                            .value(i)
855                            .mul_checked(div)
856                            .ok()
857                            .and_then(<T::Native as NumCast>::from::<D::Native>);
858                        value_builder.append_option(v);
859                    }
860                }
861            }
862            false => {
863                for i in 0..array.len() {
864                    if array.is_null(i) {
865                        value_builder.append_null();
866                    } else {
867                        let v = array.value(i).mul_checked(div)?;
868
869                        let value =
870                            <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
871                                ArrowError::CastError(format!(
872                                    "value of {:?} is out of range {}",
873                                    v,
874                                    T::DATA_TYPE
875                                ))
876                            })?;
877
878                        value_builder.append_value(value);
879                    }
880                }
881            }
882        }
883    } else {
884        match cast_options.safe {
885            true => {
886                for i in 0..array.len() {
887                    if array.is_null(i) {
888                        value_builder.append_null();
889                    } else {
890                        let v = array
891                            .value(i)
892                            .div_checked(div)
893                            .ok()
894                            .and_then(<T::Native as NumCast>::from::<D::Native>);
895                        value_builder.append_option(v);
896                    }
897                }
898            }
899            false => {
900                for i in 0..array.len() {
901                    if array.is_null(i) {
902                        value_builder.append_null();
903                    } else {
904                        let v = array.value(i).div_checked(div)?;
905
906                        let value =
907                            <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| {
908                                ArrowError::CastError(format!(
909                                    "value of {:?} is out of range {}",
910                                    v,
911                                    T::DATA_TYPE
912                                ))
913                            })?;
914
915                        value_builder.append_value(value);
916                    }
917                }
918            }
919        }
920    }
921    Ok(Arc::new(value_builder.finish()))
922}
923
924/// Cast a decimal array to a floating point array.
925///
926/// Conversion is lossy and follows standard floating point semantics. Values
927/// that exceed the representable range become `INFINITY` or `-INFINITY` without
928/// returning an error.
929pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>(
930    array: &dyn Array,
931    op: F,
932) -> Result<ArrayRef, ArrowError>
933where
934    F: Fn(D::Native) -> T::Native,
935{
936    let array = array.as_primitive::<D>();
937    let array = array.unary::<_, T>(op);
938    Ok(Arc::new(array))
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944
945    #[test]
946    fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
947        assert_eq!(
948            parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
949            0_i128
950        );
951        assert_eq!(
952            parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
953            0_i128
954        );
955
956        assert_eq!(
957            parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
958            123_i128
959        );
960        assert_eq!(
961            parse_string_to_decimal_native::<Decimal128Type>("123", 5)?,
962            12300000_i128
963        );
964
965        assert_eq!(
966            parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?,
967            123_i128
968        );
969        assert_eq!(
970            parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?,
971            12345000_i128
972        );
973
974        assert_eq!(
975            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?,
976            123_i128
977        );
978        assert_eq!(
979            parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?,
980            12345679_i128
981        );
982
983        for value in ["", " ", ".", "+", "-", "+.", "-."] {
984            assert!(
985                parse_string_to_decimal_native::<Decimal128Type>(value, 2).is_err(),
986                "expected {value:?} to fail parsing as Decimal128"
987            );
988            assert!(
989                parse_string_to_decimal_native::<Decimal256Type>(value, 2).is_err(),
990                "expected {value:?} to fail parsing as Decimal256"
991            );
992        }
993        Ok(())
994    }
995
996    #[test]
997    fn test_rescale_decimal_upscale_within_precision() {
998        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(
999            12_345_i128, // 123.45 with scale 2
1000            5,
1001            2,
1002            8,
1003            5,
1004        );
1005        assert_eq!(result, Some(12_345_000_i128));
1006    }
1007
1008    #[test]
1009    fn test_rescale_decimal_downscale_rounds_half_away_from_zero() {
1010        let positive = rescale_decimal::<Decimal128Type, Decimal128Type>(
1011            1_050_i128, // 1.050 with scale 3
1012            5, 3, 5, 1,
1013        );
1014        assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1
1015
1016        let negative = rescale_decimal::<Decimal128Type, Decimal128Type>(
1017            -1_050_i128, // -1.050 with scale 3
1018            5,
1019            3,
1020            5,
1021            1,
1022        );
1023        assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1
1024    }
1025
1026    #[test]
1027    fn test_rescale_decimal_downscale_large_delta_returns_zero() {
1028        let result = rescale_decimal::<Decimal32Type, Decimal32Type>(12_345_i32, 9, 9, 9, 4);
1029        assert_eq!(result, Some(0_i32));
1030    }
1031
1032    #[test]
1033    fn test_rescale_decimal_upscale_overflow_returns_none() {
1034        let result = rescale_decimal::<Decimal32Type, Decimal32Type>(9_999_i32, 4, 0, 5, 2);
1035        assert_eq!(result, None);
1036    }
1037
1038    #[test]
1039    fn test_rescale_decimal_invalid_input_precision_scale_returns_none() {
1040        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(123_i128, 39, 39, 38, 38);
1041        assert_eq!(result, None);
1042    }
1043
1044    #[test]
1045    fn test_rescale_decimal_invalid_output_precision_scale_returns_none() {
1046        let result = rescale_decimal::<Decimal128Type, Decimal128Type>(123_i128, 38, 38, 39, 39);
1047        assert_eq!(result, None);
1048    }
1049}