Skip to main content

arrow_json/reader/
struct_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::builder::BooleanBufferBuilder;
22use arrow_array::{Array, ArrayRef, StructArray};
23use arrow_buffer::buffer::NullBuffer;
24use arrow_schema::{ArrowError, DataType, Fields};
25
26use crate::reader::tape::{Tape, TapeElement};
27use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
28
29/// Reusable buffer for tape positions, indexed by (field_idx, row_idx).
30/// A value of 0 indicates the field is absent for that row.
31struct FieldTapePositions {
32    data: Vec<u32>,
33    row_count: usize,
34}
35
36impl FieldTapePositions {
37    fn new() -> Self {
38        Self {
39            data: Vec::new(),
40            row_count: 0,
41        }
42    }
43
44    fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
45        let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
46            ArrowError::JsonError(format!(
47                "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
48            ))
49        })?;
50        self.data.clear();
51        self.data.resize(total_len, 0);
52        self.row_count = row_count;
53        Ok(())
54    }
55
56    fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
57        let idx = field_idx
58            .checked_mul(self.row_count)?
59            .checked_add(row_idx)?;
60        *self.data.get_mut(idx)? = pos;
61        Some(())
62    }
63
64    fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
65        self.data[field_idx * self.row_count + row_idx] = pos;
66    }
67
68    fn field_positions(&self, field_idx: usize) -> &[u32] {
69        let start = field_idx * self.row_count;
70        &self.data[start..start + self.row_count]
71    }
72}
73
74pub struct StructArrayDecoder {
75    data_type: DataType,
76    decoders: Vec<Box<dyn ArrayDecoder>>,
77    strict_mode: bool,
78    is_nullable: bool,
79    struct_mode: StructMode,
80    field_name_to_index: Option<HashMap<String, usize>>,
81    field_tape_positions: FieldTapePositions,
82}
83
84impl StructArrayDecoder {
85    pub fn new(
86        ctx: &DecoderContext,
87        data_type: &DataType,
88        is_nullable: bool,
89    ) -> Result<Self, ArrowError> {
90        let fields = struct_fields(data_type);
91        let decoders = fields
92            .iter()
93            .map(|f| {
94                // If this struct nullable, need to permit nullability in child array
95                // StructArrayDecoder::decode verifies that if the child is not nullable
96                // it doesn't contain any nulls not masked by its parent
97                let nullable = f.is_nullable() || is_nullable;
98                ctx.make_decoder(f.data_type(), nullable)
99            })
100            .collect::<Result<Vec<_>, ArrowError>>()?;
101
102        let struct_mode = ctx.struct_mode();
103        let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
104            build_field_index(fields)
105        } else {
106            None
107        };
108
109        Ok(Self {
110            data_type: data_type.clone(),
111            decoders,
112            strict_mode: ctx.strict_mode(),
113            is_nullable,
114            struct_mode,
115            field_name_to_index,
116            field_tape_positions: FieldTapePositions::new(),
117        })
118    }
119}
120
121impl ArrayDecoder for StructArrayDecoder {
122    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
123        let fields = struct_fields(&self.data_type);
124        let row_count = pos.len();
125        let field_count = fields.len();
126        self.field_tape_positions.resize(field_count, row_count)?;
127        let mut nulls = self
128            .is_nullable
129            .then(|| BooleanBufferBuilder::new(pos.len()));
130
131        {
132            // We avoid having the match on self.struct_mode inside the hot loop for performance
133            // TODO: Investigate how to extract duplicated logic.
134            match self.struct_mode {
135                StructMode::ObjectOnly => {
136                    for (row, p) in pos.iter().enumerate() {
137                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
138                            (TapeElement::StartObject(end_idx), None) => end_idx,
139                            (TapeElement::StartObject(end_idx), Some(nulls)) => {
140                                nulls.append(true);
141                                end_idx
142                            }
143                            (TapeElement::Null, Some(nulls)) => {
144                                nulls.append(false);
145                                continue;
146                            }
147                            (_, _) => return Err(tape.error(*p, "{")),
148                        };
149
150                        let mut cur_idx = *p + 1;
151                        while cur_idx < end_idx {
152                            // Read field name
153                            let field_name = match tape.get(cur_idx) {
154                                TapeElement::String(s) => tape.get_string(s),
155                                _ => return Err(tape.error(cur_idx, "field name")),
156                            };
157
158                            // Update child pos if match found
159                            let field_idx = match &self.field_name_to_index {
160                                Some(map) => map.get(field_name).copied(),
161                                None => fields.iter().position(|x| x.name() == field_name),
162                            };
163                            match field_idx {
164                                Some(field_idx) => {
165                                    self.field_tape_positions.set(field_idx, row, cur_idx + 1);
166                                }
167                                None => {
168                                    if self.strict_mode {
169                                        return Err(ArrowError::JsonError(format!(
170                                            "column '{field_name}' missing from schema",
171                                        )));
172                                    }
173                                }
174                            }
175                            // Advance to next field
176                            cur_idx = tape.next(cur_idx + 1, "field value")?;
177                        }
178                    }
179                }
180                StructMode::ListOnly => {
181                    for (row, p) in pos.iter().enumerate() {
182                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
183                            (TapeElement::StartList(end_idx), None) => end_idx,
184                            (TapeElement::StartList(end_idx), Some(nulls)) => {
185                                nulls.append(true);
186                                end_idx
187                            }
188                            (TapeElement::Null, Some(nulls)) => {
189                                nulls.append(false);
190                                continue;
191                            }
192                            (_, _) => return Err(tape.error(*p, "[")),
193                        };
194
195                        let mut cur_idx = *p + 1;
196                        let mut entry_idx = 0;
197                        while cur_idx < end_idx {
198                            self.field_tape_positions
199                                .try_set(entry_idx, row, cur_idx)
200                                .ok_or_else(|| {
201                                    ArrowError::JsonError(format!(
202                                        "found extra columns for {} fields",
203                                        fields.len()
204                                    ))
205                                })?;
206                            entry_idx += 1;
207                            // Advance to next field
208                            cur_idx = tape.next(cur_idx, "field value")?;
209                        }
210                        if entry_idx != fields.len() {
211                            return Err(ArrowError::JsonError(format!(
212                                "found {} columns for {} fields",
213                                entry_idx,
214                                fields.len()
215                            )));
216                        }
217                    }
218                }
219            }
220        }
221
222        let child_arrays = self
223            .decoders
224            .iter_mut()
225            .enumerate()
226            .zip(fields)
227            .map(|((field_idx, d), f)| {
228                let pos = self.field_tape_positions.field_positions(field_idx);
229                d.decode(tape, pos).map_err(|e| match e {
230                    ArrowError::JsonError(s) => {
231                        ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
232                    }
233                    e => e,
234                })
235            })
236            .collect::<Result<Vec<_>, ArrowError>>()?;
237
238        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
239
240        for (c, f) in child_arrays.iter().zip(fields) {
241            // Sanity check
242            assert_eq!(c.len(), pos.len());
243            if let Some(a) = c.nulls() {
244                let nulls_valid =
245                    f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
246
247                if !nulls_valid {
248                    return Err(ArrowError::JsonError(format!(
249                        "Encountered unmasked nulls in non-nullable StructArray child: {f}"
250                    )));
251                }
252            }
253        }
254
255        // SAFETY: fields, child array lengths, and nullability are validated above
256        let array = unsafe {
257            StructArray::new_unchecked_with_length(fields.clone(), child_arrays, nulls, row_count)
258        };
259        Ok(Arc::new(array))
260    }
261}
262
263fn struct_fields(data_type: &DataType) -> &Fields {
264    match &data_type {
265        DataType::Struct(f) => f,
266        _ => unreachable!(),
267    }
268}
269
270fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
271    // Heuristic threshold: for small field counts, linear scan avoids HashMap overhead.
272    const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
273    if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
274        return None;
275    }
276
277    let mut map = HashMap::with_capacity(fields.len());
278    for (idx, field) in fields.iter().enumerate() {
279        let name = field.name();
280        if !map.contains_key(name) {
281            map.insert(name.to_string(), idx);
282        }
283    }
284    Some(map)
285}