1use crate::builder::{ArrayBuilder, PrimitiveBuilder};
19use crate::types::ArrowDictionaryKeyType;
20use crate::{
21 Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, PrimitiveArray, TypedDictionaryArray,
22};
23use arrow_buffer::{ArrowNativeType, ToByteSlice};
24use arrow_schema::{ArrowError, DataType};
25use num::NumCast;
26use std::any::Any;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30#[derive(Debug)]
34struct Value<T>(T);
35
36impl<T: ToByteSlice> std::hash::Hash for Value<T> {
37 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
38 self.0.to_byte_slice().hash(state)
39 }
40}
41
42impl<T: ToByteSlice> PartialEq for Value<T> {
43 fn eq(&self, other: &Self) -> bool {
44 self.0.to_byte_slice().eq(other.0.to_byte_slice())
45 }
46}
47
48impl<T: ToByteSlice> Eq for Value<T> {}
49
50#[derive(Debug)]
83pub struct PrimitiveDictionaryBuilder<K, V>
84where
85 K: ArrowPrimitiveType,
86 V: ArrowPrimitiveType,
87{
88 keys_builder: PrimitiveBuilder<K>,
89 values_builder: PrimitiveBuilder<V>,
90 map: HashMap<Value<V::Native>, usize>,
91}
92
93impl<K, V> Default for PrimitiveDictionaryBuilder<K, V>
94where
95 K: ArrowPrimitiveType,
96 V: ArrowPrimitiveType,
97{
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl<K, V> PrimitiveDictionaryBuilder<K, V>
104where
105 K: ArrowPrimitiveType,
106 V: ArrowPrimitiveType,
107{
108 pub fn new() -> Self {
110 Self {
111 keys_builder: PrimitiveBuilder::new(),
112 values_builder: PrimitiveBuilder::new(),
113 map: HashMap::new(),
114 }
115 }
116
117 pub fn new_from_empty_builders(
123 keys_builder: PrimitiveBuilder<K>,
124 values_builder: PrimitiveBuilder<V>,
125 ) -> Self {
126 assert!(
127 keys_builder.is_empty() && values_builder.is_empty(),
128 "keys and values builders must be empty"
129 );
130 let values_capacity = values_builder.capacity();
131 Self {
132 keys_builder,
133 values_builder,
134 map: HashMap::with_capacity(values_capacity),
135 }
136 }
137
138 pub unsafe fn new_from_builders(
144 keys_builder: PrimitiveBuilder<K>,
145 values_builder: PrimitiveBuilder<V>,
146 ) -> Self {
147 let keys = keys_builder.values_slice();
148 let values = values_builder.values_slice();
149 let mut map = HashMap::with_capacity(values.len());
150
151 keys.iter().zip(values.iter()).for_each(|(key, value)| {
152 map.insert(Value(*value), K::Native::to_usize(*key).unwrap());
153 });
154
155 Self {
156 keys_builder,
157 values_builder,
158 map,
159 }
160 }
161
162 pub fn with_capacity(keys_capacity: usize, values_capacity: usize) -> Self {
167 Self {
168 keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
169 values_builder: PrimitiveBuilder::with_capacity(values_capacity),
170 map: HashMap::with_capacity(values_capacity),
171 }
172 }
173
174 pub fn try_new_from_builder<K2>(
201 mut source: PrimitiveDictionaryBuilder<K2, V>,
202 ) -> Result<Self, ArrowError>
203 where
204 K::Native: NumCast,
205 K2: ArrowDictionaryKeyType,
206 K2::Native: NumCast,
207 {
208 let map = source.map;
209 let values_builder = source.values_builder;
210
211 let source_keys = source.keys_builder.finish();
212 let new_keys: PrimitiveArray<K> = source_keys.try_unary(|value| {
213 num::cast::cast::<K2::Native, K::Native>(value).ok_or_else(|| {
214 ArrowError::CastError(format!(
215 "Can't cast dictionary keys from source type {:?} to type {:?}",
216 K2::DATA_TYPE,
217 K::DATA_TYPE
218 ))
219 })
220 })?;
221
222 drop(source_keys);
226
227 Ok(Self {
228 map,
229 keys_builder: new_keys
230 .into_builder()
231 .expect("underlying buffer has no references"),
232 values_builder,
233 })
234 }
235}
236
237impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V>
238where
239 K: ArrowDictionaryKeyType,
240 V: ArrowPrimitiveType,
241{
242 fn as_any(&self) -> &dyn Any {
244 self
245 }
246
247 fn as_any_mut(&mut self) -> &mut dyn Any {
249 self
250 }
251
252 fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
254 self
255 }
256
257 fn len(&self) -> usize {
259 self.keys_builder.len()
260 }
261
262 fn finish(&mut self) -> ArrayRef {
264 Arc::new(self.finish())
265 }
266
267 fn finish_cloned(&self) -> ArrayRef {
269 Arc::new(self.finish_cloned())
270 }
271}
272
273impl<K, V> PrimitiveDictionaryBuilder<K, V>
274where
275 K: ArrowDictionaryKeyType,
276 V: ArrowPrimitiveType,
277{
278 #[inline]
279 fn get_or_insert_key(&mut self, value: V::Native) -> Result<K::Native, ArrowError> {
280 match self.map.get(&Value(value)) {
281 Some(&key) => {
282 Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?)
283 }
284 None => {
285 let key = self.values_builder.len();
286 self.values_builder.append_value(value);
287 self.map.insert(Value(value), key);
288 Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?)
289 }
290 }
291 }
292
293 #[inline]
297 pub fn append(&mut self, value: V::Native) -> Result<K::Native, ArrowError> {
298 let key = self.get_or_insert_key(value)?;
299 self.keys_builder.append_value(key);
300 Ok(key)
301 }
302
303 pub fn append_n(&mut self, value: V::Native, count: usize) -> Result<K::Native, ArrowError> {
308 let key = self.get_or_insert_key(value)?;
309 self.keys_builder.append_value_n(key, count);
310 Ok(key)
311 }
312
313 #[inline]
319 pub fn append_value(&mut self, value: V::Native) {
320 self.append(value).expect("dictionary key overflow");
321 }
322
323 pub fn append_values(&mut self, value: V::Native, count: usize) {
330 self.append_n(value, count)
331 .expect("dictionary key overflow");
332 }
333
334 #[inline]
336 pub fn append_null(&mut self) {
337 self.keys_builder.append_null()
338 }
339
340 #[inline]
342 pub fn append_nulls(&mut self, n: usize) {
343 self.keys_builder.append_nulls(n)
344 }
345
346 #[inline]
352 pub fn append_option(&mut self, value: Option<V::Native>) {
353 match value {
354 None => self.append_null(),
355 Some(v) => self.append_value(v),
356 };
357 }
358
359 pub fn append_options(&mut self, value: Option<V::Native>, count: usize) {
366 match value {
367 None => self.keys_builder.append_nulls(count),
368 Some(v) => self.append_values(v, count),
369 };
370 }
371
372 pub fn extend_dictionary(
380 &mut self,
381 dictionary: &TypedDictionaryArray<K, PrimitiveArray<V>>,
382 ) -> Result<(), ArrowError> {
383 let values = dictionary.values();
384
385 let v_len = values.len();
386 let k_len = dictionary.keys().len();
387 if v_len == 0 && k_len == 0 {
388 return Ok(());
389 }
390
391 if v_len == 0 {
393 self.append_nulls(k_len);
394 return Ok(());
395 }
396
397 if k_len == 0 {
398 return Err(ArrowError::InvalidArgumentError(
399 "Dictionary keys should not be empty when values are not empty".to_string(),
400 ));
401 }
402
403 let mapped_values = values
405 .iter()
406 .map(|dict_value| {
408 dict_value
409 .map(|dict_value| self.get_or_insert_key(dict_value))
410 .transpose()
411 })
412 .collect::<Result<Vec<_>, _>>()?;
413
414 dictionary.keys().iter().for_each(|key| match key {
416 None => self.append_null(),
417 Some(original_dict_index) => {
418 let index = original_dict_index.as_usize().min(v_len - 1);
419 match mapped_values[index] {
420 None => self.append_null(),
421 Some(mapped_value) => self.keys_builder.append_value(mapped_value),
422 }
423 }
424 });
425
426 Ok(())
427 }
428
429 pub fn finish(&mut self) -> DictionaryArray<K> {
431 self.map.clear();
432 let values = self.values_builder.finish();
433 let keys = self.keys_builder.finish();
434
435 let data_type =
436 DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone()));
437
438 let builder = keys
439 .into_data()
440 .into_builder()
441 .data_type(data_type)
442 .child_data(vec![values.into_data()]);
443
444 DictionaryArray::from(unsafe { builder.build_unchecked() })
445 }
446
447 pub fn finish_cloned(&self) -> DictionaryArray<K> {
449 let values = self.values_builder.finish_cloned();
450 let keys = self.keys_builder.finish_cloned();
451
452 let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE));
453
454 let builder = keys
455 .into_data()
456 .into_builder()
457 .data_type(data_type)
458 .child_data(vec![values.into_data()]);
459
460 DictionaryArray::from(unsafe { builder.build_unchecked() })
461 }
462
463 pub fn values_slice(&self) -> &[V::Native] {
465 self.values_builder.values_slice()
466 }
467
468 pub fn values_slice_mut(&mut self) -> &mut [V::Native] {
470 self.values_builder.values_slice_mut()
471 }
472
473 pub fn validity_slice(&self) -> Option<&[u8]> {
475 self.keys_builder.validity_slice()
476 }
477}
478
479impl<K: ArrowDictionaryKeyType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
480 for PrimitiveDictionaryBuilder<K, P>
481{
482 #[inline]
483 fn extend<T: IntoIterator<Item = Option<P::Native>>>(&mut self, iter: T) {
484 for v in iter {
485 self.append_option(v)
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 use crate::array::{Int32Array, UInt32Array, UInt8Array};
495 use crate::builder::Decimal128Builder;
496 use crate::cast::AsArray;
497 use crate::types::{
498 Date32Type, Decimal128Type, DurationNanosecondType, Float32Type, Float64Type, Int16Type,
499 Int32Type, Int64Type, Int8Type, TimestampNanosecondType, UInt16Type, UInt32Type,
500 UInt64Type, UInt8Type,
501 };
502
503 #[test]
504 fn test_primitive_dictionary_builder() {
505 let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(3, 2);
506 builder.append(12345678).unwrap();
507 builder.append_null();
508 builder.append(22345678).unwrap();
509 let array = builder.finish();
510
511 assert_eq!(
512 array.keys(),
513 &UInt8Array::from(vec![Some(0), None, Some(1)])
514 );
515
516 let av = array.values();
518 let ava: &UInt32Array = av.as_any().downcast_ref::<UInt32Array>().unwrap();
519 let avs: &[u32] = ava.values();
520
521 assert!(!array.is_null(0));
522 assert!(array.is_null(1));
523 assert!(!array.is_null(2));
524
525 assert_eq!(avs, &[12345678, 22345678]);
526 }
527
528 #[test]
529 fn test_extend() {
530 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
531 builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some));
532 builder.extend([4, 5, 1, 3, 1].into_iter().map(Some));
533 let dict = builder.finish();
534 assert_eq!(
535 dict.keys().values(),
536 &[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 0, 2, 0]
537 );
538 assert_eq!(dict.values().len(), 5);
539 }
540
541 #[test]
542 #[should_panic(expected = "DictionaryKeyOverflowError")]
543 fn test_primitive_dictionary_overflow() {
544 let mut builder =
545 PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(257, 257);
546 for i in 0..256 {
548 builder.append(i + 1000).unwrap();
549 }
550 builder.append(1257).unwrap();
552 }
553
554 #[test]
555 fn test_primitive_dictionary_with_builders() {
556 let keys_builder = PrimitiveBuilder::<Int32Type>::new();
557 let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2));
558 let mut builder =
559 PrimitiveDictionaryBuilder::<Int32Type, Decimal128Type>::new_from_empty_builders(
560 keys_builder,
561 values_builder,
562 );
563 let dict_array = builder.finish();
564 assert_eq!(dict_array.value_type(), DataType::Decimal128(1, 2));
565 assert_eq!(
566 dict_array.data_type(),
567 &DataType::Dictionary(
568 Box::new(DataType::Int32),
569 Box::new(DataType::Decimal128(1, 2)),
570 )
571 );
572 }
573
574 #[test]
575 fn test_extend_dictionary() {
576 let some_dict = {
577 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
578 builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some));
579 builder.extend([None::<i32>]);
580 builder.extend([4, 5, 1, 3, 1].into_iter().map(Some));
581 builder.append_null();
582 builder.finish()
583 };
584
585 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
586 builder.extend([6, 6, 7, 6, 5].into_iter().map(Some));
587 builder
588 .extend_dictionary(&some_dict.downcast_dict().unwrap())
589 .unwrap();
590 let dict = builder.finish();
591
592 assert_eq!(dict.values().len(), 7);
593
594 let values = dict
595 .downcast_dict::<Int32Array>()
596 .unwrap()
597 .into_iter()
598 .collect::<Vec<_>>();
599
600 assert_eq!(
601 values,
602 [
603 Some(6),
604 Some(6),
605 Some(7),
606 Some(6),
607 Some(5),
608 Some(1),
609 Some(2),
610 Some(3),
611 Some(1),
612 Some(2),
613 Some(3),
614 Some(1),
615 Some(2),
616 Some(3),
617 None,
618 Some(4),
619 Some(5),
620 Some(1),
621 Some(3),
622 Some(1),
623 None
624 ]
625 );
626 }
627
628 #[test]
629 fn test_extend_dictionary_with_null_in_mapped_value() {
630 let some_dict = {
631 let mut values_builder = PrimitiveBuilder::<Int32Type>::new();
632 let mut keys_builder = PrimitiveBuilder::<Int32Type>::new();
633
634 values_builder.append_null();
636 keys_builder.append_value(0);
637 values_builder.append_value(42);
638 keys_builder.append_value(1);
639
640 let values = values_builder.finish();
641 let keys = keys_builder.finish();
642
643 let data_type = DataType::Dictionary(
644 Box::new(Int32Type::DATA_TYPE),
645 Box::new(values.data_type().clone()),
646 );
647
648 let builder = keys
649 .into_data()
650 .into_builder()
651 .data_type(data_type)
652 .child_data(vec![values.into_data()]);
653
654 DictionaryArray::from(unsafe { builder.build_unchecked() })
655 };
656
657 let some_dict_values = some_dict.values().as_primitive::<Int32Type>();
658 assert_eq!(
659 some_dict_values.into_iter().collect::<Vec<_>>(),
660 &[None, Some(42)]
661 );
662
663 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
664 builder
665 .extend_dictionary(&some_dict.downcast_dict().unwrap())
666 .unwrap();
667 let dict = builder.finish();
668
669 assert_eq!(dict.values().len(), 1);
670
671 let values = dict
672 .downcast_dict::<Int32Array>()
673 .unwrap()
674 .into_iter()
675 .collect::<Vec<_>>();
676
677 assert_eq!(values, [None, Some(42)]);
678 }
679
680 #[test]
681 fn test_extend_all_null_dictionary() {
682 let some_dict = {
683 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
684 builder.append_nulls(2);
685 builder.finish()
686 };
687
688 let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new();
689 builder
690 .extend_dictionary(&some_dict.downcast_dict().unwrap())
691 .unwrap();
692 let dict = builder.finish();
693
694 assert_eq!(dict.values().len(), 0);
695
696 let values = dict
697 .downcast_dict::<Int32Array>()
698 .unwrap()
699 .into_iter()
700 .collect::<Vec<_>>();
701
702 assert_eq!(values, [None, None]);
703 }
704
705 #[test]
706 fn creating_dictionary_from_builders_should_use_values_capacity_for_the_map() {
707 let builder = PrimitiveDictionaryBuilder::<Int32Type, crate::types::TimestampMicrosecondType>::new_from_empty_builders(
708 PrimitiveBuilder::with_capacity(1).with_data_type(DataType::Int32),
709 PrimitiveBuilder::with_capacity(2).with_data_type(DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("+08:00".into()))),
710 );
711
712 assert!(
713 builder.map.capacity() >= builder.values_builder.capacity(),
714 "map capacity {} should be at least the values capacity {}",
715 builder.map.capacity(),
716 builder.values_builder.capacity()
717 )
718 }
719
720 fn _test_try_new_from_builder_generic_for_key_types<K1, K2, V>(values: Vec<V::Native>)
721 where
722 K1: ArrowDictionaryKeyType,
723 K1::Native: NumCast,
724 K2: ArrowDictionaryKeyType,
725 K2::Native: NumCast + From<u8>,
726 V: ArrowPrimitiveType,
727 {
728 let mut source = PrimitiveDictionaryBuilder::<K1, V>::new();
729 source.append(values[0]).unwrap();
730 source.append_null();
731 source.append(values[1]).unwrap();
732 source.append(values[2]).unwrap();
733
734 let mut result = PrimitiveDictionaryBuilder::<K2, V>::try_new_from_builder(source).unwrap();
735 let array = result.finish();
736
737 let mut expected_keys_builder = PrimitiveBuilder::<K2>::new();
738 expected_keys_builder
739 .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(0u8));
740 expected_keys_builder.append_null();
741 expected_keys_builder
742 .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(1u8));
743 expected_keys_builder
744 .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(2u8));
745 let expected_keys = expected_keys_builder.finish();
746 assert_eq!(array.keys(), &expected_keys);
747
748 let av = array.values();
749 let ava = av.as_any().downcast_ref::<PrimitiveArray<V>>().unwrap();
750 assert_eq!(ava.value(0), values[0]);
751 assert_eq!(ava.value(1), values[1]);
752 assert_eq!(ava.value(2), values[2]);
753 }
754
755 fn _test_try_new_from_builder_generic_for_value<T>(values: Vec<T::Native>)
756 where
757 T: ArrowPrimitiveType,
758 {
759 _test_try_new_from_builder_generic_for_key_types::<UInt8Type, UInt16Type, T>(
761 values.clone(),
762 );
763 _test_try_new_from_builder_generic_for_key_types::<UInt16Type, UInt8Type, T>(
765 values.clone(),
766 );
767 _test_try_new_from_builder_generic_for_key_types::<Int8Type, Int16Type, T>(values.clone());
769 _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone());
771 _test_try_new_from_builder_generic_for_key_types::<UInt8Type, Int16Type, T>(values.clone());
773 _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt8Type, T>(values.clone());
774 _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt16Type, T>(values.clone());
775 _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone());
776 }
777
778 #[test]
779 fn test_try_new_from_builder() {
780 _test_try_new_from_builder_generic_for_value::<UInt8Type>(vec![1, 2, 3]);
782 _test_try_new_from_builder_generic_for_value::<UInt16Type>(vec![1, 2, 3]);
783 _test_try_new_from_builder_generic_for_value::<UInt32Type>(vec![1, 2, 3]);
784 _test_try_new_from_builder_generic_for_value::<UInt64Type>(vec![1, 2, 3]);
785 _test_try_new_from_builder_generic_for_value::<Int8Type>(vec![-1, 0, 1]);
787 _test_try_new_from_builder_generic_for_value::<Int16Type>(vec![-1, 0, 1]);
788 _test_try_new_from_builder_generic_for_value::<Int32Type>(vec![-1, 0, 1]);
789 _test_try_new_from_builder_generic_for_value::<Int64Type>(vec![-1, 0, 1]);
790 _test_try_new_from_builder_generic_for_value::<Date32Type>(vec![5, 6, 7]);
792 _test_try_new_from_builder_generic_for_value::<DurationNanosecondType>(vec![1, 2, 3]);
793 _test_try_new_from_builder_generic_for_value::<TimestampNanosecondType>(vec![1, 2, 3]);
794 _test_try_new_from_builder_generic_for_value::<Float32Type>(vec![0.1, 0.2, 0.3]);
796 _test_try_new_from_builder_generic_for_value::<Float64Type>(vec![-0.1, 0.2, 0.3]);
797 }
798
799 #[test]
800 fn test_try_new_from_builder_cast_fails() {
801 let mut source_builder = PrimitiveDictionaryBuilder::<UInt16Type, UInt64Type>::new();
802 for i in 0..257 {
803 source_builder.append_value(i);
804 }
805
806 let result = PrimitiveDictionaryBuilder::<UInt8Type, UInt64Type>::try_new_from_builder(
809 source_builder,
810 );
811 assert!(result.is_err());
812 if let Err(e) = result {
813 assert!(matches!(e, ArrowError::CastError(_)));
814 assert_eq!(
815 e.to_string(),
816 "Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8"
817 );
818 }
819 }
820}