1use crate::{VariantArray, VariantArrayBuilder};
19use arrow::array::{Array, AsArray};
20use arrow::datatypes::{
21 BinaryType, BinaryViewType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
22 Int64Type, Int8Type, LargeBinaryType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
23};
24use arrow_schema::{ArrowError, DataType};
25use half::f16;
26use parquet_variant::Variant;
27
28macro_rules! primitive_conversion {
31 ($t:ty, $input:expr, $builder:expr) => {{
32 let array = $input.as_primitive::<$t>();
33 for i in 0..array.len() {
34 if array.is_null(i) {
35 $builder.append_null();
36 continue;
37 }
38 $builder.append_variant(Variant::from(array.value(i)));
39 }
40 }};
41}
42
43macro_rules! cast_conversion {
47 ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
48 let array = $input.$method::<$t>();
49 for i in 0..array.len() {
50 if array.is_null(i) {
51 $builder.append_null();
52 continue;
53 }
54 let cast_value = $cast_fn(array.value(i));
55 $builder.append_variant(Variant::from(cast_value));
56 }
57 }};
58}
59
60macro_rules! cast_conversion_nongeneric {
61 ($method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
62 let array = $input.$method();
63 for i in 0..array.len() {
64 if array.is_null(i) {
65 $builder.append_null();
66 continue;
67 }
68 let cast_value = $cast_fn(array.value(i));
69 $builder.append_variant(Variant::from(cast_value));
70 }
71 }};
72}
73
74pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
98 let mut builder = VariantArrayBuilder::new(input.len());
99
100 let input_type = input.data_type();
101 match input_type {
103 DataType::Binary => {
104 cast_conversion!(BinaryType, as_bytes, |v| v, input, builder);
105 }
106 DataType::LargeBinary => {
107 cast_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder);
108 }
109 DataType::BinaryView => {
110 cast_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder);
111 }
112 DataType::Int8 => {
113 primitive_conversion!(Int8Type, input, builder);
114 }
115 DataType::Int16 => {
116 primitive_conversion!(Int16Type, input, builder);
117 }
118 DataType::Int32 => {
119 primitive_conversion!(Int32Type, input, builder);
120 }
121 DataType::Int64 => {
122 primitive_conversion!(Int64Type, input, builder);
123 }
124 DataType::UInt8 => {
125 primitive_conversion!(UInt8Type, input, builder);
126 }
127 DataType::UInt16 => {
128 primitive_conversion!(UInt16Type, input, builder);
129 }
130 DataType::UInt32 => {
131 primitive_conversion!(UInt32Type, input, builder);
132 }
133 DataType::UInt64 => {
134 primitive_conversion!(UInt64Type, input, builder);
135 }
136 DataType::Float16 => {
137 cast_conversion!(
138 Float16Type,
139 as_primitive,
140 |v: f16| -> f32 { v.into() },
141 input,
142 builder
143 );
144 }
145 DataType::Float32 => {
146 primitive_conversion!(Float32Type, input, builder);
147 }
148 DataType::Float64 => {
149 primitive_conversion!(Float64Type, input, builder);
150 }
151 DataType::FixedSizeBinary(_) => {
152 cast_conversion_nongeneric!(as_fixed_size_binary, |v| v, input, builder);
153 }
154 dt => {
155 return Err(ArrowError::CastError(format!(
156 "Unsupported data type for casting to Variant: {dt:?}",
157 )));
158 }
159 };
160 Ok(builder.build())
161}
162
163#[cfg(test)]
168mod tests {
169 use super::*;
170 use arrow::array::{
171 ArrayRef, FixedSizeBinaryBuilder, Float16Array, Float32Array, Float64Array,
172 GenericByteBuilder, GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array,
173 UInt16Array, UInt32Array, UInt64Array, UInt8Array,
174 };
175 use parquet_variant::{Variant, VariantDecimal16};
176 use std::{sync::Arc, vec};
177
178 #[test]
179 fn test_cast_to_variant_fixed_size_binary() {
180 let v1 = vec![1, 2];
181 let v2 = vec![3, 4];
182 let v3 = vec![5, 6];
183
184 let mut builder = FixedSizeBinaryBuilder::new(2);
185 builder.append_value(&v1).unwrap();
186 builder.append_value(&v2).unwrap();
187 builder.append_null();
188 builder.append_value(&v3).unwrap();
189 let array = builder.finish();
190
191 run_test(
192 Arc::new(array),
193 vec![
194 Some(Variant::Binary(&v1)),
195 Some(Variant::Binary(&v2)),
196 None,
197 Some(Variant::Binary(&v3)),
198 ],
199 );
200 }
201
202 #[test]
203 fn test_cast_to_variant_binary() {
204 let mut builder = GenericByteBuilder::<BinaryType>::new();
206 builder.append_value(b"hello");
207 builder.append_value(b"");
208 builder.append_null();
209 builder.append_value(b"world");
210 let binary_array = builder.finish();
211 run_test(
212 Arc::new(binary_array),
213 vec![
214 Some(Variant::Binary(b"hello")),
215 Some(Variant::Binary(b"")),
216 None,
217 Some(Variant::Binary(b"world")),
218 ],
219 );
220
221 let mut builder = GenericByteBuilder::<LargeBinaryType>::new();
223 builder.append_value(b"hello");
224 builder.append_value(b"");
225 builder.append_null();
226 builder.append_value(b"world");
227 let large_binary_array = builder.finish();
228 run_test(
229 Arc::new(large_binary_array),
230 vec![
231 Some(Variant::Binary(b"hello")),
232 Some(Variant::Binary(b"")),
233 None,
234 Some(Variant::Binary(b"world")),
235 ],
236 );
237
238 let mut builder = GenericByteViewBuilder::<BinaryViewType>::new();
240 builder.append_value(b"hello");
241 builder.append_value(b"");
242 builder.append_null();
243 builder.append_value(b"world");
244 let byte_view_array = builder.finish();
245 run_test(
246 Arc::new(byte_view_array),
247 vec![
248 Some(Variant::Binary(b"hello")),
249 Some(Variant::Binary(b"")),
250 None,
251 Some(Variant::Binary(b"world")),
252 ],
253 );
254 }
255
256 #[test]
257 fn test_cast_to_variant_int8() {
258 run_test(
259 Arc::new(Int8Array::from(vec![
260 Some(i8::MIN),
261 None,
262 Some(-1),
263 Some(1),
264 Some(i8::MAX),
265 ])),
266 vec![
267 Some(Variant::Int8(i8::MIN)),
268 None,
269 Some(Variant::Int8(-1)),
270 Some(Variant::Int8(1)),
271 Some(Variant::Int8(i8::MAX)),
272 ],
273 )
274 }
275
276 #[test]
277 fn test_cast_to_variant_int16() {
278 run_test(
279 Arc::new(Int16Array::from(vec![
280 Some(i16::MIN),
281 None,
282 Some(-1),
283 Some(1),
284 Some(i16::MAX),
285 ])),
286 vec![
287 Some(Variant::Int16(i16::MIN)),
288 None,
289 Some(Variant::Int16(-1)),
290 Some(Variant::Int16(1)),
291 Some(Variant::Int16(i16::MAX)),
292 ],
293 )
294 }
295
296 #[test]
297 fn test_cast_to_variant_int32() {
298 run_test(
299 Arc::new(Int32Array::from(vec![
300 Some(i32::MIN),
301 None,
302 Some(-1),
303 Some(1),
304 Some(i32::MAX),
305 ])),
306 vec![
307 Some(Variant::Int32(i32::MIN)),
308 None,
309 Some(Variant::Int32(-1)),
310 Some(Variant::Int32(1)),
311 Some(Variant::Int32(i32::MAX)),
312 ],
313 )
314 }
315
316 #[test]
317 fn test_cast_to_variant_int64() {
318 run_test(
319 Arc::new(Int64Array::from(vec![
320 Some(i64::MIN),
321 None,
322 Some(-1),
323 Some(1),
324 Some(i64::MAX),
325 ])),
326 vec![
327 Some(Variant::Int64(i64::MIN)),
328 None,
329 Some(Variant::Int64(-1)),
330 Some(Variant::Int64(1)),
331 Some(Variant::Int64(i64::MAX)),
332 ],
333 )
334 }
335
336 #[test]
337 fn test_cast_to_variant_uint8() {
338 run_test(
339 Arc::new(UInt8Array::from(vec![
340 Some(0),
341 None,
342 Some(1),
343 Some(127),
344 Some(u8::MAX),
345 ])),
346 vec![
347 Some(Variant::Int8(0)),
348 None,
349 Some(Variant::Int8(1)),
350 Some(Variant::Int8(127)),
351 Some(Variant::Int16(255)), ],
353 )
354 }
355
356 #[test]
357 fn test_cast_to_variant_uint16() {
358 run_test(
359 Arc::new(UInt16Array::from(vec![
360 Some(0),
361 None,
362 Some(1),
363 Some(32767),
364 Some(u16::MAX),
365 ])),
366 vec![
367 Some(Variant::Int16(0)),
368 None,
369 Some(Variant::Int16(1)),
370 Some(Variant::Int16(32767)),
371 Some(Variant::Int32(65535)), ],
373 )
374 }
375
376 #[test]
377 fn test_cast_to_variant_uint32() {
378 run_test(
379 Arc::new(UInt32Array::from(vec![
380 Some(0),
381 None,
382 Some(1),
383 Some(2147483647),
384 Some(u32::MAX),
385 ])),
386 vec![
387 Some(Variant::Int32(0)),
388 None,
389 Some(Variant::Int32(1)),
390 Some(Variant::Int32(2147483647)),
391 Some(Variant::Int64(4294967295)), ],
393 )
394 }
395
396 #[test]
397 fn test_cast_to_variant_uint64() {
398 run_test(
399 Arc::new(UInt64Array::from(vec![
400 Some(0),
401 None,
402 Some(1),
403 Some(9223372036854775807),
404 Some(u64::MAX),
405 ])),
406 vec![
407 Some(Variant::Int64(0)),
408 None,
409 Some(Variant::Int64(1)),
410 Some(Variant::Int64(9223372036854775807)),
411 Some(Variant::Decimal16(
412 VariantDecimal16::try_from(18446744073709551615).unwrap(),
414 )),
415 ],
416 )
417 }
418
419 #[test]
420 fn test_cast_to_variant_float16() {
421 run_test(
422 Arc::new(Float16Array::from(vec![
423 Some(f16::MIN),
424 None,
425 Some(f16::from_f32(-1.5)),
426 Some(f16::from_f32(0.0)),
427 Some(f16::from_f32(1.5)),
428 Some(f16::MAX),
429 ])),
430 vec![
431 Some(Variant::Float(f16::MIN.into())),
432 None,
433 Some(Variant::Float(-1.5)),
434 Some(Variant::Float(0.0)),
435 Some(Variant::Float(1.5)),
436 Some(Variant::Float(f16::MAX.into())),
437 ],
438 )
439 }
440
441 #[test]
442 fn test_cast_to_variant_float32() {
443 run_test(
444 Arc::new(Float32Array::from(vec![
445 Some(f32::MIN),
446 None,
447 Some(-1.5),
448 Some(0.0),
449 Some(1.5),
450 Some(f32::MAX),
451 ])),
452 vec![
453 Some(Variant::Float(f32::MIN)),
454 None,
455 Some(Variant::Float(-1.5)),
456 Some(Variant::Float(0.0)),
457 Some(Variant::Float(1.5)),
458 Some(Variant::Float(f32::MAX)),
459 ],
460 )
461 }
462
463 #[test]
464 fn test_cast_to_variant_float64() {
465 run_test(
466 Arc::new(Float64Array::from(vec![
467 Some(f64::MIN),
468 None,
469 Some(-1.5),
470 Some(0.0),
471 Some(1.5),
472 Some(f64::MAX),
473 ])),
474 vec![
475 Some(Variant::Double(f64::MIN)),
476 None,
477 Some(Variant::Double(-1.5)),
478 Some(Variant::Double(0.0)),
479 Some(Variant::Double(1.5)),
480 Some(Variant::Double(f64::MAX)),
481 ],
482 )
483 }
484
485 fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
489 let variant_array = cast_to_variant(&values).unwrap();
491 assert_eq!(variant_array.len(), expected.len());
492 for (i, expected_value) in expected.iter().enumerate() {
493 match expected_value {
494 Some(value) => {
495 assert!(!variant_array.is_null(i), "Expected non-null at index {i}");
496 assert_eq!(variant_array.value(i), *value, "mismatch at index {i}");
497 }
498 None => {
499 assert!(variant_array.is_null(i), "Expected null at index {i}");
500 }
501 }
502 }
503 }
504}