Skip to main content

arrow_json/reader/
struct_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24use std::collections::HashMap;
25
26/// Reusable buffer for tape positions, indexed by (field_idx, row_idx).
27/// A value of 0 indicates the field is absent for that row.
28struct FieldTapePositions {
29    data: Vec<u32>,
30    row_count: usize,
31}
32
33impl FieldTapePositions {
34    fn new() -> Self {
35        Self {
36            data: Vec::new(),
37            row_count: 0,
38        }
39    }
40
41    fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
42        let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
43            ArrowError::JsonError(format!(
44                "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
45            ))
46        })?;
47        self.data.clear();
48        self.data.resize(total_len, 0);
49        self.row_count = row_count;
50        Ok(())
51    }
52
53    fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
54        let idx = field_idx
55            .checked_mul(self.row_count)?
56            .checked_add(row_idx)?;
57        *self.data.get_mut(idx)? = pos;
58        Some(())
59    }
60
61    fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
62        self.data[field_idx * self.row_count + row_idx] = pos;
63    }
64
65    fn field_positions(&self, field_idx: usize) -> &[u32] {
66        let start = field_idx * self.row_count;
67        &self.data[start..start + self.row_count]
68    }
69}
70
71pub struct StructArrayDecoder {
72    data_type: DataType,
73    decoders: Vec<Box<dyn ArrayDecoder>>,
74    strict_mode: bool,
75    is_nullable: bool,
76    struct_mode: StructMode,
77    field_name_to_index: Option<HashMap<String, usize>>,
78    field_tape_positions: FieldTapePositions,
79}
80
81impl StructArrayDecoder {
82    pub fn new(
83        ctx: &DecoderContext,
84        data_type: &DataType,
85        is_nullable: bool,
86    ) -> Result<Self, ArrowError> {
87        let fields = struct_fields(data_type);
88        let decoders = fields
89            .iter()
90            .map(|f| {
91                // If this struct nullable, need to permit nullability in child array
92                // StructArrayDecoder::decode verifies that if the child is not nullable
93                // it doesn't contain any nulls not masked by its parent
94                let nullable = f.is_nullable() || is_nullable;
95                ctx.make_decoder(f.data_type(), nullable)
96            })
97            .collect::<Result<Vec<_>, ArrowError>>()?;
98
99        let struct_mode = ctx.struct_mode();
100        let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
101            build_field_index(fields)
102        } else {
103            None
104        };
105
106        Ok(Self {
107            data_type: data_type.clone(),
108            decoders,
109            strict_mode: ctx.strict_mode(),
110            is_nullable,
111            struct_mode,
112            field_name_to_index,
113            field_tape_positions: FieldTapePositions::new(),
114        })
115    }
116}
117
118impl ArrayDecoder for StructArrayDecoder {
119    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
120        let fields = struct_fields(&self.data_type);
121        let row_count = pos.len();
122        let field_count = fields.len();
123        self.field_tape_positions.resize(field_count, row_count)?;
124        let mut nulls = self
125            .is_nullable
126            .then(|| BooleanBufferBuilder::new(pos.len()));
127
128        {
129            // We avoid having the match on self.struct_mode inside the hot loop for performance
130            // TODO: Investigate how to extract duplicated logic.
131            match self.struct_mode {
132                StructMode::ObjectOnly => {
133                    for (row, p) in pos.iter().enumerate() {
134                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
135                            (TapeElement::StartObject(end_idx), None) => end_idx,
136                            (TapeElement::StartObject(end_idx), Some(nulls)) => {
137                                nulls.append(true);
138                                end_idx
139                            }
140                            (TapeElement::Null, Some(nulls)) => {
141                                nulls.append(false);
142                                continue;
143                            }
144                            (_, _) => return Err(tape.error(*p, "{")),
145                        };
146
147                        let mut cur_idx = *p + 1;
148                        while cur_idx < end_idx {
149                            // Read field name
150                            let field_name = match tape.get(cur_idx) {
151                                TapeElement::String(s) => tape.get_string(s),
152                                _ => return Err(tape.error(cur_idx, "field name")),
153                            };
154
155                            // Update child pos if match found
156                            let field_idx = match &self.field_name_to_index {
157                                Some(map) => map.get(field_name).copied(),
158                                None => fields.iter().position(|x| x.name() == field_name),
159                            };
160                            match field_idx {
161                                Some(field_idx) => {
162                                    self.field_tape_positions.set(field_idx, row, cur_idx + 1);
163                                }
164                                None => {
165                                    if self.strict_mode {
166                                        return Err(ArrowError::JsonError(format!(
167                                            "column '{field_name}' missing from schema",
168                                        )));
169                                    }
170                                }
171                            }
172                            // Advance to next field
173                            cur_idx = tape.next(cur_idx + 1, "field value")?;
174                        }
175                    }
176                }
177                StructMode::ListOnly => {
178                    for (row, p) in pos.iter().enumerate() {
179                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
180                            (TapeElement::StartList(end_idx), None) => end_idx,
181                            (TapeElement::StartList(end_idx), Some(nulls)) => {
182                                nulls.append(true);
183                                end_idx
184                            }
185                            (TapeElement::Null, Some(nulls)) => {
186                                nulls.append(false);
187                                continue;
188                            }
189                            (_, _) => return Err(tape.error(*p, "[")),
190                        };
191
192                        let mut cur_idx = *p + 1;
193                        let mut entry_idx = 0;
194                        while cur_idx < end_idx {
195                            self.field_tape_positions
196                                .try_set(entry_idx, row, cur_idx)
197                                .ok_or_else(|| {
198                                    ArrowError::JsonError(format!(
199                                        "found extra columns for {} fields",
200                                        fields.len()
201                                    ))
202                                })?;
203                            entry_idx += 1;
204                            // Advance to next field
205                            cur_idx = tape.next(cur_idx, "field value")?;
206                        }
207                        if entry_idx != fields.len() {
208                            return Err(ArrowError::JsonError(format!(
209                                "found {} columns for {} fields",
210                                entry_idx,
211                                fields.len()
212                            )));
213                        }
214                    }
215                }
216            }
217        }
218
219        let child_data = self
220            .decoders
221            .iter_mut()
222            .enumerate()
223            .zip(fields)
224            .map(|((field_idx, d), f)| {
225                let pos = self.field_tape_positions.field_positions(field_idx);
226                d.decode(tape, pos).map_err(|e| match e {
227                    ArrowError::JsonError(s) => {
228                        ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
229                    }
230                    e => e,
231                })
232            })
233            .collect::<Result<Vec<_>, ArrowError>>()?;
234
235        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
236
237        for (c, f) in child_data.iter().zip(fields) {
238            // Sanity check
239            assert_eq!(c.len(), pos.len());
240            if let Some(a) = c.nulls() {
241                let nulls_valid =
242                    f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
243
244                if !nulls_valid {
245                    return Err(ArrowError::JsonError(format!(
246                        "Encountered unmasked nulls in non-nullable StructArray child: {f}"
247                    )));
248                }
249            }
250        }
251
252        let data = ArrayDataBuilder::new(self.data_type.clone())
253            .len(pos.len())
254            .nulls(nulls)
255            .child_data(child_data);
256
257        // Safety
258        // Validated lengths above
259        Ok(unsafe { data.build_unchecked() })
260    }
261}
262
263fn struct_fields(data_type: &DataType) -> &Fields {
264    match &data_type {
265        DataType::Struct(f) => f,
266        _ => unreachable!(),
267    }
268}
269
270fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
271    // Heuristic threshold: for small field counts, linear scan avoids HashMap overhead.
272    const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
273    if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
274        return None;
275    }
276
277    let mut map = HashMap::with_capacity(fields.len());
278    for (idx, field) in fields.iter().enumerate() {
279        let name = field.name();
280        if !map.contains_key(name) {
281            map.insert(name.to_string(), idx);
282        }
283    }
284    Some(map)
285}