arrow_integration_test/
datatype.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode};
19use arrow::error::{ArrowError, Result};
20use std::sync::Arc;
21
22/// Parse a data type from a JSON representation.
23pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
24    use serde_json::Value;
25    let default_field = Arc::new(Field::new("", DataType::Boolean, true));
26    match *json {
27        Value::Object(ref map) => match map.get("name") {
28            Some(s) if s == "null" => Ok(DataType::Null),
29            Some(s) if s == "bool" => Ok(DataType::Boolean),
30            Some(s) if s == "binary" => Ok(DataType::Binary),
31            Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
32            Some(s) if s == "utf8" => Ok(DataType::Utf8),
33            Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
34            Some(s) if s == "fixedsizebinary" => {
35                // return a list with any type as its child isn't defined in the map
36                if let Some(Value::Number(size)) = map.get("byteWidth") {
37                    Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
38                } else {
39                    Err(ArrowError::ParseError(
40                        "Expecting a byteWidth for fixedsizebinary".to_string(),
41                    ))
42                }
43            }
44            Some(s) if s == "decimal" => {
45                // return a list with any type as its child isn't defined in the map
46                let precision = match map.get("precision") {
47                    Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()),
48                    None => Err(ArrowError::ParseError(
49                        "Expecting a precision for decimal".to_string(),
50                    )),
51                }?;
52                let scale = match map.get("scale") {
53                    Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()),
54                    _ => Err(ArrowError::ParseError(
55                        "Expecting a scale for decimal".to_string(),
56                    )),
57                }?;
58                let bit_width: usize = match map.get("bitWidth") {
59                    Some(b) => b.as_u64().unwrap() as usize,
60                    _ => 128, // Default bit width
61                };
62
63                match bit_width {
64                    128 => Ok(DataType::Decimal128(precision, scale)),
65                    256 => Ok(DataType::Decimal256(precision, scale)),
66                    _ => Err(ArrowError::ParseError(
67                        "Decimal bit_width invalid".to_string(),
68                    )),
69                }
70            }
71            Some(s) if s == "floatingpoint" => match map.get("precision") {
72                Some(p) if p == "HALF" => Ok(DataType::Float16),
73                Some(p) if p == "SINGLE" => Ok(DataType::Float32),
74                Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
75                _ => Err(ArrowError::ParseError(
76                    "floatingpoint precision missing or invalid".to_string(),
77                )),
78            },
79            Some(s) if s == "timestamp" => {
80                let unit = match map.get("unit") {
81                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
82                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
83                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
84                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
85                    _ => Err(ArrowError::ParseError(
86                        "timestamp unit missing or invalid".to_string(),
87                    )),
88                };
89                let tz = match map.get("timezone") {
90                    None => Ok(None),
91                    Some(Value::String(tz)) => Ok(Some(tz.as_str().into())),
92                    _ => Err(ArrowError::ParseError(
93                        "timezone must be a string".to_string(),
94                    )),
95                };
96                Ok(DataType::Timestamp(unit?, tz?))
97            }
98            Some(s) if s == "date" => match map.get("unit") {
99                Some(p) if p == "DAY" => Ok(DataType::Date32),
100                Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
101                _ => Err(ArrowError::ParseError(
102                    "date unit missing or invalid".to_string(),
103                )),
104            },
105            Some(s) if s == "time" => {
106                let unit = match map.get("unit") {
107                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
108                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
109                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
110                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
111                    _ => Err(ArrowError::ParseError(
112                        "time unit missing or invalid".to_string(),
113                    )),
114                };
115                match map.get("bitWidth") {
116                    Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
117                    Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
118                    _ => Err(ArrowError::ParseError(
119                        "time bitWidth missing or invalid".to_string(),
120                    )),
121                }
122            }
123            Some(s) if s == "duration" => match map.get("unit") {
124                Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
125                Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)),
126                Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)),
127                Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)),
128                _ => Err(ArrowError::ParseError(
129                    "time unit missing or invalid".to_string(),
130                )),
131            },
132            Some(s) if s == "interval" => match map.get("unit") {
133                Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)),
134                Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)),
135                Some(p) if p == "MONTH_DAY_NANO" => {
136                    Ok(DataType::Interval(IntervalUnit::MonthDayNano))
137                }
138                _ => Err(ArrowError::ParseError(
139                    "interval unit missing or invalid".to_string(),
140                )),
141            },
142            Some(s) if s == "int" => match map.get("isSigned") {
143                Some(&Value::Bool(true)) => match map.get("bitWidth") {
144                    Some(Value::Number(n)) => match n.as_u64() {
145                        Some(8) => Ok(DataType::Int8),
146                        Some(16) => Ok(DataType::Int16),
147                        Some(32) => Ok(DataType::Int32),
148                        Some(64) => Ok(DataType::Int64),
149                        _ => Err(ArrowError::ParseError(
150                            "int bitWidth missing or invalid".to_string(),
151                        )),
152                    },
153                    _ => Err(ArrowError::ParseError(
154                        "int bitWidth missing or invalid".to_string(),
155                    )),
156                },
157                Some(&Value::Bool(false)) => match map.get("bitWidth") {
158                    Some(Value::Number(n)) => match n.as_u64() {
159                        Some(8) => Ok(DataType::UInt8),
160                        Some(16) => Ok(DataType::UInt16),
161                        Some(32) => Ok(DataType::UInt32),
162                        Some(64) => Ok(DataType::UInt64),
163                        _ => Err(ArrowError::ParseError(
164                            "int bitWidth missing or invalid".to_string(),
165                        )),
166                    },
167                    _ => Err(ArrowError::ParseError(
168                        "int bitWidth missing or invalid".to_string(),
169                    )),
170                },
171                _ => Err(ArrowError::ParseError(
172                    "int signed missing or invalid".to_string(),
173                )),
174            },
175            Some(s) if s == "list" => {
176                // return a list with any type as its child isn't defined in the map
177                Ok(DataType::List(default_field))
178            }
179            Some(s) if s == "largelist" => {
180                // return a largelist with any type as its child isn't defined in the map
181                Ok(DataType::LargeList(default_field))
182            }
183            Some(s) if s == "fixedsizelist" => {
184                // return a list with any type as its child isn't defined in the map
185                if let Some(Value::Number(size)) = map.get("listSize") {
186                    Ok(DataType::FixedSizeList(
187                        default_field,
188                        size.as_i64().unwrap() as i32,
189                    ))
190                } else {
191                    Err(ArrowError::ParseError(
192                        "Expecting a listSize for fixedsizelist".to_string(),
193                    ))
194                }
195            }
196            Some(s) if s == "struct" => {
197                // return an empty `struct` type as its children aren't defined in the map
198                Ok(DataType::Struct(Fields::empty()))
199            }
200            Some(s) if s == "map" => {
201                if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
202                    // Return a map with an empty type as its children aren't defined in the map
203                    Ok(DataType::Map(default_field, *keys_sorted))
204                } else {
205                    Err(ArrowError::ParseError(
206                        "Expecting a keysSorted for map".to_string(),
207                    ))
208                }
209            }
210            Some(s) if s == "union" => {
211                if let Some(Value::String(mode)) = map.get("mode") {
212                    let union_mode = if mode == "SPARSE" {
213                        UnionMode::Sparse
214                    } else if mode == "DENSE" {
215                        UnionMode::Dense
216                    } else {
217                        return Err(ArrowError::ParseError(format!(
218                            "Unknown union mode {mode:?} for union"
219                        )));
220                    };
221                    if let Some(values) = map.get("typeIds") {
222                        let values = values.as_array().unwrap();
223                        let fields = values
224                            .iter()
225                            .map(|t| (t.as_i64().unwrap() as i8, default_field.clone()))
226                            .collect();
227
228                        Ok(DataType::Union(fields, union_mode))
229                    } else {
230                        Err(ArrowError::ParseError(
231                            "Expecting a typeIds for union ".to_string(),
232                        ))
233                    }
234                } else {
235                    Err(ArrowError::ParseError(
236                        "Expecting a mode for union".to_string(),
237                    ))
238                }
239            }
240            Some(other) => Err(ArrowError::ParseError(format!(
241                "invalid or unsupported type name: {other} in {json:?}"
242            ))),
243            None => Err(ArrowError::ParseError("type name missing".to_string())),
244        },
245        _ => Err(ArrowError::ParseError(
246            "invalid json value type".to_string(),
247        )),
248    }
249}
250
251/// Generate a JSON representation of the data type.
252pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
253    use serde_json::json;
254    match data_type {
255        DataType::Null => json!({"name": "null"}),
256        DataType::Boolean => json!({"name": "bool"}),
257        DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
258        DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
259        DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
260        DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
261        DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
262        DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
263        DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
264        DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
265        DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
266        DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
267        DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
268        DataType::Utf8 => json!({"name": "utf8"}),
269        DataType::LargeUtf8 => json!({"name": "largeutf8"}),
270        DataType::Binary => json!({"name": "binary"}),
271        DataType::LargeBinary => json!({"name": "largebinary"}),
272        DataType::BinaryView | DataType::Utf8View => {
273            unimplemented!("BinaryView/Utf8View not implemented")
274        }
275        DataType::FixedSizeBinary(byte_width) => {
276            json!({"name": "fixedsizebinary", "byteWidth": byte_width})
277        }
278        DataType::Struct(_) => json!({"name": "struct"}),
279        DataType::Union(_, _) => json!({"name": "union"}),
280        DataType::List(_) => json!({ "name": "list"}),
281        DataType::LargeList(_) => json!({ "name": "largelist"}),
282        DataType::ListView(_) | DataType::LargeListView(_) => {
283            unimplemented!("ListView/LargeListView not implemented")
284        }
285        DataType::FixedSizeList(_, length) => {
286            json!({"name":"fixedsizelist", "listSize": length})
287        }
288        DataType::Time32(unit) => {
289            json!({"name": "time", "bitWidth": 32, "unit": match unit {
290                TimeUnit::Second => "SECOND",
291                TimeUnit::Millisecond => "MILLISECOND",
292                TimeUnit::Microsecond => "MICROSECOND",
293                TimeUnit::Nanosecond => "NANOSECOND",
294            }})
295        }
296        DataType::Time64(unit) => {
297            json!({"name": "time", "bitWidth": 64, "unit": match unit {
298                TimeUnit::Second => "SECOND",
299                TimeUnit::Millisecond => "MILLISECOND",
300                TimeUnit::Microsecond => "MICROSECOND",
301                TimeUnit::Nanosecond => "NANOSECOND",
302            }})
303        }
304        DataType::Date32 => {
305            json!({"name": "date", "unit": "DAY"})
306        }
307        DataType::Date64 => {
308            json!({"name": "date", "unit": "MILLISECOND"})
309        }
310        DataType::Timestamp(unit, None) => {
311            json!({"name": "timestamp", "unit": match unit {
312                TimeUnit::Second => "SECOND",
313                TimeUnit::Millisecond => "MILLISECOND",
314                TimeUnit::Microsecond => "MICROSECOND",
315                TimeUnit::Nanosecond => "NANOSECOND",
316            }})
317        }
318        DataType::Timestamp(unit, Some(tz)) => {
319            json!({"name": "timestamp", "unit": match unit {
320                    TimeUnit::Second => "SECOND",
321                    TimeUnit::Millisecond => "MILLISECOND",
322                    TimeUnit::Microsecond => "MICROSECOND",
323                    TimeUnit::Nanosecond => "NANOSECOND",
324                }, "timezone": tz})
325        }
326        DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
327            IntervalUnit::YearMonth => "YEAR_MONTH",
328            IntervalUnit::DayTime => "DAY_TIME",
329            IntervalUnit::MonthDayNano => "MONTH_DAY_NANO",
330        }}),
331        DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
332            TimeUnit::Second => "SECOND",
333            TimeUnit::Millisecond => "MILLISECOND",
334            TimeUnit::Microsecond => "MICROSECOND",
335            TimeUnit::Nanosecond => "NANOSECOND",
336        }}),
337        DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
338        DataType::Decimal128(precision, scale) => {
339            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
340        }
341        DataType::Decimal256(precision, scale) => {
342            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
343        }
344        DataType::Map(_, keys_sorted) => {
345            json!({"name": "map", "keysSorted": keys_sorted})
346        }
347        DataType::RunEndEncoded(_, _) => todo!(),
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use serde_json::Value;
355
356    #[test]
357    fn parse_utf8_from_json() {
358        let json = "{\"name\":\"utf8\"}";
359        let value: Value = serde_json::from_str(json).unwrap();
360        let dt = data_type_from_json(&value).unwrap();
361        assert_eq!(DataType::Utf8, dt);
362    }
363
364    #[test]
365    fn parse_int32_from_json() {
366        let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}";
367        let value: Value = serde_json::from_str(json).unwrap();
368        let dt = data_type_from_json(&value).unwrap();
369        assert_eq!(DataType::Int32, dt);
370    }
371}