parquet_variant_compute/
variant_to_arrow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::shred_variant::{
19    VariantToShreddedVariantRowBuilder, make_variant_to_shredded_variant_arrow_row_builder,
20};
21use crate::type_conversion::{
22    PrimitiveFromVariant, TimestampFromVariant, variant_to_unscaled_decimal,
23};
24use crate::variant_array::ShreddedVariantFieldArray;
25use crate::{VariantArray, VariantValueArrayBuilder};
26use arrow::array::{
27    ArrayRef, ArrowNativeTypeOp, BinaryBuilder, BinaryLikeArrayBuilder, BinaryViewArray,
28    BinaryViewBuilder, BooleanBuilder, FixedSizeBinaryBuilder, GenericListArray,
29    GenericListViewArray, LargeBinaryBuilder, LargeStringBuilder, NullArray, NullBufferBuilder,
30    OffsetSizeTrait, PrimitiveBuilder, StringBuilder, StringLikeArrayBuilder, StringViewBuilder,
31};
32use arrow::buffer::{OffsetBuffer, ScalarBuffer};
33use arrow::compute::{CastOptions, DecimalCast};
34use arrow::datatypes::{self, DataType, DecimalType};
35use arrow::error::{ArrowError, Result};
36use arrow_schema::{FieldRef, TimeUnit};
37use parquet_variant::{Variant, VariantList, VariantPath};
38use std::sync::Arc;
39
40/// Builder for converting variant values into strongly typed Arrow arrays.
41///
42/// Useful for variant_get kernels that need to extract specific paths from variant values, possibly
43/// with casting of leaf values to specific types.
44pub(crate) enum VariantToArrowRowBuilder<'a> {
45    Primitive(PrimitiveVariantToArrowRowBuilder<'a>),
46    BinaryVariant(VariantToBinaryVariantArrowRowBuilder),
47
48    // Path extraction wrapper - contains a boxed enum for any of the above
49    WithPath(VariantPathRowBuilder<'a>),
50}
51
52impl<'a> VariantToArrowRowBuilder<'a> {
53    pub fn append_null(&mut self) -> Result<()> {
54        use VariantToArrowRowBuilder::*;
55        match self {
56            Primitive(b) => b.append_null(),
57            BinaryVariant(b) => b.append_null(),
58            WithPath(path_builder) => path_builder.append_null(),
59        }
60    }
61
62    pub fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
63        use VariantToArrowRowBuilder::*;
64        match self {
65            Primitive(b) => b.append_value(&value),
66            BinaryVariant(b) => b.append_value(value),
67            WithPath(path_builder) => path_builder.append_value(value),
68        }
69    }
70
71    pub fn finish(self) -> Result<ArrayRef> {
72        use VariantToArrowRowBuilder::*;
73        match self {
74            Primitive(b) => b.finish(),
75            BinaryVariant(b) => b.finish(),
76            WithPath(path_builder) => path_builder.finish(),
77        }
78    }
79}
80
81pub(crate) fn make_variant_to_arrow_row_builder<'a>(
82    metadata: &BinaryViewArray,
83    path: VariantPath<'a>,
84    data_type: Option<&'a DataType>,
85    cast_options: &'a CastOptions,
86    capacity: usize,
87) -> Result<VariantToArrowRowBuilder<'a>> {
88    use VariantToArrowRowBuilder::*;
89
90    let mut builder = match data_type {
91        // If no data type was requested, build an unshredded VariantArray.
92        None => BinaryVariant(VariantToBinaryVariantArrowRowBuilder::new(
93            metadata.clone(),
94            capacity,
95        )),
96        Some(DataType::Struct(_)) => {
97            return Err(ArrowError::NotYetImplemented(
98                "Converting unshredded variant objects to arrow structs".to_string(),
99            ));
100        }
101        Some(
102            DataType::List(_)
103            | DataType::LargeList(_)
104            | DataType::ListView(_)
105            | DataType::LargeListView(_)
106            | DataType::FixedSizeList(..),
107        ) => {
108            return Err(ArrowError::NotYetImplemented(
109                "Converting unshredded variant arrays to arrow lists".to_string(),
110            ));
111        }
112        Some(data_type) => {
113            let builder =
114                make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?;
115            Primitive(builder)
116        }
117    };
118
119    // Wrap with path extraction if needed
120    if !path.is_empty() {
121        builder = WithPath(VariantPathRowBuilder {
122            builder: Box::new(builder),
123            path,
124        })
125    };
126
127    Ok(builder)
128}
129
130/// Builder for converting primitive variant values to Arrow arrays. It is used by both
131/// `VariantToArrowRowBuilder` (below) and `VariantToShreddedPrimitiveVariantRowBuilder` (in
132/// `shred_variant.rs`).
133pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> {
134    Null(VariantToNullArrowRowBuilder<'a>),
135    Boolean(VariantToBooleanArrowRowBuilder<'a>),
136    Int8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int8Type>),
137    Int16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int16Type>),
138    Int32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int32Type>),
139    Int64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int64Type>),
140    UInt8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt8Type>),
141    UInt16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt16Type>),
142    UInt32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt32Type>),
143    UInt64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt64Type>),
144    Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>),
145    Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>),
146    Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>),
147    Decimal32(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal32Type>),
148    Decimal64(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal64Type>),
149    Decimal128(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal128Type>),
150    Decimal256(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal256Type>),
151    TimestampSecond(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampSecondType>),
152    TimestampSecondNtz(VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampSecondType>),
153    TimestampMilli(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampMillisecondType>),
154    TimestampMilliNtz(
155        VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampMillisecondType>,
156    ),
157    TimestampMicro(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>),
158    TimestampMicroNtz(
159        VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>,
160    ),
161    TimestampNano(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampNanosecondType>),
162    TimestampNanoNtz(VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampNanosecondType>),
163    Time32Second(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time32SecondType>),
164    Time32Milli(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time32MillisecondType>),
165    Time64Micro(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time64MicrosecondType>),
166    Time64Nano(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time64NanosecondType>),
167    Date32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date32Type>),
168    Date64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date64Type>),
169    Uuid(VariantToUuidArrowRowBuilder<'a>),
170    String(VariantToStringArrowBuilder<'a, StringBuilder>),
171    LargeString(VariantToStringArrowBuilder<'a, LargeStringBuilder>),
172    StringView(VariantToStringArrowBuilder<'a, StringViewBuilder>),
173    Binary(VariantToBinaryArrowRowBuilder<'a, BinaryBuilder>),
174    LargeBinary(VariantToBinaryArrowRowBuilder<'a, LargeBinaryBuilder>),
175    BinaryView(VariantToBinaryArrowRowBuilder<'a, BinaryViewBuilder>),
176}
177
178impl<'a> PrimitiveVariantToArrowRowBuilder<'a> {
179    pub fn append_null(&mut self) -> Result<()> {
180        use PrimitiveVariantToArrowRowBuilder::*;
181        match self {
182            Null(b) => b.append_null(),
183            Boolean(b) => b.append_null(),
184            Int8(b) => b.append_null(),
185            Int16(b) => b.append_null(),
186            Int32(b) => b.append_null(),
187            Int64(b) => b.append_null(),
188            UInt8(b) => b.append_null(),
189            UInt16(b) => b.append_null(),
190            UInt32(b) => b.append_null(),
191            UInt64(b) => b.append_null(),
192            Float16(b) => b.append_null(),
193            Float32(b) => b.append_null(),
194            Float64(b) => b.append_null(),
195            Decimal32(b) => b.append_null(),
196            Decimal64(b) => b.append_null(),
197            Decimal128(b) => b.append_null(),
198            Decimal256(b) => b.append_null(),
199            TimestampSecond(b) => b.append_null(),
200            TimestampSecondNtz(b) => b.append_null(),
201            TimestampMilli(b) => b.append_null(),
202            TimestampMilliNtz(b) => b.append_null(),
203            TimestampMicro(b) => b.append_null(),
204            TimestampMicroNtz(b) => b.append_null(),
205            TimestampNano(b) => b.append_null(),
206            TimestampNanoNtz(b) => b.append_null(),
207            Time32Second(b) => b.append_null(),
208            Time32Milli(b) => b.append_null(),
209            Time64Micro(b) => b.append_null(),
210            Time64Nano(b) => b.append_null(),
211            Date32(b) => b.append_null(),
212            Date64(b) => b.append_null(),
213            Uuid(b) => b.append_null(),
214            String(b) => b.append_null(),
215            LargeString(b) => b.append_null(),
216            StringView(b) => b.append_null(),
217            Binary(b) => b.append_null(),
218            LargeBinary(b) => b.append_null(),
219            BinaryView(b) => b.append_null(),
220        }
221    }
222
223    pub fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
224        use PrimitiveVariantToArrowRowBuilder::*;
225        match self {
226            Null(b) => b.append_value(value),
227            Boolean(b) => b.append_value(value),
228            Int8(b) => b.append_value(value),
229            Int16(b) => b.append_value(value),
230            Int32(b) => b.append_value(value),
231            Int64(b) => b.append_value(value),
232            UInt8(b) => b.append_value(value),
233            UInt16(b) => b.append_value(value),
234            UInt32(b) => b.append_value(value),
235            UInt64(b) => b.append_value(value),
236            Float16(b) => b.append_value(value),
237            Float32(b) => b.append_value(value),
238            Float64(b) => b.append_value(value),
239            Decimal32(b) => b.append_value(value),
240            Decimal64(b) => b.append_value(value),
241            Decimal128(b) => b.append_value(value),
242            Decimal256(b) => b.append_value(value),
243            TimestampSecond(b) => b.append_value(value),
244            TimestampSecondNtz(b) => b.append_value(value),
245            TimestampMilli(b) => b.append_value(value),
246            TimestampMilliNtz(b) => b.append_value(value),
247            TimestampMicro(b) => b.append_value(value),
248            TimestampMicroNtz(b) => b.append_value(value),
249            TimestampNano(b) => b.append_value(value),
250            TimestampNanoNtz(b) => b.append_value(value),
251            Time32Second(b) => b.append_value(value),
252            Time32Milli(b) => b.append_value(value),
253            Time64Micro(b) => b.append_value(value),
254            Time64Nano(b) => b.append_value(value),
255            Date32(b) => b.append_value(value),
256            Date64(b) => b.append_value(value),
257            Uuid(b) => b.append_value(value),
258            String(b) => b.append_value(value),
259            LargeString(b) => b.append_value(value),
260            StringView(b) => b.append_value(value),
261            Binary(b) => b.append_value(value),
262            LargeBinary(b) => b.append_value(value),
263            BinaryView(b) => b.append_value(value),
264        }
265    }
266
267    pub fn finish(self) -> Result<ArrayRef> {
268        use PrimitiveVariantToArrowRowBuilder::*;
269        match self {
270            Null(b) => b.finish(),
271            Boolean(b) => b.finish(),
272            Int8(b) => b.finish(),
273            Int16(b) => b.finish(),
274            Int32(b) => b.finish(),
275            Int64(b) => b.finish(),
276            UInt8(b) => b.finish(),
277            UInt16(b) => b.finish(),
278            UInt32(b) => b.finish(),
279            UInt64(b) => b.finish(),
280            Float16(b) => b.finish(),
281            Float32(b) => b.finish(),
282            Float64(b) => b.finish(),
283            Decimal32(b) => b.finish(),
284            Decimal64(b) => b.finish(),
285            Decimal128(b) => b.finish(),
286            Decimal256(b) => b.finish(),
287            TimestampSecond(b) => b.finish(),
288            TimestampSecondNtz(b) => b.finish(),
289            TimestampMilli(b) => b.finish(),
290            TimestampMilliNtz(b) => b.finish(),
291            TimestampMicro(b) => b.finish(),
292            TimestampMicroNtz(b) => b.finish(),
293            TimestampNano(b) => b.finish(),
294            TimestampNanoNtz(b) => b.finish(),
295            Time32Second(b) => b.finish(),
296            Time32Milli(b) => b.finish(),
297            Time64Micro(b) => b.finish(),
298            Time64Nano(b) => b.finish(),
299            Date32(b) => b.finish(),
300            Date64(b) => b.finish(),
301            Uuid(b) => b.finish(),
302            String(b) => b.finish(),
303            LargeString(b) => b.finish(),
304            StringView(b) => b.finish(),
305            Binary(b) => b.finish(),
306            LargeBinary(b) => b.finish(),
307            BinaryView(b) => b.finish(),
308        }
309    }
310}
311
312/// Creates a row builder that converts primitive `Variant` values into the requested Arrow data type.
313pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>(
314    data_type: &'a DataType,
315    cast_options: &'a CastOptions,
316    capacity: usize,
317) -> Result<PrimitiveVariantToArrowRowBuilder<'a>> {
318    use PrimitiveVariantToArrowRowBuilder::*;
319
320    let builder =
321        match data_type {
322            DataType::Null => Null(VariantToNullArrowRowBuilder::new(cast_options, capacity)),
323            DataType::Boolean => {
324                Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity))
325            }
326            DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
327                cast_options,
328                capacity,
329            )),
330            DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
331                cast_options,
332                capacity,
333            )),
334            DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
335                cast_options,
336                capacity,
337            )),
338            DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
339                cast_options,
340                capacity,
341            )),
342            DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
343                cast_options,
344                capacity,
345            )),
346            DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
347                cast_options,
348                capacity,
349            )),
350            DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
351                cast_options,
352                capacity,
353            )),
354            DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
355                cast_options,
356                capacity,
357            )),
358            DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new(
359                cast_options,
360                capacity,
361            )),
362            DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new(
363                cast_options,
364                capacity,
365            )),
366            DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new(
367                cast_options,
368                capacity,
369            )),
370            DataType::Decimal32(precision, scale) => Decimal32(
371                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
372            ),
373            DataType::Decimal64(precision, scale) => Decimal64(
374                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
375            ),
376            DataType::Decimal128(precision, scale) => Decimal128(
377                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
378            ),
379            DataType::Decimal256(precision, scale) => Decimal256(
380                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
381            ),
382            DataType::Date32 => Date32(VariantToPrimitiveArrowRowBuilder::new(
383                cast_options,
384                capacity,
385            )),
386            DataType::Date64 => Date64(VariantToPrimitiveArrowRowBuilder::new(
387                cast_options,
388                capacity,
389            )),
390            DataType::Time32(TimeUnit::Second) => Time32Second(
391                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
392            ),
393            DataType::Time32(TimeUnit::Millisecond) => Time32Milli(
394                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
395            ),
396            DataType::Time32(t) => {
397                return Err(ArrowError::InvalidArgumentError(format!(
398                    "The unit for Time32 must be second/millisecond, received {t:?}"
399                )));
400            }
401            DataType::Time64(TimeUnit::Microsecond) => Time64Micro(
402                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
403            ),
404            DataType::Time64(TimeUnit::Nanosecond) => Time64Nano(
405                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
406            ),
407            DataType::Time64(t) => {
408                return Err(ArrowError::InvalidArgumentError(format!(
409                    "The unit for Time64 must be micro/nano seconds, received {t:?}"
410                )));
411            }
412            DataType::Timestamp(TimeUnit::Second, None) => TimestampSecondNtz(
413                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
414            ),
415            DataType::Timestamp(TimeUnit::Second, tz) => TimestampSecond(
416                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
417            ),
418            DataType::Timestamp(TimeUnit::Millisecond, None) => TimestampMilliNtz(
419                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
420            ),
421            DataType::Timestamp(TimeUnit::Millisecond, tz) => TimestampMilli(
422                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
423            ),
424            DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz(
425                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
426            ),
427            DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro(
428                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
429            ),
430            DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz(
431                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
432            ),
433            DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano(
434                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
435            ),
436            DataType::Duration(_) | DataType::Interval(_) => {
437                return Err(ArrowError::InvalidArgumentError(
438                    "Casting Variant to duration/interval types is not supported. \
439                    The Variant format does not define duration/interval types."
440                        .to_string(),
441                ));
442            }
443            DataType::Binary => Binary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity)),
444            DataType::LargeBinary => {
445                LargeBinary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity))
446            }
447            DataType::BinaryView => {
448                BinaryView(VariantToBinaryArrowRowBuilder::new(cast_options, capacity))
449            }
450            DataType::FixedSizeBinary(16) => {
451                Uuid(VariantToUuidArrowRowBuilder::new(cast_options, capacity))
452            }
453            DataType::FixedSizeBinary(_) => {
454                return Err(ArrowError::NotYetImplemented(format!(
455                    "DataType {data_type:?} not yet implemented"
456                )));
457            }
458            DataType::Utf8 => String(VariantToStringArrowBuilder::new(cast_options, capacity)),
459            DataType::LargeUtf8 => {
460                LargeString(VariantToStringArrowBuilder::new(cast_options, capacity))
461            }
462            DataType::Utf8View => {
463                StringView(VariantToStringArrowBuilder::new(cast_options, capacity))
464            }
465            DataType::List(_)
466            | DataType::LargeList(_)
467            | DataType::ListView(_)
468            | DataType::LargeListView(_)
469            | DataType::FixedSizeList(..)
470            | DataType::Struct(_)
471            | DataType::Map(..)
472            | DataType::Union(..)
473            | DataType::Dictionary(..)
474            | DataType::RunEndEncoded(..) => {
475                return Err(ArrowError::InvalidArgumentError(format!(
476                    "Casting to {data_type:?} is not applicable for primitive Variant types"
477                )));
478            }
479        };
480    Ok(builder)
481}
482
483pub(crate) enum ArrayVariantToArrowRowBuilder<'a> {
484    List(VariantToListArrowRowBuilder<'a, i32, false>),
485    LargeList(VariantToListArrowRowBuilder<'a, i64, false>),
486    ListView(VariantToListArrowRowBuilder<'a, i32, true>),
487    LargeListView(VariantToListArrowRowBuilder<'a, i64, true>),
488}
489
490impl<'a> ArrayVariantToArrowRowBuilder<'a> {
491    pub(crate) fn try_new(
492        data_type: &'a DataType,
493        cast_options: &'a CastOptions,
494        capacity: usize,
495    ) -> Result<Self> {
496        use ArrayVariantToArrowRowBuilder::*;
497
498        // Make List/ListView builders without repeating the constructor boilerplate.
499        macro_rules! make_list_builder {
500            ($variant:ident, $offset:ty, $is_view:expr, $field:ident) => {
501                $variant(VariantToListArrowRowBuilder::<$offset, $is_view>::try_new(
502                    $field.clone(),
503                    $field.data_type(),
504                    cast_options,
505                    capacity,
506                )?)
507            };
508        }
509
510        let builder = match data_type {
511            DataType::List(field) => make_list_builder!(List, i32, false, field),
512            DataType::LargeList(field) => make_list_builder!(LargeList, i64, false, field),
513            DataType::ListView(field) => make_list_builder!(ListView, i32, true, field),
514            DataType::LargeListView(field) => make_list_builder!(LargeListView, i64, true, field),
515            DataType::FixedSizeList(..) => {
516                return Err(ArrowError::NotYetImplemented(
517                    "Converting unshredded variant arrays to arrow fixed-size lists".to_string(),
518                ));
519            }
520            other => {
521                return Err(ArrowError::InvalidArgumentError(format!(
522                    "Casting to {other:?} is not applicable for array Variant types"
523                )));
524            }
525        };
526        Ok(builder)
527    }
528
529    pub(crate) fn append_null(&mut self) {
530        match self {
531            Self::List(builder) => builder.append_null(),
532            Self::LargeList(builder) => builder.append_null(),
533            Self::ListView(builder) => builder.append_null(),
534            Self::LargeListView(builder) => builder.append_null(),
535        }
536    }
537
538    pub(crate) fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> {
539        match self {
540            Self::List(builder) => builder.append_value(list),
541            Self::LargeList(builder) => builder.append_value(list),
542            Self::ListView(builder) => builder.append_value(list),
543            Self::LargeListView(builder) => builder.append_value(list),
544        }
545    }
546
547    pub(crate) fn finish(self) -> Result<ArrayRef> {
548        match self {
549            Self::List(builder) => builder.finish(),
550            Self::LargeList(builder) => builder.finish(),
551            Self::ListView(builder) => builder.finish(),
552            Self::LargeListView(builder) => builder.finish(),
553        }
554    }
555}
556
557/// A thin wrapper whose only job is to extract a specific path from a variant value and pass the
558/// result to a nested builder.
559pub(crate) struct VariantPathRowBuilder<'a> {
560    builder: Box<VariantToArrowRowBuilder<'a>>,
561    path: VariantPath<'a>,
562}
563
564impl<'a> VariantPathRowBuilder<'a> {
565    fn append_null(&mut self) -> Result<()> {
566        self.builder.append_null()
567    }
568
569    fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
570        if let Some(v) = value.get_path(&self.path) {
571            self.builder.append_value(v)
572        } else {
573            self.builder.append_null()?;
574            Ok(false)
575        }
576    }
577
578    fn finish(self) -> Result<ArrayRef> {
579        self.builder.finish()
580    }
581}
582
583macro_rules! define_variant_to_primitive_builder {
584    (struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?>
585    |$array_param:ident $(, $field:ident: $field_type:ty)?| -> $builder_name:ident $(< $array_type:ty >)? { $init_expr: expr },
586    |$value: ident| $value_transform:expr,
587    type_name: $type_name:expr) => {
588        pub(crate) struct $name<$lifetime $(, $generic : $bound )?>
589        {
590            builder: $builder_name $(<$array_type>)?,
591            cast_options: &$lifetime CastOptions<$lifetime>,
592        }
593
594        impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic )?> {
595            fn new(
596                cast_options: &$lifetime CastOptions<$lifetime>,
597                $array_param: usize,
598                // add this so that $init_expr can use it
599                $( $field: $field_type, )?
600            ) -> Self {
601                Self {
602                    builder: $init_expr,
603                    cast_options,
604                }
605            }
606
607            fn append_null(&mut self) -> Result<()> {
608                self.builder.append_null();
609                Ok(())
610            }
611
612            fn append_value(&mut self, $value: &Variant<'_, '_>) -> Result<bool> {
613                if let Some(v) = $value_transform {
614                    self.builder.append_value(v);
615                    Ok(true)
616                } else {
617                    if !self.cast_options.safe {
618                        // Unsafe casting: return error on conversion failure
619                        return Err(ArrowError::CastError(format!(
620                            "Failed to extract primitive of type {} from variant {:?} at path VariantPath([])",
621                            $type_name,
622                            $value
623                        )));
624                    }
625                    // Safe casting: append null on conversion failure
626                    self.builder.append_null();
627                    Ok(false)
628                }
629            }
630
631            // Add this to silence unused mut warning from macro-generated code
632            // This is mainly for `FakeNullBuilder`
633            #[allow(unused_mut)]
634            fn finish(mut self) -> Result<ArrayRef> {
635                Ok(Arc::new(self.builder.finish()))
636            }
637        }
638    }
639}
640
641define_variant_to_primitive_builder!(
642    struct VariantToStringArrowBuilder<'a, B: StringLikeArrayBuilder>
643    |capacity| -> B { B::with_capacity(capacity) },
644    |value| value.as_string(),
645    type_name: B::type_name()
646);
647
648define_variant_to_primitive_builder!(
649    struct VariantToBooleanArrowRowBuilder<'a>
650    |capacity| -> BooleanBuilder { BooleanBuilder::with_capacity(capacity) },
651    |value|  value.as_boolean(),
652    type_name: datatypes::BooleanType::DATA_TYPE
653);
654
655define_variant_to_primitive_builder!(
656    struct VariantToPrimitiveArrowRowBuilder<'a, T:PrimitiveFromVariant>
657    |capacity| -> PrimitiveBuilder<T> { PrimitiveBuilder::<T>::with_capacity(capacity) },
658    |value| T::from_variant(value),
659    type_name: T::DATA_TYPE
660);
661
662define_variant_to_primitive_builder!(
663    struct VariantToTimestampNtzArrowRowBuilder<'a, T:TimestampFromVariant<true>>
664    |capacity| -> PrimitiveBuilder<T> { PrimitiveBuilder::<T>::with_capacity(capacity) },
665    |value| T::from_variant(value),
666    type_name: T::DATA_TYPE
667);
668
669define_variant_to_primitive_builder!(
670    struct VariantToTimestampArrowRowBuilder<'a, T:TimestampFromVariant<false>>
671    |capacity, tz: Option<Arc<str>> | -> PrimitiveBuilder<T> {
672        PrimitiveBuilder::<T>::with_capacity(capacity).with_timezone_opt(tz)
673    },
674    |value| T::from_variant(value),
675    type_name: T::DATA_TYPE
676);
677
678define_variant_to_primitive_builder!(
679    struct VariantToBinaryArrowRowBuilder<'a, B: BinaryLikeArrayBuilder>
680    |capacity| -> B { B::with_capacity(capacity) },
681    |value| value.as_u8_slice(),
682    type_name: B::type_name()
683);
684
685/// Builder for converting variant values to arrow Decimal values
686pub(crate) struct VariantToDecimalArrowRowBuilder<'a, T>
687where
688    T: DecimalType,
689    T::Native: DecimalCast,
690{
691    builder: PrimitiveBuilder<T>,
692    cast_options: &'a CastOptions<'a>,
693    precision: u8,
694    scale: i8,
695}
696
697impl<'a, T> VariantToDecimalArrowRowBuilder<'a, T>
698where
699    T: DecimalType,
700    T::Native: DecimalCast,
701{
702    fn new(
703        cast_options: &'a CastOptions<'a>,
704        capacity: usize,
705        precision: u8,
706        scale: i8,
707    ) -> Result<Self> {
708        let builder = PrimitiveBuilder::<T>::with_capacity(capacity)
709            .with_precision_and_scale(precision, scale)?;
710        Ok(Self {
711            builder,
712            cast_options,
713            precision,
714            scale,
715        })
716    }
717
718    fn append_null(&mut self) -> Result<()> {
719        self.builder.append_null();
720        Ok(())
721    }
722
723    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
724        if let Some(scaled) = variant_to_unscaled_decimal::<T>(value, self.precision, self.scale) {
725            self.builder.append_value(scaled);
726            Ok(true)
727        } else if self.cast_options.safe {
728            self.builder.append_null();
729            Ok(false)
730        } else {
731            Err(ArrowError::CastError(format!(
732                "Failed to cast to {}(precision={}, scale={}) from variant {:?}",
733                T::PREFIX,
734                self.precision,
735                self.scale,
736                value
737            )))
738        }
739    }
740
741    fn finish(mut self) -> Result<ArrayRef> {
742        Ok(Arc::new(self.builder.finish()))
743    }
744}
745
746/// Builder for converting variant values to FixedSizeBinary(16) for UUIDs
747pub(crate) struct VariantToUuidArrowRowBuilder<'a> {
748    builder: FixedSizeBinaryBuilder,
749    cast_options: &'a CastOptions<'a>,
750}
751
752impl<'a> VariantToUuidArrowRowBuilder<'a> {
753    fn new(cast_options: &'a CastOptions<'a>, capacity: usize) -> Self {
754        Self {
755            builder: FixedSizeBinaryBuilder::with_capacity(capacity, 16),
756            cast_options,
757        }
758    }
759
760    fn append_null(&mut self) -> Result<()> {
761        self.builder.append_null();
762        Ok(())
763    }
764
765    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
766        match value.as_uuid() {
767            Some(uuid) => {
768                self.builder
769                    .append_value(uuid.as_bytes())
770                    .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
771
772                Ok(true)
773            }
774            None if self.cast_options.safe => {
775                self.builder.append_null();
776                Ok(false)
777            }
778            None => Err(ArrowError::CastError(format!(
779                "Failed to extract UUID from variant {value:?}",
780            ))),
781        }
782    }
783
784    fn finish(mut self) -> Result<ArrayRef> {
785        Ok(Arc::new(self.builder.finish()))
786    }
787}
788
789pub(crate) struct VariantToListArrowRowBuilder<'a, O, const IS_VIEW: bool>
790where
791    O: OffsetSizeTrait + ArrowNativeTypeOp,
792{
793    field: FieldRef,
794    offsets: Vec<O>,
795    element_builder: Box<VariantToShreddedVariantRowBuilder<'a>>,
796    nulls: NullBufferBuilder,
797    current_offset: O,
798}
799
800impl<'a, O, const IS_VIEW: bool> VariantToListArrowRowBuilder<'a, O, IS_VIEW>
801where
802    O: OffsetSizeTrait + ArrowNativeTypeOp,
803{
804    fn try_new(
805        field: FieldRef,
806        element_data_type: &'a DataType,
807        cast_options: &'a CastOptions,
808        capacity: usize,
809    ) -> Result<Self> {
810        if capacity >= isize::MAX as usize {
811            return Err(ArrowError::ComputeError(
812                "Capacity exceeds isize::MAX when reserving list offsets".to_string(),
813            ));
814        }
815        let mut offsets = Vec::with_capacity(capacity + 1);
816        offsets.push(O::ZERO);
817        let element_builder = make_variant_to_shredded_variant_arrow_row_builder(
818            element_data_type,
819            cast_options,
820            capacity,
821            false,
822        )?;
823        Ok(Self {
824            field,
825            offsets,
826            element_builder: Box::new(element_builder),
827            nulls: NullBufferBuilder::new(capacity),
828            current_offset: O::ZERO,
829        })
830    }
831
832    fn append_null(&mut self) {
833        self.offsets.push(self.current_offset);
834        self.nulls.append_null();
835    }
836
837    fn append_value(&mut self, list: VariantList<'_, '_>) -> Result<()> {
838        for element in list.iter() {
839            self.element_builder.append_value(element)?;
840            self.current_offset = self.current_offset.add_checked(O::ONE)?;
841        }
842        self.offsets.push(self.current_offset);
843        self.nulls.append_non_null();
844        Ok(())
845    }
846
847    fn finish(mut self) -> Result<ArrayRef> {
848        let (value, typed_value, nulls) = self.element_builder.finish()?;
849        let element_array =
850            ShreddedVariantFieldArray::from_parts(Some(value), Some(typed_value), nulls);
851        let field = Arc::new(
852            self.field
853                .as_ref()
854                .clone()
855                .with_data_type(element_array.data_type().clone()),
856        );
857
858        if IS_VIEW {
859            // NOTE: `offsets` is never empty (constructor pushes an entry)
860            let mut sizes = Vec::with_capacity(self.offsets.len() - 1);
861            for i in 1..self.offsets.len() {
862                sizes.push(self.offsets[i] - self.offsets[i - 1]);
863            }
864            self.offsets.pop();
865            let list_view_array = GenericListViewArray::<O>::new(
866                field,
867                ScalarBuffer::from(self.offsets),
868                ScalarBuffer::from(sizes),
869                ArrayRef::from(element_array),
870                self.nulls.finish(),
871            );
872            Ok(Arc::new(list_view_array))
873        } else {
874            let list_array = GenericListArray::<O>::new(
875                field,
876                OffsetBuffer::<O>::new(ScalarBuffer::from(self.offsets)),
877                ArrayRef::from(element_array),
878                self.nulls.finish(),
879            );
880            Ok(Arc::new(list_array))
881        }
882    }
883}
884
885/// Builder for creating VariantArray output (for path extraction without type conversion)
886pub(crate) struct VariantToBinaryVariantArrowRowBuilder {
887    metadata: BinaryViewArray,
888    builder: VariantValueArrayBuilder,
889    nulls: NullBufferBuilder,
890}
891
892impl VariantToBinaryVariantArrowRowBuilder {
893    fn new(metadata: BinaryViewArray, capacity: usize) -> Self {
894        Self {
895            metadata,
896            builder: VariantValueArrayBuilder::new(capacity),
897            nulls: NullBufferBuilder::new(capacity),
898        }
899    }
900}
901
902impl VariantToBinaryVariantArrowRowBuilder {
903    fn append_null(&mut self) -> Result<()> {
904        self.builder.append_null();
905        self.nulls.append_null();
906        Ok(())
907    }
908
909    fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
910        self.builder.append_value(value);
911        self.nulls.append_non_null();
912        Ok(true)
913    }
914
915    fn finish(mut self) -> Result<ArrayRef> {
916        let variant_array = VariantArray::from_parts(
917            self.metadata,
918            Some(self.builder.build()?),
919            None, // no typed_value column
920            self.nulls.finish(),
921        );
922
923        Ok(ArrayRef::from(variant_array))
924    }
925}
926
927#[derive(Default)]
928struct FakeNullBuilder {
929    item_count: usize,
930}
931
932impl FakeNullBuilder {
933    fn append_value(&mut self, _: ()) {
934        self.item_count += 1;
935    }
936
937    fn append_null(&mut self) {
938        self.item_count += 1;
939    }
940
941    fn finish(self) -> NullArray {
942        NullArray::new(self.item_count)
943    }
944}
945
946define_variant_to_primitive_builder!(
947    struct VariantToNullArrowRowBuilder<'a>
948    |_capacity| -> FakeNullBuilder { FakeNullBuilder::default() },
949    |value| value.as_null(),
950    type_name: "Null"
951);
952
953#[cfg(test)]
954mod tests {
955    use super::make_primitive_variant_to_arrow_row_builder;
956    use arrow::compute::CastOptions;
957    use arrow::datatypes::{DataType, Field, Fields, UnionFields, UnionMode};
958    use arrow::error::ArrowError;
959    use std::sync::Arc;
960
961    #[test]
962    fn make_primitive_builder_rejects_non_primitive_types() {
963        let cast_options = CastOptions::default();
964        let item_field = Arc::new(Field::new("item", DataType::Int32, true));
965        let struct_fields = Fields::from(vec![Field::new("child", DataType::Int32, true)]);
966        let map_entries_field = Arc::new(Field::new(
967            "entries",
968            DataType::Struct(Fields::from(vec![
969                Field::new("key", DataType::Utf8, false),
970                Field::new("value", DataType::Float64, true),
971            ])),
972            true,
973        ));
974        let union_fields =
975            UnionFields::try_new(vec![1], vec![Field::new("child", DataType::Int32, true)])
976                .unwrap();
977        let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int32, false));
978        let ree_values_field = Arc::new(Field::new("values", DataType::Utf8, true));
979
980        let non_primitive_types = vec![
981            DataType::List(item_field.clone()),
982            DataType::LargeList(item_field.clone()),
983            DataType::ListView(item_field.clone()),
984            DataType::LargeListView(item_field.clone()),
985            DataType::FixedSizeList(item_field.clone(), 2),
986            DataType::Struct(struct_fields.clone()),
987            DataType::Map(map_entries_field.clone(), false),
988            DataType::Union(union_fields.clone(), UnionMode::Dense),
989            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
990            DataType::RunEndEncoded(run_ends_field.clone(), ree_values_field.clone()),
991        ];
992
993        for data_type in non_primitive_types {
994            let err =
995                match make_primitive_variant_to_arrow_row_builder(&data_type, &cast_options, 1) {
996                    Ok(_) => panic!("non-primitive type {data_type:?} should be rejected"),
997                    Err(err) => err,
998                };
999
1000            match err {
1001                ArrowError::InvalidArgumentError(msg) => {
1002                    assert!(msg.contains(&format!("{data_type:?}")));
1003                }
1004                other => panic!("expected InvalidArgumentError, got {other:?}"),
1005            }
1006        }
1007    }
1008}