Skip to main content

arrow_json/reader/
struct_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::{Array, ArrayRef, StructArray};
22use arrow_buffer::NullBufferBuilder;
23use arrow_schema::{ArrowError, DataType, Fields};
24
25use crate::reader::tape::{Tape, TapeElement};
26use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
27
28/// Reusable buffer for tape positions, indexed by (field_idx, row_idx).
29/// A value of 0 indicates the field is absent for that row.
30struct FieldTapePositions {
31    data: Vec<u32>,
32    row_count: usize,
33}
34
35impl FieldTapePositions {
36    fn new() -> Self {
37        Self {
38            data: Vec::new(),
39            row_count: 0,
40        }
41    }
42
43    fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
44        let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
45            ArrowError::JsonError(format!(
46                "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
47            ))
48        })?;
49        self.data.clear();
50        self.data.resize(total_len, 0);
51        self.row_count = row_count;
52        Ok(())
53    }
54
55    fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
56        let idx = field_idx
57            .checked_mul(self.row_count)?
58            .checked_add(row_idx)?;
59        *self.data.get_mut(idx)? = pos;
60        Some(())
61    }
62
63    fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
64        self.data[field_idx * self.row_count + row_idx] = pos;
65    }
66
67    fn field_positions(&self, field_idx: usize) -> &[u32] {
68        let start = field_idx * self.row_count;
69        &self.data[start..start + self.row_count]
70    }
71}
72
73pub struct StructArrayDecoder {
74    data_type: DataType,
75    decoders: Vec<Box<dyn ArrayDecoder>>,
76    strict_mode: bool,
77    ignore_type_conflicts: bool,
78    is_nullable: bool,
79    struct_mode: StructMode,
80    field_name_to_index: Option<HashMap<String, usize>>,
81    field_tape_positions: FieldTapePositions,
82}
83
84impl StructArrayDecoder {
85    pub fn new(
86        ctx: &DecoderContext,
87        data_type: &DataType,
88        is_nullable: bool,
89    ) -> Result<Self, ArrowError> {
90        let fields = struct_fields(data_type);
91        let decoders = fields
92            .iter()
93            .map(|f| {
94                // If this struct nullable, need to permit nullability in child array
95                // StructArrayDecoder::decode verifies that if the child is not nullable
96                // it doesn't contain any nulls not masked by its parent
97                let nullable = f.is_nullable() || is_nullable;
98                ctx.make_decoder(f.data_type(), nullable)
99            })
100            .collect::<Result<Vec<_>, ArrowError>>()?;
101
102        let struct_mode = ctx.struct_mode();
103        let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
104            build_field_index(fields)
105        } else {
106            None
107        };
108
109        Ok(Self {
110            data_type: data_type.clone(),
111            decoders,
112            strict_mode: ctx.strict_mode(),
113            ignore_type_conflicts: ctx.ignore_type_conflicts(),
114            is_nullable,
115            struct_mode,
116            field_name_to_index,
117            field_tape_positions: FieldTapePositions::new(),
118        })
119    }
120}
121
122impl ArrayDecoder for StructArrayDecoder {
123    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
124        let fields = struct_fields(&self.data_type);
125        let row_count = pos.len();
126        let field_count = fields.len();
127        self.field_tape_positions.resize(field_count, row_count)?;
128        let mut nulls = self.is_nullable.then(|| NullBufferBuilder::new(pos.len()));
129
130        {
131            // We avoid having the match on self.struct_mode inside the hot loop for performance
132            // TODO: Investigate how to extract duplicated logic.
133            match self.struct_mode {
134                StructMode::ObjectOnly => {
135                    for (row, p) in pos.iter().enumerate() {
136                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
137                            (TapeElement::StartObject(end_idx), None) => end_idx,
138                            (TapeElement::StartObject(end_idx), Some(nulls)) => {
139                                nulls.append_non_null();
140                                end_idx
141                            }
142                            (TapeElement::Null, Some(nulls)) => {
143                                nulls.append_null();
144                                continue;
145                            }
146                            (_, Some(nulls)) if self.ignore_type_conflicts => {
147                                nulls.append_null();
148                                continue;
149                            }
150                            (_, _) => return Err(tape.error(*p, "{")),
151                        };
152
153                        let mut cur_idx = *p + 1;
154                        while cur_idx < end_idx {
155                            // Read field name
156                            let field_name = match tape.get(cur_idx) {
157                                TapeElement::String(s) => tape.get_string(s),
158                                _ => return Err(tape.error(cur_idx, "field name")),
159                            };
160
161                            // Update child pos if match found
162                            let field_idx = match &self.field_name_to_index {
163                                Some(map) => map.get(field_name).copied(),
164                                None => fields.iter().position(|x| x.name() == field_name),
165                            };
166                            match field_idx {
167                                Some(field_idx) => {
168                                    self.field_tape_positions.set(field_idx, row, cur_idx + 1);
169                                }
170                                None => {
171                                    if self.strict_mode {
172                                        return Err(ArrowError::JsonError(format!(
173                                            "column '{field_name}' missing from schema",
174                                        )));
175                                    }
176                                }
177                            }
178                            // Advance to next field
179                            cur_idx = tape.next(cur_idx + 1, "field value")?;
180                        }
181                    }
182                }
183                StructMode::ListOnly => {
184                    for (row, p) in pos.iter().enumerate() {
185                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
186                            (TapeElement::StartList(end_idx), None) => end_idx,
187                            (TapeElement::StartList(end_idx), Some(nulls)) => {
188                                nulls.append_non_null();
189                                end_idx
190                            }
191                            (TapeElement::Null, Some(nulls)) => {
192                                nulls.append_null();
193                                continue;
194                            }
195                            (_, Some(nulls)) if self.ignore_type_conflicts => {
196                                nulls.append_null();
197                                continue;
198                            }
199                            (_, _) => return Err(tape.error(*p, "[")),
200                        };
201
202                        let mut cur_idx = *p + 1;
203                        let mut entry_idx = 0;
204                        while cur_idx < end_idx {
205                            self.field_tape_positions
206                                .try_set(entry_idx, row, cur_idx)
207                                .ok_or_else(|| {
208                                    ArrowError::JsonError(format!(
209                                        "found extra columns for {} fields",
210                                        fields.len()
211                                    ))
212                                })?;
213                            entry_idx += 1;
214                            // Advance to next field
215                            cur_idx = tape.next(cur_idx, "field value")?;
216                        }
217                        if entry_idx != fields.len() {
218                            return Err(ArrowError::JsonError(format!(
219                                "found {} columns for {} fields",
220                                entry_idx,
221                                fields.len()
222                            )));
223                        }
224                    }
225                }
226            }
227        }
228
229        let child_arrays = self
230            .decoders
231            .iter_mut()
232            .enumerate()
233            .zip(fields)
234            .map(|((field_idx, d), f)| {
235                let pos = self.field_tape_positions.field_positions(field_idx);
236                d.decode(tape, pos).map_err(|e| match e {
237                    ArrowError::JsonError(s) => {
238                        ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
239                    }
240                    e => e,
241                })
242            })
243            .collect::<Result<Vec<_>, ArrowError>>()?;
244
245        let nulls = nulls.as_mut().and_then(|x| x.finish());
246
247        for (c, f) in child_arrays.iter().zip(fields) {
248            // Sanity check
249            assert_eq!(c.len(), pos.len());
250            if let Some(a) = c.nulls() {
251                let nulls_valid =
252                    f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
253
254                if !nulls_valid {
255                    return Err(ArrowError::JsonError(format!(
256                        "Encountered unmasked nulls in non-nullable StructArray child: {f}"
257                    )));
258                }
259            }
260        }
261
262        // SAFETY: fields, child array lengths, and nullability are validated above
263        let array = unsafe {
264            StructArray::new_unchecked_with_length(fields.clone(), child_arrays, nulls, row_count)
265        };
266        Ok(Arc::new(array))
267    }
268}
269
270fn struct_fields(data_type: &DataType) -> &Fields {
271    match &data_type {
272        DataType::Struct(f) => f,
273        _ => unreachable!(),
274    }
275}
276
277fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
278    // Heuristic threshold: for small field counts, linear scan avoids HashMap overhead.
279    const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
280    if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
281        return None;
282    }
283
284    let mut map = HashMap::with_capacity(fields.len());
285    for (idx, field) in fields.iter().enumerate() {
286        let name = field.name();
287        if !map.contains_key(name) {
288            map.insert(name.to_string(), idx);
289        }
290    }
291    Some(map)
292}