Skip to main content

arrow_array/builder/
map_builder.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::builder::ArrayBuilder;
19use crate::{Array, ArrayRef, MapArray, StructArray};
20use arrow_buffer::Buffer;
21use arrow_buffer::{NullBuffer, NullBufferBuilder};
22use arrow_data::ArrayData;
23use arrow_schema::{ArrowError, DataType, Field, FieldRef};
24use std::any::Any;
25use std::sync::Arc;
26
27/// Builder for [`MapArray`]
28///
29/// ```
30/// # use arrow_array::builder::{Int32Builder, MapBuilder, StringBuilder};
31/// # use arrow_array::{Int32Array, StringArray};
32///
33/// let string_builder = StringBuilder::new();
34/// let int_builder = Int32Builder::with_capacity(4);
35///
36/// // Construct `[{"joe": 1}, {"blogs": 2, "foo": 4}, {}, null]`
37/// let mut builder = MapBuilder::new(None, string_builder, int_builder);
38///
39/// builder.keys().append_value("joe");
40/// builder.values().append_value(1);
41/// builder.append(true).unwrap();
42///
43/// builder.keys().append_value("blogs");
44/// builder.values().append_value(2);
45/// builder.keys().append_value("foo");
46/// builder.values().append_value(4);
47/// builder.append(true).unwrap();
48/// builder.append(true).unwrap();
49/// builder.append(false).unwrap();
50///
51/// let array = builder.finish();
52/// assert_eq!(array.value_offsets(), &[0, 1, 3, 3, 3]);
53/// assert_eq!(array.values().as_ref(), &Int32Array::from(vec![1, 2, 4]));
54/// assert_eq!(array.keys().as_ref(), &StringArray::from(vec!["joe", "blogs", "foo"]));
55///
56/// ```
57#[derive(Debug)]
58pub struct MapBuilder<K: ArrayBuilder, V: ArrayBuilder> {
59    offsets_builder: Vec<i32>,
60    null_buffer_builder: NullBufferBuilder,
61    field_names: MapFieldNames,
62    key_builder: K,
63    value_builder: V,
64    key_field: Option<FieldRef>,
65    value_field: Option<FieldRef>,
66}
67
68/// The [`Field`] names for a [`MapArray`]
69#[derive(Debug, Clone)]
70pub struct MapFieldNames {
71    /// [`Field`] name for map entries
72    pub entry: String,
73    /// [`Field`] name for map key
74    pub key: String,
75    /// [`Field`] name for map value
76    pub value: String,
77}
78
79impl Default for MapFieldNames {
80    fn default() -> Self {
81        Self {
82            entry: "entries".to_string(),
83            key: "keys".to_string(),
84            value: "values".to_string(),
85        }
86    }
87}
88
89impl<K: ArrayBuilder, V: ArrayBuilder> MapBuilder<K, V> {
90    /// Creates a new `MapBuilder`
91    pub fn new(field_names: Option<MapFieldNames>, key_builder: K, value_builder: V) -> Self {
92        let capacity = key_builder.len();
93        Self::with_capacity(field_names, key_builder, value_builder, capacity)
94    }
95
96    /// Creates a new `MapBuilder` with capacity
97    pub fn with_capacity(
98        field_names: Option<MapFieldNames>,
99        key_builder: K,
100        value_builder: V,
101        capacity: usize,
102    ) -> Self {
103        let mut offsets_builder = Vec::with_capacity(capacity + 1);
104        offsets_builder.push(0);
105        Self {
106            offsets_builder,
107            null_buffer_builder: NullBufferBuilder::new(capacity),
108            field_names: field_names.unwrap_or_default(),
109            key_builder,
110            value_builder,
111            key_field: None,
112            value_field: None,
113        }
114    }
115
116    /// Override the field passed to [`MapBuilder::new`]
117    ///
118    /// By default, a non-nullable field is created with the name `keys`
119    ///
120    /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the
121    /// field's data type does not match that of `K` or the field is nullable
122    pub fn with_keys_field(self, field: impl Into<FieldRef>) -> Self {
123        Self {
124            key_field: Some(field.into()),
125            ..self
126        }
127    }
128
129    /// Override the field passed to [`MapBuilder::new`]
130    ///
131    /// By default, a nullable field is created with the name `values`
132    ///
133    /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the
134    /// field's data type does not match that of `V`
135    pub fn with_values_field(self, field: impl Into<FieldRef>) -> Self {
136        Self {
137            value_field: Some(field.into()),
138            ..self
139        }
140    }
141
142    /// Returns the key array builder of the map
143    pub fn keys(&mut self) -> &mut K {
144        &mut self.key_builder
145    }
146
147    /// Returns the value array builder of the map
148    pub fn values(&mut self) -> &mut V {
149        &mut self.value_builder
150    }
151
152    /// Returns both the key and value array builders of the map
153    pub fn entries(&mut self) -> (&mut K, &mut V) {
154        (&mut self.key_builder, &mut self.value_builder)
155    }
156
157    /// Validates that key and value builders have equal lengths.
158    #[inline]
159    fn validate_equal_lengths(&self) -> Result<(), ArrowError> {
160        if self.key_builder.len() != self.value_builder.len() {
161            return Err(ArrowError::InvalidArgumentError(format!(
162                "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}",
163                self.key_builder.len(),
164                self.value_builder.len()
165            )));
166        }
167        Ok(())
168    }
169
170    /// Finish the current map array slot
171    ///
172    /// Returns an error if the key and values builders are in an inconsistent state.
173    #[inline]
174    pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> {
175        self.validate_equal_lengths()?;
176        self.offsets_builder.push(self.key_builder.len() as i32);
177        self.null_buffer_builder.append(is_valid);
178        Ok(())
179    }
180
181    /// Append `n` nulls to this [`MapBuilder`]
182    ///
183    /// Returns an error if the key and values builders are in an inconsistent state.
184    #[inline]
185    pub fn append_nulls(&mut self, n: usize) -> Result<(), ArrowError> {
186        self.validate_equal_lengths()?;
187        let offset = self.key_builder.len() as i32;
188        self.offsets_builder.extend(std::iter::repeat_n(offset, n));
189        self.null_buffer_builder.append_n_nulls(n);
190        Ok(())
191    }
192
193    /// Builds the [`MapArray`]
194    pub fn finish(&mut self) -> MapArray {
195        let len = self.len();
196        // Build the keys
197        let keys_arr = self.key_builder.finish();
198        let values_arr = self.value_builder.finish();
199        let offset_buffer = Buffer::from_vec(std::mem::take(&mut self.offsets_builder));
200        self.offsets_builder.push(0);
201        let null_bit_buffer = self.null_buffer_builder.finish();
202
203        self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len)
204    }
205
206    /// Builds the [`MapArray`] without resetting the builder.
207    pub fn finish_cloned(&self) -> MapArray {
208        let len = self.len();
209        // Build the keys
210        let keys_arr = self.key_builder.finish_cloned();
211        let values_arr = self.value_builder.finish_cloned();
212        let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice());
213        let nulls = self.null_buffer_builder.finish_cloned();
214        self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len)
215    }
216
217    fn finish_helper(
218        &self,
219        keys_arr: Arc<dyn Array>,
220        values_arr: Arc<dyn Array>,
221        offset_buffer: Buffer,
222        nulls: Option<NullBuffer>,
223        len: usize,
224    ) -> MapArray {
225        assert!(
226            keys_arr.null_count() == 0,
227            "Keys array must have no null values, found {} null value(s)",
228            keys_arr.null_count()
229        );
230
231        let keys_field = match &self.key_field {
232            Some(f) => {
233                assert!(!f.is_nullable(), "Keys field must not be nullable");
234                f.clone()
235            }
236            None => Arc::new(Field::new(
237                self.field_names.key.as_str(),
238                keys_arr.data_type().clone(),
239                false, // always non-nullable
240            )),
241        };
242        let values_field = match &self.value_field {
243            Some(f) => f.clone(),
244            None => Arc::new(Field::new(
245                self.field_names.value.as_str(),
246                values_arr.data_type().clone(),
247                true,
248            )),
249        };
250
251        let struct_array =
252            StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]);
253
254        let map_field = Arc::new(Field::new(
255            self.field_names.entry.as_str(),
256            struct_array.data_type().clone(),
257            false, // always non-nullable
258        ));
259        let array_data = ArrayData::builder(DataType::Map(map_field, false)) // TODO: support sorted keys
260            .len(len)
261            .add_buffer(offset_buffer)
262            .add_child_data(struct_array.into_data())
263            .nulls(nulls);
264
265        let array_data = unsafe { array_data.build_unchecked() };
266
267        MapArray::from(array_data)
268    }
269
270    /// Returns the current null buffer as a slice
271    pub fn validity_slice(&self) -> Option<&[u8]> {
272        self.null_buffer_builder.as_slice()
273    }
274}
275
276impl<K: ArrayBuilder, V: ArrayBuilder> ArrayBuilder for MapBuilder<K, V> {
277    fn len(&self) -> usize {
278        self.null_buffer_builder.len()
279    }
280
281    fn finish(&mut self) -> ArrayRef {
282        Arc::new(self.finish())
283    }
284
285    /// Builds the array without resetting the builder.
286    fn finish_cloned(&self) -> ArrayRef {
287        Arc::new(self.finish_cloned())
288    }
289
290    fn as_any(&self) -> &dyn Any {
291        self
292    }
293
294    fn as_any_mut(&mut self) -> &mut dyn Any {
295        self
296    }
297
298    fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
299        self
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::builder::{Int32Builder, StringBuilder, make_builder};
307    use crate::{Int32Array, StringArray};
308    use std::collections::HashMap;
309
310    #[test]
311    #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")]
312    fn test_map_builder_with_null_keys_panics() {
313        let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
314        builder.keys().append_null();
315        builder.values().append_value(42);
316        builder.append(true).unwrap();
317
318        builder.finish();
319    }
320
321    #[test]
322    fn test_boxed_map_builder() {
323        let keys_builder = make_builder(&DataType::Utf8, 5);
324        let values_builder = make_builder(&DataType::Int32, 5);
325
326        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
327        builder
328            .keys()
329            .as_any_mut()
330            .downcast_mut::<StringBuilder>()
331            .expect("should be an StringBuilder")
332            .append_value("1");
333        builder
334            .values()
335            .as_any_mut()
336            .downcast_mut::<Int32Builder>()
337            .expect("should be an Int32Builder")
338            .append_value(42);
339        builder.append(true).unwrap();
340
341        let map_array = builder.finish();
342
343        assert_eq!(
344            map_array
345                .keys()
346                .as_any()
347                .downcast_ref::<StringArray>()
348                .expect("should be an StringArray")
349                .value(0),
350            "1"
351        );
352        assert_eq!(
353            map_array
354                .values()
355                .as_any()
356                .downcast_ref::<Int32Array>()
357                .expect("should be an Int32Array")
358                .value(0),
359            42
360        );
361    }
362
363    #[test]
364    fn test_with_values_field() {
365        let value_field = Arc::new(Field::new("bars", DataType::Int32, false));
366        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
367            .with_values_field(value_field.clone());
368        builder.keys().append_value(1);
369        builder.values().append_value(2);
370        builder.append(true).unwrap();
371        builder.append(false).unwrap(); // This is fine as nullability refers to nullability of values
372        builder.keys().append_value(3);
373        builder.values().append_value(4);
374        builder.append(true).unwrap();
375        let map = builder.finish();
376
377        assert_eq!(map.len(), 3);
378        assert_eq!(
379            map.data_type(),
380            &DataType::Map(
381                Arc::new(Field::new(
382                    "entries",
383                    DataType::Struct(
384                        vec![
385                            Arc::new(Field::new("keys", DataType::Int32, false)),
386                            value_field.clone()
387                        ]
388                        .into()
389                    ),
390                    false,
391                )),
392                false
393            )
394        );
395
396        builder.keys().append_value(5);
397        builder.values().append_value(6);
398        builder.append(true).unwrap();
399        let map = builder.finish();
400
401        assert_eq!(map.len(), 1);
402        assert_eq!(
403            map.data_type(),
404            &DataType::Map(
405                Arc::new(Field::new(
406                    "entries",
407                    DataType::Struct(
408                        vec![
409                            Arc::new(Field::new("keys", DataType::Int32, false)),
410                            value_field
411                        ]
412                        .into()
413                    ),
414                    false,
415                )),
416                false
417            )
418        );
419    }
420
421    #[test]
422    fn test_with_keys_field() {
423        let mut key_metadata = HashMap::new();
424        key_metadata.insert("foo".to_string(), "bar".to_string());
425        let key_field = Arc::new(
426            Field::new("keys", DataType::Int32, false).with_metadata(key_metadata.clone()),
427        );
428        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
429            .with_keys_field(key_field.clone());
430        builder.keys().append_value(1);
431        builder.values().append_value(2);
432        builder.append(true).unwrap();
433        let map = builder.finish();
434
435        assert_eq!(map.len(), 1);
436        assert_eq!(
437            map.data_type(),
438            &DataType::Map(
439                Arc::new(Field::new(
440                    "entries",
441                    DataType::Struct(
442                        vec![
443                            Arc::new(
444                                Field::new("keys", DataType::Int32, false)
445                                    .with_metadata(key_metadata)
446                            ),
447                            Arc::new(Field::new("values", DataType::Int32, true))
448                        ]
449                        .into()
450                    ),
451                    false,
452                )),
453                false
454            )
455        );
456    }
457
458    #[test]
459    fn test_append_nulls() {
460        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new());
461
462        builder.keys().append_value(1);
463        builder.values().append_value(100);
464        builder.append(true).unwrap();
465
466        builder.append_nulls(3).unwrap();
467
468        builder.keys().append_value(2);
469        builder.values().append_value(200);
470        builder.append(true).unwrap();
471
472        let map = builder.finish();
473        assert_eq!(map.len(), 5);
474        assert_eq!(map.null_count(), 3);
475        assert!(map.is_valid(0));
476        assert!(map.is_null(1));
477        assert!(map.is_null(2));
478        assert!(map.is_null(3));
479        assert!(map.is_valid(4));
480        assert_eq!(map.value_offsets(), &[0, 1, 1, 1, 1, 2]);
481    }
482
483    #[test]
484    fn test_append_nulls_inconsistent_state() {
485        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new());
486        // Add a key without a matching value
487        builder.keys().append_value(1);
488
489        let result = builder.append_nulls(2);
490        assert!(result.is_err());
491        assert!(result.unwrap_err().to_string().contains("unequal lengths"));
492    }
493
494    #[test]
495    #[should_panic(expected = "Keys field must not be nullable")]
496    fn test_with_nullable_keys_field() {
497        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
498            .with_keys_field(Arc::new(Field::new("keys", DataType::Int32, true)));
499
500        builder.keys().append_value(1);
501        builder.values().append_value(2);
502        builder.append(true).unwrap();
503
504        builder.finish();
505    }
506
507    #[test]
508    #[should_panic(expected = "Incorrect datatype")]
509    fn test_keys_field_type_mismatch() {
510        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
511            .with_keys_field(Arc::new(Field::new("keys", DataType::Utf8, false)));
512
513        builder.keys().append_value(1);
514        builder.values().append_value(2);
515        builder.append(true).unwrap();
516
517        builder.finish();
518    }
519}