arrow_select/
dictionary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::interleave::interleave;
19use ahash::RandomState;
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_array::types::{
22    ArrowDictionaryKeyType, ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType,
23    LargeUtf8Type, Utf8Type,
24};
25use arrow_array::{cast::AsArray, downcast_primitive};
26use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray};
27use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer, ToByteSlice};
28use arrow_schema::{ArrowError, DataType};
29
30/// A best effort interner that maintains a fixed number of buckets
31/// and interns keys based on their hash value
32///
33/// Hash collisions will result in replacement
34struct Interner<'a, V> {
35    state: RandomState,
36    buckets: Vec<Option<InternerBucket<'a, V>>>,
37    shift: u32,
38}
39
40/// A single bucket in [`Interner`].
41type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
42
43impl<'a, V> Interner<'a, V> {
44    /// Capacity controls the number of unique buckets allocated within the Interner
45    ///
46    /// A larger capacity reduces the probability of hash collisions, and should be set
47    /// based on an approximation of the upper bound of unique values
48    fn new(capacity: usize) -> Self {
49        // Add additional buckets to help reduce collisions
50        let shift = (capacity as u64 + 128).leading_zeros();
51        let num_buckets = (u64::MAX >> shift) as usize;
52        let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect();
53        Self {
54            // A fixed seed to ensure deterministic behaviour
55            state: RandomState::with_seeds(0, 0, 0, 0),
56            buckets,
57            shift,
58        }
59    }
60
61    fn intern<F: FnOnce() -> Result<V, E>, E>(
62        &mut self,
63        new: Option<&'a [u8]>,
64        f: F,
65    ) -> Result<&V, E> {
66        let hash = self.state.hash_one(new);
67        let bucket_idx = hash >> self.shift;
68        Ok(match &mut self.buckets[bucket_idx as usize] {
69            Some((current, v)) => {
70                if *current != new {
71                    *v = f()?;
72                    *current = new;
73                }
74                v
75            }
76            slot => &slot.insert((new, f()?)).1,
77        })
78    }
79}
80
81pub struct MergedDictionaries<K: ArrowDictionaryKeyType> {
82    /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key`
83    pub key_mappings: Vec<Vec<K::Native>>,
84    /// The new values
85    pub values: ArrayRef,
86}
87
88/// Performs a cheap, pointer-based comparison of two byte array
89///
90/// See [`ScalarBuffer::ptr_eq`]
91fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
92    match (a.as_bytes_opt::<T>(), b.as_bytes_opt::<T>()) {
93        (Some(a), Some(b)) => {
94            let values_eq = a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets());
95            match (a.nulls(), b.nulls()) {
96                (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()),
97                (None, None) => values_eq,
98                _ => false,
99            }
100        }
101        _ => false,
102    }
103}
104
105/// A type-erased function that compares two array for pointer equality
106type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
107
108/// A weak heuristic of whether to merge dictionary values that aims to only
109/// perform the expensive merge computation when it is likely to yield at least
110/// some return over the naive approach used by MutableArrayData
111///
112/// `len` is the total length of the merged output
113pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
114    dictionaries: &[&DictionaryArray<K>],
115    len: usize,
116) -> bool {
117    use DataType::*;
118    let first_values = dictionaries[0].values().as_ref();
119    let ptr_eq: PtrEq = match first_values.data_type() {
120        Utf8 => bytes_ptr_eq::<Utf8Type>,
121        LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
122        Binary => bytes_ptr_eq::<BinaryType>,
123        LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
124        dt => {
125            if !dt.is_primitive() {
126                return false;
127            }
128            |a, b| a.to_data().ptr_eq(&b.to_data())
129        }
130    };
131
132    let mut single_dictionary = true;
133    let mut total_values = first_values.len();
134    for dict in dictionaries.iter().skip(1) {
135        let values = dict.values().as_ref();
136        total_values += values.len();
137        if single_dictionary {
138            single_dictionary = ptr_eq(first_values, values)
139        }
140    }
141
142    let overflow = K::Native::from_usize(total_values).is_none();
143    let values_exceed_length = total_values >= len;
144
145    !single_dictionary && (overflow || values_exceed_length)
146}
147
148/// Given an array of dictionaries and an optional key mask compute a values array
149/// containing referenced values, along with mappings from the [`DictionaryArray`]
150/// keys to the new keys within this values array. Best-effort will be made to ensure
151/// that the dictionary values are unique
152///
153/// This method is meant to be very fast and the output dictionary values
154/// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower
155/// but produces unique values
156pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
157    dictionaries: &[&DictionaryArray<K>],
158    masks: Option<&[BooleanBuffer]>,
159) -> Result<MergedDictionaries<K>, ArrowError> {
160    let mut num_values = 0;
161
162    let mut values_arrays = Vec::with_capacity(dictionaries.len());
163    let mut value_slices = Vec::with_capacity(dictionaries.len());
164
165    for (idx, dictionary) in dictionaries.iter().enumerate() {
166        let mask = masks.and_then(|m| m.get(idx));
167        let key_mask_owned;
168        let key_mask = match (dictionary.nulls(), mask) {
169            (Some(n), None) => Some(n.inner()),
170            (None, Some(n)) => Some(n),
171            (Some(n), Some(m)) => {
172                key_mask_owned = n.inner() & m;
173                Some(&key_mask_owned)
174            }
175            (None, None) => None,
176        };
177        let keys = dictionary.keys().values();
178        let values = dictionary.values().as_ref();
179        let values_mask = compute_values_mask(keys, key_mask, values.len());
180
181        let masked_values = get_masked_values(values, &values_mask);
182        num_values += masked_values.len();
183        value_slices.push(masked_values);
184        values_arrays.push(values)
185    }
186
187    // Map from value to new index
188    let mut interner = Interner::new(num_values);
189    // Interleave indices for new values array
190    let mut indices = Vec::with_capacity(num_values);
191
192    // Compute the mapping for each dictionary
193    let key_mappings = dictionaries
194        .iter()
195        .enumerate()
196        .zip(value_slices)
197        .map(|((dictionary_idx, dictionary), values)| {
198            let zero = K::Native::from_usize(0).unwrap();
199            let mut mapping = vec![zero; dictionary.values().len()];
200
201            for (value_idx, value) in values {
202                mapping[value_idx] =
203                    *interner.intern(value, || match K::Native::from_usize(indices.len()) {
204                        Some(idx) => {
205                            indices.push((dictionary_idx, value_idx));
206                            Ok(idx)
207                        }
208                        None => Err(ArrowError::DictionaryKeyOverflowError),
209                    })?;
210            }
211            Ok(mapping)
212        })
213        .collect::<Result<Vec<_>, ArrowError>>()?;
214
215    Ok(MergedDictionaries {
216        key_mappings,
217        values: interleave(&values_arrays, &indices)?,
218    })
219}
220
221/// Return a mask identifying the values that are referenced by keys in `dictionary`
222/// at the positions indicated by `selection`
223fn compute_values_mask<K: ArrowNativeType>(
224    keys: &ScalarBuffer<K>,
225    mask: Option<&BooleanBuffer>,
226    max_key: usize,
227) -> BooleanBuffer {
228    let mut builder = BooleanBufferBuilder::new(max_key);
229    builder.advance(max_key);
230
231    match mask {
232        Some(n) => n
233            .set_indices()
234            .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)),
235        None => keys
236            .iter()
237            .for_each(|k| builder.set_bit(k.as_usize(), true)),
238    }
239    builder.finish()
240}
241
242/// Process primitive array values to bytes
243fn masked_primitives_to_bytes<'a, T: ArrowPrimitiveType>(
244    array: &'a PrimitiveArray<T>,
245    mask: &BooleanBuffer,
246) -> Vec<(usize, Option<&'a [u8]>)>
247where
248    T::Native: ToByteSlice,
249{
250    let mut out = Vec::with_capacity(mask.count_set_bits());
251    let values = array.values();
252    for idx in mask.set_indices() {
253        out.push((
254            idx,
255            array.is_valid(idx).then_some(values[idx].to_byte_slice()),
256        ))
257    }
258    out
259}
260
261macro_rules! masked_primitive_to_bytes_helper {
262    ($t:ty, $array:expr, $mask:expr) => {
263        masked_primitives_to_bytes::<$t>($array.as_primitive(), $mask)
264    };
265}
266
267/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
268fn get_masked_values<'a>(
269    array: &'a dyn Array,
270    mask: &BooleanBuffer,
271) -> Vec<(usize, Option<&'a [u8]>)> {
272    downcast_primitive! {
273        array.data_type() => (masked_primitive_to_bytes_helper, array, mask),
274        DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
275        DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
276        DataType::Binary => masked_bytes(array.as_binary::<i32>(), mask),
277        DataType::LargeBinary => masked_bytes(array.as_binary::<i64>(), mask),
278        _ => unimplemented!("Dictionary merging for type {} is not implemented", array.data_type()),
279    }
280}
281
282/// Compute [`get_masked_values`] for a [`GenericByteArray`]
283///
284/// Note: this does not check the null mask and will return values contained in null slots
285fn masked_bytes<'a, T: ByteArrayType>(
286    array: &'a GenericByteArray<T>,
287    mask: &BooleanBuffer,
288) -> Vec<(usize, Option<&'a [u8]>)> {
289    let mut out = Vec::with_capacity(mask.count_set_bits());
290    for idx in mask.set_indices() {
291        out.push((
292            idx,
293            array.is_valid(idx).then_some(array.value(idx).as_ref()),
294        ))
295    }
296    out
297}
298
299#[cfg(test)]
300mod tests {
301    use crate::dictionary::merge_dictionary_values;
302    use arrow_array::cast::as_string_array;
303    use arrow_array::types::Int32Type;
304    use arrow_array::{DictionaryArray, Int32Array, StringArray};
305    use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer};
306    use std::sync::Arc;
307
308    #[test]
309    fn test_merge_strings() {
310        let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "a", "b", "d", "c", "e"]);
311        let b = DictionaryArray::<Int32Type>::from_iter(["c", "f", "c", "d", "a", "d"]);
312        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
313
314        let values = as_string_array(merged.values.as_ref());
315        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
316        assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]);
317
318        assert_eq!(merged.key_mappings.len(), 2);
319        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]);
320        assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]);
321
322        let a_slice = a.slice(1, 4);
323        let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap();
324
325        let values = as_string_array(merged.values.as_ref());
326        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
327        assert_eq!(&actual, &["a", "b", "d", "c", "f"]);
328
329        assert_eq!(merged.key_mappings.len(), 2);
330        assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]);
331        assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]);
332
333        // Mask out only ["b", "b", "d"] from a
334        let a_mask = BooleanBuffer::from_iter([false, true, false, true, true, false, false]);
335        let b_mask = BooleanBuffer::new_set(b.len());
336        let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap();
337
338        let values = as_string_array(merged.values.as_ref());
339        let actual: Vec<_> = values.iter().map(Option::unwrap).collect();
340        assert_eq!(&actual, &["b", "d", "c", "f", "a"]);
341
342        assert_eq!(merged.key_mappings.len(), 2);
343        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]);
344        assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]);
345    }
346
347    #[test]
348    fn test_merge_nulls() {
349        let buffer = Buffer::from(b"helloworldbingohelloworld");
350        let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]);
351        let nulls = NullBuffer::from(vec![true, false, true, true, true]);
352        let values = StringArray::new(offsets, buffer, Some(nulls));
353
354        let key_values = vec![1, 2, 3, 1, 8, 2, 3];
355        let key_nulls = NullBuffer::from(vec![true, true, false, true, false, true, true]);
356        let keys = Int32Array::new(key_values.into(), Some(key_nulls));
357        let a = DictionaryArray::new(keys, Arc::new(values));
358        // [NULL, "bingo", NULL, NULL, NULL, "bingo", "hello"]
359
360        let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
361
362        let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
363        let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
364        assert_eq!(merged.values.as_ref(), &expected);
365        assert_eq!(merged.key_mappings.len(), 2);
366        assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
367        assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
368    }
369
370    #[test]
371    fn test_merge_keys_smaller() {
372        let values = StringArray::from_iter_values(["a", "b"]);
373        let keys = Int32Array::from_iter_values([1]);
374        let a = DictionaryArray::new(keys, Arc::new(values));
375
376        let merged = merge_dictionary_values(&[&a], None).unwrap();
377        let expected = StringArray::from(vec!["b"]);
378        assert_eq!(merged.values.as_ref(), &expected);
379    }
380}