parquet_variant_compute/variant_get/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use arrow::{
18    array::{Array, ArrayRef},
19    compute::CastOptions,
20    error::Result,
21};
22use arrow_schema::{ArrowError, FieldRef};
23use parquet_variant::VariantPath;
24
25use crate::variant_array::ShreddingState;
26use crate::variant_get::output::instantiate_output_builder;
27use crate::VariantArray;
28
29mod output;
30
31/// Returns an array with the specified path extracted from the variant values.
32///
33/// The return array type depends on the `as_type` field of the options parameter
34/// 1. `as_type: None`: a VariantArray is returned. The values in this new VariantArray will point
35///    to the specified path.
36/// 2. `as_type: Some(<specific field>)`: an array of the specified type is returned.
37pub fn variant_get(input: &ArrayRef, options: GetOptions) -> Result<ArrayRef> {
38    let variant_array: &VariantArray = input.as_any().downcast_ref().ok_or_else(|| {
39        ArrowError::InvalidArgumentError(
40            "expected a VariantArray as the input for variant_get".to_owned(),
41        )
42    })?;
43
44    // Create the output writer based on the specified output options
45    let output_builder = instantiate_output_builder(options.clone())?;
46
47    // Dispatch based on the shredding state of the input variant array
48    match variant_array.shredding_state() {
49        ShreddingState::PartiallyShredded {
50            metadata,
51            value,
52            typed_value,
53        } => output_builder.partially_shredded(variant_array, metadata, value, typed_value),
54        ShreddingState::Typed {
55            metadata,
56            typed_value,
57        } => output_builder.typed(variant_array, metadata, typed_value),
58        ShreddingState::Unshredded { metadata, value } => {
59            output_builder.unshredded(variant_array, metadata, value)
60        }
61    }
62}
63
64/// Controls the action of the variant_get kernel.
65#[derive(Debug, Clone, Default)]
66pub struct GetOptions<'a> {
67    /// What path to extract
68    pub path: VariantPath<'a>,
69    /// if `as_type` is None, the returned array will itself be a VariantArray.
70    ///
71    /// if `as_type` is `Some(type)` the field is returned as the specified type.
72    pub as_type: Option<FieldRef>,
73    /// Controls the casting behavior (e.g. error vs substituting null on cast error).
74    pub cast_options: CastOptions<'a>,
75}
76
77impl<'a> GetOptions<'a> {
78    /// Construct default options to get the specified path as a variant.
79    pub fn new() -> Self {
80        Default::default()
81    }
82
83    /// Construct options to get the specified path as a variant.
84    pub fn new_with_path(path: VariantPath<'a>) -> Self {
85        Self {
86            path,
87            as_type: None,
88            cast_options: Default::default(),
89        }
90    }
91
92    /// Specify the type to return.
93    pub fn with_as_type(mut self, as_type: Option<FieldRef>) -> Self {
94        self.as_type = as_type;
95        self
96    }
97
98    /// Specify the cast options to use when casting to the specified type.
99    pub fn with_cast_options(mut self, cast_options: CastOptions<'a>) -> Self {
100        self.cast_options = cast_options;
101        self
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use std::sync::Arc;
108
109    use arrow::array::{Array, ArrayRef, BinaryViewArray, Int32Array, StringArray, StructArray};
110    use arrow::buffer::NullBuffer;
111    use arrow::compute::CastOptions;
112    use arrow_schema::{DataType, Field, FieldRef, Fields};
113    use parquet_variant::{Variant, VariantPath};
114
115    use crate::batch_json_string_to_variant;
116    use crate::VariantArray;
117
118    use super::{variant_get, GetOptions};
119
120    fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) {
121        // Create input array from JSON string
122        let input_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(input_json)]));
123        let input_variant_array_ref: ArrayRef =
124            Arc::new(batch_json_string_to_variant(&input_array_ref).unwrap());
125
126        let result =
127            variant_get(&input_variant_array_ref, GetOptions::new_with_path(path)).unwrap();
128
129        // Create expected array from JSON string
130        let expected_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(expected_json)]));
131        let expected_variant_array = batch_json_string_to_variant(&expected_array_ref).unwrap();
132
133        let result_array: &VariantArray = result.as_any().downcast_ref().unwrap();
134        assert_eq!(
135            result_array.len(),
136            1,
137            "Expected result array to have length 1"
138        );
139        assert!(
140            result_array.nulls().is_none(),
141            "Expected no nulls in result array"
142        );
143        let result_variant = result_array.value(0);
144        let expected_variant = expected_variant_array.value(0);
145        assert_eq!(
146            result_variant, expected_variant,
147            "Result variant does not match expected variant"
148        );
149    }
150
151    #[test]
152    fn get_primitive_variant_field() {
153        single_variant_get_test(
154            r#"{"some_field": 1234}"#,
155            VariantPath::from("some_field"),
156            "1234",
157        );
158    }
159
160    #[test]
161    fn get_primitive_variant_list_index() {
162        single_variant_get_test("[1234, 5678]", VariantPath::from(0), "1234");
163    }
164
165    #[test]
166    fn get_primitive_variant_inside_object_of_object() {
167        single_variant_get_test(
168            r#"{"top_level_field": {"inner_field": 1234}}"#,
169            VariantPath::from("top_level_field").join("inner_field"),
170            "1234",
171        );
172    }
173
174    #[test]
175    fn get_primitive_variant_inside_list_of_object() {
176        single_variant_get_test(
177            r#"[{"some_field": 1234}]"#,
178            VariantPath::from(0).join("some_field"),
179            "1234",
180        );
181    }
182
183    #[test]
184    fn get_primitive_variant_inside_object_of_list() {
185        single_variant_get_test(
186            r#"{"some_field": [1234]}"#,
187            VariantPath::from("some_field").join(0),
188            "1234",
189        );
190    }
191
192    #[test]
193    fn get_complex_variant() {
194        single_variant_get_test(
195            r#"{"top_level_field": {"inner_field": 1234}}"#,
196            VariantPath::from("top_level_field"),
197            r#"{"inner_field": 1234}"#,
198        );
199    }
200
201    /// Shredding: extract a value as a VariantArray
202    #[test]
203    fn get_variant_shredded_int32_as_variant() {
204        let array = shredded_int32_variant_array();
205        let options = GetOptions::new();
206        let result = variant_get(&array, options).unwrap();
207
208        // expect the result is a VariantArray
209        let result: &VariantArray = result.as_any().downcast_ref().unwrap();
210        assert_eq!(result.len(), 4);
211
212        // Expect the values are the same as the original values
213        assert_eq!(result.value(0), Variant::Int32(34));
214        assert!(!result.is_valid(1));
215        assert_eq!(result.value(2), Variant::from("n/a"));
216        assert_eq!(result.value(3), Variant::Int32(100));
217    }
218
219    /// Shredding: extract a value as an Int32Array
220    #[test]
221    fn get_variant_shredded_int32_as_int32_safe_cast() {
222        // Extract the typed value as Int32Array
223        let array = shredded_int32_variant_array();
224        // specify we want the typed value as Int32
225        let field = Field::new("typed_value", DataType::Int32, true);
226        let options = GetOptions::new().with_as_type(Some(FieldRef::from(field)));
227        let result = variant_get(&array, options).unwrap();
228        let expected: ArrayRef = Arc::new(Int32Array::from(vec![
229            Some(34),
230            None,
231            None, // "n/a" is not an Int32 so converted to null
232            Some(100),
233        ]));
234        assert_eq!(&result, &expected)
235    }
236
237    /// Shredding: extract a value as an Int32Array, unsafe cast (should error on "n/a")
238
239    #[test]
240    fn get_variant_shredded_int32_as_int32_unsafe_cast() {
241        // Extract the typed value as Int32Array
242        let array = shredded_int32_variant_array();
243        let field = Field::new("typed_value", DataType::Int32, true);
244        let cast_options = CastOptions {
245            safe: false, // unsafe cast
246            ..Default::default()
247        };
248        let options = GetOptions::new()
249            .with_as_type(Some(FieldRef::from(field)))
250            .with_cast_options(cast_options);
251
252        let err = variant_get(&array, options).unwrap_err();
253        // TODO make this error message nicer (not Debug format)
254        assert_eq!(err.to_string(), "Cast error: Failed to extract primitive of type Int32 from variant ShortString(ShortString(\"n/a\")) at path VariantPath([])");
255    }
256
257    /// Perfect Shredding: extract the typed value as a VariantArray
258    #[test]
259    fn get_variant_perfectly_shredded_int32_as_variant() {
260        let array = perfectly_shredded_int32_variant_array();
261        let options = GetOptions::new();
262        let result = variant_get(&array, options).unwrap();
263
264        // expect the result is a VariantArray
265        let result: &VariantArray = result.as_any().downcast_ref().unwrap();
266        assert_eq!(result.len(), 3);
267
268        // Expect the values are the same as the original values
269        assert_eq!(result.value(0), Variant::Int32(1));
270        assert_eq!(result.value(1), Variant::Int32(2));
271        assert_eq!(result.value(2), Variant::Int32(3));
272    }
273
274    /// Shredding: Extract the typed value as Int32Array
275    #[test]
276    fn get_variant_perfectly_shredded_int32_as_int32() {
277        // Extract the typed value as Int32Array
278        let array = perfectly_shredded_int32_variant_array();
279        // specify we want the typed value as Int32
280        let field = Field::new("typed_value", DataType::Int32, true);
281        let options = GetOptions::new().with_as_type(Some(FieldRef::from(field)));
282        let result = variant_get(&array, options).unwrap();
283        let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]));
284        assert_eq!(&result, &expected)
285    }
286
287    /// Return a VariantArray that represents a perfectly "shredded" variant
288    /// for the following example (3 Variant::Int32 values):
289    ///
290    /// ```text
291    /// 1
292    /// 2
293    /// 3
294    /// ```
295    ///
296    /// The schema of the corresponding `StructArray` would look like this:
297    ///
298    /// ```text
299    /// StructArray {
300    ///   metadata: BinaryViewArray,
301    ///   typed_value: Int32Array,
302    /// }
303    /// ```
304    fn perfectly_shredded_int32_variant_array() -> ArrayRef {
305        // At the time of writing, the `VariantArrayBuilder` does not support shredding.
306        // so we must construct the array manually.  see https://github.com/apache/arrow-rs/issues/7895
307        let (metadata, _value) = { parquet_variant::VariantBuilder::new().finish() };
308
309        let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 3));
310        let typed_value = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
311
312        let struct_array = StructArrayBuilder::new()
313            .with_field("metadata", Arc::new(metadata))
314            .with_field("typed_value", Arc::new(typed_value))
315            .build();
316
317        Arc::new(
318            VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"),
319        )
320    }
321
322    /// Return a VariantArray that represents a normal "shredded" variant
323    /// for the following example
324    ///
325    /// Based on the example from [the doc]
326    ///
327    /// [the doc]: https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?tab=t.0
328    ///
329    /// ```text
330    /// 34
331    /// null (an Arrow NULL, not a Variant::Null)
332    /// "n/a" (a string)
333    /// 100
334    /// ```
335    ///
336    /// The schema of the corresponding `StructArray` would look like this:
337    ///
338    /// ```text
339    /// StructArray {
340    ///   metadata: BinaryViewArray,
341    ///   value: BinaryViewArray,
342    ///   typed_value: Int32Array,
343    /// }
344    /// ```
345    fn shredded_int32_variant_array() -> ArrayRef {
346        // At the time of writing, the `VariantArrayBuilder` does not support shredding.
347        // so we must construct the array manually.  see https://github.com/apache/arrow-rs/issues/7895
348        let (metadata, string_value) = {
349            let mut builder = parquet_variant::VariantBuilder::new();
350            builder.append_value("n/a");
351            builder.finish()
352        };
353
354        let nulls = NullBuffer::from(vec![
355            true,  // row 0 non null
356            false, // row 1 is null
357            true,  // row 2 non null
358            true,  // row 3 non null
359        ]);
360
361        // metadata is the same for all rows
362        let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4));
363
364        // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY
365        // about why row1 is an empty but non null, value.
366        let values = BinaryViewArray::from(vec![
367            None,                // row 0 is shredded, so no value
368            Some(b"" as &[u8]),  // row 1 is null, so empty value (why?)
369            Some(&string_value), // copy the string value "N/A"
370            None,                // row 3 is shredded, so no value
371        ]);
372
373        let typed_value = Int32Array::from(vec![
374            Some(34),  // row 0 is shredded, so it has a value
375            None,      // row 1 is null, so no value
376            None,      // row 2 is a string, so no typed value
377            Some(100), // row 3 is shredded, so it has a value
378        ]);
379
380        let struct_array = StructArrayBuilder::new()
381            .with_field("metadata", Arc::new(metadata))
382            .with_field("typed_value", Arc::new(typed_value))
383            .with_field("value", Arc::new(values))
384            .with_nulls(nulls)
385            .build();
386
387        Arc::new(
388            VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"),
389        )
390    }
391
392    /// Builds struct arrays from component fields
393    ///
394    /// TODO: move to arrow crate
395    #[derive(Debug, Default, Clone)]
396    struct StructArrayBuilder {
397        fields: Vec<FieldRef>,
398        arrays: Vec<ArrayRef>,
399        nulls: Option<NullBuffer>,
400    }
401
402    impl StructArrayBuilder {
403        fn new() -> Self {
404            Default::default()
405        }
406
407        /// Add an array to this struct array as a field with the specified name.
408        fn with_field(mut self, field_name: &str, array: ArrayRef) -> Self {
409            let field = Field::new(field_name, array.data_type().clone(), true);
410            self.fields.push(Arc::new(field));
411            self.arrays.push(array);
412            self
413        }
414
415        /// Set the null buffer for this struct array.
416        fn with_nulls(mut self, nulls: NullBuffer) -> Self {
417            self.nulls = Some(nulls);
418            self
419        }
420
421        pub fn build(self) -> StructArray {
422            let Self {
423                fields,
424                arrays,
425                nulls,
426            } = self;
427            StructArray::new(Fields::from(fields), arrays, nulls)
428        }
429    }
430}