arrow_json/reader/
struct_array.rs1use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{make_decoder, ArrayDecoder, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24
25pub struct StructArrayDecoder {
26 data_type: DataType,
27 decoders: Vec<Box<dyn ArrayDecoder>>,
28 strict_mode: bool,
29 is_nullable: bool,
30 struct_mode: StructMode,
31}
32
33impl StructArrayDecoder {
34 pub fn new(
35 data_type: DataType,
36 coerce_primitive: bool,
37 strict_mode: bool,
38 is_nullable: bool,
39 struct_mode: StructMode,
40 ) -> Result<Self, ArrowError> {
41 let decoders = struct_fields(&data_type)
42 .iter()
43 .map(|f| {
44 let nullable = f.is_nullable() || is_nullable;
48 make_decoder(
49 f.data_type().clone(),
50 coerce_primitive,
51 strict_mode,
52 nullable,
53 struct_mode,
54 )
55 })
56 .collect::<Result<Vec<_>, ArrowError>>()?;
57
58 Ok(Self {
59 data_type,
60 decoders,
61 strict_mode,
62 is_nullable,
63 struct_mode,
64 })
65 }
66}
67
68impl ArrayDecoder for StructArrayDecoder {
69 fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
70 let fields = struct_fields(&self.data_type);
71 let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect();
72
73 let mut nulls = self
74 .is_nullable
75 .then(|| BooleanBufferBuilder::new(pos.len()));
76
77 match self.struct_mode {
80 StructMode::ObjectOnly => {
81 for (row, p) in pos.iter().enumerate() {
82 let end_idx = match (tape.get(*p), nulls.as_mut()) {
83 (TapeElement::StartObject(end_idx), None) => end_idx,
84 (TapeElement::StartObject(end_idx), Some(nulls)) => {
85 nulls.append(true);
86 end_idx
87 }
88 (TapeElement::Null, Some(nulls)) => {
89 nulls.append(false);
90 continue;
91 }
92 (_, _) => return Err(tape.error(*p, "{")),
93 };
94
95 let mut cur_idx = *p + 1;
96 while cur_idx < end_idx {
97 let field_name = match tape.get(cur_idx) {
99 TapeElement::String(s) => tape.get_string(s),
100 _ => return Err(tape.error(cur_idx, "field name")),
101 };
102
103 match fields.iter().position(|x| x.name() == field_name) {
105 Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
106 None => {
107 if self.strict_mode {
108 return Err(ArrowError::JsonError(format!(
109 "column '{}' missing from schema",
110 field_name
111 )));
112 }
113 }
114 }
115 cur_idx = tape.next(cur_idx + 1, "field value")?;
117 }
118 }
119 }
120 StructMode::ListOnly => {
121 for (row, p) in pos.iter().enumerate() {
122 let end_idx = match (tape.get(*p), nulls.as_mut()) {
123 (TapeElement::StartList(end_idx), None) => end_idx,
124 (TapeElement::StartList(end_idx), Some(nulls)) => {
125 nulls.append(true);
126 end_idx
127 }
128 (TapeElement::Null, Some(nulls)) => {
129 nulls.append(false);
130 continue;
131 }
132 (_, _) => return Err(tape.error(*p, "[")),
133 };
134
135 let mut cur_idx = *p + 1;
136 let mut entry_idx = 0;
137 while cur_idx < end_idx {
138 if entry_idx >= fields.len() {
139 return Err(ArrowError::JsonError(format!(
140 "found extra columns for {} fields",
141 fields.len()
142 )));
143 }
144 child_pos[entry_idx][row] = cur_idx;
145 entry_idx += 1;
146 cur_idx = tape.next(cur_idx, "field value")?;
148 }
149 if entry_idx != fields.len() {
150 return Err(ArrowError::JsonError(format!(
151 "found {} columns for {} fields",
152 entry_idx,
153 fields.len()
154 )));
155 }
156 }
157 }
158 }
159
160 let child_data = self
161 .decoders
162 .iter_mut()
163 .zip(child_pos)
164 .zip(fields)
165 .map(|((d, pos), f)| {
166 d.decode(tape, &pos).map_err(|e| match e {
167 ArrowError::JsonError(s) => {
168 ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
169 }
170 e => e,
171 })
172 })
173 .collect::<Result<Vec<_>, ArrowError>>()?;
174
175 let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
176
177 for (c, f) in child_data.iter().zip(fields) {
178 assert_eq!(c.len(), pos.len());
180 if let Some(a) = c.nulls() {
181 let nulls_valid =
182 f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
183
184 if !nulls_valid {
185 return Err(ArrowError::JsonError(format!(
186 "Encountered unmasked nulls in non-nullable StructArray child: {f}"
187 )));
188 }
189 }
190 }
191
192 let data = ArrayDataBuilder::new(self.data_type.clone())
193 .len(pos.len())
194 .nulls(nulls)
195 .child_data(child_data);
196
197 Ok(unsafe { data.build_unchecked() })
200 }
201}
202
203fn struct_fields(data_type: &DataType) -> &Fields {
204 match &data_type {
205 DataType::Struct(f) => f,
206 _ => unreachable!(),
207 }
208}