1#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{Array, ArrayRef, make_array};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31#[derive(Clone)]
123pub struct UnionArray {
124 data_type: DataType,
125 type_ids: ScalarBuffer<i8>,
126 offsets: Option<ScalarBuffer<i32>>,
127 fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131 pub unsafe fn new_unchecked(
150 fields: UnionFields,
151 type_ids: ScalarBuffer<i8>,
152 offsets: Option<ScalarBuffer<i32>>,
153 children: Vec<ArrayRef>,
154 ) -> Self {
155 let mode = if offsets.is_some() {
156 UnionMode::Dense
157 } else {
158 UnionMode::Sparse
159 };
160
161 let len = type_ids.len();
162 let builder = ArrayData::builder(DataType::Union(fields, mode))
163 .add_buffer(type_ids.into_inner())
164 .child_data(children.into_iter().map(Array::into_data).collect())
165 .len(len);
166
167 let data = match offsets {
168 Some(offsets) => unsafe { builder.add_buffer(offsets.into_inner()).build_unchecked() },
169 None => unsafe { builder.build_unchecked() },
170 };
171 Self::from(data)
172 }
173
174 pub fn try_new(
178 fields: UnionFields,
179 type_ids: ScalarBuffer<i8>,
180 offsets: Option<ScalarBuffer<i32>>,
181 children: Vec<ArrayRef>,
182 ) -> Result<Self, ArrowError> {
183 if fields.len() != children.len() {
185 return Err(ArrowError::InvalidArgumentError(
186 "Union fields length must match child arrays length".to_string(),
187 ));
188 }
189
190 if let Some(offsets) = &offsets {
191 if offsets.len() != type_ids.len() {
193 return Err(ArrowError::InvalidArgumentError(
194 "Type Ids and Offsets lengths must match".to_string(),
195 ));
196 }
197 } else {
198 for child in &children {
200 if child.len() != type_ids.len() {
201 return Err(ArrowError::InvalidArgumentError(
202 "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203 ));
204 }
205 }
206 }
207
208 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210 let mut array_lens = vec![i32::MIN; max_id + 1];
211 for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212 array_lens[field_id as usize] = cd.len() as i32;
213 }
214
215 for id in &type_ids {
217 match array_lens.get(*id as usize) {
218 Some(x) if *x != i32::MIN => {}
219 _ => {
220 return Err(ArrowError::InvalidArgumentError(
221 "Type Ids values must match one of the field type ids".to_owned(),
222 ));
223 }
224 }
225 }
226
227 if let Some(offsets) = &offsets {
229 let mut iter = type_ids.iter().zip(offsets.iter());
230 if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231 {
232 return Err(ArrowError::InvalidArgumentError(
233 "Offsets must be non-negative and within the length of the Array".to_owned(),
234 ));
235 }
236 }
237
238 let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241 Ok(union_array)
242 }
243
244 pub fn child(&self, type_id: i8) -> &ArrayRef {
251 assert!((type_id as usize) < self.fields.len());
252 let boxed = &self.fields[type_id as usize];
253 boxed.as_ref().expect("invalid type id")
254 }
255
256 pub fn type_id(&self, index: usize) -> i8 {
262 assert!(index < self.type_ids.len());
263 self.type_ids[index]
264 }
265
266 pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268 &self.type_ids
269 }
270
271 pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273 self.offsets.as_ref()
274 }
275
276 pub fn value_offset(&self, index: usize) -> usize {
282 assert!(index < self.len());
283 match &self.offsets {
284 Some(offsets) => offsets[index] as usize,
285 None => self.offset() + index,
286 }
287 }
288
289 pub fn value(&self, i: usize) -> ArrayRef {
297 let type_id = self.type_id(i);
298 let value_offset = self.value_offset(i);
299 let child = self.child(type_id);
300 child.slice(value_offset, 1)
301 }
302
303 pub fn type_names(&self) -> Vec<&str> {
305 match self.data_type() {
306 DataType::Union(fields, _) => fields
307 .iter()
308 .map(|(_, f)| f.name().as_str())
309 .collect::<Vec<&str>>(),
310 _ => unreachable!("Union array's data type is not a union!"),
311 }
312 }
313
314 pub fn fields(&self) -> &UnionFields {
316 match self.data_type() {
317 DataType::Union(fields, _) => fields,
318 _ => unreachable!("Union array's data type is not a union!"),
319 }
320 }
321
322 pub fn is_dense(&self) -> bool {
324 match self.data_type() {
325 DataType::Union(_, mode) => mode == &UnionMode::Dense,
326 _ => unreachable!("Union array's data type is not a union!"),
327 }
328 }
329
330 pub fn slice(&self, offset: usize, length: usize) -> Self {
332 let (offsets, fields) = match self.offsets.as_ref() {
333 Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
335 None => {
337 let fields = self
338 .fields
339 .iter()
340 .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
341 .collect();
342 (None, fields)
343 }
344 };
345
346 Self {
347 data_type: self.data_type.clone(),
348 type_ids: self.type_ids.slice(offset, length),
349 offsets,
350 fields,
351 }
352 }
353
354 #[allow(clippy::type_complexity)]
382 pub fn into_parts(
383 self,
384 ) -> (
385 UnionFields,
386 ScalarBuffer<i8>,
387 Option<ScalarBuffer<i32>>,
388 Vec<ArrayRef>,
389 ) {
390 let Self {
391 data_type,
392 type_ids,
393 offsets,
394 mut fields,
395 } = self;
396 match data_type {
397 DataType::Union(union_fields, _) => {
398 let children = union_fields
399 .iter()
400 .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
401 .collect();
402 (union_fields, type_ids, offsets, children)
403 }
404 _ => unreachable!(),
405 }
406 }
407
408 fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
410 let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
416 (
417 with_nulls_selected | is_field,
418 union_nulls | (is_field & field_nulls),
419 )
420 };
421
422 self.mask_sparse_helper(
423 nulls,
424 |type_ids_chunk_array, nulls_masks_iters| {
425 let (with_nulls_selected, union_nulls) = nulls_masks_iters
426 .iter_mut()
427 .map(|(field_type_id, field_nulls)| {
428 let field_nulls = field_nulls.next().unwrap();
429 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
430
431 (is_field, field_nulls)
432 })
433 .fold((0, 0), fold);
434
435 let without_nulls_selected = !with_nulls_selected;
437
438 without_nulls_selected | union_nulls
441 },
442 |type_ids_remainder, bit_chunks| {
443 let (with_nulls_selected, union_nulls) = bit_chunks
444 .iter()
445 .map(|(field_type_id, field_bit_chunks)| {
446 let field_nulls = field_bit_chunks.remainder_bits();
447 let is_field = selection_mask(type_ids_remainder, *field_type_id);
448
449 (is_field, field_nulls)
450 })
451 .fold((0, 0), fold);
452
453 let without_nulls_selected = !with_nulls_selected;
454
455 without_nulls_selected | union_nulls
456 },
457 )
458 }
459
460 fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
462 let fields = match self.data_type() {
463 DataType::Union(fields, _) => fields,
464 _ => unreachable!("Union array's data type is not a union!"),
465 };
466
467 let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
468 let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
469
470 let without_nulls_ids = type_ids
471 .difference(&with_nulls)
472 .copied()
473 .collect::<Vec<_>>();
474
475 nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
476
477 self.mask_sparse_helper(
482 nulls,
483 |type_ids_chunk_array, nulls_masks_iters| {
484 let union_nulls = nulls_masks_iters.iter_mut().fold(
485 0,
486 |union_nulls, (field_type_id, nulls_iter)| {
487 let field_nulls = nulls_iter.next().unwrap();
488
489 if field_nulls == 0 {
490 union_nulls
491 } else {
492 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
493
494 union_nulls | (is_field & field_nulls)
495 }
496 },
497 );
498
499 let without_nulls_selected =
501 without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
502
503 union_nulls | without_nulls_selected
506 },
507 |type_ids_remainder, bit_chunks| {
508 let union_nulls =
509 bit_chunks
510 .iter()
511 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
512 let is_field = selection_mask(type_ids_remainder, *field_type_id);
513 let field_nulls = field_bit_chunks.remainder_bits();
514
515 union_nulls | is_field & field_nulls
516 });
517
518 union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
519 },
520 )
521 }
522
523 fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
525 self.mask_sparse_helper(
532 nulls,
533 |type_ids_chunk_array, nulls_masks_iters| {
534 let (is_not_first, union_nulls) = nulls_masks_iters[1..] .iter_mut()
536 .fold(
537 (0, 0),
538 |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
539 let field_nulls = nulls_iter.next().unwrap();
540 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
541
542 (
543 is_not_first | is_field,
544 union_nulls | (is_field & field_nulls),
545 )
546 },
547 );
548
549 let is_first = !is_not_first;
550 let first_nulls = nulls_masks_iters[0].1.next().unwrap();
551
552 (is_first & first_nulls) | union_nulls
553 },
554 |type_ids_remainder, bit_chunks| {
555 bit_chunks
556 .iter()
557 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
558 let field_nulls = field_bit_chunks.remainder_bits();
559 let is_field = selection_mask(type_ids_remainder, *field_type_id);
562
563 union_nulls | (is_field & field_nulls)
564 })
565 },
566 )
567 }
568
569 fn mask_sparse_helper(
572 &self,
573 nulls: Vec<(i8, NullBuffer)>,
574 mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
575 mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
576 ) -> BooleanBuffer {
577 let bit_chunks = nulls
578 .iter()
579 .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
580 .collect::<Vec<_>>();
581
582 let mut nulls_masks_iter = bit_chunks
583 .iter()
584 .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
585 .collect::<Vec<_>>();
586
587 let chunks_exact = self.type_ids.chunks_exact(64);
588 let remainder = chunks_exact.remainder();
589
590 let chunks = chunks_exact.map(|type_ids_chunk| {
591 let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
592
593 mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
594 });
595
596 let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
599
600 if !remainder.is_empty() {
601 buffer.push(mask_remainder(remainder, &bit_chunks));
602 }
603
604 BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
605 }
606
607 fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
609 let one_null = NullBuffer::new_null(1);
610 let one_valid = NullBuffer::new_valid(1);
611
612 let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
619
620 for (type_id, nulls) in &nulls {
621 if nulls.null_count() == nulls.len() {
622 logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
624 } else {
625 logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
626 }
627 }
628
629 match &self.offsets {
630 Some(offsets) => {
631 assert_eq!(self.type_ids.len(), offsets.len());
632
633 BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
634 let type_id = *self.type_ids.get_unchecked(i);
636 let offset = *offsets.get_unchecked(i);
638
639 let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
640
641 nulls
647 .inner()
648 .value_unchecked(offset as usize & *offset_mask as usize)
649 })
650 }
651 None => {
652 BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
653 let type_id = *self.type_ids.get_unchecked(index);
655
656 let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
657
658 nulls.inner().value_unchecked(index & *index_mask as usize)
664 })
665 }
666 }
667 }
668
669 fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
672 self.fields
673 .iter()
674 .enumerate()
675 .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
676 .filter(|(_, nulls)| nulls.null_count() > 0)
677 .collect()
678 }
679}
680
681impl From<ArrayData> for UnionArray {
682 fn from(data: ArrayData) -> Self {
683 let (fields, mode) = match data.data_type() {
684 DataType::Union(fields, mode) => (fields, *mode),
685 d => panic!("UnionArray expected ArrayData with type Union got {d}"),
686 };
687 let (type_ids, offsets) = match mode {
688 UnionMode::Sparse => (
689 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
690 None,
691 ),
692 UnionMode::Dense => (
693 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
694 Some(ScalarBuffer::new(
695 data.buffers()[1].clone(),
696 data.offset(),
697 data.len(),
698 )),
699 ),
700 };
701
702 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
703 let mut boxed_fields = vec![None; max_id + 1];
704 for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
705 boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
706 }
707 Self {
708 data_type: data.data_type().clone(),
709 type_ids,
710 offsets,
711 fields: boxed_fields,
712 }
713 }
714}
715
716impl From<UnionArray> for ArrayData {
717 fn from(array: UnionArray) -> Self {
718 let len = array.len();
719 let f = match &array.data_type {
720 DataType::Union(f, _) => f,
721 _ => unreachable!(),
722 };
723 let buffers = match array.offsets {
724 Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
725 None => vec![array.type_ids.into_inner()],
726 };
727
728 let child = f
729 .iter()
730 .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
731 .collect();
732
733 let builder = ArrayDataBuilder::new(array.data_type)
734 .len(len)
735 .buffers(buffers)
736 .child_data(child);
737 unsafe { builder.build_unchecked() }
738 }
739}
740
741impl Array for UnionArray {
742 fn as_any(&self) -> &dyn Any {
743 self
744 }
745
746 fn to_data(&self) -> ArrayData {
747 self.clone().into()
748 }
749
750 fn into_data(self) -> ArrayData {
751 self.into()
752 }
753
754 fn data_type(&self) -> &DataType {
755 &self.data_type
756 }
757
758 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
759 Arc::new(self.slice(offset, length))
760 }
761
762 fn len(&self) -> usize {
763 self.type_ids.len()
764 }
765
766 fn is_empty(&self) -> bool {
767 self.type_ids.is_empty()
768 }
769
770 fn shrink_to_fit(&mut self) {
771 self.type_ids.shrink_to_fit();
772 if let Some(offsets) = &mut self.offsets {
773 offsets.shrink_to_fit();
774 }
775 for array in self.fields.iter_mut().flatten() {
776 array.shrink_to_fit();
777 }
778 self.fields.shrink_to_fit();
779 }
780
781 fn offset(&self) -> usize {
782 0
783 }
784
785 fn nulls(&self) -> Option<&NullBuffer> {
786 None
787 }
788
789 fn logical_nulls(&self) -> Option<NullBuffer> {
790 let fields = match self.data_type() {
791 DataType::Union(fields, _) => fields,
792 _ => unreachable!(),
793 };
794
795 if fields.len() <= 1 {
796 return self.fields.iter().find_map(|field_opt| {
797 field_opt
798 .as_ref()
799 .and_then(|field| field.logical_nulls())
800 .map(|logical_nulls| {
801 if self.is_dense() {
802 self.gather_nulls(vec![(0, logical_nulls)]).into()
803 } else {
804 logical_nulls
805 }
806 })
807 });
808 }
809
810 let logical_nulls = self.fields_logical_nulls();
811
812 if logical_nulls.is_empty() {
813 return None;
814 }
815
816 let fully_null_count = logical_nulls
817 .iter()
818 .filter(|(_, nulls)| nulls.null_count() == nulls.len())
819 .count();
820
821 if fully_null_count == fields.len() {
822 if let Some((_, exactly_sized)) = logical_nulls
823 .iter()
824 .find(|(_, nulls)| nulls.len() == self.len())
825 {
826 return Some(exactly_sized.clone());
827 }
828
829 if let Some((_, bigger)) = logical_nulls
830 .iter()
831 .find(|(_, nulls)| nulls.len() > self.len())
832 {
833 return Some(bigger.slice(0, self.len()));
834 }
835
836 return Some(NullBuffer::new_null(self.len()));
837 }
838
839 let boolean_buffer = match &self.offsets {
840 Some(_) => self.gather_nulls(logical_nulls),
841 None => {
842 let gather_relative_cost = if cfg!(target_feature = "avx2") {
850 10
851 } else if cfg!(target_feature = "sse4.1") {
852 3
853 } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
854 2
856 } else {
857 0
861 };
862
863 let strategies = [
864 (SparseStrategy::Gather, gather_relative_cost, true),
865 (
866 SparseStrategy::MaskAllFieldsWithNullsSkipOne,
867 fields.len() - 1,
868 fields.len() == logical_nulls.len(),
869 ),
870 (
871 SparseStrategy::MaskSkipWithoutNulls,
872 logical_nulls.len(),
873 true,
874 ),
875 (
876 SparseStrategy::MaskSkipFullyNull,
877 fields.len() - fully_null_count,
878 true,
879 ),
880 ];
881
882 let (strategy, _, _) = strategies
883 .iter()
884 .filter(|(_, _, applicable)| *applicable)
885 .min_by_key(|(_, cost, _)| cost)
886 .unwrap();
887
888 match strategy {
889 SparseStrategy::Gather => self.gather_nulls(logical_nulls),
890 SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
891 self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
892 }
893 SparseStrategy::MaskSkipWithoutNulls => {
894 self.mask_sparse_skip_without_nulls(logical_nulls)
895 }
896 SparseStrategy::MaskSkipFullyNull => {
897 self.mask_sparse_skip_fully_null(logical_nulls)
898 }
899 }
900 }
901 };
902
903 let null_buffer = NullBuffer::from(boolean_buffer);
904
905 if null_buffer.null_count() > 0 {
906 Some(null_buffer)
907 } else {
908 None
909 }
910 }
911
912 fn is_nullable(&self) -> bool {
913 self.fields
914 .iter()
915 .flatten()
916 .any(|field| field.is_nullable())
917 }
918
919 fn get_buffer_memory_size(&self) -> usize {
920 let mut sum = self.type_ids.inner().capacity();
921 if let Some(o) = self.offsets.as_ref() {
922 sum += o.inner().capacity()
923 }
924 self.fields
925 .iter()
926 .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
927 .sum::<usize>()
928 + sum
929 }
930
931 fn get_array_memory_size(&self) -> usize {
932 let mut sum = self.type_ids.inner().capacity();
933 if let Some(o) = self.offsets.as_ref() {
934 sum += o.inner().capacity()
935 }
936 std::mem::size_of::<Self>()
937 + self
938 .fields
939 .iter()
940 .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
941 .sum::<usize>()
942 + sum
943 }
944}
945
946impl std::fmt::Debug for UnionArray {
947 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
948 let header = if self.is_dense() {
949 "UnionArray(Dense)\n["
950 } else {
951 "UnionArray(Sparse)\n["
952 };
953 writeln!(f, "{header}")?;
954
955 writeln!(f, "-- type id buffer:")?;
956 writeln!(f, "{:?}", self.type_ids)?;
957
958 if let Some(offsets) = &self.offsets {
959 writeln!(f, "-- offsets buffer:")?;
960 writeln!(f, "{offsets:?}")?;
961 }
962
963 let fields = match self.data_type() {
964 DataType::Union(fields, _) => fields,
965 _ => unreachable!(),
966 };
967
968 for (type_id, field) in fields.iter() {
969 let child = self.child(type_id);
970 writeln!(
971 f,
972 "-- child {}: \"{}\" ({:?})",
973 type_id,
974 field.name(),
975 field.data_type()
976 )?;
977 std::fmt::Debug::fmt(child, f)?;
978 writeln!(f)?;
979 }
980 writeln!(f, "]")
981 }
982}
983
984enum SparseStrategy {
989 Gather,
991 MaskAllFieldsWithNullsSkipOne,
993 MaskSkipWithoutNulls,
995 MaskSkipFullyNull,
997}
998
999#[derive(Copy, Clone)]
1000#[repr(usize)]
1001enum Mask {
1002 Zero = 0,
1003 #[allow(clippy::enum_clike_unportable_variant)]
1005 Max = usize::MAX,
1006}
1007
1008fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1009 type_ids_chunk
1010 .iter()
1011 .copied()
1012 .enumerate()
1013 .fold(0, |packed, (bit_idx, v)| {
1014 packed | (((v == type_id) as u64) << bit_idx)
1015 })
1016}
1017
1018fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1020 without_nulls_ids
1021 .iter()
1022 .fold(0, |fully_valid_selected, field_type_id| {
1023 fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1024 })
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030 use std::collections::HashSet;
1031
1032 use crate::array::Int8Type;
1033 use crate::builder::UnionBuilder;
1034 use crate::cast::AsArray;
1035 use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1036 use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1037 use crate::{Int8Array, RecordBatch};
1038 use arrow_buffer::Buffer;
1039 use arrow_schema::{Field, Schema};
1040
1041 #[test]
1042 fn test_dense_i32() {
1043 let mut builder = UnionBuilder::new_dense();
1044 builder.append::<Int32Type>("a", 1).unwrap();
1045 builder.append::<Int32Type>("b", 2).unwrap();
1046 builder.append::<Int32Type>("c", 3).unwrap();
1047 builder.append::<Int32Type>("a", 4).unwrap();
1048 builder.append::<Int32Type>("c", 5).unwrap();
1049 builder.append::<Int32Type>("a", 6).unwrap();
1050 builder.append::<Int32Type>("b", 7).unwrap();
1051 let union = builder.build().unwrap();
1052
1053 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1054 let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1055 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1056
1057 assert_eq!(*union.type_ids(), expected_type_ids);
1059 for (i, id) in expected_type_ids.iter().enumerate() {
1060 assert_eq!(id, &union.type_id(i));
1061 }
1062
1063 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1065 for (i, id) in expected_offsets.iter().enumerate() {
1066 assert_eq!(union.value_offset(i), *id as usize);
1067 }
1068
1069 assert_eq!(
1071 *union.child(0).as_primitive::<Int32Type>().values(),
1072 [1_i32, 4, 6]
1073 );
1074 assert_eq!(
1075 *union.child(1).as_primitive::<Int32Type>().values(),
1076 [2_i32, 7]
1077 );
1078 assert_eq!(
1079 *union.child(2).as_primitive::<Int32Type>().values(),
1080 [3_i32, 5]
1081 );
1082
1083 assert_eq!(expected_array_values.len(), union.len());
1084 for (i, expected_value) in expected_array_values.iter().enumerate() {
1085 assert!(!union.is_null(i));
1086 let slot = union.value(i);
1087 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1088 assert_eq!(slot.len(), 1);
1089 let value = slot.value(0);
1090 assert_eq!(expected_value, &value);
1091 }
1092 }
1093
1094 #[test]
1095 fn slice_union_array_single_field() {
1096 let union_array = {
1099 let mut builder = UnionBuilder::new_dense();
1100 builder.append::<Int32Type>("a", 1).unwrap();
1101 builder.append_null::<Int32Type>("a").unwrap();
1102 builder.append::<Int32Type>("a", 3).unwrap();
1103 builder.append_null::<Int32Type>("a").unwrap();
1104 builder.append::<Int32Type>("a", 4).unwrap();
1105 builder.build().unwrap()
1106 };
1107
1108 let union_slice = union_array.slice(1, 3);
1110 let logical_nulls = union_slice.logical_nulls().unwrap();
1111
1112 assert_eq!(logical_nulls.len(), 3);
1113 assert!(logical_nulls.is_null(0));
1114 assert!(logical_nulls.is_valid(1));
1115 assert!(logical_nulls.is_null(2));
1116 }
1117
1118 #[test]
1119 #[cfg_attr(miri, ignore)]
1120 fn test_dense_i32_large() {
1121 let mut builder = UnionBuilder::new_dense();
1122
1123 let expected_type_ids = vec![0_i8; 1024];
1124 let expected_offsets: Vec<_> = (0..1024).collect();
1125 let expected_array_values: Vec<_> = (1..=1024).collect();
1126
1127 expected_array_values
1128 .iter()
1129 .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1130
1131 let union = builder.build().unwrap();
1132
1133 assert_eq!(*union.type_ids(), expected_type_ids);
1135 for (i, id) in expected_type_ids.iter().enumerate() {
1136 assert_eq!(id, &union.type_id(i));
1137 }
1138
1139 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1141 for (i, id) in expected_offsets.iter().enumerate() {
1142 assert_eq!(union.value_offset(i), *id as usize);
1143 }
1144
1145 for (i, expected_value) in expected_array_values.iter().enumerate() {
1146 assert!(!union.is_null(i));
1147 let slot = union.value(i);
1148 let slot = slot.as_primitive::<Int32Type>();
1149 assert_eq!(slot.len(), 1);
1150 let value = slot.value(0);
1151 assert_eq!(expected_value, &value);
1152 }
1153 }
1154
1155 #[test]
1156 fn test_dense_mixed() {
1157 let mut builder = UnionBuilder::new_dense();
1158 builder.append::<Int32Type>("a", 1).unwrap();
1159 builder.append::<Int64Type>("c", 3).unwrap();
1160 builder.append::<Int32Type>("a", 4).unwrap();
1161 builder.append::<Int64Type>("c", 5).unwrap();
1162 builder.append::<Int32Type>("a", 6).unwrap();
1163 let union = builder.build().unwrap();
1164
1165 assert_eq!(5, union.len());
1166 for i in 0..union.len() {
1167 let slot = union.value(i);
1168 assert!(!union.is_null(i));
1169 match i {
1170 0 => {
1171 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1172 assert_eq!(slot.len(), 1);
1173 let value = slot.value(0);
1174 assert_eq!(1_i32, value);
1175 }
1176 1 => {
1177 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1178 assert_eq!(slot.len(), 1);
1179 let value = slot.value(0);
1180 assert_eq!(3_i64, value);
1181 }
1182 2 => {
1183 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1184 assert_eq!(slot.len(), 1);
1185 let value = slot.value(0);
1186 assert_eq!(4_i32, value);
1187 }
1188 3 => {
1189 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1190 assert_eq!(slot.len(), 1);
1191 let value = slot.value(0);
1192 assert_eq!(5_i64, value);
1193 }
1194 4 => {
1195 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1196 assert_eq!(slot.len(), 1);
1197 let value = slot.value(0);
1198 assert_eq!(6_i32, value);
1199 }
1200 _ => unreachable!(),
1201 }
1202 }
1203 }
1204
1205 #[test]
1206 fn test_dense_mixed_with_nulls() {
1207 let mut builder = UnionBuilder::new_dense();
1208 builder.append::<Int32Type>("a", 1).unwrap();
1209 builder.append::<Int64Type>("c", 3).unwrap();
1210 builder.append::<Int32Type>("a", 10).unwrap();
1211 builder.append_null::<Int32Type>("a").unwrap();
1212 builder.append::<Int32Type>("a", 6).unwrap();
1213 let union = builder.build().unwrap();
1214
1215 assert_eq!(5, union.len());
1216 for i in 0..union.len() {
1217 let slot = union.value(i);
1218 match i {
1219 0 => {
1220 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1221 assert!(!slot.is_null(0));
1222 assert_eq!(slot.len(), 1);
1223 let value = slot.value(0);
1224 assert_eq!(1_i32, value);
1225 }
1226 1 => {
1227 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1228 assert!(!slot.is_null(0));
1229 assert_eq!(slot.len(), 1);
1230 let value = slot.value(0);
1231 assert_eq!(3_i64, value);
1232 }
1233 2 => {
1234 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1235 assert!(!slot.is_null(0));
1236 assert_eq!(slot.len(), 1);
1237 let value = slot.value(0);
1238 assert_eq!(10_i32, value);
1239 }
1240 3 => assert!(slot.is_null(0)),
1241 4 => {
1242 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1243 assert!(!slot.is_null(0));
1244 assert_eq!(slot.len(), 1);
1245 let value = slot.value(0);
1246 assert_eq!(6_i32, value);
1247 }
1248 _ => unreachable!(),
1249 }
1250 }
1251 }
1252
1253 #[test]
1254 fn test_dense_mixed_with_nulls_and_offset() {
1255 let mut builder = UnionBuilder::new_dense();
1256 builder.append::<Int32Type>("a", 1).unwrap();
1257 builder.append::<Int64Type>("c", 3).unwrap();
1258 builder.append::<Int32Type>("a", 10).unwrap();
1259 builder.append_null::<Int32Type>("a").unwrap();
1260 builder.append::<Int32Type>("a", 6).unwrap();
1261 let union = builder.build().unwrap();
1262
1263 let slice = union.slice(2, 3);
1264 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1265
1266 assert_eq!(3, new_union.len());
1267 for i in 0..new_union.len() {
1268 let slot = new_union.value(i);
1269 match i {
1270 0 => {
1271 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1272 assert!(!slot.is_null(0));
1273 assert_eq!(slot.len(), 1);
1274 let value = slot.value(0);
1275 assert_eq!(10_i32, value);
1276 }
1277 1 => assert!(slot.is_null(0)),
1278 2 => {
1279 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1280 assert!(!slot.is_null(0));
1281 assert_eq!(slot.len(), 1);
1282 let value = slot.value(0);
1283 assert_eq!(6_i32, value);
1284 }
1285 _ => unreachable!(),
1286 }
1287 }
1288 }
1289
1290 #[test]
1291 fn test_dense_mixed_with_str() {
1292 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1293 let int_array = Int32Array::from(vec![5, 6]);
1294 let float_array = Float64Array::from(vec![10.0]);
1295
1296 let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1297 let offsets = [0, 0, 1, 0, 2, 1]
1298 .into_iter()
1299 .collect::<ScalarBuffer<i32>>();
1300
1301 let fields = [
1302 (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1303 (1, Arc::new(Field::new("B", DataType::Int32, false))),
1304 (2, Arc::new(Field::new("C", DataType::Float64, false))),
1305 ]
1306 .into_iter()
1307 .collect::<UnionFields>();
1308 let children = [
1309 Arc::new(string_array) as Arc<dyn Array>,
1310 Arc::new(int_array),
1311 Arc::new(float_array),
1312 ]
1313 .into_iter()
1314 .collect();
1315 let array =
1316 UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1317
1318 assert_eq!(*array.type_ids(), type_ids);
1320 for (i, id) in type_ids.iter().enumerate() {
1321 assert_eq!(id, &array.type_id(i));
1322 }
1323
1324 assert_eq!(*array.offsets().unwrap(), offsets);
1326 for (i, id) in offsets.iter().enumerate() {
1327 assert_eq!(*id as usize, array.value_offset(i));
1328 }
1329
1330 assert_eq!(6, array.len());
1332
1333 let slot = array.value(0);
1334 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1335 assert_eq!(5, value);
1336
1337 let slot = array.value(1);
1338 let value = slot
1339 .as_any()
1340 .downcast_ref::<StringArray>()
1341 .unwrap()
1342 .value(0);
1343 assert_eq!("foo", value);
1344
1345 let slot = array.value(2);
1346 let value = slot
1347 .as_any()
1348 .downcast_ref::<StringArray>()
1349 .unwrap()
1350 .value(0);
1351 assert_eq!("bar", value);
1352
1353 let slot = array.value(3);
1354 let value = slot
1355 .as_any()
1356 .downcast_ref::<Float64Array>()
1357 .unwrap()
1358 .value(0);
1359 assert_eq!(10.0, value);
1360
1361 let slot = array.value(4);
1362 let value = slot
1363 .as_any()
1364 .downcast_ref::<StringArray>()
1365 .unwrap()
1366 .value(0);
1367 assert_eq!("baz", value);
1368
1369 let slot = array.value(5);
1370 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1371 assert_eq!(6, value);
1372 }
1373
1374 #[test]
1375 fn test_sparse_i32() {
1376 let mut builder = UnionBuilder::new_sparse();
1377 builder.append::<Int32Type>("a", 1).unwrap();
1378 builder.append::<Int32Type>("b", 2).unwrap();
1379 builder.append::<Int32Type>("c", 3).unwrap();
1380 builder.append::<Int32Type>("a", 4).unwrap();
1381 builder.append::<Int32Type>("c", 5).unwrap();
1382 builder.append::<Int32Type>("a", 6).unwrap();
1383 builder.append::<Int32Type>("b", 7).unwrap();
1384 let union = builder.build().unwrap();
1385
1386 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1387 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1388
1389 assert_eq!(*union.type_ids(), expected_type_ids);
1391 for (i, id) in expected_type_ids.iter().enumerate() {
1392 assert_eq!(id, &union.type_id(i));
1393 }
1394
1395 assert!(union.offsets().is_none());
1397
1398 assert_eq!(
1400 *union.child(0).as_primitive::<Int32Type>().values(),
1401 [1_i32, 0, 0, 4, 0, 6, 0],
1402 );
1403 assert_eq!(
1404 *union.child(1).as_primitive::<Int32Type>().values(),
1405 [0_i32, 2_i32, 0, 0, 0, 0, 7]
1406 );
1407 assert_eq!(
1408 *union.child(2).as_primitive::<Int32Type>().values(),
1409 [0_i32, 0, 3_i32, 0, 5, 0, 0]
1410 );
1411
1412 assert_eq!(expected_array_values.len(), union.len());
1413 for (i, expected_value) in expected_array_values.iter().enumerate() {
1414 assert!(!union.is_null(i));
1415 let slot = union.value(i);
1416 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1417 assert_eq!(slot.len(), 1);
1418 let value = slot.value(0);
1419 assert_eq!(expected_value, &value);
1420 }
1421 }
1422
1423 #[test]
1424 fn test_sparse_mixed() {
1425 let mut builder = UnionBuilder::new_sparse();
1426 builder.append::<Int32Type>("a", 1).unwrap();
1427 builder.append::<Float64Type>("c", 3.0).unwrap();
1428 builder.append::<Int32Type>("a", 4).unwrap();
1429 builder.append::<Float64Type>("c", 5.0).unwrap();
1430 builder.append::<Int32Type>("a", 6).unwrap();
1431 let union = builder.build().unwrap();
1432
1433 let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1434
1435 assert_eq!(*union.type_ids(), expected_type_ids);
1437 for (i, id) in expected_type_ids.iter().enumerate() {
1438 assert_eq!(id, &union.type_id(i));
1439 }
1440
1441 assert!(union.offsets().is_none());
1443
1444 for i in 0..union.len() {
1445 let slot = union.value(i);
1446 assert!(!union.is_null(i));
1447 match i {
1448 0 => {
1449 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1450 assert_eq!(slot.len(), 1);
1451 let value = slot.value(0);
1452 assert_eq!(1_i32, value);
1453 }
1454 1 => {
1455 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1456 assert_eq!(slot.len(), 1);
1457 let value = slot.value(0);
1458 assert_eq!(value, 3_f64);
1459 }
1460 2 => {
1461 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1462 assert_eq!(slot.len(), 1);
1463 let value = slot.value(0);
1464 assert_eq!(4_i32, value);
1465 }
1466 3 => {
1467 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1468 assert_eq!(slot.len(), 1);
1469 let value = slot.value(0);
1470 assert_eq!(5_f64, value);
1471 }
1472 4 => {
1473 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1474 assert_eq!(slot.len(), 1);
1475 let value = slot.value(0);
1476 assert_eq!(6_i32, value);
1477 }
1478 _ => unreachable!(),
1479 }
1480 }
1481 }
1482
1483 #[test]
1484 fn test_sparse_mixed_with_nulls() {
1485 let mut builder = UnionBuilder::new_sparse();
1486 builder.append::<Int32Type>("a", 1).unwrap();
1487 builder.append_null::<Int32Type>("a").unwrap();
1488 builder.append::<Float64Type>("c", 3.0).unwrap();
1489 builder.append::<Int32Type>("a", 4).unwrap();
1490 let union = builder.build().unwrap();
1491
1492 let expected_type_ids = vec![0_i8, 0, 1, 0];
1493
1494 assert_eq!(*union.type_ids(), expected_type_ids);
1496 for (i, id) in expected_type_ids.iter().enumerate() {
1497 assert_eq!(id, &union.type_id(i));
1498 }
1499
1500 assert!(union.offsets().is_none());
1502
1503 for i in 0..union.len() {
1504 let slot = union.value(i);
1505 match i {
1506 0 => {
1507 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1508 assert!(!slot.is_null(0));
1509 assert_eq!(slot.len(), 1);
1510 let value = slot.value(0);
1511 assert_eq!(1_i32, value);
1512 }
1513 1 => assert!(slot.is_null(0)),
1514 2 => {
1515 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1516 assert!(!slot.is_null(0));
1517 assert_eq!(slot.len(), 1);
1518 let value = slot.value(0);
1519 assert_eq!(value, 3_f64);
1520 }
1521 3 => {
1522 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1523 assert!(!slot.is_null(0));
1524 assert_eq!(slot.len(), 1);
1525 let value = slot.value(0);
1526 assert_eq!(4_i32, value);
1527 }
1528 _ => unreachable!(),
1529 }
1530 }
1531 }
1532
1533 #[test]
1534 fn test_sparse_mixed_with_nulls_and_offset() {
1535 let mut builder = UnionBuilder::new_sparse();
1536 builder.append::<Int32Type>("a", 1).unwrap();
1537 builder.append_null::<Int32Type>("a").unwrap();
1538 builder.append::<Float64Type>("c", 3.0).unwrap();
1539 builder.append_null::<Float64Type>("c").unwrap();
1540 builder.append::<Int32Type>("a", 4).unwrap();
1541 let union = builder.build().unwrap();
1542
1543 let slice = union.slice(1, 4);
1544 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1545
1546 assert_eq!(4, new_union.len());
1547 for i in 0..new_union.len() {
1548 let slot = new_union.value(i);
1549 match i {
1550 0 => assert!(slot.is_null(0)),
1551 1 => {
1552 let slot = slot.as_primitive::<Float64Type>();
1553 assert!(!slot.is_null(0));
1554 assert_eq!(slot.len(), 1);
1555 let value = slot.value(0);
1556 assert_eq!(value, 3_f64);
1557 }
1558 2 => assert!(slot.is_null(0)),
1559 3 => {
1560 let slot = slot.as_primitive::<Int32Type>();
1561 assert!(!slot.is_null(0));
1562 assert_eq!(slot.len(), 1);
1563 let value = slot.value(0);
1564 assert_eq!(4_i32, value);
1565 }
1566 _ => unreachable!(),
1567 }
1568 }
1569 }
1570
1571 fn test_union_validity(union_array: &UnionArray) {
1572 assert_eq!(union_array.null_count(), 0);
1573
1574 for i in 0..union_array.len() {
1575 assert!(!union_array.is_null(i));
1576 assert!(union_array.is_valid(i));
1577 }
1578 }
1579
1580 #[test]
1581 fn test_union_array_validity() {
1582 let mut builder = UnionBuilder::new_sparse();
1583 builder.append::<Int32Type>("a", 1).unwrap();
1584 builder.append_null::<Int32Type>("a").unwrap();
1585 builder.append::<Float64Type>("c", 3.0).unwrap();
1586 builder.append_null::<Float64Type>("c").unwrap();
1587 builder.append::<Int32Type>("a", 4).unwrap();
1588 let union = builder.build().unwrap();
1589
1590 test_union_validity(&union);
1591
1592 let mut builder = UnionBuilder::new_dense();
1593 builder.append::<Int32Type>("a", 1).unwrap();
1594 builder.append_null::<Int32Type>("a").unwrap();
1595 builder.append::<Float64Type>("c", 3.0).unwrap();
1596 builder.append_null::<Float64Type>("c").unwrap();
1597 builder.append::<Int32Type>("a", 4).unwrap();
1598 let union = builder.build().unwrap();
1599
1600 test_union_validity(&union);
1601 }
1602
1603 #[test]
1604 fn test_type_check() {
1605 let mut builder = UnionBuilder::new_sparse();
1606 builder.append::<Float32Type>("a", 1.0).unwrap();
1607 let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1608 assert!(
1609 err.contains(
1610 "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1611 ),
1612 "{}",
1613 err
1614 );
1615 }
1616
1617 #[test]
1618 fn slice_union_array() {
1619 fn create_union(mut builder: UnionBuilder) -> UnionArray {
1621 builder.append::<Int32Type>("a", 1).unwrap();
1622 builder.append_null::<Int32Type>("a").unwrap();
1623 builder.append::<Float64Type>("c", 3.0).unwrap();
1624 builder.append_null::<Float64Type>("c").unwrap();
1625 builder.append::<Int32Type>("a", 4).unwrap();
1626 builder.build().unwrap()
1627 }
1628
1629 fn create_batch(union: UnionArray) -> RecordBatch {
1630 let schema = Schema::new(vec![Field::new(
1631 "struct_array",
1632 union.data_type().clone(),
1633 true,
1634 )]);
1635
1636 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1637 }
1638
1639 fn test_slice_union(record_batch_slice: RecordBatch) {
1640 let union_slice = record_batch_slice
1641 .column(0)
1642 .as_any()
1643 .downcast_ref::<UnionArray>()
1644 .unwrap();
1645
1646 assert_eq!(union_slice.type_id(0), 0);
1647 assert_eq!(union_slice.type_id(1), 1);
1648 assert_eq!(union_slice.type_id(2), 1);
1649
1650 let slot = union_slice.value(0);
1651 let array = slot.as_primitive::<Int32Type>();
1652 assert_eq!(array.len(), 1);
1653 assert!(array.is_null(0));
1654
1655 let slot = union_slice.value(1);
1656 let array = slot.as_primitive::<Float64Type>();
1657 assert_eq!(array.len(), 1);
1658 assert!(array.is_valid(0));
1659 assert_eq!(array.value(0), 3.0);
1660
1661 let slot = union_slice.value(2);
1662 let array = slot.as_primitive::<Float64Type>();
1663 assert_eq!(array.len(), 1);
1664 assert!(array.is_null(0));
1665 }
1666
1667 let builder = UnionBuilder::new_sparse();
1669 let record_batch = create_batch(create_union(builder));
1670 let record_batch_slice = record_batch.slice(1, 3);
1672 test_slice_union(record_batch_slice);
1673
1674 let builder = UnionBuilder::new_dense();
1676 let record_batch = create_batch(create_union(builder));
1677 let record_batch_slice = record_batch.slice(1, 3);
1679 test_slice_union(record_batch_slice);
1680 }
1681
1682 #[test]
1683 fn test_custom_type_ids() {
1684 let data_type = DataType::Union(
1685 UnionFields::new(
1686 vec![8, 4, 9],
1687 vec![
1688 Field::new("strings", DataType::Utf8, false),
1689 Field::new("integers", DataType::Int32, false),
1690 Field::new("floats", DataType::Float64, false),
1691 ],
1692 ),
1693 UnionMode::Dense,
1694 );
1695
1696 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1697 let int_array = Int32Array::from(vec![5, 6, 4]);
1698 let float_array = Float64Array::from(vec![10.0]);
1699
1700 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1701 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1702
1703 let data = ArrayData::builder(data_type)
1704 .len(7)
1705 .buffers(vec![type_ids, value_offsets])
1706 .child_data(vec![
1707 string_array.into_data(),
1708 int_array.into_data(),
1709 float_array.into_data(),
1710 ])
1711 .build()
1712 .unwrap();
1713
1714 let array = UnionArray::from(data);
1715
1716 let v = array.value(0);
1717 assert_eq!(v.data_type(), &DataType::Int32);
1718 assert_eq!(v.len(), 1);
1719 assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1720
1721 let v = array.value(1);
1722 assert_eq!(v.data_type(), &DataType::Utf8);
1723 assert_eq!(v.len(), 1);
1724 assert_eq!(v.as_string::<i32>().value(0), "foo");
1725
1726 let v = array.value(2);
1727 assert_eq!(v.data_type(), &DataType::Int32);
1728 assert_eq!(v.len(), 1);
1729 assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1730
1731 let v = array.value(3);
1732 assert_eq!(v.data_type(), &DataType::Utf8);
1733 assert_eq!(v.len(), 1);
1734 assert_eq!(v.as_string::<i32>().value(0), "bar");
1735
1736 let v = array.value(4);
1737 assert_eq!(v.data_type(), &DataType::Float64);
1738 assert_eq!(v.len(), 1);
1739 assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1740
1741 let v = array.value(5);
1742 assert_eq!(v.data_type(), &DataType::Int32);
1743 assert_eq!(v.len(), 1);
1744 assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1745
1746 let v = array.value(6);
1747 assert_eq!(v.data_type(), &DataType::Utf8);
1748 assert_eq!(v.len(), 1);
1749 assert_eq!(v.as_string::<i32>().value(0), "baz");
1750 }
1751
1752 #[test]
1753 fn into_parts() {
1754 let mut builder = UnionBuilder::new_dense();
1755 builder.append::<Int32Type>("a", 1).unwrap();
1756 builder.append::<Int8Type>("b", 2).unwrap();
1757 builder.append::<Int32Type>("a", 3).unwrap();
1758 let dense_union = builder.build().unwrap();
1759
1760 let field = [
1761 &Arc::new(Field::new("a", DataType::Int32, false)),
1762 &Arc::new(Field::new("b", DataType::Int8, false)),
1763 ];
1764 let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1765 assert_eq!(
1766 union_fields
1767 .iter()
1768 .map(|(_, field)| field)
1769 .collect::<Vec<_>>(),
1770 field
1771 );
1772 assert_eq!(type_ids, [0, 1, 0]);
1773 assert!(offsets.is_some());
1774 assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1775
1776 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1777 assert!(result.is_ok());
1778 assert_eq!(result.unwrap().len(), 3);
1779
1780 let mut builder = UnionBuilder::new_sparse();
1781 builder.append::<Int32Type>("a", 1).unwrap();
1782 builder.append::<Int8Type>("b", 2).unwrap();
1783 builder.append::<Int32Type>("a", 3).unwrap();
1784 let sparse_union = builder.build().unwrap();
1785
1786 let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1787 assert_eq!(type_ids, [0, 1, 0]);
1788 assert!(offsets.is_none());
1789
1790 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1791 assert!(result.is_ok());
1792 assert_eq!(result.unwrap().len(), 3);
1793 }
1794
1795 #[test]
1796 fn into_parts_custom_type_ids() {
1797 let set_field_type_ids: [i8; 3] = [8, 4, 9];
1798 let data_type = DataType::Union(
1799 UnionFields::new(
1800 set_field_type_ids,
1801 [
1802 Field::new("strings", DataType::Utf8, false),
1803 Field::new("integers", DataType::Int32, false),
1804 Field::new("floats", DataType::Float64, false),
1805 ],
1806 ),
1807 UnionMode::Dense,
1808 );
1809 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1810 let int_array = Int32Array::from(vec![5, 6, 4]);
1811 let float_array = Float64Array::from(vec![10.0]);
1812 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1813 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1814 let data = ArrayData::builder(data_type)
1815 .len(7)
1816 .buffers(vec![type_ids, value_offsets])
1817 .child_data(vec![
1818 string_array.into_data(),
1819 int_array.into_data(),
1820 float_array.into_data(),
1821 ])
1822 .build()
1823 .unwrap();
1824 let array = UnionArray::from(data);
1825
1826 let (union_fields, type_ids, offsets, children) = array.into_parts();
1827 assert_eq!(
1828 type_ids.iter().collect::<HashSet<_>>(),
1829 set_field_type_ids.iter().collect::<HashSet<_>>()
1830 );
1831 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1832 assert!(result.is_ok());
1833 let array = result.unwrap();
1834 assert_eq!(array.len(), 7);
1835 }
1836
1837 #[test]
1838 fn test_invalid() {
1839 let fields = UnionFields::new(
1840 [3, 2],
1841 [
1842 Field::new("a", DataType::Utf8, false),
1843 Field::new("b", DataType::Utf8, false),
1844 ],
1845 );
1846 let children = vec![
1847 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1848 Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1849 ];
1850
1851 let type_ids = vec![3, 3, 2].into();
1852 let err =
1853 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1854 assert_eq!(
1855 err.to_string(),
1856 "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1857 );
1858
1859 let type_ids = vec![1, 2].into();
1860 let err =
1861 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1862 assert_eq!(
1863 err.to_string(),
1864 "Invalid argument error: Type Ids values must match one of the field type ids"
1865 );
1866
1867 let type_ids = vec![7, 2].into();
1868 let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1869 assert_eq!(
1870 err.to_string(),
1871 "Invalid argument error: Type Ids values must match one of the field type ids"
1872 );
1873
1874 let children = vec![
1875 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1876 Arc::new(StringArray::from_iter_values(["c"])) as _,
1877 ];
1878 let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1879 let offsets = Some(vec![0, 1, 0].into());
1880 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1881
1882 let offsets = Some(vec![0, 1, 1].into());
1883 let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1884 .unwrap_err();
1885
1886 assert_eq!(
1887 err.to_string(),
1888 "Invalid argument error: Offsets must be non-negative and within the length of the Array"
1889 );
1890
1891 let offsets = Some(vec![0, 1].into());
1892 let err =
1893 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1894
1895 assert_eq!(
1896 err.to_string(),
1897 "Invalid argument error: Type Ids and Offsets lengths must match"
1898 );
1899
1900 let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1901
1902 assert_eq!(
1903 err.to_string(),
1904 "Invalid argument error: Union fields length must match child arrays length"
1905 );
1906 }
1907
1908 #[test]
1909 fn test_logical_nulls_fast_paths() {
1910 let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1912
1913 assert_eq!(array.logical_nulls(), None);
1914
1915 let fields = UnionFields::new(
1916 [1, 3],
1917 [
1918 Field::new("a", DataType::Int8, false), Field::new("b", DataType::Int8, false), ],
1921 );
1922 let array = UnionArray::try_new(
1923 fields,
1924 vec![1].into(),
1925 None,
1926 vec![
1927 Arc::new(Int8Array::from_value(5, 1)),
1928 Arc::new(Int8Array::from_value(5, 1)),
1929 ],
1930 )
1931 .unwrap();
1932
1933 assert_eq!(array.logical_nulls(), None);
1934
1935 let nullable_fields = UnionFields::new(
1936 [1, 3],
1937 [
1938 Field::new("a", DataType::Int8, true), Field::new("b", DataType::Int8, true), ],
1941 );
1942 let array = UnionArray::try_new(
1943 nullable_fields.clone(),
1944 vec![1, 1].into(),
1945 None,
1946 vec![
1947 Arc::new(Int8Array::from_value(-5, 2)), Arc::new(Int8Array::from_value(-5, 2)), ],
1950 )
1951 .unwrap();
1952
1953 assert_eq!(array.logical_nulls(), None);
1954
1955 let array = UnionArray::try_new(
1956 nullable_fields.clone(),
1957 vec![1, 1].into(),
1958 None,
1959 vec![
1960 Arc::new(Int8Array::new_null(2)), Arc::new(Int8Array::new_null(2)), ],
1964 )
1965 .unwrap();
1966
1967 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1968
1969 let array = UnionArray::try_new(
1970 nullable_fields.clone(),
1971 vec![1, 1].into(),
1972 Some(vec![0, 1].into()),
1973 vec![
1974 Arc::new(Int8Array::new_null(3)), Arc::new(Int8Array::new_null(3)), ],
1978 )
1979 .unwrap();
1980
1981 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1982 }
1983
1984 #[test]
1985 fn test_dense_union_logical_nulls_gather() {
1986 let int_array = Int32Array::from(vec![1, 2]);
1988 let float_array = Float64Array::from(vec![Some(3.2), None]);
1989 let str_array = StringArray::new_null(1);
1990 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1991 let offsets = [0, 1, 0, 1, 0, 0]
1992 .into_iter()
1993 .collect::<ScalarBuffer<i32>>();
1994
1995 let children = vec![
1996 Arc::new(int_array) as Arc<dyn Array>,
1997 Arc::new(float_array),
1998 Arc::new(str_array),
1999 ];
2000
2001 let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
2002
2003 let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
2004
2005 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2006 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2007 }
2008
2009 #[test]
2010 fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2011 let fields: UnionFields = [
2012 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2013 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2014 ]
2015 .into_iter()
2016 .collect();
2017
2018 let int_array = Int32Array::new_null(4);
2020 let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2021 let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2022
2023 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2024
2025 let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2026
2027 let expected = BooleanBuffer::from(vec![false, false, true, false]);
2028
2029 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2030 assert_eq!(
2031 expected,
2032 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2033 );
2034
2035 let len = 2 * 64 + 32;
2037
2038 let int_array = Int32Array::new_null(len);
2039 let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2040 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2041
2042 let array = UnionArray::try_new(
2043 fields,
2044 type_ids,
2045 None,
2046 vec![Arc::new(int_array), Arc::new(float_array)],
2047 )
2048 .unwrap();
2049
2050 let expected =
2051 BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2052
2053 assert_eq!(array.len(), len);
2054 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2055 assert_eq!(
2056 expected,
2057 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2058 );
2059 }
2060
2061 #[test]
2062 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2063 let int_array = Int32Array::from_value(2, 6);
2065 let float_array = Float64Array::from_value(4.2, 6);
2066 let str_array = StringArray::new_null(6);
2067 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2068
2069 let children = vec![
2070 Arc::new(int_array) as Arc<dyn Array>,
2071 Arc::new(float_array),
2072 Arc::new(str_array),
2073 ];
2074
2075 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2076
2077 let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2078
2079 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2080 assert_eq!(
2081 expected,
2082 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2083 );
2084
2085 let len = 2 * 64 + 32;
2087
2088 let int_array = Int32Array::from_value(2, len);
2089 let float_array = Float64Array::from_value(4.2, len);
2090 let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2091 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2092
2093 let children = vec![
2094 Arc::new(int_array) as Arc<dyn Array>,
2095 Arc::new(float_array),
2096 Arc::new(str_array),
2097 ];
2098
2099 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2100
2101 let expected = BooleanBuffer::from_iter(
2102 [true, true, true, true, false, true]
2103 .into_iter()
2104 .cycle()
2105 .take(len),
2106 );
2107
2108 assert_eq!(array.len(), len);
2109 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2110 assert_eq!(
2111 expected,
2112 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2113 );
2114 }
2115
2116 #[test]
2117 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2118 let int_array = Int32Array::new_null(6);
2120 let float_array = Float64Array::from_value(4.2, 6);
2121 let str_array = StringArray::new_null(6);
2122 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2123
2124 let children = vec![
2125 Arc::new(int_array) as Arc<dyn Array>,
2126 Arc::new(float_array),
2127 Arc::new(str_array),
2128 ];
2129
2130 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2131
2132 let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2133
2134 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2135 assert_eq!(
2136 expected,
2137 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2138 );
2139
2140 let len = 2 * 64 + 32;
2142
2143 let int_array = Int32Array::new_null(len);
2144 let float_array = Float64Array::from_value(4.2, len);
2145 let str_array = StringArray::new_null(len);
2146 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2147
2148 let children = vec![
2149 Arc::new(int_array) as Arc<dyn Array>,
2150 Arc::new(float_array),
2151 Arc::new(str_array),
2152 ];
2153
2154 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2155
2156 let expected = BooleanBuffer::from_iter(
2157 [false, false, true, true, false, false]
2158 .into_iter()
2159 .cycle()
2160 .take(len),
2161 );
2162
2163 assert_eq!(array.len(), len);
2164 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2165 assert_eq!(
2166 expected,
2167 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2168 );
2169 }
2170
2171 #[test]
2172 fn test_sparse_union_logical_nulls_gather() {
2173 let n_fields = 50;
2174
2175 let non_null = Int32Array::from_value(2, 4);
2176 let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2177 let fully_null = Int32Array::new_null(4);
2178
2179 let array = UnionArray::try_new(
2180 (1..)
2181 .step_by(2)
2182 .map(|i| {
2183 (
2184 i,
2185 Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2186 )
2187 })
2188 .take(n_fields)
2189 .collect(),
2190 vec![1, 3, 3, 5].into(),
2191 None,
2192 [
2193 Arc::new(non_null) as ArrayRef,
2194 Arc::new(mixed),
2195 Arc::new(fully_null),
2196 ]
2197 .into_iter()
2198 .cycle()
2199 .take(n_fields)
2200 .collect(),
2201 )
2202 .unwrap();
2203
2204 let expected = BooleanBuffer::from(vec![true, false, true, false]);
2205
2206 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2207 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2208 }
2209
2210 fn union_fields() -> UnionFields {
2211 [
2212 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2213 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2214 (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2215 ]
2216 .into_iter()
2217 .collect()
2218 }
2219
2220 #[test]
2221 fn test_is_nullable() {
2222 assert!(!create_union_array(false, false).is_nullable());
2223 assert!(create_union_array(true, false).is_nullable());
2224 assert!(create_union_array(false, true).is_nullable());
2225 assert!(create_union_array(true, true).is_nullable());
2226 }
2227
2228 fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2235 let int_array = if int_nullable {
2236 Int32Array::from(vec![Some(1), None, Some(3)])
2237 } else {
2238 Int32Array::from(vec![1, 2, 3])
2239 };
2240 let float_array = if float_nullable {
2241 Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2242 } else {
2243 Float64Array::from(vec![3.2, 4.2, 5.2])
2244 };
2245 let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2246 let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2247 let union_fields = [
2248 (0, Arc::new(Field::new("A", DataType::Int32, true))),
2249 (1, Arc::new(Field::new("B", DataType::Float64, true))),
2250 ]
2251 .into_iter()
2252 .collect::<UnionFields>();
2253
2254 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2255
2256 UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2257 }
2258}