arrow_select/
dictionary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Dictionary utilities for Arrow arrays
19
20use std::sync::Arc;
21
22use crate::filter::filter;
23use crate::interleave::interleave;
24use ahash::RandomState;
25use arrow_array::builder::BooleanBufferBuilder;
26use arrow_array::types::{
27    ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
28    LargeUtf8Type, Utf8Type,
29};
30use arrow_array::{cast::AsArray, downcast_primitive};
31use arrow_array::{
32    downcast_dictionary_array, AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp,
33    BooleanArray, DictionaryArray, GenericByteArray, PrimitiveArray,
34};
35use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
36use arrow_schema::{ArrowError, DataType};
37
38/// Garbage collects a [DictionaryArray] by removing unreferenced values.
39///
40/// Returns a new [DictionaryArray] such that there are no values
41/// that are not referenced by at least one key. There may still be duplicate
42/// values.
43///
44/// See also [`garbage_collect_any_dictionary`] if you need to handle multiple dictionary types
45pub fn garbage_collect_dictionary<K: ArrowDictionaryKeyType>(
46    dictionary: &DictionaryArray<K>,
47) -> Result<DictionaryArray<K>, ArrowError> {
48    let keys = dictionary.keys();
49    let values = dictionary.values();
50
51    let mask = dictionary.occupancy();
52
53    // If no work to do, return the original dictionary
54    if mask.count_set_bits() == values.len() {
55        return Ok(dictionary.clone());
56    }
57
58    // Create a mapping from the old keys to the new keys, use a Vec for easy indexing
59    let mut key_remap = vec![K::Native::ZERO; values.len()];
60    for (new_idx, old_idx) in mask.set_indices().enumerate() {
61        key_remap[old_idx] = K::Native::from_usize(new_idx)
62            .expect("new index should fit in K::Native, as old index was in range");
63    }
64
65    // ... and then build the new keys array
66    let new_keys = keys.unary(|key| {
67        key_remap
68            .get(key.as_usize())
69            .copied()
70            // nulls may be present in the keys, and they will have arbitrary value; we don't care
71            // and can safely return zero
72            .unwrap_or(K::Native::ZERO)
73    });
74
75    // Create a new values array by filtering using the mask
76    let values = filter(dictionary.values(), &BooleanArray::new(mask, None))?;
77
78    Ok(DictionaryArray::new(new_keys, values))
79}
80
81/// Equivalent to [`garbage_collect_dictionary`] but without requiring casting to a specific key type.
82pub fn garbage_collect_any_dictionary(
83    dictionary: &dyn AnyDictionaryArray,
84) -> Result<ArrayRef, ArrowError> {
85    // FIXME: this is a workaround for MSRV Rust versions below 1.86 where trait upcasting is not stable.
86    // From 1.86 onward, `&dyn AnyDictionaryArray` can be directly passed to `downcast_dictionary_array!`.
87    let dictionary = &*dictionary.slice(0, dictionary.len());
88    downcast_dictionary_array!(
89        dictionary => garbage_collect_dictionary(dictionary).map(|dict| Arc::new(dict) as ArrayRef),
90        _ => unreachable!("have a dictionary array")
91    )
92}
93
94/// A best effort interner that maintains a fixed number of buckets
95/// and interns keys based on their hash value
96///
97/// Hash collisions will result in replacement
98struct Interner<'a, V> {
99    state: RandomState,
100    buckets: Vec<Option<InternerBucket<'a, V>>>,
101    shift: u32,
102}
103
104/// A single bucket in [`Interner`].
105type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
106
107impl<'a, V> Interner<'a, V> {
108    /// Capacity controls the number of unique buckets allocated within the Interner
109    ///
110    /// A larger capacity reduces the probability of hash collisions, and should be set
111    /// based on an approximation of the upper bound of unique values
112    fn new(capacity: usize) -> Self {
113        // Add additional buckets to help reduce collisions
114        let shift = (capacity as u64 + 128).leading_zeros();
115        let num_buckets = (u64::MAX >> shift) as usize;
116        let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
117        Self {
118            // A fixed seed to ensure deterministic behaviour
119            state: RandomState::with_seeds(0, 0, 0, 0),
120            buckets,
121            shift,
122        }
123    }
124
125    fn intern<F: FnOnce() -> Result<V, E>, E>(
126        &mut self,
127        new: Option<&'a [u8]>,
128        f: F,
129    ) -> Result<&V, E> {
130        let hash = self.state.hash_one(new);
131        let bucket_idx = hash >> self.shift;
132        Ok(match &mut self.buckets[bucket_idx as usize] {
133            Some((current, v)) => {
134                if *current != new {
135                    *v = f()?;
136                    *current = new;
137                }
138                v
139            }
140            slot => &slot.insert((new, f()?)).1,
141        })
142    }
143}
144
145pub(crate) struct MergedDictionaries<K: ArrowDictionaryKeyType> {
146    /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key`
147    pub key_mappings: Vec<Vec<K::Native>>,
148    /// The new values
149    pub values: ArrayRef,
150}
151
152/// Performs a cheap, pointer-based comparison of two byte array
153///
154/// See [`ScalarBuffer::ptr_eq`]
155fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
156    match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
157        (Some(a), Some(b)) => {
158            let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
159            match (a.nulls(), b.nulls()) {
160                (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
161                (None, None) => values_eq,
162                _ => false,
163            }
164        }
165        _ => false,
166    }
167}
168
169/// A type-erased function that compares two array for pointer equality
170type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
171
172/// A weak heuristic of whether to merge dictionary values that aims to only
173/// perform the expensive merge computation when it is likely to yield at least
174/// some return over the naive approach used by MutableArrayData
175///
176/// `len` is the total length of the merged output
177pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
178    dictionaries: &[&DictionaryArray<K>],
179    len: usize,
180) -> bool {
181    use DataType::*;
182    let first_values = dictionaries[0].values().as_ref();
183    let ptr_eq: PtrEq = match first_values.data_type() {
184        Utf8 => bytes_ptr_eq::<Utf8Type>,
185        LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
186        Binary => bytes_ptr_eq::<BinaryType>,
187        LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
188        dt => {
189            if !dt.is_primitive() {
190                return false;
191            }
192            |a, b| a.to_data().ptr_eq(&b.to_data())
193        }
194    };
195
196    let mut single_dictionary = true;
197    let mut total_values = first_values.len();
198    for dict in dictionaries.iter().skip(1) {
199        let values = dict.values().as_ref();
200        total_values += values.len();
201        if single_dictionary {
202            single_dictionary = ptr_eq(first_values, values)
203        }
204    }
205
206    let overflow = K::Native::from_usize(total_values).is_none();
207    let values_exceed_length = total_values >= len;
208
209    !single_dictionary && (overflow || values_exceed_length)
210}
211
212/// Given an array of dictionaries and an optional key mask compute a values array
213/// containing referenced values, along with mappings from the [`DictionaryArray`]
214/// keys to the new keys within this values array. Best-effort will be made to ensure
215/// that the dictionary values are unique
216///
217/// This method is meant to be very fast and the output dictionary values
218/// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower
219/// but produces unique values
220pub(crate) fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
221    dictionaries: &[&DictionaryArray<K>],
222    masks: Option<&[BooleanBuffer]>,
223) -> Result<MergedDictionaries<K>, ArrowError> {
224    let mut num_values = 0;
225
226    let mut values_arrays = Vec::with_capacity(dictionaries.len());
227    let mut value_slices = Vec::with_capacity(dictionaries.len());
228
229    for (idx, dictionary) in dictionaries.iter().enumerate() {
230        let mask = masks.and_then(|m| m.get(idx));
231        let key_mask_owned;
232        let key_mask = match (dictionary.nulls(), mask) {
233            (Some(n), None) => Some(n.inner()),
234            (None, Some(n)) => Some(n),
235            (Some(n), Some(m)) => {
236                key_mask_owned = n.inner() & m;
237                Some(&key_mask_owned)
238            }
239            (None, None) => None,
240        };
241        let keys = dictionary.keys().values();
242        let values = dictionary.values().as_ref();
243        let values_mask = compute_values_mask(keys, key_mask, values.len());
244
245        let masked_values = get_masked_values(values, &values_mask);
246        num_values += masked_values.len();
247        value_slices.push(masked_values);
248        values_arrays.push(values)
249    }
250
251    // Map from value to new index
252    let mut interner = Interner::new(num_values);
253    // Interleave indices for new values array
254    let mut indices = Vec::with_capacity(num_values);
255
256    // Compute the mapping for each dictionary
257    let key_mappings = dictionaries
258        .iter()
259        .enumerate()
260        .zip(value_slices)
261        .map(|((dictionary_idx, dictionary), values)| {
262            let zero = K::Native::from_usize(0).unwrap();
263            let mut mapping = vec![zero; dictionary.values().len()];
264
265            for (value_idx, value) in values {
266                mapping[value_idx] =
267                    *interner.intern(value, || match K::Native::from_usize(indices.len()) {
268                        Some(idx) => {
269                            indices.push((dictionary_idx, value_idx));
270                            Ok(idx)
271                        }
272                        None => Err(ArrowError::DictionaryKeyOverflowError),
273                    })?;
274            }
275            Ok(mapping)
276        })
277        .collect::<Result<Vec<_>, ArrowError>>()?;
278
279    Ok(MergedDictionaries {
280        key_mappings,
281        values: interleave(&values_arrays, &indices)?,
282    })
283}
284
285/// Return a mask identifying the values that are referenced by keys in `dictionary`
286/// at the positions indicated by `selection`
287fn compute_values_mask<K: ArrowNativeType>(
288    keys: &ScalarBuffer<K>,
289    mask: Option<&BooleanBuffer>,
290    max_key: usize,
291) -> BooleanBuffer {
292    let mut builder = BooleanBufferBuilder::new(max_key);
293    builder.advance(max_key);
294
295    match mask {
296        Some(n) => n
297            .set_indices()
298            .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
299        None => keys
300            .iter()
301            .for_each(|k| builder.set_bit(k.as_usize(), true)),
302    }
303    builder.finish()
304}
305
306/// Process primitive array values to bytes
307fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
308    array: &'a PrimitiveArray<T>,
309    mask: &BooleanBuffer,
310) -> Vec<(usize, Option<&'a [u8]>)>
311where
312    T::Native: ToByteSlice,
313{
314    let mut out = Vec::with_capacity(mask.count_set_bits());
315    let values = array.values();
316    for idx in mask.set_indices() {
317        out.push((
318            idx,
319            array.is_valid(idx).then_some(values[idx].to_byte_slice()),
320        ))
321    }
322    out
323}
324
325macro_rules! masked_primitive_to_bytes_helper {
326    ($t:ty, $array:expr, $mask:expr) => {
327        masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
328    };
329}
330
331/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
332fn get_masked_values<'a>(
333    array: &'a dyn Array,
334    mask: &BooleanBuffer,
335) -> Vec<(usize, Option<&'a [u8]>)> {
336    downcast_primitive! {
337        array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
338        DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
339        DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
340        DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
341        DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
342        _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
343    }
344}
345
346/// Compute [`get_masked_values`] for a [`GenericByteArray`]
347///
348/// Note: this does not check the null mask and will return values contained in null slots
349fn masked_bytes<'a, T: ByteArrayType>(
350    array: &'a GenericByteArray<T>,
351    mask: &BooleanBuffer,
352) -> Vec<(usize, Option<&'a [u8]>)> {
353    let mut out = Vec::with_capacity(mask.count_set_bits());
354    for idx in mask.set_indices() {
355        out.push((
356            idx,
357            array.is_valid(idx).then_some(array.value(idx).as_ref()),
358        ))
359    }
360    out
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    use arrow_array::cast::as_string_array;
368    use arrow_array::types::Int32Type;
369    use arrow_array::types::Int8Type;
370    use arrow_array::{DictionaryArray, Int32Array, Int8Array, StringArray};
371    use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
372    use std::sync::Arc;
373
374    #[test]
375    fn test_garbage_collect_i32_dictionary() {
376        let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
377        let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
378        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
379
380        // Only "a", "b", "d" are referenced, "c" is not
381        let gc = garbage_collect_dictionary(&dict).unwrap();
382
383        let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
384        let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
385        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
386
387        assert_eq!(gc, expected);
388    }
389
390    #[test]
391    fn test_garbage_collect_any_dictionary() {
392        let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
393        let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
394        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
395
396        let gc = garbage_collect_any_dictionary(&dict).unwrap();
397
398        let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
399        let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
400        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
401
402        assert_eq!(gc.as_ref(), &expected);
403    }
404
405    #[test]
406    fn test_garbage_collect_with_nulls() {
407        let values = StringArray::from_iter_values(["a", "b", "c"]);
408        let keys = Int8Array::from(vec![Some(2), None, Some(0)]);
409        let dict = DictionaryArray::<Int8Type>::new(keys, Arc::new(values));
410
411        let gc = garbage_collect_dictionary(&dict).unwrap();
412
413        let expected_values = StringArray::from_iter_values(["a", "c"]);
414        let expected_keys = Int8Array::from(vec![Some(1), None, Some(0)]);
415        let expected = DictionaryArray::<Int8Type>::new(expected_keys, Arc::new(expected_values));
416
417        assert_eq!(gc, expected);
418    }
419
420    #[test]
421    fn test_garbage_collect_empty_dictionary() {
422        let values = StringArray::from_iter_values::<&str, _>([]);
423        let keys = Int32Array::from_iter_values([]);
424        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
425
426        let gc = garbage_collect_dictionary(&dict).unwrap();
427
428        assert_eq!(gc, dict);
429    }
430
431    #[test]
432    fn test_garbage_collect_dictionary_all_unreferenced() {
433        let values = StringArray::from_iter_values(["a", "b", "c"]);
434        let keys = Int32Array::from(vec![None, None, None]);
435        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
436
437        let gc = garbage_collect_dictionary(&dict).unwrap();
438
439        // All keys are null, so dictionary values can be empty
440        let expected_values = StringArray::from_iter_values::<&str, _>([]);
441        let expected_keys = Int32Array::from(vec![None, None, None]);
442        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
443
444        assert_eq!(gc, expected);
445    }
446
447    #[test]
448    fn test_merge_strings() {
449        let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
450        let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
451        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
452
453        let values = as_string_array(merged.values.as_ref());
454        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
455        assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
456
457        assert_eq!(merged.key_mappings.len(), 2);
458        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
459        assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
460
461        let a_slice = a.slice(1, 4);
462        let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
463
464        let values = as_string_array(merged.values.as_ref());
465        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
466        assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
467
468        assert_eq!(merged.key_mappings.len(), 2);
469        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
470        assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
471
472        // Mask out only ["b", "b", "d"] from a
473        let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
474        let b_mask = BooleanBuffer::new_set(b.len());
475        let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
476
477        let values = as_string_array(merged.values.as_ref());
478        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
479        assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
480
481        assert_eq!(merged.key_mappings.len(), 2);
482        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
483        assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
484    }
485
486    #[test]
487    fn test_merge_nulls() {
488        let buffer = Buffer::from(b"helloworldbingohelloworld");
489        let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
490        let nulls = NullBuffer::from(vec![true, false, true, true, true]);
491        let values = StringArray::new(offsets, buffer, Some(nulls));
492
493        let key_values = vec![1, 2, 3, 1, 8, 2, 3];
494        let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
495        let keys = Int32Array::new(key_values.into(), Some(key_nulls));
496        let a = DictionaryArray::new(keys, Arc::new(values));
497        // [NULL, "bingo", NULL, NULL, NULL, "bingo", "hello"]
498
499        let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
500
501        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
502        let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
503        assert_eq!(merged.values.as_ref(), &expected);
504        assert_eq!(merged.key_mappings.len(), 2);
505        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
506        assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
507    }
508
509    #[test]
510    fn test_merge_keys_smaller() {
511        let values = StringArray::from_iter_values(["a", "b"]);
512        let keys = Int32Array::from_iter_values([1]);
513        let a = DictionaryArray::new(keys, Arc::new(values));
514
515        let merged = merge_dictionary_values(&[&a], None).unwrap();
516        let expected = StringArray::from(vec!["b"]);
517        assert_eq!(merged.values.as_ref(), &expected);
518    }
519}