1use crate::interleave::interleave;
19use ahash::RandomState;
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_array::cast::AsArray;
22use arrow_array::types::{
23 ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
24};
25use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray};
26use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer};
27use arrow_schema::{ArrowError, DataType};
28
29struct Interner<'a, V> {
34 state: RandomState,
35 buckets: Vec<Option<InternerBucket<'a, V>>>,
36 shift: u32,
37}
38
39type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
41
42impl<'a, V> Interner<'a, V> {
43 fn new(capacity: usize) -> Self {
48 let shift = (capacity as u64 + 128).leading_zeros();
50 let num_buckets = (u64::MAX >> shift) as usize;
51 let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
52 Self {
53 state: RandomState::with_seeds(0, 0, 0, 0),
55 buckets,
56 shift,
57 }
58 }
59
60 fn intern<F: FnOnce() -> Result<V, E>, E>(
61 &mut self,
62 new: Option<&'a [u8]>,
63 f: F,
64 ) -> Result<&V, E> {
65 let hash = self.state.hash_one(new);
66 let bucket_idx = hash >> self.shift;
67 Ok(match &mut self.buckets[bucket_idx as usize] {
68 Some((current, v)) => {
69 if *current != new {
70 *v = f()?;
71 *current = new;
72 }
73 v
74 }
75 slot => &slot.insert((new, f()?)).1,
76 })
77 }
78}
79
80pub struct MergedDictionaries<K: ArrowDictionaryKeyType> {
81 pub key_mappings: Vec<Vec<K::Native>>,
83 pub values: ArrayRef,
85}
86
87fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
91 match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
92 (Some(a), Some(b)) => {
93 let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
94 match (a.nulls(), b.nulls()) {
95 (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
96 (None, None) => values_eq,
97 _ => false,
98 }
99 }
100 _ => false,
101 }
102}
103
104type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
106
107pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
113 dictionaries: &[&DictionaryArray<K>],
114 len: usize,
115) -> bool {
116 use DataType::*;
117 let first_values = dictionaries[0].values().as_ref();
118 let ptr_eq: Box<PtrEq> = match first_values.data_type() {
119 Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
120 LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
121 Binary => Box::new(bytes_ptr_eq::<BinaryType>),
122 LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
123 _ => return false,
124 };
125
126 let mut single_dictionary = true;
127 let mut total_values = first_values.len();
128 for dict in dictionaries.iter().skip(1) {
129 let values = dict.values().as_ref();
130 total_values += values.len();
131 if single_dictionary {
132 single_dictionary = ptr_eq(first_values, values)
133 }
134 }
135
136 let overflow = K::Native::from_usize(total_values).is_none();
137 let values_exceed_length = total_values >= len;
138
139 !single_dictionary && (overflow || values_exceed_length)
140}
141
142pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
151 dictionaries: &[&DictionaryArray<K>],
152 masks: Option<&[BooleanBuffer]>,
153) -> Result<MergedDictionaries<K>, ArrowError> {
154 let mut num_values = 0;
155
156 let mut values_arrays = Vec::with_capacity(dictionaries.len());
157 let mut value_slices = Vec::with_capacity(dictionaries.len());
158
159 for (idx, dictionary) in dictionaries.iter().enumerate() {
160 let mask = masks.and_then(|m| m.get(idx));
161 let key_mask_owned;
162 let key_mask = match (dictionary.nulls(), mask) {
163 (Some(n), None) => Some(n.inner()),
164 (None, Some(n)) => Some(n),
165 (Some(n), Some(m)) => {
166 key_mask_owned = n.inner() & m;
167 Some(&key_mask_owned)
168 }
169 (None, None) => None,
170 };
171 let keys = dictionary.keys().values();
172 let values = dictionary.values().as_ref();
173 let values_mask = compute_values_mask(keys, key_mask, values.len());
174
175 let masked_values = get_masked_values(values, &values_mask);
176 num_values += masked_values.len();
177 value_slices.push(masked_values);
178 values_arrays.push(values)
179 }
180
181 let mut interner = Interner::new(num_values);
183 let mut indices = Vec::with_capacity(num_values);
185
186 let key_mappings = dictionaries
188 .iter()
189 .enumerate()
190 .zip(value_slices)
191 .map(|((dictionary_idx, dictionary), values)| {
192 let zero = K::Native::from_usize(0).unwrap();
193 let mut mapping = vec![zero; dictionary.values().len()];
194
195 for (value_idx, value) in values {
196 mapping[value_idx] =
197 *interner.intern(value, || match K::Native::from_usize(indices.len()) {
198 Some(idx) => {
199 indices.push((dictionary_idx, value_idx));
200 Ok(idx)
201 }
202 None => Err(ArrowError::DictionaryKeyOverflowError),
203 })?;
204 }
205 Ok(mapping)
206 })
207 .collect::<Result<Vec<_>, ArrowError>>()?;
208
209 Ok(MergedDictionaries {
210 key_mappings,
211 values: interleave(&values_arrays, &indices)?,
212 })
213}
214
215fn compute_values_mask<K: ArrowNativeType>(
218 keys: &ScalarBuffer<K>,
219 mask: Option<&BooleanBuffer>,
220 max_key: usize,
221) -> BooleanBuffer {
222 let mut builder = BooleanBufferBuilder::new(max_key);
223 builder.advance(max_key);
224
225 match mask {
226 Some(n) => n
227 .set_indices()
228 .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
229 None => keys
230 .iter()
231 .for_each(|k| builder.set_bit(k.as_usize(), true)),
232 }
233 builder.finish()
234}
235
236fn get_masked_values<'a>(
238 array: &'a dyn Array,
239 mask: &BooleanBuffer,
240) -> Vec<(usize, Option<&'a [u8]>)> {
241 match array.data_type() {
242 DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
243 DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
244 DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
245 DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
246 _ => unimplemented!(),
247 }
248}
249
250fn masked_bytes<'a, T: ByteArrayType>(
254 array: &'a GenericByteArray<T>,
255 mask: &BooleanBuffer,
256) -> Vec<(usize, Option<&'a [u8]>)> {
257 let mut out = Vec::with_capacity(mask.count_set_bits());
258 for idx in mask.set_indices() {
259 out.push((
260 idx,
261 array.is_valid(idx).then_some(array.value(idx).as_ref()),
262 ))
263 }
264 out
265}
266
267#[cfg(test)]
268mod tests {
269 use crate::dictionary::merge_dictionary_values;
270 use arrow_array::cast::as_string_array;
271 use arrow_array::types::Int32Type;
272 use arrow_array::{DictionaryArray, Int32Array, StringArray};
273 use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
274 use std::sync::Arc;
275
276 #[test]
277 fn test_merge_strings() {
278 let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
279 let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
280 let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
281
282 let values = as_string_array(merged.values.as_ref());
283 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
284 assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
285
286 assert_eq!(merged.key_mappings.len(), 2);
287 assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
288 assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
289
290 let a_slice = a.slice(1, 4);
291 let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
292
293 let values = as_string_array(merged.values.as_ref());
294 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
295 assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
296
297 assert_eq!(merged.key_mappings.len(), 2);
298 assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
299 assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
300
301 let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
303 let b_mask = BooleanBuffer::new_set(b.len());
304 let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
305
306 let values = as_string_array(merged.values.as_ref());
307 let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
308 assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
309
310 assert_eq!(merged.key_mappings.len(), 2);
311 assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
312 assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
313 }
314
315 #[test]
316 fn test_merge_nulls() {
317 let buffer = Buffer::from(b"helloworldbingohelloworld");
318 let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
319 let nulls = NullBuffer::from(vec![true, false, true, true, true]);
320 let values = StringArray::new(offsets, buffer, Some(nulls));
321
322 let key_values = vec![1, 2, 3, 1, 8, 2, 3];
323 let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
324 let keys = Int32Array::new(key_values.into(), Some(key_nulls));
325 let a = DictionaryArray::new(keys, Arc::new(values));
326 let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
329
330 let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
331 let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
332 assert_eq!(merged.values.as_ref(), &expected);
333 assert_eq!(merged.key_mappings.len(), 2);
334 assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
335 assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
336 }
337
338 #[test]
339 fn test_merge_keys_smaller() {
340 let values = StringArray::from_iter_values(["a", "b"]);
341 let keys = Int32Array::from_iter_values([1]);
342 let a = DictionaryArray::new(keys, Arc::new(values));
343
344 let merged = merge_dictionary_values(&[&a], None).unwrap();
345 let expected = StringArray::from(vec!["b"]);
346 assert_eq!(merged.values.as_ref(), &expected);
347 }
348}