1use crate::take::take;
21use arrow_array::{
22 Array, ArrayRef, BooleanArray, Int32Array, Scalar, UnionArray, make_array, new_empty_array,
23 new_null_array,
24};
25use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer, bit_util};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80 let fields = match union_array.data_type() {
81 DataType::Union(fields, _) => fields,
82 _ => unreachable!(),
83 };
84
85 let (target_type_id, _) = fields
86 .iter()
87 .find(|field| field.1.name() == target)
88 .ok_or_else(|| {
89 ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90 })?;
91
92 match union_array.offsets() {
93 Some(_) => extract_dense(union_array, fields, target_type_id),
94 None => extract_sparse(union_array, fields, target_type_id),
95 }
96}
97
98fn extract_sparse(
99 union_array: &UnionArray,
100 fields: &UnionFields,
101 target_type_id: i8,
102) -> Result<ArrayRef, ArrowError> {
103 let target = union_array.child(target_type_id);
104
105 if fields.len() == 1 || union_array.is_empty() || target.null_count() == target.len() || target.data_type().is_null()
108 {
110 Ok(Arc::clone(target))
111 } else {
112 match eq_scalar(union_array.type_ids(), target_type_id) {
113 BoolValue::Scalar(true) => Ok(Arc::clone(target)),
115 BoolValue::Scalar(false) => {
117 if layout(target.data_type()).can_contain_null_mask {
118 let data = unsafe {
121 target
122 .into_data()
123 .into_builder()
124 .nulls(Some(NullBuffer::new_null(target.len())))
125 .build_unchecked()
126 };
127
128 Ok(make_array(data))
129 } else {
130 Ok(new_null_array(target.data_type(), target.len()))
132 }
133 }
134 BoolValue::Buffer(selected) => {
136 if layout(target.data_type()).can_contain_null_mask {
137 let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
139 Some(nulls) => &selected & nulls.inner(),
142 None => selected,
144 };
145
146 let data = unsafe {
148 assert_eq!(nulls.len(), target.len());
149
150 target
151 .into_data()
152 .into_builder()
153 .nulls(Some(nulls.into()))
154 .build_unchecked()
155 };
156
157 Ok(make_array(data))
158 } else {
159 Ok(crate::zip::zip(
161 &BooleanArray::new(selected, None),
162 target,
163 &Scalar::new(new_null_array(target.data_type(), 1)),
164 )?)
165 }
166 }
167 }
168 }
169}
170
171fn extract_dense(
172 union_array: &UnionArray,
173 fields: &UnionFields,
174 target_type_id: i8,
175) -> Result<ArrayRef, ArrowError> {
176 let target = union_array.child(target_type_id);
177 let offsets = union_array.offsets().unwrap();
178
179 if union_array.is_empty() {
180 if target.is_empty() {
182 Ok(Arc::clone(target))
184 } else {
185 Ok(new_empty_array(target.data_type()))
187 }
188 } else if target.is_empty() {
189 Ok(new_null_array(target.data_type(), union_array.len()))
191 } else if target.null_count() == target.len() || target.data_type().is_null() {
192 match target.len().cmp(&union_array.len()) {
194 Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
196 Ordering::Equal => Ok(Arc::clone(target)),
198 Ordering::Greater => Ok(target.slice(0, union_array.len())),
200 }
201 } else if fields.len() == 1 || fields
203 .iter()
204 .filter(|(field_type_id, _)| *field_type_id != target_type_id)
205 .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
206 {
208 Ok(extract_dense_all_selected(union_array, target, offsets)?)
210 } else {
211 match eq_scalar(union_array.type_ids(), target_type_id) {
212 BoolValue::Scalar(true) => {
216 Ok(extract_dense_all_selected(union_array, target, offsets)?)
217 }
218 BoolValue::Scalar(false) => {
219 match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
223 (Ordering::Less, _) | (_, false) => { Ok(new_null_array(target.data_type(), union_array.len()))
226 }
227 (Ordering::Equal, true) => {
229 let data = unsafe {
231 target
232 .into_data()
233 .into_builder()
234 .nulls(Some(NullBuffer::new_null(union_array.len())))
235 .build_unchecked()
236 };
237
238 Ok(make_array(data))
239 }
240 (Ordering::Greater, true) => {
242 let data = unsafe {
244 target
245 .into_data()
246 .slice(0, union_array.len())
247 .into_builder()
248 .nulls(Some(NullBuffer::new_null(union_array.len())))
249 .build_unchecked()
250 };
251
252 Ok(make_array(data))
253 }
254 }
255 }
256 BoolValue::Buffer(selected) => {
257 Ok(take(
259 target,
260 &Int32Array::try_new(offsets.clone(), Some(selected.into()))?,
261 None,
262 )?)
263 }
264 }
265 }
266}
267
268fn extract_dense_all_selected(
269 union_array: &UnionArray,
270 target: &Arc<dyn Array>,
271 offsets: &ScalarBuffer<i32>,
272) -> Result<ArrayRef, ArrowError> {
273 let sequential =
274 target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
275
276 if sequential && target.len() == union_array.len() {
277 Ok(Arc::clone(target))
279 } else if sequential && target.len() > union_array.len() {
280 Ok(target.slice(offsets[0] as usize, union_array.len()))
282 } else {
283 let indices = Int32Array::try_new(offsets.clone(), None)?;
285
286 Ok(take(target, &indices, None)?)
287 }
288}
289
290const EQ_SCALAR_CHUNK_SIZE: usize = 512;
291
292#[derive(Debug, PartialEq)]
294enum BoolValue {
295 Scalar(bool),
298 Buffer(BooleanBuffer),
300}
301
302fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
303 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
304}
305
306fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
307 type_ids
308 .chunks(chunk_size)
309 .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
310 .map(|chunk| chunk.len())
311 .sum()
312}
313
314fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
316 let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
317
318 let (set_bits, val) = if true_bits == type_ids.len() {
319 return BoolValue::Scalar(true);
320 } else if true_bits == 0 {
321 let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
322
323 if false_bits == type_ids.len() {
324 return BoolValue::Scalar(false);
325 } else {
326 (false_bits, false)
327 }
328 } else {
329 (true_bits, true)
330 };
331
332 let set_bits = set_bits - set_bits % 64;
334
335 let mut buffer =
336 MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
337
338 buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
339 chunk
340 .iter()
341 .copied()
342 .enumerate()
343 .fold(0, |packed, (bit_idx, v)| {
344 packed | (((v == target) as u64) << bit_idx)
345 })
346 }));
347
348 BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
349}
350
351const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
352
353fn is_sequential(offsets: &[i32]) -> bool {
354 is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
355}
356
357fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
358 if offsets.is_empty() {
359 return true;
360 }
361
362 if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
372 return false;
373 }
374
375 let chunks = offsets.chunks_exact(N);
376
377 let remainder = chunks.remainder();
378
379 chunks.enumerate().all(|(i, chunk)| {
380 let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
381
382 chunk_array
384 .iter()
385 .copied()
386 .enumerate()
387 .fold(true, |acc, (i, offset)| {
388 acc & (offset == chunk_array[0] + i as i32)
389 })
390 && offsets[0] + (i * N) as i32 == chunk_array[0] }) && remainder
392 .iter()
393 .copied()
394 .enumerate()
395 .fold(true, |acc, (i, offset)| {
396 acc & (offset == remainder[0] + i as i32)
397 }) }
399
400#[cfg(test)]
401mod tests {
402 use super::{BoolValue, eq_scalar_inner, is_sequential_generic, union_extract};
403 use arrow_array::{Array, Int32Array, NullArray, StringArray, UnionArray, new_null_array};
404 use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405 use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
406 use std::sync::Arc;
407
408 #[test]
409 fn test_eq_scalar() {
410 const ARRAY_LEN: usize = 64 * 4;
413
414 const EQ_SCALAR_CHUNK_SIZE: usize = 3;
416
417 fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
418 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
419 }
420
421 fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
422 BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
423 }
424
425 assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
426
427 assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
428 assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
429
430 let mut values = [1; ARRAY_LEN];
431
432 assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
433 assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
434
435 for i in 1..ARRAY_LEN {
437 assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
438 assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
439 }
440
441 for i in 0..ARRAY_LEN {
443 values[i] = 2;
444
445 assert_eq!(
446 eq_scalar(&values, 1),
447 BoolValue::Buffer(cross_check(&values, 1))
448 );
449 assert_eq!(
450 eq_scalar(&values, 2),
451 BoolValue::Buffer(cross_check(&values, 2))
452 );
453
454 values[i] = 1;
455 }
456 }
457
458 #[test]
459 fn test_is_sequential() {
460 const CHUNK_SIZE: usize = 3;
466 fn is_sequential(v: &[i32]) -> bool {
473 is_sequential_generic::<CHUNK_SIZE>(v)
474 }
475
476 assert!(is_sequential(&[])); assert!(is_sequential(&[1])); assert!(is_sequential(&[1, 2]));
480 assert!(is_sequential(&[1, 2, 3]));
481 assert!(is_sequential(&[1, 2, 3, 4]));
482 assert!(is_sequential(&[1, 2, 3, 4, 5]));
483 assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
484 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
485 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
486
487 assert!(!is_sequential(&[8, 7]));
488 assert!(!is_sequential(&[8, 7, 6]));
489 assert!(!is_sequential(&[8, 7, 6, 5]));
490 assert!(!is_sequential(&[8, 7, 6, 5, 4]));
491 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
492 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
493 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
494
495 assert!(!is_sequential(&[0, 2]));
496 assert!(!is_sequential(&[1, 0]));
497
498 assert!(!is_sequential(&[0, 2, 3]));
499 assert!(!is_sequential(&[1, 0, 3]));
500 assert!(!is_sequential(&[1, 2, 0]));
501
502 assert!(!is_sequential(&[0, 2, 3, 4]));
503 assert!(!is_sequential(&[1, 0, 3, 4]));
504 assert!(!is_sequential(&[1, 2, 0, 4]));
505 assert!(!is_sequential(&[1, 2, 3, 0]));
506
507 assert!(!is_sequential(&[0, 2, 3, 4, 5]));
508 assert!(!is_sequential(&[1, 0, 3, 4, 5]));
509 assert!(!is_sequential(&[1, 2, 0, 4, 5]));
510 assert!(!is_sequential(&[1, 2, 3, 0, 5]));
511 assert!(!is_sequential(&[1, 2, 3, 4, 0]));
512
513 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
514 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
515 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
516 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
517 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
518 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
519
520 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
521 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
522 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
523 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
524 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
525 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
526 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
527
528 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
529 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
530 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
531 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
532 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
533 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
534 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
535 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
536
537 assert!(!is_sequential(&[1, 2, 3, 5]));
539 assert!(!is_sequential(&[1, 2, 3, 5, 6]));
540 assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
541 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
542 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
543 }
544
545 fn str1() -> UnionFields {
546 UnionFields::try_new(vec![1], vec![Field::new("str", DataType::Utf8, true)]).unwrap()
547 }
548
549 fn str1_int3() -> UnionFields {
550 UnionFields::try_new(
551 vec![1, 3],
552 vec![
553 Field::new("str", DataType::Utf8, true),
554 Field::new("int", DataType::Int32, true),
555 ],
556 )
557 .unwrap()
558 }
559
560 #[test]
561 fn sparse_1_1_single_field() {
562 let union = UnionArray::try_new(
563 str1(),
565 ScalarBuffer::from(vec![1, 1]), None, vec![
568 Arc::new(StringArray::from(vec!["a", "b"])), ],
570 )
571 .unwrap();
572
573 let expected = StringArray::from(vec!["a", "b"]);
574 let extracted = union_extract(&union, "str").unwrap();
575
576 assert_eq!(extracted.into_data(), expected.into_data());
577 }
578
579 #[test]
580 fn sparse_1_2_empty() {
581 let union = UnionArray::try_new(
582 str1_int3(),
584 ScalarBuffer::from(vec![]), None, vec![
587 Arc::new(StringArray::new_null(0)),
588 Arc::new(Int32Array::new_null(0)),
589 ],
590 )
591 .unwrap();
592
593 let expected = StringArray::new_null(0);
594 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
597 }
598
599 #[test]
600 fn sparse_1_3a_null_target() {
601 let union = UnionArray::try_new(
602 UnionFields::try_new(
604 vec![1, 3],
605 vec![
606 Field::new("str", DataType::Utf8, true),
607 Field::new("null", DataType::Null, true), ],
609 )
610 .unwrap(),
611 ScalarBuffer::from(vec![1]), None, vec![
614 Arc::new(StringArray::new_null(1)),
615 Arc::new(NullArray::new(1)), ],
617 )
618 .unwrap();
619
620 let expected = NullArray::new(1);
621 let extracted = union_extract(&union, "null").unwrap();
622
623 assert_eq!(extracted.into_data(), expected.into_data());
624 }
625
626 #[test]
627 fn sparse_1_3b_null_target() {
628 let union = UnionArray::try_new(
629 str1_int3(),
631 ScalarBuffer::from(vec![1]), None, vec![
634 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(1)),
636 ],
637 )
638 .unwrap();
639
640 let expected = StringArray::new_null(1);
641 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
644 }
645
646 #[test]
647 fn sparse_2_all_types_match() {
648 let union = UnionArray::try_new(
649 str1_int3(),
651 ScalarBuffer::from(vec![3, 3]), None, vec![
654 Arc::new(StringArray::new_null(2)),
655 Arc::new(Int32Array::from(vec![1, 4])), ],
657 )
658 .unwrap();
659
660 let expected = Int32Array::from(vec![1, 4]);
661 let extracted = union_extract(&union, "int").unwrap();
662
663 assert_eq!(extracted.into_data(), expected.into_data());
664 }
665
666 #[test]
667 fn sparse_3_1_none_match_target_can_contain_null_mask() {
668 let union = UnionArray::try_new(
669 str1_int3(),
671 ScalarBuffer::from(vec![1, 1, 1, 1]), None, vec![
674 Arc::new(StringArray::new_null(4)),
675 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
677 )
678 .unwrap();
679
680 let expected = Int32Array::new_null(4);
681 let extracted = union_extract(&union, "int").unwrap();
682
683 assert_eq!(extracted.into_data(), expected.into_data());
684 }
685
686 fn str1_union3(union3_datatype: DataType) -> UnionFields {
687 UnionFields::try_new(
688 vec![1, 3],
689 vec![
690 Field::new("str", DataType::Utf8, true),
691 Field::new("union", union3_datatype, true),
692 ],
693 )
694 .unwrap()
695 }
696
697 #[test]
698 fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
699 let target_fields = str1();
700 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
701
702 let union = UnionArray::try_new(
703 str1_union3(target_type.clone()),
705 ScalarBuffer::from(vec![1, 1]), None, vec![
708 Arc::new(StringArray::new_null(2)),
709 Arc::new(
711 UnionArray::try_new(
712 target_fields.clone(),
713 ScalarBuffer::from(vec![1, 1]),
714 None,
715 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
716 )
717 .unwrap(),
718 ),
719 ],
720 )
721 .unwrap();
722
723 let expected = new_null_array(&target_type, 2);
724 let extracted = union_extract(&union, "union").unwrap();
725
726 assert_eq!(extracted.into_data(), expected.into_data());
727 }
728
729 #[test]
730 fn sparse_4_1_1_target_with_nulls() {
731 let union = UnionArray::try_new(
732 str1_int3(),
734 ScalarBuffer::from(vec![3, 3, 1, 1]), None, vec![
737 Arc::new(StringArray::new_null(4)),
738 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
740 )
741 .unwrap();
742
743 let expected = Int32Array::from(vec![None, Some(4), None, None]);
744 let extracted = union_extract(&union, "int").unwrap();
745
746 assert_eq!(extracted.into_data(), expected.into_data());
747 }
748
749 #[test]
750 fn sparse_4_1_2_target_without_nulls() {
751 let union = UnionArray::try_new(
752 str1_int3(),
754 ScalarBuffer::from(vec![1, 3, 3]), None, vec![
757 Arc::new(StringArray::new_null(3)),
758 Arc::new(Int32Array::from(vec![2, 4, 8])), ],
760 )
761 .unwrap();
762
763 let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
764 let extracted = union_extract(&union, "int").unwrap();
765
766 assert_eq!(extracted.into_data(), expected.into_data());
767 }
768
769 #[test]
770 fn sparse_4_2_some_match_target_cant_contain_null_mask() {
771 let target_fields = str1();
772 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
773
774 let union = UnionArray::try_new(
775 str1_union3(target_type),
777 ScalarBuffer::from(vec![3, 1]), None, vec![
780 Arc::new(StringArray::new_null(2)),
781 Arc::new(
782 UnionArray::try_new(
783 target_fields.clone(),
784 ScalarBuffer::from(vec![1, 1]),
785 None,
786 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
787 )
788 .unwrap(),
789 ),
790 ],
791 )
792 .unwrap();
793
794 let expected = UnionArray::try_new(
795 target_fields,
796 ScalarBuffer::from(vec![1, 1]),
797 None,
798 vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
799 )
800 .unwrap();
801 let extracted = union_extract(&union, "union").unwrap();
802
803 assert_eq!(extracted.into_data(), expected.into_data());
804 }
805
806 #[test]
807 fn dense_1_1_both_empty() {
808 let union = UnionArray::try_new(
809 str1_int3(),
810 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
813 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(0)),
815 ],
816 )
817 .unwrap();
818
819 let expected = StringArray::new_null(0);
820 let extracted = union_extract(&union, "str").unwrap();
821
822 assert_eq!(extracted.into_data(), expected.into_data());
823 }
824
825 #[test]
826 fn dense_1_2_empty_union_target_non_empty() {
827 let union = UnionArray::try_new(
828 str1_int3(),
829 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
832 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(0)),
834 ],
835 )
836 .unwrap();
837
838 let expected = StringArray::new_null(0);
839 let extracted = union_extract(&union, "str").unwrap();
840
841 assert_eq!(extracted.into_data(), expected.into_data());
842 }
843
844 #[test]
845 fn dense_2_non_empty_union_target_empty() {
846 let union = UnionArray::try_new(
847 str1_int3(),
848 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 1])), vec![
851 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(2)),
853 ],
854 )
855 .unwrap();
856
857 let expected = StringArray::new_null(2);
858 let extracted = union_extract(&union, "str").unwrap();
859
860 assert_eq!(extracted.into_data(), expected.into_data());
861 }
862
863 #[test]
864 fn dense_3_1_null_target_smaller_len() {
865 let union = UnionArray::try_new(
866 str1_int3(),
867 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
870 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(2)),
872 ],
873 )
874 .unwrap();
875
876 let expected = StringArray::new_null(2);
877 let extracted = union_extract(&union, "str").unwrap();
878
879 assert_eq!(extracted.into_data(), expected.into_data());
880 }
881
882 #[test]
883 fn dense_3_2_null_target_equal_len() {
884 let union = UnionArray::try_new(
885 str1_int3(),
886 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
889 Arc::new(StringArray::new_null(2)), Arc::new(Int32Array::new_null(2)),
891 ],
892 )
893 .unwrap();
894
895 let expected = StringArray::new_null(2);
896 let extracted = union_extract(&union, "str").unwrap();
897
898 assert_eq!(extracted.into_data(), expected.into_data());
899 }
900
901 #[test]
902 fn dense_3_3_null_target_bigger_len() {
903 let union = UnionArray::try_new(
904 str1_int3(),
905 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
908 Arc::new(StringArray::new_null(3)), Arc::new(Int32Array::new_null(3)),
910 ],
911 )
912 .unwrap();
913
914 let expected = StringArray::new_null(2);
915 let extracted = union_extract(&union, "str").unwrap();
916
917 assert_eq!(extracted.into_data(), expected.into_data());
918 }
919
920 #[test]
921 fn dense_4_1a_single_type_sequential_offsets_equal_len() {
922 let union = UnionArray::try_new(
923 str1(),
925 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
928 Arc::new(StringArray::from(vec!["a1", "b2"])), ],
930 )
931 .unwrap();
932
933 let expected = StringArray::from(vec!["a1", "b2"]);
934 let extracted = union_extract(&union, "str").unwrap();
935
936 assert_eq!(extracted.into_data(), expected.into_data());
937 }
938
939 #[test]
940 fn dense_4_2a_single_type_sequential_offsets_bigger() {
941 let union = UnionArray::try_new(
942 str1(),
944 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
947 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
949 )
950 .unwrap();
951
952 let expected = StringArray::from(vec!["a1", "b2"]);
953 let extracted = union_extract(&union, "str").unwrap();
954
955 assert_eq!(extracted.into_data(), expected.into_data());
956 }
957
958 #[test]
959 fn dense_4_3a_single_type_non_sequential() {
960 let union = UnionArray::try_new(
961 str1(),
963 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
966 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
968 )
969 .unwrap();
970
971 let expected = StringArray::from(vec!["a1", "c3"]);
972 let extracted = union_extract(&union, "str").unwrap();
973
974 assert_eq!(extracted.into_data(), expected.into_data());
975 }
976
977 #[test]
978 fn dense_4_1b_empty_siblings_sequential_equal_len() {
979 let union = UnionArray::try_new(
980 str1_int3(),
982 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
985 Arc::new(StringArray::from(vec!["a", "b"])), Arc::new(Int32Array::new_null(0)), ],
988 )
989 .unwrap();
990
991 let expected = StringArray::from(vec!["a", "b"]);
992 let extracted = union_extract(&union, "str").unwrap();
993
994 assert_eq!(extracted.into_data(), expected.into_data());
995 }
996
997 #[test]
998 fn dense_4_2b_empty_siblings_sequential_bigger_len() {
999 let union = UnionArray::try_new(
1000 str1_int3(),
1002 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1005 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1008 )
1009 .unwrap();
1010
1011 let expected = StringArray::from(vec!["a", "b"]);
1012 let extracted = union_extract(&union, "str").unwrap();
1013
1014 assert_eq!(extracted.into_data(), expected.into_data());
1015 }
1016
1017 #[test]
1018 fn dense_4_3b_empty_sibling_non_sequential() {
1019 let union = UnionArray::try_new(
1020 str1_int3(),
1022 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1025 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1028 )
1029 .unwrap();
1030
1031 let expected = StringArray::from(vec!["a", "c"]);
1032 let extracted = union_extract(&union, "str").unwrap();
1033
1034 assert_eq!(extracted.into_data(), expected.into_data());
1035 }
1036
1037 #[test]
1038 fn dense_4_1c_all_types_match_sequential_equal_len() {
1039 let union = UnionArray::try_new(
1040 str1_int3(),
1042 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1045 Arc::new(StringArray::from(vec!["a1", "b2"])), Arc::new(Int32Array::new_null(2)), ],
1048 )
1049 .unwrap();
1050
1051 let expected = StringArray::from(vec!["a1", "b2"]);
1052 let extracted = union_extract(&union, "str").unwrap();
1053
1054 assert_eq!(extracted.into_data(), expected.into_data());
1055 }
1056
1057 #[test]
1058 fn dense_4_2c_all_types_match_sequential_bigger_len() {
1059 let union = UnionArray::try_new(
1060 str1_int3(),
1062 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1065 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), Arc::new(Int32Array::new_null(2)), ],
1068 )
1069 .unwrap();
1070
1071 let expected = StringArray::from(vec!["a1", "b2"]);
1072 let extracted = union_extract(&union, "str").unwrap();
1073
1074 assert_eq!(extracted.into_data(), expected.into_data());
1075 }
1076
1077 #[test]
1078 fn dense_4_3c_all_types_match_non_sequential() {
1079 let union = UnionArray::try_new(
1080 str1_int3(),
1082 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1085 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1086 Arc::new(Int32Array::new_null(2)), ],
1088 )
1089 .unwrap();
1090
1091 let expected = StringArray::from(vec!["a1", "b3"]);
1092 let extracted = union_extract(&union, "str").unwrap();
1093
1094 assert_eq!(extracted.into_data(), expected.into_data());
1095 }
1096
1097 #[test]
1098 fn dense_5_1a_none_match_less_len() {
1099 let union = UnionArray::try_new(
1100 str1_int3(),
1102 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1105 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1107 ],
1108 )
1109 .unwrap();
1110
1111 let expected = StringArray::new_null(5);
1112 let extracted = union_extract(&union, "str").unwrap();
1113
1114 assert_eq!(extracted.into_data(), expected.into_data());
1115 }
1116
1117 #[test]
1118 fn dense_5_1b_cant_contain_null_mask() {
1119 let target_fields = str1();
1120 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1121
1122 let union = UnionArray::try_new(
1123 str1_union3(target_type.clone()),
1125 ScalarBuffer::from(vec![1, 1, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1128 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(
1130 UnionArray::try_new(
1131 target_fields.clone(),
1132 ScalarBuffer::from(vec![1]),
1133 None,
1134 vec![Arc::new(StringArray::from(vec!["a"]))],
1135 )
1136 .unwrap(),
1137 ), ],
1139 )
1140 .unwrap();
1141
1142 let expected = new_null_array(&target_type, 5);
1143 let extracted = union_extract(&union, "union").unwrap();
1144
1145 assert_eq!(extracted.into_data(), expected.into_data());
1146 }
1147
1148 #[test]
1149 fn dense_5_2_none_match_equal_len() {
1150 let union = UnionArray::try_new(
1151 str1_int3(),
1153 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1156 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), Arc::new(Int32Array::from(vec![1, 2])),
1158 ],
1159 )
1160 .unwrap();
1161
1162 let expected = StringArray::new_null(5);
1163 let extracted = union_extract(&union, "str").unwrap();
1164
1165 assert_eq!(extracted.into_data(), expected.into_data());
1166 }
1167
1168 #[test]
1169 fn dense_5_3_none_match_greater_len() {
1170 let union = UnionArray::try_new(
1171 str1_int3(),
1173 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1176 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), Arc::new(Int32Array::from(vec![1, 2])), ],
1179 )
1180 .unwrap();
1181
1182 let expected = StringArray::new_null(5);
1183 let extracted = union_extract(&union, "str").unwrap();
1184
1185 assert_eq!(extracted.into_data(), expected.into_data());
1186 }
1187
1188 #[test]
1189 fn dense_6_some_matches() {
1190 let union = UnionArray::try_new(
1191 str1_int3(),
1193 ScalarBuffer::from(vec![3, 3, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), vec![
1196 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1198 ],
1199 )
1200 .unwrap();
1201
1202 let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1203 let extracted = union_extract(&union, "int").unwrap();
1204
1205 assert_eq!(extracted.into_data(), expected.into_data());
1206 }
1207
1208 #[test]
1209 fn empty_sparse_union() {
1210 let union = UnionArray::try_new(
1211 UnionFields::empty(),
1212 ScalarBuffer::from(vec![]),
1213 None,
1214 vec![],
1215 )
1216 .unwrap();
1217
1218 assert_eq!(
1219 union_extract(&union, "a").unwrap_err().to_string(),
1220 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1221 );
1222 }
1223
1224 #[test]
1225 fn empty_dense_union() {
1226 let union = UnionArray::try_new(
1227 UnionFields::empty(),
1228 ScalarBuffer::from(vec![]),
1229 Some(ScalarBuffer::from(vec![])),
1230 vec![],
1231 )
1232 .unwrap();
1233
1234 assert_eq!(
1235 union_extract(&union, "a").unwrap_err().to_string(),
1236 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1237 );
1238 }
1239}