1use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::builder::BooleanBufferBuilder;
22use arrow_array::{Array, ArrayRef, StructArray};
23use arrow_buffer::buffer::NullBuffer;
24use arrow_schema::{ArrowError, DataType, Fields};
25
26use crate::reader::tape::{Tape, TapeElement};
27use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
28
29struct FieldTapePositions {
32 data: Vec<u32>,
33 row_count: usize,
34}
35
36impl FieldTapePositions {
37 fn new() -> Self {
38 Self {
39 data: Vec::new(),
40 row_count: 0,
41 }
42 }
43
44 fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
45 let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
46 ArrowError::JsonError(format!(
47 "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
48 ))
49 })?;
50 self.data.clear();
51 self.data.resize(total_len, 0);
52 self.row_count = row_count;
53 Ok(())
54 }
55
56 fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
57 let idx = field_idx
58 .checked_mul(self.row_count)?
59 .checked_add(row_idx)?;
60 *self.data.get_mut(idx)? = pos;
61 Some(())
62 }
63
64 fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
65 self.data[field_idx * self.row_count + row_idx] = pos;
66 }
67
68 fn field_positions(&self, field_idx: usize) -> &[u32] {
69 let start = field_idx * self.row_count;
70 &self.data[start..start + self.row_count]
71 }
72}
73
74pub struct StructArrayDecoder {
75 data_type: DataType,
76 decoders: Vec<Box<dyn ArrayDecoder>>,
77 strict_mode: bool,
78 is_nullable: bool,
79 struct_mode: StructMode,
80 field_name_to_index: Option<HashMap<String, usize>>,
81 field_tape_positions: FieldTapePositions,
82}
83
84impl StructArrayDecoder {
85 pub fn new(
86 ctx: &DecoderContext,
87 data_type: &DataType,
88 is_nullable: bool,
89 ) -> Result<Self, ArrowError> {
90 let fields = struct_fields(data_type);
91 let decoders = fields
92 .iter()
93 .map(|f| {
94 let nullable = f.is_nullable() || is_nullable;
98 ctx.make_decoder(f.data_type(), nullable)
99 })
100 .collect::<Result<Vec<_>, ArrowError>>()?;
101
102 let struct_mode = ctx.struct_mode();
103 let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
104 build_field_index(fields)
105 } else {
106 None
107 };
108
109 Ok(Self {
110 data_type: data_type.clone(),
111 decoders,
112 strict_mode: ctx.strict_mode(),
113 is_nullable,
114 struct_mode,
115 field_name_to_index,
116 field_tape_positions: FieldTapePositions::new(),
117 })
118 }
119}
120
121impl ArrayDecoder for StructArrayDecoder {
122 fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
123 let fields = struct_fields(&self.data_type);
124 let row_count = pos.len();
125 let field_count = fields.len();
126 self.field_tape_positions.resize(field_count, row_count)?;
127 let mut nulls = self
128 .is_nullable
129 .then(|| BooleanBufferBuilder::new(pos.len()));
130
131 {
132 match self.struct_mode {
135 StructMode::ObjectOnly => {
136 for (row, p) in pos.iter().enumerate() {
137 let end_idx = match (tape.get(*p), nulls.as_mut()) {
138 (TapeElement::StartObject(end_idx), None) => end_idx,
139 (TapeElement::StartObject(end_idx), Some(nulls)) => {
140 nulls.append(true);
141 end_idx
142 }
143 (TapeElement::Null, Some(nulls)) => {
144 nulls.append(false);
145 continue;
146 }
147 (_, _) => return Err(tape.error(*p, "{")),
148 };
149
150 let mut cur_idx = *p + 1;
151 while cur_idx < end_idx {
152 let field_name = match tape.get(cur_idx) {
154 TapeElement::String(s) => tape.get_string(s),
155 _ => return Err(tape.error(cur_idx, "field name")),
156 };
157
158 let field_idx = match &self.field_name_to_index {
160 Some(map) => map.get(field_name).copied(),
161 None => fields.iter().position(|x| x.name() == field_name),
162 };
163 match field_idx {
164 Some(field_idx) => {
165 self.field_tape_positions.set(field_idx, row, cur_idx + 1);
166 }
167 None => {
168 if self.strict_mode {
169 return Err(ArrowError::JsonError(format!(
170 "column '{field_name}' missing from schema",
171 )));
172 }
173 }
174 }
175 cur_idx = tape.next(cur_idx + 1, "field value")?;
177 }
178 }
179 }
180 StructMode::ListOnly => {
181 for (row, p) in pos.iter().enumerate() {
182 let end_idx = match (tape.get(*p), nulls.as_mut()) {
183 (TapeElement::StartList(end_idx), None) => end_idx,
184 (TapeElement::StartList(end_idx), Some(nulls)) => {
185 nulls.append(true);
186 end_idx
187 }
188 (TapeElement::Null, Some(nulls)) => {
189 nulls.append(false);
190 continue;
191 }
192 (_, _) => return Err(tape.error(*p, "[")),
193 };
194
195 let mut cur_idx = *p + 1;
196 let mut entry_idx = 0;
197 while cur_idx < end_idx {
198 self.field_tape_positions
199 .try_set(entry_idx, row, cur_idx)
200 .ok_or_else(|| {
201 ArrowError::JsonError(format!(
202 "found extra columns for {} fields",
203 fields.len()
204 ))
205 })?;
206 entry_idx += 1;
207 cur_idx = tape.next(cur_idx, "field value")?;
209 }
210 if entry_idx != fields.len() {
211 return Err(ArrowError::JsonError(format!(
212 "found {} columns for {} fields",
213 entry_idx,
214 fields.len()
215 )));
216 }
217 }
218 }
219 }
220 }
221
222 let child_arrays = self
223 .decoders
224 .iter_mut()
225 .enumerate()
226 .zip(fields)
227 .map(|((field_idx, d), f)| {
228 let pos = self.field_tape_positions.field_positions(field_idx);
229 d.decode(tape, pos).map_err(|e| match e {
230 ArrowError::JsonError(s) => {
231 ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
232 }
233 e => e,
234 })
235 })
236 .collect::<Result<Vec<_>, ArrowError>>()?;
237
238 let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
239
240 for (c, f) in child_arrays.iter().zip(fields) {
241 assert_eq!(c.len(), pos.len());
243 if let Some(a) = c.nulls() {
244 let nulls_valid =
245 f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
246
247 if !nulls_valid {
248 return Err(ArrowError::JsonError(format!(
249 "Encountered unmasked nulls in non-nullable StructArray child: {f}"
250 )));
251 }
252 }
253 }
254
255 let array = unsafe {
257 StructArray::new_unchecked_with_length(fields.clone(), child_arrays, nulls, row_count)
258 };
259 Ok(Arc::new(array))
260 }
261}
262
263fn struct_fields(data_type: &DataType) -> &Fields {
264 match &data_type {
265 DataType::Struct(f) => f,
266 _ => unreachable!(),
267 }
268}
269
270fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
271 const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
273 if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
274 return None;
275 }
276
277 let mut map = HashMap::with_capacity(fields.len());
278 for (idx, field) in fields.iter().enumerate() {
279 let name = field.name();
280 if !map.contains_key(name) {
281 map.insert(name.to_string(), idx);
282 }
283 }
284 Some(map)
285}