1use crate::take::take;
21use arrow_array::{
22    Array, ArrayRef, BooleanArray, Int32Array, Scalar, UnionArray, make_array, new_empty_array,
23    new_null_array,
24};
25use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer, bit_util};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80    let fields = match union_array.data_type() {
81        DataType::Union(fields, _) => fields,
82        _ => unreachable!(),
83    };
84
85    let (target_type_id, _) = fields
86        .iter()
87        .find(|field| field.1.name() == target)
88        .ok_or_else(|| {
89            ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90        })?;
91
92    match union_array.offsets() {
93        Some(_) => extract_dense(union_array, fields, target_type_id),
94        None => extract_sparse(union_array, fields, target_type_id),
95    }
96}
97
98fn extract_sparse(
99    union_array: &UnionArray,
100    fields: &UnionFields,
101    target_type_id: i8,
102) -> Result<ArrayRef, ArrowError> {
103    let target = union_array.child(target_type_id);
104
105    if fields.len() == 1 || union_array.is_empty() || target.null_count() == target.len() || target.data_type().is_null()
108    {
110        Ok(Arc::clone(target))
111    } else {
112        match eq_scalar(union_array.type_ids(), target_type_id) {
113            BoolValue::Scalar(true) => Ok(Arc::clone(target)),
115            BoolValue::Scalar(false) => {
117                if layout(target.data_type()).can_contain_null_mask {
118                    let data = unsafe {
121                        target
122                            .into_data()
123                            .into_builder()
124                            .nulls(Some(NullBuffer::new_null(target.len())))
125                            .build_unchecked()
126                    };
127
128                    Ok(make_array(data))
129                } else {
130                    Ok(new_null_array(target.data_type(), target.len()))
132                }
133            }
134            BoolValue::Buffer(selected) => {
136                if layout(target.data_type()).can_contain_null_mask {
137                    let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
139                        Some(nulls) => &selected & nulls.inner(),
142                        None => selected,
144                    };
145
146                    let data = unsafe {
148                        assert_eq!(nulls.len(), target.len());
149
150                        target
151                            .into_data()
152                            .into_builder()
153                            .nulls(Some(nulls.into()))
154                            .build_unchecked()
155                    };
156
157                    Ok(make_array(data))
158                } else {
159                    Ok(crate::zip::zip(
161                        &BooleanArray::new(selected, None),
162                        target,
163                        &Scalar::new(new_null_array(target.data_type(), 1)),
164                    )?)
165                }
166            }
167        }
168    }
169}
170
171fn extract_dense(
172    union_array: &UnionArray,
173    fields: &UnionFields,
174    target_type_id: i8,
175) -> Result<ArrayRef, ArrowError> {
176    let target = union_array.child(target_type_id);
177    let offsets = union_array.offsets().unwrap();
178
179    if union_array.is_empty() {
180        if target.is_empty() {
182            Ok(Arc::clone(target))
184        } else {
185            Ok(new_empty_array(target.data_type()))
187        }
188    } else if target.is_empty() {
189        Ok(new_null_array(target.data_type(), union_array.len()))
191    } else if target.null_count() == target.len() || target.data_type().is_null() {
192        match target.len().cmp(&union_array.len()) {
194            Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
196            Ordering::Equal => Ok(Arc::clone(target)),
198            Ordering::Greater => Ok(target.slice(0, union_array.len())),
200        }
201    } else if fields.len() == 1 || fields
203            .iter()
204            .filter(|(field_type_id, _)| *field_type_id != target_type_id)
205            .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
206    {
208        Ok(extract_dense_all_selected(union_array, target, offsets)?)
210    } else {
211        match eq_scalar(union_array.type_ids(), target_type_id) {
212            BoolValue::Scalar(true) => {
216                Ok(extract_dense_all_selected(union_array, target, offsets)?)
217            }
218            BoolValue::Scalar(false) => {
219                match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
223                    (Ordering::Less, _) | (_, false) => { Ok(new_null_array(target.data_type(), union_array.len()))
226                    }
227                    (Ordering::Equal, true) => {
229                        let data = unsafe {
231                            target
232                                .into_data()
233                                .into_builder()
234                                .nulls(Some(NullBuffer::new_null(union_array.len())))
235                                .build_unchecked()
236                        };
237
238                        Ok(make_array(data))
239                    }
240                    (Ordering::Greater, true) => {
242                        let data = unsafe {
244                            target
245                                .into_data()
246                                .slice(0, union_array.len())
247                                .into_builder()
248                                .nulls(Some(NullBuffer::new_null(union_array.len())))
249                                .build_unchecked()
250                        };
251
252                        Ok(make_array(data))
253                    }
254                }
255            }
256            BoolValue::Buffer(selected) => {
257                Ok(take(
259                    target,
260                    &Int32Array::try_new(offsets.clone(), Some(selected.into()))?,
261                    None,
262                )?)
263            }
264        }
265    }
266}
267
268fn extract_dense_all_selected(
269    union_array: &UnionArray,
270    target: &Arc<dyn Array>,
271    offsets: &ScalarBuffer<i32>,
272) -> Result<ArrayRef, ArrowError> {
273    let sequential =
274        target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
275
276    if sequential && target.len() == union_array.len() {
277        Ok(Arc::clone(target))
279    } else if sequential && target.len() > union_array.len() {
280        Ok(target.slice(offsets[0] as usize, union_array.len()))
282    } else {
283        let indices = Int32Array::try_new(offsets.clone(), None)?;
285
286        Ok(take(target, &indices, None)?)
287    }
288}
289
290const EQ_SCALAR_CHUNK_SIZE: usize = 512;
291
292#[derive(Debug, PartialEq)]
294enum BoolValue {
295    Scalar(bool),
298    Buffer(BooleanBuffer),
300}
301
302fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
303    eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
304}
305
306fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
307    type_ids
308        .chunks(chunk_size)
309        .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
310        .map(|chunk| chunk.len())
311        .sum()
312}
313
314fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
316    let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
317
318    let (set_bits, val) = if true_bits == type_ids.len() {
319        return BoolValue::Scalar(true);
320    } else if true_bits == 0 {
321        let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
322
323        if false_bits == type_ids.len() {
324            return BoolValue::Scalar(false);
325        } else {
326            (false_bits, false)
327        }
328    } else {
329        (true_bits, true)
330    };
331
332    let set_bits = set_bits - set_bits % 64;
334
335    let mut buffer =
336        MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
337
338    buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
339        chunk
340            .iter()
341            .copied()
342            .enumerate()
343            .fold(0, |packed, (bit_idx, v)| {
344                packed | (((v == target) as u64) << bit_idx)
345            })
346    }));
347
348    BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
349}
350
351const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
352
353fn is_sequential(offsets: &[i32]) -> bool {
354    is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
355}
356
357fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
358    if offsets.is_empty() {
359        return true;
360    }
361
362    if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
372        return false;
373    }
374
375    let chunks = offsets.chunks_exact(N);
376
377    let remainder = chunks.remainder();
378
379    chunks.enumerate().all(|(i, chunk)| {
380        let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
381
382        chunk_array
384            .iter()
385            .copied()
386            .enumerate()
387            .fold(true, |acc, (i, offset)| {
388                acc & (offset == chunk_array[0] + i as i32)
389            })
390            && offsets[0] + (i * N) as i32 == chunk_array[0] }) && remainder
392        .iter()
393        .copied()
394        .enumerate()
395        .fold(true, |acc, (i, offset)| {
396            acc & (offset == remainder[0] + i as i32)
397        }) }
399
400#[cfg(test)]
401mod tests {
402    use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract};
403    use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
404    use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405    use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
406    use std::sync::Arc;
407
408    #[test]
409    fn test_eq_scalar() {
410        const ARRAY_LEN: usize = 64 * 4;
413
414        const EQ_SCALAR_CHUNK_SIZE: usize = 3;
416
417        fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
418            eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
419        }
420
421        fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
422            BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
423        }
424
425        assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
426
427        assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
428        assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
429
430        let mut values = [1; ARRAY_LEN];
431
432        assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
433        assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
434
435        for i in 1..ARRAY_LEN {
437            assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
438            assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
439        }
440
441        for i in 0..ARRAY_LEN {
443            values[i] = 2;
444
445            assert_eq!(
446                eq_scalar(&values, 1),
447                BoolValue::Buffer(cross_check(&values, 1))
448            );
449            assert_eq!(
450                eq_scalar(&values, 2),
451                BoolValue::Buffer(cross_check(&values, 2))
452            );
453
454            values[i] = 1;
455        }
456    }
457
458    #[test]
459    fn test_is_sequential() {
460        const CHUNK_SIZE: usize = 3;
466        fn is_sequential(v: &[i32]) -> bool {
473            is_sequential_generic::<CHUNK_SIZE>(v)
474        }
475
476        assert!(is_sequential(&[])); assert!(is_sequential(&[1])); assert!(is_sequential(&[1, 2]));
480        assert!(is_sequential(&[1, 2, 3]));
481        assert!(is_sequential(&[1, 2, 3, 4]));
482        assert!(is_sequential(&[1, 2, 3, 4, 5]));
483        assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
484        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
485        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
486
487        assert!(!is_sequential(&[8, 7]));
488        assert!(!is_sequential(&[8, 7, 6]));
489        assert!(!is_sequential(&[8, 7, 6, 5]));
490        assert!(!is_sequential(&[8, 7, 6, 5, 4]));
491        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
492        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
493        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
494
495        assert!(!is_sequential(&[0, 2]));
496        assert!(!is_sequential(&[1, 0]));
497
498        assert!(!is_sequential(&[0, 2, 3]));
499        assert!(!is_sequential(&[1, 0, 3]));
500        assert!(!is_sequential(&[1, 2, 0]));
501
502        assert!(!is_sequential(&[0, 2, 3, 4]));
503        assert!(!is_sequential(&[1, 0, 3, 4]));
504        assert!(!is_sequential(&[1, 2, 0, 4]));
505        assert!(!is_sequential(&[1, 2, 3, 0]));
506
507        assert!(!is_sequential(&[0, 2, 3, 4, 5]));
508        assert!(!is_sequential(&[1, 0, 3, 4, 5]));
509        assert!(!is_sequential(&[1, 2, 0, 4, 5]));
510        assert!(!is_sequential(&[1, 2, 3, 0, 5]));
511        assert!(!is_sequential(&[1, 2, 3, 4, 0]));
512
513        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
514        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
515        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
516        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
517        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
518        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
519
520        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
521        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
522        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
523        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
524        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
525        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
526        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
527
528        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
529        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
530        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
531        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
532        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
533        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
534        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
535        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
536
537        assert!(!is_sequential(&[1, 2, 3, 5]));
539        assert!(!is_sequential(&[1, 2, 3, 5, 6]));
540        assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
541        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
542        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
543    }
544
545    fn str1() -> UnionFields {
546        UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, true)])
547    }
548
549    fn str1_int3() -> UnionFields {
550        UnionFields::new(
551            vec![1, 3],
552            vec![
553                Field::new("str", DataType::Utf8, true),
554                Field::new("int", DataType::Int32, true),
555            ],
556        )
557    }
558
559    #[test]
560    fn sparse_1_1_single_field() {
561        let union = UnionArray::try_new(
562            str1(),
564            ScalarBuffer::from(vec![1, 1]), None,                           vec![
567                Arc::new(StringArray::from(vec!["a", "b"])), ],
569        )
570        .unwrap();
571
572        let expected = StringArray::from(vec!["a", "b"]);
573        let extracted = union_extract(&union, "str").unwrap();
574
575        assert_eq!(extracted.into_data(), expected.into_data());
576    }
577
578    #[test]
579    fn sparse_1_2_empty() {
580        let union = UnionArray::try_new(
581            str1_int3(),
583            ScalarBuffer::from(vec![]), None,                       vec![
586                Arc::new(StringArray::new_null(0)),
587                Arc::new(Int32Array::new_null(0)),
588            ],
589        )
590        .unwrap();
591
592        let expected = StringArray::new_null(0);
593        let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
596    }
597
598    #[test]
599    fn sparse_1_3a_null_target() {
600        let union = UnionArray::try_new(
601            UnionFields::new(
603                vec![1, 3],
604                vec![
605                    Field::new("str", DataType::Utf8, true),
606                    Field::new("null", DataType::Null, true), ],
608            ),
609            ScalarBuffer::from(vec![1]), None,                        vec![
612                Arc::new(StringArray::new_null(1)),
613                Arc::new(NullArray::new(1)), ],
615        )
616        .unwrap();
617
618        let expected = NullArray::new(1);
619        let extracted = union_extract(&union, "null").unwrap();
620
621        assert_eq!(extracted.into_data(), expected.into_data());
622    }
623
624    #[test]
625    fn sparse_1_3b_null_target() {
626        let union = UnionArray::try_new(
627            str1_int3(),
629            ScalarBuffer::from(vec![1]), None,                        vec![
632                Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(1)),
634            ],
635        )
636        .unwrap();
637
638        let expected = StringArray::new_null(1);
639        let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
642    }
643
644    #[test]
645    fn sparse_2_all_types_match() {
646        let union = UnionArray::try_new(
647            str1_int3(),
649            ScalarBuffer::from(vec![3, 3]), None,                           vec![
652                Arc::new(StringArray::new_null(2)),
653                Arc::new(Int32Array::from(vec![1, 4])), ],
655        )
656        .unwrap();
657
658        let expected = Int32Array::from(vec![1, 4]);
659        let extracted = union_extract(&union, "int").unwrap();
660
661        assert_eq!(extracted.into_data(), expected.into_data());
662    }
663
664    #[test]
665    fn sparse_3_1_none_match_target_can_contain_null_mask() {
666        let union = UnionArray::try_new(
667            str1_int3(),
669            ScalarBuffer::from(vec![1, 1, 1, 1]), None,                                 vec![
672                Arc::new(StringArray::new_null(4)),
673                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
675        )
676        .unwrap();
677
678        let expected = Int32Array::new_null(4);
679        let extracted = union_extract(&union, "int").unwrap();
680
681        assert_eq!(extracted.into_data(), expected.into_data());
682    }
683
684    fn str1_union3(union3_datatype: DataType) -> UnionFields {
685        UnionFields::new(
686            vec![1, 3],
687            vec![
688                Field::new("str", DataType::Utf8, true),
689                Field::new("union", union3_datatype, true),
690            ],
691        )
692    }
693
694    #[test]
695    fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
696        let target_fields = str1();
697        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
698
699        let union = UnionArray::try_new(
700            str1_union3(target_type.clone()),
702            ScalarBuffer::from(vec![1, 1]), None,                           vec![
705                Arc::new(StringArray::new_null(2)),
706                Arc::new(
708                    UnionArray::try_new(
709                        target_fields.clone(),
710                        ScalarBuffer::from(vec![1, 1]),
711                        None,
712                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
713                    )
714                    .unwrap(),
715                ),
716            ],
717        )
718        .unwrap();
719
720        let expected = new_null_array(&target_type, 2);
721        let extracted = union_extract(&union, "union").unwrap();
722
723        assert_eq!(extracted.into_data(), expected.into_data());
724    }
725
726    #[test]
727    fn sparse_4_1_1_target_with_nulls() {
728        let union = UnionArray::try_new(
729            str1_int3(),
731            ScalarBuffer::from(vec![3, 3, 1, 1]), None,                                 vec![
734                Arc::new(StringArray::new_null(4)),
735                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
737        )
738        .unwrap();
739
740        let expected = Int32Array::from(vec![None, Some(4), None, None]);
741        let extracted = union_extract(&union, "int").unwrap();
742
743        assert_eq!(extracted.into_data(), expected.into_data());
744    }
745
746    #[test]
747    fn sparse_4_1_2_target_without_nulls() {
748        let union = UnionArray::try_new(
749            str1_int3(),
751            ScalarBuffer::from(vec![1, 3, 3]), None,                              vec![
754                Arc::new(StringArray::new_null(3)),
755                Arc::new(Int32Array::from(vec![2, 4, 8])), ],
757        )
758        .unwrap();
759
760        let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
761        let extracted = union_extract(&union, "int").unwrap();
762
763        assert_eq!(extracted.into_data(), expected.into_data());
764    }
765
766    #[test]
767    fn sparse_4_2_some_match_target_cant_contain_null_mask() {
768        let target_fields = str1();
769        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
770
771        let union = UnionArray::try_new(
772            str1_union3(target_type),
774            ScalarBuffer::from(vec![3, 1]), None,                           vec![
777                Arc::new(StringArray::new_null(2)),
778                Arc::new(
779                    UnionArray::try_new(
780                        target_fields.clone(),
781                        ScalarBuffer::from(vec![1, 1]),
782                        None,
783                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
784                    )
785                    .unwrap(),
786                ),
787            ],
788        )
789        .unwrap();
790
791        let expected = UnionArray::try_new(
792            target_fields,
793            ScalarBuffer::from(vec![1, 1]),
794            None,
795            vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
796        )
797        .unwrap();
798        let extracted = union_extract(&union, "union").unwrap();
799
800        assert_eq!(extracted.into_data(), expected.into_data());
801    }
802
803    #[test]
804    fn dense_1_1_both_empty() {
805        let union = UnionArray::try_new(
806            str1_int3(),
807            ScalarBuffer::from(vec![]),       Some(ScalarBuffer::from(vec![])), vec![
810                Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(0)),
812            ],
813        )
814        .unwrap();
815
816        let expected = StringArray::new_null(0);
817        let extracted = union_extract(&union, "str").unwrap();
818
819        assert_eq!(extracted.into_data(), expected.into_data());
820    }
821
822    #[test]
823    fn dense_1_2_empty_union_target_non_empty() {
824        let union = UnionArray::try_new(
825            str1_int3(),
826            ScalarBuffer::from(vec![]),       Some(ScalarBuffer::from(vec![])), vec![
829                Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(0)),
831            ],
832        )
833        .unwrap();
834
835        let expected = StringArray::new_null(0);
836        let extracted = union_extract(&union, "str").unwrap();
837
838        assert_eq!(extracted.into_data(), expected.into_data());
839    }
840
841    #[test]
842    fn dense_2_non_empty_union_target_empty() {
843        let union = UnionArray::try_new(
844            str1_int3(),
845            ScalarBuffer::from(vec![3, 3]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
848                Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(2)),
850            ],
851        )
852        .unwrap();
853
854        let expected = StringArray::new_null(2);
855        let extracted = union_extract(&union, "str").unwrap();
856
857        assert_eq!(extracted.into_data(), expected.into_data());
858    }
859
860    #[test]
861    fn dense_3_1_null_target_smaller_len() {
862        let union = UnionArray::try_new(
863            str1_int3(),
864            ScalarBuffer::from(vec![3, 3]),       Some(ScalarBuffer::from(vec![0, 0])), vec![
867                Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(2)),
869            ],
870        )
871        .unwrap();
872
873        let expected = StringArray::new_null(2);
874        let extracted = union_extract(&union, "str").unwrap();
875
876        assert_eq!(extracted.into_data(), expected.into_data());
877    }
878
879    #[test]
880    fn dense_3_2_null_target_equal_len() {
881        let union = UnionArray::try_new(
882            str1_int3(),
883            ScalarBuffer::from(vec![3, 3]),       Some(ScalarBuffer::from(vec![0, 0])), vec![
886                Arc::new(StringArray::new_null(2)), Arc::new(Int32Array::new_null(2)),
888            ],
889        )
890        .unwrap();
891
892        let expected = StringArray::new_null(2);
893        let extracted = union_extract(&union, "str").unwrap();
894
895        assert_eq!(extracted.into_data(), expected.into_data());
896    }
897
898    #[test]
899    fn dense_3_3_null_target_bigger_len() {
900        let union = UnionArray::try_new(
901            str1_int3(),
902            ScalarBuffer::from(vec![3, 3]),       Some(ScalarBuffer::from(vec![0, 0])), vec![
905                Arc::new(StringArray::new_null(3)), Arc::new(Int32Array::new_null(3)),
907            ],
908        )
909        .unwrap();
910
911        let expected = StringArray::new_null(2);
912        let extracted = union_extract(&union, "str").unwrap();
913
914        assert_eq!(extracted.into_data(), expected.into_data());
915    }
916
917    #[test]
918    fn dense_4_1a_single_type_sequential_offsets_equal_len() {
919        let union = UnionArray::try_new(
920            str1(),
922            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
925                Arc::new(StringArray::from(vec!["a1", "b2"])), ],
927        )
928        .unwrap();
929
930        let expected = StringArray::from(vec!["a1", "b2"]);
931        let extracted = union_extract(&union, "str").unwrap();
932
933        assert_eq!(extracted.into_data(), expected.into_data());
934    }
935
936    #[test]
937    fn dense_4_2a_single_type_sequential_offsets_bigger() {
938        let union = UnionArray::try_new(
939            str1(),
941            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
944                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
946        )
947        .unwrap();
948
949        let expected = StringArray::from(vec!["a1", "b2"]);
950        let extracted = union_extract(&union, "str").unwrap();
951
952        assert_eq!(extracted.into_data(), expected.into_data());
953    }
954
955    #[test]
956    fn dense_4_3a_single_type_non_sequential() {
957        let union = UnionArray::try_new(
958            str1(),
960            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 2])), vec![
963                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
965        )
966        .unwrap();
967
968        let expected = StringArray::from(vec!["a1", "c3"]);
969        let extracted = union_extract(&union, "str").unwrap();
970
971        assert_eq!(extracted.into_data(), expected.into_data());
972    }
973
974    #[test]
975    fn dense_4_1b_empty_siblings_sequential_equal_len() {
976        let union = UnionArray::try_new(
977            str1_int3(),
979            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
982                Arc::new(StringArray::from(vec!["a", "b"])), Arc::new(Int32Array::new_null(0)),           ],
985        )
986        .unwrap();
987
988        let expected = StringArray::from(vec!["a", "b"]);
989        let extracted = union_extract(&union, "str").unwrap();
990
991        assert_eq!(extracted.into_data(), expected.into_data());
992    }
993
994    #[test]
995    fn dense_4_2b_empty_siblings_sequential_bigger_len() {
996        let union = UnionArray::try_new(
997            str1_int3(),
999            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
1002                Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)),                ],
1005        )
1006        .unwrap();
1007
1008        let expected = StringArray::from(vec!["a", "b"]);
1009        let extracted = union_extract(&union, "str").unwrap();
1010
1011        assert_eq!(extracted.into_data(), expected.into_data());
1012    }
1013
1014    #[test]
1015    fn dense_4_3b_empty_sibling_non_sequential() {
1016        let union = UnionArray::try_new(
1017            str1_int3(),
1019            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 2])), vec![
1022                Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)),                ],
1025        )
1026        .unwrap();
1027
1028        let expected = StringArray::from(vec!["a", "c"]);
1029        let extracted = union_extract(&union, "str").unwrap();
1030
1031        assert_eq!(extracted.into_data(), expected.into_data());
1032    }
1033
1034    #[test]
1035    fn dense_4_1c_all_types_match_sequential_equal_len() {
1036        let union = UnionArray::try_new(
1037            str1_int3(),
1039            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
1042                Arc::new(StringArray::from(vec!["a1", "b2"])), Arc::new(Int32Array::new_null(2)),             ],
1045        )
1046        .unwrap();
1047
1048        let expected = StringArray::from(vec!["a1", "b2"]);
1049        let extracted = union_extract(&union, "str").unwrap();
1050
1051        assert_eq!(extracted.into_data(), expected.into_data());
1052    }
1053
1054    #[test]
1055    fn dense_4_2c_all_types_match_sequential_bigger_len() {
1056        let union = UnionArray::try_new(
1057            str1_int3(),
1059            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 1])), vec![
1062                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), Arc::new(Int32Array::new_null(2)),                   ],
1065        )
1066        .unwrap();
1067
1068        let expected = StringArray::from(vec!["a1", "b2"]);
1069        let extracted = union_extract(&union, "str").unwrap();
1070
1071        assert_eq!(extracted.into_data(), expected.into_data());
1072    }
1073
1074    #[test]
1075    fn dense_4_3c_all_types_match_non_sequential() {
1076        let union = UnionArray::try_new(
1077            str1_int3(),
1079            ScalarBuffer::from(vec![1, 1]),       Some(ScalarBuffer::from(vec![0, 2])), vec![
1082                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1083                Arc::new(Int32Array::new_null(2)), ],
1085        )
1086        .unwrap();
1087
1088        let expected = StringArray::from(vec!["a1", "b3"]);
1089        let extracted = union_extract(&union, "str").unwrap();
1090
1091        assert_eq!(extracted.into_data(), expected.into_data());
1092    }
1093
1094    #[test]
1095    fn dense_5_1a_none_match_less_len() {
1096        let union = UnionArray::try_new(
1097            str1_int3(),
1099            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1102                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1104            ],
1105        )
1106        .unwrap();
1107
1108        let expected = StringArray::new_null(5);
1109        let extracted = union_extract(&union, "str").unwrap();
1110
1111        assert_eq!(extracted.into_data(), expected.into_data());
1112    }
1113
1114    #[test]
1115    fn dense_5_1b_cant_contain_null_mask() {
1116        let target_fields = str1();
1117        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1118
1119        let union = UnionArray::try_new(
1120            str1_union3(target_type.clone()),
1122            ScalarBuffer::from(vec![1, 1, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1125                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(
1127                    UnionArray::try_new(
1128                        target_fields.clone(),
1129                        ScalarBuffer::from(vec![1]),
1130                        None,
1131                        vec![Arc::new(StringArray::from(vec!["a"]))],
1132                    )
1133                    .unwrap(),
1134                ), ],
1136        )
1137        .unwrap();
1138
1139        let expected = new_null_array(&target_type, 5);
1140        let extracted = union_extract(&union, "union").unwrap();
1141
1142        assert_eq!(extracted.into_data(), expected.into_data());
1143    }
1144
1145    #[test]
1146    fn dense_5_2_none_match_equal_len() {
1147        let union = UnionArray::try_new(
1148            str1_int3(),
1150            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1153                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), Arc::new(Int32Array::from(vec![1, 2])),
1155            ],
1156        )
1157        .unwrap();
1158
1159        let expected = StringArray::new_null(5);
1160        let extracted = union_extract(&union, "str").unwrap();
1161
1162        assert_eq!(extracted.into_data(), expected.into_data());
1163    }
1164
1165    #[test]
1166    fn dense_5_3_none_match_greater_len() {
1167        let union = UnionArray::try_new(
1168            str1_int3(),
1170            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1173                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), Arc::new(Int32Array::from(vec![1, 2])),                                ],
1176        )
1177        .unwrap();
1178
1179        let expected = StringArray::new_null(5);
1180        let extracted = union_extract(&union, "str").unwrap();
1181
1182        assert_eq!(extracted.into_data(), expected.into_data());
1183    }
1184
1185    #[test]
1186    fn dense_6_some_matches() {
1187        let union = UnionArray::try_new(
1188            str1_int3(),
1190            ScalarBuffer::from(vec![3, 3, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), vec![
1193                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1195            ],
1196        )
1197        .unwrap();
1198
1199        let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1200        let extracted = union_extract(&union, "int").unwrap();
1201
1202        assert_eq!(extracted.into_data(), expected.into_data());
1203    }
1204
1205    #[test]
1206    fn empty_sparse_union() {
1207        let union = UnionArray::try_new(
1208            UnionFields::empty(),
1209            ScalarBuffer::from(vec![]),
1210            None,
1211            vec![],
1212        )
1213        .unwrap();
1214
1215        assert_eq!(
1216            union_extract(&union, "a").unwrap_err().to_string(),
1217            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1218        );
1219    }
1220
1221    #[test]
1222    fn empty_dense_union() {
1223        let union = UnionArray::try_new(
1224            UnionFields::empty(),
1225            ScalarBuffer::from(vec![]),
1226            Some(ScalarBuffer::from(vec![])),
1227            vec![],
1228        )
1229        .unwrap();
1230
1231        assert_eq!(
1232            union_extract(&union, "a").unwrap_err().to_string(),
1233            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1234        );
1235    }
1236}