arrow_json/reader/
struct_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::reader::tape::{Tape, TapeElement};
19use crate::reader::{make_decoder, ArrayDecoder, StructMode};
20use arrow_array::builder::BooleanBufferBuilder;
21use arrow_buffer::buffer::NullBuffer;
22use arrow_data::{ArrayData, ArrayDataBuilder};
23use arrow_schema::{ArrowError, DataType, Fields};
24
25pub struct StructArrayDecoder {
26    data_type: DataType,
27    decoders: Vec<Box<dyn ArrayDecoder>>,
28    strict_mode: bool,
29    is_nullable: bool,
30    struct_mode: StructMode,
31}
32
33impl StructArrayDecoder {
34    pub fn new(
35        data_type: DataType,
36        coerce_primitive: bool,
37        strict_mode: bool,
38        is_nullable: bool,
39        struct_mode: StructMode,
40    ) -> Result<Self, ArrowError> {
41        let decoders = struct_fields(&data_type)
42            .iter()
43            .map(|f| {
44                // If this struct nullable, need to permit nullability in child array
45                // StructArrayDecoder::decode verifies that if the child is not nullable
46                // it doesn't contain any nulls not masked by its parent
47                let nullable = f.is_nullable() || is_nullable;
48                make_decoder(
49                    f.data_type().clone(),
50                    coerce_primitive,
51                    strict_mode,
52                    nullable,
53                    struct_mode,
54                )
55            })
56            .collect::<Result<Vec<_>, ArrowError>>()?;
57
58        Ok(Self {
59            data_type,
60            decoders,
61            strict_mode,
62            is_nullable,
63            struct_mode,
64        })
65    }
66}
67
68impl ArrayDecoder for StructArrayDecoder {
69    fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
70        let fields = struct_fields(&self.data_type);
71        let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect();
72
73        let mut nulls = self
74            .is_nullable
75            .then(|| BooleanBufferBuilder::new(pos.len()));
76
77        // We avoid having the match on self.struct_mode inside the hot loop for performance
78        // TODO: Investigate how to extract duplicated logic.
79        match self.struct_mode {
80            StructMode::ObjectOnly => {
81                for (row, p) in pos.iter().enumerate() {
82                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
83                        (TapeElement::StartObject(end_idx), None) => end_idx,
84                        (TapeElement::StartObject(end_idx), Some(nulls)) => {
85                            nulls.append(true);
86                            end_idx
87                        }
88                        (TapeElement::Null, Some(nulls)) => {
89                            nulls.append(false);
90                            continue;
91                        }
92                        (_, _) => return Err(tape.error(*p, "{")),
93                    };
94
95                    let mut cur_idx = *p + 1;
96                    while cur_idx < end_idx {
97                        // Read field name
98                        let field_name = match tape.get(cur_idx) {
99                            TapeElement::String(s) => tape.get_string(s),
100                            _ => return Err(tape.error(cur_idx, "field name")),
101                        };
102
103                        // Update child pos if match found
104                        match fields.iter().position(|x| x.name() == field_name) {
105                            Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1,
106                            None => {
107                                if self.strict_mode {
108                                    return Err(ArrowError::JsonError(format!(
109                                        "column '{}' missing from schema",
110                                        field_name
111                                    )));
112                                }
113                            }
114                        }
115                        // Advance to next field
116                        cur_idx = tape.next(cur_idx + 1, "field value")?;
117                    }
118                }
119            }
120            StructMode::ListOnly => {
121                for (row, p) in pos.iter().enumerate() {
122                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
123                        (TapeElement::StartList(end_idx), None) => end_idx,
124                        (TapeElement::StartList(end_idx), Some(nulls)) => {
125                            nulls.append(true);
126                            end_idx
127                        }
128                        (TapeElement::Null, Some(nulls)) => {
129                            nulls.append(false);
130                            continue;
131                        }
132                        (_, _) => return Err(tape.error(*p, "[")),
133                    };
134
135                    let mut cur_idx = *p + 1;
136                    let mut entry_idx = 0;
137                    while cur_idx < end_idx {
138                        if entry_idx >= fields.len() {
139                            return Err(ArrowError::JsonError(format!(
140                                "found extra columns for {} fields",
141                                fields.len()
142                            )));
143                        }
144                        child_pos[entry_idx][row] = cur_idx;
145                        entry_idx += 1;
146                        // Advance to next field
147                        cur_idx = tape.next(cur_idx, "field value")?;
148                    }
149                    if entry_idx != fields.len() {
150                        return Err(ArrowError::JsonError(format!(
151                            "found {} columns for {} fields",
152                            entry_idx,
153                            fields.len()
154                        )));
155                    }
156                }
157            }
158        }
159
160        let child_data = self
161            .decoders
162            .iter_mut()
163            .zip(child_pos)
164            .zip(fields)
165            .map(|((d, pos), f)| {
166                d.decode(tape, &pos).map_err(|e| match e {
167                    ArrowError::JsonError(s) => {
168                        ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name()))
169                    }
170                    e => e,
171                })
172            })
173            .collect::<Result<Vec<_>, ArrowError>>()?;
174
175        let nulls = nulls.as_mut().map(|x| NullBuffer::new(x.finish()));
176
177        for (c, f) in child_data.iter().zip(fields) {
178            // Sanity check
179            assert_eq!(c.len(), pos.len());
180            if let Some(a) = c.nulls() {
181                let nulls_valid =
182                    f.is_nullable() || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
183
184                if !nulls_valid {
185                    return Err(ArrowError::JsonError(format!(
186                        "Encountered unmasked nulls in non-nullable StructArray child: {f}"
187                    )));
188                }
189            }
190        }
191
192        let data = ArrayDataBuilder::new(self.data_type.clone())
193            .len(pos.len())
194            .nulls(nulls)
195            .child_data(child_data);
196
197        // Safety
198        // Validated lengths above
199        Ok(unsafe { data.build_unchecked() })
200    }
201}
202
203fn struct_fields(data_type: &DataType) -> &Fields {
204    match &data_type {
205        DataType::Struct(f) => f,
206        _ => unreachable!(),
207    }
208}