1use crate::arrow::buffer::offset_buffer::OffsetBuffer;
19use crate::arrow::record_reader::buffer::ValuesBuffer;
20use crate::errors::{ParquetError, Result};
21use arrow_array::{make_array, Array, ArrayRef, OffsetSizeTrait};
22use arrow_buffer::{ArrowNativeType, Buffer};
23use arrow_data::ArrayDataBuilder;
24use arrow_schema::DataType as ArrowType;
25use std::sync::Arc;
26
27pub enum DictionaryBuffer<K: ArrowNativeType, V: OffsetSizeTrait> {
30 Dict { keys: Vec<K>, values: ArrayRef },
31 Values { values: OffsetBuffer<V> },
32}
33
34impl<K: ArrowNativeType, V: OffsetSizeTrait> Default for DictionaryBuffer<K, V> {
35 fn default() -> Self {
36 Self::Values {
37 values: Default::default(),
38 }
39 }
40}
41
42impl<K: ArrowNativeType + Ord, V: OffsetSizeTrait> DictionaryBuffer<K, V> {
43 #[allow(unused)]
44 pub fn len(&self) -> usize {
45 match self {
46 Self::Dict { keys, .. } => keys.len(),
47 Self::Values { values } => values.len(),
48 }
49 }
50
51 pub fn as_keys(&mut self, dictionary: &ArrayRef) -> Option<&mut Vec<K>> {
59 assert!(K::from_usize(dictionary.len()).is_some());
60
61 match self {
62 Self::Dict { keys, values } => {
63 let values_ptr = values.as_ref() as *const _ as *const ();
67 let dict_ptr = dictionary.as_ref() as *const _ as *const ();
68 if values_ptr == dict_ptr {
69 Some(keys)
70 } else if keys.is_empty() {
71 *values = Arc::clone(dictionary);
72 Some(keys)
73 } else {
74 None
75 }
76 }
77 Self::Values { values } if values.is_empty() => {
78 *self = Self::Dict {
79 keys: Default::default(),
80 values: Arc::clone(dictionary),
81 };
82 match self {
83 Self::Dict { keys, .. } => Some(keys),
84 _ => unreachable!(),
85 }
86 }
87 _ => None,
88 }
89 }
90
91 pub fn spill_values(&mut self) -> Result<&mut OffsetBuffer<V>> {
96 match self {
97 Self::Values { values } => Ok(values),
98 Self::Dict { keys, values } => {
99 let mut spilled = OffsetBuffer::default();
100 let data = values.to_data();
101 let dict_buffers = data.buffers();
102 let dict_offsets = dict_buffers[0].typed_data::<V>();
103 let dict_values = dict_buffers[1].as_slice();
104
105 if values.is_empty() {
106 spilled.offsets.resize(keys.len() + 1, V::default());
108 } else {
109 spilled.extend_from_dictionary(keys.as_slice(), dict_offsets, dict_values)?;
115 }
116
117 *self = Self::Values { values: spilled };
118 match self {
119 Self::Values { values } => Ok(values),
120 _ => unreachable!(),
121 }
122 }
123 }
124 }
125
126 pub fn into_array(
128 self,
129 null_buffer: Option<Buffer>,
130 data_type: &ArrowType,
131 ) -> Result<ArrayRef> {
132 assert!(matches!(data_type, ArrowType::Dictionary(_, _)));
133
134 match self {
135 Self::Dict { keys, values } => {
136 if !values.is_empty() {
138 let min = K::from_usize(0).unwrap();
139 let max = K::from_usize(values.len()).unwrap();
140
141 if !keys
145 .as_slice()
146 .iter()
147 .copied()
148 .fold(true, |a, x| a && x >= min && x < max)
149 {
150 return Err(general_err!(
151 "dictionary key beyond bounds of dictionary: 0..{}",
152 values.len()
153 ));
154 }
155 }
156
157 let builder = ArrayDataBuilder::new(data_type.clone())
158 .len(keys.len())
159 .add_buffer(Buffer::from_vec(keys))
160 .add_child_data(values.into_data())
161 .null_bit_buffer(null_buffer);
162
163 let data = match cfg!(debug_assertions) {
164 true => builder.build().unwrap(),
165 false => unsafe { builder.build_unchecked() },
166 };
167
168 Ok(make_array(data))
169 }
170 Self::Values { values } => {
171 let value_type = match data_type {
172 ArrowType::Dictionary(_, v) => v.as_ref().clone(),
173 _ => unreachable!(),
174 };
175
176 let array =
178 arrow_cast::cast(&values.into_array(null_buffer, value_type), data_type)
179 .expect("cast should be infallible");
180
181 Ok(array)
182 }
183 }
184 }
185}
186
187impl<K: ArrowNativeType, V: OffsetSizeTrait> ValuesBuffer for DictionaryBuffer<K, V> {
188 fn pad_nulls(
189 &mut self,
190 read_offset: usize,
191 values_read: usize,
192 levels_read: usize,
193 valid_mask: &[u8],
194 ) {
195 match self {
196 Self::Dict { keys, .. } => {
197 keys.resize(read_offset + levels_read, K::default());
198 keys.pad_nulls(read_offset, values_read, levels_read, valid_mask)
199 }
200 Self::Values { values, .. } => {
201 values.pad_nulls(read_offset, values_read, levels_read, valid_mask)
202 }
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use arrow::compute::cast;
211 use arrow_array::StringArray;
212
213 #[test]
214 fn test_dictionary_buffer() {
215 let dict_type =
216 ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8));
217
218 let d1: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world", "", "a", "b"]));
219
220 let mut buffer = DictionaryBuffer::<i32, i32>::default();
221
222 let values = &[1, 0, 3, 2, 4];
224 buffer.as_keys(&d1).unwrap().extend_from_slice(values);
225
226 let mut valid = vec![false, false, true, true, false, true, true, true];
227 let valid_buffer = Buffer::from_iter(valid.iter().cloned());
228 buffer.pad_nulls(0, values.len(), valid.len(), valid_buffer.as_slice());
229
230 let values = buffer.spill_values().unwrap();
233 let read_offset = values.len();
234 values.try_push("bingo".as_bytes(), false).unwrap();
235 values.try_push("bongo".as_bytes(), false).unwrap();
236
237 valid.extend_from_slice(&[false, false, true, false, true]);
238 let null_buffer = Buffer::from_iter(valid.iter().cloned());
239 buffer.pad_nulls(read_offset, 2, 5, null_buffer.as_slice());
240
241 assert_eq!(buffer.len(), 13);
242 let split = std::mem::take(&mut buffer);
243
244 let array = split.into_array(Some(null_buffer), &dict_type).unwrap();
245 assert_eq!(array.data_type(), &dict_type);
246
247 let strings = cast(&array, &ArrowType::Utf8).unwrap();
248 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
249 assert_eq!(
250 strings.iter().collect::<Vec<_>>(),
251 vec![
252 None,
253 None,
254 Some("world"),
255 Some("hello"),
256 None,
257 Some("a"),
258 Some(""),
259 Some("b"),
260 None,
261 None,
262 Some("bingo"),
263 None,
264 Some("bongo")
265 ]
266 );
267
268 assert!(matches!(&buffer, DictionaryBuffer::Values { .. }));
270 assert_eq!(buffer.len(), 0);
271 let d2 = Arc::new(StringArray::from(vec!["bingo", ""])) as ArrayRef;
272 buffer
273 .as_keys(&d2)
274 .unwrap()
275 .extend_from_slice(&[0, 1, 0, 1]);
276
277 let array = std::mem::take(&mut buffer)
278 .into_array(None, &dict_type)
279 .unwrap();
280 assert_eq!(array.data_type(), &dict_type);
281
282 let strings = cast(&array, &ArrowType::Utf8).unwrap();
283 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
284 assert_eq!(
285 strings.iter().collect::<Vec<_>>(),
286 vec![Some("bingo"), Some(""), Some("bingo"), Some("")]
287 );
288
289 assert!(matches!(&buffer, DictionaryBuffer::Values { .. }));
291 assert_eq!(buffer.len(), 0);
292 let d3 = Arc::new(StringArray::from(vec!["bongo"])) as ArrayRef;
293 buffer.as_keys(&d3).unwrap().extend_from_slice(&[0, 0]);
294
295 let d4 = Arc::new(StringArray::from(vec!["bananas"])) as ArrayRef;
297 assert!(buffer.as_keys(&d4).is_none());
298 }
299
300 #[test]
301 fn test_validates_keys() {
302 let dict_type =
303 ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8));
304
305 let mut buffer = DictionaryBuffer::<i32, i32>::default();
306 let d = Arc::new(StringArray::from(vec!["", "f"])) as ArrayRef;
307 buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 2, 0]);
308
309 let err = buffer.into_array(None, &dict_type).unwrap_err().to_string();
310 assert!(
311 err.contains("dictionary key beyond bounds of dictionary: 0..2"),
312 "{}",
313 err
314 );
315
316 let mut buffer = DictionaryBuffer::<i32, i32>::default();
317 let d = Arc::new(StringArray::from(vec![""])) as ArrayRef;
318 buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 1, 0]);
319
320 let err = buffer.spill_values().unwrap_err().to_string();
321 assert!(
322 err.contains("dictionary key beyond bounds of dictionary: 0..1"),
323 "{}",
324 err
325 );
326 }
327}