arrow_select/
dictionary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Dictionary utilities for Arrow arrays
19
20use std::sync::Arc;
21
22use crate::filter::filter;
23use crate::interleave::interleave;
24use ahash::RandomState;
25use arrow_array::builder::BooleanBufferBuilder;
26use arrow_array::types::{
27    ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
28    LargeUtf8Type, Utf8Type,
29};
30use arrow_array::{
31    AnyDictionaryArray, Array, ArrayRef, ArrowNativeTypeOp, BooleanArray, DictionaryArray,
32    GenericByteArray, PrimitiveArray, downcast_dictionary_array,
33};
34use arrow_array::{cast::AsArray, downcast_primitive};
35use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
36use arrow_schema::{ArrowError, DataType};
37
38/// Garbage collects a [DictionaryArray] by removing unreferenced values.
39///
40/// Returns a new [DictionaryArray] such that there are no values
41/// that are not referenced by at least one key. There may still be duplicate
42/// values.
43///
44/// See also [`garbage_collect_any_dictionary`] if you need to handle multiple dictionary types
45pub fn garbage_collect_dictionary<K: ArrowDictionaryKeyType>(
46    dictionary: &DictionaryArray<K>,
47) -> Result<DictionaryArray<K>, ArrowError> {
48    let keys = dictionary.keys();
49    let values = dictionary.values();
50
51    let mask = dictionary.occupancy();
52
53    // If no work to do, return the original dictionary
54    if mask.count_set_bits() == values.len() {
55        return Ok(dictionary.clone());
56    }
57
58    // Create a mapping from the old keys to the new keys, use a Vec for easy indexing
59    let mut key_remap = vec![K::Native::ZERO; values.len()];
60    for (new_idx, old_idx) in mask.set_indices().enumerate() {
61        key_remap[old_idx] = K::Native::from_usize(new_idx)
62            .expect("new index should fit in K::Native, as old index was in range");
63    }
64
65    // ... and then build the new keys array
66    let new_keys = keys.unary(|key| {
67        key_remap
68            .get(key.as_usize())
69            .copied()
70            // nulls may be present in the keys, and they will have arbitrary value; we don't care
71            // and can safely return zero
72            .unwrap_or(K::Native::ZERO)
73    });
74
75    // Create a new values array by filtering using the mask
76    let values = filter(dictionary.values(), &BooleanArray::new(mask, None))?;
77
78    DictionaryArray::try_new(new_keys, values)
79}
80
81/// Equivalent to [`garbage_collect_dictionary`] but without requiring casting to a specific key type.
82pub fn garbage_collect_any_dictionary(
83    dictionary: &dyn AnyDictionaryArray,
84) -> Result<ArrayRef, ArrowError> {
85    // FIXME: this is a workaround for MSRV Rust versions below 1.86 where trait upcasting is not stable.
86    // From 1.86 onward, `&dyn AnyDictionaryArray` can be directly passed to `downcast_dictionary_array!`.
87    let dictionary = &*dictionary.slice(0, dictionary.len());
88    downcast_dictionary_array!(
89        dictionary => garbage_collect_dictionary(dictionary).map(|dict| Arc::new(dict) as ArrayRef),
90        _ => unreachable!("have a dictionary array")
91    )
92}
93
94/// A best effort interner that maintains a fixed number of buckets
95/// and interns keys based on their hash value
96///
97/// Hash collisions will result in replacement
98struct Interner<'a, V> {
99    state: RandomState,
100    buckets: Vec<Option<InternerBucket<'a, V>>>,
101    shift: u32,
102}
103
104/// A single bucket in [`Interner`].
105type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
106
107impl<'a, V> Interner<'a, V> {
108    /// Capacity controls the number of unique buckets allocated within the Interner
109    ///
110    /// A larger capacity reduces the probability of hash collisions, and should be set
111    /// based on an approximation of the upper bound of unique values
112    fn new(capacity: usize) -> Self {
113        // Add additional buckets to help reduce collisions
114        let shift = (capacity as u64 + 128).leading_zeros();
115        let num_buckets = (u64::MAX >> shift) as usize;
116        let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
117        Self {
118            // A fixed seed to ensure deterministic behaviour
119            state: RandomState::with_seeds(0, 0, 0, 0),
120            buckets,
121            shift,
122        }
123    }
124
125    fn intern<F: FnOnce() -> Result<V, E>, E>(
126        &mut self,
127        new: Option<&'a [u8]>,
128        f: F,
129    ) -> Result<&V, E> {
130        let hash = self.state.hash_one(new);
131        let bucket_idx = hash >> self.shift;
132        Ok(match &mut self.buckets[bucket_idx as usize] {
133            Some((current, v)) => {
134                if *current != new {
135                    *v = f()?;
136                    *current = new;
137                }
138                v
139            }
140            slot => &slot.insert((new, f()?)).1,
141        })
142    }
143}
144
145pub(crate) struct MergedDictionaries<K: ArrowDictionaryKeyType> {
146    /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key`
147    pub key_mappings: Vec<Vec<K::Native>>,
148    /// The new values
149    pub values: ArrayRef,
150}
151
152/// Performs a cheap, pointer-based comparison of two byte array
153///
154/// See [`ScalarBuffer::ptr_eq`]
155fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
156    match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
157        (Some(a), Some(b)) => {
158            let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
159            match (a.nulls(), b.nulls()) {
160                (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
161                (None, None) => values_eq,
162                _ => false,
163            }
164        }
165        _ => false,
166    }
167}
168
169/// A type-erased function that compares two array for pointer equality
170type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
171
172/// A weak heuristic of whether to merge dictionary values that aims to only
173/// perform the expensive merge computation when it is likely to yield at least
174/// some return over the naive approach used by MutableArrayData
175///
176/// `len` is the total length of the merged output
177///
178/// Returns `(should_merge, has_overflow)` where:
179/// - `should_merge`: whether dictionary values should be merged
180/// - `has_overflow`: whether the combined dictionary values would overflow the key type
181pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
182    dictionaries: &[&DictionaryArray<K>],
183    len: usize,
184) -> (bool, bool) {
185    use DataType::*;
186    let first_values = dictionaries[0].values().as_ref();
187    let ptr_eq: PtrEq = match first_values.data_type() {
188        Utf8 => bytes_ptr_eq::<Utf8Type>,
189        LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
190        Binary => bytes_ptr_eq::<BinaryType>,
191        LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
192        dt => {
193            if !dt.is_primitive() {
194                return (
195                    false,
196                    K::Native::from_usize(dictionaries.iter().map(|d| d.values().len()).sum())
197                        .is_none(),
198                );
199            }
200            |a, b| a.to_data().ptr_eq(&b.to_data())
201        }
202    };
203
204    let mut single_dictionary = true;
205    let mut total_values = first_values.len();
206    for dict in dictionaries.iter().skip(1) {
207        let values = dict.values().as_ref();
208        total_values += values.len();
209        if single_dictionary {
210            single_dictionary = ptr_eq(first_values, values)
211        }
212    }
213
214    let overflow = K::Native::from_usize(total_values).is_none();
215    let values_exceed_length = total_values >= len;
216
217    (
218        !single_dictionary && (overflow || values_exceed_length),
219        overflow,
220    )
221}
222
223/// Given an array of dictionaries and an optional key mask compute a values array
224/// containing referenced values, along with mappings from the [`DictionaryArray`]
225/// keys to the new keys within this values array. Best-effort will be made to ensure
226/// that the dictionary values are unique
227///
228/// This method is meant to be very fast and the output dictionary values
229/// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower
230/// but produces unique values
231pub(crate) fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
232    dictionaries: &[&DictionaryArray<K>],
233    masks: Option<&[BooleanBuffer]>,
234) -> Result<MergedDictionaries<K>, ArrowError> {
235    let mut num_values = 0;
236
237    let mut values_arrays = Vec::with_capacity(dictionaries.len());
238    let mut value_slices = Vec::with_capacity(dictionaries.len());
239
240    for (idx, dictionary) in dictionaries.iter().enumerate() {
241        let mask = masks.and_then(|m| m.get(idx));
242        let key_mask_owned;
243        let key_mask = match (dictionary.nulls(), mask) {
244            (Some(n), None) => Some(n.inner()),
245            (None, Some(n)) => Some(n),
246            (Some(n), Some(m)) => {
247                key_mask_owned = n.inner() & m;
248                Some(&key_mask_owned)
249            }
250            (None, None) => None,
251        };
252        let keys = dictionary.keys().values();
253        let values = dictionary.values().as_ref();
254        let values_mask = compute_values_mask(keys, key_mask, values.len());
255
256        let masked_values = get_masked_values(values, &values_mask);
257        num_values += masked_values.len();
258        value_slices.push(masked_values);
259        values_arrays.push(values)
260    }
261
262    // Map from value to new index
263    let mut interner = Interner::new(num_values);
264    // Interleave indices for new values array
265    let mut indices = Vec::with_capacity(num_values);
266
267    // Compute the mapping for each dictionary
268    let key_mappings = dictionaries
269        .iter()
270        .enumerate()
271        .zip(value_slices)
272        .map(|((dictionary_idx, dictionary), values)| {
273            let zero = K::Native::from_usize(0).unwrap();
274            let mut mapping = vec![zero; dictionary.values().len()];
275
276            for (value_idx, value) in values {
277                mapping[value_idx] =
278                    *interner.intern(value, || match K::Native::from_usize(indices.len()) {
279                        Some(idx) => {
280                            indices.push((dictionary_idx, value_idx));
281                            Ok(idx)
282                        }
283                        None => Err(ArrowError::DictionaryKeyOverflowError),
284                    })?;
285            }
286            Ok(mapping)
287        })
288        .collect::<Result<Vec<_>, ArrowError>>()?;
289
290    Ok(MergedDictionaries {
291        key_mappings,
292        values: interleave(&values_arrays, &indices)?,
293    })
294}
295
296/// Return a mask identifying the values that are referenced by keys in `dictionary`
297/// at the positions indicated by `selection`
298fn compute_values_mask<K: ArrowNativeType>(
299    keys: &ScalarBuffer<K>,
300    mask: Option<&BooleanBuffer>,
301    max_key: usize,
302) -> BooleanBuffer {
303    let mut builder = BooleanBufferBuilder::new(max_key);
304    builder.advance(max_key);
305
306    match mask {
307        Some(n) => n
308            .set_indices()
309            .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
310        None => keys
311            .iter()
312            .for_each(|k| builder.set_bit(k.as_usize(), true)),
313    }
314    builder.finish()
315}
316
317/// Process primitive array values to bytes
318fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
319    array: &'a PrimitiveArray<T>,
320    mask: &BooleanBuffer,
321) -> Vec<(usize, Option<&'a [u8]>)>
322where
323    T::Native: ToByteSlice,
324{
325    let mut out = Vec::with_capacity(mask.count_set_bits());
326    let values = array.values();
327    for idx in mask.set_indices() {
328        out.push((
329            idx,
330            array.is_valid(idx).then_some(values[idx].to_byte_slice()),
331        ))
332    }
333    out
334}
335
336macro_rules! masked_primitive_to_bytes_helper {
337    ($t:ty, $array:expr, $mask:expr) => {
338        masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
339    };
340}
341
342/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
343fn get_masked_values<'a>(
344    array: &'a dyn Array,
345    mask: &BooleanBuffer,
346) -> Vec<(usize, Option<&'a [u8]>)> {
347    downcast_primitive! {
348        array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
349        DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
350        DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
351        DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
352        DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
353        _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
354    }
355}
356
357/// Compute [`get_masked_values`] for a [`GenericByteArray`]
358///
359/// Note: this does not check the null mask and will return values contained in null slots
360fn masked_bytes<'a, T: ByteArrayType>(
361    array: &'a GenericByteArray<T>,
362    mask: &BooleanBuffer,
363) -> Vec<(usize, Option<&'a [u8]>)> {
364    let mut out = Vec::with_capacity(mask.count_set_bits());
365    for idx in mask.set_indices() {
366        out.push((
367            idx,
368            array.is_valid(idx).then_some(array.value(idx).as_ref()),
369        ))
370    }
371    out
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    use arrow_array::cast::as_string_array;
379    use arrow_array::types::Int8Type;
380    use arrow_array::types::Int32Type;
381    use arrow_array::{DictionaryArray, Int8Array, Int32Array, StringArray};
382    use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
383    use std::sync::Arc;
384
385    #[test]
386    fn test_garbage_collect_i32_dictionary() {
387        let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
388        let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
389        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
390
391        // Only "a", "b", "d" are referenced, "c" is not
392        let gc = garbage_collect_dictionary(&dict).unwrap();
393
394        let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
395        let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
396        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
397
398        assert_eq!(gc, expected);
399    }
400
401    #[test]
402    fn test_garbage_collect_any_dictionary() {
403        let values = StringArray::from_iter_values(["a", "b", "c", "d"]);
404        let keys = Int32Array::from_iter_values([0, 1, 1, 3, 0, 0, 1]);
405        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
406
407        let gc = garbage_collect_any_dictionary(&dict).unwrap();
408
409        let expected_values = StringArray::from_iter_values(["a", "b", "d"]);
410        let expected_keys = Int32Array::from_iter_values([0, 1, 1, 2, 0, 0, 1]);
411        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
412
413        assert_eq!(gc.as_ref(), &expected);
414    }
415
416    #[test]
417    fn test_garbage_collect_with_nulls() {
418        let values = StringArray::from_iter_values(["a", "b", "c"]);
419        let keys = Int8Array::from(vec![Some(2), None, Some(0)]);
420        let dict = DictionaryArray::<Int8Type>::new(keys, Arc::new(values));
421
422        let gc = garbage_collect_dictionary(&dict).unwrap();
423
424        let expected_values = StringArray::from_iter_values(["a", "c"]);
425        let expected_keys = Int8Array::from(vec![Some(1), None, Some(0)]);
426        let expected = DictionaryArray::<Int8Type>::new(expected_keys, Arc::new(expected_values));
427
428        assert_eq!(gc, expected);
429    }
430
431    #[test]
432    fn test_garbage_collect_empty_dictionary() {
433        let values = StringArray::from_iter_values::<&str, _>([]);
434        let keys = Int32Array::from_iter_values([]);
435        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
436
437        let gc = garbage_collect_dictionary(&dict).unwrap();
438
439        assert_eq!(gc, dict);
440    }
441
442    #[test]
443    fn test_garbage_collect_dictionary_all_unreferenced() {
444        let values = StringArray::from_iter_values(["a", "b", "c"]);
445        let keys = Int32Array::from(vec![None, None, None]);
446        let dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(values));
447
448        let gc = garbage_collect_dictionary(&dict).unwrap();
449
450        // All keys are null, so dictionary values can be empty
451        let expected_values = StringArray::from_iter_values::<&str, _>([]);
452        let expected_keys = Int32Array::from(vec![None, None, None]);
453        let expected = DictionaryArray::<Int32Type>::new(expected_keys, Arc::new(expected_values));
454
455        assert_eq!(gc, expected);
456    }
457
458    #[test]
459    fn test_merge_strings() {
460        let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
461        let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
462        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
463
464        let values = as_string_array(merged.values.as_ref());
465        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
466        assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
467
468        assert_eq!(merged.key_mappings.len(), 2);
469        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
470        assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
471
472        let a_slice = a.slice(1, 4);
473        let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
474
475        let values = as_string_array(merged.values.as_ref());
476        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
477        assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
478
479        assert_eq!(merged.key_mappings.len(), 2);
480        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
481        assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
482
483        // Mask out only ["b", "b", "d"] from a
484        let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
485        let b_mask = BooleanBuffer::new_set(b.len());
486        let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
487
488        let values = as_string_array(merged.values.as_ref());
489        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
490        assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
491
492        assert_eq!(merged.key_mappings.len(), 2);
493        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
494        assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
495    }
496
497    #[test]
498    fn test_merge_nulls() {
499        let buffer = Buffer::from(b"helloworldbingohelloworld");
500        let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
501        let nulls = NullBuffer::from(vec![true, false, true, true, true]);
502        let values = StringArray::new(offsets, buffer, Some(nulls));
503
504        let key_values = vec![1, 2, 3, 1, 8, 2, 3];
505        let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
506        let keys = Int32Array::new(key_values.into(), Some(key_nulls));
507        let a = DictionaryArray::new(keys, Arc::new(values));
508        // [NULL, "bingo", NULL, NULL, NULL, "bingo", "hello"]
509
510        let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
511
512        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
513        let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
514        assert_eq!(merged.values.as_ref(), &expected);
515        assert_eq!(merged.key_mappings.len(), 2);
516        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
517        assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
518    }
519
520    #[test]
521    fn test_merge_keys_smaller() {
522        let values = StringArray::from_iter_values(["a", "b"]);
523        let keys = Int32Array::from_iter_values([1]);
524        let a = DictionaryArray::new(keys, Arc::new(values));
525
526        let merged = merge_dictionary_values(&[&a], None).unwrap();
527        let expected = StringArray::from(vec!["b"]);
528        assert_eq!(merged.values.as_ref(), &expected);
529    }
530}