1use crate::cast::can_cast_types;
21use crate::cast_with_options;
22use arrow_array::{Array, ArrayRef, UnionArray};
23use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
24use arrow_select::union_extract::union_extract;
25
26use super::CastOptions;
27
28fn same_type_family(a: &DataType, b: &DataType) -> bool {
31 use DataType::*;
32 matches!(
33 (a, b),
34 (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View)
35 | (
36 Binary | LargeBinary | BinaryView,
37 Binary | LargeBinary | BinaryView
38 )
39 | (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64)
40 | (
41 UInt8 | UInt16 | UInt32 | UInt64,
42 UInt8 | UInt16 | UInt32 | UInt64
43 )
44 | (Float16 | Float32 | Float64, Float16 | Float32 | Float64)
45 )
46}
47
48pub(crate) fn resolve_child_array<'a>(
65 fields: &'a UnionFields,
66 target_type: &DataType,
67) -> Option<&'a FieldRef> {
68 fields
69 .iter()
70 .find(|(_, f)| f.data_type() == target_type)
71 .or_else(|| {
72 fields
73 .iter()
74 .find(|(_, f)| same_type_family(f.data_type(), target_type))
75 })
76 .or_else(|| {
77 if target_type.is_nested() {
81 return None;
82 }
83 fields
84 .iter()
85 .find(|(_, f)| can_cast_types(f.data_type(), target_type))
86 })
87 .map(|(_, f)| f)
88}
89
90pub fn union_extract_by_type(
131 union_array: &UnionArray,
132 target_type: &DataType,
133 cast_options: &CastOptions,
134) -> Result<ArrayRef, ArrowError> {
135 let fields = match union_array.data_type() {
136 DataType::Union(fields, _) => fields,
137 _ => unreachable!("union_extract_by_type called on non-union array"),
138 };
139
140 let Some(field) = resolve_child_array(fields, target_type) else {
141 return Err(ArrowError::CastError(format!(
142 "cannot cast Union with fields {} to {}",
143 fields
144 .iter()
145 .map(|(_, f)| f.data_type().to_string())
146 .collect::<Vec<_>>()
147 .join(", "),
148 target_type
149 )));
150 };
151
152 let extracted = union_extract(union_array, field.name())?;
153
154 if extracted.data_type() == target_type {
155 return Ok(extracted);
156 }
157
158 cast_with_options(&extracted, target_type, cast_options)
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::cast;
165 use arrow_array::*;
166 use arrow_schema::{Field, UnionFields, UnionMode};
167 use std::sync::Arc;
168
169 fn int_str_fields() -> UnionFields {
170 UnionFields::try_new(
171 [0, 1],
172 [
173 Field::new("int", DataType::Int32, true),
174 Field::new("str", DataType::Utf8, true),
175 ],
176 )
177 .unwrap()
178 }
179
180 fn int_str_union_type(mode: UnionMode) -> DataType {
181 DataType::Union(int_str_fields(), mode)
182 }
183
184 #[test]
188 fn test_exact_type_match() {
189 let target = DataType::Utf8;
190
191 assert!(can_cast_types(
193 &int_str_union_type(UnionMode::Sparse),
194 &target
195 ));
196
197 let sparse = UnionArray::try_new(
198 int_str_fields(),
199 vec![1_i8, 0, 1].into(),
200 None,
201 vec![
202 Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef,
203 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
204 ],
205 )
206 .unwrap();
207
208 let result = cast::cast(&sparse, &target).unwrap();
209 assert_eq!(result.data_type(), &target);
210 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
211 assert_eq!(arr.value(0), "hello");
212 assert!(arr.is_null(1));
213 assert_eq!(arr.value(2), "world");
214
215 assert!(can_cast_types(
217 &int_str_union_type(UnionMode::Dense),
218 &target
219 ));
220
221 let dense = UnionArray::try_new(
222 int_str_fields(),
223 vec![1_i8, 0, 1].into(),
224 Some(vec![0_i32, 0, 1].into()),
225 vec![
226 Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
227 Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
228 ],
229 )
230 .unwrap();
231
232 let result = cast::cast(&dense, &target).unwrap();
233 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
234 assert_eq!(arr.value(0), "hello");
235 assert!(arr.is_null(1));
236 assert_eq!(arr.value(2), "world");
237 }
238
239 #[test]
245 fn test_same_family_utf8_to_utf8view() {
246 let target = DataType::Utf8View;
247
248 assert!(can_cast_types(
250 &int_str_union_type(UnionMode::Sparse),
251 &target
252 ));
253
254 let sparse = UnionArray::try_new(
255 int_str_fields(),
256 vec![1_i8, 0, 1, 1].into(),
257 None,
258 vec![
259 Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef,
260 Arc::new(StringArray::from(vec![
261 Some("agent_alpha"),
262 None,
263 Some("agent_beta"),
264 None,
265 ])),
266 ],
267 )
268 .unwrap();
269
270 let result = cast::cast(&sparse, &target).unwrap();
271 assert_eq!(result.data_type(), &target);
272 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
273 assert_eq!(arr.value(0), "agent_alpha");
274 assert!(arr.is_null(1));
275 assert_eq!(arr.value(2), "agent_beta");
276 assert!(arr.is_null(3));
277
278 assert!(can_cast_types(
280 &int_str_union_type(UnionMode::Dense),
281 &target
282 ));
283
284 let dense = UnionArray::try_new(
285 int_str_fields(),
286 vec![1_i8, 0, 1].into(),
287 Some(vec![0_i32, 0, 1].into()),
288 vec![
289 Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
290 Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])),
291 ],
292 )
293 .unwrap();
294
295 let result = cast::cast(&dense, &target).unwrap();
296 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
297 assert_eq!(arr.value(0), "alpha");
298 assert!(arr.is_null(1));
299 assert_eq!(arr.value(2), "beta");
300 }
301
302 #[test]
307 fn test_one_directional_cast() {
308 let target = DataType::Boolean;
309
310 assert!(can_cast_types(
312 &int_str_union_type(UnionMode::Sparse),
313 &target
314 ));
315
316 let sparse = UnionArray::try_new(
317 int_str_fields(),
318 vec![0_i8, 1, 0].into(),
319 None,
320 vec![
321 Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef,
322 Arc::new(StringArray::from(vec![None, Some("hello"), None])),
323 ],
324 )
325 .unwrap();
326
327 let result = cast::cast(&sparse, &target).unwrap();
328 assert_eq!(result.data_type(), &target);
329 let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
330 assert!(arr.value(0));
331 assert!(arr.is_null(1));
332 assert!(!arr.value(2));
333
334 assert!(can_cast_types(
336 &int_str_union_type(UnionMode::Dense),
337 &target
338 ));
339
340 let dense = UnionArray::try_new(
341 int_str_fields(),
342 vec![0_i8, 1, 0].into(),
343 Some(vec![0_i32, 0, 1].into()),
344 vec![
345 Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef,
346 Arc::new(StringArray::from(vec![Some("hello")])),
347 ],
348 )
349 .unwrap();
350
351 let result = cast::cast(&dense, &target).unwrap();
352 let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
353 assert!(arr.value(0));
354 assert!(arr.is_null(1));
355 assert!(!arr.value(2));
356 }
357
358 #[test]
362 fn test_no_match_errors() {
363 let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into());
364
365 assert!(!can_cast_types(
366 &int_str_union_type(UnionMode::Sparse),
367 &target
368 ));
369
370 let union = UnionArray::try_new(
371 int_str_fields(),
372 vec![0_i8, 1].into(),
373 None,
374 vec![
375 Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef,
376 Arc::new(StringArray::from(vec![None, Some("hello")])),
377 ],
378 )
379 .unwrap();
380
381 assert!(cast::cast(&union, &target).is_err());
382 }
383
384 #[test]
388 fn test_exact_match_preferred_over_family() {
389 let fields = UnionFields::try_new(
390 [0, 1],
391 [
392 Field::new("a", DataType::Utf8, true),
393 Field::new("b", DataType::Utf8View, true),
394 ],
395 )
396 .unwrap();
397 let target = DataType::Utf8View;
398
399 assert!(can_cast_types(
400 &DataType::Union(fields.clone(), UnionMode::Sparse),
401 &target,
402 ));
403
404 let union = UnionArray::try_new(
406 fields,
407 vec![0_i8, 1, 0].into(),
408 None,
409 vec![
410 Arc::new(StringArray::from(vec![
411 Some("from_a"),
412 None,
413 Some("also_a"),
414 ])) as ArrayRef,
415 Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])),
416 ],
417 )
418 .unwrap();
419
420 let result = cast::cast(&union, &target).unwrap();
421 assert_eq!(result.data_type(), &target);
422 let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
423
424 assert!(arr.is_null(0));
426 assert_eq!(arr.value(1), "from_b");
427 assert!(arr.is_null(2));
428 }
429
430 #[test]
434 fn test_null_in_selected_child_array() {
435 let target = DataType::Utf8;
436
437 assert!(can_cast_types(
438 &int_str_union_type(UnionMode::Sparse),
439 &target
440 ));
441
442 let union = UnionArray::try_new(
445 int_str_fields(),
446 vec![1_i8, 1, 1].into(),
447 None,
448 vec![
449 Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef,
450 Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
451 ],
452 )
453 .unwrap();
454
455 let result = cast::cast(&union, &target).unwrap();
456 let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
457 assert_eq!(arr.value(0), "hello");
458 assert!(arr.is_null(1));
459 assert_eq!(arr.value(2), "world");
460 }
461
462 #[test]
464 fn test_empty_union() {
465 let target = DataType::Utf8View;
466
467 assert!(can_cast_types(
468 &int_str_union_type(UnionMode::Sparse),
469 &target
470 ));
471
472 let union = UnionArray::try_new(
473 int_str_fields(),
474 Vec::<i8>::new().into(),
475 None,
476 vec![
477 Arc::new(Int32Array::from(Vec::<Option<i32>>::new())) as ArrayRef,
478 Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
479 ],
480 )
481 .unwrap();
482
483 let result = cast::cast(&union, &target).unwrap();
484 assert_eq!(result.data_type(), &target);
485 assert_eq!(result.len(), 0);
486 }
487}