arrow_select/
union_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines union_extract kernel for [UnionArray]
19
20use crate::take::take;
21use arrow_array::{
22    Array, ArrayRef, BooleanArray, Int32Array, Scalar, UnionArray, make_array, new_empty_array,
23    new_null_array,
24};
25use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer, bit_util};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31/// Returns the value of the target field when selected, or NULL otherwise.
32/// ```text
33/// ┌─────────────────┐                                   ┌─────────────────┐
34/// │       A=1       │                                   │        1        │
35/// ├─────────────────┤                                   ├─────────────────┤
36/// │      A=NULL     │                                   │       NULL      │
37/// ├─────────────────┤    union_extract(values, 'A')     ├─────────────────┤
38/// │      B='t'      │  ────────────────────────────▶    │       NULL      │
39/// ├─────────────────┤                                   ├─────────────────┤
40/// │       A=3       │                                   │        3        │
41/// ├─────────────────┤                                   ├─────────────────┤
42/// │      B=NULL     │                                   │       NULL      │
43/// └─────────────────┘                                   └─────────────────┘
44///    union array                                              result
45/// ```
46/// # Errors
47///
48/// Returns error if target field is not found
49///
50/// # Examples
51/// ```
52/// # use std::sync::Arc;
53/// # use arrow_schema::{DataType, Field, UnionFields};
54/// # use arrow_array::{UnionArray, StringArray, Int32Array};
55/// # use arrow_select::union_extract::union_extract;
56/// let fields = UnionFields::try_new(
57///     [1, 3],
58///     [
59///         Field::new("A", DataType::Int32, true),
60///         Field::new("B", DataType::Utf8, true)
61///     ]
62/// ).unwrap();
63///
64/// let union = UnionArray::try_new(
65///     fields,
66///     vec![1, 1, 3, 1, 3].into(),
67///     None,
68///     vec![
69///         Arc::new(Int32Array::from(vec![Some(1), None, None, Some(3), Some(0)])),
70///         Arc::new(StringArray::from(vec![None, None, Some("t"), Some("."), None]))
71///     ]
72/// ).unwrap();
73///
74/// // Extract field A
75/// let extracted = union_extract(&union, "A").unwrap();
76///
77/// assert_eq!(*extracted, Int32Array::from(vec![Some(1), None, None, Some(3), None]));
78/// ```
79pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80    let fields = match union_array.data_type() {
81        DataType::Union(fields, _) => fields,
82        _ => unreachable!(),
83    };
84
85    let (target_type_id, _) = fields
86        .iter()
87        .find(|field| field.1.name() == target)
88        .ok_or_else(|| {
89            ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90        })?;
91
92    match union_array.offsets() {
93        Some(_) => extract_dense(union_array, fields, target_type_id),
94        None => extract_sparse(union_array, fields, target_type_id),
95    }
96}
97
98fn extract_sparse(
99    union_array: &UnionArray,
100    fields: &UnionFields,
101    target_type_id: i8,
102) -> Result<ArrayRef, ArrowError> {
103    let target = union_array.child(target_type_id);
104
105    if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child
106        || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children
107        || target.null_count() == target.len() || target.data_type().is_null()
108    // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null
109    {
110        Ok(Arc::clone(target))
111    } else {
112        match eq_scalar(union_array.type_ids(), target_type_id) {
113            // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target
114            BoolValue::Scalar(true) => Ok(Arc::clone(target)),
115            // case 3: none type_id matches our target, the result is a null array
116            BoolValue::Scalar(false) => {
117                if layout(target.data_type()).can_contain_null_mask {
118                    // case 3.1: target array can contain a null mask
119                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
120                    let data = unsafe {
121                        target
122                            .into_data()
123                            .into_builder()
124                            .nulls(Some(NullBuffer::new_null(target.len())))
125                            .build_unchecked()
126                    };
127
128                    Ok(make_array(data))
129                } else {
130                    // case 3.2: target can't contain a null mask
131                    Ok(new_null_array(target.data_type(), target.len()))
132                }
133            }
134            // case 4: some but not all type_id matches our target
135            BoolValue::Buffer(selected) => {
136                if layout(target.data_type()).can_contain_null_mask {
137                    // case 4.1: target array can contain a null mask
138                    let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
139                        // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks
140                        // the case where n.null_count() == n.len() is cheaply handled at case 1.3
141                        Some(nulls) => &selected & nulls.inner(),
142                        // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask
143                        None => selected,
144                    };
145
146                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
147                    let data = unsafe {
148                        assert_eq!(nulls.len(), target.len());
149
150                        target
151                            .into_data()
152                            .into_builder()
153                            .nulls(Some(nulls.into()))
154                            .build_unchecked()
155                    };
156
157                    Ok(make_array(data))
158                } else {
159                    // case 4.2: target can't containt a null mask, zip the values that match with a null value
160                    Ok(crate::zip::zip(
161                        &BooleanArray::new(selected, None),
162                        target,
163                        &Scalar::new(new_null_array(target.data_type(), 1)),
164                    )?)
165                }
166            }
167        }
168    }
169}
170
171fn extract_dense(
172    union_array: &UnionArray,
173    fields: &UnionFields,
174    target_type_id: i8,
175) -> Result<ArrayRef, ArrowError> {
176    let target = union_array.child(target_type_id);
177    let offsets = union_array.offsets().unwrap();
178
179    if union_array.is_empty() {
180        // case 1: the union is empty
181        if target.is_empty() {
182            // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array
183            Ok(Arc::clone(target))
184        } else {
185            // case 1.2: the target is not empty, allocate a new empty array
186            Ok(new_empty_array(target.data_type()))
187        }
188    } else if target.is_empty() {
189        // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array
190        Ok(new_null_array(target.data_type(), union_array.len()))
191    } else if target.null_count() == target.len() || target.data_type().is_null() {
192        // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array
193        match target.len().cmp(&union_array.len()) {
194            // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array
195            Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
196            // case 3.2: target equals the union len, return it direcly
197            Ordering::Equal => Ok(Arc::clone(target)),
198            // case 3.3: target len is bigger than the union len, slice it
199            Ordering::Greater => Ok(target.slice(0, union_array.len())),
200        }
201    } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target
202        || fields
203            .iter()
204            .filter(|(field_type_id, _)| *field_type_id != target_type_id)
205            .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
206    // case B: since siblings are empty, every type id must matches our target
207    {
208        // case 4: every type id matches our target
209        Ok(extract_dense_all_selected(union_array, target, offsets)?)
210    } else {
211        match eq_scalar(union_array.type_ids(), target_type_id) {
212            // case 4C: all type ids matches our target.
213            // Non empty sibling without any selected value may happen after slicing the parent union,
214            // since only type_ids and offsets are sliced, not the children
215            BoolValue::Scalar(true) => {
216                Ok(extract_dense_all_selected(union_array, target, offsets)?)
217            }
218            BoolValue::Scalar(false) => {
219                // case 5: none type_id matches our target, so the result array will be completely null
220                // Non empty target without any selected value may happen after slicing the parent union,
221                // since only type_ids and offsets are sliced, not the children
222                match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
223                    (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array
224                    | (_, false) => { // case 5.1B: target array can't contain a null mask
225                        Ok(new_null_array(target.data_type(), union_array.len()))
226                    }
227                    // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer
228                    (Ordering::Equal, true) => {
229                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
230                        let data = unsafe {
231                            target
232                                .into_data()
233                                .into_builder()
234                                .nulls(Some(NullBuffer::new_null(union_array.len())))
235                                .build_unchecked()
236                        };
237
238                        Ok(make_array(data))
239                    }
240                    // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer
241                    (Ordering::Greater, true) => {
242                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
243                        let data = unsafe {
244                            target
245                                .into_data()
246                                .slice(0, union_array.len())
247                                .into_builder()
248                                .nulls(Some(NullBuffer::new_null(union_array.len())))
249                                .build_unchecked()
250                        };
251
252                        Ok(make_array(data))
253                    }
254                }
255            }
256            BoolValue::Buffer(selected) => {
257                //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, use a valid null
258                Ok(take(
259                    target,
260                    &Int32Array::try_new(offsets.clone(), Some(selected.into()))?,
261                    None,
262                )?)
263            }
264        }
265    }
266}
267
268fn extract_dense_all_selected(
269    union_array: &UnionArray,
270    target: &Arc<dyn Array>,
271    offsets: &ScalarBuffer<i32>,
272) -> Result<ArrayRef, ArrowError> {
273    let sequential =
274        target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
275
276    if sequential && target.len() == union_array.len() {
277        // case 1: all offsets are sequential and both lengths match, return the array directly
278        Ok(Arc::clone(target))
279    } else if sequential && target.len() > union_array.len() {
280        // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset
281        Ok(target.slice(offsets[0] as usize, union_array.len()))
282    } else {
283        // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array
284        let indices = Int32Array::try_new(offsets.clone(), None)?;
285
286        Ok(take(target, &indices, None)?)
287    }
288}
289
290const EQ_SCALAR_CHUNK_SIZE: usize = 512;
291
292/// The result of checking which type_ids matches the target type_id
293#[derive(Debug, PartialEq)]
294enum BoolValue {
295    /// If true, all type_ids matches the target type_id
296    /// If false, none type_ids matches the target type_id
297    Scalar(bool),
298    /// A mask represeting which type_ids matches the target type_id
299    Buffer(BooleanBuffer),
300}
301
302fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
303    eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
304}
305
306fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
307    type_ids
308        .chunks(chunk_size)
309        .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
310        .map(|chunk| chunk.len())
311        .sum()
312}
313
314// This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values.
315fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
316    let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
317
318    let (set_bits, val) = if true_bits == type_ids.len() {
319        return BoolValue::Scalar(true);
320    } else if true_bits == 0 {
321        let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
322
323        if false_bits == type_ids.len() {
324            return BoolValue::Scalar(false);
325        } else {
326            (false_bits, false)
327        }
328    } else {
329        (true_bits, true)
330    };
331
332    // restrict to chunk boundaries
333    let set_bits = set_bits - set_bits % 64;
334
335    let mut buffer =
336        MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
337
338    buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
339        chunk
340            .iter()
341            .copied()
342            .enumerate()
343            .fold(0, |packed, (bit_idx, v)| {
344                packed | (((v == target) as u64) << bit_idx)
345            })
346    }));
347
348    BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
349}
350
351const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
352
353fn is_sequential(offsets: &[i32]) -> bool {
354    is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
355}
356
357fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
358    if offsets.is_empty() {
359        return true;
360    }
361
362    // fast check this common combination:
363    // 1: sequential nulls are represented as a single null value on the values array, pointed by the same offset multiple times
364    // 2: valid values offsets increase one by one.
365    // example for an union with a single field A with type_id 0:
366    // union    = A=7 A=NULL A=NULL A=5 A=9
367    // a values = 7 NULL 5 9
368    // offsets  = 0 1 1 2 3
369    // type_ids = 0 0 0 0 0
370    // this also checks if the last chunk/remainder is sequential relative to the first offset
371    if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
372        return false;
373    }
374
375    let chunks = offsets.chunks_exact(N);
376
377    let remainder = chunks.remainder();
378
379    chunks.enumerate().all(|(i, chunk)| {
380        let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
381
382        //checks if values within chunk are sequential
383        chunk_array
384            .iter()
385            .copied()
386            .enumerate()
387            .fold(true, |acc, (i, offset)| {
388                acc & (offset == chunk_array[0] + i as i32)
389            })
390            && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset
391    }) && remainder
392        .iter()
393        .copied()
394        .enumerate()
395        .fold(true, |acc, (i, offset)| {
396            acc & (offset == remainder[0] + i as i32)
397        }) //if the remainder is sequential relative to the first offset is checked at the start of the function
398}
399
400#[cfg(test)]
401mod tests {
402    use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract};
403    use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
404    use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405    use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
406    use std::sync::Arc;
407
408    #[test]
409    fn test_eq_scalar() {
410        //multiple all equal chunks, so it's loop and sum logic it's tested
411        //multiple chunks after, so it's loop logic it's tested
412        const ARRAY_LEN: usize = 64 * 4;
413
414        //so out of 64 boundaries chunks can be generated and checked for
415        const EQ_SCALAR_CHUNK_SIZE: usize = 3;
416
417        fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
418            eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
419        }
420
421        fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
422            BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
423        }
424
425        assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
426
427        assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
428        assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
429
430        let mut values = [1; ARRAY_LEN];
431
432        assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
433        assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
434
435        //every subslice should return the same value
436        for i in 1..ARRAY_LEN {
437            assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
438            assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
439        }
440
441        // test that a single change anywhere is checked for
442        for i in 0..ARRAY_LEN {
443            values[i] = 2;
444
445            assert_eq!(
446                eq_scalar(&values, 1),
447                BoolValue::Buffer(cross_check(&values, 1))
448            );
449            assert_eq!(
450                eq_scalar(&values, 2),
451                BoolValue::Buffer(cross_check(&values, 2))
452            );
453
454            values[i] = 1;
455        }
456    }
457
458    #[test]
459    fn test_is_sequential() {
460        /*
461        the smallest value that satisfies:
462        >1 so the fold logic of a exact chunk executes
463        >2 so a >1 non-exact remainder can exist, and it's fold logic executes
464         */
465        const CHUNK_SIZE: usize = 3;
466        //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2:
467        //multiple(2) exact chunks, so the AND logic between them executes
468        //a >1(2) remainder, so:
469        //    the AND logic between all exact chunks and the remainder executes
470        //    the remainder fold logic executes
471
472        fn is_sequential(v: &[i32]) -> bool {
473            is_sequential_generic::<CHUNK_SIZE>(v)
474        }
475
476        assert!(is_sequential(&[])); //empty
477        assert!(is_sequential(&[1])); //single
478
479        assert!(is_sequential(&[1, 2]));
480        assert!(is_sequential(&[1, 2, 3]));
481        assert!(is_sequential(&[1, 2, 3, 4]));
482        assert!(is_sequential(&[1, 2, 3, 4, 5]));
483        assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
484        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
485        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
486
487        assert!(!is_sequential(&[8, 7]));
488        assert!(!is_sequential(&[8, 7, 6]));
489        assert!(!is_sequential(&[8, 7, 6, 5]));
490        assert!(!is_sequential(&[8, 7, 6, 5, 4]));
491        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
492        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
493        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
494
495        assert!(!is_sequential(&[0, 2]));
496        assert!(!is_sequential(&[1, 0]));
497
498        assert!(!is_sequential(&[0, 2, 3]));
499        assert!(!is_sequential(&[1, 0, 3]));
500        assert!(!is_sequential(&[1, 2, 0]));
501
502        assert!(!is_sequential(&[0, 2, 3, 4]));
503        assert!(!is_sequential(&[1, 0, 3, 4]));
504        assert!(!is_sequential(&[1, 2, 0, 4]));
505        assert!(!is_sequential(&[1, 2, 3, 0]));
506
507        assert!(!is_sequential(&[0, 2, 3, 4, 5]));
508        assert!(!is_sequential(&[1, 0, 3, 4, 5]));
509        assert!(!is_sequential(&[1, 2, 0, 4, 5]));
510        assert!(!is_sequential(&[1, 2, 3, 0, 5]));
511        assert!(!is_sequential(&[1, 2, 3, 4, 0]));
512
513        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
514        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
515        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
516        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
517        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
518        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
519
520        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
521        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
522        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
523        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
524        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
525        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
526        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
527
528        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
529        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
530        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
531        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
532        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
533        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
534        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
535        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
536
537        // checks increments at the chunk boundary
538        assert!(!is_sequential(&[1, 2, 3, 5]));
539        assert!(!is_sequential(&[1, 2, 3, 5, 6]));
540        assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
541        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
542        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
543    }
544
545    fn str1() -> UnionFields {
546        UnionFields::try_new(vec![1], vec![Field::new("str", DataType::Utf8, true)]).unwrap()
547    }
548
549    fn str1_int3() -> UnionFields {
550        UnionFields::try_new(
551            vec![1, 3],
552            vec![
553                Field::new("str", DataType::Utf8, true),
554                Field::new("int", DataType::Int32, true),
555            ],
556        )
557        .unwrap()
558    }
559
560    #[test]
561    fn sparse_1_1_single_field() {
562        let union = UnionArray::try_new(
563            //single field
564            str1(),
565            ScalarBuffer::from(vec![1, 1]), // non empty, every type id must match
566            None,                           //sparse
567            vec![
568                Arc::new(StringArray::from(vec!["a", "b"])), // not null
569            ],
570        )
571        .unwrap();
572
573        let expected = StringArray::from(vec!["a", "b"]);
574        let extracted = union_extract(&union, "str").unwrap();
575
576        assert_eq!(extracted.into_data(), expected.into_data());
577    }
578
579    #[test]
580    fn sparse_1_2_empty() {
581        let union = UnionArray::try_new(
582            // multiple fields
583            str1_int3(),
584            ScalarBuffer::from(vec![]), //empty union
585            None,                       // sparse
586            vec![
587                Arc::new(StringArray::new_null(0)),
588                Arc::new(Int32Array::new_null(0)),
589            ],
590        )
591        .unwrap();
592
593        let expected = StringArray::new_null(0);
594        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
595
596        assert_eq!(extracted.into_data(), expected.into_data());
597    }
598
599    #[test]
600    fn sparse_1_3a_null_target() {
601        let union = UnionArray::try_new(
602            // multiple fields
603            UnionFields::try_new(
604                vec![1, 3],
605                vec![
606                    Field::new("str", DataType::Utf8, true),
607                    Field::new("null", DataType::Null, true), // target type is Null
608                ],
609            )
610            .unwrap(),
611            ScalarBuffer::from(vec![1]), //not empty
612            None,                        // sparse
613            vec![
614                Arc::new(StringArray::new_null(1)),
615                Arc::new(NullArray::new(1)), // null data type
616            ],
617        )
618        .unwrap();
619
620        let expected = NullArray::new(1);
621        let extracted = union_extract(&union, "null").unwrap();
622
623        assert_eq!(extracted.into_data(), expected.into_data());
624    }
625
626    #[test]
627    fn sparse_1_3b_null_target() {
628        let union = UnionArray::try_new(
629            // multiple fields
630            str1_int3(),
631            ScalarBuffer::from(vec![1]), //not empty
632            None,                        // sparse
633            vec![
634                Arc::new(StringArray::new_null(1)), //all null
635                Arc::new(Int32Array::new_null(1)),
636            ],
637        )
638        .unwrap();
639
640        let expected = StringArray::new_null(1);
641        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
642
643        assert_eq!(extracted.into_data(), expected.into_data());
644    }
645
646    #[test]
647    fn sparse_2_all_types_match() {
648        let union = UnionArray::try_new(
649            //multiple fields
650            str1_int3(),
651            ScalarBuffer::from(vec![3, 3]), // all types match
652            None,                           //sparse
653            vec![
654                Arc::new(StringArray::new_null(2)),
655                Arc::new(Int32Array::from(vec![1, 4])), // not null
656            ],
657        )
658        .unwrap();
659
660        let expected = Int32Array::from(vec![1, 4]);
661        let extracted = union_extract(&union, "int").unwrap();
662
663        assert_eq!(extracted.into_data(), expected.into_data());
664    }
665
666    #[test]
667    fn sparse_3_1_none_match_target_can_contain_null_mask() {
668        let union = UnionArray::try_new(
669            //multiple fields
670            str1_int3(),
671            ScalarBuffer::from(vec![1, 1, 1, 1]), // none match
672            None,                                 // sparse
673            vec![
674                Arc::new(StringArray::new_null(4)),
675                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null
676            ],
677        )
678        .unwrap();
679
680        let expected = Int32Array::new_null(4);
681        let extracted = union_extract(&union, "int").unwrap();
682
683        assert_eq!(extracted.into_data(), expected.into_data());
684    }
685
686    fn str1_union3(union3_datatype: DataType) -> UnionFields {
687        UnionFields::try_new(
688            vec![1, 3],
689            vec![
690                Field::new("str", DataType::Utf8, true),
691                Field::new("union", union3_datatype, true),
692            ],
693        )
694        .unwrap()
695    }
696
697    #[test]
698    fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
699        let target_fields = str1();
700        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
701
702        let union = UnionArray::try_new(
703            //multiple fields
704            str1_union3(target_type.clone()),
705            ScalarBuffer::from(vec![1, 1]), // none match
706            None,                           //sparse
707            vec![
708                Arc::new(StringArray::new_null(2)),
709                //target is not null
710                Arc::new(
711                    UnionArray::try_new(
712                        target_fields.clone(),
713                        ScalarBuffer::from(vec![1, 1]),
714                        None,
715                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
716                    )
717                    .unwrap(),
718                ),
719            ],
720        )
721        .unwrap();
722
723        let expected = new_null_array(&target_type, 2);
724        let extracted = union_extract(&union, "union").unwrap();
725
726        assert_eq!(extracted.into_data(), expected.into_data());
727    }
728
729    #[test]
730    fn sparse_4_1_1_target_with_nulls() {
731        let union = UnionArray::try_new(
732            //multiple fields
733            str1_int3(),
734            ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types
735            None,                                 // sparse
736            vec![
737                Arc::new(StringArray::new_null(4)),
738                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls
739            ],
740        )
741        .unwrap();
742
743        let expected = Int32Array::from(vec![None, Some(4), None, None]);
744        let extracted = union_extract(&union, "int").unwrap();
745
746        assert_eq!(extracted.into_data(), expected.into_data());
747    }
748
749    #[test]
750    fn sparse_4_1_2_target_without_nulls() {
751        let union = UnionArray::try_new(
752            //multiple fields
753            str1_int3(),
754            ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types
755            None,                              // sparse
756            vec![
757                Arc::new(StringArray::new_null(3)),
758                Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls
759            ],
760        )
761        .unwrap();
762
763        let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
764        let extracted = union_extract(&union, "int").unwrap();
765
766        assert_eq!(extracted.into_data(), expected.into_data());
767    }
768
769    #[test]
770    fn sparse_4_2_some_match_target_cant_contain_null_mask() {
771        let target_fields = str1();
772        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
773
774        let union = UnionArray::try_new(
775            //multiple fields
776            str1_union3(target_type),
777            ScalarBuffer::from(vec![3, 1]), // some types match, but not all
778            None,                           //sparse
779            vec![
780                Arc::new(StringArray::new_null(2)),
781                Arc::new(
782                    UnionArray::try_new(
783                        target_fields.clone(),
784                        ScalarBuffer::from(vec![1, 1]),
785                        None,
786                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
787                    )
788                    .unwrap(),
789                ),
790            ],
791        )
792        .unwrap();
793
794        let expected = UnionArray::try_new(
795            target_fields,
796            ScalarBuffer::from(vec![1, 1]),
797            None,
798            vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
799        )
800        .unwrap();
801        let extracted = union_extract(&union, "union").unwrap();
802
803        assert_eq!(extracted.into_data(), expected.into_data());
804    }
805
806    #[test]
807    fn dense_1_1_both_empty() {
808        let union = UnionArray::try_new(
809            str1_int3(),
810            ScalarBuffer::from(vec![]),       //empty union
811            Some(ScalarBuffer::from(vec![])), // dense
812            vec![
813                Arc::new(StringArray::new_null(0)), //empty target
814                Arc::new(Int32Array::new_null(0)),
815            ],
816        )
817        .unwrap();
818
819        let expected = StringArray::new_null(0);
820        let extracted = union_extract(&union, "str").unwrap();
821
822        assert_eq!(extracted.into_data(), expected.into_data());
823    }
824
825    #[test]
826    fn dense_1_2_empty_union_target_non_empty() {
827        let union = UnionArray::try_new(
828            str1_int3(),
829            ScalarBuffer::from(vec![]),       //empty union
830            Some(ScalarBuffer::from(vec![])), // dense
831            vec![
832                Arc::new(StringArray::new_null(1)), //non empty target
833                Arc::new(Int32Array::new_null(0)),
834            ],
835        )
836        .unwrap();
837
838        let expected = StringArray::new_null(0);
839        let extracted = union_extract(&union, "str").unwrap();
840
841        assert_eq!(extracted.into_data(), expected.into_data());
842    }
843
844    #[test]
845    fn dense_2_non_empty_union_target_empty() {
846        let union = UnionArray::try_new(
847            str1_int3(),
848            ScalarBuffer::from(vec![3, 3]),       //non empty union
849            Some(ScalarBuffer::from(vec![0, 1])), // dense
850            vec![
851                Arc::new(StringArray::new_null(0)), //empty target
852                Arc::new(Int32Array::new_null(2)),
853            ],
854        )
855        .unwrap();
856
857        let expected = StringArray::new_null(2);
858        let extracted = union_extract(&union, "str").unwrap();
859
860        assert_eq!(extracted.into_data(), expected.into_data());
861    }
862
863    #[test]
864    fn dense_3_1_null_target_smaller_len() {
865        let union = UnionArray::try_new(
866            str1_int3(),
867            ScalarBuffer::from(vec![3, 3]),       //non empty union
868            Some(ScalarBuffer::from(vec![0, 0])), //dense
869            vec![
870                Arc::new(StringArray::new_null(1)), //smaller target
871                Arc::new(Int32Array::new_null(2)),
872            ],
873        )
874        .unwrap();
875
876        let expected = StringArray::new_null(2);
877        let extracted = union_extract(&union, "str").unwrap();
878
879        assert_eq!(extracted.into_data(), expected.into_data());
880    }
881
882    #[test]
883    fn dense_3_2_null_target_equal_len() {
884        let union = UnionArray::try_new(
885            str1_int3(),
886            ScalarBuffer::from(vec![3, 3]),       //non empty union
887            Some(ScalarBuffer::from(vec![0, 0])), //dense
888            vec![
889                Arc::new(StringArray::new_null(2)), //equal len
890                Arc::new(Int32Array::new_null(2)),
891            ],
892        )
893        .unwrap();
894
895        let expected = StringArray::new_null(2);
896        let extracted = union_extract(&union, "str").unwrap();
897
898        assert_eq!(extracted.into_data(), expected.into_data());
899    }
900
901    #[test]
902    fn dense_3_3_null_target_bigger_len() {
903        let union = UnionArray::try_new(
904            str1_int3(),
905            ScalarBuffer::from(vec![3, 3]),       //non empty union
906            Some(ScalarBuffer::from(vec![0, 0])), //dense
907            vec![
908                Arc::new(StringArray::new_null(3)), //bigger len
909                Arc::new(Int32Array::new_null(3)),
910            ],
911        )
912        .unwrap();
913
914        let expected = StringArray::new_null(2);
915        let extracted = union_extract(&union, "str").unwrap();
916
917        assert_eq!(extracted.into_data(), expected.into_data());
918    }
919
920    #[test]
921    fn dense_4_1a_single_type_sequential_offsets_equal_len() {
922        let union = UnionArray::try_new(
923            // single field
924            str1(),
925            ScalarBuffer::from(vec![1, 1]),       //non empty union
926            Some(ScalarBuffer::from(vec![0, 1])), //sequential
927            vec![
928                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len, non null
929            ],
930        )
931        .unwrap();
932
933        let expected = StringArray::from(vec!["a1", "b2"]);
934        let extracted = union_extract(&union, "str").unwrap();
935
936        assert_eq!(extracted.into_data(), expected.into_data());
937    }
938
939    #[test]
940    fn dense_4_2a_single_type_sequential_offsets_bigger() {
941        let union = UnionArray::try_new(
942            // single field
943            str1(),
944            ScalarBuffer::from(vec![1, 1]),       //non empty union
945            Some(ScalarBuffer::from(vec![0, 1])), //sequential
946            vec![
947                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
948            ],
949        )
950        .unwrap();
951
952        let expected = StringArray::from(vec!["a1", "b2"]);
953        let extracted = union_extract(&union, "str").unwrap();
954
955        assert_eq!(extracted.into_data(), expected.into_data());
956    }
957
958    #[test]
959    fn dense_4_3a_single_type_non_sequential() {
960        let union = UnionArray::try_new(
961            // single field
962            str1(),
963            ScalarBuffer::from(vec![1, 1]),       //non empty union
964            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
965            vec![
966                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
967            ],
968        )
969        .unwrap();
970
971        let expected = StringArray::from(vec!["a1", "c3"]);
972        let extracted = union_extract(&union, "str").unwrap();
973
974        assert_eq!(extracted.into_data(), expected.into_data());
975    }
976
977    #[test]
978    fn dense_4_1b_empty_siblings_sequential_equal_len() {
979        let union = UnionArray::try_new(
980            // multiple fields
981            str1_int3(),
982            ScalarBuffer::from(vec![1, 1]),       //non empty union
983            Some(ScalarBuffer::from(vec![0, 1])), //sequential
984            vec![
985                Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null
986                Arc::new(Int32Array::new_null(0)),           //empty sibling
987            ],
988        )
989        .unwrap();
990
991        let expected = StringArray::from(vec!["a", "b"]);
992        let extracted = union_extract(&union, "str").unwrap();
993
994        assert_eq!(extracted.into_data(), expected.into_data());
995    }
996
997    #[test]
998    fn dense_4_2b_empty_siblings_sequential_bigger_len() {
999        let union = UnionArray::try_new(
1000            // multiple fields
1001            str1_int3(),
1002            ScalarBuffer::from(vec![1, 1]),       //non empty union
1003            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1004            vec![
1005                Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null
1006                Arc::new(Int32Array::new_null(0)),                //empty sibling
1007            ],
1008        )
1009        .unwrap();
1010
1011        let expected = StringArray::from(vec!["a", "b"]);
1012        let extracted = union_extract(&union, "str").unwrap();
1013
1014        assert_eq!(extracted.into_data(), expected.into_data());
1015    }
1016
1017    #[test]
1018    fn dense_4_3b_empty_sibling_non_sequential() {
1019        let union = UnionArray::try_new(
1020            // multiple fields
1021            str1_int3(),
1022            ScalarBuffer::from(vec![1, 1]),       //non empty union
1023            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1024            vec![
1025                Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null
1026                Arc::new(Int32Array::new_null(0)),                //empty sibling
1027            ],
1028        )
1029        .unwrap();
1030
1031        let expected = StringArray::from(vec!["a", "c"]);
1032        let extracted = union_extract(&union, "str").unwrap();
1033
1034        assert_eq!(extracted.into_data(), expected.into_data());
1035    }
1036
1037    #[test]
1038    fn dense_4_1c_all_types_match_sequential_equal_len() {
1039        let union = UnionArray::try_new(
1040            // multiple fields
1041            str1_int3(),
1042            ScalarBuffer::from(vec![1, 1]),       //all types match
1043            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1044            vec![
1045                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len
1046                Arc::new(Int32Array::new_null(2)),             //non empty sibling
1047            ],
1048        )
1049        .unwrap();
1050
1051        let expected = StringArray::from(vec!["a1", "b2"]);
1052        let extracted = union_extract(&union, "str").unwrap();
1053
1054        assert_eq!(extracted.into_data(), expected.into_data());
1055    }
1056
1057    #[test]
1058    fn dense_4_2c_all_types_match_sequential_bigger_len() {
1059        let union = UnionArray::try_new(
1060            // multiple fields
1061            str1_int3(),
1062            ScalarBuffer::from(vec![1, 1]),       //all types match
1063            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1064            vec![
1065                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len
1066                Arc::new(Int32Array::new_null(2)),                   //non empty sibling
1067            ],
1068        )
1069        .unwrap();
1070
1071        let expected = StringArray::from(vec!["a1", "b2"]);
1072        let extracted = union_extract(&union, "str").unwrap();
1073
1074        assert_eq!(extracted.into_data(), expected.into_data());
1075    }
1076
1077    #[test]
1078    fn dense_4_3c_all_types_match_non_sequential() {
1079        let union = UnionArray::try_new(
1080            // multiple fields
1081            str1_int3(),
1082            ScalarBuffer::from(vec![1, 1]),       //all types match
1083            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1084            vec![
1085                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1086                Arc::new(Int32Array::new_null(2)), //non empty sibling
1087            ],
1088        )
1089        .unwrap();
1090
1091        let expected = StringArray::from(vec!["a1", "b3"]);
1092        let extracted = union_extract(&union, "str").unwrap();
1093
1094        assert_eq!(extracted.into_data(), expected.into_data());
1095    }
1096
1097    #[test]
1098    fn dense_5_1a_none_match_less_len() {
1099        let union = UnionArray::try_new(
1100            // multiple fields
1101            str1_int3(),
1102            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1103            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1104            vec![
1105                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1106                Arc::new(Int32Array::from(vec![1, 2])),
1107            ],
1108        )
1109        .unwrap();
1110
1111        let expected = StringArray::new_null(5);
1112        let extracted = union_extract(&union, "str").unwrap();
1113
1114        assert_eq!(extracted.into_data(), expected.into_data());
1115    }
1116
1117    #[test]
1118    fn dense_5_1b_cant_contain_null_mask() {
1119        let target_fields = str1();
1120        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1121
1122        let union = UnionArray::try_new(
1123            // multiple fields
1124            str1_union3(target_type.clone()),
1125            ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches
1126            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1127            vec![
1128                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1129                Arc::new(
1130                    UnionArray::try_new(
1131                        target_fields.clone(),
1132                        ScalarBuffer::from(vec![1]),
1133                        None,
1134                        vec![Arc::new(StringArray::from(vec!["a"]))],
1135                    )
1136                    .unwrap(),
1137                ), // non empty
1138            ],
1139        )
1140        .unwrap();
1141
1142        let expected = new_null_array(&target_type, 5);
1143        let extracted = union_extract(&union, "union").unwrap();
1144
1145        assert_eq!(extracted.into_data(), expected.into_data());
1146    }
1147
1148    #[test]
1149    fn dense_5_2_none_match_equal_len() {
1150        let union = UnionArray::try_new(
1151            // multiple fields
1152            str1_int3(),
1153            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1154            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1155            vec![
1156                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len
1157                Arc::new(Int32Array::from(vec![1, 2])),
1158            ],
1159        )
1160        .unwrap();
1161
1162        let expected = StringArray::new_null(5);
1163        let extracted = union_extract(&union, "str").unwrap();
1164
1165        assert_eq!(extracted.into_data(), expected.into_data());
1166    }
1167
1168    #[test]
1169    fn dense_5_3_none_match_greater_len() {
1170        let union = UnionArray::try_new(
1171            // multiple fields
1172            str1_int3(),
1173            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1174            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1175            vec![
1176                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len
1177                Arc::new(Int32Array::from(vec![1, 2])),                                //non null
1178            ],
1179        )
1180        .unwrap();
1181
1182        let expected = StringArray::new_null(5);
1183        let extracted = union_extract(&union, "str").unwrap();
1184
1185        assert_eq!(extracted.into_data(), expected.into_data());
1186    }
1187
1188    #[test]
1189    fn dense_6_some_matches() {
1190        let union = UnionArray::try_new(
1191            // multiple fields
1192            str1_int3(),
1193            ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches
1194            Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense
1195            vec![
1196                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null
1197                Arc::new(Int32Array::from(vec![1, 2])),
1198            ],
1199        )
1200        .unwrap();
1201
1202        let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1203        let extracted = union_extract(&union, "int").unwrap();
1204
1205        assert_eq!(extracted.into_data(), expected.into_data());
1206    }
1207
1208    #[test]
1209    fn empty_sparse_union() {
1210        let union = UnionArray::try_new(
1211            UnionFields::empty(),
1212            ScalarBuffer::from(vec![]),
1213            None,
1214            vec![],
1215        )
1216        .unwrap();
1217
1218        assert_eq!(
1219            union_extract(&union, "a").unwrap_err().to_string(),
1220            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1221        );
1222    }
1223
1224    #[test]
1225    fn empty_dense_union() {
1226        let union = UnionArray::try_new(
1227            UnionFields::empty(),
1228            ScalarBuffer::from(vec![]),
1229            Some(ScalarBuffer::from(vec![])),
1230            vec![],
1231        )
1232        .unwrap();
1233
1234        assert_eq!(
1235            union_extract(&union, "a").unwrap_err().to_string(),
1236            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1237        );
1238    }
1239}