1use crate::cast::can_cast_types;
21use crate::cast_with_options;
22use arrow_array::{Array, ArrayRef, UnionArray};
23use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
24use arrow_select::union_extract::union_extract_by_id;
25
26use super::CastOptions;
27
28fn same_type_family(a: &DataType, b: &DataType) -> bool {
31 use DataType::*;
32 matches!(
33 (a, b),
34 (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View)
35 | (
36 Binary | LargeBinary | BinaryView,
37 Binary | LargeBinary | BinaryView
38 )
39 | (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64)
40 | (
41 UInt8 | UInt16 | UInt32 | UInt64,
42 UInt8 | UInt16 | UInt32 | UInt64
43 )
44 | (Float16 | Float32 | Float64, Float16 | Float32 | Float64)
45 )
46}
47
48pub(crate) fn resolve_child_array<'a>(
65 fields: &'a UnionFields,
66 target_type: &DataType,
67) -> Option<(i8, &'a FieldRef)> {
68 fields
69 .iter()
70 .find(|(_, f)| f.data_type() == target_type)
71 .or_else(|| {
72 fields
73 .iter()
74 .find(|(_, f)| same_type_family(f.data_type(), target_type))
75 })
76 .or_else(|| {
77 if target_type.is_nested() {
81 return None;
82 }
83 fields
84 .iter()
85 .find(|(_, f)| can_cast_types(f.data_type(), target_type))
86 })
87}
88
89pub fn union_extract_by_type(
130 union_array: &UnionArray,
131 target_type: &DataType,
132 cast_options: &CastOptions,
133) -> Result<ArrayRef, ArrowError> {
134 let fields = match union_array.data_type() {
135 DataType::Union(fields, _) => fields,
136 _ => unreachable!("union_extract_by_type called on non-union array"),
137 };
138
139 let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
140 return Err(ArrowError::CastError(format!(
141 "cannot cast Union with fields {} to {}",
142 fields
143 .iter()
144 .map(|(_, f)| f.data_type().to_string())
145 .collect::<Vec<_>>()
146 .join(", "),
147 target_type
148 )));
149 };
150
151 let extracted = union_extract_by_id(union_array, type_id)?;
152
153 if extracted.data_type() == target_type {
154 return Ok(extracted);
155 }
156
157 cast_with_options(&extracted, target_type, cast_options)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::cast;
164 use arrow_array::*;
165 use arrow_schema::{Field, UnionFields, UnionMode};
166 use std::sync::Arc;
167
168 fn int_str_fields() -> UnionFields {
169 UnionFields::try_new(
170 [0, 1],
171 [
172 Field::new("int", DataType::Int32, true),
173 Field::new("str", DataType::Utf8, true),
174 ],
175 )
176 .unwrap()
177 }
178
179 fn int_str_union_type(mode: UnionMode) -> DataType {
180 DataType::Union(int_str_fields(), mode)
181 }
182
183 #[test]
187 fn test_exact_type_match() {
188 let target = DataType::Utf8;
189
190 assert!(can_cast_types(
192 &int_str_union_type(UnionMode::Sparse),
193 &target
194 ));
195
196 let sparse = UnionArray::try_new(
197 int_str_fields(),
198 vec![1_i8, 0, 1].into(),
199 None,
200 vec![
201 Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef,
202 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
203 ],
204 )
205 .unwrap();
206
207 let result = cast::cast(&sparse, &target).unwrap();
208 assert_eq!(result.data_type(), &target);
209 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
210 assert_eq!(arr.value(0), "hello");
211 assert!(arr.is_null(1));
212 assert_eq!(arr.value(2), "world");
213
214 assert!(can_cast_types(
216 &int_str_union_type(UnionMode::Dense),
217 &target
218 ));
219
220 let dense = UnionArray::try_new(
221 int_str_fields(),
222 vec![1_i8, 0, 1].into(),
223 Some(vec![0_i32, 0, 1].into()),
224 vec![
225 Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
226 Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
227 ],
228 )
229 .unwrap();
230
231 let result = cast::cast(&dense, &target).unwrap();
232 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
233 assert_eq!(arr.value(0), "hello");
234 assert!(arr.is_null(1));
235 assert_eq!(arr.value(2), "world");
236 }
237
238 #[test]
244 fn test_same_family_utf8_to_utf8view() {
245 let target = DataType::Utf8View;
246
247 assert!(can_cast_types(
249 &int_str_union_type(UnionMode::Sparse),
250 &target
251 ));
252
253 let sparse = UnionArray::try_new(
254 int_str_fields(),
255 vec![1_i8, 0, 1, 1].into(),
256 None,
257 vec![
258 Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef,
259 Arc::new(StringArray::from(vec![
260 Some("agent_alpha"),
261 None,
262 Some("agent_beta"),
263 None,
264 ])),
265 ],
266 )
267 .unwrap();
268
269 let result = cast::cast(&sparse, &target).unwrap();
270 assert_eq!(result.data_type(), &target);
271 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
272 assert_eq!(arr.value(0), "agent_alpha");
273 assert!(arr.is_null(1));
274 assert_eq!(arr.value(2), "agent_beta");
275 assert!(arr.is_null(3));
276
277 assert!(can_cast_types(
279 &int_str_union_type(UnionMode::Dense),
280 &target
281 ));
282
283 let dense = UnionArray::try_new(
284 int_str_fields(),
285 vec![1_i8, 0, 1].into(),
286 Some(vec![0_i32, 0, 1].into()),
287 vec![
288 Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
289 Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])),
290 ],
291 )
292 .unwrap();
293
294 let result = cast::cast(&dense, &target).unwrap();
295 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
296 assert_eq!(arr.value(0), "alpha");
297 assert!(arr.is_null(1));
298 assert_eq!(arr.value(2), "beta");
299 }
300
301 #[test]
306 fn test_one_directional_cast() {
307 let target = DataType::Boolean;
308
309 assert!(can_cast_types(
311 &int_str_union_type(UnionMode::Sparse),
312 &target
313 ));
314
315 let sparse = UnionArray::try_new(
316 int_str_fields(),
317 vec![0_i8, 1, 0].into(),
318 None,
319 vec![
320 Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef,
321 Arc::new(StringArray::from(vec![None, Some("hello"), None])),
322 ],
323 )
324 .unwrap();
325
326 let result = cast::cast(&sparse, &target).unwrap();
327 assert_eq!(result.data_type(), &target);
328 let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
329 assert!(arr.value(0));
330 assert!(arr.is_null(1));
331 assert!(!arr.value(2));
332
333 assert!(can_cast_types(
335 &int_str_union_type(UnionMode::Dense),
336 &target
337 ));
338
339 let dense = UnionArray::try_new(
340 int_str_fields(),
341 vec![0_i8, 1, 0].into(),
342 Some(vec![0_i32, 0, 1].into()),
343 vec![
344 Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef,
345 Arc::new(StringArray::from(vec![Some("hello")])),
346 ],
347 )
348 .unwrap();
349
350 let result = cast::cast(&dense, &target).unwrap();
351 let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
352 assert!(arr.value(0));
353 assert!(arr.is_null(1));
354 assert!(!arr.value(2));
355 }
356
357 #[test]
361 fn test_duplicate_field_names() {
362 let fields = UnionFields::try_new(
363 [0, 1],
364 [
365 Field::new("val", DataType::Int32, true),
366 Field::new("val", DataType::Utf8, true),
367 ],
368 )
369 .unwrap();
370
371 let target = DataType::Utf8;
372
373 let sparse = UnionArray::try_new(
374 fields.clone(),
375 vec![0_i8, 1, 0, 1].into(),
376 None,
377 vec![
378 Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef,
379 Arc::new(StringArray::from(vec![
380 None,
381 Some("hello"),
382 None,
383 Some("world"),
384 ])),
385 ],
386 )
387 .unwrap();
388
389 let result = cast::cast(&sparse, &target).unwrap();
390 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
391 assert!(arr.is_null(0));
392 assert_eq!(arr.value(1), "hello");
393 assert!(arr.is_null(2));
394 assert_eq!(arr.value(3), "world");
395
396 let dense = UnionArray::try_new(
397 fields,
398 vec![0_i8, 1, 1].into(),
399 Some(vec![0_i32, 0, 1].into()),
400 vec![
401 Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
402 Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
403 ],
404 )
405 .unwrap();
406
407 let result = cast::cast(&dense, &target).unwrap();
408 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
409 assert!(arr.is_null(0));
410 assert_eq!(arr.value(1), "hello");
411 assert_eq!(arr.value(2), "world");
412 }
413
414 #[test]
418 fn test_no_match_errors() {
419 let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into());
420
421 assert!(!can_cast_types(
422 &int_str_union_type(UnionMode::Sparse),
423 &target
424 ));
425
426 let union = UnionArray::try_new(
427 int_str_fields(),
428 vec![0_i8, 1].into(),
429 None,
430 vec![
431 Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef,
432 Arc::new(StringArray::from(vec![None, Some("hello")])),
433 ],
434 )
435 .unwrap();
436
437 assert!(cast::cast(&union, &target).is_err());
438 }
439
440 #[test]
444 fn test_exact_match_preferred_over_family() {
445 let fields = UnionFields::try_new(
446 [0, 1],
447 [
448 Field::new("a", DataType::Utf8, true),
449 Field::new("b", DataType::Utf8View, true),
450 ],
451 )
452 .unwrap();
453 let target = DataType::Utf8View;
454
455 assert!(can_cast_types(
456 &DataType::Union(fields.clone(), UnionMode::Sparse),
457 &target,
458 ));
459
460 let union = UnionArray::try_new(
462 fields,
463 vec![0_i8, 1, 0].into(),
464 None,
465 vec![
466 Arc::new(StringArray::from(vec![
467 Some("from_a"),
468 None,
469 Some("also_a"),
470 ])) as ArrayRef,
471 Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])),
472 ],
473 )
474 .unwrap();
475
476 let result = cast::cast(&union, &target).unwrap();
477 assert_eq!(result.data_type(), &target);
478 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
479
480 assert!(arr.is_null(0));
482 assert_eq!(arr.value(1), "from_b");
483 assert!(arr.is_null(2));
484 }
485
486 #[test]
490 fn test_null_in_selected_child_array() {
491 let target = DataType::Utf8;
492
493 assert!(can_cast_types(
494 &int_str_union_type(UnionMode::Sparse),
495 &target
496 ));
497
498 let union = UnionArray::try_new(
501 int_str_fields(),
502 vec![1_i8, 1, 1].into(),
503 None,
504 vec![
505 Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef,
506 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
507 ],
508 )
509 .unwrap();
510
511 let result = cast::cast(&union, &target).unwrap();
512 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
513 assert_eq!(arr.value(0), "hello");
514 assert!(arr.is_null(1));
515 assert_eq!(arr.value(2), "world");
516 }
517
518 #[test]
520 fn test_empty_union() {
521 let target = DataType::Utf8View;
522
523 assert!(can_cast_types(
524 &int_str_union_type(UnionMode::Sparse),
525 &target
526 ));
527
528 let union = UnionArray::try_new(
529 int_str_fields(),
530 Vec::<i8>::new().into(),
531 None,
532 vec![
533 Arc::new(Int32Array::from(Vec::<Option<i32>>::new())) as ArrayRef,
534 Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
535 ],
536 )
537 .unwrap();
538
539 let result = cast::cast(&union, &target).unwrap();
540 assert_eq!(result.data_type(), &target);
541 assert_eq!(result.len(), 0);
542 }
543}