arrow_array/builder/
union_builder.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder};
19use crate::builder::{ArrayBuilder, BufferBuilder};
20use crate::{make_array, ArrayRef, ArrowPrimitiveType, UnionArray};
21use arrow_buffer::NullBufferBuilder;
22use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};
23use arrow_data::ArrayDataBuilder;
24use arrow_schema::{ArrowError, DataType, Field};
25use std::any::Any;
26use std::collections::BTreeMap;
27use std::sync::Arc;
28
29/// `FieldData` is a helper struct to track the state of the fields in the `UnionBuilder`.
30#[derive(Debug)]
31struct FieldData {
32    /// The type id for this field
33    type_id: i8,
34    /// The Arrow data type represented in the `values_buffer`, which is untyped
35    data_type: DataType,
36    /// A buffer containing the values for this field in raw bytes
37    values_buffer: Box<dyn FieldDataValues>,
38    ///  The number of array slots represented by the buffer
39    slots: usize,
40    /// A builder for the null bitmap
41    null_buffer_builder: NullBufferBuilder,
42}
43
44/// A type-erased [`BufferBuilder`] used by [`FieldData`]
45trait FieldDataValues: std::fmt::Debug + Send + Sync {
46    fn as_mut_any(&mut self) -> &mut dyn Any;
47
48    fn append_null(&mut self);
49
50    fn finish(&mut self) -> Buffer;
51
52    fn finish_cloned(&self) -> Buffer;
53}
54
55impl<T: ArrowNativeType> FieldDataValues for BufferBuilder<T> {
56    fn as_mut_any(&mut self) -> &mut dyn Any {
57        self
58    }
59
60    fn append_null(&mut self) {
61        self.advance(1)
62    }
63
64    fn finish(&mut self) -> Buffer {
65        self.finish()
66    }
67
68    fn finish_cloned(&self) -> Buffer {
69        Buffer::from_slice_ref(self.as_slice())
70    }
71}
72
73impl FieldData {
74    /// Creates a new `FieldData`.
75    fn new<T: ArrowPrimitiveType>(type_id: i8, data_type: DataType, capacity: usize) -> Self {
76        Self {
77            type_id,
78            data_type,
79            slots: 0,
80            values_buffer: Box::new(BufferBuilder::<T::Native>::new(capacity)),
81            null_buffer_builder: NullBufferBuilder::new(capacity),
82        }
83    }
84
85    /// Appends a single value to this `FieldData`'s `values_buffer`.
86    fn append_value<T: ArrowPrimitiveType>(&mut self, v: T::Native) {
87        self.values_buffer
88            .as_mut_any()
89            .downcast_mut::<BufferBuilder<T::Native>>()
90            .expect("Tried to append unexpected type")
91            .append(v);
92
93        self.null_buffer_builder.append(true);
94        self.slots += 1;
95    }
96
97    /// Appends a null to this `FieldData`.
98    fn append_null(&mut self) {
99        self.values_buffer.append_null();
100        self.null_buffer_builder.append(false);
101        self.slots += 1;
102    }
103}
104
105/// Builder for [`UnionArray`]
106///
107/// Example: **Dense Memory Layout**
108///
109/// ```
110/// # use arrow_array::builder::UnionBuilder;
111/// # use arrow_array::types::{Float64Type, Int32Type};
112///
113/// let mut builder = UnionBuilder::new_dense();
114/// builder.append::<Int32Type>("a", 1).unwrap();
115/// builder.append::<Float64Type>("b", 3.0).unwrap();
116/// builder.append::<Int32Type>("a", 4).unwrap();
117/// let union = builder.build().unwrap();
118///
119/// assert_eq!(union.type_id(0), 0);
120/// assert_eq!(union.type_id(1), 1);
121/// assert_eq!(union.type_id(2), 0);
122///
123/// assert_eq!(union.value_offset(0), 0);
124/// assert_eq!(union.value_offset(1), 0);
125/// assert_eq!(union.value_offset(2), 1);
126/// ```
127///
128/// Example: **Sparse Memory Layout**
129/// ```
130/// # use arrow_array::builder::UnionBuilder;
131/// # use arrow_array::types::{Float64Type, Int32Type};
132///
133/// let mut builder = UnionBuilder::new_sparse();
134/// builder.append::<Int32Type>("a", 1).unwrap();
135/// builder.append::<Float64Type>("b", 3.0).unwrap();
136/// builder.append::<Int32Type>("a", 4).unwrap();
137/// let union = builder.build().unwrap();
138///
139/// assert_eq!(union.type_id(0), 0);
140/// assert_eq!(union.type_id(1), 1);
141/// assert_eq!(union.type_id(2), 0);
142///
143/// assert_eq!(union.value_offset(0), 0);
144/// assert_eq!(union.value_offset(1), 1);
145/// assert_eq!(union.value_offset(2), 2);
146/// ```
147#[derive(Debug, Default)]
148pub struct UnionBuilder {
149    /// The current number of slots in the array
150    len: usize,
151    /// Maps field names to `FieldData` instances which track the builders for that field
152    fields: BTreeMap<String, FieldData>,
153    /// Builder to keep track of type ids
154    type_id_builder: Int8BufferBuilder,
155    /// Builder to keep track of offsets (`None` for sparse unions)
156    value_offset_builder: Option<Int32BufferBuilder>,
157    initial_capacity: usize,
158}
159
160impl UnionBuilder {
161    /// Creates a new dense array builder.
162    pub fn new_dense() -> Self {
163        Self::with_capacity_dense(1024)
164    }
165
166    /// Creates a new sparse array builder.
167    pub fn new_sparse() -> Self {
168        Self::with_capacity_sparse(1024)
169    }
170
171    /// Creates a new dense array builder with capacity.
172    pub fn with_capacity_dense(capacity: usize) -> Self {
173        Self {
174            len: 0,
175            fields: Default::default(),
176            type_id_builder: Int8BufferBuilder::new(capacity),
177            value_offset_builder: Some(Int32BufferBuilder::new(capacity)),
178            initial_capacity: capacity,
179        }
180    }
181
182    /// Creates a new sparse array builder  with capacity.
183    pub fn with_capacity_sparse(capacity: usize) -> Self {
184        Self {
185            len: 0,
186            fields: Default::default(),
187            type_id_builder: Int8BufferBuilder::new(capacity),
188            value_offset_builder: None,
189            initial_capacity: capacity,
190        }
191    }
192
193    /// Appends a null to this builder, encoding the null in the array
194    /// of the `type_name` child / field.
195    ///
196    /// Since `UnionArray` encodes nulls as an entry in its children
197    /// (it doesn't have a validity bitmap itself), and where the null
198    /// is part of the final array, appending a NULL requires
199    /// specifying which field (child) to use.
200    #[inline]
201    pub fn append_null<T: ArrowPrimitiveType>(
202        &mut self,
203        type_name: &str,
204    ) -> Result<(), ArrowError> {
205        self.append_option::<T>(type_name, None)
206    }
207
208    /// Appends a value to this builder.
209    #[inline]
210    pub fn append<T: ArrowPrimitiveType>(
211        &mut self,
212        type_name: &str,
213        v: T::Native,
214    ) -> Result<(), ArrowError> {
215        self.append_option::<T>(type_name, Some(v))
216    }
217
218    fn append_option<T: ArrowPrimitiveType>(
219        &mut self,
220        type_name: &str,
221        v: Option<T::Native>,
222    ) -> Result<(), ArrowError> {
223        let type_name = type_name.to_string();
224
225        let mut field_data = match self.fields.remove(&type_name) {
226            Some(data) => {
227                if data.data_type != T::DATA_TYPE {
228                    return Err(ArrowError::InvalidArgumentError(format!(
229                        "Attempt to write col \"{}\" with type {} doesn't match existing type {}",
230                        type_name,
231                        T::DATA_TYPE,
232                        data.data_type
233                    )));
234                }
235                data
236            }
237            None => match self.value_offset_builder {
238                Some(_) => FieldData::new::<T>(
239                    self.fields.len() as i8,
240                    T::DATA_TYPE,
241                    self.initial_capacity,
242                ),
243                // In the case of a sparse union, we should pass the maximum of the currently length and the capacity.
244                None => {
245                    let mut fd = FieldData::new::<T>(
246                        self.fields.len() as i8,
247                        T::DATA_TYPE,
248                        self.len.max(self.initial_capacity),
249                    );
250                    for _ in 0..self.len {
251                        fd.append_null();
252                    }
253                    fd
254                }
255            },
256        };
257        self.type_id_builder.append(field_data.type_id);
258
259        match &mut self.value_offset_builder {
260            // Dense Union
261            Some(offset_builder) => {
262                offset_builder.append(field_data.slots as i32);
263            }
264            // Sparse Union
265            None => {
266                for (_, fd) in self.fields.iter_mut() {
267                    // Append to all bar the FieldData currently being appended to
268                    fd.append_null();
269                }
270            }
271        }
272
273        match v {
274            Some(v) => field_data.append_value::<T>(v),
275            None => field_data.append_null(),
276        }
277
278        self.fields.insert(type_name, field_data);
279        self.len += 1;
280        Ok(())
281    }
282
283    /// Builds this builder creating a new `UnionArray`.
284    pub fn build(self) -> Result<UnionArray, ArrowError> {
285        let mut children = Vec::with_capacity(self.fields.len());
286        let union_fields = self
287            .fields
288            .into_iter()
289            .map(
290                |(
291                    name,
292                    FieldData {
293                        type_id,
294                        data_type,
295                        mut values_buffer,
296                        slots,
297                        mut null_buffer_builder,
298                    },
299                )| {
300                    let array_ref = make_array(unsafe {
301                        ArrayDataBuilder::new(data_type.clone())
302                            .add_buffer(values_buffer.finish())
303                            .len(slots)
304                            .nulls(null_buffer_builder.finish())
305                            .build_unchecked()
306                    });
307                    children.push(array_ref);
308                    (type_id, Arc::new(Field::new(name, data_type, false)))
309                },
310            )
311            .collect();
312        UnionArray::try_new(
313            union_fields,
314            self.type_id_builder.into(),
315            self.value_offset_builder.map(Into::into),
316            children,
317        )
318    }
319
320    /// Builds this builder creating a new `UnionArray` without consuming the builder.
321    ///
322    /// This is used for the `finish_cloned` implementation in `ArrayBuilder`.
323    fn build_cloned(&self) -> Result<UnionArray, ArrowError> {
324        let mut children = Vec::with_capacity(self.fields.len());
325        let union_fields: Vec<_> = self
326            .fields
327            .iter()
328            .map(|(name, field_data)| {
329                let FieldData {
330                    type_id,
331                    data_type,
332                    values_buffer,
333                    slots,
334                    null_buffer_builder,
335                } = field_data;
336
337                let array_ref = make_array(unsafe {
338                    ArrayDataBuilder::new(data_type.clone())
339                        .add_buffer(values_buffer.finish_cloned())
340                        .len(*slots)
341                        .nulls(null_buffer_builder.finish_cloned())
342                        .build_unchecked()
343                });
344                children.push(array_ref);
345                (
346                    *type_id,
347                    Arc::new(Field::new(name.clone(), data_type.clone(), false)),
348                )
349            })
350            .collect();
351        UnionArray::try_new(
352            union_fields.into_iter().collect(),
353            ScalarBuffer::from(self.type_id_builder.as_slice().to_vec()),
354            self.value_offset_builder
355                .as_ref()
356                .map(|builder| ScalarBuffer::from(builder.as_slice().to_vec())),
357            children,
358        )
359    }
360}
361
362impl ArrayBuilder for UnionBuilder {
363    /// Returns the number of array slots in the builder
364    fn len(&self) -> usize {
365        self.len
366    }
367
368    /// Builds the array
369    fn finish(&mut self) -> ArrayRef {
370        // Even simpler - just move the builder using mem::take and replace with default
371        let builder = std::mem::take(self);
372
373        // Since UnionBuilder controls all invariants, this should never fail
374        Arc::new(builder.build().unwrap())
375    }
376
377    /// Builds the array without resetting the underlying builder
378    fn finish_cloned(&self) -> ArrayRef {
379        // We construct the UnionArray carefully to ensure try_new cannot fail.
380        // Since UnionBuilder controls all the invariants, this should never panic.
381        Arc::new(self.build_cloned().unwrap_or_else(|err| {
382            panic!("UnionBuilder::build_cloned failed unexpectedly: {}", err)
383        }))
384    }
385
386    /// Returns the builder as a non-mutable `Any` reference
387    fn as_any(&self) -> &dyn Any {
388        self
389    }
390
391    /// Returns the builder as a mutable `Any` reference
392    fn as_any_mut(&mut self) -> &mut dyn Any {
393        self
394    }
395
396    /// Returns the boxed builder as a box of `Any`
397    fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
398        self
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::array::Array;
406    use crate::cast::AsArray;
407    use crate::types::{Float64Type, Int32Type};
408
409    #[test]
410    fn test_union_builder_array_builder_trait() {
411        // Test that UnionBuilder implements ArrayBuilder trait
412        let mut builder = UnionBuilder::new_dense();
413
414        // Add some data
415        builder.append::<Int32Type>("a", 1).unwrap();
416        builder.append::<Float64Type>("b", 3.0).unwrap();
417        builder.append::<Int32Type>("a", 4).unwrap();
418
419        assert_eq!(builder.len(), 3);
420
421        // Test finish_cloned (non-destructive)
422        let array1 = builder.finish_cloned();
423        assert_eq!(array1.len(), 3);
424
425        // Verify values in cloned array
426        let union1 = array1.as_any().downcast_ref::<UnionArray>().unwrap();
427        assert_eq!(union1.type_ids(), &[0, 1, 0]);
428        assert_eq!(union1.offsets().unwrap().as_ref(), &[0, 0, 1]);
429        let int_array1 = union1.child(0).as_primitive::<Int32Type>();
430        let float_array1 = union1.child(1).as_primitive::<Float64Type>();
431        assert_eq!(int_array1.value(0), 1);
432        assert_eq!(int_array1.value(1), 4);
433        assert_eq!(float_array1.value(0), 3.0);
434
435        // Builder should still be usable after finish_cloned
436        builder.append::<Float64Type>("b", 5.0).unwrap();
437        assert_eq!(builder.len(), 4);
438
439        // Test finish (destructive)
440        let array2 = builder.finish();
441        assert_eq!(array2.len(), 4);
442
443        // Verify values in final array
444        let union2 = array2.as_any().downcast_ref::<UnionArray>().unwrap();
445        assert_eq!(union2.type_ids(), &[0, 1, 0, 1]);
446        assert_eq!(union2.offsets().unwrap().as_ref(), &[0, 0, 1, 1]);
447        let int_array2 = union2.child(0).as_primitive::<Int32Type>();
448        let float_array2 = union2.child(1).as_primitive::<Float64Type>();
449        assert_eq!(int_array2.value(0), 1);
450        assert_eq!(int_array2.value(1), 4);
451        assert_eq!(float_array2.value(0), 3.0);
452        assert_eq!(float_array2.value(1), 5.0);
453    }
454
455    #[test]
456    fn test_union_builder_type_erased() {
457        // Test type-erased usage with Box<dyn ArrayBuilder>
458        let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![Box::new(UnionBuilder::new_sparse())];
459
460        // Downcast and use
461        let union_builder = builders[0]
462            .as_any_mut()
463            .downcast_mut::<UnionBuilder>()
464            .unwrap();
465        union_builder.append::<Int32Type>("x", 10).unwrap();
466        union_builder.append::<Float64Type>("y", 20.0).unwrap();
467
468        assert_eq!(builders[0].len(), 2);
469
470        let result = builders
471            .into_iter()
472            .map(|mut b| b.finish())
473            .collect::<Vec<_>>();
474        assert_eq!(result[0].len(), 2);
475
476        // Verify sparse union values
477        let union = result[0].as_any().downcast_ref::<UnionArray>().unwrap();
478        assert_eq!(union.type_ids(), &[0, 1]);
479        assert!(union.offsets().is_none()); // Sparse union has no offsets
480        let int_array = union.child(0).as_primitive::<Int32Type>();
481        let float_array = union.child(1).as_primitive::<Float64Type>();
482        assert_eq!(int_array.value(0), 10);
483        assert!(int_array.is_null(1)); // Null in sparse layout
484        assert!(float_array.is_null(0)); // Null in sparse layout
485        assert_eq!(float_array.value(1), 20.0);
486    }
487}