arrow_arith/
numeric.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`]
19
20use std::cmp::Ordering;
21use std::fmt::Formatter;
22use std::sync::Arc;
23
24use arrow_array::cast::AsArray;
25use arrow_array::timezone::Tz;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{ArrowNativeType, IntervalDayTime, IntervalMonthDayNano};
29use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};
30
31use crate::arity::{binary, try_binary};
32
33/// Perform `lhs + rhs`, returning an error on overflow
34pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
35    arithmetic_op(Op::Add, lhs, rhs)
36}
37
38/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`]
39pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
40    arithmetic_op(Op::AddWrapping, lhs, rhs)
41}
42
43/// Perform `lhs - rhs`, returning an error on overflow
44pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
45    arithmetic_op(Op::Sub, lhs, rhs)
46}
47
48/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`]
49pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
50    arithmetic_op(Op::SubWrapping, lhs, rhs)
51}
52
53/// Perform `lhs * rhs`, returning an error on overflow
54pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
55    arithmetic_op(Op::Mul, lhs, rhs)
56}
57
58/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`]
59pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
60    arithmetic_op(Op::MulWrapping, lhs, rhs)
61}
62
63/// Perform `lhs / rhs`
64///
65/// Overflow or division by zero will result in an error, with exception to
66/// floating point numbers, which instead follow the IEEE 754 rules
67pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
68    arithmetic_op(Op::Div, lhs, rhs)
69}
70
71/// Perform `lhs % rhs`
72///
73/// Division by zero will result in an error, with exception to
74/// floating point numbers, which instead follow the IEEE 754 rules
75///
76/// `signed_integer::MIN % -1` will not result in an error but return 0
77pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
78    arithmetic_op(Op::Rem, lhs, rhs)
79}
80
81macro_rules! neg_checked {
82    ($t:ty, $a:ident) => {{
83        let array = $a
84            .as_primitive::<$t>()
85            .try_unary::<_, $t, _>(|x| x.neg_checked())?;
86        Ok(Arc::new(array))
87    }};
88}
89
90macro_rules! neg_wrapping {
91    ($t:ty, $a:ident) => {{
92        let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping());
93        Ok(Arc::new(array))
94    }};
95}
96
97/// Negates each element of  `array`, returning an error on overflow
98///
99/// Note: negation of unsigned arrays is not supported and will return in an error,
100/// for wrapping unsigned negation consider using [`neg_wrapping`][neg_wrapping()]
101pub fn neg(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
102    use DataType::*;
103    use IntervalUnit::*;
104    use TimeUnit::*;
105
106    match array.data_type() {
107        Int8 => neg_checked!(Int8Type, array),
108        Int16 => neg_checked!(Int16Type, array),
109        Int32 => neg_checked!(Int32Type, array),
110        Int64 => neg_checked!(Int64Type, array),
111        Float16 => neg_wrapping!(Float16Type, array),
112        Float32 => neg_wrapping!(Float32Type, array),
113        Float64 => neg_wrapping!(Float64Type, array),
114        Decimal128(p, s) => {
115            let a = array
116                .as_primitive::<Decimal128Type>()
117                .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?;
118
119            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
120        }
121        Decimal256(p, s) => {
122            let a = array
123                .as_primitive::<Decimal256Type>()
124                .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?;
125
126            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
127        }
128        Duration(Second) => neg_checked!(DurationSecondType, array),
129        Duration(Millisecond) => neg_checked!(DurationMillisecondType, array),
130        Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array),
131        Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array),
132        Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array),
133        Interval(DayTime) => {
134            let a = array
135                .as_primitive::<IntervalDayTimeType>()
136                .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| {
137                    let (days, ms) = IntervalDayTimeType::to_parts(x);
138                    Ok(IntervalDayTimeType::make_value(
139                        days.neg_checked()?,
140                        ms.neg_checked()?,
141                    ))
142                })?;
143            Ok(Arc::new(a))
144        }
145        Interval(MonthDayNano) => {
146            let a = array
147                .as_primitive::<IntervalMonthDayNanoType>()
148                .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| {
149                    let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x);
150                    Ok(IntervalMonthDayNanoType::make_value(
151                        months.neg_checked()?,
152                        days.neg_checked()?,
153                        nanos.neg_checked()?,
154                    ))
155                })?;
156            Ok(Arc::new(a))
157        }
158        t => Err(ArrowError::InvalidArgumentError(format!(
159            "Invalid arithmetic operation: !{t}"
160        ))),
161    }
162}
163
164/// Negates each element of  `array`, wrapping on overflow for [`DataType::is_integer`]
165pub fn neg_wrapping(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
166    downcast_integer! {
167        array.data_type() => (neg_wrapping, array),
168        _ => neg(array),
169    }
170}
171
172/// An enumeration of arithmetic operations
173///
174/// This allows sharing the type dispatch logic across the various kernels
175#[derive(Debug, Copy, Clone)]
176enum Op {
177    AddWrapping,
178    Add,
179    SubWrapping,
180    Sub,
181    MulWrapping,
182    Mul,
183    Div,
184    Rem,
185}
186
187impl std::fmt::Display for Op {
188    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
189        match self {
190            Op::AddWrapping | Op::Add => write!(f, "+"),
191            Op::SubWrapping | Op::Sub => write!(f, "-"),
192            Op::MulWrapping | Op::Mul => write!(f, "*"),
193            Op::Div => write!(f, "/"),
194            Op::Rem => write!(f, "%"),
195        }
196    }
197}
198
199impl Op {
200    fn commutative(&self) -> bool {
201        matches!(self, Self::Add | Self::AddWrapping)
202    }
203}
204
205/// Dispatch the given `op` to the appropriate specialized kernel
206fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
207    use DataType::*;
208    use IntervalUnit::*;
209    use TimeUnit::*;
210
211    macro_rules! integer_helper {
212        ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => {
213            integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar)
214        };
215    }
216
217    let (l, l_scalar) = lhs.get();
218    let (r, r_scalar) = rhs.get();
219    downcast_integer! {
220        l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar),
221        (Float16, Float16) => float_op::<Float16Type>(op, l, l_scalar, r, r_scalar),
222        (Float32, Float32) => float_op::<Float32Type>(op, l, l_scalar, r, r_scalar),
223        (Float64, Float64) => float_op::<Float64Type>(op, l, l_scalar, r, r_scalar),
224        (Timestamp(Second, _), _) => timestamp_op::<TimestampSecondType>(op, l, l_scalar, r, r_scalar),
225        (Timestamp(Millisecond, _), _) => timestamp_op::<TimestampMillisecondType>(op, l, l_scalar, r, r_scalar),
226        (Timestamp(Microsecond, _), _) => timestamp_op::<TimestampMicrosecondType>(op, l, l_scalar, r, r_scalar),
227        (Timestamp(Nanosecond, _), _) => timestamp_op::<TimestampNanosecondType>(op, l, l_scalar, r, r_scalar),
228        (Duration(Second), Duration(Second)) => duration_op::<DurationSecondType>(op, l, l_scalar, r, r_scalar),
229        (Duration(Millisecond), Duration(Millisecond)) => duration_op::<DurationMillisecondType>(op, l, l_scalar, r, r_scalar),
230        (Duration(Microsecond), Duration(Microsecond)) => duration_op::<DurationMicrosecondType>(op, l, l_scalar, r, r_scalar),
231        (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::<DurationNanosecondType>(op, l, l_scalar, r, r_scalar),
232        (Interval(YearMonth), Interval(YearMonth)) => interval_op::<IntervalYearMonthType>(op, l, l_scalar, r, r_scalar),
233        (Interval(DayTime), Interval(DayTime)) => interval_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
234        (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
235        (Date32, _) => date_op::<Date32Type>(op, l, l_scalar, r, r_scalar),
236        (Date64, _) => date_op::<Date64Type>(op, l, l_scalar, r, r_scalar),
237        (Decimal128(_, _), Decimal128(_, _)) => decimal_op::<Decimal128Type>(op, l, l_scalar, r, r_scalar),
238        (Decimal256(_, _), Decimal256(_, _)) => decimal_op::<Decimal256Type>(op, l, l_scalar, r, r_scalar),
239        (l_t, r_t) => match (l_t, r_t) {
240            (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => {
241                arithmetic_op(op, rhs, lhs)
242            }
243            _ => Err(ArrowError::InvalidArgumentError(
244              format!("Invalid arithmetic operation: {l_t} {op} {r_t}")
245            ))
246        }
247    }
248}
249
250/// Perform an infallible binary operation on potentially scalar inputs
251macro_rules! op {
252    ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
253        match ($l_s, $r_s) {
254            (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?,
255            (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) {
256                None => PrimitiveArray::new_null($r.len()),
257                Some($l) => $r.unary(|$r| $op),
258            },
259            (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) {
260                None => PrimitiveArray::new_null($l.len()),
261                Some($r) => $l.unary(|$l| $op),
262            },
263        }
264    };
265}
266
267/// Same as `op` but with a type hint for the returned array
268macro_rules! op_ref {
269    ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
270        let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op);
271        Arc::new(array)
272    }};
273}
274
275/// Perform a fallible binary operation on potentially scalar inputs
276macro_rules! try_op {
277    ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
278        match ($l_s, $r_s) {
279            (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?,
280            (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) {
281                None => PrimitiveArray::new_null($r.len()),
282                Some($l) => $r.try_unary(|$r| $op)?,
283            },
284            (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) {
285                None => PrimitiveArray::new_null($l.len()),
286                Some($r) => $l.try_unary(|$l| $op)?,
287            },
288        }
289    };
290}
291
292/// Same as `try_op` but with a type hint for the returned array
293macro_rules! try_op_ref {
294    ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
295        let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op);
296        Arc::new(array)
297    }};
298}
299
300/// Perform an arithmetic operation on integers
301fn integer_op<T: ArrowPrimitiveType>(
302    op: Op,
303    l: &dyn Array,
304    l_s: bool,
305    r: &dyn Array,
306    r_s: bool,
307) -> Result<ArrayRef, ArrowError> {
308    let l = l.as_primitive::<T>();
309    let r = r.as_primitive::<T>();
310    let array: PrimitiveArray<T> = match op {
311        Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)),
312        Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)),
313        Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
314        Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)),
315        Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
316        Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)),
317        Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)),
318        Op::Rem => try_op!(l, l_s, r, r_s, {
319            if r.is_zero() {
320                Err(ArrowError::DivideByZero)
321            } else {
322                Ok(l.mod_wrapping(r))
323            }
324        }),
325    };
326    Ok(Arc::new(array))
327}
328
329/// Perform an arithmetic operation on floats
330fn float_op<T: ArrowPrimitiveType>(
331    op: Op,
332    l: &dyn Array,
333    l_s: bool,
334    r: &dyn Array,
335    r_s: bool,
336) -> Result<ArrayRef, ArrowError> {
337    let l = l.as_primitive::<T>();
338    let r = r.as_primitive::<T>();
339    let array: PrimitiveArray<T> = match op {
340        Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)),
341        Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
342        Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
343        Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)),
344        Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
345    };
346    Ok(Arc::new(array))
347}
348
349/// Arithmetic trait for timestamp arrays
350trait TimestampOp: ArrowTimestampType {
351    type Duration: ArrowPrimitiveType<Native = i64>;
352
353    fn add_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option<i64>;
354    fn add_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option<i64>;
355    fn add_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option<i64>;
356
357    fn sub_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option<i64>;
358    fn sub_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option<i64>;
359    fn sub_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option<i64>;
360}
361
362macro_rules! timestamp {
363    ($t:ty, $d:ty) => {
364        impl TimestampOp for $t {
365            type Duration = $d;
366
367            fn add_year_month(left: i64, right: i32, tz: Tz) -> Option<i64> {
368                Self::add_year_months(left, right, tz)
369            }
370
371            fn add_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option<i64> {
372                Self::add_day_time(left, right, tz)
373            }
374
375            fn add_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option<i64> {
376                Self::add_month_day_nano(left, right, tz)
377            }
378
379            fn sub_year_month(left: i64, right: i32, tz: Tz) -> Option<i64> {
380                Self::subtract_year_months(left, right, tz)
381            }
382
383            fn sub_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option<i64> {
384                Self::subtract_day_time(left, right, tz)
385            }
386
387            fn sub_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option<i64> {
388                Self::subtract_month_day_nano(left, right, tz)
389            }
390        }
391    };
392}
393timestamp!(TimestampSecondType, DurationSecondType);
394timestamp!(TimestampMillisecondType, DurationMillisecondType);
395timestamp!(TimestampMicrosecondType, DurationMicrosecondType);
396timestamp!(TimestampNanosecondType, DurationNanosecondType);
397
398/// Perform arithmetic operation on a timestamp array
399fn timestamp_op<T: TimestampOp>(
400    op: Op,
401    l: &dyn Array,
402    l_s: bool,
403    r: &dyn Array,
404    r_s: bool,
405) -> Result<ArrayRef, ArrowError> {
406    use DataType::*;
407    use IntervalUnit::*;
408
409    let l = l.as_primitive::<T>();
410    let l_tz: Tz = l.timezone().unwrap_or("+00:00").parse()?;
411
412    let array: PrimitiveArray<T> = match (op, r.data_type()) {
413        (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => {
414            let r = r.as_primitive::<T>();
415            return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r)));
416        }
417
418        (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => {
419            let r = r.as_primitive::<T::Duration>();
420            try_op!(l, l_s, r, r_s, l.add_checked(r))
421        }
422        (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => {
423            let r = r.as_primitive::<T::Duration>();
424            try_op!(l, l_s, r, r_s, l.sub_checked(r))
425        }
426
427        (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
428            let r = r.as_primitive::<IntervalYearMonthType>();
429            try_op!(
430                l,
431                l_s,
432                r,
433                r_s,
434                T::add_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError(
435                    "Timestamp out of range".to_string()
436                ))
437            )
438        }
439        (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
440            let r = r.as_primitive::<IntervalYearMonthType>();
441            try_op!(
442                l,
443                l_s,
444                r,
445                r_s,
446                T::sub_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError(
447                    "Timestamp out of range".to_string()
448                ))
449            )
450        }
451
452        (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
453            let r = r.as_primitive::<IntervalDayTimeType>();
454            try_op!(
455                l,
456                l_s,
457                r,
458                r_s,
459                T::add_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError(
460                    "Timestamp out of range".to_string()
461                ))
462            )
463        }
464        (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
465            let r = r.as_primitive::<IntervalDayTimeType>();
466            try_op!(
467                l,
468                l_s,
469                r,
470                r_s,
471                T::sub_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError(
472                    "Timestamp out of range".to_string()
473                ))
474            )
475        }
476
477        (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
478            let r = r.as_primitive::<IntervalMonthDayNanoType>();
479            try_op!(
480                l,
481                l_s,
482                r,
483                r_s,
484                T::add_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError(
485                    "Timestamp out of range".to_string()
486                ))
487            )
488        }
489        (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
490            let r = r.as_primitive::<IntervalMonthDayNanoType>();
491            try_op!(
492                l,
493                l_s,
494                r,
495                r_s,
496                T::sub_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError(
497                    "Timestamp out of range".to_string()
498                ))
499            )
500        }
501        _ => {
502            return Err(ArrowError::InvalidArgumentError(format!(
503                "Invalid timestamp arithmetic operation: {} {op} {}",
504                l.data_type(),
505                r.data_type()
506            )))
507        }
508    };
509    Ok(Arc::new(array.with_timezone_opt(l.timezone())))
510}
511
512/// Arithmetic trait for date arrays
513///
514/// Note: these should be fallible (#4456)
515trait DateOp: ArrowTemporalType {
516    fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
517    fn add_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native;
518    fn add_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native;
519
520    fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
521    fn sub_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native;
522    fn sub_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native;
523}
524
525macro_rules! date {
526    ($t:ty) => {
527        impl DateOp for $t {
528            fn add_year_month(left: Self::Native, right: i32) -> Self::Native {
529                Self::add_year_months(left, right)
530            }
531
532            fn add_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native {
533                Self::add_day_time(left, right)
534            }
535
536            fn add_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native {
537                Self::add_month_day_nano(left, right)
538            }
539
540            fn sub_year_month(left: Self::Native, right: i32) -> Self::Native {
541                Self::subtract_year_months(left, right)
542            }
543
544            fn sub_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native {
545                Self::subtract_day_time(left, right)
546            }
547
548            fn sub_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native {
549                Self::subtract_month_day_nano(left, right)
550            }
551        }
552    };
553}
554date!(Date32Type);
555date!(Date64Type);
556
557/// Arithmetic trait for interval arrays
558trait IntervalOp: ArrowPrimitiveType {
559    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError>;
560    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError>;
561}
562
563impl IntervalOp for IntervalYearMonthType {
564    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
565        left.add_checked(right)
566    }
567
568    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
569        left.sub_checked(right)
570    }
571}
572
573impl IntervalOp for IntervalDayTimeType {
574    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
575        let (l_days, l_ms) = Self::to_parts(left);
576        let (r_days, r_ms) = Self::to_parts(right);
577        let days = l_days.add_checked(r_days)?;
578        let ms = l_ms.add_checked(r_ms)?;
579        Ok(Self::make_value(days, ms))
580    }
581
582    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
583        let (l_days, l_ms) = Self::to_parts(left);
584        let (r_days, r_ms) = Self::to_parts(right);
585        let days = l_days.sub_checked(r_days)?;
586        let ms = l_ms.sub_checked(r_ms)?;
587        Ok(Self::make_value(days, ms))
588    }
589}
590
591impl IntervalOp for IntervalMonthDayNanoType {
592    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
593        let (l_months, l_days, l_nanos) = Self::to_parts(left);
594        let (r_months, r_days, r_nanos) = Self::to_parts(right);
595        let months = l_months.add_checked(r_months)?;
596        let days = l_days.add_checked(r_days)?;
597        let nanos = l_nanos.add_checked(r_nanos)?;
598        Ok(Self::make_value(months, days, nanos))
599    }
600
601    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
602        let (l_months, l_days, l_nanos) = Self::to_parts(left);
603        let (r_months, r_days, r_nanos) = Self::to_parts(right);
604        let months = l_months.sub_checked(r_months)?;
605        let days = l_days.sub_checked(r_days)?;
606        let nanos = l_nanos.sub_checked(r_nanos)?;
607        Ok(Self::make_value(months, days, nanos))
608    }
609}
610
611/// Perform arithmetic operation on an interval array
612fn interval_op<T: IntervalOp>(
613    op: Op,
614    l: &dyn Array,
615    l_s: bool,
616    r: &dyn Array,
617    r_s: bool,
618) -> Result<ArrayRef, ArrowError> {
619    let l = l.as_primitive::<T>();
620    let r = r.as_primitive::<T>();
621    match op {
622        Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))),
623        Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))),
624        _ => Err(ArrowError::InvalidArgumentError(format!(
625            "Invalid interval arithmetic operation: {} {op} {}",
626            l.data_type(),
627            r.data_type()
628        ))),
629    }
630}
631
632fn duration_op<T: ArrowPrimitiveType>(
633    op: Op,
634    l: &dyn Array,
635    l_s: bool,
636    r: &dyn Array,
637    r_s: bool,
638) -> Result<ArrayRef, ArrowError> {
639    let l = l.as_primitive::<T>();
640    let r = r.as_primitive::<T>();
641    match op {
642        Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))),
643        Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))),
644        _ => Err(ArrowError::InvalidArgumentError(format!(
645            "Invalid duration arithmetic operation: {} {op} {}",
646            l.data_type(),
647            r.data_type()
648        ))),
649    }
650}
651
652/// Perform arithmetic operation on a date array
653fn date_op<T: DateOp>(
654    op: Op,
655    l: &dyn Array,
656    l_s: bool,
657    r: &dyn Array,
658    r_s: bool,
659) -> Result<ArrayRef, ArrowError> {
660    use DataType::*;
661    use IntervalUnit::*;
662
663    const NUM_SECONDS_IN_DAY: i64 = 60 * 60 * 24;
664
665    let r_t = r.data_type();
666    match (T::DATA_TYPE, op, r_t) {
667        (Date32, Op::Sub | Op::SubWrapping, Date32) => {
668            let l = l.as_primitive::<Date32Type>();
669            let r = r.as_primitive::<Date32Type>();
670            return Ok(op_ref!(
671                DurationSecondType,
672                l,
673                l_s,
674                r,
675                r_s,
676                ((l as i64) - (r as i64)) * NUM_SECONDS_IN_DAY
677            ));
678        }
679        (Date64, Op::Sub | Op::SubWrapping, Date64) => {
680            let l = l.as_primitive::<Date64Type>();
681            let r = r.as_primitive::<Date64Type>();
682            let result = try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r));
683            return Ok(result);
684        }
685        _ => {}
686    }
687
688    let l = l.as_primitive::<T>();
689    match (op, r_t) {
690        (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
691            let r = r.as_primitive::<IntervalYearMonthType>();
692            Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r)))
693        }
694        (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
695            let r = r.as_primitive::<IntervalYearMonthType>();
696            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r)))
697        }
698
699        (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
700            let r = r.as_primitive::<IntervalDayTimeType>();
701            Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r)))
702        }
703        (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
704            let r = r.as_primitive::<IntervalDayTimeType>();
705            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r)))
706        }
707
708        (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
709            let r = r.as_primitive::<IntervalMonthDayNanoType>();
710            Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r)))
711        }
712        (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
713            let r = r.as_primitive::<IntervalMonthDayNanoType>();
714            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r)))
715        }
716
717        _ => Err(ArrowError::InvalidArgumentError(format!(
718            "Invalid date arithmetic operation: {} {op} {}",
719            l.data_type(),
720            r.data_type()
721        ))),
722    }
723}
724
725/// Perform arithmetic operation on decimal arrays
726fn decimal_op<T: DecimalType>(
727    op: Op,
728    l: &dyn Array,
729    l_s: bool,
730    r: &dyn Array,
731    r_s: bool,
732) -> Result<ArrayRef, ArrowError> {
733    let l = l.as_primitive::<T>();
734    let r = r.as_primitive::<T>();
735
736    let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) {
737        (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2),
738        (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2),
739        _ => unreachable!(),
740    };
741
742    // Follow the Hive decimal arithmetic rules
743    // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
744    let array: PrimitiveArray<T> = match op {
745        Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => {
746            // max(s1, s2)
747            let result_scale = *s1.max(s2);
748
749            // max(s1, s2) + max(p1-s1, p2-s2) + 1
750            let result_precision =
751                (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8)
752                    .saturating_add(1)
753                    .min(T::MAX_PRECISION);
754
755            let l_mul = T::Native::usize_as(10).pow_checked((result_scale - s1) as _)?;
756            let r_mul = T::Native::usize_as(10).pow_checked((result_scale - s2) as _)?;
757
758            match op {
759                Op::Add | Op::AddWrapping => {
760                    try_op!(
761                        l,
762                        l_s,
763                        r,
764                        r_s,
765                        l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?)
766                    )
767                }
768                Op::Sub | Op::SubWrapping => {
769                    try_op!(
770                        l,
771                        l_s,
772                        r,
773                        r_s,
774                        l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?)
775                    )
776                }
777                _ => unreachable!(),
778            }
779            .with_precision_and_scale(result_precision, result_scale)?
780        }
781        Op::Mul | Op::MulWrapping => {
782            let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION);
783            let result_scale = s1.saturating_add(*s2);
784            if result_scale > T::MAX_SCALE {
785                // SQL standard says that if the resulting scale of a multiply operation goes
786                // beyond the maximum, rounding is not acceptable and thus an error occurs
787                return Err(ArrowError::InvalidArgumentError(format!(
788                    "Output scale of {} {op} {} would exceed max scale of {}",
789                    l.data_type(),
790                    r.data_type(),
791                    T::MAX_SCALE
792                )));
793            }
794
795            try_op!(l, l_s, r, r_s, l.mul_checked(r))
796                .with_precision_and_scale(result_precision, result_scale)?
797        }
798
799        Op::Div => {
800            // Follow postgres and MySQL adding a fixed scale increment of 4
801            // s1 + 4
802            let result_scale = s1.saturating_add(4).min(T::MAX_SCALE);
803            let mul_pow = result_scale - s1 + s2;
804
805            // p1 - s1 + s2 + result_scale
806            let result_precision = (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION);
807
808            let (l_mul, r_mul) = match mul_pow.cmp(&0) {
809                Ordering::Greater => (
810                    T::Native::usize_as(10).pow_checked(mul_pow as _)?,
811                    T::Native::ONE,
812                ),
813                Ordering::Equal => (T::Native::ONE, T::Native::ONE),
814                Ordering::Less => (
815                    T::Native::ONE,
816                    T::Native::usize_as(10).pow_checked(mul_pow.neg_wrapping() as _)?,
817                ),
818            };
819
820            try_op!(
821                l,
822                l_s,
823                r,
824                r_s,
825                l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?)
826            )
827            .with_precision_and_scale(result_precision, result_scale)?
828        }
829
830        Op::Rem => {
831            // max(s1, s2)
832            let result_scale = *s1.max(s2);
833            // min(p1-s1, p2 -s2) + max( s1,s2 )
834            let result_precision =
835                (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8)
836                    .min(T::MAX_PRECISION);
837
838            let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _);
839            let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _);
840
841            try_op!(
842                l,
843                l_s,
844                r,
845                r_s,
846                l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?)
847            )
848            .with_precision_and_scale(result_precision, result_scale)?
849        }
850    };
851
852    Ok(Arc::new(array))
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use arrow_array::temporal_conversions::{as_date, as_datetime};
859    use arrow_buffer::{i256, ScalarBuffer};
860    use chrono::{DateTime, NaiveDate};
861
862    fn test_neg_primitive<T: ArrowPrimitiveType>(
863        input: &[T::Native],
864        out: Result<&[T::Native], &str>,
865    ) {
866        let a = PrimitiveArray::<T>::new(ScalarBuffer::from(input.to_vec()), None);
867        match out {
868            Ok(expected) => {
869                let result = neg(&a).unwrap();
870                assert_eq!(result.as_primitive::<T>().values(), expected);
871            }
872            Err(e) => {
873                let err = neg(&a).unwrap_err().to_string();
874                assert_eq!(e, err);
875            }
876        }
877    }
878
879    #[test]
880    fn test_neg() {
881        let input = &[1, -5, 2, 693, 3929];
882        let output = &[-1, 5, -2, -693, -3929];
883        test_neg_primitive::<Int32Type>(input, Ok(output));
884
885        let input = &[1, -5, 2, 693, 3929];
886        let output = &[-1, 5, -2, -693, -3929];
887        test_neg_primitive::<Int64Type>(input, Ok(output));
888        test_neg_primitive::<DurationSecondType>(input, Ok(output));
889        test_neg_primitive::<DurationMillisecondType>(input, Ok(output));
890        test_neg_primitive::<DurationMicrosecondType>(input, Ok(output));
891        test_neg_primitive::<DurationNanosecondType>(input, Ok(output));
892
893        let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5];
894        let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5];
895        test_neg_primitive::<Float32Type>(input, Ok(output));
896
897        test_neg_primitive::<Int32Type>(
898            &[i32::MIN],
899            Err("Arithmetic overflow: Overflow happened on: - -2147483648"),
900        );
901        test_neg_primitive::<Int64Type>(
902            &[i64::MIN],
903            Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"),
904        );
905        test_neg_primitive::<DurationSecondType>(
906            &[i64::MIN],
907            Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"),
908        );
909
910        let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap();
911        assert_eq!(r.as_primitive::<Int32Type>().value(0), i32::MIN);
912
913        let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap();
914        assert_eq!(r.as_primitive::<Int64Type>().value(0), i64::MIN);
915
916        let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN]))
917            .unwrap_err()
918            .to_string();
919
920        assert_eq!(
921            err,
922            "Arithmetic overflow: Overflow happened on: - -9223372036854775808"
923        );
924
925        let a = Decimal128Array::from(vec![1, 3, -44, 2, 4])
926            .with_precision_and_scale(9, 6)
927            .unwrap();
928
929        let r = neg(&a).unwrap();
930        assert_eq!(r.data_type(), a.data_type());
931        assert_eq!(
932            r.as_primitive::<Decimal128Type>().values(),
933            &[-1, -3, 44, -2, -4]
934        );
935
936        let a = Decimal256Array::from(vec![
937            i256::from_i128(342),
938            i256::from_i128(-4949),
939            i256::from_i128(3),
940        ])
941        .with_precision_and_scale(9, 6)
942        .unwrap();
943
944        let r = neg(&a).unwrap();
945        assert_eq!(r.data_type(), a.data_type());
946        assert_eq!(
947            r.as_primitive::<Decimal256Type>().values(),
948            &[
949                i256::from_i128(-342),
950                i256::from_i128(4949),
951                i256::from_i128(-3),
952            ]
953        );
954
955        let a = IntervalYearMonthArray::from(vec![
956            IntervalYearMonthType::make_value(2, 4),
957            IntervalYearMonthType::make_value(2, -4),
958            IntervalYearMonthType::make_value(-3, -5),
959        ]);
960        let r = neg(&a).unwrap();
961        assert_eq!(
962            r.as_primitive::<IntervalYearMonthType>().values(),
963            &[
964                IntervalYearMonthType::make_value(-2, -4),
965                IntervalYearMonthType::make_value(-2, 4),
966                IntervalYearMonthType::make_value(3, 5),
967            ]
968        );
969
970        let a = IntervalDayTimeArray::from(vec![
971            IntervalDayTimeType::make_value(2, 4),
972            IntervalDayTimeType::make_value(2, -4),
973            IntervalDayTimeType::make_value(-3, -5),
974        ]);
975        let r = neg(&a).unwrap();
976        assert_eq!(
977            r.as_primitive::<IntervalDayTimeType>().values(),
978            &[
979                IntervalDayTimeType::make_value(-2, -4),
980                IntervalDayTimeType::make_value(-2, 4),
981                IntervalDayTimeType::make_value(3, 5),
982            ]
983        );
984
985        let a = IntervalMonthDayNanoArray::from(vec![
986            IntervalMonthDayNanoType::make_value(2, 4, 5953394),
987            IntervalMonthDayNanoType::make_value(2, -4, -45839),
988            IntervalMonthDayNanoType::make_value(-3, -5, 6944),
989        ]);
990        let r = neg(&a).unwrap();
991        assert_eq!(
992            r.as_primitive::<IntervalMonthDayNanoType>().values(),
993            &[
994                IntervalMonthDayNanoType::make_value(-2, -4, -5953394),
995                IntervalMonthDayNanoType::make_value(-2, 4, 45839),
996                IntervalMonthDayNanoType::make_value(3, 5, -6944),
997            ]
998        );
999    }
1000
1001    #[test]
1002    fn test_integer() {
1003        let a = Int32Array::from(vec![4, 3, 5, -6, 100]);
1004        let b = Int32Array::from(vec![6, 2, 5, -7, 3]);
1005        let result = add(&a, &b).unwrap();
1006        assert_eq!(
1007            result.as_ref(),
1008            &Int32Array::from(vec![10, 5, 10, -13, 103])
1009        );
1010        let result = sub(&a, &b).unwrap();
1011        assert_eq!(result.as_ref(), &Int32Array::from(vec![-2, 1, 0, 1, 97]));
1012        let result = div(&a, &b).unwrap();
1013        assert_eq!(result.as_ref(), &Int32Array::from(vec![0, 1, 1, 0, 33]));
1014        let result = mul(&a, &b).unwrap();
1015        assert_eq!(result.as_ref(), &Int32Array::from(vec![24, 6, 25, 42, 300]));
1016        let result = rem(&a, &b).unwrap();
1017        assert_eq!(result.as_ref(), &Int32Array::from(vec![4, 1, 0, -6, 1]));
1018
1019        let a = Int8Array::from(vec![Some(2), None, Some(45)]);
1020        let b = Int8Array::from(vec![Some(5), Some(3), None]);
1021        let result = add(&a, &b).unwrap();
1022        assert_eq!(result.as_ref(), &Int8Array::from(vec![Some(7), None, None]));
1023
1024        let a = UInt8Array::from(vec![56, 5, 3]);
1025        let b = UInt8Array::from(vec![200, 2, 5]);
1026        let err = add(&a, &b).unwrap_err().to_string();
1027        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 56 + 200");
1028        let result = add_wrapping(&a, &b).unwrap();
1029        assert_eq!(result.as_ref(), &UInt8Array::from(vec![0, 7, 8]));
1030
1031        let a = UInt8Array::from(vec![34, 5, 3]);
1032        let b = UInt8Array::from(vec![200, 2, 5]);
1033        let err = sub(&a, &b).unwrap_err().to_string();
1034        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 - 200");
1035        let result = sub_wrapping(&a, &b).unwrap();
1036        assert_eq!(result.as_ref(), &UInt8Array::from(vec![90, 3, 254]));
1037
1038        let a = UInt8Array::from(vec![34, 5, 3]);
1039        let b = UInt8Array::from(vec![200, 2, 5]);
1040        let err = mul(&a, &b).unwrap_err().to_string();
1041        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 * 200");
1042        let result = mul_wrapping(&a, &b).unwrap();
1043        assert_eq!(result.as_ref(), &UInt8Array::from(vec![144, 10, 15]));
1044
1045        let a = Int16Array::from(vec![i16::MIN]);
1046        let b = Int16Array::from(vec![-1]);
1047        let err = div(&a, &b).unwrap_err().to_string();
1048        assert_eq!(
1049            err,
1050            "Arithmetic overflow: Overflow happened on: -32768 / -1"
1051        );
1052
1053        let a = Int16Array::from(vec![i16::MIN]);
1054        let b = Int16Array::from(vec![-1]);
1055        let result = rem(&a, &b).unwrap();
1056        assert_eq!(result.as_ref(), &Int16Array::from(vec![0]));
1057
1058        let a = Int16Array::from(vec![21]);
1059        let b = Int16Array::from(vec![0]);
1060        let err = div(&a, &b).unwrap_err().to_string();
1061        assert_eq!(err, "Divide by zero error");
1062
1063        let a = Int16Array::from(vec![21]);
1064        let b = Int16Array::from(vec![0]);
1065        let err = rem(&a, &b).unwrap_err().to_string();
1066        assert_eq!(err, "Divide by zero error");
1067    }
1068
1069    #[test]
1070    fn test_float() {
1071        let a = Float32Array::from(vec![1., f32::MAX, 6., -4., -1., 0.]);
1072        let b = Float32Array::from(vec![1., f32::MAX, f32::MAX, -3., 45., 0.]);
1073        let result = add(&a, &b).unwrap();
1074        assert_eq!(
1075            result.as_ref(),
1076            &Float32Array::from(vec![2., f32::INFINITY, f32::MAX, -7., 44.0, 0.])
1077        );
1078
1079        let result = sub(&a, &b).unwrap();
1080        assert_eq!(
1081            result.as_ref(),
1082            &Float32Array::from(vec![0., 0., f32::MIN, -1., -46., 0.])
1083        );
1084
1085        let result = mul(&a, &b).unwrap();
1086        assert_eq!(
1087            result.as_ref(),
1088            &Float32Array::from(vec![1., f32::INFINITY, f32::INFINITY, 12., -45., 0.])
1089        );
1090
1091        let result = div(&a, &b).unwrap();
1092        let r = result.as_primitive::<Float32Type>();
1093        assert_eq!(r.value(0), 1.);
1094        assert_eq!(r.value(1), 1.);
1095        assert!(r.value(2) < f32::EPSILON);
1096        assert_eq!(r.value(3), -4. / -3.);
1097        assert!(r.value(5).is_nan());
1098
1099        let result = rem(&a, &b).unwrap();
1100        let r = result.as_primitive::<Float32Type>();
1101        assert_eq!(&r.values()[..5], &[0., 0., 6., -1., -1.]);
1102        assert!(r.value(5).is_nan());
1103    }
1104
1105    #[test]
1106    fn test_decimal() {
1107        // 0.015 7.842 -0.577 0.334 -0.078 0.003
1108        let a = Decimal128Array::from(vec![15, 0, -577, 334, -78, 3])
1109            .with_precision_and_scale(12, 3)
1110            .unwrap();
1111
1112        // 5.4 0 -35.6 0.3 0.6 7.45
1113        let b = Decimal128Array::from(vec![54, 34, -356, 3, 6, 745])
1114            .with_precision_and_scale(12, 1)
1115            .unwrap();
1116
1117        let result = add(&a, &b).unwrap();
1118        assert_eq!(result.data_type(), &DataType::Decimal128(15, 3));
1119        assert_eq!(
1120            result.as_primitive::<Decimal128Type>().values(),
1121            &[5415, 3400, -36177, 634, 522, 74503]
1122        );
1123
1124        let result = sub(&a, &b).unwrap();
1125        assert_eq!(result.data_type(), &DataType::Decimal128(15, 3));
1126        assert_eq!(
1127            result.as_primitive::<Decimal128Type>().values(),
1128            &[-5385, -3400, 35023, 34, -678, -74497]
1129        );
1130
1131        let result = mul(&a, &b).unwrap();
1132        assert_eq!(result.data_type(), &DataType::Decimal128(25, 4));
1133        assert_eq!(
1134            result.as_primitive::<Decimal128Type>().values(),
1135            &[810, 0, 205412, 1002, -468, 2235]
1136        );
1137
1138        let result = div(&a, &b).unwrap();
1139        assert_eq!(result.data_type(), &DataType::Decimal128(17, 7));
1140        assert_eq!(
1141            result.as_primitive::<Decimal128Type>().values(),
1142            &[27777, 0, 162078, 11133333, -1300000, 402]
1143        );
1144
1145        let result = rem(&a, &b).unwrap();
1146        assert_eq!(result.data_type(), &DataType::Decimal128(12, 3));
1147        assert_eq!(
1148            result.as_primitive::<Decimal128Type>().values(),
1149            &[15, 0, -577, 34, -78, 3]
1150        );
1151
1152        let a = Decimal128Array::from(vec![1])
1153            .with_precision_and_scale(3, 3)
1154            .unwrap();
1155        let b = Decimal128Array::from(vec![1])
1156            .with_precision_and_scale(37, 37)
1157            .unwrap();
1158        let err = mul(&a, &b).unwrap_err().to_string();
1159        assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38");
1160
1161        let a = Decimal128Array::from(vec![1])
1162            .with_precision_and_scale(3, -2)
1163            .unwrap();
1164        let err = add(&a, &b).unwrap_err().to_string();
1165        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 10 ^ 39");
1166
1167        let a = Decimal128Array::from(vec![10])
1168            .with_precision_and_scale(3, -1)
1169            .unwrap();
1170        let err = add(&a, &b).unwrap_err().to_string();
1171        assert_eq!(
1172            err,
1173            "Arithmetic overflow: Overflow happened on: 10 * 100000000000000000000000000000000000000"
1174        );
1175
1176        let b = Decimal128Array::from(vec![0])
1177            .with_precision_and_scale(1, 1)
1178            .unwrap();
1179        let err = div(&a, &b).unwrap_err().to_string();
1180        assert_eq!(err, "Divide by zero error");
1181        let err = rem(&a, &b).unwrap_err().to_string();
1182        assert_eq!(err, "Divide by zero error");
1183    }
1184
1185    fn test_timestamp_impl<T: TimestampOp>() {
1186        let a = PrimitiveArray::<T>::new(vec![2000000, 434030324, 53943340].into(), None);
1187        let b = PrimitiveArray::<T>::new(vec![329593, 59349, 694994].into(), None);
1188
1189        let result = sub(&a, &b).unwrap();
1190        assert_eq!(
1191            result.as_primitive::<T::Duration>().values(),
1192            &[1670407, 433970975, 53248346]
1193        );
1194
1195        let r2 = add(&b, &result.as_ref()).unwrap();
1196        assert_eq!(r2.as_ref(), &a);
1197
1198        let r3 = add(&result.as_ref(), &b).unwrap();
1199        assert_eq!(r3.as_ref(), &a);
1200
1201        let format_array = |x: &dyn Array| -> Vec<String> {
1202            x.as_primitive::<T>()
1203                .values()
1204                .into_iter()
1205                .map(|x| as_datetime::<T>(*x).unwrap().to_string())
1206                .collect()
1207        };
1208
1209        let values = vec![
1210            "1970-01-01T00:00:00Z",
1211            "2010-04-01T04:00:20Z",
1212            "1960-01-30T04:23:20Z",
1213        ]
1214        .into_iter()
1215        .map(|x| T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap())
1216        .collect();
1217
1218        let a = PrimitiveArray::<T>::new(values, None);
1219        let b = IntervalYearMonthArray::from(vec![
1220            IntervalYearMonthType::make_value(5, 34),
1221            IntervalYearMonthType::make_value(-2, 4),
1222            IntervalYearMonthType::make_value(7, -4),
1223        ]);
1224        let r4 = add(&a, &b).unwrap();
1225        assert_eq!(
1226            &format_array(r4.as_ref()),
1227            &[
1228                "1977-11-01 00:00:00".to_string(),
1229                "2008-08-01 04:00:20".to_string(),
1230                "1966-09-30 04:23:20".to_string()
1231            ]
1232        );
1233
1234        let r5 = sub(&r4, &b).unwrap();
1235        assert_eq!(r5.as_ref(), &a);
1236
1237        let b = IntervalDayTimeArray::from(vec![
1238            IntervalDayTimeType::make_value(5, 454000),
1239            IntervalDayTimeType::make_value(-34, 0),
1240            IntervalDayTimeType::make_value(7, -4000),
1241        ]);
1242        let r6 = add(&a, &b).unwrap();
1243        assert_eq!(
1244            &format_array(r6.as_ref()),
1245            &[
1246                "1970-01-06 00:07:34".to_string(),
1247                "2010-02-26 04:00:20".to_string(),
1248                "1960-02-06 04:23:16".to_string()
1249            ]
1250        );
1251
1252        let r7 = sub(&r6, &b).unwrap();
1253        assert_eq!(r7.as_ref(), &a);
1254
1255        let b = IntervalMonthDayNanoArray::from(vec![
1256            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1257            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1258            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1259        ]);
1260        let r8 = add(&a, &b).unwrap();
1261        assert_eq!(
1262            &format_array(r8.as_ref()),
1263            &[
1264                "1998-10-04 23:59:17".to_string(),
1265                "1960-09-29 04:00:33".to_string(),
1266                "1960-07-02 04:31:33".to_string()
1267            ]
1268        );
1269
1270        let r9 = sub(&r8, &b).unwrap();
1271        // Note: subtraction is not the inverse of addition for intervals
1272        assert_eq!(
1273            &format_array(r9.as_ref()),
1274            &[
1275                "1970-01-02 00:00:00".to_string(),
1276                "2010-04-02 04:00:20".to_string(),
1277                "1960-01-31 04:23:20".to_string()
1278            ]
1279        );
1280    }
1281
1282    #[test]
1283    fn test_timestamp() {
1284        test_timestamp_impl::<TimestampSecondType>();
1285        test_timestamp_impl::<TimestampMillisecondType>();
1286        test_timestamp_impl::<TimestampMicrosecondType>();
1287        test_timestamp_impl::<TimestampNanosecondType>();
1288    }
1289
1290    #[test]
1291    fn test_interval() {
1292        let a = IntervalYearMonthArray::from(vec![
1293            IntervalYearMonthType::make_value(32, 4),
1294            IntervalYearMonthType::make_value(32, 4),
1295        ]);
1296        let b = IntervalYearMonthArray::from(vec![
1297            IntervalYearMonthType::make_value(-4, 6),
1298            IntervalYearMonthType::make_value(-3, 23),
1299        ]);
1300        let result = add(&a, &b).unwrap();
1301        assert_eq!(
1302            result.as_ref(),
1303            &IntervalYearMonthArray::from(vec![
1304                IntervalYearMonthType::make_value(28, 10),
1305                IntervalYearMonthType::make_value(29, 27)
1306            ])
1307        );
1308        let result = sub(&a, &b).unwrap();
1309        assert_eq!(
1310            result.as_ref(),
1311            &IntervalYearMonthArray::from(vec![
1312                IntervalYearMonthType::make_value(36, -2),
1313                IntervalYearMonthType::make_value(35, -19)
1314            ])
1315        );
1316
1317        let a = IntervalDayTimeArray::from(vec![
1318            IntervalDayTimeType::make_value(32, 4),
1319            IntervalDayTimeType::make_value(32, 4),
1320        ]);
1321        let b = IntervalDayTimeArray::from(vec![
1322            IntervalDayTimeType::make_value(-4, 6),
1323            IntervalDayTimeType::make_value(-3, 23),
1324        ]);
1325        let result = add(&a, &b).unwrap();
1326        assert_eq!(
1327            result.as_ref(),
1328            &IntervalDayTimeArray::from(vec![
1329                IntervalDayTimeType::make_value(28, 10),
1330                IntervalDayTimeType::make_value(29, 27)
1331            ])
1332        );
1333        let result = sub(&a, &b).unwrap();
1334        assert_eq!(
1335            result.as_ref(),
1336            &IntervalDayTimeArray::from(vec![
1337                IntervalDayTimeType::make_value(36, -2),
1338                IntervalDayTimeType::make_value(35, -19)
1339            ])
1340        );
1341        let a = IntervalMonthDayNanoArray::from(vec![
1342            IntervalMonthDayNanoType::make_value(32, 4, 4000000000000),
1343            IntervalMonthDayNanoType::make_value(32, 4, 45463000000000000),
1344        ]);
1345        let b = IntervalMonthDayNanoArray::from(vec![
1346            IntervalMonthDayNanoType::make_value(-4, 6, 46000000000000),
1347            IntervalMonthDayNanoType::make_value(-3, 23, 3564000000000000),
1348        ]);
1349        let result = add(&a, &b).unwrap();
1350        assert_eq!(
1351            result.as_ref(),
1352            &IntervalMonthDayNanoArray::from(vec![
1353                IntervalMonthDayNanoType::make_value(28, 10, 50000000000000),
1354                IntervalMonthDayNanoType::make_value(29, 27, 49027000000000000)
1355            ])
1356        );
1357        let result = sub(&a, &b).unwrap();
1358        assert_eq!(
1359            result.as_ref(),
1360            &IntervalMonthDayNanoArray::from(vec![
1361                IntervalMonthDayNanoType::make_value(36, -2, -42000000000000),
1362                IntervalMonthDayNanoType::make_value(35, -19, 41899000000000000)
1363            ])
1364        );
1365        let a = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::MAX]);
1366        let b = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::ONE]);
1367        let err = add(&a, &b).unwrap_err().to_string();
1368        assert_eq!(
1369            err,
1370            "Arithmetic overflow: Overflow happened on: 2147483647 + 1"
1371        );
1372    }
1373
1374    fn test_duration_impl<T: ArrowPrimitiveType<Native = i64>>() {
1375        let a = PrimitiveArray::<T>::new(vec![1000, 4394, -3944].into(), None);
1376        let b = PrimitiveArray::<T>::new(vec![4, -5, -243].into(), None);
1377
1378        let result = add(&a, &b).unwrap();
1379        assert_eq!(result.as_primitive::<T>().values(), &[1004, 4389, -4187]);
1380        let result = sub(&a, &b).unwrap();
1381        assert_eq!(result.as_primitive::<T>().values(), &[996, 4399, -3701]);
1382
1383        let err = mul(&a, &b).unwrap_err().to_string();
1384        assert!(
1385            err.contains("Invalid duration arithmetic operation"),
1386            "{err}"
1387        );
1388
1389        let err = div(&a, &b).unwrap_err().to_string();
1390        assert!(
1391            err.contains("Invalid duration arithmetic operation"),
1392            "{err}"
1393        );
1394
1395        let err = rem(&a, &b).unwrap_err().to_string();
1396        assert!(
1397            err.contains("Invalid duration arithmetic operation"),
1398            "{err}"
1399        );
1400
1401        let a = PrimitiveArray::<T>::new(vec![i64::MAX].into(), None);
1402        let b = PrimitiveArray::<T>::new(vec![1].into(), None);
1403        let err = add(&a, &b).unwrap_err().to_string();
1404        assert_eq!(
1405            err,
1406            "Arithmetic overflow: Overflow happened on: 9223372036854775807 + 1"
1407        );
1408    }
1409
1410    #[test]
1411    fn test_duration() {
1412        test_duration_impl::<DurationSecondType>();
1413        test_duration_impl::<DurationMillisecondType>();
1414        test_duration_impl::<DurationMicrosecondType>();
1415        test_duration_impl::<DurationNanosecondType>();
1416    }
1417
1418    fn test_date_impl<T: ArrowPrimitiveType, F>(f: F)
1419    where
1420        F: Fn(NaiveDate) -> T::Native,
1421        T::Native: TryInto<i64>,
1422    {
1423        let a = PrimitiveArray::<T>::new(
1424            vec![
1425                f(NaiveDate::from_ymd_opt(1979, 1, 30).unwrap()),
1426                f(NaiveDate::from_ymd_opt(2010, 4, 3).unwrap()),
1427                f(NaiveDate::from_ymd_opt(2008, 2, 29).unwrap()),
1428            ]
1429            .into(),
1430            None,
1431        );
1432
1433        let b = IntervalYearMonthArray::from(vec![
1434            IntervalYearMonthType::make_value(34, 2),
1435            IntervalYearMonthType::make_value(3, -3),
1436            IntervalYearMonthType::make_value(-12, 4),
1437        ]);
1438
1439        let format_array = |x: &dyn Array| -> Vec<String> {
1440            x.as_primitive::<T>()
1441                .values()
1442                .into_iter()
1443                .map(|x| {
1444                    as_date::<T>((*x).try_into().ok().unwrap())
1445                        .unwrap()
1446                        .to_string()
1447                })
1448                .collect()
1449        };
1450
1451        let result = add(&a, &b).unwrap();
1452        assert_eq!(
1453            &format_array(result.as_ref()),
1454            &[
1455                "2013-03-30".to_string(),
1456                "2013-01-03".to_string(),
1457                "1996-06-29".to_string(),
1458            ]
1459        );
1460        let result = sub(&result, &b).unwrap();
1461        assert_eq!(result.as_ref(), &a);
1462
1463        let b = IntervalDayTimeArray::from(vec![
1464            IntervalDayTimeType::make_value(34, 2),
1465            IntervalDayTimeType::make_value(3, -3),
1466            IntervalDayTimeType::make_value(-12, 4),
1467        ]);
1468
1469        let result = add(&a, &b).unwrap();
1470        assert_eq!(
1471            &format_array(result.as_ref()),
1472            &[
1473                "1979-03-05".to_string(),
1474                "2010-04-06".to_string(),
1475                "2008-02-17".to_string(),
1476            ]
1477        );
1478        let result = sub(&result, &b).unwrap();
1479        assert_eq!(result.as_ref(), &a);
1480
1481        let b = IntervalMonthDayNanoArray::from(vec![
1482            IntervalMonthDayNanoType::make_value(34, 2, -34353534),
1483            IntervalMonthDayNanoType::make_value(3, -3, 2443),
1484            IntervalMonthDayNanoType::make_value(-12, 4, 2323242423232),
1485        ]);
1486
1487        let result = add(&a, &b).unwrap();
1488        assert_eq!(
1489            &format_array(result.as_ref()),
1490            &[
1491                "1981-12-02".to_string(),
1492                "2010-06-30".to_string(),
1493                "2007-03-04".to_string(),
1494            ]
1495        );
1496        let result = sub(&result, &b).unwrap();
1497        assert_eq!(
1498            &format_array(result.as_ref()),
1499            &[
1500                "1979-01-31".to_string(),
1501                "2010-04-02".to_string(),
1502                "2008-02-29".to_string(),
1503            ]
1504        );
1505    }
1506
1507    #[test]
1508    fn test_date() {
1509        test_date_impl::<Date32Type, _>(Date32Type::from_naive_date);
1510        test_date_impl::<Date64Type, _>(Date64Type::from_naive_date);
1511
1512        let a = Date32Array::from(vec![i32::MIN, i32::MAX, 23, 7684]);
1513        let b = Date32Array::from(vec![i32::MIN, i32::MIN, -2, 45]);
1514        let result = sub(&a, &b).unwrap();
1515        assert_eq!(
1516            result.as_primitive::<DurationSecondType>().values(),
1517            &[0, 371085174288000, 2160000, 660009600]
1518        );
1519
1520        let a = Date64Array::from(vec![4343, 76676, 3434]);
1521        let b = Date64Array::from(vec![3, -5, 5]);
1522        let result = sub(&a, &b).unwrap();
1523        assert_eq!(
1524            result.as_primitive::<DurationMillisecondType>().values(),
1525            &[4340, 76681, 3429]
1526        );
1527
1528        let a = Date64Array::from(vec![i64::MAX]);
1529        let b = Date64Array::from(vec![-1]);
1530        let err = sub(&a, &b).unwrap_err().to_string();
1531        assert_eq!(
1532            err,
1533            "Arithmetic overflow: Overflow happened on: 9223372036854775807 - -1"
1534        );
1535    }
1536}