arrow_json/reader/
struct_array.rs1use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{make_decoder, ArrayDecoder, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24
25pub struct StructArrayDecoder {
26 data_type: DataType,
27 decoders: Vec<Box<dyn ArrayDecoder>>,
28 strict_mode: bool,
29 is_nullable: bool,
30 struct_mode: StructMode,
31}
32
33impl StructArrayDecoder {
34 pub fn new(
35 data_type: DataType,
36 coerce_primitive: bool,
37 strict_mode: bool,
38 is_nullable: bool,
39 struct_mode: StructMode,
40 ) -> Result<Self, ArrowError> {
41 let decoders = struct_fields(&data_type)
42 .iter()
43 .map(|f| {
44 let nullable = f.is_nullable() || is_nullable;
48 make_decoder(
49 f.data_type().clone(),
50 coerce_primitive,
51 strict_mode,
52 nullable,
53 struct_mode,
54 )
55 })
56 .collect::<Result<Vec<_>, ArrowError>>()?;
57
58 Ok(Self {
59 data_type,
60 decoders,
61 strict_mode,
62 is_nullable,
63 struct_mode,
64 })
65 }
66}
67
68impl ArrayDecoder for StructArrayDecoder {
69 fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
70 let fields = struct_fields(&self.data_type);
71 let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect();
72
73 let mut nulls = self
74 .is_nullable
75 .then(|| BooleanBufferBuilder::new(pos.len()));
76
77 match self.struct_mode {
80 StructMode::ObjectOnly => {
81 for (row, p) in pos.iter().enumerate() {
82 let end_idx = match (tape.get(*p), nulls.as_mut()) {
83 (TapeElement::StartObject(end_idx), None) => end_idx,
84 (TapeElement::StartObject(end_idx), Some(nulls)) => {
85 nulls.append(true);
86 end_idx
87 }
88 (TapeElement::Null, Some(nulls)) => {
89 nulls.append(false);
90 continue;
91 }
92 (_, _) => return Err(tape.error(*p, "{")),
93 };
94
95 let mut cur_idx = *p + 1;
96 while cur_idx < end_idx {
97 let field_name = match tape.get(cur_idx) {
99 TapeElement::String(s) => tape.get_string(s),
100 _ => return Err(tape.error(cur_idx, "field name")),
101 };
102
103 match fields.iter().position(|x| x.name() == field_name) {
105 Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
106 None => {
107 if self.strict_mode {
108 return Err(ArrowError::JsonError(format!(
109 "column '{field_name}' missing from schema",
110 )));
111 }
112 }
113 }
114 cur_idx = tape.next(cur_idx + 1, "field value")?;
116 }
117 }
118 }
119 StructMode::ListOnly => {
120 for (row, p) in pos.iter().enumerate() {
121 let end_idx = match (tape.get(*p), nulls.as_mut()) {
122 (TapeElement::StartList(end_idx), None) => end_idx,
123 (TapeElement::StartList(end_idx), Some(nulls)) => {
124 nulls.append(true);
125 end_idx
126 }
127 (TapeElement::Null, Some(nulls)) => {
128 nulls.append(false);
129 continue;
130 }
131 (_, _) => return Err(tape.error(*p, "[")),
132 };
133
134 let mut cur_idx = *p + 1;
135 let mut entry_idx = 0;
136 while cur_idx < end_idx {
137 if entry_idx >= fields.len() {
138 return Err(ArrowError::JsonError(format!(
139 "found extra columns for {} fields",
140 fields.len()
141 )));
142 }
143 child_pos[entry_idx][row] = cur_idx;
144 entry_idx += 1;
145 cur_idx = tape.next(cur_idx, "field value")?;
147 }
148 if entry_idx != fields.len() {
149 return Err(ArrowError::JsonError(format!(
150 "found {} columns for {} fields",
151 entry_idx,
152 fields.len()
153 )));
154 }
155 }
156 }
157 }
158
159 let child_data = self
160 .decoders
161 .iter_mut()
162 .zip(child_pos)
163 .zip(fields)
164 .map(|((d, pos), f)| {
165 d.decode(tape, &pos).map_err(|e| match e {
166 ArrowError::JsonError(s) => {
167 ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
168 }
169 e => e,
170 })
171 })
172 .collect::<Result<Vec<_>, ArrowError>>()?;
173
174 let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
175
176 for (c, f) in child_data.iter().zip(fields) {
177 assert_eq!(c.len(), pos.len());
179 if let Some(a) = c.nulls() {
180 let nulls_valid =
181 f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
182
183 if !nulls_valid {
184 return Err(ArrowError::JsonError(format!(
185 "Encountered unmasked nulls in non-nullable StructArray child: {f}"
186 )));
187 }
188 }
189 }
190
191 let data = ArrayDataBuilder::new(self.data_type.clone())
192 .len(pos.len())
193 .nulls(nulls)
194 .child_data(child_data);
195
196 Ok(unsafe { data.build_unchecked() })
199 }
200}
201
202fn struct_fields(data_type: &DataType) -> &Fields {
203 match &data_type {
204 DataType::Struct(f) => f,
205 _ => unreachable!(),
206 }
207}