Skip to main content

arrow_select/
union_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines union_extract kernel for [UnionArray]
19
20use crate::take::take;
21use arrow_array::{
22    Array, ArrayRef, BooleanArray, Int32Array, Scalar, UnionArray, make_array, new_empty_array,
23    new_null_array,
24};
25use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer, bit_util};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31/// Returns the value of the target field when selected, or NULL otherwise.
32/// ```text
33/// ┌─────────────────┐                                   ┌─────────────────┐
34/// │       A=1       │                                   │        1        │
35/// ├─────────────────┤                                   ├─────────────────┤
36/// │      A=NULL     │                                   │       NULL      │
37/// ├─────────────────┤    union_extract(values, 'A')     ├─────────────────┤
38/// │      B='t'      │  ────────────────────────────▶    │       NULL      │
39/// ├─────────────────┤                                   ├─────────────────┤
40/// │       A=3       │                                   │        3        │
41/// ├─────────────────┤                                   ├─────────────────┤
42/// │      B=NULL     │                                   │       NULL      │
43/// └─────────────────┘                                   └─────────────────┘
44///    union array                                              result
45/// ```
46/// # Errors
47///
48/// Returns error if target field is not found
49///
50/// # Examples
51/// ```
52/// # use std::sync::Arc;
53/// # use arrow_schema::{DataType, Field, UnionFields};
54/// # use arrow_array::{UnionArray, StringArray, Int32Array};
55/// # use arrow_select::union_extract::union_extract;
56/// let fields = UnionFields::try_new(
57///     [1, 3],
58///     [
59///         Field::new("A", DataType::Int32, true),
60///         Field::new("B", DataType::Utf8, true)
61///     ]
62/// ).unwrap();
63///
64/// let union = UnionArray::try_new(
65///     fields,
66///     vec![1, 1, 3, 1, 3].into(),
67///     None,
68///     vec![
69///         Arc::new(Int32Array::from(vec![Some(1), None, None, Some(3), Some(0)])),
70///         Arc::new(StringArray::from(vec![None, None, Some("t"), Some("."), None]))
71///     ]
72/// ).unwrap();
73///
74/// // Extract field A
75/// let extracted = union_extract(&union, "A").unwrap();
76///
77/// assert_eq!(*extracted, Int32Array::from(vec![Some(1), None, None, Some(3), None]));
78/// ```
79pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80    let fields = match union_array.data_type() {
81        DataType::Union(fields, _) => fields,
82        _ => unreachable!(),
83    };
84
85    let (target_type_id, _) = fields
86        .iter()
87        .find(|field| field.1.name() == target)
88        .ok_or_else(|| {
89            ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90        })?;
91
92    union_extract_impl(union_array, fields, target_type_id)
93}
94
95/// Like [`union_extract`], but selects the child by `type_id` rather than by
96/// field name.
97///
98/// This avoids ambiguity when the union contains duplicate field names.
99///
100/// # Errors
101///
102/// Returns error if `target_type_id` does not correspond to a field in the union.
103pub fn union_extract_by_id(
104    union_array: &UnionArray,
105    target_type_id: i8,
106) -> Result<ArrayRef, ArrowError> {
107    let fields = match union_array.data_type() {
108        DataType::Union(fields, _) => fields,
109        _ => unreachable!(),
110    };
111
112    if fields.iter().all(|(id, _)| id != target_type_id) {
113        return Err(ArrowError::InvalidArgumentError(format!(
114            "type_id {target_type_id} not found on union"
115        )));
116    }
117
118    union_extract_impl(union_array, fields, target_type_id)
119}
120
121fn union_extract_impl(
122    union_array: &UnionArray,
123    fields: &UnionFields,
124    target_type_id: i8,
125) -> Result<ArrayRef, ArrowError> {
126    match union_array.offsets() {
127        Some(_) => extract_dense(union_array, fields, target_type_id),
128        None => extract_sparse(union_array, fields, target_type_id),
129    }
130}
131
132fn extract_sparse(
133    union_array: &UnionArray,
134    fields: &UnionFields,
135    target_type_id: i8,
136) -> Result<ArrayRef, ArrowError> {
137    let target = union_array.child(target_type_id);
138
139    if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child
140        || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children
141        || target.null_count() == target.len() || target.data_type().is_null()
142    // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null
143    {
144        Ok(Arc::clone(target))
145    } else {
146        match eq_scalar(union_array.type_ids(), target_type_id) {
147            // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target
148            BoolValue::Scalar(true) => Ok(Arc::clone(target)),
149            // case 3: none type_id matches our target, the result is a null array
150            BoolValue::Scalar(false) => {
151                if layout(target.data_type()).can_contain_null_mask {
152                    // case 3.1: target array can contain a null mask
153                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
154                    let data = unsafe {
155                        target
156                            .into_data()
157                            .into_builder()
158                            .nulls(Some(NullBuffer::new_null(target.len())))
159                            .build_unchecked()
160                    };
161
162                    Ok(make_array(data))
163                } else {
164                    // case 3.2: target can't contain a null mask
165                    Ok(new_null_array(target.data_type(), target.len()))
166                }
167            }
168            // case 4: some but not all type_id matches our target
169            BoolValue::Buffer(selected) => {
170                if layout(target.data_type()).can_contain_null_mask {
171                    // case 4.1: target array can contain a null mask
172                    let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
173                        // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks
174                        // the case where n.null_count() == n.len() is cheaply handled at case 1.3
175                        Some(nulls) => &selected & nulls.inner(),
176                        // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask
177                        None => selected,
178                    };
179
180                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
181                    let data = unsafe {
182                        assert_eq!(nulls.len(), target.len());
183
184                        target
185                            .into_data()
186                            .into_builder()
187                            .nulls(Some(nulls.into()))
188                            .build_unchecked()
189                    };
190
191                    Ok(make_array(data))
192                } else {
193                    // case 4.2: target can't containt a null mask, zip the values that match with a null value
194                    Ok(crate::zip::zip(
195                        &BooleanArray::new(selected, None),
196                        target,
197                        &Scalar::new(new_null_array(target.data_type(), 1)),
198                    )?)
199                }
200            }
201        }
202    }
203}
204
205fn extract_dense(
206    union_array: &UnionArray,
207    fields: &UnionFields,
208    target_type_id: i8,
209) -> Result<ArrayRef, ArrowError> {
210    let target = union_array.child(target_type_id);
211    let offsets = union_array.offsets().unwrap();
212
213    if union_array.is_empty() {
214        // case 1: the union is empty
215        if target.is_empty() {
216            // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array
217            Ok(Arc::clone(target))
218        } else {
219            // case 1.2: the target is not empty, allocate a new empty array
220            Ok(new_empty_array(target.data_type()))
221        }
222    } else if target.is_empty() {
223        // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array
224        Ok(new_null_array(target.data_type(), union_array.len()))
225    } else if target.null_count() == target.len() || target.data_type().is_null() {
226        // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array
227        match target.len().cmp(&union_array.len()) {
228            // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array
229            Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
230            // case 3.2: target equals the union len, return it direcly
231            Ordering::Equal => Ok(Arc::clone(target)),
232            // case 3.3: target len is bigger than the union len, slice it
233            Ordering::Greater => Ok(target.slice(0, union_array.len())),
234        }
235    } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target
236        || fields
237            .iter()
238            .filter(|(field_type_id, _)| *field_type_id != target_type_id)
239            .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
240    // case B: since siblings are empty, every type id must matches our target
241    {
242        // case 4: every type id matches our target
243        Ok(extract_dense_all_selected(union_array, target, offsets)?)
244    } else {
245        match eq_scalar(union_array.type_ids(), target_type_id) {
246            // case 4C: all type ids matches our target.
247            // Non empty sibling without any selected value may happen after slicing the parent union,
248            // since only type_ids and offsets are sliced, not the children
249            BoolValue::Scalar(true) => {
250                Ok(extract_dense_all_selected(union_array, target, offsets)?)
251            }
252            BoolValue::Scalar(false) => {
253                // case 5: none type_id matches our target, so the result array will be completely null
254                // Non empty target without any selected value may happen after slicing the parent union,
255                // since only type_ids and offsets are sliced, not the children
256                match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
257                    (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array
258                    | (_, false) => { // case 5.1B: target array can't contain a null mask
259                        Ok(new_null_array(target.data_type(), union_array.len()))
260                    }
261                    // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer
262                    (Ordering::Equal, true) => {
263                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
264                        let data = unsafe {
265                            target
266                                .into_data()
267                                .into_builder()
268                                .nulls(Some(NullBuffer::new_null(union_array.len())))
269                                .build_unchecked()
270                        };
271
272                        Ok(make_array(data))
273                    }
274                    // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer
275                    (Ordering::Greater, true) => {
276                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
277                        let data = unsafe {
278                            target
279                                .into_data()
280                                .slice(0, union_array.len())
281                                .into_builder()
282                                .nulls(Some(NullBuffer::new_null(union_array.len())))
283                                .build_unchecked()
284                        };
285
286                        Ok(make_array(data))
287                    }
288                }
289            }
290            BoolValue::Buffer(selected) => {
291                //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, use a valid null
292                Ok(take(
293                    target,
294                    &Int32Array::try_new(offsets.clone(), Some(selected.into()))?,
295                    None,
296                )?)
297            }
298        }
299    }
300}
301
302fn extract_dense_all_selected(
303    union_array: &UnionArray,
304    target: &Arc<dyn Array>,
305    offsets: &ScalarBuffer<i32>,
306) -> Result<ArrayRef, ArrowError> {
307    let sequential =
308        target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
309
310    if sequential && target.len() == union_array.len() {
311        // case 1: all offsets are sequential and both lengths match, return the array directly
312        Ok(Arc::clone(target))
313    } else if sequential && target.len() > union_array.len() {
314        // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset
315        Ok(target.slice(offsets[0] as usize, union_array.len()))
316    } else {
317        // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array
318        let indices = Int32Array::try_new(offsets.clone(), None)?;
319
320        Ok(take(target, &indices, None)?)
321    }
322}
323
324const EQ_SCALAR_CHUNK_SIZE: usize = 512;
325
326/// The result of checking which type_ids matches the target type_id
327#[derive(Debug, PartialEq)]
328enum BoolValue {
329    /// If true, all type_ids matches the target type_id
330    /// If false, none type_ids matches the target type_id
331    Scalar(bool),
332    /// A mask represeting which type_ids matches the target type_id
333    Buffer(BooleanBuffer),
334}
335
336fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
337    eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
338}
339
340fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
341    type_ids
342        .chunks(chunk_size)
343        .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
344        .map(|chunk| chunk.len())
345        .sum()
346}
347
348// This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values.
349fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
350    let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
351
352    let (set_bits, val) = if true_bits == type_ids.len() {
353        return BoolValue::Scalar(true);
354    } else if true_bits == 0 {
355        let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
356
357        if false_bits == type_ids.len() {
358            return BoolValue::Scalar(false);
359        } else {
360            (false_bits, false)
361        }
362    } else {
363        (true_bits, true)
364    };
365
366    // restrict to chunk boundaries
367    let set_bits = set_bits - set_bits % 64;
368
369    let mut buffer =
370        MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
371
372    buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
373        chunk
374            .iter()
375            .copied()
376            .enumerate()
377            .fold(0, |packed, (bit_idx, v)| {
378                packed | (((v == target) as u64) << bit_idx)
379            })
380    }));
381
382    BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
383}
384
385const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
386
387fn is_sequential(offsets: &[i32]) -> bool {
388    is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
389}
390
391fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
392    if offsets.is_empty() {
393        return true;
394    }
395
396    // fast check this common combination:
397    // 1: sequential nulls are represented as a single null value on the values array, pointed by the same offset multiple times
398    // 2: valid values offsets increase one by one.
399    // example for an union with a single field A with type_id 0:
400    // union    = A=7 A=NULL A=NULL A=5 A=9
401    // a values = 7 NULL 5 9
402    // offsets  = 0 1 1 2 3
403    // type_ids = 0 0 0 0 0
404    // this also checks if the last chunk/remainder is sequential relative to the first offset
405    if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
406        return false;
407    }
408
409    let chunks = offsets.chunks_exact(N);
410
411    let remainder = chunks.remainder();
412
413    chunks.enumerate().all(|(i, chunk)| {
414        let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
415
416        //checks if values within chunk are sequential
417        chunk_array
418            .iter()
419            .copied()
420            .enumerate()
421            .fold(true, |acc, (i, offset)| {
422                acc & (offset == chunk_array[0] + i as i32)
423            })
424            && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset
425    }) && remainder
426        .iter()
427        .copied()
428        .enumerate()
429        .fold(true, |acc, (i, offset)| {
430            acc & (offset == remainder[0] + i as i32)
431        }) //if the remainder is sequential relative to the first offset is checked at the start of the function
432}
433
434#[cfg(test)]
435mod tests {
436    use super::{
437        BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id,
438    };
439    use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
440    use arrow_buffer::{BooleanBuffer, ScalarBuffer};
441    use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
442    use std::sync::Arc;
443
444    #[test]
445    fn test_eq_scalar() {
446        //multiple all equal chunks, so it's loop and sum logic it's tested
447        //multiple chunks after, so it's loop logic it's tested
448        const ARRAY_LEN: usize = 64 * 4;
449
450        //so out of 64 boundaries chunks can be generated and checked for
451        const EQ_SCALAR_CHUNK_SIZE: usize = 3;
452
453        fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
454            eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
455        }
456
457        fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
458            BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
459        }
460
461        assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
462
463        assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
464        assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
465
466        let mut values = [1; ARRAY_LEN];
467
468        assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
469        assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
470
471        //every subslice should return the same value
472        for i in 1..ARRAY_LEN {
473            assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
474            assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
475        }
476
477        // test that a single change anywhere is checked for
478        for i in 0..ARRAY_LEN {
479            values[i] = 2;
480
481            assert_eq!(
482                eq_scalar(&values, 1),
483                BoolValue::Buffer(cross_check(&values, 1))
484            );
485            assert_eq!(
486                eq_scalar(&values, 2),
487                BoolValue::Buffer(cross_check(&values, 2))
488            );
489
490            values[i] = 1;
491        }
492    }
493
494    #[test]
495    fn test_is_sequential() {
496        /*
497        the smallest value that satisfies:
498        >1 so the fold logic of a exact chunk executes
499        >2 so a >1 non-exact remainder can exist, and it's fold logic executes
500         */
501        const CHUNK_SIZE: usize = 3;
502        //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2:
503        //multiple(2) exact chunks, so the AND logic between them executes
504        //a >1(2) remainder, so:
505        //    the AND logic between all exact chunks and the remainder executes
506        //    the remainder fold logic executes
507
508        fn is_sequential(v: &[i32]) -> bool {
509            is_sequential_generic::<CHUNK_SIZE>(v)
510        }
511
512        assert!(is_sequential(&[])); //empty
513        assert!(is_sequential(&[1])); //single
514
515        assert!(is_sequential(&[1, 2]));
516        assert!(is_sequential(&[1, 2, 3]));
517        assert!(is_sequential(&[1, 2, 3, 4]));
518        assert!(is_sequential(&[1, 2, 3, 4, 5]));
519        assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
520        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
521        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
522
523        assert!(!is_sequential(&[8, 7]));
524        assert!(!is_sequential(&[8, 7, 6]));
525        assert!(!is_sequential(&[8, 7, 6, 5]));
526        assert!(!is_sequential(&[8, 7, 6, 5, 4]));
527        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
528        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
529        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
530
531        assert!(!is_sequential(&[0, 2]));
532        assert!(!is_sequential(&[1, 0]));
533
534        assert!(!is_sequential(&[0, 2, 3]));
535        assert!(!is_sequential(&[1, 0, 3]));
536        assert!(!is_sequential(&[1, 2, 0]));
537
538        assert!(!is_sequential(&[0, 2, 3, 4]));
539        assert!(!is_sequential(&[1, 0, 3, 4]));
540        assert!(!is_sequential(&[1, 2, 0, 4]));
541        assert!(!is_sequential(&[1, 2, 3, 0]));
542
543        assert!(!is_sequential(&[0, 2, 3, 4, 5]));
544        assert!(!is_sequential(&[1, 0, 3, 4, 5]));
545        assert!(!is_sequential(&[1, 2, 0, 4, 5]));
546        assert!(!is_sequential(&[1, 2, 3, 0, 5]));
547        assert!(!is_sequential(&[1, 2, 3, 4, 0]));
548
549        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
550        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
551        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
552        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
553        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
554        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
555
556        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
557        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
558        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
559        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
560        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
561        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
562        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
563
564        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
565        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
566        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
567        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
568        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
569        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
570        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
571        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
572
573        // checks increments at the chunk boundary
574        assert!(!is_sequential(&[1, 2, 3, 5]));
575        assert!(!is_sequential(&[1, 2, 3, 5, 6]));
576        assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
577        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
578        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
579    }
580
581    fn str1() -> UnionFields {
582        UnionFields::try_new(vec![1], vec![Field::new("str", DataType::Utf8, true)]).unwrap()
583    }
584
585    fn str1_int3() -> UnionFields {
586        UnionFields::try_new(
587            vec![1, 3],
588            vec![
589                Field::new("str", DataType::Utf8, true),
590                Field::new("int", DataType::Int32, true),
591            ],
592        )
593        .unwrap()
594    }
595
596    #[test]
597    fn sparse_1_1_single_field() {
598        let union = UnionArray::try_new(
599            //single field
600            str1(),
601            ScalarBuffer::from(vec![1, 1]), // non empty, every type id must match
602            None,                           //sparse
603            vec![
604                Arc::new(StringArray::from(vec!["a", "b"])), // not null
605            ],
606        )
607        .unwrap();
608
609        let expected = StringArray::from(vec!["a", "b"]);
610        let extracted = union_extract(&union, "str").unwrap();
611
612        assert_eq!(extracted.into_data(), expected.into_data());
613    }
614
615    #[test]
616    fn sparse_1_2_empty() {
617        let union = UnionArray::try_new(
618            // multiple fields
619            str1_int3(),
620            ScalarBuffer::from(vec![]), //empty union
621            None,                       // sparse
622            vec![
623                Arc::new(StringArray::new_null(0)),
624                Arc::new(Int32Array::new_null(0)),
625            ],
626        )
627        .unwrap();
628
629        let expected = StringArray::new_null(0);
630        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
631
632        assert_eq!(extracted.into_data(), expected.into_data());
633    }
634
635    #[test]
636    fn sparse_1_3a_null_target() {
637        let union = UnionArray::try_new(
638            // multiple fields
639            UnionFields::try_new(
640                vec![1, 3],
641                vec![
642                    Field::new("str", DataType::Utf8, true),
643                    Field::new("null", DataType::Null, true), // target type is Null
644                ],
645            )
646            .unwrap(),
647            ScalarBuffer::from(vec![1]), //not empty
648            None,                        // sparse
649            vec![
650                Arc::new(StringArray::new_null(1)),
651                Arc::new(NullArray::new(1)), // null data type
652            ],
653        )
654        .unwrap();
655
656        let expected = NullArray::new(1);
657        let extracted = union_extract(&union, "null").unwrap();
658
659        assert_eq!(extracted.into_data(), expected.into_data());
660    }
661
662    #[test]
663    fn sparse_1_3b_null_target() {
664        let union = UnionArray::try_new(
665            // multiple fields
666            str1_int3(),
667            ScalarBuffer::from(vec![1]), //not empty
668            None,                        // sparse
669            vec![
670                Arc::new(StringArray::new_null(1)), //all null
671                Arc::new(Int32Array::new_null(1)),
672            ],
673        )
674        .unwrap();
675
676        let expected = StringArray::new_null(1);
677        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
678
679        assert_eq!(extracted.into_data(), expected.into_data());
680    }
681
682    #[test]
683    fn sparse_2_all_types_match() {
684        let union = UnionArray::try_new(
685            //multiple fields
686            str1_int3(),
687            ScalarBuffer::from(vec![3, 3]), // all types match
688            None,                           //sparse
689            vec![
690                Arc::new(StringArray::new_null(2)),
691                Arc::new(Int32Array::from(vec![1, 4])), // not null
692            ],
693        )
694        .unwrap();
695
696        let expected = Int32Array::from(vec![1, 4]);
697        let extracted = union_extract(&union, "int").unwrap();
698
699        assert_eq!(extracted.into_data(), expected.into_data());
700    }
701
702    #[test]
703    fn sparse_3_1_none_match_target_can_contain_null_mask() {
704        let union = UnionArray::try_new(
705            //multiple fields
706            str1_int3(),
707            ScalarBuffer::from(vec![1, 1, 1, 1]), // none match
708            None,                                 // sparse
709            vec![
710                Arc::new(StringArray::new_null(4)),
711                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null
712            ],
713        )
714        .unwrap();
715
716        let expected = Int32Array::new_null(4);
717        let extracted = union_extract(&union, "int").unwrap();
718
719        assert_eq!(extracted.into_data(), expected.into_data());
720    }
721
722    fn str1_union3(union3_datatype: DataType) -> UnionFields {
723        UnionFields::try_new(
724            vec![1, 3],
725            vec![
726                Field::new("str", DataType::Utf8, true),
727                Field::new("union", union3_datatype, true),
728            ],
729        )
730        .unwrap()
731    }
732
733    #[test]
734    fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
735        let target_fields = str1();
736        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
737
738        let union = UnionArray::try_new(
739            //multiple fields
740            str1_union3(target_type.clone()),
741            ScalarBuffer::from(vec![1, 1]), // none match
742            None,                           //sparse
743            vec![
744                Arc::new(StringArray::new_null(2)),
745                //target is not null
746                Arc::new(
747                    UnionArray::try_new(
748                        target_fields.clone(),
749                        ScalarBuffer::from(vec![1, 1]),
750                        None,
751                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
752                    )
753                    .unwrap(),
754                ),
755            ],
756        )
757        .unwrap();
758
759        let expected = new_null_array(&target_type, 2);
760        let extracted = union_extract(&union, "union").unwrap();
761
762        assert_eq!(extracted.into_data(), expected.into_data());
763    }
764
765    #[test]
766    fn sparse_4_1_1_target_with_nulls() {
767        let union = UnionArray::try_new(
768            //multiple fields
769            str1_int3(),
770            ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types
771            None,                                 // sparse
772            vec![
773                Arc::new(StringArray::new_null(4)),
774                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls
775            ],
776        )
777        .unwrap();
778
779        let expected = Int32Array::from(vec![None, Some(4), None, None]);
780        let extracted = union_extract(&union, "int").unwrap();
781
782        assert_eq!(extracted.into_data(), expected.into_data());
783    }
784
785    #[test]
786    fn sparse_4_1_2_target_without_nulls() {
787        let union = UnionArray::try_new(
788            //multiple fields
789            str1_int3(),
790            ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types
791            None,                              // sparse
792            vec![
793                Arc::new(StringArray::new_null(3)),
794                Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls
795            ],
796        )
797        .unwrap();
798
799        let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
800        let extracted = union_extract(&union, "int").unwrap();
801
802        assert_eq!(extracted.into_data(), expected.into_data());
803    }
804
805    #[test]
806    fn sparse_4_2_some_match_target_cant_contain_null_mask() {
807        let target_fields = str1();
808        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
809
810        let union = UnionArray::try_new(
811            //multiple fields
812            str1_union3(target_type),
813            ScalarBuffer::from(vec![3, 1]), // some types match, but not all
814            None,                           //sparse
815            vec![
816                Arc::new(StringArray::new_null(2)),
817                Arc::new(
818                    UnionArray::try_new(
819                        target_fields.clone(),
820                        ScalarBuffer::from(vec![1, 1]),
821                        None,
822                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
823                    )
824                    .unwrap(),
825                ),
826            ],
827        )
828        .unwrap();
829
830        let expected = UnionArray::try_new(
831            target_fields,
832            ScalarBuffer::from(vec![1, 1]),
833            None,
834            vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
835        )
836        .unwrap();
837        let extracted = union_extract(&union, "union").unwrap();
838
839        assert_eq!(extracted.into_data(), expected.into_data());
840    }
841
842    #[test]
843    fn dense_1_1_both_empty() {
844        let union = UnionArray::try_new(
845            str1_int3(),
846            ScalarBuffer::from(vec![]),       //empty union
847            Some(ScalarBuffer::from(vec![])), // dense
848            vec![
849                Arc::new(StringArray::new_null(0)), //empty target
850                Arc::new(Int32Array::new_null(0)),
851            ],
852        )
853        .unwrap();
854
855        let expected = StringArray::new_null(0);
856        let extracted = union_extract(&union, "str").unwrap();
857
858        assert_eq!(extracted.into_data(), expected.into_data());
859    }
860
861    #[test]
862    fn dense_1_2_empty_union_target_non_empty() {
863        let union = UnionArray::try_new(
864            str1_int3(),
865            ScalarBuffer::from(vec![]),       //empty union
866            Some(ScalarBuffer::from(vec![])), // dense
867            vec![
868                Arc::new(StringArray::new_null(1)), //non empty target
869                Arc::new(Int32Array::new_null(0)),
870            ],
871        )
872        .unwrap();
873
874        let expected = StringArray::new_null(0);
875        let extracted = union_extract(&union, "str").unwrap();
876
877        assert_eq!(extracted.into_data(), expected.into_data());
878    }
879
880    #[test]
881    fn dense_2_non_empty_union_target_empty() {
882        let union = UnionArray::try_new(
883            str1_int3(),
884            ScalarBuffer::from(vec![3, 3]),       //non empty union
885            Some(ScalarBuffer::from(vec![0, 1])), // dense
886            vec![
887                Arc::new(StringArray::new_null(0)), //empty target
888                Arc::new(Int32Array::new_null(2)),
889            ],
890        )
891        .unwrap();
892
893        let expected = StringArray::new_null(2);
894        let extracted = union_extract(&union, "str").unwrap();
895
896        assert_eq!(extracted.into_data(), expected.into_data());
897    }
898
899    #[test]
900    fn dense_3_1_null_target_smaller_len() {
901        let union = UnionArray::try_new(
902            str1_int3(),
903            ScalarBuffer::from(vec![3, 3]),       //non empty union
904            Some(ScalarBuffer::from(vec![0, 0])), //dense
905            vec![
906                Arc::new(StringArray::new_null(1)), //smaller target
907                Arc::new(Int32Array::new_null(2)),
908            ],
909        )
910        .unwrap();
911
912        let expected = StringArray::new_null(2);
913        let extracted = union_extract(&union, "str").unwrap();
914
915        assert_eq!(extracted.into_data(), expected.into_data());
916    }
917
918    #[test]
919    fn dense_3_2_null_target_equal_len() {
920        let union = UnionArray::try_new(
921            str1_int3(),
922            ScalarBuffer::from(vec![3, 3]),       //non empty union
923            Some(ScalarBuffer::from(vec![0, 0])), //dense
924            vec![
925                Arc::new(StringArray::new_null(2)), //equal len
926                Arc::new(Int32Array::new_null(2)),
927            ],
928        )
929        .unwrap();
930
931        let expected = StringArray::new_null(2);
932        let extracted = union_extract(&union, "str").unwrap();
933
934        assert_eq!(extracted.into_data(), expected.into_data());
935    }
936
937    #[test]
938    fn dense_3_3_null_target_bigger_len() {
939        let union = UnionArray::try_new(
940            str1_int3(),
941            ScalarBuffer::from(vec![3, 3]),       //non empty union
942            Some(ScalarBuffer::from(vec![0, 0])), //dense
943            vec![
944                Arc::new(StringArray::new_null(3)), //bigger len
945                Arc::new(Int32Array::new_null(3)),
946            ],
947        )
948        .unwrap();
949
950        let expected = StringArray::new_null(2);
951        let extracted = union_extract(&union, "str").unwrap();
952
953        assert_eq!(extracted.into_data(), expected.into_data());
954    }
955
956    #[test]
957    fn dense_4_1a_single_type_sequential_offsets_equal_len() {
958        let union = UnionArray::try_new(
959            // single field
960            str1(),
961            ScalarBuffer::from(vec![1, 1]),       //non empty union
962            Some(ScalarBuffer::from(vec![0, 1])), //sequential
963            vec![
964                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len, non null
965            ],
966        )
967        .unwrap();
968
969        let expected = StringArray::from(vec!["a1", "b2"]);
970        let extracted = union_extract(&union, "str").unwrap();
971
972        assert_eq!(extracted.into_data(), expected.into_data());
973    }
974
975    #[test]
976    fn dense_4_2a_single_type_sequential_offsets_bigger() {
977        let union = UnionArray::try_new(
978            // single field
979            str1(),
980            ScalarBuffer::from(vec![1, 1]),       //non empty union
981            Some(ScalarBuffer::from(vec![0, 1])), //sequential
982            vec![
983                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
984            ],
985        )
986        .unwrap();
987
988        let expected = StringArray::from(vec!["a1", "b2"]);
989        let extracted = union_extract(&union, "str").unwrap();
990
991        assert_eq!(extracted.into_data(), expected.into_data());
992    }
993
994    #[test]
995    fn dense_4_3a_single_type_non_sequential() {
996        let union = UnionArray::try_new(
997            // single field
998            str1(),
999            ScalarBuffer::from(vec![1, 1]),       //non empty union
1000            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1001            vec![
1002                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
1003            ],
1004        )
1005        .unwrap();
1006
1007        let expected = StringArray::from(vec!["a1", "c3"]);
1008        let extracted = union_extract(&union, "str").unwrap();
1009
1010        assert_eq!(extracted.into_data(), expected.into_data());
1011    }
1012
1013    #[test]
1014    fn dense_4_1b_empty_siblings_sequential_equal_len() {
1015        let union = UnionArray::try_new(
1016            // multiple fields
1017            str1_int3(),
1018            ScalarBuffer::from(vec![1, 1]),       //non empty union
1019            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1020            vec![
1021                Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null
1022                Arc::new(Int32Array::new_null(0)),           //empty sibling
1023            ],
1024        )
1025        .unwrap();
1026
1027        let expected = StringArray::from(vec!["a", "b"]);
1028        let extracted = union_extract(&union, "str").unwrap();
1029
1030        assert_eq!(extracted.into_data(), expected.into_data());
1031    }
1032
1033    #[test]
1034    fn dense_4_2b_empty_siblings_sequential_bigger_len() {
1035        let union = UnionArray::try_new(
1036            // multiple fields
1037            str1_int3(),
1038            ScalarBuffer::from(vec![1, 1]),       //non empty union
1039            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1040            vec![
1041                Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null
1042                Arc::new(Int32Array::new_null(0)),                //empty sibling
1043            ],
1044        )
1045        .unwrap();
1046
1047        let expected = StringArray::from(vec!["a", "b"]);
1048        let extracted = union_extract(&union, "str").unwrap();
1049
1050        assert_eq!(extracted.into_data(), expected.into_data());
1051    }
1052
1053    #[test]
1054    fn dense_4_3b_empty_sibling_non_sequential() {
1055        let union = UnionArray::try_new(
1056            // multiple fields
1057            str1_int3(),
1058            ScalarBuffer::from(vec![1, 1]),       //non empty union
1059            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1060            vec![
1061                Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null
1062                Arc::new(Int32Array::new_null(0)),                //empty sibling
1063            ],
1064        )
1065        .unwrap();
1066
1067        let expected = StringArray::from(vec!["a", "c"]);
1068        let extracted = union_extract(&union, "str").unwrap();
1069
1070        assert_eq!(extracted.into_data(), expected.into_data());
1071    }
1072
1073    #[test]
1074    fn dense_4_1c_all_types_match_sequential_equal_len() {
1075        let union = UnionArray::try_new(
1076            // multiple fields
1077            str1_int3(),
1078            ScalarBuffer::from(vec![1, 1]),       //all types match
1079            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1080            vec![
1081                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len
1082                Arc::new(Int32Array::new_null(2)),             //non empty sibling
1083            ],
1084        )
1085        .unwrap();
1086
1087        let expected = StringArray::from(vec!["a1", "b2"]);
1088        let extracted = union_extract(&union, "str").unwrap();
1089
1090        assert_eq!(extracted.into_data(), expected.into_data());
1091    }
1092
1093    #[test]
1094    fn dense_4_2c_all_types_match_sequential_bigger_len() {
1095        let union = UnionArray::try_new(
1096            // multiple fields
1097            str1_int3(),
1098            ScalarBuffer::from(vec![1, 1]),       //all types match
1099            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1100            vec![
1101                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len
1102                Arc::new(Int32Array::new_null(2)),                   //non empty sibling
1103            ],
1104        )
1105        .unwrap();
1106
1107        let expected = StringArray::from(vec!["a1", "b2"]);
1108        let extracted = union_extract(&union, "str").unwrap();
1109
1110        assert_eq!(extracted.into_data(), expected.into_data());
1111    }
1112
1113    #[test]
1114    fn dense_4_3c_all_types_match_non_sequential() {
1115        let union = UnionArray::try_new(
1116            // multiple fields
1117            str1_int3(),
1118            ScalarBuffer::from(vec![1, 1]),       //all types match
1119            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1120            vec![
1121                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1122                Arc::new(Int32Array::new_null(2)), //non empty sibling
1123            ],
1124        )
1125        .unwrap();
1126
1127        let expected = StringArray::from(vec!["a1", "b3"]);
1128        let extracted = union_extract(&union, "str").unwrap();
1129
1130        assert_eq!(extracted.into_data(), expected.into_data());
1131    }
1132
1133    #[test]
1134    fn dense_5_1a_none_match_less_len() {
1135        let union = UnionArray::try_new(
1136            // multiple fields
1137            str1_int3(),
1138            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1139            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1140            vec![
1141                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1142                Arc::new(Int32Array::from(vec![1, 2])),
1143            ],
1144        )
1145        .unwrap();
1146
1147        let expected = StringArray::new_null(5);
1148        let extracted = union_extract(&union, "str").unwrap();
1149
1150        assert_eq!(extracted.into_data(), expected.into_data());
1151    }
1152
1153    #[test]
1154    fn dense_5_1b_cant_contain_null_mask() {
1155        let target_fields = str1();
1156        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1157
1158        let union = UnionArray::try_new(
1159            // multiple fields
1160            str1_union3(target_type.clone()),
1161            ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches
1162            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1163            vec![
1164                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1165                Arc::new(
1166                    UnionArray::try_new(
1167                        target_fields.clone(),
1168                        ScalarBuffer::from(vec![1]),
1169                        None,
1170                        vec![Arc::new(StringArray::from(vec!["a"]))],
1171                    )
1172                    .unwrap(),
1173                ), // non empty
1174            ],
1175        )
1176        .unwrap();
1177
1178        let expected = new_null_array(&target_type, 5);
1179        let extracted = union_extract(&union, "union").unwrap();
1180
1181        assert_eq!(extracted.into_data(), expected.into_data());
1182    }
1183
1184    #[test]
1185    fn dense_5_2_none_match_equal_len() {
1186        let union = UnionArray::try_new(
1187            // multiple fields
1188            str1_int3(),
1189            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1190            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1191            vec![
1192                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len
1193                Arc::new(Int32Array::from(vec![1, 2])),
1194            ],
1195        )
1196        .unwrap();
1197
1198        let expected = StringArray::new_null(5);
1199        let extracted = union_extract(&union, "str").unwrap();
1200
1201        assert_eq!(extracted.into_data(), expected.into_data());
1202    }
1203
1204    #[test]
1205    fn dense_5_3_none_match_greater_len() {
1206        let union = UnionArray::try_new(
1207            // multiple fields
1208            str1_int3(),
1209            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1210            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1211            vec![
1212                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len
1213                Arc::new(Int32Array::from(vec![1, 2])),                                //non null
1214            ],
1215        )
1216        .unwrap();
1217
1218        let expected = StringArray::new_null(5);
1219        let extracted = union_extract(&union, "str").unwrap();
1220
1221        assert_eq!(extracted.into_data(), expected.into_data());
1222    }
1223
1224    #[test]
1225    fn dense_6_some_matches() {
1226        let union = UnionArray::try_new(
1227            // multiple fields
1228            str1_int3(),
1229            ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches
1230            Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense
1231            vec![
1232                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null
1233                Arc::new(Int32Array::from(vec![1, 2])),
1234            ],
1235        )
1236        .unwrap();
1237
1238        let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1239        let extracted = union_extract(&union, "int").unwrap();
1240
1241        assert_eq!(extracted.into_data(), expected.into_data());
1242    }
1243
1244    #[test]
1245    fn empty_sparse_union() {
1246        let union = UnionArray::try_new(
1247            UnionFields::empty(),
1248            ScalarBuffer::from(vec![]),
1249            None,
1250            vec![],
1251        )
1252        .unwrap();
1253
1254        assert_eq!(
1255            union_extract(&union, "a").unwrap_err().to_string(),
1256            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1257        );
1258    }
1259
1260    #[test]
1261    fn empty_dense_union() {
1262        let union = UnionArray::try_new(
1263            UnionFields::empty(),
1264            ScalarBuffer::from(vec![]),
1265            Some(ScalarBuffer::from(vec![])),
1266            vec![],
1267        )
1268        .unwrap();
1269
1270        assert_eq!(
1271            union_extract(&union, "a").unwrap_err().to_string(),
1272            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1273        );
1274    }
1275
1276    #[test]
1277    fn extract_by_id_sparse_duplicate_names() {
1278        // Two fields with the same name "val" but different type_ids and types
1279        let fields = UnionFields::try_new(
1280            [0, 1],
1281            [
1282                Field::new("val", DataType::Int32, true),
1283                Field::new("val", DataType::Utf8, true),
1284            ],
1285        )
1286        .unwrap();
1287
1288        let union = UnionArray::try_new(
1289            fields,
1290            vec![0_i8, 1, 0, 1].into(),
1291            None,
1292            vec![
1293                Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as _,
1294                Arc::new(StringArray::from(vec![
1295                    None,
1296                    Some("hello"),
1297                    None,
1298                    Some("world"),
1299                ])),
1300            ],
1301        )
1302        .unwrap();
1303
1304        // union_extract by name always returns type_id 0 (first match)
1305        let by_name = union_extract(&union, "val").unwrap();
1306        assert_eq!(
1307            *by_name,
1308            Int32Array::from(vec![Some(42), None, Some(99), None])
1309        );
1310
1311        // union_extract_by_id can select type_id 1 (the Utf8 child)
1312        let by_id = union_extract_by_id(&union, 1).unwrap();
1313        assert_eq!(
1314            *by_id,
1315            StringArray::from(vec![None, Some("hello"), None, Some("world")])
1316        );
1317    }
1318
1319    #[test]
1320    fn extract_by_id_dense_duplicate_names() {
1321        let fields = UnionFields::try_new(
1322            [0, 1],
1323            [
1324                Field::new("val", DataType::Int32, true),
1325                Field::new("val", DataType::Utf8, true),
1326            ],
1327        )
1328        .unwrap();
1329
1330        let union = UnionArray::try_new(
1331            fields,
1332            vec![0_i8, 1, 0].into(),
1333            Some(vec![0_i32, 0, 1].into()),
1334            vec![
1335                Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
1336                Arc::new(StringArray::from(vec![Some("hello")])),
1337            ],
1338        )
1339        .unwrap();
1340
1341        // by type_id 0 → Int32 child
1342        let by_id_0 = union_extract_by_id(&union, 0).unwrap();
1343        assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));
1344
1345        // by type_id 1 → Utf8 child
1346        let by_id_1 = union_extract_by_id(&union, 1).unwrap();
1347        assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"), None]));
1348    }
1349
1350    #[test]
1351    fn extract_by_id_not_found() {
1352        let fields = UnionFields::try_new(
1353            [0, 1],
1354            [
1355                Field::new("a", DataType::Int32, true),
1356                Field::new("b", DataType::Utf8, true),
1357            ],
1358        )
1359        .unwrap();
1360
1361        let union = UnionArray::try_new(
1362            fields,
1363            vec![0_i8, 1].into(),
1364            None,
1365            vec![
1366                Arc::new(Int32Array::from(vec![Some(1), None])) as _,
1367                Arc::new(StringArray::from(vec![None, Some("x")])),
1368            ],
1369        )
1370        .unwrap();
1371
1372        assert_eq!(
1373            union_extract_by_id(&union, 5).unwrap_err().to_string(),
1374            ArrowError::InvalidArgumentError("type_id 5 not found on union".into()).to_string()
1375        );
1376    }
1377}