1use crate::arrow::buffer::offset_buffer::OffsetBuffer;
19use crate::arrow::record_reader::buffer::ValuesBuffer;
20use crate::errors::{ParquetError, Result};
21use arrow_array::{make_array, Array, ArrayRef, OffsetSizeTrait};
22use arrow_buffer::{ArrowNativeType, Buffer};
23use arrow_data::ArrayDataBuilder;
24use arrow_schema::DataType as ArrowType;
25use std::sync::Arc;
26
27pub enum DictionaryBuffer<K: ArrowNativeType, V: OffsetSizeTrait> {
30 Dict { keys: Vec<K>, values: ArrayRef },
31 Values { values: OffsetBuffer<V> },
32}
33
34impl<K: ArrowNativeType, V: OffsetSizeTrait> Default for DictionaryBuffer<K, V> {
35 fn default() -> Self {
36 Self::Values {
37 values: Default::default(),
38 }
39 }
40}
41
42impl<K: ArrowNativeType + Ord, V: OffsetSizeTrait> DictionaryBuffer<K, V> {
43 #[allow(unused)]
44 pub fn len(&self) -> usize {
45 match self {
46 Self::Dict { keys, .. } => keys.len(),
47 Self::Values { values } => values.len(),
48 }
49 }
50
51 pub fn as_keys(&mut self, dictionary: &ArrayRef) -> Option<&mut Vec<K>> {
59 assert!(K::from_usize(dictionary.len()).is_some());
60
61 match self {
62 Self::Dict { keys, values } => {
63 let values_ptr = values.as_ref() as *const _ as *const ();
67 let dict_ptr = dictionary.as_ref() as *const _ as *const ();
68 if values_ptr == dict_ptr {
69 Some(keys)
70 } else if keys.is_empty() {
71 *values = Arc::clone(dictionary);
72 Some(keys)
73 } else {
74 None
75 }
76 }
77 Self::Values { values } if values.is_empty() => {
78 *self = Self::Dict {
79 keys: Default::default(),
80 values: Arc::clone(dictionary),
81 };
82 match self {
83 Self::Dict { keys, .. } => Some(keys),
84 _ => unreachable!(),
85 }
86 }
87 _ => None,
88 }
89 }
90
91 pub fn spill_values(&mut self) -> Result<&mut OffsetBuffer<V>> {
96 match self {
97 Self::Values { values } => Ok(values),
98 Self::Dict { keys, values } => {
99 let mut spilled = OffsetBuffer::default();
100 let data = values.to_data();
101 let dict_buffers = data.buffers();
102 let dict_offsets = dict_buffers[0].typed_data::<V>();
103 let dict_values = dict_buffers[1].as_slice();
104
105 if values.is_empty() {
106 spilled.offsets.resize(keys.len() + 1, V::default());
108 } else {
109 spilled.extend_from_dictionary(keys.as_slice(), dict_offsets, dict_values)?;
115 }
116
117 *self = Self::Values { values: spilled };
118 match self {
119 Self::Values { values } => Ok(values),
120 _ => unreachable!(),
121 }
122 }
123 }
124 }
125
126 pub fn into_array(
128 self,
129 null_buffer: Option<Buffer>,
130 data_type: &ArrowType,
131 ) -> Result<ArrayRef> {
132 assert!(matches!(data_type, ArrowType::Dictionary(_, _)));
133
134 match self {
135 Self::Dict { keys, values } => {
136 if !values.is_empty() {
138 let min = K::from_usize(0).unwrap();
139 let max = K::from_usize(values.len()).unwrap();
140
141 if !keys
145 .as_slice()
146 .iter()
147 .copied()
148 .fold(true, |a, x| a && x >= min && x < max)
149 {
150 return Err(general_err!(
151 "dictionary key beyond bounds of dictionary: 0..{}",
152 values.len()
153 ));
154 }
155 }
156
157 let ArrowType::Dictionary(_, value_type) = data_type else {
158 unreachable!()
159 };
160 let values = if let ArrowType::FixedSizeBinary(size) = **value_type {
161 arrow_cast::cast(&values, &ArrowType::FixedSizeBinary(size)).unwrap()
162 } else {
163 values
164 };
165
166 let builder = ArrayDataBuilder::new(data_type.clone())
167 .len(keys.len())
168 .add_buffer(Buffer::from_vec(keys))
169 .add_child_data(values.into_data())
170 .null_bit_buffer(null_buffer);
171
172 let data = match cfg!(debug_assertions) {
173 true => builder.build().unwrap(),
174 false => unsafe { builder.build_unchecked() },
175 };
176
177 Ok(make_array(data))
178 }
179 Self::Values { values } => {
180 let value_type = match data_type {
181 ArrowType::Dictionary(_, v) => v.as_ref().clone(),
182 _ => unreachable!(),
183 };
184
185 let array =
187 arrow_cast::cast(&values.into_array(null_buffer, value_type), data_type)
188 .expect("cast should be infallible");
189
190 Ok(array)
191 }
192 }
193 }
194}
195
196impl<K: ArrowNativeType, V: OffsetSizeTrait> ValuesBuffer for DictionaryBuffer<K, V> {
197 fn pad_nulls(
198 &mut self,
199 read_offset: usize,
200 values_read: usize,
201 levels_read: usize,
202 valid_mask: &[u8],
203 ) {
204 match self {
205 Self::Dict { keys, .. } => {
206 keys.resize(read_offset + levels_read, K::default());
207 keys.pad_nulls(read_offset, values_read, levels_read, valid_mask)
208 }
209 Self::Values { values, .. } => {
210 values.pad_nulls(read_offset, values_read, levels_read, valid_mask)
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use arrow::compute::cast;
220 use arrow_array::StringArray;
221
222 #[test]
223 fn test_dictionary_buffer() {
224 let dict_type =
225 ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8));
226
227 let d1: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world", "", "a", "b"]));
228
229 let mut buffer = DictionaryBuffer::<i32, i32>::default();
230
231 let values = &[1, 0, 3, 2, 4];
233 buffer.as_keys(&d1).unwrap().extend_from_slice(values);
234
235 let mut valid = vec![false, false, true, true, false, true, true, true];
236 let valid_buffer = Buffer::from_iter(valid.iter().cloned());
237 buffer.pad_nulls(0, values.len(), valid.len(), valid_buffer.as_slice());
238
239 let values = buffer.spill_values().unwrap();
242 let read_offset = values.len();
243 values.try_push("bingo".as_bytes(), false).unwrap();
244 values.try_push("bongo".as_bytes(), false).unwrap();
245
246 valid.extend_from_slice(&[false, false, true, false, true]);
247 let null_buffer = Buffer::from_iter(valid.iter().cloned());
248 buffer.pad_nulls(read_offset, 2, 5, null_buffer.as_slice());
249
250 assert_eq!(buffer.len(), 13);
251 let split = std::mem::take(&mut buffer);
252
253 let array = split.into_array(Some(null_buffer), &dict_type).unwrap();
254 assert_eq!(array.data_type(), &dict_type);
255
256 let strings = cast(&array, &ArrowType::Utf8).unwrap();
257 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
258 assert_eq!(
259 strings.iter().collect::<Vec<_>>(),
260 vec![
261 None,
262 None,
263 Some("world"),
264 Some("hello"),
265 None,
266 Some("a"),
267 Some(""),
268 Some("b"),
269 None,
270 None,
271 Some("bingo"),
272 None,
273 Some("bongo")
274 ]
275 );
276
277 assert!(matches!(&buffer, DictionaryBuffer::Values { .. }));
279 assert_eq!(buffer.len(), 0);
280 let d2 = Arc::new(StringArray::from(vec!["bingo", ""])) as ArrayRef;
281 buffer
282 .as_keys(&d2)
283 .unwrap()
284 .extend_from_slice(&[0, 1, 0, 1]);
285
286 let array = std::mem::take(&mut buffer)
287 .into_array(None, &dict_type)
288 .unwrap();
289 assert_eq!(array.data_type(), &dict_type);
290
291 let strings = cast(&array, &ArrowType::Utf8).unwrap();
292 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
293 assert_eq!(
294 strings.iter().collect::<Vec<_>>(),
295 vec![Some("bingo"), Some(""), Some("bingo"), Some("")]
296 );
297
298 assert!(matches!(&buffer, DictionaryBuffer::Values { .. }));
300 assert_eq!(buffer.len(), 0);
301 let d3 = Arc::new(StringArray::from(vec!["bongo"])) as ArrayRef;
302 buffer.as_keys(&d3).unwrap().extend_from_slice(&[0, 0]);
303
304 let d4 = Arc::new(StringArray::from(vec!["bananas"])) as ArrayRef;
306 assert!(buffer.as_keys(&d4).is_none());
307 }
308
309 #[test]
310 fn test_validates_keys() {
311 let dict_type =
312 ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8));
313
314 let mut buffer = DictionaryBuffer::<i32, i32>::default();
315 let d = Arc::new(StringArray::from(vec!["", "f"])) as ArrayRef;
316 buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 2, 0]);
317
318 let err = buffer.into_array(None, &dict_type).unwrap_err().to_string();
319 assert!(
320 err.contains("dictionary key beyond bounds of dictionary: 0..2"),
321 "{}",
322 err
323 );
324
325 let mut buffer = DictionaryBuffer::<i32, i32>::default();
326 let d = Arc::new(StringArray::from(vec![""])) as ArrayRef;
327 buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 1, 0]);
328
329 let err = buffer.spill_values().unwrap_err().to_string();
330 assert!(
331 err.contains("dictionary key beyond bounds of dictionary: 0..1"),
332 "{}",
333 err
334 );
335 }
336}