Skip to main content

arrow_json/reader/
map_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow_array::builder::BufferBuilder;
21use arrow_array::{ArrayRef, MapArray, StructArray};
22use arrow_buffer::{ArrowNativeType, NullBufferBuilder, OffsetBuffer, ScalarBuffer};
23use arrow_schema::{ArrowError, DataType, FieldRef, Fields};
24
25use crate::reader::tape::{Tape, TapeElement};
26use crate::reader::{ArrayDecoder, DecoderContext};
27
28pub struct MapArrayDecoder {
29    entries_field: FieldRef,
30    key_value_fields: Fields,
31    ordered: bool,
32    keys: Box<dyn ArrayDecoder>,
33    values: Box<dyn ArrayDecoder>,
34    ignore_type_conflicts: bool,
35    is_nullable: bool,
36}
37
38impl MapArrayDecoder {
39    pub fn new(
40        ctx: &DecoderContext,
41        data_type: &DataType,
42        is_nullable: bool,
43    ) -> Result<Self, ArrowError> {
44        let (entries_field, ordered) = match data_type {
45            DataType::Map(_, true) => {
46                return Err(ArrowError::NotYetImplemented(
47                    "Decoding MapArray with sorted fields".to_string(),
48                ));
49            }
50            DataType::Map(f, ordered) => (f.clone(), *ordered),
51            _ => unreachable!(),
52        };
53
54        let key_value_fields = match entries_field.data_type() {
55            DataType::Struct(fields) if fields.len() == 2 => fields.clone(),
56            d => {
57                return Err(ArrowError::InvalidArgumentError(format!(
58                    "MapArray must contain struct with two fields, got {d}"
59                )));
60            }
61        };
62
63        let keys = ctx.make_decoder(
64            key_value_fields[0].data_type(),
65            key_value_fields[0].is_nullable(),
66        )?;
67        let values = ctx.make_decoder(
68            key_value_fields[1].data_type(),
69            key_value_fields[1].is_nullable(),
70        )?;
71
72        Ok(Self {
73            entries_field,
74            key_value_fields,
75            ordered,
76            keys,
77            values,
78            ignore_type_conflicts: ctx.ignore_type_conflicts(),
79            is_nullable,
80        })
81    }
82}
83
84impl ArrayDecoder for MapArrayDecoder {
85    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
86        let mut offsets = BufferBuilder::<i32>::new(pos.len() + 1);
87        offsets.append(0);
88
89        let mut key_pos = Vec::with_capacity(pos.len());
90        let mut value_pos = Vec::with_capacity(pos.len());
91
92        let mut nulls = self.is_nullable.then(|| NullBufferBuilder::new(pos.len()));
93
94        for p in pos.iter().copied() {
95            let end_idx = match (tape.get(p), nulls.as_mut()) {
96                (TapeElement::StartObject(end_idx), None) => end_idx,
97                (TapeElement::StartObject(end_idx), Some(nulls)) => {
98                    nulls.append_non_null();
99                    end_idx
100                }
101                (TapeElement::Null, Some(nulls)) => {
102                    nulls.append_null();
103                    p + 1
104                }
105                (_, Some(nulls)) if self.ignore_type_conflicts => {
106                    nulls.append_null();
107                    p + 1
108                }
109                _ => return Err(tape.error(p, "{")),
110            };
111
112            let mut cur_idx = p + 1;
113            while cur_idx < end_idx {
114                let key = cur_idx;
115                let value = tape.next(key, "map key")?;
116                cur_idx = tape.next(value, "map value")?;
117
118                key_pos.push(key);
119                value_pos.push(value);
120            }
121
122            let offset = i32::from_usize(key_pos.len()).ok_or_else(|| {
123                ArrowError::JsonError("offset overflow decoding MapArray".to_string())
124            })?;
125            offsets.append(offset)
126        }
127
128        assert_eq!(key_pos.len(), value_pos.len());
129
130        let key_array = self.keys.decode(tape, &key_pos)?;
131        let value_array = self.values.decode(tape, &value_pos)?;
132
133        // SAFETY: fields/arrays match the schema, lengths are equal, no nulls
134        let entries = unsafe {
135            StructArray::new_unchecked_with_length(
136                self.key_value_fields.clone(),
137                vec![key_array, value_array],
138                None,
139                key_pos.len(),
140            )
141        };
142
143        let nulls = nulls.as_mut().and_then(|x| x.finish());
144        // SAFETY: offsets are built monotonically starting from 0
145        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets.finish())) };
146
147        let array = MapArray::try_new(
148            self.entries_field.clone(),
149            offsets,
150            entries,
151            nulls,
152            self.ordered,
153        )?;
154        Ok(Arc::new(array))
155    }
156}