1use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24use std::collections::HashMap;
25
26struct FieldTapePositions {
29 data: Vec<u32>,
30 row_count: usize,
31}
32
33impl FieldTapePositions {
34 fn new() -> Self {
35 Self {
36 data: Vec::new(),
37 row_count: 0,
38 }
39 }
40
41 fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
42 let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
43 ArrowError::JsonError(format!(
44 "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
45 ))
46 })?;
47 self.data.clear();
48 self.data.resize(total_len, 0);
49 self.row_count = row_count;
50 Ok(())
51 }
52
53 fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
54 let idx = field_idx
55 .checked_mul(self.row_count)?
56 .checked_add(row_idx)?;
57 *self.data.get_mut(idx)? = pos;
58 Some(())
59 }
60
61 fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
62 self.data[field_idx * self.row_count + row_idx] = pos;
63 }
64
65 fn field_positions(&self, field_idx: usize) -> &[u32] {
66 let start = field_idx * self.row_count;
67 &self.data[start..start + self.row_count]
68 }
69}
70
71pub struct StructArrayDecoder {
72 data_type: DataType,
73 decoders: Vec<Box<dyn ArrayDecoder>>,
74 strict_mode: bool,
75 is_nullable: bool,
76 struct_mode: StructMode,
77 field_name_to_index: Option<HashMap<String, usize>>,
78 field_tape_positions: FieldTapePositions,
79}
80
81impl StructArrayDecoder {
82 pub fn new(
83 ctx: &DecoderContext,
84 data_type: &DataType,
85 is_nullable: bool,
86 ) -> Result<Self, ArrowError> {
87 let fields = struct_fields(data_type);
88 let decoders = fields
89 .iter()
90 .map(|f| {
91 let nullable = f.is_nullable() || is_nullable;
95 ctx.make_decoder(f.data_type(), nullable)
96 })
97 .collect::<Result<Vec<_>, ArrowError>>()?;
98
99 let struct_mode = ctx.struct_mode();
100 let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
101 build_field_index(fields)
102 } else {
103 None
104 };
105
106 Ok(Self {
107 data_type: data_type.clone(),
108 decoders,
109 strict_mode: ctx.strict_mode(),
110 is_nullable,
111 struct_mode,
112 field_name_to_index,
113 field_tape_positions: FieldTapePositions::new(),
114 })
115 }
116}
117
118impl ArrayDecoder for StructArrayDecoder {
119 fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
120 let fields = struct_fields(&self.data_type);
121 let row_count = pos.len();
122 let field_count = fields.len();
123 self.field_tape_positions.resize(field_count, row_count)?;
124 let mut nulls = self
125 .is_nullable
126 .then(|| BooleanBufferBuilder::new(pos.len()));
127
128 {
129 match self.struct_mode {
132 StructMode::ObjectOnly => {
133 for (row, p) in pos.iter().enumerate() {
134 let end_idx = match (tape.get(*p), nulls.as_mut()) {
135 (TapeElement::StartObject(end_idx), None) => end_idx,
136 (TapeElement::StartObject(end_idx), Some(nulls)) => {
137 nulls.append(true);
138 end_idx
139 }
140 (TapeElement::Null, Some(nulls)) => {
141 nulls.append(false);
142 continue;
143 }
144 (_, _) => return Err(tape.error(*p, "{")),
145 };
146
147 let mut cur_idx = *p + 1;
148 while cur_idx < end_idx {
149 let field_name = match tape.get(cur_idx) {
151 TapeElement::String(s) => tape.get_string(s),
152 _ => return Err(tape.error(cur_idx, "field name")),
153 };
154
155 let field_idx = match &self.field_name_to_index {
157 Some(map) => map.get(field_name).copied(),
158 None => fields.iter().position(|x| x.name() == field_name),
159 };
160 match field_idx {
161 Some(field_idx) => {
162 self.field_tape_positions.set(field_idx, row, cur_idx + 1);
163 }
164 None => {
165 if self.strict_mode {
166 return Err(ArrowError::JsonError(format!(
167 "column '{field_name}' missing from schema",
168 )));
169 }
170 }
171 }
172 cur_idx = tape.next(cur_idx + 1, "field value")?;
174 }
175 }
176 }
177 StructMode::ListOnly => {
178 for (row, p) in pos.iter().enumerate() {
179 let end_idx = match (tape.get(*p), nulls.as_mut()) {
180 (TapeElement::StartList(end_idx), None) => end_idx,
181 (TapeElement::StartList(end_idx), Some(nulls)) => {
182 nulls.append(true);
183 end_idx
184 }
185 (TapeElement::Null, Some(nulls)) => {
186 nulls.append(false);
187 continue;
188 }
189 (_, _) => return Err(tape.error(*p, "[")),
190 };
191
192 let mut cur_idx = *p + 1;
193 let mut entry_idx = 0;
194 while cur_idx < end_idx {
195 self.field_tape_positions
196 .try_set(entry_idx, row, cur_idx)
197 .ok_or_else(|| {
198 ArrowError::JsonError(format!(
199 "found extra columns for {} fields",
200 fields.len()
201 ))
202 })?;
203 entry_idx += 1;
204 cur_idx = tape.next(cur_idx, "field value")?;
206 }
207 if entry_idx != fields.len() {
208 return Err(ArrowError::JsonError(format!(
209 "found {} columns for {} fields",
210 entry_idx,
211 fields.len()
212 )));
213 }
214 }
215 }
216 }
217 }
218
219 let child_data = self
220 .decoders
221 .iter_mut()
222 .enumerate()
223 .zip(fields)
224 .map(|((field_idx, d), f)| {
225 let pos = self.field_tape_positions.field_positions(field_idx);
226 d.decode(tape, pos).map_err(|e| match e {
227 ArrowError::JsonError(s) => {
228 ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
229 }
230 e => e,
231 })
232 })
233 .collect::<Result<Vec<_>, ArrowError>>()?;
234
235 let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
236
237 for (c, f) in child_data.iter().zip(fields) {
238 assert_eq!(c.len(), pos.len());
240 if let Some(a) = c.nulls() {
241 let nulls_valid =
242 f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
243
244 if !nulls_valid {
245 return Err(ArrowError::JsonError(format!(
246 "Encountered unmasked nulls in non-nullable StructArray child: {f}"
247 )));
248 }
249 }
250 }
251
252 let data = ArrayDataBuilder::new(self.data_type.clone())
253 .len(pos.len())
254 .nulls(nulls)
255 .child_data(child_data);
256
257 Ok(unsafe { data.build_unchecked() })
260 }
261}
262
263fn struct_fields(data_type: &DataType) -> &Fields {
264 match &data_type {
265 DataType::Struct(f) => f,
266 _ => unreachable!(),
267 }
268}
269
270fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
271 const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
273 if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
274 return None;
275 }
276
277 let mut map = HashMap::with_capacity(fields.len());
278 for (idx, field) in fields.iter().enumerate() {
279 let name = field.name();
280 if !map.contains_key(name) {
281 map.insert(name.to_string(), idx);
282 }
283 }
284 Some(map)
285}