arrow_json/reader/
struct_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{make_decoder, ArrayDecoder, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24
25pub struct StructArrayDecoder {
26    data_type: DataType,
27    decoders: Vec<Box<dyn ArrayDecoder>>,
28    strict_mode: bool,
29    is_nullable: bool,
30    struct_mode: StructMode,
31}
32
33impl StructArrayDecoder {
34    pub fn new(
35        data_type: DataType,
36        coerce_primitive: bool,
37        strict_mode: bool,
38        is_nullable: bool,
39        struct_mode: StructMode,
40    ) -> Result<Self, ArrowError> {
41        let decoders = struct_fields(&data_type)
42            .iter()
43            .map(|f| {
44                // If this struct nullable, need to permit nullability in child array
45                // StructArrayDecoder::decode verifies that if the child is not nullable
46                // it doesn't contain any nulls not masked by its parent
47                let nullable = f.is_nullable() || is_nullable;
48                make_decoder(
49                    f.data_type().clone(),
50                    coerce_primitive,
51                    strict_mode,
52                    nullable,
53                    struct_mode,
54                )
55            })
56            .collect::<Result<Vec<_>, ArrowError>>()?;
57
58        Ok(Self {
59            data_type,
60            decoders,
61            strict_mode,
62            is_nullable,
63            struct_mode,
64        })
65    }
66}
67
68impl ArrayDecoder for StructArrayDecoder {
69    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
70        let fields = struct_fields(&self.data_type);
71        let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect();
72
73        let mut nulls = self
74            .is_nullable
75            .then(|| BooleanBufferBuilder::new(pos.len()));
76
77        // We avoid having the match on self.struct_mode inside the hot loop for performance
78        // TODO: Investigate how to extract duplicated logic.
79        match self.struct_mode {
80            StructMode::ObjectOnly => {
81                for (row, p) in pos.iter().enumerate() {
82                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
83                        (TapeElement::StartObject(end_idx), None) => end_idx,
84                        (TapeElement::StartObject(end_idx), Some(nulls)) => {
85                            nulls.append(true);
86                            end_idx
87                        }
88                        (TapeElement::Null, Some(nulls)) => {
89                            nulls.append(false);
90                            continue;
91                        }
92                        (_, _) => return Err(tape.error(*p, "{")),
93                    };
94
95                    let mut cur_idx = *p + 1;
96                    while cur_idx < end_idx {
97                        // Read field name
98                        let field_name = match tape.get(cur_idx) {
99                            TapeElement::String(s) => tape.get_string(s),
100                            _ => return Err(tape.error(cur_idx, "field name")),
101                        };
102
103                        // Update child pos if match found
104                        match fields.iter().position(|x| x.name() == field_name) {
105                            Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
106                            None => {
107                                if self.strict_mode {
108                                    return Err(ArrowError::JsonError(format!(
109                                        "column '{field_name}' missing from schema",
110                                    )));
111                                }
112                            }
113                        }
114                        // Advance to next field
115                        cur_idx = tape.next(cur_idx + 1, "field value")?;
116                    }
117                }
118            }
119            StructMode::ListOnly => {
120                for (row, p) in pos.iter().enumerate() {
121                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
122                        (TapeElement::StartList(end_idx), None) => end_idx,
123                        (TapeElement::StartList(end_idx), Some(nulls)) => {
124                            nulls.append(true);
125                            end_idx
126                        }
127                        (TapeElement::Null, Some(nulls)) => {
128                            nulls.append(false);
129                            continue;
130                        }
131                        (_, _) => return Err(tape.error(*p, "[")),
132                    };
133
134                    let mut cur_idx = *p + 1;
135                    let mut entry_idx = 0;
136                    while cur_idx < end_idx {
137                        if entry_idx >= fields.len() {
138                            return Err(ArrowError::JsonError(format!(
139                                "found extra columns for {} fields",
140                                fields.len()
141                            )));
142                        }
143                        child_pos[entry_idx][row] = cur_idx;
144                        entry_idx += 1;
145                        // Advance to next field
146                        cur_idx = tape.next(cur_idx, "field value")?;
147                    }
148                    if entry_idx != fields.len() {
149                        return Err(ArrowError::JsonError(format!(
150                            "found {} columns for {} fields",
151                            entry_idx,
152                            fields.len()
153                        )));
154                    }
155                }
156            }
157        }
158
159        let child_data = self
160            .decoders
161            .iter_mut()
162            .zip(child_pos)
163            .zip(fields)
164            .map(|((d, pos), f)| {
165                d.decode(tape, &pos).map_err(|e| match e {
166                    ArrowError::JsonError(s) => {
167                        ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
168                    }
169                    e => e,
170                })
171            })
172            .collect::<Result<Vec<_>, ArrowError>>()?;
173
174        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
175
176        for (c, f) in child_data.iter().zip(fields) {
177            // Sanity check
178            assert_eq!(c.len(), pos.len());
179            if let Some(a) = c.nulls() {
180                let nulls_valid =
181                    f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
182
183                if !nulls_valid {
184                    return Err(ArrowError::JsonError(format!(
185                        "Encountered unmasked nulls in non-nullable StructArray child: {f}"
186                    )));
187                }
188            }
189        }
190
191        let data = ArrayDataBuilder::new(self.data_type.clone())
192            .len(pos.len())
193            .nulls(nulls)
194            .child_data(child_data);
195
196        // Safety
197        // Validated lengths above
198        Ok(unsafe { data.build_unchecked() })
199    }
200}
201
202fn struct_fields(data_type: &DataType) -> &Fields {
203    match &data_type {
204        DataType::Struct(f) => f,
205        _ => unreachable!(),
206    }
207}