Skip to main content

arrow_array/array/
union_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{Array, ArrayRef, make_array};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25/// Contains the `UnionArray` type.
26///
27use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
32///
33/// Each slot in a [UnionArray] can have a value chosen from a number
34/// of types.  Each of the possible types are named like the fields of
35/// a [`StructArray`](crate::StructArray).  A `UnionArray` can
36/// have two possible memory layouts, "dense" or "sparse".  For more
37/// information on please see the
38/// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout).
39///
40/// [UnionBuilder](crate::builder::UnionBuilder) can be used to
41/// create [UnionArray]'s of primitive types. `UnionArray`'s of nested
42/// types are also supported but not via `UnionBuilder`, see the tests
43/// for examples.
44///
45/// # Examples
46/// ## Create a dense UnionArray `[1, 3.2, 34]`
47/// ```
48/// use arrow_buffer::ScalarBuffer;
49/// use arrow_schema::*;
50/// use std::sync::Arc;
51/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
52///
53/// let int_array = Int32Array::from(vec![1, 34]);
54/// let float_array = Float64Array::from(vec![3.2]);
55/// let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
56/// let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
57///
58/// let union_fields = [
59///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
60///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
61/// ].into_iter().collect::<UnionFields>();
62///
63/// let children = vec![
64///     Arc::new(int_array) as Arc<dyn Array>,
65///     Arc::new(float_array),
66/// ];
67///
68/// let array = UnionArray::try_new(
69///     union_fields,
70///     type_ids,
71///     Some(offsets),
72///     children,
73/// ).unwrap();
74///
75/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
76/// assert_eq!(1, value);
77///
78/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
79/// assert!(3.2 - value < f64::EPSILON);
80///
81/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
82/// assert_eq!(34, value);
83/// ```
84///
85/// ## Create a sparse UnionArray `[1, 3.2, 34]`
86/// ```
87/// use arrow_buffer::ScalarBuffer;
88/// use arrow_schema::*;
89/// use std::sync::Arc;
90/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
91///
92/// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]);
93/// let float_array = Float64Array::from(vec![None, Some(3.2), None]);
94/// let type_ids = [0_i8, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
95///
96/// let union_fields = [
97///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
98///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
99/// ].into_iter().collect::<UnionFields>();
100///
101/// let children = vec![
102///     Arc::new(int_array) as Arc<dyn Array>,
103///     Arc::new(float_array),
104/// ];
105///
106/// let array = UnionArray::try_new(
107///     union_fields,
108///     type_ids,
109///     None,
110///     children,
111/// ).unwrap();
112///
113/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
114/// assert_eq!(1, value);
115///
116/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
117/// assert!(3.2 - value < f64::EPSILON);
118///
119/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
120/// assert_eq!(34, value);
121/// ```
122#[derive(Clone)]
123pub struct UnionArray {
124    data_type: DataType,
125    type_ids: ScalarBuffer<i8>,
126    offsets: Option<ScalarBuffer<i32>>,
127    fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131    /// Creates a new `UnionArray`.
132    ///
133    /// Accepts type ids, child arrays and optionally offsets (for dense unions) to create
134    /// a new `UnionArray`.  This method makes no attempt to validate the data provided by the
135    /// caller and assumes that each of the components are correct and consistent with each other.
136    /// See `try_new` for an alternative that validates the data provided.
137    ///
138    /// # Safety
139    ///
140    /// The `type_ids` values should be non-negative and must match one of the type ids of the fields provided in `fields`.
141    /// These values are used to index into the `children` arrays.
142    ///
143    /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`.
144    /// If provided the `offsets` values should be non-negative and must be less than the length of the
145    /// corresponding array.
146    ///
147    /// In both cases above we use signed integer types to maintain compatibility with other
148    /// Arrow implementations.
149    pub unsafe fn new_unchecked(
150        fields: UnionFields,
151        type_ids: ScalarBuffer<i8>,
152        offsets: Option<ScalarBuffer<i32>>,
153        children: Vec<ArrayRef>,
154    ) -> Self {
155        let mode = if offsets.is_some() {
156            UnionMode::Dense
157        } else {
158            UnionMode::Sparse
159        };
160
161        let len = type_ids.len();
162        let builder = ArrayData::builder(DataType::Union(fields, mode))
163            .add_buffer(type_ids.into_inner())
164            .child_data(children.into_iter().map(Array::into_data).collect())
165            .len(len);
166
167        let data = match offsets {
168            Some(offsets) => unsafe { builder.add_buffer(offsets.into_inner()).build_unchecked() },
169            None => unsafe { builder.build_unchecked() },
170        };
171        Self::from(data)
172    }
173
174    /// Attempts to create a new `UnionArray`, validating the inputs provided.
175    ///
176    /// The order of child arrays child array order must match the fields order
177    pub fn try_new(
178        fields: UnionFields,
179        type_ids: ScalarBuffer<i8>,
180        offsets: Option<ScalarBuffer<i32>>,
181        children: Vec<ArrayRef>,
182    ) -> Result<Self, ArrowError> {
183        // There must be a child array for every field.
184        if fields.len() != children.len() {
185            return Err(ArrowError::InvalidArgumentError(
186                "Union fields length must match child arrays length".to_string(),
187            ));
188        }
189
190        if let Some(offsets) = &offsets {
191            // There must be an offset value for every type id value.
192            if offsets.len() != type_ids.len() {
193                return Err(ArrowError::InvalidArgumentError(
194                    "Type Ids and Offsets lengths must match".to_string(),
195                ));
196            }
197        } else {
198            // Sparse union child arrays must be equal in length to the length of the union
199            for child in &children {
200                if child.len() != type_ids.len() {
201                    return Err(ArrowError::InvalidArgumentError(
202                        "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203                    ));
204                }
205            }
206        }
207
208        // Create mapping from type id to array lengths.
209        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210        let mut array_lens = vec![i32::MIN; max_id + 1];
211        for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212            array_lens[field_id as usize] = cd.len() as i32;
213        }
214
215        // Type id values must match one of the fields.
216        for id in &type_ids {
217            match array_lens.get(*id as usize) {
218                Some(x) if *x != i32::MIN => {}
219                _ => {
220                    return Err(ArrowError::InvalidArgumentError(
221                        "Type Ids values must match one of the field type ids".to_owned(),
222                    ));
223                }
224            }
225        }
226
227        // Check the value offsets are in bounds.
228        if let Some(offsets) = &offsets {
229            let mut iter = type_ids.iter().zip(offsets.iter());
230            if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231            {
232                return Err(ArrowError::InvalidArgumentError(
233                    "Offsets must be non-negative and within the length of the Array".to_owned(),
234                ));
235            }
236        }
237
238        // Safety:
239        // - Arguments validated above.
240        let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241        Ok(union_array)
242    }
243
244    /// Accesses the child array for `type_id`.
245    ///
246    /// # Panics
247    ///
248    /// Panics if the `type_id` provided is not present in the array's DataType
249    /// in the `Union`.
250    pub fn child(&self, type_id: i8) -> &ArrayRef {
251        assert!((type_id as usize) < self.fields.len());
252        let boxed = &self.fields[type_id as usize];
253        boxed.as_ref().expect("invalid type id")
254    }
255
256    /// Returns the `type_id` for the array slot at `index`.
257    ///
258    /// # Panics
259    ///
260    /// Panics if `index` is greater than or equal to the number of child arrays
261    pub fn type_id(&self, index: usize) -> i8 {
262        assert!(index < self.type_ids.len());
263        self.type_ids[index]
264    }
265
266    /// Returns the `type_ids` buffer for this array
267    pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268        &self.type_ids
269    }
270
271    /// Returns the `offsets` buffer if this is a dense array
272    pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273        self.offsets.as_ref()
274    }
275
276    /// Returns the offset into the underlying values array for the array slot at `index`.
277    ///
278    /// # Panics
279    ///
280    /// Panics if `index` is greater than or equal the length of the array.
281    pub fn value_offset(&self, index: usize) -> usize {
282        assert!(index < self.len());
283        match &self.offsets {
284            Some(offsets) => offsets[index] as usize,
285            None => self.offset() + index,
286        }
287    }
288
289    /// Returns the array's value at index `i`.
290    ///
291    /// Note: This method does not check for nulls and the value is arbitrary
292    /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index.
293    ///
294    /// # Panics
295    /// Panics if index `i` is out of bounds
296    pub fn value(&self, i: usize) -> ArrayRef {
297        let type_id = self.type_id(i);
298        let value_offset = self.value_offset(i);
299        let child = self.child(type_id);
300        child.slice(value_offset, 1)
301    }
302
303    /// Returns the names of the types in the union.
304    pub fn type_names(&self) -> Vec<&str> {
305        match self.data_type() {
306            DataType::Union(fields, _) => fields
307                .iter()
308                .map(|(_, f)| f.name().as_str())
309                .collect::<Vec<&str>>(),
310            _ => unreachable!("Union array's data type is not a union!"),
311        }
312    }
313
314    /// Returns the [`UnionFields`] for the union.
315    pub fn fields(&self) -> &UnionFields {
316        match self.data_type() {
317            DataType::Union(fields, _) => fields,
318            _ => unreachable!("Union array's data type is not a union!"),
319        }
320    }
321
322    /// Returns whether the `UnionArray` is dense (or sparse if `false`).
323    pub fn is_dense(&self) -> bool {
324        match self.data_type() {
325            DataType::Union(_, mode) => mode == &UnionMode::Dense,
326            _ => unreachable!("Union array's data type is not a union!"),
327        }
328    }
329
330    /// Returns a zero-copy slice of this array with the indicated offset and length.
331    pub fn slice(&self, offset: usize, length: usize) -> Self {
332        let (offsets, fields) = match self.offsets.as_ref() {
333            // If dense union, slice offsets
334            Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
335            // Otherwise need to slice sparse children
336            None => {
337                let fields = self
338                    .fields
339                    .iter()
340                    .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
341                    .collect();
342                (None, fields)
343            }
344        };
345
346        Self {
347            data_type: self.data_type.clone(),
348            type_ids: self.type_ids.slice(offset, length),
349            offsets,
350            fields,
351        }
352    }
353
354    /// Deconstruct this array into its constituent parts
355    ///
356    /// # Example
357    ///
358    /// ```
359    /// # use arrow_array::array::UnionArray;
360    /// # use arrow_array::types::Int32Type;
361    /// # use arrow_array::builder::UnionBuilder;
362    /// # use arrow_buffer::ScalarBuffer;
363    /// # fn main() -> Result<(), arrow_schema::ArrowError> {
364    /// let mut builder = UnionBuilder::new_dense();
365    /// builder.append::<Int32Type>("a", 1).unwrap();
366    /// let union_array = builder.build()?;
367    ///
368    /// // Deconstruct into parts
369    /// let (union_fields, type_ids, offsets, children) = union_array.into_parts();
370    ///
371    /// // Reconstruct from parts
372    /// let union_array = UnionArray::try_new(
373    ///     union_fields,
374    ///     type_ids,
375    ///     offsets,
376    ///     children,
377    /// );
378    /// # Ok(())
379    /// # }
380    /// ```
381    #[allow(clippy::type_complexity)]
382    pub fn into_parts(
383        self,
384    ) -> (
385        UnionFields,
386        ScalarBuffer<i8>,
387        Option<ScalarBuffer<i32>>,
388        Vec<ArrayRef>,
389    ) {
390        let Self {
391            data_type,
392            type_ids,
393            offsets,
394            mut fields,
395        } = self;
396        match data_type {
397            DataType::Union(union_fields, _) => {
398                let children = union_fields
399                    .iter()
400                    .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
401                    .collect();
402                (union_fields, type_ids, offsets, children)
403            }
404            _ => unreachable!(),
405        }
406    }
407
408    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields without nulls
409    fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
410        // Example logic for a union with 5 fields, a, b & c with nulls, d & e without nulls:
411        // let [a_nulls, b_nulls, c_nulls] = nulls;
412        // let [is_a, is_b, is_c] = masks;
413        // let is_d_or_e = !(is_a | is_b | is_c)
414        // let union_chunk_nulls = is_d_or_e  | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
415        let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
416            (
417                with_nulls_selected | is_field,
418                union_nulls | (is_field & field_nulls),
419            )
420        };
421
422        self.mask_sparse_helper(
423            nulls,
424            |type_ids_chunk_array, nulls_masks_iters| {
425                let (with_nulls_selected, union_nulls) = nulls_masks_iters
426                    .iter_mut()
427                    .map(|(field_type_id, field_nulls)| {
428                        let field_nulls = field_nulls.next().unwrap();
429                        let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
430
431                        (is_field, field_nulls)
432                    })
433                    .fold((0, 0), fold);
434
435                // In the example above, this is the is_d_or_e = !(is_a | is_b) part
436                let without_nulls_selected = !with_nulls_selected;
437
438                // if a field without nulls is selected, the value is always true(set bit)
439                // otherwise, the true/set bits have been computed above
440                without_nulls_selected | union_nulls
441            },
442            |type_ids_remainder, bit_chunks| {
443                let (with_nulls_selected, union_nulls) = bit_chunks
444                    .iter()
445                    .map(|(field_type_id, field_bit_chunks)| {
446                        let field_nulls = field_bit_chunks.remainder_bits();
447                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
448
449                        (is_field, field_nulls)
450                    })
451                    .fold((0, 0), fold);
452
453                let without_nulls_selected = !with_nulls_selected;
454
455                without_nulls_selected | union_nulls
456            },
457        )
458    }
459
460    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields fully null
461    fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
462        let fields = match self.data_type() {
463            DataType::Union(fields, _) => fields,
464            _ => unreachable!("Union array's data type is not a union!"),
465        };
466
467        let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
468        let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
469
470        let without_nulls_ids = type_ids
471            .difference(&with_nulls)
472            .copied()
473            .collect::<Vec<_>>();
474
475        nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
476
477        // Example logic for a union with 6 fields, a, b & c with nulls, d & e without nulls, and f fully_null:
478        // let [a_nulls, b_nulls, c_nulls] = nulls;
479        // let [is_a, is_b, is_c, is_d, is_e] = masks;
480        // let union_chunk_nulls = is_d | is_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
481        self.mask_sparse_helper(
482            nulls,
483            |type_ids_chunk_array, nulls_masks_iters| {
484                let union_nulls = nulls_masks_iters.iter_mut().fold(
485                    0,
486                    |union_nulls, (field_type_id, nulls_iter)| {
487                        let field_nulls = nulls_iter.next().unwrap();
488
489                        if field_nulls == 0 {
490                            union_nulls
491                        } else {
492                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
493
494                            union_nulls | (is_field & field_nulls)
495                        }
496                    },
497                );
498
499                // Given the example above, this is the is_d_or_e = (is_d | is_e) part
500                let without_nulls_selected =
501                    without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
502
503                // if a field without nulls is selected, the value is always true(set bit)
504                // otherwise, the true/set bits have been computed above
505                union_nulls | without_nulls_selected
506            },
507            |type_ids_remainder, bit_chunks| {
508                let union_nulls =
509                    bit_chunks
510                        .iter()
511                        .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
512                            let is_field = selection_mask(type_ids_remainder, *field_type_id);
513                            let field_nulls = field_bit_chunks.remainder_bits();
514
515                            union_nulls | is_field & field_nulls
516                        });
517
518                union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
519            },
520        )
521    }
522
523    /// Computes the logical nulls for a sparse union, optimized for when all fields contains nulls
524    fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
525        // Example logic for a union with 3 fields, a, b & c, all containing nulls:
526        // let [a_nulls, b_nulls, c_nulls] = nulls;
527        // We can skip the first field: it's selection mask is the negation of all others selection mask
528        // let [is_b, is_c] = selection_masks;
529        // let is_a = !(is_b | is_c)
530        // let union_chunk_nulls = (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
531        self.mask_sparse_helper(
532            nulls,
533            |type_ids_chunk_array, nulls_masks_iters| {
534                let (is_not_first, union_nulls) = nulls_masks_iters[1..] // skip first
535                    .iter_mut()
536                    .fold(
537                        (0, 0),
538                        |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
539                            let field_nulls = nulls_iter.next().unwrap();
540                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
541
542                            (
543                                is_not_first | is_field,
544                                union_nulls | (is_field & field_nulls),
545                            )
546                        },
547                    );
548
549                let is_first = !is_not_first;
550                let first_nulls = nulls_masks_iters[0].1.next().unwrap();
551
552                (is_first & first_nulls) | union_nulls
553            },
554            |type_ids_remainder, bit_chunks| {
555                bit_chunks
556                    .iter()
557                    .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
558                        let field_nulls = field_bit_chunks.remainder_bits();
559                        // The same logic as above, except that since this runs at most once,
560                        // it doesn't make difference to speed-up the first selection mask
561                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
562
563                        union_nulls | (is_field & field_nulls)
564                    })
565            },
566        )
567    }
568
569    /// Maps `nulls` to `BitChunk's` and then to `BitChunkIterator's`, then divides `self.type_ids` into exact chunks of 64 values,
570    /// calling `mask_chunk` for every exact chunk, and `mask_remainder` for the remainder, if any, collecting the result in a `BooleanBuffer`
571    fn mask_sparse_helper(
572        &self,
573        nulls: Vec<(i8, NullBuffer)>,
574        mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
575        mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
576    ) -> BooleanBuffer {
577        let bit_chunks = nulls
578            .iter()
579            .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
580            .collect::<Vec<_>>();
581
582        let mut nulls_masks_iter = bit_chunks
583            .iter()
584            .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
585            .collect::<Vec<_>>();
586
587        let chunks_exact = self.type_ids.chunks_exact(64);
588        let remainder = chunks_exact.remainder();
589
590        let chunks = chunks_exact.map(|type_ids_chunk| {
591            let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
592
593            mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
594        });
595
596        // SAFETY:
597        // chunks is a ChunksExact iterator, which implements TrustedLen, and correctly reports its length
598        let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
599
600        if !remainder.is_empty() {
601            buffer.push(mask_remainder(remainder, &bit_chunks));
602        }
603
604        BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
605    }
606
607    /// Computes the logical nulls for a sparse or dense union, by gathering individual bits from the null buffer of the selected field
608    fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
609        let one_null = NullBuffer::new_null(1);
610        let one_valid = NullBuffer::new_valid(1);
611
612        // Unsafe code below depend on it:
613        // To remove one branch from the loop, if the a type_id is not utilized, or it's logical_nulls is None/all set,
614        // we use a null buffer of len 1 and a index_mask of 0, or the true null buffer and usize::MAX otherwise.
615        // We then unconditionally access the null buffer with index & index_mask,
616        // which always return 0 for the 1-len buffer, or the true index unchanged otherwise
617        // We also use a 256 array, so llvm knows that `type_id as u8 as usize` is always in bounds
618        let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
619
620        for (type_id, nulls) in &nulls {
621            if nulls.null_count() == nulls.len() {
622                // Similarly, if all values are null, use a 1-null null-buffer to reduce cache pressure a bit
623                logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
624            } else {
625                logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
626            }
627        }
628
629        match &self.offsets {
630            Some(offsets) => {
631                assert_eq!(self.type_ids.len(), offsets.len());
632
633                BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
634                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
635                    let type_id = *self.type_ids.get_unchecked(i);
636                    // SAFETY: We asserted that offsets len and self.type_ids len are equal
637                    let offset = *offsets.get_unchecked(i);
638
639                    let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
640
641                    // SAFETY:
642                    // If offset_mask is Max
643                    // 1. Offset validity is checked at union creation
644                    // 2. If the null buffer len equals it's array len is checked at array creation
645                    // If offset_mask is Zero, the null buffer len is 1
646                    nulls
647                        .inner()
648                        .value_unchecked(offset as usize & *offset_mask as usize)
649                })
650            }
651            None => {
652                BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
653                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
654                    let type_id = *self.type_ids.get_unchecked(index);
655
656                    let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
657
658                    // SAFETY:
659                    // If index_mask is Max
660                    // 1. On sparse union, every child len match it's parent, this is checked at union creation
661                    // 2. If the null buffer len equals it's array len is checked at array creation
662                    // If index_mask is Zero, the null buffer len is 1
663                    nulls.inner().value_unchecked(index & *index_mask as usize)
664                })
665            }
666        }
667    }
668
669    /// Returns a vector of tuples containing each field's type_id and its logical null buffer.
670    /// Only fields with non-zero null counts are included.
671    fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
672        self.fields
673            .iter()
674            .enumerate()
675            .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
676            .filter(|(_, nulls)| nulls.null_count() > 0)
677            .collect()
678    }
679}
680
681impl From<ArrayData> for UnionArray {
682    fn from(data: ArrayData) -> Self {
683        let (data_type, len, _nulls, offset, buffers, child_data) = data.into_parts();
684
685        let (fields, mode) = match &data_type {
686            DataType::Union(fields, mode) => (fields, mode),
687            d => panic!("UnionArray expected ArrayData with type Union got {d}"),
688        };
689
690        let (type_ids, offsets) = match mode {
691            UnionMode::Sparse => {
692                let [buffer]: [Buffer; 1] = buffers.try_into().expect("1 buffer for type_ids");
693                (ScalarBuffer::new(buffer, offset, len), None)
694            }
695            UnionMode::Dense => {
696                let [type_ids_buffer, offsets_buffer]: [Buffer; 2] = buffers
697                    .try_into()
698                    .expect("2 buffers for type_ids and offsets");
699                (
700                    ScalarBuffer::new(type_ids_buffer, offset, len),
701                    Some(ScalarBuffer::new(offsets_buffer, offset, len)),
702                )
703            }
704        };
705
706        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
707        let mut boxed_fields = vec![None; max_id + 1];
708        for (cd, (field_id, _)) in child_data.into_iter().zip(fields.iter()) {
709            boxed_fields[field_id as usize] = Some(make_array(cd));
710        }
711        Self {
712            data_type,
713            type_ids,
714            offsets,
715            fields: boxed_fields,
716        }
717    }
718}
719
720impl From<UnionArray> for ArrayData {
721    fn from(array: UnionArray) -> Self {
722        let len = array.len();
723        let f = match &array.data_type {
724            DataType::Union(f, _) => f,
725            _ => unreachable!(),
726        };
727        let buffers = match array.offsets {
728            Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
729            None => vec![array.type_ids.into_inner()],
730        };
731
732        let child = f
733            .iter()
734            .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
735            .collect();
736
737        let builder = ArrayDataBuilder::new(array.data_type)
738            .len(len)
739            .buffers(buffers)
740            .child_data(child);
741        unsafe { builder.build_unchecked() }
742    }
743}
744
745/// SAFETY: Correctly implements the contract of Arrow Arrays
746unsafe impl Array for UnionArray {
747    fn as_any(&self) -> &dyn Any {
748        self
749    }
750
751    fn to_data(&self) -> ArrayData {
752        self.clone().into()
753    }
754
755    fn into_data(self) -> ArrayData {
756        self.into()
757    }
758
759    fn data_type(&self) -> &DataType {
760        &self.data_type
761    }
762
763    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
764        Arc::new(self.slice(offset, length))
765    }
766
767    fn len(&self) -> usize {
768        self.type_ids.len()
769    }
770
771    fn is_empty(&self) -> bool {
772        self.type_ids.is_empty()
773    }
774
775    fn shrink_to_fit(&mut self) {
776        self.type_ids.shrink_to_fit();
777        if let Some(offsets) = &mut self.offsets {
778            offsets.shrink_to_fit();
779        }
780        for array in self.fields.iter_mut().flatten() {
781            array.shrink_to_fit();
782        }
783        self.fields.shrink_to_fit();
784    }
785
786    fn offset(&self) -> usize {
787        0
788    }
789
790    fn nulls(&self) -> Option<&NullBuffer> {
791        None
792    }
793
794    fn logical_nulls(&self) -> Option<NullBuffer> {
795        let fields = match self.data_type() {
796            DataType::Union(fields, _) => fields,
797            _ => unreachable!(),
798        };
799
800        if fields.len() <= 1 {
801            return self.fields.iter().find_map(|field_opt| {
802                field_opt
803                    .as_ref()
804                    .and_then(|field| field.logical_nulls())
805                    .map(|logical_nulls| {
806                        if self.is_dense() {
807                            self.gather_nulls(vec![(0, logical_nulls)]).into()
808                        } else {
809                            logical_nulls
810                        }
811                    })
812            });
813        }
814
815        let logical_nulls = self.fields_logical_nulls();
816
817        if logical_nulls.is_empty() {
818            return None;
819        }
820
821        let fully_null_count = logical_nulls
822            .iter()
823            .filter(|(_, nulls)| nulls.null_count() == nulls.len())
824            .count();
825
826        if fully_null_count == fields.len() {
827            if let Some((_, exactly_sized)) = logical_nulls
828                .iter()
829                .find(|(_, nulls)| nulls.len() == self.len())
830            {
831                return Some(exactly_sized.clone());
832            }
833
834            if let Some((_, bigger)) = logical_nulls
835                .iter()
836                .find(|(_, nulls)| nulls.len() > self.len())
837            {
838                return Some(bigger.slice(0, self.len()));
839            }
840
841            return Some(NullBuffer::new_null(self.len()));
842        }
843
844        let boolean_buffer = match &self.offsets {
845            Some(_) => self.gather_nulls(logical_nulls),
846            None => {
847                // Choose the fastest way to compute the logical nulls
848                // Gather computes one null per iteration, while the others work on 64 nulls chunks,
849                // but must also compute selection masks, which is expensive,
850                // so it's cost is the number of selection masks computed per chunk
851                // Since computing the selection mask gets auto-vectorized, it's performance depends on which simd feature is enabled
852                // For gather, the cost is the threshold where masking becomes slower than gather, which is determined with benchmarks
853                // TODO: bench on avx512f(feature is still unstable)
854                let gather_relative_cost = if cfg!(target_feature = "avx2") {
855                    10
856                } else if cfg!(target_feature = "sse4.1") {
857                    3
858                } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
859                    // x86 baseline includes sse2
860                    2
861                } else {
862                    // TODO: bench on non x86
863                    // Always use gather on non benchmarked archs because even though it may slower on some cases,
864                    // it's performance depends only on the union length, without being affected by the number of fields
865                    0
866                };
867
868                let strategies = [
869                    (SparseStrategy::Gather, gather_relative_cost, true),
870                    (
871                        SparseStrategy::MaskAllFieldsWithNullsSkipOne,
872                        fields.len() - 1,
873                        fields.len() == logical_nulls.len(),
874                    ),
875                    (
876                        SparseStrategy::MaskSkipWithoutNulls,
877                        logical_nulls.len(),
878                        true,
879                    ),
880                    (
881                        SparseStrategy::MaskSkipFullyNull,
882                        fields.len() - fully_null_count,
883                        true,
884                    ),
885                ];
886
887                let (strategy, _, _) = strategies
888                    .iter()
889                    .filter(|(_, _, applicable)| *applicable)
890                    .min_by_key(|(_, cost, _)| cost)
891                    .unwrap();
892
893                match strategy {
894                    SparseStrategy::Gather => self.gather_nulls(logical_nulls),
895                    SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
896                        self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
897                    }
898                    SparseStrategy::MaskSkipWithoutNulls => {
899                        self.mask_sparse_skip_without_nulls(logical_nulls)
900                    }
901                    SparseStrategy::MaskSkipFullyNull => {
902                        self.mask_sparse_skip_fully_null(logical_nulls)
903                    }
904                }
905            }
906        };
907
908        let null_buffer = NullBuffer::from(boolean_buffer);
909
910        if null_buffer.null_count() > 0 {
911            Some(null_buffer)
912        } else {
913            None
914        }
915    }
916
917    fn is_nullable(&self) -> bool {
918        self.fields
919            .iter()
920            .flatten()
921            .any(|field| field.is_nullable())
922    }
923
924    fn get_buffer_memory_size(&self) -> usize {
925        let mut sum = self.type_ids.inner().capacity();
926        if let Some(o) = self.offsets.as_ref() {
927            sum += o.inner().capacity()
928        }
929        self.fields
930            .iter()
931            .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
932            .sum::<usize>()
933            + sum
934    }
935
936    fn get_array_memory_size(&self) -> usize {
937        let mut sum = self.type_ids.inner().capacity();
938        if let Some(o) = self.offsets.as_ref() {
939            sum += o.inner().capacity()
940        }
941        std::mem::size_of::<Self>()
942            + self
943                .fields
944                .iter()
945                .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
946                .sum::<usize>()
947            + sum
948    }
949
950    #[cfg(feature = "pool")]
951    fn claim(&self, pool: &dyn arrow_buffer::MemoryPool) {
952        self.type_ids.claim(pool);
953        if let Some(offsets) = &self.offsets {
954            offsets.claim(pool);
955        }
956        for field in self.fields.iter().flatten() {
957            field.claim(pool);
958        }
959    }
960}
961
962impl std::fmt::Debug for UnionArray {
963    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
964        let header = if self.is_dense() {
965            "UnionArray(Dense)\n["
966        } else {
967            "UnionArray(Sparse)\n["
968        };
969        writeln!(f, "{header}")?;
970
971        writeln!(f, "-- type id buffer:")?;
972        writeln!(f, "{:?}", self.type_ids)?;
973
974        if let Some(offsets) = &self.offsets {
975            writeln!(f, "-- offsets buffer:")?;
976            writeln!(f, "{offsets:?}")?;
977        }
978
979        let fields = match self.data_type() {
980            DataType::Union(fields, _) => fields,
981            _ => unreachable!(),
982        };
983
984        for (type_id, field) in fields.iter() {
985            let child = self.child(type_id);
986            writeln!(
987                f,
988                "-- child {}: \"{}\" ({:?})",
989                type_id,
990                field.name(),
991                field.data_type()
992            )?;
993            std::fmt::Debug::fmt(child, f)?;
994            writeln!(f)?;
995        }
996        writeln!(f, "]")
997    }
998}
999
1000/// How to compute the logical nulls of a sparse union. All strategies return the same result.
1001/// Those starting with Mask perform bitwise masking for each chunk of 64 values, including
1002/// computing expensive selection masks of fields: which fields masks must be computed is the
1003/// difference between them
1004enum SparseStrategy {
1005    /// Gather individual bits from the null buffer of the selected field
1006    Gather,
1007    /// All fields contains nulls, so we can skip the selection mask computation of one field by negating the others
1008    MaskAllFieldsWithNullsSkipOne,
1009    /// Skip the selection mask computation of the fields without nulls
1010    MaskSkipWithoutNulls,
1011    /// Skip the selection mask computation of the fully nulls fields
1012    MaskSkipFullyNull,
1013}
1014
1015#[derive(Copy, Clone)]
1016#[repr(usize)]
1017enum Mask {
1018    Zero = 0,
1019    // false positive, see https://github.com/rust-lang/rust-clippy/issues/8043
1020    #[allow(clippy::enum_clike_unportable_variant)]
1021    Max = usize::MAX,
1022}
1023
1024fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1025    type_ids_chunk
1026        .iter()
1027        .copied()
1028        .enumerate()
1029        .fold(0, |packed, (bit_idx, v)| {
1030            packed | (((v == type_id) as u64) << bit_idx)
1031        })
1032}
1033
1034/// Returns a bitmask where bits indicate if any id from `without_nulls_ids` exist in `type_ids_chunk`.
1035fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1036    without_nulls_ids
1037        .iter()
1038        .fold(0, |fully_valid_selected, field_type_id| {
1039            fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1040        })
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046    use std::collections::HashSet;
1047
1048    use crate::array::Int8Type;
1049    use crate::builder::UnionBuilder;
1050    use crate::cast::AsArray;
1051    use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1052    use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1053    use crate::{Int8Array, RecordBatch};
1054    use arrow_buffer::Buffer;
1055    use arrow_schema::{Field, Schema};
1056
1057    #[test]
1058    fn test_dense_i32() {
1059        let mut builder = UnionBuilder::new_dense();
1060        builder.append::<Int32Type>("a", 1).unwrap();
1061        builder.append::<Int32Type>("b", 2).unwrap();
1062        builder.append::<Int32Type>("c", 3).unwrap();
1063        builder.append::<Int32Type>("a", 4).unwrap();
1064        builder.append::<Int32Type>("c", 5).unwrap();
1065        builder.append::<Int32Type>("a", 6).unwrap();
1066        builder.append::<Int32Type>("b", 7).unwrap();
1067        let union = builder.build().unwrap();
1068
1069        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1070        let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1071        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1072
1073        // Check type ids
1074        assert_eq!(*union.type_ids(), expected_type_ids);
1075        for (i, id) in expected_type_ids.iter().enumerate() {
1076            assert_eq!(id, &union.type_id(i));
1077        }
1078
1079        // Check offsets
1080        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1081        for (i, id) in expected_offsets.iter().enumerate() {
1082            assert_eq!(union.value_offset(i), *id as usize);
1083        }
1084
1085        // Check data
1086        assert_eq!(
1087            *union.child(0).as_primitive::<Int32Type>().values(),
1088            [1_i32, 4, 6]
1089        );
1090        assert_eq!(
1091            *union.child(1).as_primitive::<Int32Type>().values(),
1092            [2_i32, 7]
1093        );
1094        assert_eq!(
1095            *union.child(2).as_primitive::<Int32Type>().values(),
1096            [3_i32, 5]
1097        );
1098
1099        assert_eq!(expected_array_values.len(), union.len());
1100        for (i, expected_value) in expected_array_values.iter().enumerate() {
1101            assert!(!union.is_null(i));
1102            let slot = union.value(i);
1103            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1104            assert_eq!(slot.len(), 1);
1105            let value = slot.value(0);
1106            assert_eq!(expected_value, &value);
1107        }
1108    }
1109
1110    #[test]
1111    fn slice_union_array_single_field() {
1112        // Dense Union
1113        // [1, null, 3, null, 4]
1114        let union_array = {
1115            let mut builder = UnionBuilder::new_dense();
1116            builder.append::<Int32Type>("a", 1).unwrap();
1117            builder.append_null::<Int32Type>("a").unwrap();
1118            builder.append::<Int32Type>("a", 3).unwrap();
1119            builder.append_null::<Int32Type>("a").unwrap();
1120            builder.append::<Int32Type>("a", 4).unwrap();
1121            builder.build().unwrap()
1122        };
1123
1124        // [null, 3, null]
1125        let union_slice = union_array.slice(1, 3);
1126        let logical_nulls = union_slice.logical_nulls().unwrap();
1127
1128        assert_eq!(logical_nulls.len(), 3);
1129        assert!(logical_nulls.is_null(0));
1130        assert!(logical_nulls.is_valid(1));
1131        assert!(logical_nulls.is_null(2));
1132    }
1133
1134    #[test]
1135    #[cfg_attr(miri, ignore)]
1136    fn test_dense_i32_large() {
1137        let mut builder = UnionBuilder::new_dense();
1138
1139        let expected_type_ids = vec![0_i8; 1024];
1140        let expected_offsets: Vec<_> = (0..1024).collect();
1141        let expected_array_values: Vec<_> = (1..=1024).collect();
1142
1143        expected_array_values
1144            .iter()
1145            .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1146
1147        let union = builder.build().unwrap();
1148
1149        // Check type ids
1150        assert_eq!(*union.type_ids(), expected_type_ids);
1151        for (i, id) in expected_type_ids.iter().enumerate() {
1152            assert_eq!(id, &union.type_id(i));
1153        }
1154
1155        // Check offsets
1156        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1157        for (i, id) in expected_offsets.iter().enumerate() {
1158            assert_eq!(union.value_offset(i), *id as usize);
1159        }
1160
1161        for (i, expected_value) in expected_array_values.iter().enumerate() {
1162            assert!(!union.is_null(i));
1163            let slot = union.value(i);
1164            let slot = slot.as_primitive::<Int32Type>();
1165            assert_eq!(slot.len(), 1);
1166            let value = slot.value(0);
1167            assert_eq!(expected_value, &value);
1168        }
1169    }
1170
1171    #[test]
1172    fn test_dense_mixed() {
1173        let mut builder = UnionBuilder::new_dense();
1174        builder.append::<Int32Type>("a", 1).unwrap();
1175        builder.append::<Int64Type>("c", 3).unwrap();
1176        builder.append::<Int32Type>("a", 4).unwrap();
1177        builder.append::<Int64Type>("c", 5).unwrap();
1178        builder.append::<Int32Type>("a", 6).unwrap();
1179        let union = builder.build().unwrap();
1180
1181        assert_eq!(5, union.len());
1182        for i in 0..union.len() {
1183            let slot = union.value(i);
1184            assert!(!union.is_null(i));
1185            match i {
1186                0 => {
1187                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1188                    assert_eq!(slot.len(), 1);
1189                    let value = slot.value(0);
1190                    assert_eq!(1_i32, value);
1191                }
1192                1 => {
1193                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1194                    assert_eq!(slot.len(), 1);
1195                    let value = slot.value(0);
1196                    assert_eq!(3_i64, value);
1197                }
1198                2 => {
1199                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1200                    assert_eq!(slot.len(), 1);
1201                    let value = slot.value(0);
1202                    assert_eq!(4_i32, value);
1203                }
1204                3 => {
1205                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1206                    assert_eq!(slot.len(), 1);
1207                    let value = slot.value(0);
1208                    assert_eq!(5_i64, value);
1209                }
1210                4 => {
1211                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1212                    assert_eq!(slot.len(), 1);
1213                    let value = slot.value(0);
1214                    assert_eq!(6_i32, value);
1215                }
1216                _ => unreachable!(),
1217            }
1218        }
1219    }
1220
1221    #[test]
1222    fn test_dense_mixed_with_nulls() {
1223        let mut builder = UnionBuilder::new_dense();
1224        builder.append::<Int32Type>("a", 1).unwrap();
1225        builder.append::<Int64Type>("c", 3).unwrap();
1226        builder.append::<Int32Type>("a", 10).unwrap();
1227        builder.append_null::<Int32Type>("a").unwrap();
1228        builder.append::<Int32Type>("a", 6).unwrap();
1229        let union = builder.build().unwrap();
1230
1231        assert_eq!(5, union.len());
1232        for i in 0..union.len() {
1233            let slot = union.value(i);
1234            match i {
1235                0 => {
1236                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1237                    assert!(!slot.is_null(0));
1238                    assert_eq!(slot.len(), 1);
1239                    let value = slot.value(0);
1240                    assert_eq!(1_i32, value);
1241                }
1242                1 => {
1243                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1244                    assert!(!slot.is_null(0));
1245                    assert_eq!(slot.len(), 1);
1246                    let value = slot.value(0);
1247                    assert_eq!(3_i64, value);
1248                }
1249                2 => {
1250                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1251                    assert!(!slot.is_null(0));
1252                    assert_eq!(slot.len(), 1);
1253                    let value = slot.value(0);
1254                    assert_eq!(10_i32, value);
1255                }
1256                3 => assert!(slot.is_null(0)),
1257                4 => {
1258                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1259                    assert!(!slot.is_null(0));
1260                    assert_eq!(slot.len(), 1);
1261                    let value = slot.value(0);
1262                    assert_eq!(6_i32, value);
1263                }
1264                _ => unreachable!(),
1265            }
1266        }
1267    }
1268
1269    #[test]
1270    fn test_dense_mixed_with_nulls_and_offset() {
1271        let mut builder = UnionBuilder::new_dense();
1272        builder.append::<Int32Type>("a", 1).unwrap();
1273        builder.append::<Int64Type>("c", 3).unwrap();
1274        builder.append::<Int32Type>("a", 10).unwrap();
1275        builder.append_null::<Int32Type>("a").unwrap();
1276        builder.append::<Int32Type>("a", 6).unwrap();
1277        let union = builder.build().unwrap();
1278
1279        let slice = union.slice(2, 3);
1280        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1281
1282        assert_eq!(3, new_union.len());
1283        for i in 0..new_union.len() {
1284            let slot = new_union.value(i);
1285            match i {
1286                0 => {
1287                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1288                    assert!(!slot.is_null(0));
1289                    assert_eq!(slot.len(), 1);
1290                    let value = slot.value(0);
1291                    assert_eq!(10_i32, value);
1292                }
1293                1 => assert!(slot.is_null(0)),
1294                2 => {
1295                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1296                    assert!(!slot.is_null(0));
1297                    assert_eq!(slot.len(), 1);
1298                    let value = slot.value(0);
1299                    assert_eq!(6_i32, value);
1300                }
1301                _ => unreachable!(),
1302            }
1303        }
1304    }
1305
1306    #[test]
1307    fn test_dense_mixed_with_str() {
1308        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1309        let int_array = Int32Array::from(vec![5, 6]);
1310        let float_array = Float64Array::from(vec![10.0]);
1311
1312        let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1313        let offsets = [0, 0, 1, 0, 2, 1]
1314            .into_iter()
1315            .collect::<ScalarBuffer<i32>>();
1316
1317        let fields = [
1318            (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1319            (1, Arc::new(Field::new("B", DataType::Int32, false))),
1320            (2, Arc::new(Field::new("C", DataType::Float64, false))),
1321        ]
1322        .into_iter()
1323        .collect::<UnionFields>();
1324        let children = [
1325            Arc::new(string_array) as Arc<dyn Array>,
1326            Arc::new(int_array),
1327            Arc::new(float_array),
1328        ]
1329        .into_iter()
1330        .collect();
1331        let array =
1332            UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1333
1334        // Check type ids
1335        assert_eq!(*array.type_ids(), type_ids);
1336        for (i, id) in type_ids.iter().enumerate() {
1337            assert_eq!(id, &array.type_id(i));
1338        }
1339
1340        // Check offsets
1341        assert_eq!(*array.offsets().unwrap(), offsets);
1342        for (i, id) in offsets.iter().enumerate() {
1343            assert_eq!(*id as usize, array.value_offset(i));
1344        }
1345
1346        // Check values
1347        assert_eq!(6, array.len());
1348
1349        let slot = array.value(0);
1350        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1351        assert_eq!(5, value);
1352
1353        let slot = array.value(1);
1354        let value = slot
1355            .as_any()
1356            .downcast_ref::<StringArray>()
1357            .unwrap()
1358            .value(0);
1359        assert_eq!("foo", value);
1360
1361        let slot = array.value(2);
1362        let value = slot
1363            .as_any()
1364            .downcast_ref::<StringArray>()
1365            .unwrap()
1366            .value(0);
1367        assert_eq!("bar", value);
1368
1369        let slot = array.value(3);
1370        let value = slot
1371            .as_any()
1372            .downcast_ref::<Float64Array>()
1373            .unwrap()
1374            .value(0);
1375        assert_eq!(10.0, value);
1376
1377        let slot = array.value(4);
1378        let value = slot
1379            .as_any()
1380            .downcast_ref::<StringArray>()
1381            .unwrap()
1382            .value(0);
1383        assert_eq!("baz", value);
1384
1385        let slot = array.value(5);
1386        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1387        assert_eq!(6, value);
1388    }
1389
1390    #[test]
1391    fn test_sparse_i32() {
1392        let mut builder = UnionBuilder::new_sparse();
1393        builder.append::<Int32Type>("a", 1).unwrap();
1394        builder.append::<Int32Type>("b", 2).unwrap();
1395        builder.append::<Int32Type>("c", 3).unwrap();
1396        builder.append::<Int32Type>("a", 4).unwrap();
1397        builder.append::<Int32Type>("c", 5).unwrap();
1398        builder.append::<Int32Type>("a", 6).unwrap();
1399        builder.append::<Int32Type>("b", 7).unwrap();
1400        let union = builder.build().unwrap();
1401
1402        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1403        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1404
1405        // Check type ids
1406        assert_eq!(*union.type_ids(), expected_type_ids);
1407        for (i, id) in expected_type_ids.iter().enumerate() {
1408            assert_eq!(id, &union.type_id(i));
1409        }
1410
1411        // Check offsets, sparse union should only have a single buffer
1412        assert!(union.offsets().is_none());
1413
1414        // Check data
1415        assert_eq!(
1416            *union.child(0).as_primitive::<Int32Type>().values(),
1417            [1_i32, 0, 0, 4, 0, 6, 0],
1418        );
1419        assert_eq!(
1420            *union.child(1).as_primitive::<Int32Type>().values(),
1421            [0_i32, 2_i32, 0, 0, 0, 0, 7]
1422        );
1423        assert_eq!(
1424            *union.child(2).as_primitive::<Int32Type>().values(),
1425            [0_i32, 0, 3_i32, 0, 5, 0, 0]
1426        );
1427
1428        assert_eq!(expected_array_values.len(), union.len());
1429        for (i, expected_value) in expected_array_values.iter().enumerate() {
1430            assert!(!union.is_null(i));
1431            let slot = union.value(i);
1432            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1433            assert_eq!(slot.len(), 1);
1434            let value = slot.value(0);
1435            assert_eq!(expected_value, &value);
1436        }
1437    }
1438
1439    #[test]
1440    fn test_sparse_mixed() {
1441        let mut builder = UnionBuilder::new_sparse();
1442        builder.append::<Int32Type>("a", 1).unwrap();
1443        builder.append::<Float64Type>("c", 3.0).unwrap();
1444        builder.append::<Int32Type>("a", 4).unwrap();
1445        builder.append::<Float64Type>("c", 5.0).unwrap();
1446        builder.append::<Int32Type>("a", 6).unwrap();
1447        let union = builder.build().unwrap();
1448
1449        let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1450
1451        // Check type ids
1452        assert_eq!(*union.type_ids(), expected_type_ids);
1453        for (i, id) in expected_type_ids.iter().enumerate() {
1454            assert_eq!(id, &union.type_id(i));
1455        }
1456
1457        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1458        assert!(union.offsets().is_none());
1459
1460        for i in 0..union.len() {
1461            let slot = union.value(i);
1462            assert!(!union.is_null(i));
1463            match i {
1464                0 => {
1465                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1466                    assert_eq!(slot.len(), 1);
1467                    let value = slot.value(0);
1468                    assert_eq!(1_i32, value);
1469                }
1470                1 => {
1471                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1472                    assert_eq!(slot.len(), 1);
1473                    let value = slot.value(0);
1474                    assert_eq!(value, 3_f64);
1475                }
1476                2 => {
1477                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1478                    assert_eq!(slot.len(), 1);
1479                    let value = slot.value(0);
1480                    assert_eq!(4_i32, value);
1481                }
1482                3 => {
1483                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1484                    assert_eq!(slot.len(), 1);
1485                    let value = slot.value(0);
1486                    assert_eq!(5_f64, value);
1487                }
1488                4 => {
1489                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1490                    assert_eq!(slot.len(), 1);
1491                    let value = slot.value(0);
1492                    assert_eq!(6_i32, value);
1493                }
1494                _ => unreachable!(),
1495            }
1496        }
1497    }
1498
1499    #[test]
1500    fn test_sparse_mixed_with_nulls() {
1501        let mut builder = UnionBuilder::new_sparse();
1502        builder.append::<Int32Type>("a", 1).unwrap();
1503        builder.append_null::<Int32Type>("a").unwrap();
1504        builder.append::<Float64Type>("c", 3.0).unwrap();
1505        builder.append::<Int32Type>("a", 4).unwrap();
1506        let union = builder.build().unwrap();
1507
1508        let expected_type_ids = vec![0_i8, 0, 1, 0];
1509
1510        // Check type ids
1511        assert_eq!(*union.type_ids(), expected_type_ids);
1512        for (i, id) in expected_type_ids.iter().enumerate() {
1513            assert_eq!(id, &union.type_id(i));
1514        }
1515
1516        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1517        assert!(union.offsets().is_none());
1518
1519        for i in 0..union.len() {
1520            let slot = union.value(i);
1521            match i {
1522                0 => {
1523                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1524                    assert!(!slot.is_null(0));
1525                    assert_eq!(slot.len(), 1);
1526                    let value = slot.value(0);
1527                    assert_eq!(1_i32, value);
1528                }
1529                1 => assert!(slot.is_null(0)),
1530                2 => {
1531                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1532                    assert!(!slot.is_null(0));
1533                    assert_eq!(slot.len(), 1);
1534                    let value = slot.value(0);
1535                    assert_eq!(value, 3_f64);
1536                }
1537                3 => {
1538                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1539                    assert!(!slot.is_null(0));
1540                    assert_eq!(slot.len(), 1);
1541                    let value = slot.value(0);
1542                    assert_eq!(4_i32, value);
1543                }
1544                _ => unreachable!(),
1545            }
1546        }
1547    }
1548
1549    #[test]
1550    fn test_sparse_mixed_with_nulls_and_offset() {
1551        let mut builder = UnionBuilder::new_sparse();
1552        builder.append::<Int32Type>("a", 1).unwrap();
1553        builder.append_null::<Int32Type>("a").unwrap();
1554        builder.append::<Float64Type>("c", 3.0).unwrap();
1555        builder.append_null::<Float64Type>("c").unwrap();
1556        builder.append::<Int32Type>("a", 4).unwrap();
1557        let union = builder.build().unwrap();
1558
1559        let slice = union.slice(1, 4);
1560        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1561
1562        assert_eq!(4, new_union.len());
1563        for i in 0..new_union.len() {
1564            let slot = new_union.value(i);
1565            match i {
1566                0 => assert!(slot.is_null(0)),
1567                1 => {
1568                    let slot = slot.as_primitive::<Float64Type>();
1569                    assert!(!slot.is_null(0));
1570                    assert_eq!(slot.len(), 1);
1571                    let value = slot.value(0);
1572                    assert_eq!(value, 3_f64);
1573                }
1574                2 => assert!(slot.is_null(0)),
1575                3 => {
1576                    let slot = slot.as_primitive::<Int32Type>();
1577                    assert!(!slot.is_null(0));
1578                    assert_eq!(slot.len(), 1);
1579                    let value = slot.value(0);
1580                    assert_eq!(4_i32, value);
1581                }
1582                _ => unreachable!(),
1583            }
1584        }
1585    }
1586
1587    fn test_union_validity(union_array: &UnionArray) {
1588        assert_eq!(union_array.null_count(), 0);
1589
1590        for i in 0..union_array.len() {
1591            assert!(!union_array.is_null(i));
1592            assert!(union_array.is_valid(i));
1593        }
1594    }
1595
1596    #[test]
1597    fn test_union_array_validity() {
1598        let mut builder = UnionBuilder::new_sparse();
1599        builder.append::<Int32Type>("a", 1).unwrap();
1600        builder.append_null::<Int32Type>("a").unwrap();
1601        builder.append::<Float64Type>("c", 3.0).unwrap();
1602        builder.append_null::<Float64Type>("c").unwrap();
1603        builder.append::<Int32Type>("a", 4).unwrap();
1604        let union = builder.build().unwrap();
1605
1606        test_union_validity(&union);
1607
1608        let mut builder = UnionBuilder::new_dense();
1609        builder.append::<Int32Type>("a", 1).unwrap();
1610        builder.append_null::<Int32Type>("a").unwrap();
1611        builder.append::<Float64Type>("c", 3.0).unwrap();
1612        builder.append_null::<Float64Type>("c").unwrap();
1613        builder.append::<Int32Type>("a", 4).unwrap();
1614        let union = builder.build().unwrap();
1615
1616        test_union_validity(&union);
1617    }
1618
1619    #[test]
1620    fn test_type_check() {
1621        let mut builder = UnionBuilder::new_sparse();
1622        builder.append::<Float32Type>("a", 1.0).unwrap();
1623        let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1624        assert!(
1625            err.contains(
1626                "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1627            ),
1628            "{}",
1629            err
1630        );
1631    }
1632
1633    #[test]
1634    fn slice_union_array() {
1635        // [1, null, 3.0, null, 4]
1636        fn create_union(mut builder: UnionBuilder) -> UnionArray {
1637            builder.append::<Int32Type>("a", 1).unwrap();
1638            builder.append_null::<Int32Type>("a").unwrap();
1639            builder.append::<Float64Type>("c", 3.0).unwrap();
1640            builder.append_null::<Float64Type>("c").unwrap();
1641            builder.append::<Int32Type>("a", 4).unwrap();
1642            builder.build().unwrap()
1643        }
1644
1645        fn create_batch(union: UnionArray) -> RecordBatch {
1646            let schema = Schema::new(vec![Field::new(
1647                "struct_array",
1648                union.data_type().clone(),
1649                true,
1650            )]);
1651
1652            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1653        }
1654
1655        fn test_slice_union(record_batch_slice: RecordBatch) {
1656            let union_slice = record_batch_slice
1657                .column(0)
1658                .as_any()
1659                .downcast_ref::<UnionArray>()
1660                .unwrap();
1661
1662            assert_eq!(union_slice.type_id(0), 0);
1663            assert_eq!(union_slice.type_id(1), 1);
1664            assert_eq!(union_slice.type_id(2), 1);
1665
1666            let slot = union_slice.value(0);
1667            let array = slot.as_primitive::<Int32Type>();
1668            assert_eq!(array.len(), 1);
1669            assert!(array.is_null(0));
1670
1671            let slot = union_slice.value(1);
1672            let array = slot.as_primitive::<Float64Type>();
1673            assert_eq!(array.len(), 1);
1674            assert!(array.is_valid(0));
1675            assert_eq!(array.value(0), 3.0);
1676
1677            let slot = union_slice.value(2);
1678            let array = slot.as_primitive::<Float64Type>();
1679            assert_eq!(array.len(), 1);
1680            assert!(array.is_null(0));
1681        }
1682
1683        // Sparse Union
1684        let builder = UnionBuilder::new_sparse();
1685        let record_batch = create_batch(create_union(builder));
1686        // [null, 3.0, null]
1687        let record_batch_slice = record_batch.slice(1, 3);
1688        test_slice_union(record_batch_slice);
1689
1690        // Dense Union
1691        let builder = UnionBuilder::new_dense();
1692        let record_batch = create_batch(create_union(builder));
1693        // [null, 3.0, null]
1694        let record_batch_slice = record_batch.slice(1, 3);
1695        test_slice_union(record_batch_slice);
1696    }
1697
1698    #[test]
1699    fn test_custom_type_ids() {
1700        let data_type = DataType::Union(
1701            UnionFields::try_new(
1702                vec![8, 4, 9],
1703                vec![
1704                    Field::new("strings", DataType::Utf8, false),
1705                    Field::new("integers", DataType::Int32, false),
1706                    Field::new("floats", DataType::Float64, false),
1707                ],
1708            )
1709            .unwrap(),
1710            UnionMode::Dense,
1711        );
1712
1713        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1714        let int_array = Int32Array::from(vec![5, 6, 4]);
1715        let float_array = Float64Array::from(vec![10.0]);
1716
1717        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1718        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1719
1720        let data = ArrayData::builder(data_type)
1721            .len(7)
1722            .buffers(vec![type_ids, value_offsets])
1723            .child_data(vec![
1724                string_array.into_data(),
1725                int_array.into_data(),
1726                float_array.into_data(),
1727            ])
1728            .build()
1729            .unwrap();
1730
1731        let array = UnionArray::from(data);
1732
1733        let v = array.value(0);
1734        assert_eq!(v.data_type(), &DataType::Int32);
1735        assert_eq!(v.len(), 1);
1736        assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1737
1738        let v = array.value(1);
1739        assert_eq!(v.data_type(), &DataType::Utf8);
1740        assert_eq!(v.len(), 1);
1741        assert_eq!(v.as_string::<i32>().value(0), "foo");
1742
1743        let v = array.value(2);
1744        assert_eq!(v.data_type(), &DataType::Int32);
1745        assert_eq!(v.len(), 1);
1746        assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1747
1748        let v = array.value(3);
1749        assert_eq!(v.data_type(), &DataType::Utf8);
1750        assert_eq!(v.len(), 1);
1751        assert_eq!(v.as_string::<i32>().value(0), "bar");
1752
1753        let v = array.value(4);
1754        assert_eq!(v.data_type(), &DataType::Float64);
1755        assert_eq!(v.len(), 1);
1756        assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1757
1758        let v = array.value(5);
1759        assert_eq!(v.data_type(), &DataType::Int32);
1760        assert_eq!(v.len(), 1);
1761        assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1762
1763        let v = array.value(6);
1764        assert_eq!(v.data_type(), &DataType::Utf8);
1765        assert_eq!(v.len(), 1);
1766        assert_eq!(v.as_string::<i32>().value(0), "baz");
1767    }
1768
1769    #[test]
1770    fn into_parts() {
1771        let mut builder = UnionBuilder::new_dense();
1772        builder.append::<Int32Type>("a", 1).unwrap();
1773        builder.append::<Int8Type>("b", 2).unwrap();
1774        builder.append::<Int32Type>("a", 3).unwrap();
1775        let dense_union = builder.build().unwrap();
1776
1777        let field = [
1778            &Arc::new(Field::new("a", DataType::Int32, false)),
1779            &Arc::new(Field::new("b", DataType::Int8, false)),
1780        ];
1781        let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1782        assert_eq!(
1783            union_fields
1784                .iter()
1785                .map(|(_, field)| field)
1786                .collect::<Vec<_>>(),
1787            field
1788        );
1789        assert_eq!(type_ids, [0, 1, 0]);
1790        assert!(offsets.is_some());
1791        assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1792
1793        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1794        assert!(result.is_ok());
1795        assert_eq!(result.unwrap().len(), 3);
1796
1797        let mut builder = UnionBuilder::new_sparse();
1798        builder.append::<Int32Type>("a", 1).unwrap();
1799        builder.append::<Int8Type>("b", 2).unwrap();
1800        builder.append::<Int32Type>("a", 3).unwrap();
1801        let sparse_union = builder.build().unwrap();
1802
1803        let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1804        assert_eq!(type_ids, [0, 1, 0]);
1805        assert!(offsets.is_none());
1806
1807        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1808        assert!(result.is_ok());
1809        assert_eq!(result.unwrap().len(), 3);
1810    }
1811
1812    #[test]
1813    fn into_parts_custom_type_ids() {
1814        let set_field_type_ids: [i8; 3] = [8, 4, 9];
1815        let data_type = DataType::Union(
1816            UnionFields::try_new(
1817                set_field_type_ids,
1818                [
1819                    Field::new("strings", DataType::Utf8, false),
1820                    Field::new("integers", DataType::Int32, false),
1821                    Field::new("floats", DataType::Float64, false),
1822                ],
1823            )
1824            .unwrap(),
1825            UnionMode::Dense,
1826        );
1827        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1828        let int_array = Int32Array::from(vec![5, 6, 4]);
1829        let float_array = Float64Array::from(vec![10.0]);
1830        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1831        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1832        let data = ArrayData::builder(data_type)
1833            .len(7)
1834            .buffers(vec![type_ids, value_offsets])
1835            .child_data(vec![
1836                string_array.into_data(),
1837                int_array.into_data(),
1838                float_array.into_data(),
1839            ])
1840            .build()
1841            .unwrap();
1842        let array = UnionArray::from(data);
1843
1844        let (union_fields, type_ids, offsets, children) = array.into_parts();
1845        assert_eq!(
1846            type_ids.iter().collect::<HashSet<_>>(),
1847            set_field_type_ids.iter().collect::<HashSet<_>>()
1848        );
1849        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1850        assert!(result.is_ok());
1851        let array = result.unwrap();
1852        assert_eq!(array.len(), 7);
1853    }
1854
1855    #[test]
1856    fn test_invalid() {
1857        let fields = UnionFields::try_new(
1858            [3, 2],
1859            [
1860                Field::new("a", DataType::Utf8, false),
1861                Field::new("b", DataType::Utf8, false),
1862            ],
1863        )
1864        .unwrap();
1865        let children = vec![
1866            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1867            Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1868        ];
1869
1870        let type_ids = vec![3, 3, 2].into();
1871        let err =
1872            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1873        assert_eq!(
1874            err.to_string(),
1875            "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1876        );
1877
1878        let type_ids = vec![1, 2].into();
1879        let err =
1880            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1881        assert_eq!(
1882            err.to_string(),
1883            "Invalid argument error: Type Ids values must match one of the field type ids"
1884        );
1885
1886        let type_ids = vec![7, 2].into();
1887        let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1888        assert_eq!(
1889            err.to_string(),
1890            "Invalid argument error: Type Ids values must match one of the field type ids"
1891        );
1892
1893        let children = vec![
1894            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1895            Arc::new(StringArray::from_iter_values(["c"])) as _,
1896        ];
1897        let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1898        let offsets = Some(vec![0, 1, 0].into());
1899        UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1900
1901        let offsets = Some(vec![0, 1, 1].into());
1902        let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1903            .unwrap_err();
1904
1905        assert_eq!(
1906            err.to_string(),
1907            "Invalid argument error: Offsets must be non-negative and within the length of the Array"
1908        );
1909
1910        let offsets = Some(vec![0, 1].into());
1911        let err =
1912            UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1913
1914        assert_eq!(
1915            err.to_string(),
1916            "Invalid argument error: Type Ids and Offsets lengths must match"
1917        );
1918
1919        let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1920
1921        assert_eq!(
1922            err.to_string(),
1923            "Invalid argument error: Union fields length must match child arrays length"
1924        );
1925    }
1926
1927    #[test]
1928    fn test_logical_nulls_fast_paths() {
1929        // fields.len() <= 1
1930        let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1931
1932        assert_eq!(array.logical_nulls(), None);
1933
1934        let fields = UnionFields::try_new(
1935            [1, 3],
1936            [
1937                Field::new("a", DataType::Int8, false), // non nullable
1938                Field::new("b", DataType::Int8, false), // non nullable
1939            ],
1940        )
1941        .unwrap();
1942        let array = UnionArray::try_new(
1943            fields,
1944            vec![1].into(),
1945            None,
1946            vec![
1947                Arc::new(Int8Array::from_value(5, 1)),
1948                Arc::new(Int8Array::from_value(5, 1)),
1949            ],
1950        )
1951        .unwrap();
1952
1953        assert_eq!(array.logical_nulls(), None);
1954
1955        let nullable_fields = UnionFields::try_new(
1956            [1, 3],
1957            [
1958                Field::new("a", DataType::Int8, true), // nullable but without nulls
1959                Field::new("b", DataType::Int8, true), // nullable but without nulls
1960            ],
1961        )
1962        .unwrap();
1963        let array = UnionArray::try_new(
1964            nullable_fields.clone(),
1965            vec![1, 1].into(),
1966            None,
1967            vec![
1968                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1969                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1970            ],
1971        )
1972        .unwrap();
1973
1974        assert_eq!(array.logical_nulls(), None);
1975
1976        let array = UnionArray::try_new(
1977            nullable_fields.clone(),
1978            vec![1, 1].into(),
1979            None,
1980            vec![
1981                // every children is completly null
1982                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1983                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1984            ],
1985        )
1986        .unwrap();
1987
1988        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1989
1990        let array = UnionArray::try_new(
1991            nullable_fields.clone(),
1992            vec![1, 1].into(),
1993            Some(vec![0, 1].into()),
1994            vec![
1995                // every children is completly null
1996                Arc::new(Int8Array::new_null(3)), // bigger that parent
1997                Arc::new(Int8Array::new_null(3)), // bigger that parent
1998            ],
1999        )
2000        .unwrap();
2001
2002        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
2003    }
2004
2005    #[test]
2006    fn test_dense_union_logical_nulls_gather() {
2007        // union of [{A=1}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
2008        let int_array = Int32Array::from(vec![1, 2]);
2009        let float_array = Float64Array::from(vec![Some(3.2), None]);
2010        let str_array = StringArray::new_null(1);
2011        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2012        let offsets = [0, 1, 0, 1, 0, 0]
2013            .into_iter()
2014            .collect::<ScalarBuffer<i32>>();
2015
2016        let children = vec![
2017            Arc::new(int_array) as Arc<dyn Array>,
2018            Arc::new(float_array),
2019            Arc::new(str_array),
2020        ];
2021
2022        let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
2023
2024        let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
2025
2026        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2027        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2028    }
2029
2030    #[test]
2031    fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2032        let fields: UnionFields = [
2033            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2034            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2035        ]
2036        .into_iter()
2037        .collect();
2038
2039        // union of [{A=}, {A=}, {B=3.2}, {B=}]
2040        let int_array = Int32Array::new_null(4);
2041        let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2042        let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2043
2044        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2045
2046        let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2047
2048        let expected = BooleanBuffer::from(vec![false, false, true, false]);
2049
2050        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2051        assert_eq!(
2052            expected,
2053            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2054        );
2055
2056        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2057        let len = 2 * 64 + 32;
2058
2059        let int_array = Int32Array::new_null(len);
2060        let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2061        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2062
2063        let array = UnionArray::try_new(
2064            fields,
2065            type_ids,
2066            None,
2067            vec![Arc::new(int_array), Arc::new(float_array)],
2068        )
2069        .unwrap();
2070
2071        let expected =
2072            BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2073
2074        assert_eq!(array.len(), len);
2075        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2076        assert_eq!(
2077            expected,
2078            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2079        );
2080    }
2081
2082    #[test]
2083    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2084        // union of [{A=2}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
2085        let int_array = Int32Array::from_value(2, 6);
2086        let float_array = Float64Array::from_value(4.2, 6);
2087        let str_array = StringArray::new_null(6);
2088        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2089
2090        let children = vec![
2091            Arc::new(int_array) as Arc<dyn Array>,
2092            Arc::new(float_array),
2093            Arc::new(str_array),
2094        ];
2095
2096        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2097
2098        let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2099
2100        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2101        assert_eq!(
2102            expected,
2103            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2104        );
2105
2106        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2107        let len = 2 * 64 + 32;
2108
2109        let int_array = Int32Array::from_value(2, len);
2110        let float_array = Float64Array::from_value(4.2, len);
2111        let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2112        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2113
2114        let children = vec![
2115            Arc::new(int_array) as Arc<dyn Array>,
2116            Arc::new(float_array),
2117            Arc::new(str_array),
2118        ];
2119
2120        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2121
2122        let expected = BooleanBuffer::from_iter(
2123            [true, true, true, true, false, true]
2124                .into_iter()
2125                .cycle()
2126                .take(len),
2127        );
2128
2129        assert_eq!(array.len(), len);
2130        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2131        assert_eq!(
2132            expected,
2133            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2134        );
2135    }
2136
2137    #[test]
2138    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2139        // union of [{A=}, {A=}, {B=4.2}, {B=4.2}, {C=}, {C=}]
2140        let int_array = Int32Array::new_null(6);
2141        let float_array = Float64Array::from_value(4.2, 6);
2142        let str_array = StringArray::new_null(6);
2143        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2144
2145        let children = vec![
2146            Arc::new(int_array) as Arc<dyn Array>,
2147            Arc::new(float_array),
2148            Arc::new(str_array),
2149        ];
2150
2151        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2152
2153        let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2154
2155        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2156        assert_eq!(
2157            expected,
2158            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2159        );
2160
2161        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2162        let len = 2 * 64 + 32;
2163
2164        let int_array = Int32Array::new_null(len);
2165        let float_array = Float64Array::from_value(4.2, len);
2166        let str_array = StringArray::new_null(len);
2167        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2168
2169        let children = vec![
2170            Arc::new(int_array) as Arc<dyn Array>,
2171            Arc::new(float_array),
2172            Arc::new(str_array),
2173        ];
2174
2175        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2176
2177        let expected = BooleanBuffer::from_iter(
2178            [false, false, true, true, false, false]
2179                .into_iter()
2180                .cycle()
2181                .take(len),
2182        );
2183
2184        assert_eq!(array.len(), len);
2185        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2186        assert_eq!(
2187            expected,
2188            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2189        );
2190    }
2191
2192    #[test]
2193    fn test_sparse_union_logical_nulls_gather() {
2194        let n_fields = 50;
2195
2196        let non_null = Int32Array::from_value(2, 4);
2197        let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2198        let fully_null = Int32Array::new_null(4);
2199
2200        let array = UnionArray::try_new(
2201            (1..)
2202                .step_by(2)
2203                .map(|i| {
2204                    (
2205                        i,
2206                        Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2207                    )
2208                })
2209                .take(n_fields)
2210                .collect(),
2211            vec![1, 3, 3, 5].into(),
2212            None,
2213            [
2214                Arc::new(non_null) as ArrayRef,
2215                Arc::new(mixed),
2216                Arc::new(fully_null),
2217            ]
2218            .into_iter()
2219            .cycle()
2220            .take(n_fields)
2221            .collect(),
2222        )
2223        .unwrap();
2224
2225        let expected = BooleanBuffer::from(vec![true, false, true, false]);
2226
2227        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2228        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2229    }
2230
2231    fn union_fields() -> UnionFields {
2232        [
2233            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2234            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2235            (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2236        ]
2237        .into_iter()
2238        .collect()
2239    }
2240
2241    #[test]
2242    fn test_is_nullable() {
2243        assert!(!create_union_array(false, false).is_nullable());
2244        assert!(create_union_array(true, false).is_nullable());
2245        assert!(create_union_array(false, true).is_nullable());
2246        assert!(create_union_array(true, true).is_nullable());
2247    }
2248
2249    /// Create a union array with a float and integer field
2250    ///
2251    /// If the `int_nullable` is true, the integer field will have nulls
2252    /// If the `float_nullable` is true, the float field will have nulls
2253    ///
2254    /// Note the `Field` definitions are always declared to be nullable
2255    fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2256        let int_array = if int_nullable {
2257            Int32Array::from(vec![Some(1), None, Some(3)])
2258        } else {
2259            Int32Array::from(vec![1, 2, 3])
2260        };
2261        let float_array = if float_nullable {
2262            Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2263        } else {
2264            Float64Array::from(vec![3.2, 4.2, 5.2])
2265        };
2266        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2267        let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2268        let union_fields = [
2269            (0, Arc::new(Field::new("A", DataType::Int32, true))),
2270            (1, Arc::new(Field::new("B", DataType::Float64, true))),
2271        ]
2272        .into_iter()
2273        .collect::<UnionFields>();
2274
2275        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2276
2277        UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2278    }
2279}