Skip to main content

arrow_cast/cast/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Cast support for union arrays.
19
20use crate::cast::can_cast_types;
21use crate::cast_with_options;
22use arrow_array::{Array, ArrayRef, UnionArray};
23use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
24use arrow_select::union_extract::union_extract_by_id;
25
26use super::CastOptions;
27
28// this is used during child array selection to prefer a "close" type over a distant cast
29// for example: when targeting Utf8View, a Utf8 child is preferred over Int32 despite both being castable
30fn same_type_family(a: &DataType, b: &DataType) -> bool {
31    use DataType::*;
32    matches!(
33        (a, b),
34        (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View)
35            | (
36                Binary | LargeBinary | BinaryView,
37                Binary | LargeBinary | BinaryView
38            )
39            | (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64)
40            | (
41                UInt8 | UInt16 | UInt32 | UInt64,
42                UInt8 | UInt16 | UInt32 | UInt64
43            )
44            | (Float16 | Float32 | Float64, Float16 | Float32 | Float64)
45    )
46}
47
48/// Selects the best-matching child array from a [`UnionArray`] for a given target type
49///
50/// The goal is to find the source field whose type is closest to the target,
51/// so that the subsequent cast is as lossless as possible. The heuristic uses
52/// three passes with decreasing specificity:
53///
54/// 1. **Exact match**: field type equals the target type.
55/// 2. **Same type family**: field and target belong to the same logical family
56///    (e.g. `Utf8` and `Utf8View` are both strings). This avoids a greedy
57///    cross-family cast in pass 3 (e.g. picking `Int32` over `Utf8` when the
58///    target is `Utf8View`, since `can_cast_types(Int32, Utf8View)` is true)
59/// 3. **Castable**:`can_cast_types` reports the field can be cast to the target
60///    Nested target types are skipped here because union extraction introduces
61///    nulls, which can conflict with non-nullable inner fields
62///
63/// Each pass greedily picks the first matching field by type_id order
64pub(crate) fn resolve_child_array<'a>(
65    fields: &'a UnionFields,
66    target_type: &DataType,
67) -> Option<(i8, &'a FieldRef)> {
68    fields
69        .iter()
70        .find(|(_, f)| f.data_type() == target_type)
71        .or_else(|| {
72            fields
73                .iter()
74                .find(|(_, f)| same_type_family(f.data_type(), target_type))
75        })
76        .or_else(|| {
77            // skip nested types in pass 3 because union extraction introduces nulls,
78            // and casting nullable arrays to nested types like List/Struct/Map can fail
79            // when inner fields are non-nullable.
80            if target_type.is_nested() {
81                return None;
82            }
83            fields
84                .iter()
85                .find(|(_, f)| can_cast_types(f.data_type(), target_type))
86        })
87}
88
89/// Extracts the best-matching child array from a [`UnionArray`] for a given target type,
90/// and casts it to that type.
91///
92/// Rows where a different child array is active become NULL.
93/// If no child array matches, returns an error.
94///
95/// # Example
96///
97/// ```
98/// # use std::sync::Arc;
99/// # use arrow_schema::{DataType, Field, UnionFields};
100/// # use arrow_array::{UnionArray, StringArray, Int32Array, Array};
101/// # use arrow_cast::cast::union_extract_by_type;
102/// # use arrow_cast::CastOptions;
103/// let fields = UnionFields::try_new(
104///     [0, 1],
105///     [
106///         Field::new("int", DataType::Int32, true),
107///         Field::new("str", DataType::Utf8, true),
108///     ],
109/// ).unwrap();
110///
111/// let union = UnionArray::try_new(
112///     fields,
113///     vec![0, 1, 0].into(),
114///     None,
115///     vec![
116///         Arc::new(Int32Array::from(vec![Some(42), None, Some(99)])),
117///         Arc::new(StringArray::from(vec![None, Some("hello"), None])),
118///     ],
119/// )
120/// .unwrap();
121///
122/// // extract the Utf8 child array and cast to Utf8View
123/// let result = union_extract_by_type(&union, &DataType::Utf8View, &CastOptions::default()).unwrap();
124/// assert_eq!(result.data_type(), &DataType::Utf8View);
125/// assert!(result.is_null(0));   // Int32 row -> NULL
126/// assert!(!result.is_null(1));  // Utf8 row -> "hello"
127/// assert!(result.is_null(2));   // Int32 row -> NULL
128/// ```
129pub fn union_extract_by_type(
130    union_array: &UnionArray,
131    target_type: &DataType,
132    cast_options: &CastOptions,
133) -> Result<ArrayRef, ArrowError> {
134    let fields = match union_array.data_type() {
135        DataType::Union(fields, _) => fields,
136        _ => unreachable!("union_extract_by_type called on non-union array"),
137    };
138
139    let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
140        return Err(ArrowError::CastError(format!(
141            "cannot cast Union with fields {} to {}",
142            fields
143                .iter()
144                .map(|(_, f)| f.data_type().to_string())
145                .collect::<Vec<_>>()
146                .join(", "),
147            target_type
148        )));
149    };
150
151    let extracted = union_extract_by_id(union_array, type_id)?;
152
153    if extracted.data_type() == target_type {
154        return Ok(extracted);
155    }
156
157    cast_with_options(&extracted, target_type, cast_options)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::cast;
164    use arrow_array::*;
165    use arrow_schema::{Field, UnionFields, UnionMode};
166    use std::sync::Arc;
167
168    fn int_str_fields() -> UnionFields {
169        UnionFields::try_new(
170            [0, 1],
171            [
172                Field::new("int", DataType::Int32, true),
173                Field::new("str", DataType::Utf8, true),
174            ],
175        )
176        .unwrap()
177    }
178
179    fn int_str_union_type(mode: UnionMode) -> DataType {
180        DataType::Union(int_str_fields(), mode)
181    }
182
183    // pass 1: exact type match.
184    // Union(Int32, Utf8) targeting Utf8. The Utf8 child matches exactly.
185    // Int32 rows become NULL. tested for both sparse and dense.
186    #[test]
187    fn test_exact_type_match() {
188        let target = DataType::Utf8;
189
190        // sparse
191        assert!(can_cast_types(
192            &int_str_union_type(UnionMode::Sparse),
193            &target
194        ));
195
196        let sparse = UnionArray::try_new(
197            int_str_fields(),
198            vec![1_i8, 0, 1].into(),
199            None,
200            vec![
201                Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef,
202                Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
203            ],
204        )
205        .unwrap();
206
207        let result = cast::cast(&sparse, &target).unwrap();
208        assert_eq!(result.data_type(), &target);
209        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
210        assert_eq!(arr.value(0), "hello");
211        assert!(arr.is_null(1));
212        assert_eq!(arr.value(2), "world");
213
214        // dense
215        assert!(can_cast_types(
216            &int_str_union_type(UnionMode::Dense),
217            &target
218        ));
219
220        let dense = UnionArray::try_new(
221            int_str_fields(),
222            vec![1_i8, 0, 1].into(),
223            Some(vec![0_i32, 0, 1].into()),
224            vec![
225                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
226                Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
227            ],
228        )
229        .unwrap();
230
231        let result = cast::cast(&dense, &target).unwrap();
232        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
233        assert_eq!(arr.value(0), "hello");
234        assert!(arr.is_null(1));
235        assert_eq!(arr.value(2), "world");
236    }
237
238    // pass 2: same type family match.
239    // Union(Int32, Utf8) targeting Utf8View. No exact match, but Utf8 and Utf8View
240    // are in the same family. picks the Utf8 child array and casts to Utf8View.
241    // this is the bug that motivated this work: without pass 2, pass 3 would
242    // greedily pick Int32 (since can_cast_types(Int32, Utf8View) is true).
243    #[test]
244    fn test_same_family_utf8_to_utf8view() {
245        let target = DataType::Utf8View;
246
247        // sparse
248        assert!(can_cast_types(
249            &int_str_union_type(UnionMode::Sparse),
250            &target
251        ));
252
253        let sparse = UnionArray::try_new(
254            int_str_fields(),
255            vec![1_i8, 0, 1, 1].into(),
256            None,
257            vec![
258                Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef,
259                Arc::new(StringArray::from(vec![
260                    Some("agent_alpha"),
261                    None,
262                    Some("agent_beta"),
263                    None,
264                ])),
265            ],
266        )
267        .unwrap();
268
269        let result = cast::cast(&sparse, &target).unwrap();
270        assert_eq!(result.data_type(), &target);
271        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
272        assert_eq!(arr.value(0), "agent_alpha");
273        assert!(arr.is_null(1));
274        assert_eq!(arr.value(2), "agent_beta");
275        assert!(arr.is_null(3));
276
277        // dense
278        assert!(can_cast_types(
279            &int_str_union_type(UnionMode::Dense),
280            &target
281        ));
282
283        let dense = UnionArray::try_new(
284            int_str_fields(),
285            vec![1_i8, 0, 1].into(),
286            Some(vec![0_i32, 0, 1].into()),
287            vec![
288                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
289                Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])),
290            ],
291        )
292        .unwrap();
293
294        let result = cast::cast(&dense, &target).unwrap();
295        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
296        assert_eq!(arr.value(0), "alpha");
297        assert!(arr.is_null(1));
298        assert_eq!(arr.value(2), "beta");
299    }
300
301    // pass 3: one-directional cast across type families.
302    // Union(Int32, Utf8) targeting Boolean — no exact match, no family match.
303    // pass 3 picks Int32 (first child array where can_cast_types is true) and
304    // casts to Boolean (0 → false, nonzero → true). Utf8 rows become NULL.
305    #[test]
306    fn test_one_directional_cast() {
307        let target = DataType::Boolean;
308
309        // sparse
310        assert!(can_cast_types(
311            &int_str_union_type(UnionMode::Sparse),
312            &target
313        ));
314
315        let sparse = UnionArray::try_new(
316            int_str_fields(),
317            vec![0_i8, 1, 0].into(),
318            None,
319            vec![
320                Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef,
321                Arc::new(StringArray::from(vec![None, Some("hello"), None])),
322            ],
323        )
324        .unwrap();
325
326        let result = cast::cast(&sparse, &target).unwrap();
327        assert_eq!(result.data_type(), &target);
328        let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
329        assert!(arr.value(0));
330        assert!(arr.is_null(1));
331        assert!(!arr.value(2));
332
333        // dense
334        assert!(can_cast_types(
335            &int_str_union_type(UnionMode::Dense),
336            &target
337        ));
338
339        let dense = UnionArray::try_new(
340            int_str_fields(),
341            vec![0_i8, 1, 0].into(),
342            Some(vec![0_i32, 0, 1].into()),
343            vec![
344                Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef,
345                Arc::new(StringArray::from(vec![Some("hello")])),
346            ],
347        )
348        .unwrap();
349
350        let result = cast::cast(&dense, &target).unwrap();
351        let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
352        assert!(arr.value(0));
353        assert!(arr.is_null(1));
354        assert!(!arr.value(2));
355    }
356
357    // duplicate field names: ensure we resolve by type_id, not field name.
358    // Union has two children both named "val" — Int32 (type_id 0) and Utf8 (type_id 1).
359    // Casting to Utf8 should select the Utf8 child (type_id 1), not the Int32 child (type_id 0).
360    #[test]
361    fn test_duplicate_field_names() {
362        let fields = UnionFields::try_new(
363            [0, 1],
364            [
365                Field::new("val", DataType::Int32, true),
366                Field::new("val", DataType::Utf8, true),
367            ],
368        )
369        .unwrap();
370
371        let target = DataType::Utf8;
372
373        let sparse = UnionArray::try_new(
374            fields.clone(),
375            vec![0_i8, 1, 0, 1].into(),
376            None,
377            vec![
378                Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef,
379                Arc::new(StringArray::from(vec![
380                    None,
381                    Some("hello"),
382                    None,
383                    Some("world"),
384                ])),
385            ],
386        )
387        .unwrap();
388
389        let result = cast::cast(&sparse, &target).unwrap();
390        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
391        assert!(arr.is_null(0));
392        assert_eq!(arr.value(1), "hello");
393        assert!(arr.is_null(2));
394        assert_eq!(arr.value(3), "world");
395
396        let dense = UnionArray::try_new(
397            fields,
398            vec![0_i8, 1, 1].into(),
399            Some(vec![0_i32, 0, 1].into()),
400            vec![
401                Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
402                Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
403            ],
404        )
405        .unwrap();
406
407        let result = cast::cast(&dense, &target).unwrap();
408        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
409        assert!(arr.is_null(0));
410        assert_eq!(arr.value(1), "hello");
411        assert_eq!(arr.value(2), "world");
412    }
413
414    // no matching child array, all three passes fail.
415    // Union(Int32, Utf8) targeting Struct({x: Int32}). neither Int32 nor Utf8
416    // can be cast to a Struct, so both can_cast_types and cast return errors.
417    #[test]
418    fn test_no_match_errors() {
419        let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into());
420
421        assert!(!can_cast_types(
422            &int_str_union_type(UnionMode::Sparse),
423            &target
424        ));
425
426        let union = UnionArray::try_new(
427            int_str_fields(),
428            vec![0_i8, 1].into(),
429            None,
430            vec![
431                Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef,
432                Arc::new(StringArray::from(vec![None, Some("hello")])),
433            ],
434        )
435        .unwrap();
436
437        assert!(cast::cast(&union, &target).is_err());
438    }
439
440    // priority: exact match (pass 1) wins over family match (pass 2).
441    // Union(Utf8, Utf8View) targeting Utf8View. Both child arrays are in the string
442    // family, but Utf8View is an exact match. pass 1 should pick it, not Utf8.
443    #[test]
444    fn test_exact_match_preferred_over_family() {
445        let fields = UnionFields::try_new(
446            [0, 1],
447            [
448                Field::new("a", DataType::Utf8, true),
449                Field::new("b", DataType::Utf8View, true),
450            ],
451        )
452        .unwrap();
453        let target = DataType::Utf8View;
454
455        assert!(can_cast_types(
456            &DataType::Union(fields.clone(), UnionMode::Sparse),
457            &target,
458        ));
459
460        // [Utf8("from_a"), Utf8View("from_b"), Utf8("also_a")]
461        let union = UnionArray::try_new(
462            fields,
463            vec![0_i8, 1, 0].into(),
464            None,
465            vec![
466                Arc::new(StringArray::from(vec![
467                    Some("from_a"),
468                    None,
469                    Some("also_a"),
470                ])) as ArrayRef,
471                Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])),
472            ],
473        )
474        .unwrap();
475
476        let result = cast::cast(&union, &target).unwrap();
477        assert_eq!(result.data_type(), &target);
478        let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
479
480        // pass 1 picks child "b" (Utf8View), so child "a" rows become NULL
481        assert!(arr.is_null(0));
482        assert_eq!(arr.value(1), "from_b");
483        assert!(arr.is_null(2));
484    }
485
486    // null values within the selected child array stay null.
487    // this is distinct from "wrong child array -> NULL": here the correct child array
488    // is active but its value is null.
489    #[test]
490    fn test_null_in_selected_child_array() {
491        let target = DataType::Utf8;
492
493        assert!(can_cast_types(
494            &int_str_union_type(UnionMode::Sparse),
495            &target
496        ));
497
498        // ["hello", NULL(str), "world"]
499        // all rows are the Utf8 child array, but row 1 has a null value
500        let union = UnionArray::try_new(
501            int_str_fields(),
502            vec![1_i8, 1, 1].into(),
503            None,
504            vec![
505                Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef,
506                Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
507            ],
508        )
509        .unwrap();
510
511        let result = cast::cast(&union, &target).unwrap();
512        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
513        assert_eq!(arr.value(0), "hello");
514        assert!(arr.is_null(1));
515        assert_eq!(arr.value(2), "world");
516    }
517
518    // empty union array returns a zero-length result of the target type.
519    #[test]
520    fn test_empty_union() {
521        let target = DataType::Utf8View;
522
523        assert!(can_cast_types(
524            &int_str_union_type(UnionMode::Sparse),
525            &target
526        ));
527
528        let union = UnionArray::try_new(
529            int_str_fields(),
530            Vec::<i8>::new().into(),
531            None,
532            vec![
533                Arc::new(Int32Array::from(Vec::<Option<i32>>::new())) as ArrayRef,
534                Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
535            ],
536        )
537        .unwrap();
538
539        let result = cast::cast(&union, &target).unwrap();
540        assert_eq!(result.data_type(), &target);
541        assert_eq!(result.len(), 0);
542    }
543}