Skip to main content

arrow_cast/cast/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Cast support for union arrays.
19
20use crate::cast::can_cast_types;
21use crate::cast_with_options;
22use arrow_array::{Array, ArrayRef, UnionArray};
23use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
24use arrow_select::union_extract::union_extract;
25
26use super::CastOptions;
27
28// this is used during child array selection to prefer a "close" type over a distant cast
29// for example: when targeting Utf8View, a Utf8 child is preferred over Int32 despite both being castable
30fn same_type_family(a: &DataType, b: &DataType) -> bool {
31    use DataType::*;
32    matches!(
33        (a, b),
34        (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View)
35            | (
36                Binary | LargeBinary | BinaryView,
37                Binary | LargeBinary | BinaryView
38            )
39            | (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64)
40            | (
41                UInt8 | UInt16 | UInt32 | UInt64,
42                UInt8 | UInt16 | UInt32 | UInt64
43            )
44            | (Float16 | Float32 | Float64, Float16 | Float32 | Float64)
45    )
46}
47
48/// Selects the best-matching child array from a [`UnionArray`] for a given target type
49///
50/// The goal is to find the source field whose type is closest to the target,
51/// so that the subsequent cast is as lossless as possible. The heuristic uses
52/// three passes with decreasing specificity:
53///
54/// 1. **Exact match**: field type equals the target type.
55/// 2. **Same type family**: field and target belong to the same logical family
56///    (e.g. `Utf8` and `Utf8View` are both strings). This avoids a greedy
57///    cross-family cast in pass 3 (e.g. picking `Int32` over `Utf8` when the
58///    target is `Utf8View`, since `can_cast_types(Int32, Utf8View)` is true)
59/// 3. **Castable**:`can_cast_types` reports the field can be cast to the target
60///    Nested target types are skipped here because union extraction introduces
61///    nulls, which can conflict with non-nullable inner fields
62///
63/// Each pass greedily picks the first matching field by type_id order
64pub(crate) fn resolve_child_array<'a>(
65    fields: &'a UnionFields,
66    target_type: &DataType,
67) -> Option<&'a FieldRef> {
68    fields
69        .iter()
70        .find(|(_, f)| f.data_type() == target_type)
71        .or_else(|| {
72            fields
73                .iter()
74                .find(|(_, f)| same_type_family(f.data_type(), target_type))
75        })
76        .or_else(|| {
77            // skip nested types in pass 3 because union extraction introduces nulls,
78            // and casting nullable arrays to nested types like List/Struct/Map can fail
79            // when inner fields are non-nullable.
80            if target_type.is_nested() {
81                return None;
82            }
83            fields
84                .iter()
85                .find(|(_, f)| can_cast_types(f.data_type(), target_type))
86        })
87        .map(|(_, f)| f)
88}
89
90/// Extracts the best-matching child array from a [`UnionArray`] for a given target type,
91/// and casts it to that type.
92///
93/// Rows where a different child array is active become NULL.
94/// If no child array matches, returns an error.
95///
96/// # Example
97///
98/// ```
99/// # use std::sync::Arc;
100/// # use arrow_schema::{DataType, Field, UnionFields};
101/// # use arrow_array::{UnionArray, StringArray, Int32Array, Array};
102/// # use arrow_cast::cast::union_extract_by_type;
103/// # use arrow_cast::CastOptions;
104/// let fields = UnionFields::try_new(
105///     [0, 1],
106///     [
107///         Field::new("int", DataType::Int32, true),
108///         Field::new("str", DataType::Utf8, true),
109///     ],
110/// ).unwrap();
111///
112/// let union = UnionArray::try_new(
113///     fields,
114///     vec![0, 1, 0].into(),
115///     None,
116///     vec![
117///         Arc::new(Int32Array::from(vec![Some(42), None, Some(99)])),
118///         Arc::new(StringArray::from(vec![None, Some("hello"), None])),
119///     ],
120/// )
121/// .unwrap();
122///
123/// // extract the Utf8 child array and cast to Utf8View
124/// let result = union_extract_by_type(&union, &DataType::Utf8View, &CastOptions::default()).unwrap();
125/// assert_eq!(result.data_type(), &DataType::Utf8View);
126/// assert!(result.is_null(0));   // Int32 row -> NULL
127/// assert!(!result.is_null(1));  // Utf8 row -> "hello"
128/// assert!(result.is_null(2));   // Int32 row -> NULL
129/// ```
130pub fn union_extract_by_type(
131    union_array: &UnionArray,
132    target_type: &DataType,
133    cast_options: &CastOptions,
134) -> Result<ArrayRef, ArrowError> {
135    let fields = match union_array.data_type() {
136        DataType::Union(fields, _) => fields,
137        _ => unreachable!("union_extract_by_type called on non-union array"),
138    };
139
140    let Some(field) = resolve_child_array(fields, target_type) else {
141        return Err(ArrowError::CastError(format!(
142            "cannot cast Union with fields {} to {}",
143            fields
144                .iter()
145                .map(|(_, f)| f.data_type().to_string())
146                .collect::<Vec<_>>()
147                .join(", "),
148            target_type
149        )));
150    };
151
152    let extracted = union_extract(union_array, field.name())?;
153
154    if extracted.data_type() == target_type {
155        return Ok(extracted);
156    }
157
158    cast_with_options(&extracted, target_type, cast_options)
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::cast;
165    use arrow_array::*;
166    use arrow_schema::{Field, UnionFields, UnionMode};
167    use std::sync::Arc;
168
169    fn int_str_fields() -> UnionFields {
170        UnionFields::try_new(
171            [0, 1],
172            [
173                Field::new("int", DataType::Int32, true),
174                Field::new("str", DataType::Utf8, true),
175            ],
176        )
177        .unwrap()
178    }
179
180    fn int_str_union_type(mode: UnionMode) -> DataType {
181        DataType::Union(int_str_fields(), mode)
182    }
183
184    // pass 1: exact type match.
185    // Union(Int32, Utf8) targeting Utf8. The Utf8 child matches exactly.
186    // Int32 rows become NULL. tested for both sparse and dense.
187    #[test]
188    fn test_exact_type_match() {
189        let target = DataType::Utf8;
190
191        // sparse
192        assert!(can_cast_types(
193            &int_str_union_type(UnionMode::Sparse),
194            &target
195        ));
196
197        let sparse = UnionArray::try_new(
198            int_str_fields(),
199            vec![1_i8, 0, 1].into(),
200            None,
201            vec![
202                Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef,
203                Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
204            ],
205        )
206        .unwrap();
207
208        let result = cast::cast(&sparse, &target).unwrap();
209        assert_eq!(result.data_type(), &target);
210        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
211        assert_eq!(arr.value(0), "hello");
212        assert!(arr.is_null(1));
213        assert_eq!(arr.value(2), "world");
214
215        // dense
216        assert!(can_cast_types(
217            &int_str_union_type(UnionMode::Dense),
218            &target
219        ));
220
221        let dense = UnionArray::try_new(
222            int_str_fields(),
223            vec![1_i8, 0, 1].into(),
224            Some(vec![0_i32, 0, 1].into()),
225            vec![
226                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
227                Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
228            ],
229        )
230        .unwrap();
231
232        let result = cast::cast(&dense, &target).unwrap();
233        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
234        assert_eq!(arr.value(0), "hello");
235        assert!(arr.is_null(1));
236        assert_eq!(arr.value(2), "world");
237    }
238
239    // pass 2: same type family match.
240    // Union(Int32, Utf8) targeting Utf8View. No exact match, but Utf8 and Utf8View
241    // are in the same family. picks the Utf8 child array and casts to Utf8View.
242    // this is the bug that motivated this work: without pass 2, pass 3 would
243    // greedily pick Int32 (since can_cast_types(Int32, Utf8View) is true).
244    #[test]
245    fn test_same_family_utf8_to_utf8view() {
246        let target = DataType::Utf8View;
247
248        // sparse
249        assert!(can_cast_types(
250            &int_str_union_type(UnionMode::Sparse),
251            &target
252        ));
253
254        let sparse = UnionArray::try_new(
255            int_str_fields(),
256            vec![1_i8, 0, 1, 1].into(),
257            None,
258            vec![
259                Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef,
260                Arc::new(StringArray::from(vec![
261                    Some("agent_alpha"),
262                    None,
263                    Some("agent_beta"),
264                    None,
265                ])),
266            ],
267        )
268        .unwrap();
269
270        let result = cast::cast(&sparse, &target).unwrap();
271        assert_eq!(result.data_type(), &target);
272        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
273        assert_eq!(arr.value(0), "agent_alpha");
274        assert!(arr.is_null(1));
275        assert_eq!(arr.value(2), "agent_beta");
276        assert!(arr.is_null(3));
277
278        // dense
279        assert!(can_cast_types(
280            &int_str_union_type(UnionMode::Dense),
281            &target
282        ));
283
284        let dense = UnionArray::try_new(
285            int_str_fields(),
286            vec![1_i8, 0, 1].into(),
287            Some(vec![0_i32, 0, 1].into()),
288            vec![
289                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
290                Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])),
291            ],
292        )
293        .unwrap();
294
295        let result = cast::cast(&dense, &target).unwrap();
296        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
297        assert_eq!(arr.value(0), "alpha");
298        assert!(arr.is_null(1));
299        assert_eq!(arr.value(2), "beta");
300    }
301
302    // pass 3: one-directional cast across type families.
303    // Union(Int32, Utf8) targeting Boolean — no exact match, no family match.
304    // pass 3 picks Int32 (first child array where can_cast_types is true) and
305    // casts to Boolean (0 → false, nonzero → true). Utf8 rows become NULL.
306    #[test]
307    fn test_one_directional_cast() {
308        let target = DataType::Boolean;
309
310        // sparse
311        assert!(can_cast_types(
312            &int_str_union_type(UnionMode::Sparse),
313            &target
314        ));
315
316        let sparse = UnionArray::try_new(
317            int_str_fields(),
318            vec![0_i8, 1, 0].into(),
319            None,
320            vec![
321                Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef,
322                Arc::new(StringArray::from(vec![None, Some("hello"), None])),
323            ],
324        )
325        .unwrap();
326
327        let result = cast::cast(&sparse, &target).unwrap();
328        assert_eq!(result.data_type(), &target);
329        let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
330        assert!(arr.value(0));
331        assert!(arr.is_null(1));
332        assert!(!arr.value(2));
333
334        // dense
335        assert!(can_cast_types(
336            &int_str_union_type(UnionMode::Dense),
337            &target
338        ));
339
340        let dense = UnionArray::try_new(
341            int_str_fields(),
342            vec![0_i8, 1, 0].into(),
343            Some(vec![0_i32, 0, 1].into()),
344            vec![
345                Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef,
346                Arc::new(StringArray::from(vec![Some("hello")])),
347            ],
348        )
349        .unwrap();
350
351        let result = cast::cast(&dense, &target).unwrap();
352        let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
353        assert!(arr.value(0));
354        assert!(arr.is_null(1));
355        assert!(!arr.value(2));
356    }
357
358    // no matching child array, all three passes fail.
359    // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
360    // can be cast to a Struct, so both can_cast_types and cast return errors.
361    #[test]
362    fn test_no_match_errors() {
363        let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into());
364
365        assert!(!can_cast_types(
366            &int_str_union_type(UnionMode::Sparse),
367            &target
368        ));
369
370        let union = UnionArray::try_new(
371            int_str_fields(),
372            vec![0_i8, 1].into(),
373            None,
374            vec![
375                Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef,
376                Arc::new(StringArray::from(vec![None, Some("hello")])),
377            ],
378        )
379        .unwrap();
380
381        assert!(cast::cast(&union, &target).is_err());
382    }
383
384    // priority: exact match (pass 1) wins over family match (pass 2).
385    // Union(Utf8, Utf8View) targeting Utf8View. Both child arrays are in the string
386    // family, but Utf8View is an exact match. pass 1 should pick it, not Utf8.
387    #[test]
388    fn test_exact_match_preferred_over_family() {
389        let fields = UnionFields::try_new(
390            [0, 1],
391            [
392                Field::new("a", DataType::Utf8, true),
393                Field::new("b", DataType::Utf8View, true),
394            ],
395        )
396        .unwrap();
397        let target = DataType::Utf8View;
398
399        assert!(can_cast_types(
400            &DataType::Union(fields.clone(), UnionMode::Sparse),
401            &target,
402        ));
403
404        // [Utf8("from_a"), Utf8View("from_b"), Utf8("also_a")]
405        let union = UnionArray::try_new(
406            fields,
407            vec![0_i8, 1, 0].into(),
408            None,
409            vec![
410                Arc::new(StringArray::from(vec![
411                    Some("from_a"),
412                    None,
413                    Some("also_a"),
414                ])) as ArrayRef,
415                Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])),
416            ],
417        )
418        .unwrap();
419
420        let result = cast::cast(&union, &target).unwrap();
421        assert_eq!(result.data_type(), &target);
422        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
423
424        // pass 1 picks child "b" (Utf8View), so child "a" rows become NULL
425        assert!(arr.is_null(0));
426        assert_eq!(arr.value(1), "from_b");
427        assert!(arr.is_null(2));
428    }
429
430    // null values within the selected child array stay null.
431    // this is distinct from "wrong child array -> NULL": here the correct child array
432    // is active but its value is null.
433    #[test]
434    fn test_null_in_selected_child_array() {
435        let target = DataType::Utf8;
436
437        assert!(can_cast_types(
438            &int_str_union_type(UnionMode::Sparse),
439            &target
440        ));
441
442        // ["hello", NULL(str), "world"]
443        // all rows are the Utf8 child array, but row 1 has a null value
444        let union = UnionArray::try_new(
445            int_str_fields(),
446            vec![1_i8, 1, 1].into(),
447            None,
448            vec![
449                Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef,
450                Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
451            ],
452        )
453        .unwrap();
454
455        let result = cast::cast(&union, &target).unwrap();
456        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
457        assert_eq!(arr.value(0), "hello");
458        assert!(arr.is_null(1));
459        assert_eq!(arr.value(2), "world");
460    }
461
462    // empty union array returns a zero-length result of the target type.
463    #[test]
464    fn test_empty_union() {
465        let target = DataType::Utf8View;
466
467        assert!(can_cast_types(
468            &int_str_union_type(UnionMode::Sparse),
469            &target
470        ));
471
472        let union = UnionArray::try_new(
473            int_str_fields(),
474            Vec::<i8>::new().into(),
475            None,
476            vec![
477                Arc::new(Int32Array::from(Vec::<Option<i32>>::new())) as ArrayRef,
478                Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
479            ],
480        )
481        .unwrap();
482
483        let result = cast::cast(&union, &target).unwrap();
484        assert_eq!(result.data_type(), &target);
485        assert_eq!(result.len(), 0);
486    }
487}