parquet_variant_compute/
variant_get.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::sync::Arc;
18
19use arrow::{
20    array::{Array, ArrayRef},
21    compute::CastOptions,
22    error::Result,
23};
24use arrow_schema::{ArrowError, Field};
25use parquet_variant::VariantPath;
26
27use crate::{VariantArray, VariantArrayBuilder};
28
29/// Returns an array with the specified path extracted from the variant values.
30///
31/// The return array type depends on the `as_type` field of the options parameter
32/// 1. `as_type: None`: a VariantArray is returned. The values in this new VariantArray will point
33///    to the specified path.
34/// 2. `as_type: Some(<specific field>)`: an array of the specified type is returned.
35pub fn variant_get(input: &ArrayRef, options: GetOptions) -> Result<ArrayRef> {
36    let variant_array: &VariantArray = input.as_any().downcast_ref().ok_or_else(|| {
37        ArrowError::InvalidArgumentError(
38            "expected a VariantArray as the input for variant_get".to_owned(),
39        )
40    })?;
41
42    if let Some(as_type) = options.as_type {
43        return Err(ArrowError::NotYetImplemented(format!(
44            "getting a {as_type} from a VariantArray is not implemented yet",
45        )));
46    }
47
48    let mut builder = VariantArrayBuilder::new(variant_array.len());
49    for i in 0..variant_array.len() {
50        let new_variant = variant_array.value(i);
51        // TODO: perf?
52        let new_variant = new_variant.get_path(&options.path);
53        match new_variant {
54            // TODO: we're decoding the value and doing a copy into a variant value again. This
55            // copy can be much smarter.
56            Some(new_variant) => builder.append_variant(new_variant),
57            None => builder.append_null(),
58        }
59    }
60
61    Ok(Arc::new(builder.build()))
62}
63
64/// Controls the action of the variant_get kernel.
65#[derive(Debug, Clone)]
66pub struct GetOptions<'a> {
67    /// What path to extract
68    pub path: VariantPath<'a>,
69    /// if `as_type` is None, the returned array will itself be a VariantArray.
70    ///
71    /// if `as_type` is `Some(type)` the field is returned as the specified type.
72    pub as_type: Option<Field>,
73    /// Controls the casting behavior (e.g. error vs substituting null on cast error).
74    pub cast_options: CastOptions<'a>,
75}
76
77impl<'a> GetOptions<'a> {
78    /// Construct options to get the specified path as a variant.
79    pub fn new_with_path(path: VariantPath<'a>) -> Self {
80        Self {
81            path,
82            as_type: None,
83            cast_options: Default::default(),
84        }
85    }
86}
87
88#[cfg(test)]
89mod test {
90    use std::sync::Arc;
91
92    use arrow::array::{Array, ArrayRef, StringArray};
93    use parquet_variant::VariantPath;
94
95    use crate::batch_json_string_to_variant;
96    use crate::VariantArray;
97
98    use super::{variant_get, GetOptions};
99
100    fn single_variant_get_test(input_json: &str, path: VariantPath, expected_json: &str) {
101        // Create input array from JSON string
102        let input_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(input_json)]));
103        let input_variant_array_ref: ArrayRef =
104            Arc::new(batch_json_string_to_variant(&input_array_ref).unwrap());
105
106        let result =
107            variant_get(&input_variant_array_ref, GetOptions::new_with_path(path)).unwrap();
108
109        // Create expected array from JSON string
110        let expected_array_ref: ArrayRef = Arc::new(StringArray::from(vec![Some(expected_json)]));
111        let expected_variant_array = batch_json_string_to_variant(&expected_array_ref).unwrap();
112
113        let result_array: &VariantArray = result.as_any().downcast_ref().unwrap();
114        assert_eq!(
115            result_array.len(),
116            1,
117            "Expected result array to have length 1"
118        );
119        assert!(
120            result_array.nulls().is_none(),
121            "Expected no nulls in result array"
122        );
123        let result_variant = result_array.value(0);
124        let expected_variant = expected_variant_array.value(0);
125        assert_eq!(
126            result_variant, expected_variant,
127            "Result variant does not match expected variant"
128        );
129    }
130
131    #[test]
132    fn get_primitive_variant_field() {
133        single_variant_get_test(
134            r#"{"some_field": 1234}"#,
135            VariantPath::from("some_field"),
136            "1234",
137        );
138    }
139
140    #[test]
141    fn get_primitive_variant_list_index() {
142        single_variant_get_test("[1234, 5678]", VariantPath::from(0), "1234");
143    }
144
145    #[test]
146    fn get_primitive_variant_inside_object_of_object() {
147        single_variant_get_test(
148            r#"{"top_level_field": {"inner_field": 1234}}"#,
149            VariantPath::from("top_level_field").join("inner_field"),
150            "1234",
151        );
152    }
153
154    #[test]
155    fn get_primitive_variant_inside_list_of_object() {
156        single_variant_get_test(
157            r#"[{"some_field": 1234}]"#,
158            VariantPath::from(0).join("some_field"),
159            "1234",
160        );
161    }
162
163    #[test]
164    fn get_primitive_variant_inside_object_of_list() {
165        single_variant_get_test(
166            r#"{"some_field": [1234]}"#,
167            VariantPath::from("some_field").join(0),
168            "1234",
169        );
170    }
171
172    #[test]
173    fn get_complex_variant() {
174        single_variant_get_test(
175            r#"{"top_level_field": {"inner_field": 1234}}"#,
176            VariantPath::from("top_level_field"),
177            r#"{"inner_field": 1234}"#,
178        );
179    }
180}