1use crate::take::take;
21use arrow_array::{
22 Array, ArrayRef, BooleanArray, Int32Array, Scalar, UnionArray, make_array, new_empty_array,
23 new_null_array,
24};
25use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer, bit_util};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80 let fields = match union_array.data_type() {
81 DataType::Union(fields, _) => fields,
82 _ => unreachable!(),
83 };
84
85 let (target_type_id, _) = fields
86 .iter()
87 .find(|field| field.1.name() == target)
88 .ok_or_else(|| {
89 ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90 })?;
91
92 union_extract_impl(union_array, fields, target_type_id)
93}
94
95pub fn union_extract_by_id(
104 union_array: &UnionArray,
105 target_type_id: i8,
106) -> Result<ArrayRef, ArrowError> {
107 let fields = match union_array.data_type() {
108 DataType::Union(fields, _) => fields,
109 _ => unreachable!(),
110 };
111
112 if fields.iter().all(|(id, _)| id != target_type_id) {
113 return Err(ArrowError::InvalidArgumentError(format!(
114 "type_id {target_type_id} not found on union"
115 )));
116 }
117
118 union_extract_impl(union_array, fields, target_type_id)
119}
120
121fn union_extract_impl(
122 union_array: &UnionArray,
123 fields: &UnionFields,
124 target_type_id: i8,
125) -> Result<ArrayRef, ArrowError> {
126 match union_array.offsets() {
127 Some(_) => extract_dense(union_array, fields, target_type_id),
128 None => extract_sparse(union_array, fields, target_type_id),
129 }
130}
131
132fn extract_sparse(
133 union_array: &UnionArray,
134 fields: &UnionFields,
135 target_type_id: i8,
136) -> Result<ArrayRef, ArrowError> {
137 let target = union_array.child(target_type_id);
138
139 if fields.len() == 1 || union_array.is_empty() || target.null_count() == target.len() || target.data_type().is_null()
142 {
144 Ok(Arc::clone(target))
145 } else {
146 match eq_scalar(union_array.type_ids(), target_type_id) {
147 BoolValue::Scalar(true) => Ok(Arc::clone(target)),
149 BoolValue::Scalar(false) => {
151 if layout(target.data_type()).can_contain_null_mask {
152 let data = unsafe {
155 target
156 .into_data()
157 .into_builder()
158 .nulls(Some(NullBuffer::new_null(target.len())))
159 .build_unchecked()
160 };
161
162 Ok(make_array(data))
163 } else {
164 Ok(new_null_array(target.data_type(), target.len()))
166 }
167 }
168 BoolValue::Buffer(selected) => {
170 if layout(target.data_type()).can_contain_null_mask {
171 let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
173 Some(nulls) => &selected & nulls.inner(),
176 None => selected,
178 };
179
180 let data = unsafe {
182 assert_eq!(nulls.len(), target.len());
183
184 target
185 .into_data()
186 .into_builder()
187 .nulls(Some(nulls.into()))
188 .build_unchecked()
189 };
190
191 Ok(make_array(data))
192 } else {
193 Ok(crate::zip::zip(
195 &BooleanArray::new(selected, None),
196 target,
197 &Scalar::new(new_null_array(target.data_type(), 1)),
198 )?)
199 }
200 }
201 }
202 }
203}
204
205fn extract_dense(
206 union_array: &UnionArray,
207 fields: &UnionFields,
208 target_type_id: i8,
209) -> Result<ArrayRef, ArrowError> {
210 let target = union_array.child(target_type_id);
211 let offsets = union_array.offsets().unwrap();
212
213 if union_array.is_empty() {
214 if target.is_empty() {
216 Ok(Arc::clone(target))
218 } else {
219 Ok(new_empty_array(target.data_type()))
221 }
222 } else if target.is_empty() {
223 Ok(new_null_array(target.data_type(), union_array.len()))
225 } else if target.null_count() == target.len() || target.data_type().is_null() {
226 match target.len().cmp(&union_array.len()) {
228 Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
230 Ordering::Equal => Ok(Arc::clone(target)),
232 Ordering::Greater => Ok(target.slice(0, union_array.len())),
234 }
235 } else if fields.len() == 1 || fields
237 .iter()
238 .filter(|(field_type_id, _)| *field_type_id != target_type_id)
239 .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
240 {
242 Ok(extract_dense_all_selected(union_array, target, offsets)?)
244 } else {
245 match eq_scalar(union_array.type_ids(), target_type_id) {
246 BoolValue::Scalar(true) => {
250 Ok(extract_dense_all_selected(union_array, target, offsets)?)
251 }
252 BoolValue::Scalar(false) => {
253 match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
257 (Ordering::Less, _) | (_, false) => { Ok(new_null_array(target.data_type(), union_array.len()))
260 }
261 (Ordering::Equal, true) => {
263 let data = unsafe {
265 target
266 .into_data()
267 .into_builder()
268 .nulls(Some(NullBuffer::new_null(union_array.len())))
269 .build_unchecked()
270 };
271
272 Ok(make_array(data))
273 }
274 (Ordering::Greater, true) => {
276 let data = unsafe {
278 target
279 .into_data()
280 .slice(0, union_array.len())
281 .into_builder()
282 .nulls(Some(NullBuffer::new_null(union_array.len())))
283 .build_unchecked()
284 };
285
286 Ok(make_array(data))
287 }
288 }
289 }
290 BoolValue::Buffer(selected) => {
291 Ok(take(
293 target,
294 &Int32Array::try_new(offsets.clone(), Some(selected.into()))?,
295 None,
296 )?)
297 }
298 }
299 }
300}
301
302fn extract_dense_all_selected(
303 union_array: &UnionArray,
304 target: &Arc<dyn Array>,
305 offsets: &ScalarBuffer<i32>,
306) -> Result<ArrayRef, ArrowError> {
307 let sequential =
308 target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
309
310 if sequential && target.len() == union_array.len() {
311 Ok(Arc::clone(target))
313 } else if sequential && target.len() > union_array.len() {
314 Ok(target.slice(offsets[0] as usize, union_array.len()))
316 } else {
317 let indices = Int32Array::try_new(offsets.clone(), None)?;
319
320 Ok(take(target, &indices, None)?)
321 }
322}
323
324const EQ_SCALAR_CHUNK_SIZE: usize = 512;
325
326#[derive(Debug, PartialEq)]
328enum BoolValue {
329 Scalar(bool),
332 Buffer(BooleanBuffer),
334}
335
336fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
337 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
338}
339
340fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
341 type_ids
342 .chunks(chunk_size)
343 .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
344 .map(|chunk| chunk.len())
345 .sum()
346}
347
348fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
350 let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
351
352 let (set_bits, val) = if true_bits == type_ids.len() {
353 return BoolValue::Scalar(true);
354 } else if true_bits == 0 {
355 let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
356
357 if false_bits == type_ids.len() {
358 return BoolValue::Scalar(false);
359 } else {
360 (false_bits, false)
361 }
362 } else {
363 (true_bits, true)
364 };
365
366 let set_bits = set_bits - set_bits % 64;
368
369 let mut buffer =
370 MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
371
372 buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
373 chunk
374 .iter()
375 .copied()
376 .enumerate()
377 .fold(0, |packed, (bit_idx, v)| {
378 packed | (((v == target) as u64) << bit_idx)
379 })
380 }));
381
382 BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
383}
384
385const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
386
387fn is_sequential(offsets: &[i32]) -> bool {
388 is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
389}
390
391fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
392 if offsets.is_empty() {
393 return true;
394 }
395
396 if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
406 return false;
407 }
408
409 let chunks = offsets.chunks_exact(N);
410
411 let remainder = chunks.remainder();
412
413 chunks.enumerate().all(|(i, chunk)| {
414 let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
415
416 chunk_array
418 .iter()
419 .copied()
420 .enumerate()
421 .fold(true, |acc, (i, offset)| {
422 acc & (offset == chunk_array[0] + i as i32)
423 })
424 && offsets[0] + (i * N) as i32 == chunk_array[0] }) && remainder
426 .iter()
427 .copied()
428 .enumerate()
429 .fold(true, |acc, (i, offset)| {
430 acc & (offset == remainder[0] + i as i32)
431 }) }
433
434#[cfg(test)]
435mod tests {
436 use super::{
437 BoolValue, eq_scalar_inner, is_sequential_generic, union_extract, union_extract_by_id,
438 };
439 use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
440 use arrow_buffer::{BooleanBuffer, ScalarBuffer};
441 use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
442 use std::sync::Arc;
443
444 #[test]
445 fn test_eq_scalar() {
446 const ARRAY_LEN: usize = 64 * 4;
449
450 const EQ_SCALAR_CHUNK_SIZE: usize = 3;
452
453 fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
454 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
455 }
456
457 fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
458 BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
459 }
460
461 assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
462
463 assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
464 assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
465
466 let mut values = [1; ARRAY_LEN];
467
468 assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
469 assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
470
471 for i in 1..ARRAY_LEN {
473 assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
474 assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
475 }
476
477 for i in 0..ARRAY_LEN {
479 values[i] = 2;
480
481 assert_eq!(
482 eq_scalar(&values, 1),
483 BoolValue::Buffer(cross_check(&values, 1))
484 );
485 assert_eq!(
486 eq_scalar(&values, 2),
487 BoolValue::Buffer(cross_check(&values, 2))
488 );
489
490 values[i] = 1;
491 }
492 }
493
494 #[test]
495 fn test_is_sequential() {
496 const CHUNK_SIZE: usize = 3;
502 fn is_sequential(v: &[i32]) -> bool {
509 is_sequential_generic::<CHUNK_SIZE>(v)
510 }
511
512 assert!(is_sequential(&[])); assert!(is_sequential(&[1])); assert!(is_sequential(&[1, 2]));
516 assert!(is_sequential(&[1, 2, 3]));
517 assert!(is_sequential(&[1, 2, 3, 4]));
518 assert!(is_sequential(&[1, 2, 3, 4, 5]));
519 assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
520 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
521 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
522
523 assert!(!is_sequential(&[8, 7]));
524 assert!(!is_sequential(&[8, 7, 6]));
525 assert!(!is_sequential(&[8, 7, 6, 5]));
526 assert!(!is_sequential(&[8, 7, 6, 5, 4]));
527 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
528 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
529 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
530
531 assert!(!is_sequential(&[0, 2]));
532 assert!(!is_sequential(&[1, 0]));
533
534 assert!(!is_sequential(&[0, 2, 3]));
535 assert!(!is_sequential(&[1, 0, 3]));
536 assert!(!is_sequential(&[1, 2, 0]));
537
538 assert!(!is_sequential(&[0, 2, 3, 4]));
539 assert!(!is_sequential(&[1, 0, 3, 4]));
540 assert!(!is_sequential(&[1, 2, 0, 4]));
541 assert!(!is_sequential(&[1, 2, 3, 0]));
542
543 assert!(!is_sequential(&[0, 2, 3, 4, 5]));
544 assert!(!is_sequential(&[1, 0, 3, 4, 5]));
545 assert!(!is_sequential(&[1, 2, 0, 4, 5]));
546 assert!(!is_sequential(&[1, 2, 3, 0, 5]));
547 assert!(!is_sequential(&[1, 2, 3, 4, 0]));
548
549 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
550 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
551 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
552 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
553 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
554 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
555
556 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
557 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
558 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
559 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
560 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
561 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
562 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
563
564 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
565 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
566 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
567 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
568 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
569 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
570 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
571 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
572
573 assert!(!is_sequential(&[1, 2, 3, 5]));
575 assert!(!is_sequential(&[1, 2, 3, 5, 6]));
576 assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
577 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
578 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
579 }
580
581 fn str1() -> UnionFields {
582 UnionFields::try_new(vec![1], vec![Field::new("str", DataType::Utf8, true)]).unwrap()
583 }
584
585 fn str1_int3() -> UnionFields {
586 UnionFields::try_new(
587 vec![1, 3],
588 vec![
589 Field::new("str", DataType::Utf8, true),
590 Field::new("int", DataType::Int32, true),
591 ],
592 )
593 .unwrap()
594 }
595
596 #[test]
597 fn sparse_1_1_single_field() {
598 let union = UnionArray::try_new(
599 str1(),
601 ScalarBuffer::from(vec![1, 1]), None, vec![
604 Arc::new(StringArray::from(vec!["a", "b"])), ],
606 )
607 .unwrap();
608
609 let expected = StringArray::from(vec!["a", "b"]);
610 let extracted = union_extract(&union, "str").unwrap();
611
612 assert_eq!(extracted.into_data(), expected.into_data());
613 }
614
615 #[test]
616 fn sparse_1_2_empty() {
617 let union = UnionArray::try_new(
618 str1_int3(),
620 ScalarBuffer::from(vec![]), None, vec![
623 Arc::new(StringArray::new_null(0)),
624 Arc::new(Int32Array::new_null(0)),
625 ],
626 )
627 .unwrap();
628
629 let expected = StringArray::new_null(0);
630 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
633 }
634
635 #[test]
636 fn sparse_1_3a_null_target() {
637 let union = UnionArray::try_new(
638 UnionFields::try_new(
640 vec![1, 3],
641 vec![
642 Field::new("str", DataType::Utf8, true),
643 Field::new("null", DataType::Null, true), ],
645 )
646 .unwrap(),
647 ScalarBuffer::from(vec![1]), None, vec![
650 Arc::new(StringArray::new_null(1)),
651 Arc::new(NullArray::new(1)), ],
653 )
654 .unwrap();
655
656 let expected = NullArray::new(1);
657 let extracted = union_extract(&union, "null").unwrap();
658
659 assert_eq!(extracted.into_data(), expected.into_data());
660 }
661
662 #[test]
663 fn sparse_1_3b_null_target() {
664 let union = UnionArray::try_new(
665 str1_int3(),
667 ScalarBuffer::from(vec![1]), None, vec![
670 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(1)),
672 ],
673 )
674 .unwrap();
675
676 let expected = StringArray::new_null(1);
677 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
680 }
681
682 #[test]
683 fn sparse_2_all_types_match() {
684 let union = UnionArray::try_new(
685 str1_int3(),
687 ScalarBuffer::from(vec![3, 3]), None, vec![
690 Arc::new(StringArray::new_null(2)),
691 Arc::new(Int32Array::from(vec![1, 4])), ],
693 )
694 .unwrap();
695
696 let expected = Int32Array::from(vec![1, 4]);
697 let extracted = union_extract(&union, "int").unwrap();
698
699 assert_eq!(extracted.into_data(), expected.into_data());
700 }
701
702 #[test]
703 fn sparse_3_1_none_match_target_can_contain_null_mask() {
704 let union = UnionArray::try_new(
705 str1_int3(),
707 ScalarBuffer::from(vec![1, 1, 1, 1]), None, vec![
710 Arc::new(StringArray::new_null(4)),
711 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
713 )
714 .unwrap();
715
716 let expected = Int32Array::new_null(4);
717 let extracted = union_extract(&union, "int").unwrap();
718
719 assert_eq!(extracted.into_data(), expected.into_data());
720 }
721
722 fn str1_union3(union3_datatype: DataType) -> UnionFields {
723 UnionFields::try_new(
724 vec![1, 3],
725 vec![
726 Field::new("str", DataType::Utf8, true),
727 Field::new("union", union3_datatype, true),
728 ],
729 )
730 .unwrap()
731 }
732
733 #[test]
734 fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
735 let target_fields = str1();
736 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
737
738 let union = UnionArray::try_new(
739 str1_union3(target_type.clone()),
741 ScalarBuffer::from(vec![1, 1]), None, vec![
744 Arc::new(StringArray::new_null(2)),
745 Arc::new(
747 UnionArray::try_new(
748 target_fields.clone(),
749 ScalarBuffer::from(vec![1, 1]),
750 None,
751 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
752 )
753 .unwrap(),
754 ),
755 ],
756 )
757 .unwrap();
758
759 let expected = new_null_array(&target_type, 2);
760 let extracted = union_extract(&union, "union").unwrap();
761
762 assert_eq!(extracted.into_data(), expected.into_data());
763 }
764
765 #[test]
766 fn sparse_4_1_1_target_with_nulls() {
767 let union = UnionArray::try_new(
768 str1_int3(),
770 ScalarBuffer::from(vec![3, 3, 1, 1]), None, vec![
773 Arc::new(StringArray::new_null(4)),
774 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
776 )
777 .unwrap();
778
779 let expected = Int32Array::from(vec![None, Some(4), None, None]);
780 let extracted = union_extract(&union, "int").unwrap();
781
782 assert_eq!(extracted.into_data(), expected.into_data());
783 }
784
785 #[test]
786 fn sparse_4_1_2_target_without_nulls() {
787 let union = UnionArray::try_new(
788 str1_int3(),
790 ScalarBuffer::from(vec![1, 3, 3]), None, vec![
793 Arc::new(StringArray::new_null(3)),
794 Arc::new(Int32Array::from(vec![2, 4, 8])), ],
796 )
797 .unwrap();
798
799 let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
800 let extracted = union_extract(&union, "int").unwrap();
801
802 assert_eq!(extracted.into_data(), expected.into_data());
803 }
804
805 #[test]
806 fn sparse_4_2_some_match_target_cant_contain_null_mask() {
807 let target_fields = str1();
808 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
809
810 let union = UnionArray::try_new(
811 str1_union3(target_type),
813 ScalarBuffer::from(vec![3, 1]), None, vec![
816 Arc::new(StringArray::new_null(2)),
817 Arc::new(
818 UnionArray::try_new(
819 target_fields.clone(),
820 ScalarBuffer::from(vec![1, 1]),
821 None,
822 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
823 )
824 .unwrap(),
825 ),
826 ],
827 )
828 .unwrap();
829
830 let expected = UnionArray::try_new(
831 target_fields,
832 ScalarBuffer::from(vec![1, 1]),
833 None,
834 vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
835 )
836 .unwrap();
837 let extracted = union_extract(&union, "union").unwrap();
838
839 assert_eq!(extracted.into_data(), expected.into_data());
840 }
841
842 #[test]
843 fn dense_1_1_both_empty() {
844 let union = UnionArray::try_new(
845 str1_int3(),
846 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
849 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(0)),
851 ],
852 )
853 .unwrap();
854
855 let expected = StringArray::new_null(0);
856 let extracted = union_extract(&union, "str").unwrap();
857
858 assert_eq!(extracted.into_data(), expected.into_data());
859 }
860
861 #[test]
862 fn dense_1_2_empty_union_target_non_empty() {
863 let union = UnionArray::try_new(
864 str1_int3(),
865 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
868 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(0)),
870 ],
871 )
872 .unwrap();
873
874 let expected = StringArray::new_null(0);
875 let extracted = union_extract(&union, "str").unwrap();
876
877 assert_eq!(extracted.into_data(), expected.into_data());
878 }
879
880 #[test]
881 fn dense_2_non_empty_union_target_empty() {
882 let union = UnionArray::try_new(
883 str1_int3(),
884 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 1])), vec![
887 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(2)),
889 ],
890 )
891 .unwrap();
892
893 let expected = StringArray::new_null(2);
894 let extracted = union_extract(&union, "str").unwrap();
895
896 assert_eq!(extracted.into_data(), expected.into_data());
897 }
898
899 #[test]
900 fn dense_3_1_null_target_smaller_len() {
901 let union = UnionArray::try_new(
902 str1_int3(),
903 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
906 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(2)),
908 ],
909 )
910 .unwrap();
911
912 let expected = StringArray::new_null(2);
913 let extracted = union_extract(&union, "str").unwrap();
914
915 assert_eq!(extracted.into_data(), expected.into_data());
916 }
917
918 #[test]
919 fn dense_3_2_null_target_equal_len() {
920 let union = UnionArray::try_new(
921 str1_int3(),
922 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
925 Arc::new(StringArray::new_null(2)), Arc::new(Int32Array::new_null(2)),
927 ],
928 )
929 .unwrap();
930
931 let expected = StringArray::new_null(2);
932 let extracted = union_extract(&union, "str").unwrap();
933
934 assert_eq!(extracted.into_data(), expected.into_data());
935 }
936
937 #[test]
938 fn dense_3_3_null_target_bigger_len() {
939 let union = UnionArray::try_new(
940 str1_int3(),
941 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
944 Arc::new(StringArray::new_null(3)), Arc::new(Int32Array::new_null(3)),
946 ],
947 )
948 .unwrap();
949
950 let expected = StringArray::new_null(2);
951 let extracted = union_extract(&union, "str").unwrap();
952
953 assert_eq!(extracted.into_data(), expected.into_data());
954 }
955
956 #[test]
957 fn dense_4_1a_single_type_sequential_offsets_equal_len() {
958 let union = UnionArray::try_new(
959 str1(),
961 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
964 Arc::new(StringArray::from(vec!["a1", "b2"])), ],
966 )
967 .unwrap();
968
969 let expected = StringArray::from(vec!["a1", "b2"]);
970 let extracted = union_extract(&union, "str").unwrap();
971
972 assert_eq!(extracted.into_data(), expected.into_data());
973 }
974
975 #[test]
976 fn dense_4_2a_single_type_sequential_offsets_bigger() {
977 let union = UnionArray::try_new(
978 str1(),
980 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
983 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
985 )
986 .unwrap();
987
988 let expected = StringArray::from(vec!["a1", "b2"]);
989 let extracted = union_extract(&union, "str").unwrap();
990
991 assert_eq!(extracted.into_data(), expected.into_data());
992 }
993
994 #[test]
995 fn dense_4_3a_single_type_non_sequential() {
996 let union = UnionArray::try_new(
997 str1(),
999 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1002 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
1004 )
1005 .unwrap();
1006
1007 let expected = StringArray::from(vec!["a1", "c3"]);
1008 let extracted = union_extract(&union, "str").unwrap();
1009
1010 assert_eq!(extracted.into_data(), expected.into_data());
1011 }
1012
1013 #[test]
1014 fn dense_4_1b_empty_siblings_sequential_equal_len() {
1015 let union = UnionArray::try_new(
1016 str1_int3(),
1018 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1021 Arc::new(StringArray::from(vec!["a", "b"])), Arc::new(Int32Array::new_null(0)), ],
1024 )
1025 .unwrap();
1026
1027 let expected = StringArray::from(vec!["a", "b"]);
1028 let extracted = union_extract(&union, "str").unwrap();
1029
1030 assert_eq!(extracted.into_data(), expected.into_data());
1031 }
1032
1033 #[test]
1034 fn dense_4_2b_empty_siblings_sequential_bigger_len() {
1035 let union = UnionArray::try_new(
1036 str1_int3(),
1038 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1041 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1044 )
1045 .unwrap();
1046
1047 let expected = StringArray::from(vec!["a", "b"]);
1048 let extracted = union_extract(&union, "str").unwrap();
1049
1050 assert_eq!(extracted.into_data(), expected.into_data());
1051 }
1052
1053 #[test]
1054 fn dense_4_3b_empty_sibling_non_sequential() {
1055 let union = UnionArray::try_new(
1056 str1_int3(),
1058 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1061 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1064 )
1065 .unwrap();
1066
1067 let expected = StringArray::from(vec!["a", "c"]);
1068 let extracted = union_extract(&union, "str").unwrap();
1069
1070 assert_eq!(extracted.into_data(), expected.into_data());
1071 }
1072
1073 #[test]
1074 fn dense_4_1c_all_types_match_sequential_equal_len() {
1075 let union = UnionArray::try_new(
1076 str1_int3(),
1078 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1081 Arc::new(StringArray::from(vec!["a1", "b2"])), Arc::new(Int32Array::new_null(2)), ],
1084 )
1085 .unwrap();
1086
1087 let expected = StringArray::from(vec!["a1", "b2"]);
1088 let extracted = union_extract(&union, "str").unwrap();
1089
1090 assert_eq!(extracted.into_data(), expected.into_data());
1091 }
1092
1093 #[test]
1094 fn dense_4_2c_all_types_match_sequential_bigger_len() {
1095 let union = UnionArray::try_new(
1096 str1_int3(),
1098 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1101 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), Arc::new(Int32Array::new_null(2)), ],
1104 )
1105 .unwrap();
1106
1107 let expected = StringArray::from(vec!["a1", "b2"]);
1108 let extracted = union_extract(&union, "str").unwrap();
1109
1110 assert_eq!(extracted.into_data(), expected.into_data());
1111 }
1112
1113 #[test]
1114 fn dense_4_3c_all_types_match_non_sequential() {
1115 let union = UnionArray::try_new(
1116 str1_int3(),
1118 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1121 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1122 Arc::new(Int32Array::new_null(2)), ],
1124 )
1125 .unwrap();
1126
1127 let expected = StringArray::from(vec!["a1", "b3"]);
1128 let extracted = union_extract(&union, "str").unwrap();
1129
1130 assert_eq!(extracted.into_data(), expected.into_data());
1131 }
1132
1133 #[test]
1134 fn dense_5_1a_none_match_less_len() {
1135 let union = UnionArray::try_new(
1136 str1_int3(),
1138 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1141 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1143 ],
1144 )
1145 .unwrap();
1146
1147 let expected = StringArray::new_null(5);
1148 let extracted = union_extract(&union, "str").unwrap();
1149
1150 assert_eq!(extracted.into_data(), expected.into_data());
1151 }
1152
1153 #[test]
1154 fn dense_5_1b_cant_contain_null_mask() {
1155 let target_fields = str1();
1156 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1157
1158 let union = UnionArray::try_new(
1159 str1_union3(target_type.clone()),
1161 ScalarBuffer::from(vec![1, 1, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1164 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(
1166 UnionArray::try_new(
1167 target_fields.clone(),
1168 ScalarBuffer::from(vec![1]),
1169 None,
1170 vec![Arc::new(StringArray::from(vec!["a"]))],
1171 )
1172 .unwrap(),
1173 ), ],
1175 )
1176 .unwrap();
1177
1178 let expected = new_null_array(&target_type, 5);
1179 let extracted = union_extract(&union, "union").unwrap();
1180
1181 assert_eq!(extracted.into_data(), expected.into_data());
1182 }
1183
1184 #[test]
1185 fn dense_5_2_none_match_equal_len() {
1186 let union = UnionArray::try_new(
1187 str1_int3(),
1189 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1192 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), Arc::new(Int32Array::from(vec![1, 2])),
1194 ],
1195 )
1196 .unwrap();
1197
1198 let expected = StringArray::new_null(5);
1199 let extracted = union_extract(&union, "str").unwrap();
1200
1201 assert_eq!(extracted.into_data(), expected.into_data());
1202 }
1203
1204 #[test]
1205 fn dense_5_3_none_match_greater_len() {
1206 let union = UnionArray::try_new(
1207 str1_int3(),
1209 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1212 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), Arc::new(Int32Array::from(vec![1, 2])), ],
1215 )
1216 .unwrap();
1217
1218 let expected = StringArray::new_null(5);
1219 let extracted = union_extract(&union, "str").unwrap();
1220
1221 assert_eq!(extracted.into_data(), expected.into_data());
1222 }
1223
1224 #[test]
1225 fn dense_6_some_matches() {
1226 let union = UnionArray::try_new(
1227 str1_int3(),
1229 ScalarBuffer::from(vec![3, 3, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), vec![
1232 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1234 ],
1235 )
1236 .unwrap();
1237
1238 let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1239 let extracted = union_extract(&union, "int").unwrap();
1240
1241 assert_eq!(extracted.into_data(), expected.into_data());
1242 }
1243
1244 #[test]
1245 fn empty_sparse_union() {
1246 let union = UnionArray::try_new(
1247 UnionFields::empty(),
1248 ScalarBuffer::from(vec![]),
1249 None,
1250 vec![],
1251 )
1252 .unwrap();
1253
1254 assert_eq!(
1255 union_extract(&union, "a").unwrap_err().to_string(),
1256 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1257 );
1258 }
1259
1260 #[test]
1261 fn empty_dense_union() {
1262 let union = UnionArray::try_new(
1263 UnionFields::empty(),
1264 ScalarBuffer::from(vec![]),
1265 Some(ScalarBuffer::from(vec![])),
1266 vec![],
1267 )
1268 .unwrap();
1269
1270 assert_eq!(
1271 union_extract(&union, "a").unwrap_err().to_string(),
1272 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1273 );
1274 }
1275
1276 #[test]
1277 fn extract_by_id_sparse_duplicate_names() {
1278 let fields = UnionFields::try_new(
1280 [0, 1],
1281 [
1282 Field::new("val", DataType::Int32, true),
1283 Field::new("val", DataType::Utf8, true),
1284 ],
1285 )
1286 .unwrap();
1287
1288 let union = UnionArray::try_new(
1289 fields,
1290 vec![0_i8, 1, 0, 1].into(),
1291 None,
1292 vec![
1293 Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as _,
1294 Arc::new(StringArray::from(vec![
1295 None,
1296 Some("hello"),
1297 None,
1298 Some("world"),
1299 ])),
1300 ],
1301 )
1302 .unwrap();
1303
1304 let by_name = union_extract(&union, "val").unwrap();
1306 assert_eq!(
1307 *by_name,
1308 Int32Array::from(vec![Some(42), None, Some(99), None])
1309 );
1310
1311 let by_id = union_extract_by_id(&union, 1).unwrap();
1313 assert_eq!(
1314 *by_id,
1315 StringArray::from(vec![None, Some("hello"), None, Some("world")])
1316 );
1317 }
1318
1319 #[test]
1320 fn extract_by_id_dense_duplicate_names() {
1321 let fields = UnionFields::try_new(
1322 [0, 1],
1323 [
1324 Field::new("val", DataType::Int32, true),
1325 Field::new("val", DataType::Utf8, true),
1326 ],
1327 )
1328 .unwrap();
1329
1330 let union = UnionArray::try_new(
1331 fields,
1332 vec![0_i8, 1, 0].into(),
1333 Some(vec![0_i32, 0, 1].into()),
1334 vec![
1335 Arc::new(Int32Array::from(vec![Some(42), Some(99)])) as _,
1336 Arc::new(StringArray::from(vec![Some("hello")])),
1337 ],
1338 )
1339 .unwrap();
1340
1341 let by_id_0 = union_extract_by_id(&union, 0).unwrap();
1343 assert_eq!(*by_id_0, Int32Array::from(vec![Some(42), None, Some(99)]));
1344
1345 let by_id_1 = union_extract_by_id(&union, 1).unwrap();
1347 assert_eq!(*by_id_1, StringArray::from(vec![None, Some("hello"), None]));
1348 }
1349
1350 #[test]
1351 fn extract_by_id_not_found() {
1352 let fields = UnionFields::try_new(
1353 [0, 1],
1354 [
1355 Field::new("a", DataType::Int32, true),
1356 Field::new("b", DataType::Utf8, true),
1357 ],
1358 )
1359 .unwrap();
1360
1361 let union = UnionArray::try_new(
1362 fields,
1363 vec![0_i8, 1].into(),
1364 None,
1365 vec![
1366 Arc::new(Int32Array::from(vec![Some(1), None])) as _,
1367 Arc::new(StringArray::from(vec![None, Some("x")])),
1368 ],
1369 )
1370 .unwrap();
1371
1372 assert_eq!(
1373 union_extract_by_id(&union, 5).unwrap_err().to_string(),
1374 ArrowError::InvalidArgumentError("type_id 5 not found on union".into()).to_string()
1375 );
1376 }
1377}