arrow_json/reader/
map_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{make_decoder, ArrayDecoder};
20use crate::StructMode;
21use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder};
22use arrow_buffer::buffer::NullBuffer;
23use arrow_buffer::ArrowNativeType;
24use arrow_data::{ArrayData, ArrayDataBuilder};
25use arrow_schema::{ArrowError, DataType};
26
27pub struct MapArrayDecoder {
28    data_type: DataType,
29    keys: Box<dyn ArrayDecoder>,
30    values: Box<dyn ArrayDecoder>,
31    is_nullable: bool,
32}
33
34impl MapArrayDecoder {
35    pub fn new(
36        data_type: DataType,
37        coerce_primitive: bool,
38        strict_mode: bool,
39        is_nullable: bool,
40        struct_mode: StructMode,
41    ) -> Result<Self, ArrowError> {
42        let fields = match &data_type {
43            DataType::Map(_, true) => {
44                return Err(ArrowError::NotYetImplemented(
45                    "Decoding MapArray with sorted fields".to_string(),
46                ))
47            }
48            DataType::Map(f, _) => match f.data_type() {
49                DataType::Struct(fields) if fields.len() == 2 => fields,
50                d => {
51                    return Err(ArrowError::InvalidArgumentError(format!(
52                        "MapArray must contain struct with two fields, got {d}"
53                    )))
54                }
55            },
56            _ => unreachable!(),
57        };
58
59        let keys = make_decoder(
60            fields[0].data_type().clone(),
61            coerce_primitive,
62            strict_mode,
63            fields[0].is_nullable(),
64            struct_mode,
65        )?;
66        let values = make_decoder(
67            fields[1].data_type().clone(),
68            coerce_primitive,
69            strict_mode,
70            fields[1].is_nullable(),
71            struct_mode,
72        )?;
73
74        Ok(Self {
75            data_type,
76            keys,
77            values,
78            is_nullable,
79        })
80    }
81}
82
83impl ArrayDecoder for MapArrayDecoder {
84    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
85        let s = match &self.data_type {
86            DataType::Map(f, _) => match f.data_type() {
87                s @ DataType::Struct(_) => s,
88                _ => unreachable!(),
89            },
90            _ => unreachable!(),
91        };
92
93        let mut offsets = BufferBuilder::<i32>::new(pos.len() + 1);
94        offsets.append(0);
95
96        let mut key_pos = Vec::with_capacity(pos.len());
97        let mut value_pos = Vec::with_capacity(pos.len());
98
99        let mut nulls = self
100            .is_nullable
101            .then(|| BooleanBufferBuilder::new(pos.len()));
102
103        for p in pos.iter().copied() {
104            let end_idx = match (tape.get(p), nulls.as_mut()) {
105                (TapeElement::StartObject(end_idx), None) => end_idx,
106                (TapeElement::StartObject(end_idx), Some(nulls)) => {
107                    nulls.append(true);
108                    end_idx
109                }
110                (TapeElement::Null, Some(nulls)) => {
111                    nulls.append(false);
112                    p + 1
113                }
114                _ => return Err(tape.error(p, "{")),
115            };
116
117            let mut cur_idx = p + 1;
118            while cur_idx < end_idx {
119                let key = cur_idx;
120                let value = tape.next(key, "map key")?;
121                cur_idx = tape.next(value, "map value")?;
122
123                key_pos.push(key);
124                value_pos.push(value);
125            }
126
127            let offset = i32::from_usize(key_pos.len()).ok_or_else(|| {
128                ArrowError::JsonError(format!("offset overflow decoding {}", self.data_type))
129            })?;
130            offsets.append(offset)
131        }
132
133        assert_eq!(key_pos.len(), value_pos.len());
134
135        let key_data = self.keys.decode(tape, &key_pos)?;
136        let value_data = self.values.decode(tape, &value_pos)?;
137
138        let struct_data = ArrayDataBuilder::new(s.clone())
139            .len(key_pos.len())
140            .child_data(vec![key_data, value_data]);
141
142        // Safety:
143        // Valid by construction
144        let struct_data = unsafe { struct_data.build_unchecked() };
145
146        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
147
148        let builder = ArrayDataBuilder::new(self.data_type.clone())
149            .len(pos.len())
150            .buffers(vec![offsets.finish()])
151            .nulls(nulls)
152            .child_data(vec![struct_data]);
153
154        // Safety:
155        // Valid by construction
156        Ok(unsafe { builder.build_unchecked() })
157    }
158}