Skip to main content

arrow_json/reader/
map_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{ArrayDecoder, DecoderContext};
20use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder};
21use arrow_buffer::ArrowNativeType;
22use arrow_buffer::buffer::NullBuffer;
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType};
25
26pub struct MapArrayDecoder {
27    data_type: DataType,
28    keys: Box<dyn ArrayDecoder>,
29    values: Box<dyn ArrayDecoder>,
30    is_nullable: bool,
31}
32
33impl MapArrayDecoder {
34    pub fn new(
35        ctx: &DecoderContext,
36        data_type: &DataType,
37        is_nullable: bool,
38    ) -> Result<Self, ArrowError> {
39        let fields = match data_type {
40            DataType::Map(_, true) => {
41                return Err(ArrowError::NotYetImplemented(
42                    "Decoding MapArray with sorted fields".to_string(),
43                ));
44            }
45            DataType::Map(f, _) => match f.data_type() {
46                DataType::Struct(fields) if fields.len() == 2 => fields,
47                d => {
48                    return Err(ArrowError::InvalidArgumentError(format!(
49                        "MapArray must contain struct with two fields, got {d}"
50                    )));
51                }
52            },
53            _ => unreachable!(),
54        };
55
56        let keys = ctx.make_decoder(fields[0].data_type(), fields[0].is_nullable())?;
57        let values = ctx.make_decoder(fields[1].data_type(), fields[1].is_nullable())?;
58
59        Ok(Self {
60            data_type: data_type.clone(),
61            keys,
62            values,
63            is_nullable,
64        })
65    }
66}
67
68impl ArrayDecoder for MapArrayDecoder {
69    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
70        let s = match &self.data_type {
71            DataType::Map(f, _) => match f.data_type() {
72                s @ DataType::Struct(_) => s,
73                _ => unreachable!(),
74            },
75            _ => unreachable!(),
76        };
77
78        let mut offsets = BufferBuilder::<i32>::new(pos.len() + 1);
79        offsets.append(0);
80
81        let mut key_pos = Vec::with_capacity(pos.len());
82        let mut value_pos = Vec::with_capacity(pos.len());
83
84        let mut nulls = self
85            .is_nullable
86            .then(|| BooleanBufferBuilder::new(pos.len()));
87
88        for p in pos.iter().copied() {
89            let end_idx = match (tape.get(p), nulls.as_mut()) {
90                (TapeElement::StartObject(end_idx), None) => end_idx,
91                (TapeElement::StartObject(end_idx), Some(nulls)) => {
92                    nulls.append(true);
93                    end_idx
94                }
95                (TapeElement::Null, Some(nulls)) => {
96                    nulls.append(false);
97                    p + 1
98                }
99                _ => return Err(tape.error(p, "{")),
100            };
101
102            let mut cur_idx = p + 1;
103            while cur_idx < end_idx {
104                let key = cur_idx;
105                let value = tape.next(key, "map key")?;
106                cur_idx = tape.next(value, "map value")?;
107
108                key_pos.push(key);
109                value_pos.push(value);
110            }
111
112            let offset = i32::from_usize(key_pos.len()).ok_or_else(|| {
113                ArrowError::JsonError(format!("offset overflow decoding {}", self.data_type))
114            })?;
115            offsets.append(offset)
116        }
117
118        assert_eq!(key_pos.len(), value_pos.len());
119
120        let key_data = self.keys.decode(tape, &key_pos)?;
121        let value_data = self.values.decode(tape, &value_pos)?;
122
123        let struct_data = ArrayDataBuilder::new(s.clone())
124            .len(key_pos.len())
125            .child_data(vec![key_data, value_data]);
126
127        // Safety:
128        // Valid by construction
129        let struct_data = unsafe { struct_data.build_unchecked() };
130
131        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
132
133        let builder = ArrayDataBuilder::new(self.data_type.clone())
134            .len(pos.len())
135            .buffers(vec![offsets.finish()])
136            .nulls(nulls)
137            .child_data(vec![struct_data]);
138
139        // Safety:
140        // Valid by construction
141        Ok(unsafe { builder.build_unchecked() })
142    }
143}