1use crate::arrow_to_variant::ListLikeArray;
21use crate::{BorrowedShreddingState, VariantArray, VariantValueArrayBuilder};
22use arrow::array::{
23    Array, AsArray as _, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray,
24    GenericListArray, GenericListViewArray, PrimitiveArray, StringArray, StructArray,
25};
26use arrow::buffer::NullBuffer;
27use arrow::datatypes::{
28    ArrowPrimitiveType, DataType, Date32Type, Decimal32Type, Decimal64Type, Decimal128Type,
29    DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type,
30    Time64MicrosecondType, TimeUnit, TimestampMicrosecondType, TimestampNanosecondType,
31};
32use arrow::error::{ArrowError, Result};
33use arrow::temporal_conversions::time64us_to_time;
34use chrono::{DateTime, Utc};
35use indexmap::IndexMap;
36use parquet_variant::{
37    ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8,
38    VariantDecimal16, VariantDecimalType, VariantMetadata,
39};
40use std::marker::PhantomData;
41use uuid::Uuid;
42
43pub fn unshred_variant(array: &VariantArray) -> Result<VariantArray> {
59    if array.typed_value_field().is_none() && array.value_field().is_some() {
61        return Ok(array.clone());
62    }
63
64    let nulls = array.nulls();
67    let mut row_builder = UnshredVariantRowBuilder::try_new_opt(array.shredding_state().borrow())?
68        .unwrap_or_else(|| UnshredVariantRowBuilder::null(nulls));
69
70    let metadata = array.metadata_field();
71    let mut value_builder = VariantValueArrayBuilder::new(array.len());
72    for i in 0..array.len() {
73        if array.is_null(i) {
74            value_builder.append_null();
75        } else {
76            let metadata = VariantMetadata::new(metadata.value(i));
77            let mut value_builder = value_builder.builder_ext(&metadata);
78            row_builder.append_row(&mut value_builder, &metadata, i)?;
79        }
80    }
81
82    let value = value_builder.build()?;
83    Ok(VariantArray::from_parts(
84        metadata.clone(),
85        Some(value),
86        None,
87        nulls.cloned(),
88    ))
89}
90
91enum UnshredVariantRowBuilder<'a> {
93    PrimitiveInt8(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int8Type>>),
94    PrimitiveInt16(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int16Type>>),
95    PrimitiveInt32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int32Type>>),
96    PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int64Type>>),
97    PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Float32Type>>),
98    PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Float64Type>>),
99    Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Type, VariantDecimal4>),
100    Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Type, VariantDecimal8>),
101    Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Type, VariantDecimal16>),
102    PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Date32Type>>),
103    PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Time64MicrosecondType>>),
104    TimestampMicrosecond(TimestampUnshredRowBuilder<'a, TimestampMicrosecondType>),
105    TimestampNanosecond(TimestampUnshredRowBuilder<'a, TimestampNanosecondType>),
106    PrimitiveBoolean(UnshredPrimitiveRowBuilder<'a, BooleanArray>),
107    PrimitiveString(UnshredPrimitiveRowBuilder<'a, StringArray>),
108    PrimitiveBinaryView(UnshredPrimitiveRowBuilder<'a, BinaryViewArray>),
109    PrimitiveUuid(UnshredPrimitiveRowBuilder<'a, FixedSizeBinaryArray>),
110    List(ListUnshredVariantBuilder<'a, GenericListArray<i32>>),
111    LargeList(ListUnshredVariantBuilder<'a, GenericListArray<i64>>),
112    ListView(ListUnshredVariantBuilder<'a, GenericListViewArray<i32>>),
113    LargeListView(ListUnshredVariantBuilder<'a, GenericListViewArray<i64>>),
114    FixedSizeList(ListUnshredVariantBuilder<'a, FixedSizeListArray>),
115    Struct(StructUnshredVariantBuilder<'a>),
116    ValueOnly(ValueOnlyUnshredVariantBuilder<'a>),
117    Null(NullUnshredVariantBuilder<'a>),
118}
119
120impl<'a> UnshredVariantRowBuilder<'a> {
121    fn null(nulls: Option<&'a NullBuffer>) -> Self {
123        Self::Null(NullUnshredVariantBuilder::new(nulls))
124    }
125
126    fn append_row(
128        &mut self,
129        builder: &mut impl VariantBuilderExt,
130        metadata: &VariantMetadata,
131        index: usize,
132    ) -> Result<()> {
133        match self {
134            Self::PrimitiveInt8(b) => b.append_row(builder, metadata, index),
135            Self::PrimitiveInt16(b) => b.append_row(builder, metadata, index),
136            Self::PrimitiveInt32(b) => b.append_row(builder, metadata, index),
137            Self::PrimitiveInt64(b) => b.append_row(builder, metadata, index),
138            Self::PrimitiveFloat32(b) => b.append_row(builder, metadata, index),
139            Self::PrimitiveFloat64(b) => b.append_row(builder, metadata, index),
140            Self::Decimal32(b) => b.append_row(builder, metadata, index),
141            Self::Decimal64(b) => b.append_row(builder, metadata, index),
142            Self::Decimal128(b) => b.append_row(builder, metadata, index),
143            Self::PrimitiveDate32(b) => b.append_row(builder, metadata, index),
144            Self::PrimitiveTime64(b) => b.append_row(builder, metadata, index),
145            Self::TimestampMicrosecond(b) => b.append_row(builder, metadata, index),
146            Self::TimestampNanosecond(b) => b.append_row(builder, metadata, index),
147            Self::PrimitiveBoolean(b) => b.append_row(builder, metadata, index),
148            Self::PrimitiveString(b) => b.append_row(builder, metadata, index),
149            Self::PrimitiveBinaryView(b) => b.append_row(builder, metadata, index),
150            Self::PrimitiveUuid(b) => b.append_row(builder, metadata, index),
151            Self::List(b) => b.append_row(builder, metadata, index),
152            Self::LargeList(b) => b.append_row(builder, metadata, index),
153            Self::ListView(b) => b.append_row(builder, metadata, index),
154            Self::LargeListView(b) => b.append_row(builder, metadata, index),
155            Self::FixedSizeList(b) => b.append_row(builder, metadata, index),
156            Self::Struct(b) => b.append_row(builder, metadata, index),
157            Self::ValueOnly(b) => b.append_row(builder, metadata, index),
158            Self::Null(b) => b.append_row(builder, metadata, index),
159        }
160    }
161
162    fn try_new_opt(shredding_state: BorrowedShreddingState<'a>) -> Result<Option<Self>> {
165        let value = shredding_state.value_field();
166        let typed_value = shredding_state.typed_value_field();
167        let Some(typed_value) = typed_value else {
168            return Ok(value.map(|v| Self::ValueOnly(ValueOnlyUnshredVariantBuilder::new(v))));
170        };
171
172        macro_rules! primitive_builder {
174            ($enum_variant:ident, $cast_fn:ident) => {
175                Self::$enum_variant(UnshredPrimitiveRowBuilder::new(
176                    value,
177                    typed_value.$cast_fn(),
178                ))
179            };
180        }
181
182        let builder = match typed_value.data_type() {
183            DataType::Int8 => primitive_builder!(PrimitiveInt8, as_primitive),
184            DataType::Int16 => primitive_builder!(PrimitiveInt16, as_primitive),
185            DataType::Int32 => primitive_builder!(PrimitiveInt32, as_primitive),
186            DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive),
187            DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive),
188            DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive),
189            DataType::Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) => {
190                Self::Decimal32(DecimalUnshredRowBuilder::new(value, typed_value, *s as _))
191            }
192            DataType::Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => {
193                Self::Decimal64(DecimalUnshredRowBuilder::new(value, typed_value, *s as _))
194            }
195            DataType::Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, s) => {
196                Self::Decimal128(DecimalUnshredRowBuilder::new(value, typed_value, *s as _))
197            }
198            DataType::Decimal32(_, _)
199            | DataType::Decimal64(_, _)
200            | DataType::Decimal128(_, _)
201            | DataType::Decimal256(_, _) => {
202                return Err(ArrowError::InvalidArgumentError(format!(
203                    "{} is not a valid variant shredding type",
204                    typed_value.data_type()
205                )));
206            }
207            DataType::Date32 => primitive_builder!(PrimitiveDate32, as_primitive),
208            DataType::Time64(TimeUnit::Microsecond) => {
209                primitive_builder!(PrimitiveTime64, as_primitive)
210            }
211            DataType::Time64(time_unit) => {
212                return Err(ArrowError::InvalidArgumentError(format!(
213                    "Time64({time_unit}) is not a valid variant shredding type",
214                )));
215            }
216            DataType::Timestamp(TimeUnit::Microsecond, timezone) => Self::TimestampMicrosecond(
217                TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()),
218            ),
219            DataType::Timestamp(TimeUnit::Nanosecond, timezone) => Self::TimestampNanosecond(
220                TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()),
221            ),
222            DataType::Timestamp(time_unit, _) => {
223                return Err(ArrowError::InvalidArgumentError(format!(
224                    "Timestamp({time_unit}) is not a valid variant shredding type",
225                )));
226            }
227            DataType::Boolean => primitive_builder!(PrimitiveBoolean, as_boolean),
228            DataType::Utf8 => primitive_builder!(PrimitiveString, as_string),
229            DataType::BinaryView => primitive_builder!(PrimitiveBinaryView, as_binary_view),
230            DataType::FixedSizeBinary(16) => {
231                primitive_builder!(PrimitiveUuid, as_fixed_size_binary)
232            }
233            DataType::FixedSizeBinary(size) => {
234                return Err(ArrowError::InvalidArgumentError(format!(
235                    "FixedSizeBinary({size}) is not a valid variant shredding type",
236                )));
237            }
238            DataType::Struct(_) => Self::Struct(StructUnshredVariantBuilder::try_new(
239                value,
240                typed_value.as_struct(),
241            )?),
242            DataType::List(_) => Self::List(ListUnshredVariantBuilder::try_new(
243                value,
244                typed_value.as_list(),
245            )?),
246            DataType::LargeList(_) => Self::LargeList(ListUnshredVariantBuilder::try_new(
247                value,
248                typed_value.as_list(),
249            )?),
250            DataType::ListView(_) => Self::ListView(ListUnshredVariantBuilder::try_new(
251                value,
252                typed_value.as_list_view(),
253            )?),
254            DataType::LargeListView(_) => Self::LargeListView(ListUnshredVariantBuilder::try_new(
255                value,
256                typed_value.as_list_view(),
257            )?),
258            DataType::FixedSizeList(_, _) => Self::FixedSizeList(
259                ListUnshredVariantBuilder::try_new(value, typed_value.as_fixed_size_list())?,
260            ),
261            _ => {
262                return Err(ArrowError::NotYetImplemented(format!(
263                    "Unshredding not yet supported for type: {}",
264                    typed_value.data_type()
265                )));
266            }
267        };
268        Ok(Some(builder))
269    }
270}
271
272struct NullUnshredVariantBuilder<'a> {
274    nulls: Option<&'a NullBuffer>,
275}
276
277impl<'a> NullUnshredVariantBuilder<'a> {
278    fn new(nulls: Option<&'a NullBuffer>) -> Self {
279        Self { nulls }
280    }
281
282    fn append_row(
283        &mut self,
284        builder: &mut impl VariantBuilderExt,
285        _metadata: &VariantMetadata,
286        index: usize,
287    ) -> Result<()> {
288        if self.nulls.is_some_and(|nulls| nulls.is_null(index)) {
289            builder.append_null();
290        } else {
291            builder.append_value(Variant::Null);
292        }
293        Ok(())
294    }
295}
296
297struct ValueOnlyUnshredVariantBuilder<'a> {
299    value: &'a arrow::array::BinaryViewArray,
300}
301
302impl<'a> ValueOnlyUnshredVariantBuilder<'a> {
303    fn new(value: &'a BinaryViewArray) -> Self {
304        Self { value }
305    }
306
307    fn append_row(
308        &mut self,
309        builder: &mut impl VariantBuilderExt,
310        metadata: &VariantMetadata,
311        index: usize,
312    ) -> Result<()> {
313        if self.value.is_null(index) {
314            builder.append_null();
315        } else {
316            let variant = Variant::new_with_metadata(metadata.clone(), self.value.value(index));
317            builder.append_value(variant);
318        }
319        Ok(())
320    }
321}
322
323trait AppendToVariantBuilder: Array {
326    fn append_to_variant_builder(
327        &self,
328        builder: &mut impl VariantBuilderExt,
329        index: usize,
330    ) -> Result<()>;
331}
332
333macro_rules! handle_unshredded_case {
336    ($self:expr, $builder:expr, $metadata:expr, $index:expr, $partial_shredding:expr) => {{
337        let value = $self.value.as_ref().filter(|v| v.is_valid($index));
338        let value = value.map(|v| Variant::new_with_metadata($metadata.clone(), v.value($index)));
339
340        if $self.typed_value.is_null($index) {
342            match value {
343                Some(value) => $builder.append_value(value),
344                None => $builder.append_null(),
345            }
346            return Ok(());
347        }
348
349        if !$partial_shredding && value.is_some() {
351            return Err(ArrowError::InvalidArgumentError(
352                "Invalid shredded variant: both value and typed_value are non-null".to_string(),
353            ));
354        }
355
356        value
358    }};
359}
360
361struct UnshredPrimitiveRowBuilder<'a, T> {
363    value: Option<&'a BinaryViewArray>,
364    typed_value: &'a T,
365}
366
367impl<'a, T: AppendToVariantBuilder> UnshredPrimitiveRowBuilder<'a, T> {
368    fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a T) -> Self {
369        Self { value, typed_value }
370    }
371
372    fn append_row(
373        &mut self,
374        builder: &mut impl VariantBuilderExt,
375        metadata: &VariantMetadata,
376        index: usize,
377    ) -> Result<()> {
378        handle_unshredded_case!(self, builder, metadata, index, false);
379
380        self.typed_value.append_to_variant_builder(builder, index)
382    }
383}
384
385macro_rules! impl_append_to_variant_builder {
387    ($array_type:ty $(, |$v:ident| $transform:expr)? ) => {
388        impl AppendToVariantBuilder for $array_type {
389            fn append_to_variant_builder(
390                &self,
391                builder: &mut impl VariantBuilderExt,
392                index: usize,
393            ) -> Result<()> {
394                let value = self.value(index);
395                $(
396                    let $v = value;
397                    let value = $transform;
398                )?
399                builder.append_value(value);
400                Ok(())
401            }
402        }
403    };
404}
405
406impl_append_to_variant_builder!(BooleanArray);
407impl_append_to_variant_builder!(StringArray);
408impl_append_to_variant_builder!(BinaryViewArray);
409impl_append_to_variant_builder!(PrimitiveArray<Int8Type>);
410impl_append_to_variant_builder!(PrimitiveArray<Int16Type>);
411impl_append_to_variant_builder!(PrimitiveArray<Int32Type>);
412impl_append_to_variant_builder!(PrimitiveArray<Int64Type>);
413impl_append_to_variant_builder!(PrimitiveArray<Float32Type>);
414impl_append_to_variant_builder!(PrimitiveArray<Float64Type>);
415
416impl_append_to_variant_builder!(PrimitiveArray<Date32Type>, |days_since_epoch| {
417    Date32Type::to_naive_date(days_since_epoch)
418});
419
420impl_append_to_variant_builder!(
421    PrimitiveArray<Time64MicrosecondType>,
422    |micros_since_midnight| {
423        time64us_to_time(micros_since_midnight).ok_or_else(|| {
424            ArrowError::InvalidArgumentError(format!(
425                "Invalid Time64 microsecond value: {micros_since_midnight}"
426            ))
427        })?
428    }
429);
430
431impl_append_to_variant_builder!(FixedSizeBinaryArray, |bytes| {
434    Uuid::from_slice(bytes).unwrap()
435});
436
437trait TimestampType: ArrowPrimitiveType<Native = i64> {
439    fn to_datetime_utc(value: i64) -> Result<DateTime<Utc>>;
440}
441
442impl TimestampType for TimestampMicrosecondType {
443    fn to_datetime_utc(micros: i64) -> Result<DateTime<Utc>> {
444        DateTime::from_timestamp_micros(micros).ok_or_else(|| {
445            ArrowError::InvalidArgumentError(format!(
446                "Invalid timestamp microsecond value: {micros}"
447            ))
448        })
449    }
450}
451
452impl TimestampType for TimestampNanosecondType {
453    fn to_datetime_utc(nanos: i64) -> Result<DateTime<Utc>> {
454        Ok(DateTime::from_timestamp_nanos(nanos))
455    }
456}
457
458struct TimestampUnshredRowBuilder<'a, T: TimestampType> {
460    value: Option<&'a BinaryViewArray>,
461    typed_value: &'a PrimitiveArray<T>,
462    has_timezone: bool,
463}
464
465impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> {
466    fn new(
467        value: Option<&'a BinaryViewArray>,
468        typed_value: &'a dyn Array,
469        has_timezone: bool,
470    ) -> Self {
471        Self {
472            value,
473            typed_value: typed_value.as_primitive(),
474            has_timezone,
475        }
476    }
477
478    fn append_row(
479        &mut self,
480        builder: &mut impl VariantBuilderExt,
481        metadata: &VariantMetadata,
482        index: usize,
483    ) -> Result<()> {
484        handle_unshredded_case!(self, builder, metadata, index, false);
485
486        let timestamp_value = self.typed_value.value(index);
488        let dt = T::to_datetime_utc(timestamp_value)?;
489        if self.has_timezone {
490            builder.append_value(dt);
491        } else {
492            builder.append_value(dt.naive_utc());
493        }
494        Ok(())
495    }
496}
497
498struct DecimalUnshredRowBuilder<'a, A: DecimalType, V>
500where
501    V: VariantDecimalType<Native = A::Native>,
502{
503    value: Option<&'a BinaryViewArray>,
504    typed_value: &'a PrimitiveArray<A>,
505    scale: i8,
506    _phantom: PhantomData<V>,
507}
508
509impl<'a, A: DecimalType, V> DecimalUnshredRowBuilder<'a, A, V>
510where
511    V: VariantDecimalType<Native = A::Native>,
512{
513    fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a dyn Array, scale: i8) -> Self {
514        Self {
515            value,
516            typed_value: typed_value.as_primitive(),
517            scale,
518            _phantom: PhantomData,
519        }
520    }
521
522    fn append_row(
523        &mut self,
524        builder: &mut impl VariantBuilderExt,
525        metadata: &VariantMetadata,
526        index: usize,
527    ) -> Result<()> {
528        handle_unshredded_case!(self, builder, metadata, index, false);
529
530        let raw = self.typed_value.value(index);
531        let variant = V::try_new_with_signed_scale(raw, self.scale)?;
532        builder.append_value(variant);
533        Ok(())
534    }
535}
536
537struct StructUnshredVariantBuilder<'a> {
539    value: Option<&'a arrow::array::BinaryViewArray>,
540    typed_value: &'a arrow::array::StructArray,
541    field_unshredders: IndexMap<&'a str, Option<UnshredVariantRowBuilder<'a>>>,
542}
543
544impl<'a> StructUnshredVariantBuilder<'a> {
545    fn try_new(value: Option<&'a BinaryViewArray>, typed_value: &'a StructArray) -> Result<Self> {
546        let mut field_unshredders = IndexMap::new();
548        for (field, field_array) in typed_value.fields().iter().zip(typed_value.columns()) {
549            let Some(field_array) = field_array.as_struct_opt() else {
551                return Err(ArrowError::InvalidArgumentError(format!(
552                    "Invalid shredded variant object field: expected Struct, got {}",
553                    field_array.data_type()
554                )));
555            };
556            let field_unshredder = UnshredVariantRowBuilder::try_new_opt(field_array.try_into()?)?;
557            field_unshredders.insert(field.name().as_ref(), field_unshredder);
558        }
559
560        Ok(Self {
561            value,
562            typed_value,
563            field_unshredders,
564        })
565    }
566
567    fn append_row(
568        &mut self,
569        builder: &mut impl VariantBuilderExt,
570        metadata: &VariantMetadata,
571        index: usize,
572    ) -> Result<()> {
573        let value = handle_unshredded_case!(self, builder, metadata, index, true);
574
575        let mut object_builder = builder.try_new_object()?;
577
578        for (field_name, field_unshredder_opt) in &mut self.field_unshredders {
580            if let Some(field_unshredder) = field_unshredder_opt {
581                let mut field_builder = ObjectFieldBuilder::new(field_name, &mut object_builder);
582                field_unshredder.append_row(&mut field_builder, metadata, index)?;
583            }
584        }
585
586        if let Some(value) = value {
588            let Variant::Object(object) = value else {
589                return Err(ArrowError::InvalidArgumentError(
590                    "Expected object in value field for partially shredded struct".to_string(),
591                ));
592            };
593
594            for (field_name, field_value) in object.iter() {
595                if self.field_unshredders.contains_key(field_name) {
596                    return Err(ArrowError::InvalidArgumentError(format!(
597                        "Field '{field_name}' appears in both typed_value and value",
598                    )));
599                }
600                object_builder.insert_bytes(field_name, field_value);
601            }
602        }
603
604        object_builder.finish();
605        Ok(())
606    }
607}
608
609struct ListUnshredVariantBuilder<'a, L: ListLikeArray> {
611    value: Option<&'a BinaryViewArray>,
612    typed_value: &'a L,
613    element_unshredder: Box<UnshredVariantRowBuilder<'a>>,
614}
615
616impl<'a, L: ListLikeArray> ListUnshredVariantBuilder<'a, L> {
617    fn try_new(value: Option<&'a BinaryViewArray>, typed_value: &'a L) -> Result<Self> {
618        let element_values = typed_value.values();
621
622        let Some(element_values) = element_values.as_struct_opt() else {
625            return Err(ArrowError::InvalidArgumentError(format!(
626                "Invalid shredded variant array element: expected Struct, got {}",
627                element_values.data_type()
628            )));
629        };
630
631        let element_unshredder = UnshredVariantRowBuilder::try_new_opt(element_values.try_into()?)?
636            .unwrap_or_else(|| UnshredVariantRowBuilder::null(None));
637
638        Ok(Self {
639            value,
640            typed_value,
641            element_unshredder: Box::new(element_unshredder),
642        })
643    }
644
645    fn append_row(
646        &mut self,
647        builder: &mut impl VariantBuilderExt,
648        metadata: &VariantMetadata,
649        index: usize,
650    ) -> Result<()> {
651        handle_unshredded_case!(self, builder, metadata, index, false);
652
653        let mut list_builder = builder.try_new_list()?;
655        for element_index in self.typed_value.element_range(index) {
656            self.element_unshredder
657                .append_row(&mut list_builder, metadata, element_index)?;
658        }
659
660        list_builder.finish();
661        Ok(())
662    }
663}
664
665