1use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder};
19use crate::builder::{ArrayBuilder, BufferBuilder};
20use crate::{make_array, ArrayRef, ArrowPrimitiveType, UnionArray};
21use arrow_buffer::NullBufferBuilder;
22use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};
23use arrow_data::ArrayDataBuilder;
24use arrow_schema::{ArrowError, DataType, Field};
25use std::any::Any;
26use std::collections::BTreeMap;
27use std::sync::Arc;
28
29#[derive(Debug)]
31struct FieldData {
32 type_id: i8,
34 data_type: DataType,
36 values_buffer: Box<dyn FieldDataValues>,
38 slots: usize,
40 null_buffer_builder: NullBufferBuilder,
42}
43
44trait FieldDataValues: std::fmt::Debug + Send + Sync {
46 fn as_mut_any(&mut self) -> &mut dyn Any;
47
48 fn append_null(&mut self);
49
50 fn finish(&mut self) -> Buffer;
51
52 fn finish_cloned(&self) -> Buffer;
53}
54
55impl<T: ArrowNativeType> FieldDataValues for BufferBuilder<T> {
56 fn as_mut_any(&mut self) -> &mut dyn Any {
57 self
58 }
59
60 fn append_null(&mut self) {
61 self.advance(1)
62 }
63
64 fn finish(&mut self) -> Buffer {
65 self.finish()
66 }
67
68 fn finish_cloned(&self) -> Buffer {
69 Buffer::from_slice_ref(self.as_slice())
70 }
71}
72
73impl FieldData {
74 fn new<T: ArrowPrimitiveType>(type_id: i8, data_type: DataType, capacity: usize) -> Self {
76 Self {
77 type_id,
78 data_type,
79 slots: 0,
80 values_buffer: Box::new(BufferBuilder::<T::Native>::new(capacity)),
81 null_buffer_builder: NullBufferBuilder::new(capacity),
82 }
83 }
84
85 fn append_value<T: ArrowPrimitiveType>(&mut self, v: T::Native) {
87 self.values_buffer
88 .as_mut_any()
89 .downcast_mut::<BufferBuilder<T::Native>>()
90 .expect("Tried to append unexpected type")
91 .append(v);
92
93 self.null_buffer_builder.append(true);
94 self.slots += 1;
95 }
96
97 fn append_null(&mut self) {
99 self.values_buffer.append_null();
100 self.null_buffer_builder.append(false);
101 self.slots += 1;
102 }
103}
104
105#[derive(Debug, Default)]
148pub struct UnionBuilder {
149 len: usize,
151 fields: BTreeMap<String, FieldData>,
153 type_id_builder: Int8BufferBuilder,
155 value_offset_builder: Option<Int32BufferBuilder>,
157 initial_capacity: usize,
158}
159
160impl UnionBuilder {
161 pub fn new_dense() -> Self {
163 Self::with_capacity_dense(1024)
164 }
165
166 pub fn new_sparse() -> Self {
168 Self::with_capacity_sparse(1024)
169 }
170
171 pub fn with_capacity_dense(capacity: usize) -> Self {
173 Self {
174 len: 0,
175 fields: Default::default(),
176 type_id_builder: Int8BufferBuilder::new(capacity),
177 value_offset_builder: Some(Int32BufferBuilder::new(capacity)),
178 initial_capacity: capacity,
179 }
180 }
181
182 pub fn with_capacity_sparse(capacity: usize) -> Self {
184 Self {
185 len: 0,
186 fields: Default::default(),
187 type_id_builder: Int8BufferBuilder::new(capacity),
188 value_offset_builder: None,
189 initial_capacity: capacity,
190 }
191 }
192
193 #[inline]
201 pub fn append_null<T: ArrowPrimitiveType>(
202 &mut self,
203 type_name: &str,
204 ) -> Result<(), ArrowError> {
205 self.append_option::<T>(type_name, None)
206 }
207
208 #[inline]
210 pub fn append<T: ArrowPrimitiveType>(
211 &mut self,
212 type_name: &str,
213 v: T::Native,
214 ) -> Result<(), ArrowError> {
215 self.append_option::<T>(type_name, Some(v))
216 }
217
218 fn append_option<T: ArrowPrimitiveType>(
219 &mut self,
220 type_name: &str,
221 v: Option<T::Native>,
222 ) -> Result<(), ArrowError> {
223 let type_name = type_name.to_string();
224
225 let mut field_data = match self.fields.remove(&type_name) {
226 Some(data) => {
227 if data.data_type != T::DATA_TYPE {
228 return Err(ArrowError::InvalidArgumentError(format!(
229 "Attempt to write col \"{}\" with type {} doesn't match existing type {}",
230 type_name,
231 T::DATA_TYPE,
232 data.data_type
233 )));
234 }
235 data
236 }
237 None => match self.value_offset_builder {
238 Some(_) => FieldData::new::<T>(
239 self.fields.len() as i8,
240 T::DATA_TYPE,
241 self.initial_capacity,
242 ),
243 None => {
245 let mut fd = FieldData::new::<T>(
246 self.fields.len() as i8,
247 T::DATA_TYPE,
248 self.len.max(self.initial_capacity),
249 );
250 for _ in 0..self.len {
251 fd.append_null();
252 }
253 fd
254 }
255 },
256 };
257 self.type_id_builder.append(field_data.type_id);
258
259 match &mut self.value_offset_builder {
260 Some(offset_builder) => {
262 offset_builder.append(field_data.slots as i32);
263 }
264 None => {
266 for (_, fd) in self.fields.iter_mut() {
267 fd.append_null();
269 }
270 }
271 }
272
273 match v {
274 Some(v) => field_data.append_value::<T>(v),
275 None => field_data.append_null(),
276 }
277
278 self.fields.insert(type_name, field_data);
279 self.len += 1;
280 Ok(())
281 }
282
283 pub fn build(self) -> Result<UnionArray, ArrowError> {
285 let mut children = Vec::with_capacity(self.fields.len());
286 let union_fields = self
287 .fields
288 .into_iter()
289 .map(
290 |(
291 name,
292 FieldData {
293 type_id,
294 data_type,
295 mut values_buffer,
296 slots,
297 mut null_buffer_builder,
298 },
299 )| {
300 let array_ref = make_array(unsafe {
301 ArrayDataBuilder::new(data_type.clone())
302 .add_buffer(values_buffer.finish())
303 .len(slots)
304 .nulls(null_buffer_builder.finish())
305 .build_unchecked()
306 });
307 children.push(array_ref);
308 (type_id, Arc::new(Field::new(name, data_type, false)))
309 },
310 )
311 .collect();
312 UnionArray::try_new(
313 union_fields,
314 self.type_id_builder.into(),
315 self.value_offset_builder.map(Into::into),
316 children,
317 )
318 }
319
320 fn build_cloned(&self) -> Result<UnionArray, ArrowError> {
324 let mut children = Vec::with_capacity(self.fields.len());
325 let union_fields: Vec<_> = self
326 .fields
327 .iter()
328 .map(|(name, field_data)| {
329 let FieldData {
330 type_id,
331 data_type,
332 values_buffer,
333 slots,
334 null_buffer_builder,
335 } = field_data;
336
337 let array_ref = make_array(unsafe {
338 ArrayDataBuilder::new(data_type.clone())
339 .add_buffer(values_buffer.finish_cloned())
340 .len(*slots)
341 .nulls(null_buffer_builder.finish_cloned())
342 .build_unchecked()
343 });
344 children.push(array_ref);
345 (
346 *type_id,
347 Arc::new(Field::new(name.clone(), data_type.clone(), false)),
348 )
349 })
350 .collect();
351 UnionArray::try_new(
352 union_fields.into_iter().collect(),
353 ScalarBuffer::from(self.type_id_builder.as_slice().to_vec()),
354 self.value_offset_builder
355 .as_ref()
356 .map(|builder| ScalarBuffer::from(builder.as_slice().to_vec())),
357 children,
358 )
359 }
360}
361
362impl ArrayBuilder for UnionBuilder {
363 fn len(&self) -> usize {
365 self.len
366 }
367
368 fn finish(&mut self) -> ArrayRef {
370 let builder = std::mem::take(self);
372
373 Arc::new(builder.build().unwrap())
375 }
376
377 fn finish_cloned(&self) -> ArrayRef {
379 Arc::new(self.build_cloned().unwrap_or_else(|err| {
382 panic!("UnionBuilder::build_cloned failed unexpectedly: {}", err)
383 }))
384 }
385
386 fn as_any(&self) -> &dyn Any {
388 self
389 }
390
391 fn as_any_mut(&mut self) -> &mut dyn Any {
393 self
394 }
395
396 fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
398 self
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::array::Array;
406 use crate::cast::AsArray;
407 use crate::types::{Float64Type, Int32Type};
408
409 #[test]
410 fn test_union_builder_array_builder_trait() {
411 let mut builder = UnionBuilder::new_dense();
413
414 builder.append::<Int32Type>("a", 1).unwrap();
416 builder.append::<Float64Type>("b", 3.0).unwrap();
417 builder.append::<Int32Type>("a", 4).unwrap();
418
419 assert_eq!(builder.len(), 3);
420
421 let array1 = builder.finish_cloned();
423 assert_eq!(array1.len(), 3);
424
425 let union1 = array1.as_any().downcast_ref::<UnionArray>().unwrap();
427 assert_eq!(union1.type_ids(), &[0, 1, 0]);
428 assert_eq!(union1.offsets().unwrap().as_ref(), &[0, 0, 1]);
429 let int_array1 = union1.child(0).as_primitive::<Int32Type>();
430 let float_array1 = union1.child(1).as_primitive::<Float64Type>();
431 assert_eq!(int_array1.value(0), 1);
432 assert_eq!(int_array1.value(1), 4);
433 assert_eq!(float_array1.value(0), 3.0);
434
435 builder.append::<Float64Type>("b", 5.0).unwrap();
437 assert_eq!(builder.len(), 4);
438
439 let array2 = builder.finish();
441 assert_eq!(array2.len(), 4);
442
443 let union2 = array2.as_any().downcast_ref::<UnionArray>().unwrap();
445 assert_eq!(union2.type_ids(), &[0, 1, 0, 1]);
446 assert_eq!(union2.offsets().unwrap().as_ref(), &[0, 0, 1, 1]);
447 let int_array2 = union2.child(0).as_primitive::<Int32Type>();
448 let float_array2 = union2.child(1).as_primitive::<Float64Type>();
449 assert_eq!(int_array2.value(0), 1);
450 assert_eq!(int_array2.value(1), 4);
451 assert_eq!(float_array2.value(0), 3.0);
452 assert_eq!(float_array2.value(1), 5.0);
453 }
454
455 #[test]
456 fn test_union_builder_type_erased() {
457 let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![Box::new(UnionBuilder::new_sparse())];
459
460 let union_builder = builders[0]
462 .as_any_mut()
463 .downcast_mut::<UnionBuilder>()
464 .unwrap();
465 union_builder.append::<Int32Type>("x", 10).unwrap();
466 union_builder.append::<Float64Type>("y", 20.0).unwrap();
467
468 assert_eq!(builders[0].len(), 2);
469
470 let result = builders
471 .into_iter()
472 .map(|mut b| b.finish())
473 .collect::<Vec<_>>();
474 assert_eq!(result[0].len(), 2);
475
476 let union = result[0].as_any().downcast_ref::<UnionArray>().unwrap();
478 assert_eq!(union.type_ids(), &[0, 1]);
479 assert!(union.offsets().is_none()); let int_array = union.child(0).as_primitive::<Int32Type>();
481 let float_array = union.child(1).as_primitive::<Float64Type>();
482 assert_eq!(int_array.value(0), 10);
483 assert!(int_array.is_null(1)); assert!(float_array.is_null(0)); assert_eq!(float_array.value(1), 20.0);
486 }
487}