1use crate::interleave::interleave;
19use ahash::RandomState;
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_array::types::{
22 ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
23 LargeUtf8Type, Utf8Type,
24};
25use arrow_array::{cast::AsArray, downcast_primitive};
26use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
27use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
28use arrow_schema::{ArrowError, DataType};
29
30struct Interner<'a, V> {
35 state: RandomState,
36 buckets: Vec<Option<InternerBucket<'a, V>>>,
37 shift: u32,
38}
39
40type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
42
43impl<'a, V> Interner<'a, V> {
44 fn new(capacity: usize) -> Self {
49 let shift = (capacity as u64 + 128).leading_zeros();
51 let num_buckets = (u64::MAX >> shift) as usize;
52 let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
53 Self {
54 state: RandomState::with_seeds(0, 0, 0, 0),
56 buckets,
57 shift,
58 }
59 }
60
61 fn intern<F: FnOnce() -> Result<V, E>, E>(
62 &mut self,
63 new: Option<&'a [u8]>,
64 f: F,
65 ) -> Result<&V, E> {
66 let hash = self.state.hash_one(new);
67 let bucket_idx = hash >> self.shift;
68 Ok(match &mut self.buckets[bucket_idx as usize] {
69 Some((current, v)) => {
70 if *current != new {
71 *v = f()?;
72 *current = new;
73 }
74 v
75 }
76 slot => &slot.insert((new, f()?)).1,
77 })
78 }
79}
80
81pub struct MergedDictionaries<K: ArrowDictionaryKeyType> {
82 pub key_mappings: Vec<Vec<K::Native>>,
84 pub values: ArrayRef,
86}
87
88fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
92 match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
93 (Some(a), Some(b)) => {
94 let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
95 match (a.nulls(), b.nulls()) {
96 (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
97 (None, None) => values_eq,
98 _ => false,
99 }
100 }
101 _ => false,
102 }
103}
104
105type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
107
108pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
114 dictionaries: &[&DictionaryArray<K>],
115 len: usize,
116) -> bool {
117 use DataType::*;
118 let first_values = dictionaries[0].values().as_ref();
119 let ptr_eq: PtrEq = match first_values.data_type() {
120 Utf8 => bytes_ptr_eq::<Utf8Type>,
121 LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
122 Binary => bytes_ptr_eq::<BinaryType>,
123 LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
124 dt => {
125 if !dt.is_primitive() {
126 return false;
127 }
128 |a, b| a.to_data().ptr_eq(&b.to_data())
129 }
130 };
131
132 let mut single_dictionary = true;
133 let mut total_values = first_values.len();
134 for dict in dictionaries.iter().skip(1) {
135 let values = dict.values().as_ref();
136 total_values += values.len();
137 if single_dictionary {
138 single_dictionary = ptr_eq(first_values, values)
139 }
140 }
141
142 let overflow = K::Native::from_usize(total_values).is_none();
143 let values_exceed_length = total_values >= len;
144
145 !single_dictionary && (overflow || values_exceed_length)
146}
147
148pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
157 dictionaries: &[&DictionaryArray<K>],
158 masks: Option<&[BooleanBuffer]>,
159) -> Result<MergedDictionaries<K>, ArrowError> {
160 let mut num_values = 0;
161
162 let mut values_arrays = Vec::with_capacity(dictionaries.len());
163 let mut value_slices = Vec::with_capacity(dictionaries.len());
164
165 for (idx, dictionary) in dictionaries.iter().enumerate() {
166 let mask = masks.and_then(|m| m.get(idx));
167 let key_mask_owned;
168 let key_mask = match (dictionary.nulls(), mask) {
169 (Some(n), None) => Some(n.inner()),
170 (None, Some(n)) => Some(n),
171 (Some(n), Some(m)) => {
172 key_mask_owned = n.inner() & m;
173 Some(&key_mask_owned)
174 }
175 (None, None) => None,
176 };
177 let keys = dictionary.keys().values();
178 let values = dictionary.values().as_ref();
179 let values_mask = compute_values_mask(keys, key_mask, values.len());
180
181 let masked_values = get_masked_values(values, &values_mask);
182 num_values += masked_values.len();
183 value_slices.push(masked_values);
184 values_arrays.push(values)
185 }
186
187 let mut interner = Interner::new(num_values);
189 let mut indices = Vec::with_capacity(num_values);
191
192 let key_mappings = dictionaries
194 .iter()
195 .enumerate()
196 .zip(value_slices)
197 .map(|((dictionary_idx, dictionary), values)| {
198 let zero = K::Native::from_usize(0).unwrap();
199 let mut mapping = vec![zero; dictionary.values().len()];
200
201 for (value_idx, value) in values {
202 mapping[value_idx] =
203 *interner.intern(value, || match K::Native::from_usize(indices.len()) {
204 Some(idx) => {
205 indices.push((dictionary_idx, value_idx));
206 Ok(idx)
207 }
208 None => Err(ArrowError::DictionaryKeyOverflowError),
209 })?;
210 }
211 Ok(mapping)
212 })
213 .collect::<Result<Vec<_>, ArrowError>>()?;
214
215 Ok(MergedDictionaries {
216 key_mappings,
217 values: interleave(&values_arrays, &indices)?,
218 })
219}
220
221fn compute_values_mask<K: ArrowNativeType>(
224 keys: &ScalarBuffer<K>,
225 mask: Option<&BooleanBuffer>,
226 max_key: usize,
227) -> BooleanBuffer {
228 let mut builder = BooleanBufferBuilder::new(max_key);
229 builder.advance(max_key);
230
231 match mask {
232 Some(n) => n
233 .set_indices()
234 .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
235 None => keys
236 .iter()
237 .for_each(|k| builder.set_bit(k.as_usize(), true)),
238 }
239 builder.finish()
240}
241
242fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
244 array: &'a PrimitiveArray<T>,
245 mask: &BooleanBuffer,
246) -> Vec<(usize, Option<&'a [u8]>)>
247where
248 T::Native: ToByteSlice,
249{
250 let mut out = Vec::with_capacity(mask.count_set_bits());
251 let values = array.values();
252 for idx in mask.set_indices() {
253 out.push((
254 idx,
255 array.is_valid(idx).then_some(values[idx].to_byte_slice()),
256 ))
257 }
258 out
259}
260
261macro_rules! masked_primitive_to_bytes_helper {
262 ($t:ty, $array:expr, $mask:expr) => {
263 masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
264 };
265}
266
267fn get_masked_values<'a>(
269 array: &'a dyn Array,
270 mask: &BooleanBuffer,
271) -> Vec<(usize, Option<&'a [u8]>)> {
272 downcast_primitive! {
273 array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
274 DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
275 DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
276 DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
277 DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
278 _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
279 }
280}
281
282fn masked_bytes<'a, T: ByteArrayType>(
286 array: &'a GenericByteArray<T>,
287 mask: &BooleanBuffer,
288) -> Vec<(usize, Option<&'a [u8]>)> {
289 let mut out = Vec::with_capacity(mask.count_set_bits());
290 for idx in mask.set_indices() {
291 out.push((
292 idx,
293 array.is_valid(idx).then_some(array.value(idx).as_ref()),
294 ))
295 }
296 out
297}
298
299#[cfg(test)]
300mod tests {
301 use crate::dictionary::merge_dictionary_values;
302 use arrow_array::cast::as_string_array;
303 use arrow_array::types::Int32Type;
304 use arrow_array::{DictionaryArray, Int32Array, StringArray};
305 use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
306 use std::sync::Arc;
307
308 #[test]
309 fn test_merge_strings() {
310 let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
311 let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
312 let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
313
314 let values = as_string_array(merged.values.as_ref());
315 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
316 assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
317
318 assert_eq!(merged.key_mappings.len(), 2);
319 assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
320 assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
321
322 let a_slice = a.slice(1, 4);
323 let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
324
325 let values = as_string_array(merged.values.as_ref());
326 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
327 assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
328
329 assert_eq!(merged.key_mappings.len(), 2);
330 assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
331 assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
332
333 let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
335 let b_mask = BooleanBuffer::new_set(b.len());
336 let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
337
338 let values = as_string_array(merged.values.as_ref());
339 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
340 assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
341
342 assert_eq!(merged.key_mappings.len(), 2);
343 assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
344 assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
345 }
346
347 #[test]
348 fn test_merge_nulls() {
349 let buffer = Buffer::from(b"helloworldbingohelloworld");
350 let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
351 let nulls = NullBuffer::from(vec![true, false, true, true, true]);
352 let values = StringArray::new(offsets, buffer, Some(nulls));
353
354 let key_values = vec![1, 2, 3, 1, 8, 2, 3];
355 let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
356 let keys = Int32Array::new(key_values.into(), Some(key_nulls));
357 let a = DictionaryArray::new(keys, Arc::new(values));
358 let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
361
362 let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
363 let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
364 assert_eq!(merged.values.as_ref(), &expected);
365 assert_eq!(merged.key_mappings.len(), 2);
366 assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
367 assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
368 }
369
370 #[test]
371 fn test_merge_keys_smaller() {
372 let values = StringArray::from_iter_values(["a", "b"]);
373 let keys = Int32Array::from_iter_values([1]);
374 let a = DictionaryArray::new(keys, Arc::new(values));
375
376 let merged = merge_dictionary_values(&[&a], None).unwrap();
377 let expected = StringArray::from(vec!["b"]);
378 assert_eq!(merged.values.as_ref(), &expected);
379 }
380}