Skip to main content

parquet_variant_compute/
variant_to_arrow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::shred_variant::{
19    VariantToShreddedVariantRowBuilder, make_variant_to_shredded_variant_arrow_row_builder,
20};
21use crate::type_conversion::{
22    PrimitiveFromVariant, TimestampFromVariant, variant_to_unscaled_decimal,
23};
24use crate::variant_array::ShreddedVariantFieldArray;
25use crate::{VariantArray, VariantValueArrayBuilder};
26use arrow::array::{
27    ArrayRef, ArrowNativeTypeOp, BinaryBuilder, BinaryLikeArrayBuilder, BinaryViewArray,
28    BinaryViewBuilder, BooleanBuilder, FixedSizeBinaryBuilder, GenericListArray,
29    GenericListViewArray, LargeBinaryBuilder, LargeStringBuilder, NullArray, NullBufferBuilder,
30    OffsetSizeTrait, PrimitiveBuilder, StringBuilder, StringLikeArrayBuilder, StringViewBuilder,
31    StructArray,
32};
33use arrow::buffer::{OffsetBuffer, ScalarBuffer};
34use arrow::compute::{CastOptions, DecimalCast};
35use arrow::datatypes::{self, DataType, DecimalType};
36use arrow::error::{ArrowError, Result};
37use arrow_schema::{FieldRef, Fields, TimeUnit};
38use parquet_variant::{Variant, VariantPath};
39use std::sync::Arc;
40
41/// Builder for converting variant values into strongly typed Arrow arrays.
42///
43/// Useful for variant_get kernels that need to extract specific paths from variant values, possibly
44/// with casting of leaf values to specific types.
45pub(crate) enum VariantToArrowRowBuilder<'a> {
46    Primitive(PrimitiveVariantToArrowRowBuilder<'a>),
47    Array(ArrayVariantToArrowRowBuilder<'a>),
48    Struct(StructVariantToArrowRowBuilder<'a>),
49    BinaryVariant(VariantToBinaryVariantArrowRowBuilder),
50
51    // Path extraction wrapper - contains a boxed enum for any of the above
52    WithPath(VariantPathRowBuilder<'a>),
53}
54
55impl<'a> VariantToArrowRowBuilder<'a> {
56    pub fn append_null(&mut self) -> Result<()> {
57        use VariantToArrowRowBuilder::*;
58        match self {
59            Primitive(b) => b.append_null(),
60            Array(b) => b.append_null(),
61            Struct(b) => b.append_null(),
62            BinaryVariant(b) => b.append_null(),
63            WithPath(path_builder) => path_builder.append_null(),
64        }
65    }
66
67    pub fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
68        use VariantToArrowRowBuilder::*;
69        match self {
70            Primitive(b) => b.append_value(&value),
71            Array(b) => b.append_value(&value),
72            Struct(b) => b.append_value(&value),
73            BinaryVariant(b) => b.append_value(value),
74            WithPath(path_builder) => path_builder.append_value(value),
75        }
76    }
77
78    pub fn finish(self) -> Result<ArrayRef> {
79        use VariantToArrowRowBuilder::*;
80        match self {
81            Primitive(b) => b.finish(),
82            Array(b) => b.finish(),
83            Struct(b) => b.finish(),
84            BinaryVariant(b) => b.finish(),
85            WithPath(path_builder) => path_builder.finish(),
86        }
87    }
88}
89
90fn make_typed_variant_to_arrow_row_builder<'a>(
91    data_type: &'a DataType,
92    cast_options: &'a CastOptions,
93    capacity: usize,
94) -> Result<VariantToArrowRowBuilder<'a>> {
95    use VariantToArrowRowBuilder::*;
96
97    match data_type {
98        DataType::Struct(fields) => {
99            let builder = StructVariantToArrowRowBuilder::try_new(fields, cast_options, capacity)?;
100            Ok(Struct(builder))
101        }
102        data_type @ (DataType::List(_)
103        | DataType::LargeList(_)
104        | DataType::ListView(_)
105        | DataType::LargeListView(_)
106        | DataType::FixedSizeList(..)) => {
107            let builder =
108                ArrayVariantToArrowRowBuilder::try_new(data_type, cast_options, capacity)?;
109            Ok(Array(builder))
110        }
111        data_type => {
112            let builder =
113                make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?;
114            Ok(Primitive(builder))
115        }
116    }
117}
118
119pub(crate) fn make_variant_to_arrow_row_builder<'a>(
120    metadata: &BinaryViewArray,
121    path: VariantPath<'a>,
122    data_type: Option<&'a DataType>,
123    cast_options: &'a CastOptions,
124    capacity: usize,
125) -> Result<VariantToArrowRowBuilder<'a>> {
126    use VariantToArrowRowBuilder::*;
127
128    let mut builder = match data_type {
129        // If no data type was requested, build an unshredded VariantArray.
130        None => BinaryVariant(VariantToBinaryVariantArrowRowBuilder::new(
131            metadata.clone(),
132            capacity,
133        )),
134        Some(data_type) => {
135            make_typed_variant_to_arrow_row_builder(data_type, cast_options, capacity)?
136        }
137    };
138
139    // Wrap with path extraction if needed
140    if !path.is_empty() {
141        builder = WithPath(VariantPathRowBuilder {
142            builder: Box::new(builder),
143            path,
144        })
145    };
146
147    Ok(builder)
148}
149
150/// Builder for converting primitive variant values to Arrow arrays. It is used by both
151/// `VariantToArrowRowBuilder` (below) and `VariantToShreddedPrimitiveVariantRowBuilder` (in
152/// `shred_variant.rs`).
153pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> {
154    Null(VariantToNullArrowRowBuilder<'a>),
155    Boolean(VariantToBooleanArrowRowBuilder<'a>),
156    Int8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int8Type>),
157    Int16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int16Type>),
158    Int32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int32Type>),
159    Int64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Int64Type>),
160    UInt8(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt8Type>),
161    UInt16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt16Type>),
162    UInt32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt32Type>),
163    UInt64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::UInt64Type>),
164    Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>),
165    Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>),
166    Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>),
167    Decimal32(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal32Type>),
168    Decimal64(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal64Type>),
169    Decimal128(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal128Type>),
170    Decimal256(VariantToDecimalArrowRowBuilder<'a, datatypes::Decimal256Type>),
171    TimestampSecond(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampSecondType>),
172    TimestampSecondNtz(VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampSecondType>),
173    TimestampMilli(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampMillisecondType>),
174    TimestampMilliNtz(
175        VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampMillisecondType>,
176    ),
177    TimestampMicro(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>),
178    TimestampMicroNtz(
179        VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampMicrosecondType>,
180    ),
181    TimestampNano(VariantToTimestampArrowRowBuilder<'a, datatypes::TimestampNanosecondType>),
182    TimestampNanoNtz(VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampNanosecondType>),
183    Time32Second(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time32SecondType>),
184    Time32Milli(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time32MillisecondType>),
185    Time64Micro(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time64MicrosecondType>),
186    Time64Nano(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time64NanosecondType>),
187    Date32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date32Type>),
188    Date64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date64Type>),
189    Uuid(VariantToUuidArrowRowBuilder<'a>),
190    String(VariantToStringArrowBuilder<'a, StringBuilder>),
191    LargeString(VariantToStringArrowBuilder<'a, LargeStringBuilder>),
192    StringView(VariantToStringArrowBuilder<'a, StringViewBuilder>),
193    Binary(VariantToBinaryArrowRowBuilder<'a, BinaryBuilder>),
194    LargeBinary(VariantToBinaryArrowRowBuilder<'a, LargeBinaryBuilder>),
195    BinaryView(VariantToBinaryArrowRowBuilder<'a, BinaryViewBuilder>),
196}
197
198impl<'a> PrimitiveVariantToArrowRowBuilder<'a> {
199    pub fn append_null(&mut self) -> Result<()> {
200        use PrimitiveVariantToArrowRowBuilder::*;
201        match self {
202            Null(b) => b.append_null(),
203            Boolean(b) => b.append_null(),
204            Int8(b) => b.append_null(),
205            Int16(b) => b.append_null(),
206            Int32(b) => b.append_null(),
207            Int64(b) => b.append_null(),
208            UInt8(b) => b.append_null(),
209            UInt16(b) => b.append_null(),
210            UInt32(b) => b.append_null(),
211            UInt64(b) => b.append_null(),
212            Float16(b) => b.append_null(),
213            Float32(b) => b.append_null(),
214            Float64(b) => b.append_null(),
215            Decimal32(b) => b.append_null(),
216            Decimal64(b) => b.append_null(),
217            Decimal128(b) => b.append_null(),
218            Decimal256(b) => b.append_null(),
219            TimestampSecond(b) => b.append_null(),
220            TimestampSecondNtz(b) => b.append_null(),
221            TimestampMilli(b) => b.append_null(),
222            TimestampMilliNtz(b) => b.append_null(),
223            TimestampMicro(b) => b.append_null(),
224            TimestampMicroNtz(b) => b.append_null(),
225            TimestampNano(b) => b.append_null(),
226            TimestampNanoNtz(b) => b.append_null(),
227            Time32Second(b) => b.append_null(),
228            Time32Milli(b) => b.append_null(),
229            Time64Micro(b) => b.append_null(),
230            Time64Nano(b) => b.append_null(),
231            Date32(b) => b.append_null(),
232            Date64(b) => b.append_null(),
233            Uuid(b) => b.append_null(),
234            String(b) => b.append_null(),
235            LargeString(b) => b.append_null(),
236            StringView(b) => b.append_null(),
237            Binary(b) => b.append_null(),
238            LargeBinary(b) => b.append_null(),
239            BinaryView(b) => b.append_null(),
240        }
241    }
242
243    pub fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
244        use PrimitiveVariantToArrowRowBuilder::*;
245        match self {
246            Null(b) => b.append_value(value),
247            Boolean(b) => b.append_value(value),
248            Int8(b) => b.append_value(value),
249            Int16(b) => b.append_value(value),
250            Int32(b) => b.append_value(value),
251            Int64(b) => b.append_value(value),
252            UInt8(b) => b.append_value(value),
253            UInt16(b) => b.append_value(value),
254            UInt32(b) => b.append_value(value),
255            UInt64(b) => b.append_value(value),
256            Float16(b) => b.append_value(value),
257            Float32(b) => b.append_value(value),
258            Float64(b) => b.append_value(value),
259            Decimal32(b) => b.append_value(value),
260            Decimal64(b) => b.append_value(value),
261            Decimal128(b) => b.append_value(value),
262            Decimal256(b) => b.append_value(value),
263            TimestampSecond(b) => b.append_value(value),
264            TimestampSecondNtz(b) => b.append_value(value),
265            TimestampMilli(b) => b.append_value(value),
266            TimestampMilliNtz(b) => b.append_value(value),
267            TimestampMicro(b) => b.append_value(value),
268            TimestampMicroNtz(b) => b.append_value(value),
269            TimestampNano(b) => b.append_value(value),
270            TimestampNanoNtz(b) => b.append_value(value),
271            Time32Second(b) => b.append_value(value),
272            Time32Milli(b) => b.append_value(value),
273            Time64Micro(b) => b.append_value(value),
274            Time64Nano(b) => b.append_value(value),
275            Date32(b) => b.append_value(value),
276            Date64(b) => b.append_value(value),
277            Uuid(b) => b.append_value(value),
278            String(b) => b.append_value(value),
279            LargeString(b) => b.append_value(value),
280            StringView(b) => b.append_value(value),
281            Binary(b) => b.append_value(value),
282            LargeBinary(b) => b.append_value(value),
283            BinaryView(b) => b.append_value(value),
284        }
285    }
286
287    pub fn finish(self) -> Result<ArrayRef> {
288        use PrimitiveVariantToArrowRowBuilder::*;
289        match self {
290            Null(b) => b.finish(),
291            Boolean(b) => b.finish(),
292            Int8(b) => b.finish(),
293            Int16(b) => b.finish(),
294            Int32(b) => b.finish(),
295            Int64(b) => b.finish(),
296            UInt8(b) => b.finish(),
297            UInt16(b) => b.finish(),
298            UInt32(b) => b.finish(),
299            UInt64(b) => b.finish(),
300            Float16(b) => b.finish(),
301            Float32(b) => b.finish(),
302            Float64(b) => b.finish(),
303            Decimal32(b) => b.finish(),
304            Decimal64(b) => b.finish(),
305            Decimal128(b) => b.finish(),
306            Decimal256(b) => b.finish(),
307            TimestampSecond(b) => b.finish(),
308            TimestampSecondNtz(b) => b.finish(),
309            TimestampMilli(b) => b.finish(),
310            TimestampMilliNtz(b) => b.finish(),
311            TimestampMicro(b) => b.finish(),
312            TimestampMicroNtz(b) => b.finish(),
313            TimestampNano(b) => b.finish(),
314            TimestampNanoNtz(b) => b.finish(),
315            Time32Second(b) => b.finish(),
316            Time32Milli(b) => b.finish(),
317            Time64Micro(b) => b.finish(),
318            Time64Nano(b) => b.finish(),
319            Date32(b) => b.finish(),
320            Date64(b) => b.finish(),
321            Uuid(b) => b.finish(),
322            String(b) => b.finish(),
323            LargeString(b) => b.finish(),
324            StringView(b) => b.finish(),
325            Binary(b) => b.finish(),
326            LargeBinary(b) => b.finish(),
327            BinaryView(b) => b.finish(),
328        }
329    }
330}
331
332/// Creates a row builder that converts primitive `Variant` values into the requested Arrow data type.
333pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>(
334    data_type: &'a DataType,
335    cast_options: &'a CastOptions,
336    capacity: usize,
337) -> Result<PrimitiveVariantToArrowRowBuilder<'a>> {
338    use PrimitiveVariantToArrowRowBuilder::*;
339
340    let builder =
341        match data_type {
342            DataType::Null => Null(VariantToNullArrowRowBuilder::new(cast_options, capacity)),
343            DataType::Boolean => {
344                Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity))
345            }
346            DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
347                cast_options,
348                capacity,
349            )),
350            DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
351                cast_options,
352                capacity,
353            )),
354            DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
355                cast_options,
356                capacity,
357            )),
358            DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
359                cast_options,
360                capacity,
361            )),
362            DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
363                cast_options,
364                capacity,
365            )),
366            DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
367                cast_options,
368                capacity,
369            )),
370            DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
371                cast_options,
372                capacity,
373            )),
374            DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
375                cast_options,
376                capacity,
377            )),
378            DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new(
379                cast_options,
380                capacity,
381            )),
382            DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new(
383                cast_options,
384                capacity,
385            )),
386            DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new(
387                cast_options,
388                capacity,
389            )),
390            DataType::Decimal32(precision, scale) => Decimal32(
391                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
392            ),
393            DataType::Decimal64(precision, scale) => Decimal64(
394                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
395            ),
396            DataType::Decimal128(precision, scale) => Decimal128(
397                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
398            ),
399            DataType::Decimal256(precision, scale) => Decimal256(
400                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?,
401            ),
402            DataType::Date32 => Date32(VariantToPrimitiveArrowRowBuilder::new(
403                cast_options,
404                capacity,
405            )),
406            DataType::Date64 => Date64(VariantToPrimitiveArrowRowBuilder::new(
407                cast_options,
408                capacity,
409            )),
410            DataType::Time32(TimeUnit::Second) => Time32Second(
411                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
412            ),
413            DataType::Time32(TimeUnit::Millisecond) => Time32Milli(
414                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
415            ),
416            DataType::Time32(t) => {
417                return Err(ArrowError::InvalidArgumentError(format!(
418                    "The unit for Time32 must be second/millisecond, received {t:?}"
419                )));
420            }
421            DataType::Time64(TimeUnit::Microsecond) => Time64Micro(
422                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
423            ),
424            DataType::Time64(TimeUnit::Nanosecond) => Time64Nano(
425                VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity),
426            ),
427            DataType::Time64(t) => {
428                return Err(ArrowError::InvalidArgumentError(format!(
429                    "The unit for Time64 must be micro/nano seconds, received {t:?}"
430                )));
431            }
432            DataType::Timestamp(TimeUnit::Second, None) => TimestampSecondNtz(
433                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
434            ),
435            DataType::Timestamp(TimeUnit::Second, tz) => TimestampSecond(
436                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
437            ),
438            DataType::Timestamp(TimeUnit::Millisecond, None) => TimestampMilliNtz(
439                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
440            ),
441            DataType::Timestamp(TimeUnit::Millisecond, tz) => TimestampMilli(
442                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
443            ),
444            DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz(
445                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
446            ),
447            DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro(
448                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
449            ),
450            DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz(
451                VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity),
452            ),
453            DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano(
454                VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()),
455            ),
456            DataType::Duration(_) | DataType::Interval(_) => {
457                return Err(ArrowError::InvalidArgumentError(
458                    "Casting Variant to duration/interval types is not supported. \
459                    The Variant format does not define duration/interval types."
460                        .to_string(),
461                ));
462            }
463            DataType::Binary => Binary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity)),
464            DataType::LargeBinary => {
465                LargeBinary(VariantToBinaryArrowRowBuilder::new(cast_options, capacity))
466            }
467            DataType::BinaryView => {
468                BinaryView(VariantToBinaryArrowRowBuilder::new(cast_options, capacity))
469            }
470            DataType::FixedSizeBinary(16) => {
471                Uuid(VariantToUuidArrowRowBuilder::new(cast_options, capacity))
472            }
473            DataType::FixedSizeBinary(_) => {
474                return Err(ArrowError::NotYetImplemented(format!(
475                    "DataType {data_type:?} not yet implemented"
476                )));
477            }
478            DataType::Utf8 => String(VariantToStringArrowBuilder::new(cast_options, capacity)),
479            DataType::LargeUtf8 => {
480                LargeString(VariantToStringArrowBuilder::new(cast_options, capacity))
481            }
482            DataType::Utf8View => {
483                StringView(VariantToStringArrowBuilder::new(cast_options, capacity))
484            }
485            DataType::List(_)
486            | DataType::LargeList(_)
487            | DataType::ListView(_)
488            | DataType::LargeListView(_)
489            | DataType::FixedSizeList(..)
490            | DataType::Struct(_)
491            | DataType::Map(..)
492            | DataType::Union(..)
493            | DataType::Dictionary(..)
494            | DataType::RunEndEncoded(..) => {
495                return Err(ArrowError::InvalidArgumentError(format!(
496                    "Casting to {data_type:?} is not applicable for primitive Variant types"
497                )));
498            }
499        };
500    Ok(builder)
501}
502
503pub(crate) enum ArrayVariantToArrowRowBuilder<'a> {
504    List(VariantToListArrowRowBuilder<'a, i32, false>),
505    LargeList(VariantToListArrowRowBuilder<'a, i64, false>),
506    ListView(VariantToListArrowRowBuilder<'a, i32, true>),
507    LargeListView(VariantToListArrowRowBuilder<'a, i64, true>),
508}
509
510pub(crate) struct StructVariantToArrowRowBuilder<'a> {
511    fields: &'a Fields,
512    field_builders: Vec<VariantToArrowRowBuilder<'a>>,
513    nulls: NullBufferBuilder,
514    cast_options: &'a CastOptions<'a>,
515}
516
517impl<'a> StructVariantToArrowRowBuilder<'a> {
518    fn try_new(
519        fields: &'a Fields,
520        cast_options: &'a CastOptions<'a>,
521        capacity: usize,
522    ) -> Result<Self> {
523        let mut field_builders = Vec::with_capacity(fields.len());
524        for field in fields.iter() {
525            field_builders.push(make_typed_variant_to_arrow_row_builder(
526                field.data_type(),
527                cast_options,
528                capacity,
529            )?);
530        }
531        Ok(Self {
532            fields,
533            field_builders,
534            nulls: NullBufferBuilder::new(capacity),
535            cast_options,
536        })
537    }
538
539    fn append_null(&mut self) -> Result<()> {
540        for builder in &mut self.field_builders {
541            builder.append_null()?;
542        }
543        self.nulls.append_null();
544        Ok(())
545    }
546
547    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
548        let Variant::Object(obj) = value else {
549            if self.cast_options.safe {
550                self.append_null()?;
551                return Ok(false);
552            }
553            return Err(ArrowError::CastError(format!(
554                "Failed to extract struct from variant {:?}",
555                value
556            )));
557        };
558
559        for (index, field) in self.fields.iter().enumerate() {
560            match obj.get(field.name()) {
561                Some(field_value) => {
562                    self.field_builders[index].append_value(field_value)?;
563                }
564                None => {
565                    self.field_builders[index].append_null()?;
566                }
567            }
568        }
569
570        self.nulls.append_non_null();
571        Ok(true)
572    }
573
574    fn finish(mut self) -> Result<ArrayRef> {
575        let mut children = Vec::with_capacity(self.field_builders.len());
576        for builder in self.field_builders {
577            children.push(builder.finish()?);
578        }
579        Ok(Arc::new(StructArray::try_new(
580            self.fields.clone(),
581            children,
582            self.nulls.finish(),
583        )?))
584    }
585}
586
587impl<'a> ArrayVariantToArrowRowBuilder<'a> {
588    pub(crate) fn try_new(
589        data_type: &'a DataType,
590        cast_options: &'a CastOptions,
591        capacity: usize,
592    ) -> Result<Self> {
593        use ArrayVariantToArrowRowBuilder::*;
594
595        // Make List/ListView builders without repeating the constructor boilerplate.
596        macro_rules! make_list_builder {
597            ($variant:ident, $offset:ty, $is_view:expr, $field:ident) => {
598                $variant(VariantToListArrowRowBuilder::<$offset, $is_view>::try_new(
599                    $field.clone(),
600                    $field.data_type(),
601                    cast_options,
602                    capacity,
603                )?)
604            };
605        }
606
607        let builder = match data_type {
608            DataType::List(field) => make_list_builder!(List, i32, false, field),
609            DataType::LargeList(field) => make_list_builder!(LargeList, i64, false, field),
610            DataType::ListView(field) => make_list_builder!(ListView, i32, true, field),
611            DataType::LargeListView(field) => make_list_builder!(LargeListView, i64, true, field),
612            DataType::FixedSizeList(..) => {
613                return Err(ArrowError::NotYetImplemented(
614                    "Converting unshredded variant arrays to arrow fixed-size lists".to_string(),
615                ));
616            }
617            other => {
618                return Err(ArrowError::InvalidArgumentError(format!(
619                    "Casting to {other:?} is not applicable for array Variant types"
620                )));
621            }
622        };
623        Ok(builder)
624    }
625
626    pub(crate) fn append_null(&mut self) -> Result<()> {
627        match self {
628            Self::List(builder) => builder.append_null(),
629            Self::LargeList(builder) => builder.append_null(),
630            Self::ListView(builder) => builder.append_null(),
631            Self::LargeListView(builder) => builder.append_null(),
632        }
633    }
634
635    pub(crate) fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
636        match self {
637            Self::List(builder) => builder.append_value(value),
638            Self::LargeList(builder) => builder.append_value(value),
639            Self::ListView(builder) => builder.append_value(value),
640            Self::LargeListView(builder) => builder.append_value(value),
641        }
642    }
643
644    pub(crate) fn finish(self) -> Result<ArrayRef> {
645        match self {
646            Self::List(builder) => builder.finish(),
647            Self::LargeList(builder) => builder.finish(),
648            Self::ListView(builder) => builder.finish(),
649            Self::LargeListView(builder) => builder.finish(),
650        }
651    }
652}
653
654/// A thin wrapper whose only job is to extract a specific path from a variant value and pass the
655/// result to a nested builder.
656pub(crate) struct VariantPathRowBuilder<'a> {
657    builder: Box<VariantToArrowRowBuilder<'a>>,
658    path: VariantPath<'a>,
659}
660
661impl<'a> VariantPathRowBuilder<'a> {
662    fn append_null(&mut self) -> Result<()> {
663        self.builder.append_null()
664    }
665
666    fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
667        if let Some(v) = value.get_path(&self.path) {
668            self.builder.append_value(v)
669        } else {
670            self.builder.append_null()?;
671            Ok(false)
672        }
673    }
674
675    fn finish(self) -> Result<ArrayRef> {
676        self.builder.finish()
677    }
678}
679
680macro_rules! define_variant_to_primitive_builder {
681    (struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?>
682    |$array_param:ident $(, $field:ident: $field_type:ty)?| -> $builder_name:ident $(< $array_type:ty >)? { $init_expr: expr },
683    |$value: ident| $value_transform:expr,
684    type_name: $type_name:expr) => {
685        pub(crate) struct $name<$lifetime $(, $generic : $bound )?>
686        {
687            builder: $builder_name $(<$array_type>)?,
688            cast_options: &$lifetime CastOptions<$lifetime>,
689        }
690
691        impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic )?> {
692            fn new(
693                cast_options: &$lifetime CastOptions<$lifetime>,
694                $array_param: usize,
695                // add this so that $init_expr can use it
696                $( $field: $field_type, )?
697            ) -> Self {
698                Self {
699                    builder: $init_expr,
700                    cast_options,
701                }
702            }
703
704            fn append_null(&mut self) -> Result<()> {
705                self.builder.append_null();
706                Ok(())
707            }
708
709            fn append_value(&mut self, $value: &Variant<'_, '_>) -> Result<bool> {
710                if let Some(v) = $value_transform {
711                    self.builder.append_value(v);
712                    Ok(true)
713                } else {
714                    if !self.cast_options.safe {
715                        // Unsafe casting: return error on conversion failure
716                        return Err(ArrowError::CastError(format!(
717                            "Failed to extract primitive of type {} from variant {:?} at path VariantPath([])",
718                            $type_name,
719                            $value
720                        )));
721                    }
722                    // Safe casting: append null on conversion failure
723                    self.builder.append_null();
724                    Ok(false)
725                }
726            }
727
728            // Add this to silence unused mut warning from macro-generated code
729            // This is mainly for `FakeNullBuilder`
730            #[allow(unused_mut)]
731            fn finish(mut self) -> Result<ArrayRef> {
732                // If the builder produces T: Array, the compiler infers `<Arc<T> as From<T>>::from`
733                // (which then coerces to ArrayRef). If the builder produces ArrayRef directly, the
734                // compiler infers `<ArrayRef as From<ArrayRef>>::from` (no-op, From blanket impl).
735                Ok(Arc::from(self.builder.finish()))
736            }
737        }
738    }
739}
740
741define_variant_to_primitive_builder!(
742    struct VariantToStringArrowBuilder<'a, B: StringLikeArrayBuilder>
743    |capacity| -> B { B::with_capacity(capacity) },
744    |value| value.as_string(),
745    type_name: B::type_name()
746);
747
748define_variant_to_primitive_builder!(
749    struct VariantToBooleanArrowRowBuilder<'a>
750    |capacity| -> BooleanBuilder { BooleanBuilder::with_capacity(capacity) },
751    |value|  value.as_boolean(),
752    type_name: datatypes::BooleanType::DATA_TYPE
753);
754
755define_variant_to_primitive_builder!(
756    struct VariantToPrimitiveArrowRowBuilder<'a, T:PrimitiveFromVariant>
757    |capacity| -> PrimitiveBuilder<T> { PrimitiveBuilder::<T>::with_capacity(capacity) },
758    |value| T::from_variant(value),
759    type_name: T::DATA_TYPE
760);
761
762define_variant_to_primitive_builder!(
763    struct VariantToTimestampNtzArrowRowBuilder<'a, T:TimestampFromVariant<true>>
764    |capacity| -> PrimitiveBuilder<T> { PrimitiveBuilder::<T>::with_capacity(capacity) },
765    |value| T::from_variant(value),
766    type_name: T::DATA_TYPE
767);
768
769define_variant_to_primitive_builder!(
770    struct VariantToTimestampArrowRowBuilder<'a, T:TimestampFromVariant<false>>
771    |capacity, tz: Option<Arc<str>> | -> PrimitiveBuilder<T> {
772        PrimitiveBuilder::<T>::with_capacity(capacity).with_timezone_opt(tz)
773    },
774    |value| T::from_variant(value),
775    type_name: T::DATA_TYPE
776);
777
778define_variant_to_primitive_builder!(
779    struct VariantToBinaryArrowRowBuilder<'a, B: BinaryLikeArrayBuilder>
780    |capacity| -> B { B::with_capacity(capacity) },
781    |value| value.as_u8_slice(),
782    type_name: B::type_name()
783);
784
785/// Builder for converting variant values to arrow Decimal values
786pub(crate) struct VariantToDecimalArrowRowBuilder<'a, T>
787where
788    T: DecimalType,
789    T::Native: DecimalCast,
790{
791    builder: PrimitiveBuilder<T>,
792    cast_options: &'a CastOptions<'a>,
793    precision: u8,
794    scale: i8,
795}
796
797impl<'a, T> VariantToDecimalArrowRowBuilder<'a, T>
798where
799    T: DecimalType,
800    T::Native: DecimalCast,
801{
802    fn new(
803        cast_options: &'a CastOptions<'a>,
804        capacity: usize,
805        precision: u8,
806        scale: i8,
807    ) -> Result<Self> {
808        let builder = PrimitiveBuilder::<T>::with_capacity(capacity)
809            .with_precision_and_scale(precision, scale)?;
810        Ok(Self {
811            builder,
812            cast_options,
813            precision,
814            scale,
815        })
816    }
817
818    fn append_null(&mut self) -> Result<()> {
819        self.builder.append_null();
820        Ok(())
821    }
822
823    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
824        if let Some(scaled) = variant_to_unscaled_decimal::<T>(value, self.precision, self.scale) {
825            self.builder.append_value(scaled);
826            Ok(true)
827        } else if self.cast_options.safe {
828            self.builder.append_null();
829            Ok(false)
830        } else {
831            Err(ArrowError::CastError(format!(
832                "Failed to cast to {}(precision={}, scale={}) from variant {:?}",
833                T::PREFIX,
834                self.precision,
835                self.scale,
836                value
837            )))
838        }
839    }
840
841    fn finish(mut self) -> Result<ArrayRef> {
842        Ok(Arc::new(self.builder.finish()))
843    }
844}
845
846/// Builder for converting variant values to FixedSizeBinary(16) for UUIDs
847pub(crate) struct VariantToUuidArrowRowBuilder<'a> {
848    builder: FixedSizeBinaryBuilder,
849    cast_options: &'a CastOptions<'a>,
850}
851
852impl<'a> VariantToUuidArrowRowBuilder<'a> {
853    fn new(cast_options: &'a CastOptions<'a>, capacity: usize) -> Self {
854        Self {
855            builder: FixedSizeBinaryBuilder::with_capacity(capacity, 16),
856            cast_options,
857        }
858    }
859
860    fn append_null(&mut self) -> Result<()> {
861        self.builder.append_null();
862        Ok(())
863    }
864
865    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
866        match value.as_uuid() {
867            Some(uuid) => {
868                self.builder
869                    .append_value(uuid.as_bytes())
870                    .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
871
872                Ok(true)
873            }
874            None if self.cast_options.safe => {
875                self.builder.append_null();
876                Ok(false)
877            }
878            None => Err(ArrowError::CastError(format!(
879                "Failed to extract UUID from variant {value:?}",
880            ))),
881        }
882    }
883
884    fn finish(mut self) -> Result<ArrayRef> {
885        Ok(Arc::new(self.builder.finish()))
886    }
887}
888
889pub(crate) struct VariantToListArrowRowBuilder<'a, O, const IS_VIEW: bool>
890where
891    O: OffsetSizeTrait + ArrowNativeTypeOp,
892{
893    field: FieldRef,
894    offsets: Vec<O>,
895    element_builder: Box<VariantToShreddedVariantRowBuilder<'a>>,
896    nulls: NullBufferBuilder,
897    current_offset: O,
898    cast_options: &'a CastOptions<'a>,
899}
900
901impl<'a, O, const IS_VIEW: bool> VariantToListArrowRowBuilder<'a, O, IS_VIEW>
902where
903    O: OffsetSizeTrait + ArrowNativeTypeOp,
904{
905    fn try_new(
906        field: FieldRef,
907        element_data_type: &'a DataType,
908        cast_options: &'a CastOptions,
909        capacity: usize,
910    ) -> Result<Self> {
911        if capacity >= isize::MAX as usize {
912            return Err(ArrowError::ComputeError(
913                "Capacity exceeds isize::MAX when reserving list offsets".to_string(),
914            ));
915        }
916        let mut offsets = Vec::with_capacity(capacity + 1);
917        offsets.push(O::ZERO);
918        let element_builder = make_variant_to_shredded_variant_arrow_row_builder(
919            element_data_type,
920            cast_options,
921            capacity,
922            false,
923        )?;
924        Ok(Self {
925            field,
926            offsets,
927            element_builder: Box::new(element_builder),
928            nulls: NullBufferBuilder::new(capacity),
929            current_offset: O::ZERO,
930            cast_options,
931        })
932    }
933
934    fn append_null(&mut self) -> Result<()> {
935        self.offsets.push(self.current_offset);
936        self.nulls.append_null();
937        Ok(())
938    }
939
940    fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
941        match value {
942            Variant::List(list) => {
943                for element in list.iter() {
944                    self.element_builder.append_value(element)?;
945                    self.current_offset = self.current_offset.add_checked(O::ONE)?;
946                }
947                self.offsets.push(self.current_offset);
948                self.nulls.append_non_null();
949                Ok(true)
950            }
951            _ if self.cast_options.safe => {
952                self.append_null()?;
953                Ok(false)
954            }
955            _ => Err(ArrowError::CastError(format!(
956                "Failed to extract list from variant {:?}",
957                value
958            ))),
959        }
960    }
961
962    fn finish(mut self) -> Result<ArrayRef> {
963        let (value, typed_value, nulls) = self.element_builder.finish()?;
964        let element_array =
965            ShreddedVariantFieldArray::from_parts(Some(value), Some(typed_value), nulls);
966        let field = Arc::new(
967            self.field
968                .as_ref()
969                .clone()
970                .with_data_type(element_array.data_type().clone()),
971        );
972
973        if IS_VIEW {
974            // NOTE: `offsets` is never empty (constructor pushes an entry)
975            let mut sizes = Vec::with_capacity(self.offsets.len() - 1);
976            for i in 1..self.offsets.len() {
977                sizes.push(self.offsets[i] - self.offsets[i - 1]);
978            }
979            self.offsets.pop();
980            let list_view_array = GenericListViewArray::<O>::new(
981                field,
982                ScalarBuffer::from(self.offsets),
983                ScalarBuffer::from(sizes),
984                ArrayRef::from(element_array),
985                self.nulls.finish(),
986            );
987            Ok(Arc::new(list_view_array))
988        } else {
989            let list_array = GenericListArray::<O>::new(
990                field,
991                OffsetBuffer::<O>::new(ScalarBuffer::from(self.offsets)),
992                ArrayRef::from(element_array),
993                self.nulls.finish(),
994            );
995            Ok(Arc::new(list_array))
996        }
997    }
998}
999
1000/// Builder for creating VariantArray output (for path extraction without type conversion)
1001pub(crate) struct VariantToBinaryVariantArrowRowBuilder {
1002    metadata: BinaryViewArray,
1003    builder: VariantValueArrayBuilder,
1004    nulls: NullBufferBuilder,
1005}
1006
1007impl VariantToBinaryVariantArrowRowBuilder {
1008    fn new(metadata: BinaryViewArray, capacity: usize) -> Self {
1009        Self {
1010            metadata,
1011            builder: VariantValueArrayBuilder::new(capacity),
1012            nulls: NullBufferBuilder::new(capacity),
1013        }
1014    }
1015}
1016
1017impl VariantToBinaryVariantArrowRowBuilder {
1018    fn append_null(&mut self) -> Result<()> {
1019        self.builder.append_null();
1020        self.nulls.append_null();
1021        Ok(())
1022    }
1023
1024    fn append_value(&mut self, value: Variant<'_, '_>) -> Result<bool> {
1025        self.builder.append_value(value);
1026        self.nulls.append_non_null();
1027        Ok(true)
1028    }
1029
1030    fn finish(mut self) -> Result<ArrayRef> {
1031        let variant_array = VariantArray::from_parts(
1032            self.metadata,
1033            Some(self.builder.build()?),
1034            None, // no typed_value column
1035            self.nulls.finish(),
1036        );
1037
1038        Ok(ArrayRef::from(variant_array))
1039    }
1040}
1041
1042#[derive(Default)]
1043struct FakeNullBuilder {
1044    item_count: usize,
1045}
1046
1047impl FakeNullBuilder {
1048    fn append_value(&mut self, _: ()) {
1049        self.item_count += 1;
1050    }
1051
1052    fn append_null(&mut self) {
1053        self.item_count += 1;
1054    }
1055
1056    fn finish(self) -> NullArray {
1057        NullArray::new(self.item_count)
1058    }
1059}
1060
1061define_variant_to_primitive_builder!(
1062    struct VariantToNullArrowRowBuilder<'a>
1063    |_capacity| -> FakeNullBuilder { FakeNullBuilder::default() },
1064    |value| value.as_null(),
1065    type_name: "Null"
1066);
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::make_primitive_variant_to_arrow_row_builder;
1071    use arrow::compute::CastOptions;
1072    use arrow::datatypes::{DataType, Field, Fields, UnionFields, UnionMode};
1073    use arrow::error::ArrowError;
1074    use std::sync::Arc;
1075
1076    #[test]
1077    fn make_primitive_builder_rejects_non_primitive_types() {
1078        let cast_options = CastOptions::default();
1079        let item_field = Arc::new(Field::new("item", DataType::Int32, true));
1080        let struct_fields = Fields::from(vec![Field::new("child", DataType::Int32, true)]);
1081        let map_entries_field = Arc::new(Field::new(
1082            "entries",
1083            DataType::Struct(Fields::from(vec![
1084                Field::new("key", DataType::Utf8, false),
1085                Field::new("value", DataType::Float64, true),
1086            ])),
1087            true,
1088        ));
1089        let union_fields =
1090            UnionFields::try_new(vec![1], vec![Field::new("child", DataType::Int32, true)])
1091                .unwrap();
1092        let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int32, false));
1093        let ree_values_field = Arc::new(Field::new("values", DataType::Utf8, true));
1094
1095        let non_primitive_types = vec![
1096            DataType::List(item_field.clone()),
1097            DataType::LargeList(item_field.clone()),
1098            DataType::ListView(item_field.clone()),
1099            DataType::LargeListView(item_field.clone()),
1100            DataType::FixedSizeList(item_field.clone(), 2),
1101            DataType::Struct(struct_fields.clone()),
1102            DataType::Map(map_entries_field.clone(), false),
1103            DataType::Union(union_fields.clone(), UnionMode::Dense),
1104            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
1105            DataType::RunEndEncoded(run_ends_field.clone(), ree_values_field.clone()),
1106        ];
1107
1108        for data_type in non_primitive_types {
1109            let err =
1110                match make_primitive_variant_to_arrow_row_builder(&data_type, &cast_options, 1) {
1111                    Ok(_) => panic!("non-primitive type {data_type:?} should be rejected"),
1112                    Err(err) => err,
1113                };
1114
1115            match err {
1116                ArrowError::InvalidArgumentError(msg) => {
1117                    assert!(msg.contains(&format!("{data_type:?}")));
1118                }
1119                other => panic!("expected InvalidArgumentError, got {other:?}"),
1120            }
1121        }
1122    }
1123}