1use std::collections::HashMap;
19use std::sync::Arc;
20
21use arrow_array::{Array, ArrayRef, StructArray};
22use arrow_buffer::NullBufferBuilder;
23use arrow_schema::{ArrowError, DataType, Fields};
24
25use crate::reader::tape::{Tape, TapeElement};
26use crate::reader::{ArrayDecoder, DecoderContext, StructMode};
27
28struct FieldTapePositions {
31 data: Vec<u32>,
32 row_count: usize,
33}
34
35impl FieldTapePositions {
36 fn new() -> Self {
37 Self {
38 data: Vec::new(),
39 row_count: 0,
40 }
41 }
42
43 fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> {
44 let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
45 ArrowError::JsonError(format!(
46 "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}"
47 ))
48 })?;
49 self.data.clear();
50 self.data.resize(total_len, 0);
51 self.row_count = row_count;
52 Ok(())
53 }
54
55 fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> {
56 let idx = field_idx
57 .checked_mul(self.row_count)?
58 .checked_add(row_idx)?;
59 *self.data.get_mut(idx)? = pos;
60 Some(())
61 }
62
63 fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) {
64 self.data[field_idx * self.row_count + row_idx] = pos;
65 }
66
67 fn field_positions(&self, field_idx: usize) -> &[u32] {
68 let start = field_idx * self.row_count;
69 &self.data[start..start + self.row_count]
70 }
71}
72
73pub struct StructArrayDecoder {
74 data_type: DataType,
75 decoders: Vec<Box<dyn ArrayDecoder>>,
76 strict_mode: bool,
77 ignore_type_conflicts: bool,
78 is_nullable: bool,
79 struct_mode: StructMode,
80 field_name_to_index: Option<HashMap<String, usize>>,
81 field_tape_positions: FieldTapePositions,
82}
83
84impl StructArrayDecoder {
85 pub fn new(
86 ctx: &DecoderContext,
87 data_type: &DataType,
88 is_nullable: bool,
89 ) -> Result<Self, ArrowError> {
90 let fields = struct_fields(data_type);
91 let decoders = fields
92 .iter()
93 .map(|f| {
94 let nullable = f.is_nullable() || is_nullable;
98 ctx.make_decoder(f.data_type(), nullable)
99 })
100 .collect::<Result<Vec<_>, ArrowError>>()?;
101
102 let struct_mode = ctx.struct_mode();
103 let field_name_to_index = if struct_mode == StructMode::ObjectOnly {
104 build_field_index(fields)
105 } else {
106 None
107 };
108
109 Ok(Self {
110 data_type: data_type.clone(),
111 decoders,
112 strict_mode: ctx.strict_mode(),
113 ignore_type_conflicts: ctx.ignore_type_conflicts(),
114 is_nullable,
115 struct_mode,
116 field_name_to_index,
117 field_tape_positions: FieldTapePositions::new(),
118 })
119 }
120}
121
122impl ArrayDecoder for StructArrayDecoder {
123 fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayRef, ArrowError> {
124 let fields = struct_fields(&self.data_type);
125 let row_count = pos.len();
126 let field_count = fields.len();
127 self.field_tape_positions.resize(field_count, row_count)?;
128 let mut nulls = self.is_nullable.then(|| NullBufferBuilder::new(pos.len()));
129
130 {
131 match self.struct_mode {
134 StructMode::ObjectOnly => {
135 for (row, p) in pos.iter().enumerate() {
136 let end_idx = match (tape.get(*p), nulls.as_mut()) {
137 (TapeElement::StartObject(end_idx), None) => end_idx,
138 (TapeElement::StartObject(end_idx), Some(nulls)) => {
139 nulls.append_non_null();
140 end_idx
141 }
142 (TapeElement::Null, Some(nulls)) => {
143 nulls.append_null();
144 continue;
145 }
146 (_, Some(nulls)) if self.ignore_type_conflicts => {
147 nulls.append_null();
148 continue;
149 }
150 (_, _) => return Err(tape.error(*p, "{")),
151 };
152
153 let mut cur_idx = *p + 1;
154 while cur_idx < end_idx {
155 let field_name = match tape.get(cur_idx) {
157 TapeElement::String(s) => tape.get_string(s),
158 _ => return Err(tape.error(cur_idx, "field name")),
159 };
160
161 let field_idx = match &self.field_name_to_index {
163 Some(map) => map.get(field_name).copied(),
164 None => fields.iter().position(|x| x.name() == field_name),
165 };
166 match field_idx {
167 Some(field_idx) => {
168 self.field_tape_positions.set(field_idx, row, cur_idx + 1);
169 }
170 None => {
171 if self.strict_mode {
172 return Err(ArrowError::JsonError(format!(
173 "column '{field_name}' missing from schema",
174 )));
175 }
176 }
177 }
178 cur_idx = tape.next(cur_idx + 1, "field value")?;
180 }
181 }
182 }
183 StructMode::ListOnly => {
184 for (row, p) in pos.iter().enumerate() {
185 let end_idx = match (tape.get(*p), nulls.as_mut()) {
186 (TapeElement::StartList(end_idx), None) => end_idx,
187 (TapeElement::StartList(end_idx), Some(nulls)) => {
188 nulls.append_non_null();
189 end_idx
190 }
191 (TapeElement::Null, Some(nulls)) => {
192 nulls.append_null();
193 continue;
194 }
195 (_, Some(nulls)) if self.ignore_type_conflicts => {
196 nulls.append_null();
197 continue;
198 }
199 (_, _) => return Err(tape.error(*p, "[")),
200 };
201
202 let mut cur_idx = *p + 1;
203 let mut entry_idx = 0;
204 while cur_idx < end_idx {
205 self.field_tape_positions
206 .try_set(entry_idx, row, cur_idx)
207 .ok_or_else(|| {
208 ArrowError::JsonError(format!(
209 "found extra columns for {} fields",
210 fields.len()
211 ))
212 })?;
213 entry_idx += 1;
214 cur_idx = tape.next(cur_idx, "field value")?;
216 }
217 if entry_idx != fields.len() {
218 return Err(ArrowError::JsonError(format!(
219 "found {} columns for {} fields",
220 entry_idx,
221 fields.len()
222 )));
223 }
224 }
225 }
226 }
227 }
228
229 let child_arrays = self
230 .decoders
231 .iter_mut()
232 .enumerate()
233 .zip(fields)
234 .map(|((field_idx, d), f)| {
235 let pos = self.field_tape_positions.field_positions(field_idx);
236 d.decode(tape, pos).map_err(|e| match e {
237 ArrowError::JsonError(s) => {
238 ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
239 }
240 e => e,
241 })
242 })
243 .collect::<Result<Vec<_>, ArrowError>>()?;
244
245 let nulls = nulls.as_mut().and_then(|x| x.finish());
246
247 for (c, f) in child_arrays.iter().zip(fields) {
248 assert_eq!(c.len(), pos.len());
250 if let Some(a) = c.nulls() {
251 let nulls_valid =
252 f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
253
254 if !nulls_valid {
255 return Err(ArrowError::JsonError(format!(
256 "Encountered unmasked nulls in non-nullable StructArray child: {f}"
257 )));
258 }
259 }
260 }
261
262 let array = unsafe {
264 StructArray::new_unchecked_with_length(fields.clone(), child_arrays, nulls, row_count)
265 };
266 Ok(Arc::new(array))
267 }
268}
269
270fn struct_fields(data_type: &DataType) -> &Fields {
271 match &data_type {
272 DataType::Struct(f) => f,
273 _ => unreachable!(),
274 }
275}
276
277fn build_field_index(fields: &Fields) -> Option<HashMap<String, usize>> {
278 const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16;
280 if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD {
281 return None;
282 }
283
284 let mut map = HashMap::with_capacity(fields.len());
285 for (idx, field) in fields.iter().enumerate() {
286 let name = field.name();
287 if !map.contains_key(name) {
288 map.insert(name.to_string(), idx);
289 }
290 }
291 Some(map)
292}