Skip to main content

parquet_variant_compute/
type_conversion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module for transforming a typed arrow `Array` to `VariantArray`.
19
20use arrow::compute::{
21    CastOptions, DecimalCast, parse_string_to_decimal_native, rescale_decimal,
22    single_float_to_decimal,
23};
24use arrow::datatypes::{
25    self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
26    DecimalType,
27};
28use arrow::error::{ArrowError, Result};
29use chrono::Timelike;
30use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16};
31
32/// Extension trait for Arrow primitive types that can extract their native value from a Variant
33pub(crate) trait PrimitiveFromVariant: ArrowPrimitiveType {
34    fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
35}
36
37/// Extension trait for Arrow timestamp types that can extract their native value from a Variant
38/// We can't use [`PrimitiveFromVariant`] directly because we need _two_ implementations for each
39/// timestamp type -- the `NTZ` param here.
40pub(crate) trait TimestampFromVariant<const NTZ: bool>: ArrowTimestampType {
41    fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native>;
42}
43
44/// Cast a single `Variant` value with safe/strict semantics.
45///
46/// Returns `Ok(Some(_))` on successful conversion.
47/// Returns `Ok(None)` when conversion fails in safe mode or the source value is `Variant::Null`.
48/// Returns `Err(_)` when conversion fails in strict mode.
49pub(crate) fn variant_cast_with_options<'a, 'm, 'v, T>(
50    variant: &'a Variant<'m, 'v>,
51    cast_options: &CastOptions<'_>,
52    cast: impl FnOnce(&'a Variant<'m, 'v>) -> Option<T>,
53) -> Result<Option<T>> {
54    if let Some(value) = cast(variant) {
55        Ok(Some(value))
56    } else if matches!(variant, Variant::Null) || cast_options.safe {
57        Ok(None)
58    } else {
59        Err(ArrowError::CastError(format!(
60            "Failed to cast variant value {variant:?}"
61        )))
62    }
63}
64
65/// Macro to generate PrimitiveFromVariant implementations for Arrow primitive types
66macro_rules! impl_primitive_from_variant {
67    ($arrow_type:ty, $variant_method:ident $(, $cast_fn:expr)?) => {
68        impl PrimitiveFromVariant for $arrow_type {
69            fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
70                let value = variant.$variant_method();
71                $( let value = value.and_then($cast_fn); )?
72                value
73            }
74        }
75    };
76}
77
78macro_rules! impl_timestamp_from_variant {
79    ($timestamp_type:ty, $variant_method:ident, ntz=$ntz:ident, $cast_fn:expr $(,)?) => {
80        impl TimestampFromVariant<{ $ntz }> for $timestamp_type {
81            fn from_variant(variant: &Variant<'_, '_>) -> Option<Self::Native> {
82                variant.$variant_method().and_then($cast_fn)
83            }
84        }
85    };
86}
87
88impl_primitive_from_variant!(datatypes::Int32Type, as_int32);
89impl_primitive_from_variant!(datatypes::Int16Type, as_int16);
90impl_primitive_from_variant!(datatypes::Int8Type, as_int8);
91impl_primitive_from_variant!(datatypes::Int64Type, as_int64);
92impl_primitive_from_variant!(datatypes::UInt8Type, as_u8);
93impl_primitive_from_variant!(datatypes::UInt16Type, as_u16);
94impl_primitive_from_variant!(datatypes::UInt32Type, as_u32);
95impl_primitive_from_variant!(datatypes::UInt64Type, as_u64);
96impl_primitive_from_variant!(datatypes::Float16Type, as_f16);
97impl_primitive_from_variant!(datatypes::Float32Type, as_f32);
98impl_primitive_from_variant!(datatypes::Float64Type, as_f64);
99impl_primitive_from_variant!(datatypes::Date32Type, as_naive_date, |v| {
100    Some(datatypes::Date32Type::from_naive_date(v))
101});
102impl_primitive_from_variant!(datatypes::Date64Type, as_naive_date, |v| {
103    Some(datatypes::Date64Type::from_naive_date(v))
104});
105impl_primitive_from_variant!(datatypes::Time32SecondType, as_time_utc, |v| {
106    // Return None if there are leftover nanoseconds
107    if v.nanosecond() != 0 {
108        None
109    } else {
110        Some(v.num_seconds_from_midnight() as i32)
111    }
112});
113impl_primitive_from_variant!(datatypes::Time32MillisecondType, as_time_utc, |v| {
114    // Return None if there are leftover microseconds
115    if v.nanosecond() % 1_000_000 != 0 {
116        None
117    } else {
118        Some((v.num_seconds_from_midnight() * 1_000) as i32 + (v.nanosecond() / 1_000_000) as i32)
119    }
120});
121impl_primitive_from_variant!(datatypes::Time64MicrosecondType, as_time_utc, |v| {
122    Some(v.num_seconds_from_midnight() as i64 * 1_000_000 + v.nanosecond() as i64 / 1_000)
123});
124impl_primitive_from_variant!(datatypes::Time64NanosecondType, as_time_utc, |v| {
125    // convert micro to nano seconds
126    Some(v.num_seconds_from_midnight() as i64 * 1_000_000_000 + v.nanosecond() as i64)
127});
128impl_timestamp_from_variant!(
129    datatypes::TimestampSecondType,
130    as_timestamp_ntz_nanos,
131    ntz = true,
132    |timestamp| {
133        // Return None if there are leftover nanoseconds
134        if timestamp.nanosecond() != 0 {
135            None
136        } else {
137            Self::from_naive_datetime(timestamp, None)
138        }
139    }
140);
141impl_timestamp_from_variant!(
142    datatypes::TimestampSecondType,
143    as_timestamp_nanos,
144    ntz = false,
145    |timestamp| {
146        // Return None if there are leftover nanoseconds
147        if timestamp.nanosecond() != 0 {
148            None
149        } else {
150            Self::from_naive_datetime(timestamp.naive_utc(), None)
151        }
152    }
153);
154impl_timestamp_from_variant!(
155    datatypes::TimestampMillisecondType,
156    as_timestamp_ntz_nanos,
157    ntz = true,
158    |timestamp| {
159        // Return None if there are leftover microseconds
160        if timestamp.nanosecond() % 1_000_000 != 0 {
161            None
162        } else {
163            Self::from_naive_datetime(timestamp, None)
164        }
165    }
166);
167impl_timestamp_from_variant!(
168    datatypes::TimestampMillisecondType,
169    as_timestamp_nanos,
170    ntz = false,
171    |timestamp| {
172        // Return None if there are leftover microseconds
173        if timestamp.nanosecond() % 1_000_000 != 0 {
174            None
175        } else {
176            Self::from_naive_datetime(timestamp.naive_utc(), None)
177        }
178    }
179);
180impl_timestamp_from_variant!(
181    datatypes::TimestampMicrosecondType,
182    as_timestamp_ntz_micros,
183    ntz = true,
184    |timestamp| Self::from_naive_datetime(timestamp, None),
185);
186impl_timestamp_from_variant!(
187    datatypes::TimestampMicrosecondType,
188    as_timestamp_micros,
189    ntz = false,
190    |timestamp| Self::from_naive_datetime(timestamp.naive_utc(), None)
191);
192impl_timestamp_from_variant!(
193    datatypes::TimestampNanosecondType,
194    as_timestamp_ntz_nanos,
195    ntz = true,
196    |timestamp| Self::from_naive_datetime(timestamp, None)
197);
198impl_timestamp_from_variant!(
199    datatypes::TimestampNanosecondType,
200    as_timestamp_nanos,
201    ntz = false,
202    |timestamp| Self::from_naive_datetime(timestamp.naive_utc(), None)
203);
204
205/// Returns the unscaled integer representation for Arrow decimal type `O`
206/// from a `Variant`.
207///
208/// - `precision` and `scale` specify the target Arrow decimal parameters
209/// - Integer variants (`Int8/16/32/64`) are treated as decimals with scale 0
210/// - Floating point variants (`Float/Double`) are converted to decimals with the given scale
211/// - String variants (`String/ShortString`) are parsed as decimals with the given scale
212/// - Decimal variants (`Decimal4/8/16`) use their embedded precision and scale
213///
214/// The value is rescaled to (`precision`, `scale`) using `rescale_decimal` for integers,
215/// `single_float_to_decimal` for floats, and `parse_string_to_decimal_native` for strings.
216/// returns `None` if it cannot fit the requested precision.
217pub(crate) fn variant_to_unscaled_decimal<O>(
218    variant: &Variant<'_, '_>,
219    precision: u8,
220    scale: i8,
221) -> Option<O::Native>
222where
223    O: DecimalType,
224    O::Native: DecimalCast,
225{
226    let mul = 10_f64.powi(scale as i32);
227
228    match variant {
229        Variant::Int8(i) => rescale_decimal::<Decimal32Type, O>(
230            *i as i32,
231            VariantDecimal4::MAX_PRECISION,
232            0,
233            precision,
234            scale,
235        ),
236        Variant::Int16(i) => rescale_decimal::<Decimal32Type, O>(
237            *i as i32,
238            VariantDecimal4::MAX_PRECISION,
239            0,
240            precision,
241            scale,
242        ),
243        Variant::Int32(i) => rescale_decimal::<Decimal32Type, O>(
244            *i,
245            VariantDecimal4::MAX_PRECISION,
246            0,
247            precision,
248            scale,
249        ),
250        Variant::Int64(i) => rescale_decimal::<Decimal64Type, O>(
251            *i,
252            VariantDecimal8::MAX_PRECISION,
253            0,
254            precision,
255            scale,
256        ),
257        Variant::Float(f) => single_float_to_decimal::<O>(f64::from(*f), mul),
258        Variant::Double(f) => single_float_to_decimal::<O>(*f, mul),
259        // arrow-cast only support cast string to decimal with scale >=0 for now
260        // Please see `cast_string_to_decimal` in arrow-cast/src/cast/decimal.rs for more detail
261        Variant::String(v) if scale >= 0 => parse_string_to_decimal_native::<O>(v, scale as _).ok(),
262        Variant::ShortString(v) if scale >= 0 => {
263            parse_string_to_decimal_native::<O>(v, scale as _).ok()
264        }
265        Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>(
266            d.integer(),
267            VariantDecimal4::MAX_PRECISION,
268            d.scale() as i8,
269            precision,
270            scale,
271        ),
272        Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>(
273            d.integer(),
274            VariantDecimal8::MAX_PRECISION,
275            d.scale() as i8,
276            precision,
277            scale,
278        ),
279        Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>(
280            d.integer(),
281            VariantDecimal16::MAX_PRECISION,
282            d.scale() as i8,
283            precision,
284            scale,
285        ),
286        _ => None,
287    }
288}
289
290/// Convert the value at a specific index in the given array into a `Variant`.
291macro_rules! non_generic_conversion_single_value {
292    ($array:expr, $cast_fn:expr, $index:expr) => {{
293        let array = $array;
294        if array.is_null($index) {
295            Ok(Variant::Null)
296        } else {
297            let cast_value = $cast_fn(array.value($index));
298            Ok(Variant::from(cast_value))
299        }
300    }};
301}
302pub(crate) use non_generic_conversion_single_value;
303
304/// Convert the value at a specific index in the given array into a `Variant`,
305/// using `method` requiring a generic type to downcast the generic array
306/// to a specific array type and `cast_fn` to transform the element.
307macro_rules! generic_conversion_single_value {
308    ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{
309        $crate::type_conversion::non_generic_conversion_single_value!(
310            $input.$method::<$t>(),
311            $cast_fn,
312            $index
313        )
314    }};
315}
316pub(crate) use generic_conversion_single_value;
317
318macro_rules! generic_conversion_single_value_with_result {
319    ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{
320        let arr = $input.$method::<$t>();
321        let v = arr.value($index);
322        match ($cast_fn)(v) {
323            Ok(var) => Ok(Variant::from(var)),
324            Err(e) => Err(ArrowError::CastError(format!(
325                "Cast failed at index {idx} (array type: {ty}): {e}",
326                idx = $index,
327                ty = <$t as ::arrow::datatypes::ArrowPrimitiveType>::DATA_TYPE
328            ))),
329        }
330    }};
331}
332
333pub(crate) use generic_conversion_single_value_with_result;
334
335/// Convert the value at a specific index in the given array into a `Variant`.
336macro_rules! primitive_conversion_single_value {
337    ($t:ty, $input:expr, $index:expr) => {{
338        $crate::type_conversion::generic_conversion_single_value!(
339            $t,
340            as_primitive,
341            |v| v,
342            $input,
343            $index
344        )
345    }};
346}
347pub(crate) use primitive_conversion_single_value;