parquet_variant_compute/
cast_to_variant.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{VariantArray, VariantArrayBuilder};
19use arrow::array::{Array, AsArray};
20use arrow::datatypes::{
21    BinaryType, BinaryViewType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
22    Int64Type, Int8Type, LargeBinaryType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
23};
24use arrow_schema::{ArrowError, DataType};
25use half::f16;
26use parquet_variant::Variant;
27
28/// Convert the input array of a specific primitive type to a `VariantArray`
29/// row by row
30macro_rules! primitive_conversion {
31    ($t:ty, $input:expr, $builder:expr) => {{
32        let array = $input.as_primitive::<$t>();
33        for i in 0..array.len() {
34            if array.is_null(i) {
35                $builder.append_null();
36                continue;
37            }
38            $builder.append_variant(Variant::from(array.value(i)));
39        }
40    }};
41}
42
43/// Convert the input array to a `VariantArray` row by row, using `method`
44/// to downcast the generic array to a specific array type and `cast_fn`
45/// to transform each element to a type compatible with Variant
46macro_rules! cast_conversion {
47    ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
48        let array = $input.$method::<$t>();
49        for i in 0..array.len() {
50            if array.is_null(i) {
51                $builder.append_null();
52                continue;
53            }
54            let cast_value = $cast_fn(array.value(i));
55            $builder.append_variant(Variant::from(cast_value));
56        }
57    }};
58}
59
60macro_rules! cast_conversion_nongeneric {
61    ($method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
62        let array = $input.$method();
63        for i in 0..array.len() {
64            if array.is_null(i) {
65                $builder.append_null();
66                continue;
67            }
68            let cast_value = $cast_fn(array.value(i));
69            $builder.append_variant(Variant::from(cast_value));
70        }
71    }};
72}
73
74/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
75/// need to convert a specific data type
76///
77/// # Arguments
78/// * `input` - A reference to the input [`Array`] to cast
79///
80/// # Notes
81/// If the input array element is null, the corresponding element in the
82/// output `VariantArray` will also be null (not `Variant::Null`).
83///
84/// # Example
85/// ```
86/// # use arrow::array::{Array, ArrayRef, Int64Array};
87/// # use parquet_variant::Variant;
88/// # use parquet_variant_compute::cast_to_variant::cast_to_variant;
89/// // input is an Int64Array, which will be cast to a VariantArray
90/// let input = Int64Array::from(vec![Some(1), None, Some(3)]);
91/// let result = cast_to_variant(&input).unwrap();
92/// assert_eq!(result.len(), 3);
93/// assert_eq!(result.value(0), Variant::Int64(1));
94/// assert!(result.is_null(1)); // note null, not Variant::Null
95/// assert_eq!(result.value(2), Variant::Int64(3));
96/// ```
97pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
98    let mut builder = VariantArrayBuilder::new(input.len());
99
100    let input_type = input.data_type();
101    // todo: handle other types like Boolean, Strings, Date, Timestamp, etc.
102    match input_type {
103        DataType::Binary => {
104            cast_conversion!(BinaryType, as_bytes, |v| v, input, builder);
105        }
106        DataType::LargeBinary => {
107            cast_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder);
108        }
109        DataType::BinaryView => {
110            cast_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder);
111        }
112        DataType::Int8 => {
113            primitive_conversion!(Int8Type, input, builder);
114        }
115        DataType::Int16 => {
116            primitive_conversion!(Int16Type, input, builder);
117        }
118        DataType::Int32 => {
119            primitive_conversion!(Int32Type, input, builder);
120        }
121        DataType::Int64 => {
122            primitive_conversion!(Int64Type, input, builder);
123        }
124        DataType::UInt8 => {
125            primitive_conversion!(UInt8Type, input, builder);
126        }
127        DataType::UInt16 => {
128            primitive_conversion!(UInt16Type, input, builder);
129        }
130        DataType::UInt32 => {
131            primitive_conversion!(UInt32Type, input, builder);
132        }
133        DataType::UInt64 => {
134            primitive_conversion!(UInt64Type, input, builder);
135        }
136        DataType::Float16 => {
137            cast_conversion!(
138                Float16Type,
139                as_primitive,
140                |v: f16| -> f32 { v.into() },
141                input,
142                builder
143            );
144        }
145        DataType::Float32 => {
146            primitive_conversion!(Float32Type, input, builder);
147        }
148        DataType::Float64 => {
149            primitive_conversion!(Float64Type, input, builder);
150        }
151        DataType::FixedSizeBinary(_) => {
152            cast_conversion_nongeneric!(as_fixed_size_binary, |v| v, input, builder);
153        }
154        dt => {
155            return Err(ArrowError::CastError(format!(
156                "Unsupported data type for casting to Variant: {dt:?}",
157            )));
158        }
159    };
160    Ok(builder.build())
161}
162
163// TODO do we need a cast_with_options to allow specifying conversion behavior,
164// e.g. how to handle overflows, whether to convert to Variant::Null or return
165// an error, etc. ?
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use arrow::array::{
171        ArrayRef, FixedSizeBinaryBuilder, Float16Array, Float32Array, Float64Array,
172        GenericByteBuilder, GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array,
173        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
174    };
175    use parquet_variant::{Variant, VariantDecimal16};
176    use std::{sync::Arc, vec};
177
178    #[test]
179    fn test_cast_to_variant_fixed_size_binary() {
180        let v1 = vec![1, 2];
181        let v2 = vec![3, 4];
182        let v3 = vec![5, 6];
183
184        let mut builder = FixedSizeBinaryBuilder::new(2);
185        builder.append_value(&v1).unwrap();
186        builder.append_value(&v2).unwrap();
187        builder.append_null();
188        builder.append_value(&v3).unwrap();
189        let array = builder.finish();
190
191        run_test(
192            Arc::new(array),
193            vec![
194                Some(Variant::Binary(&v1)),
195                Some(Variant::Binary(&v2)),
196                None,
197                Some(Variant::Binary(&v3)),
198            ],
199        );
200    }
201
202    #[test]
203    fn test_cast_to_variant_binary() {
204        // BinaryType
205        let mut builder = GenericByteBuilder::<BinaryType>::new();
206        builder.append_value(b"hello");
207        builder.append_value(b"");
208        builder.append_null();
209        builder.append_value(b"world");
210        let binary_array = builder.finish();
211        run_test(
212            Arc::new(binary_array),
213            vec![
214                Some(Variant::Binary(b"hello")),
215                Some(Variant::Binary(b"")),
216                None,
217                Some(Variant::Binary(b"world")),
218            ],
219        );
220
221        // LargeBinaryType
222        let mut builder = GenericByteBuilder::<LargeBinaryType>::new();
223        builder.append_value(b"hello");
224        builder.append_value(b"");
225        builder.append_null();
226        builder.append_value(b"world");
227        let large_binary_array = builder.finish();
228        run_test(
229            Arc::new(large_binary_array),
230            vec![
231                Some(Variant::Binary(b"hello")),
232                Some(Variant::Binary(b"")),
233                None,
234                Some(Variant::Binary(b"world")),
235            ],
236        );
237
238        // BinaryViewType
239        let mut builder = GenericByteViewBuilder::<BinaryViewType>::new();
240        builder.append_value(b"hello");
241        builder.append_value(b"");
242        builder.append_null();
243        builder.append_value(b"world");
244        let byte_view_array = builder.finish();
245        run_test(
246            Arc::new(byte_view_array),
247            vec![
248                Some(Variant::Binary(b"hello")),
249                Some(Variant::Binary(b"")),
250                None,
251                Some(Variant::Binary(b"world")),
252            ],
253        );
254    }
255
256    #[test]
257    fn test_cast_to_variant_int8() {
258        run_test(
259            Arc::new(Int8Array::from(vec![
260                Some(i8::MIN),
261                None,
262                Some(-1),
263                Some(1),
264                Some(i8::MAX),
265            ])),
266            vec![
267                Some(Variant::Int8(i8::MIN)),
268                None,
269                Some(Variant::Int8(-1)),
270                Some(Variant::Int8(1)),
271                Some(Variant::Int8(i8::MAX)),
272            ],
273        )
274    }
275
276    #[test]
277    fn test_cast_to_variant_int16() {
278        run_test(
279            Arc::new(Int16Array::from(vec![
280                Some(i16::MIN),
281                None,
282                Some(-1),
283                Some(1),
284                Some(i16::MAX),
285            ])),
286            vec![
287                Some(Variant::Int16(i16::MIN)),
288                None,
289                Some(Variant::Int16(-1)),
290                Some(Variant::Int16(1)),
291                Some(Variant::Int16(i16::MAX)),
292            ],
293        )
294    }
295
296    #[test]
297    fn test_cast_to_variant_int32() {
298        run_test(
299            Arc::new(Int32Array::from(vec![
300                Some(i32::MIN),
301                None,
302                Some(-1),
303                Some(1),
304                Some(i32::MAX),
305            ])),
306            vec![
307                Some(Variant::Int32(i32::MIN)),
308                None,
309                Some(Variant::Int32(-1)),
310                Some(Variant::Int32(1)),
311                Some(Variant::Int32(i32::MAX)),
312            ],
313        )
314    }
315
316    #[test]
317    fn test_cast_to_variant_int64() {
318        run_test(
319            Arc::new(Int64Array::from(vec![
320                Some(i64::MIN),
321                None,
322                Some(-1),
323                Some(1),
324                Some(i64::MAX),
325            ])),
326            vec![
327                Some(Variant::Int64(i64::MIN)),
328                None,
329                Some(Variant::Int64(-1)),
330                Some(Variant::Int64(1)),
331                Some(Variant::Int64(i64::MAX)),
332            ],
333        )
334    }
335
336    #[test]
337    fn test_cast_to_variant_uint8() {
338        run_test(
339            Arc::new(UInt8Array::from(vec![
340                Some(0),
341                None,
342                Some(1),
343                Some(127),
344                Some(u8::MAX),
345            ])),
346            vec![
347                Some(Variant::Int8(0)),
348                None,
349                Some(Variant::Int8(1)),
350                Some(Variant::Int8(127)),
351                Some(Variant::Int16(255)), // u8::MAX cannot fit in Int8
352            ],
353        )
354    }
355
356    #[test]
357    fn test_cast_to_variant_uint16() {
358        run_test(
359            Arc::new(UInt16Array::from(vec![
360                Some(0),
361                None,
362                Some(1),
363                Some(32767),
364                Some(u16::MAX),
365            ])),
366            vec![
367                Some(Variant::Int16(0)),
368                None,
369                Some(Variant::Int16(1)),
370                Some(Variant::Int16(32767)),
371                Some(Variant::Int32(65535)), // u16::MAX cannot fit in Int16
372            ],
373        )
374    }
375
376    #[test]
377    fn test_cast_to_variant_uint32() {
378        run_test(
379            Arc::new(UInt32Array::from(vec![
380                Some(0),
381                None,
382                Some(1),
383                Some(2147483647),
384                Some(u32::MAX),
385            ])),
386            vec![
387                Some(Variant::Int32(0)),
388                None,
389                Some(Variant::Int32(1)),
390                Some(Variant::Int32(2147483647)),
391                Some(Variant::Int64(4294967295)), // u32::MAX cannot fit in Int32
392            ],
393        )
394    }
395
396    #[test]
397    fn test_cast_to_variant_uint64() {
398        run_test(
399            Arc::new(UInt64Array::from(vec![
400                Some(0),
401                None,
402                Some(1),
403                Some(9223372036854775807),
404                Some(u64::MAX),
405            ])),
406            vec![
407                Some(Variant::Int64(0)),
408                None,
409                Some(Variant::Int64(1)),
410                Some(Variant::Int64(9223372036854775807)),
411                Some(Variant::Decimal16(
412                    // u64::MAX cannot fit in Int64
413                    VariantDecimal16::try_from(18446744073709551615).unwrap(),
414                )),
415            ],
416        )
417    }
418
419    #[test]
420    fn test_cast_to_variant_float16() {
421        run_test(
422            Arc::new(Float16Array::from(vec![
423                Some(f16::MIN),
424                None,
425                Some(f16::from_f32(-1.5)),
426                Some(f16::from_f32(0.0)),
427                Some(f16::from_f32(1.5)),
428                Some(f16::MAX),
429            ])),
430            vec![
431                Some(Variant::Float(f16::MIN.into())),
432                None,
433                Some(Variant::Float(-1.5)),
434                Some(Variant::Float(0.0)),
435                Some(Variant::Float(1.5)),
436                Some(Variant::Float(f16::MAX.into())),
437            ],
438        )
439    }
440
441    #[test]
442    fn test_cast_to_variant_float32() {
443        run_test(
444            Arc::new(Float32Array::from(vec![
445                Some(f32::MIN),
446                None,
447                Some(-1.5),
448                Some(0.0),
449                Some(1.5),
450                Some(f32::MAX),
451            ])),
452            vec![
453                Some(Variant::Float(f32::MIN)),
454                None,
455                Some(Variant::Float(-1.5)),
456                Some(Variant::Float(0.0)),
457                Some(Variant::Float(1.5)),
458                Some(Variant::Float(f32::MAX)),
459            ],
460        )
461    }
462
463    #[test]
464    fn test_cast_to_variant_float64() {
465        run_test(
466            Arc::new(Float64Array::from(vec![
467                Some(f64::MIN),
468                None,
469                Some(-1.5),
470                Some(0.0),
471                Some(1.5),
472                Some(f64::MAX),
473            ])),
474            vec![
475                Some(Variant::Double(f64::MIN)),
476                None,
477                Some(Variant::Double(-1.5)),
478                Some(Variant::Double(0.0)),
479                Some(Variant::Double(1.5)),
480                Some(Variant::Double(f64::MAX)),
481            ],
482        )
483    }
484
485    /// Converts the given `Array` to a `VariantArray` and tests the conversion
486    /// against the expected values. It also tests the handling of nulls by
487    /// setting one element to null and verifying the output.
488    fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
489        // test without nulls
490        let variant_array = cast_to_variant(&values).unwrap();
491        assert_eq!(variant_array.len(), expected.len());
492        for (i, expected_value) in expected.iter().enumerate() {
493            match expected_value {
494                Some(value) => {
495                    assert!(!variant_array.is_null(i), "Expected non-null at index {i}");
496                    assert_eq!(variant_array.value(i), *value, "mismatch at index {i}");
497                }
498                None => {
499                    assert!(variant_array.is_null(i), "Expected null at index {i}");
500                }
501            }
502        }
503    }
504}