arrow_integration_test/
datatype.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode};
19use arrow::error::{ArrowError, Result};
20use std::sync::Arc;
21
22/// Parse a data type from a JSON representation.
23pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
24    use serde_json::Value;
25    let default_field = Arc::new(Field::new("", DataType::Boolean, true));
26    match *json {
27        Value::Object(ref map) => match map.get("name") {
28            Some(s) if s == "null" => Ok(DataType::Null),
29            Some(s) if s == "bool" => Ok(DataType::Boolean),
30            Some(s) if s == "binary" => Ok(DataType::Binary),
31            Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
32            Some(s) if s == "utf8" => Ok(DataType::Utf8),
33            Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
34            Some(s) if s == "fixedsizebinary" => {
35                // return a list with any type as its child isn't defined in the map
36                if let Some(Value::Number(size)) = map.get("byteWidth") {
37                    Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
38                } else {
39                    Err(ArrowError::ParseError(
40                        "Expecting a byteWidth for fixedsizebinary".to_string(),
41                    ))
42                }
43            }
44            Some(s) if s == "decimal" => {
45                // return a list with any type as its child isn't defined in the map
46                let precision = match map.get("precision") {
47                    Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()),
48                    None => Err(ArrowError::ParseError(
49                        "Expecting a precision for decimal".to_string(),
50                    )),
51                }?;
52                let scale = match map.get("scale") {
53                    Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()),
54                    _ => Err(ArrowError::ParseError(
55                        "Expecting a scale for decimal".to_string(),
56                    )),
57                }?;
58                let bit_width: usize = match map.get("bitWidth") {
59                    Some(b) => b.as_u64().unwrap() as usize,
60                    _ => 128, // Default bit width
61                };
62
63                match bit_width {
64                    32 => Ok(DataType::Decimal32(precision, scale)),
65                    64 => Ok(DataType::Decimal64(precision, scale)),
66                    128 => Ok(DataType::Decimal128(precision, scale)),
67                    256 => Ok(DataType::Decimal256(precision, scale)),
68                    _ => Err(ArrowError::ParseError(
69                        "Decimal bit_width invalid".to_string(),
70                    )),
71                }
72            }
73            Some(s) if s == "floatingpoint" => match map.get("precision") {
74                Some(p) if p == "HALF" => Ok(DataType::Float16),
75                Some(p) if p == "SINGLE" => Ok(DataType::Float32),
76                Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
77                _ => Err(ArrowError::ParseError(
78                    "floatingpoint precision missing or invalid".to_string(),
79                )),
80            },
81            Some(s) if s == "timestamp" => {
82                let unit = match map.get("unit") {
83                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
84                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
85                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
86                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
87                    _ => Err(ArrowError::ParseError(
88                        "timestamp unit missing or invalid".to_string(),
89                    )),
90                };
91                let tz = match map.get("timezone") {
92                    None => Ok(None),
93                    Some(Value::String(tz)) => Ok(Some(tz.as_str().into())),
94                    _ => Err(ArrowError::ParseError(
95                        "timezone must be a string".to_string(),
96                    )),
97                };
98                Ok(DataType::Timestamp(unit?, tz?))
99            }
100            Some(s) if s == "date" => match map.get("unit") {
101                Some(p) if p == "DAY" => Ok(DataType::Date32),
102                Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
103                _ => Err(ArrowError::ParseError(
104                    "date unit missing or invalid".to_string(),
105                )),
106            },
107            Some(s) if s == "time" => {
108                let unit = match map.get("unit") {
109                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
110                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
111                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
112                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
113                    _ => Err(ArrowError::ParseError(
114                        "time unit missing or invalid".to_string(),
115                    )),
116                };
117                match map.get("bitWidth") {
118                    Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
119                    Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
120                    _ => Err(ArrowError::ParseError(
121                        "time bitWidth missing or invalid".to_string(),
122                    )),
123                }
124            }
125            Some(s) if s == "duration" => match map.get("unit") {
126                Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
127                Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)),
128                Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)),
129                Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)),
130                _ => Err(ArrowError::ParseError(
131                    "time unit missing or invalid".to_string(),
132                )),
133            },
134            Some(s) if s == "interval" => match map.get("unit") {
135                Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)),
136                Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)),
137                Some(p) if p == "MONTH_DAY_NANO" => {
138                    Ok(DataType::Interval(IntervalUnit::MonthDayNano))
139                }
140                _ => Err(ArrowError::ParseError(
141                    "interval unit missing or invalid".to_string(),
142                )),
143            },
144            Some(s) if s == "int" => match map.get("isSigned") {
145                Some(&Value::Bool(true)) => match map.get("bitWidth") {
146                    Some(Value::Number(n)) => match n.as_u64() {
147                        Some(8) => Ok(DataType::Int8),
148                        Some(16) => Ok(DataType::Int16),
149                        Some(32) => Ok(DataType::Int32),
150                        Some(64) => Ok(DataType::Int64),
151                        _ => Err(ArrowError::ParseError(
152                            "int bitWidth missing or invalid".to_string(),
153                        )),
154                    },
155                    _ => Err(ArrowError::ParseError(
156                        "int bitWidth missing or invalid".to_string(),
157                    )),
158                },
159                Some(&Value::Bool(false)) => match map.get("bitWidth") {
160                    Some(Value::Number(n)) => match n.as_u64() {
161                        Some(8) => Ok(DataType::UInt8),
162                        Some(16) => Ok(DataType::UInt16),
163                        Some(32) => Ok(DataType::UInt32),
164                        Some(64) => Ok(DataType::UInt64),
165                        _ => Err(ArrowError::ParseError(
166                            "int bitWidth missing or invalid".to_string(),
167                        )),
168                    },
169                    _ => Err(ArrowError::ParseError(
170                        "int bitWidth missing or invalid".to_string(),
171                    )),
172                },
173                _ => Err(ArrowError::ParseError(
174                    "int signed missing or invalid".to_string(),
175                )),
176            },
177            Some(s) if s == "list" => {
178                // return a list with any type as its child isn't defined in the map
179                Ok(DataType::List(default_field))
180            }
181            Some(s) if s == "largelist" => {
182                // return a largelist with any type as its child isn't defined in the map
183                Ok(DataType::LargeList(default_field))
184            }
185            Some(s) if s == "fixedsizelist" => {
186                // return a list with any type as its child isn't defined in the map
187                if let Some(Value::Number(size)) = map.get("listSize") {
188                    Ok(DataType::FixedSizeList(
189                        default_field,
190                        size.as_i64().unwrap() as i32,
191                    ))
192                } else {
193                    Err(ArrowError::ParseError(
194                        "Expecting a listSize for fixedsizelist".to_string(),
195                    ))
196                }
197            }
198            Some(s) if s == "struct" => {
199                // return an empty `struct` type as its children aren't defined in the map
200                Ok(DataType::Struct(Fields::empty()))
201            }
202            Some(s) if s == "map" => {
203                if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
204                    // Return a map with an empty type as its children aren't defined in the map
205                    Ok(DataType::Map(default_field, *keys_sorted))
206                } else {
207                    Err(ArrowError::ParseError(
208                        "Expecting a keysSorted for map".to_string(),
209                    ))
210                }
211            }
212            Some(s) if s == "union" => {
213                if let Some(Value::String(mode)) = map.get("mode") {
214                    let union_mode = if mode == "SPARSE" {
215                        UnionMode::Sparse
216                    } else if mode == "DENSE" {
217                        UnionMode::Dense
218                    } else {
219                        return Err(ArrowError::ParseError(format!(
220                            "Unknown union mode {mode:?} for union"
221                        )));
222                    };
223                    if let Some(values) = map.get("typeIds") {
224                        let values = values.as_array().unwrap();
225                        let fields = values
226                            .iter()
227                            .map(|t| (t.as_i64().unwrap() as i8, default_field.clone()))
228                            .collect();
229
230                        Ok(DataType::Union(fields, union_mode))
231                    } else {
232                        Err(ArrowError::ParseError(
233                            "Expecting a typeIds for union ".to_string(),
234                        ))
235                    }
236                } else {
237                    Err(ArrowError::ParseError(
238                        "Expecting a mode for union".to_string(),
239                    ))
240                }
241            }
242            Some(other) => Err(ArrowError::ParseError(format!(
243                "invalid or unsupported type name: {other} in {json:?}"
244            ))),
245            None => Err(ArrowError::ParseError("type name missing".to_string())),
246        },
247        _ => Err(ArrowError::ParseError(
248            "invalid json value type".to_string(),
249        )),
250    }
251}
252
253/// Generate a JSON representation of the data type.
254pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
255    use serde_json::json;
256    match data_type {
257        DataType::Null => json!({"name": "null"}),
258        DataType::Boolean => json!({"name": "bool"}),
259        DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
260        DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
261        DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
262        DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
263        DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
264        DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
265        DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
266        DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
267        DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
268        DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
269        DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
270        DataType::Utf8 => json!({"name": "utf8"}),
271        DataType::LargeUtf8 => json!({"name": "largeutf8"}),
272        DataType::Binary => json!({"name": "binary"}),
273        DataType::LargeBinary => json!({"name": "largebinary"}),
274        DataType::BinaryView | DataType::Utf8View => {
275            unimplemented!("BinaryView/Utf8View not implemented")
276        }
277        DataType::FixedSizeBinary(byte_width) => {
278            json!({"name": "fixedsizebinary", "byteWidth": byte_width})
279        }
280        DataType::Struct(_) => json!({"name": "struct"}),
281        DataType::Union(_, _) => json!({"name": "union"}),
282        DataType::List(_) => json!({ "name": "list"}),
283        DataType::LargeList(_) => json!({ "name": "largelist"}),
284        DataType::ListView(_) | DataType::LargeListView(_) => {
285            unimplemented!("ListView/LargeListView not implemented")
286        }
287        DataType::FixedSizeList(_, length) => {
288            json!({"name":"fixedsizelist", "listSize": length})
289        }
290        DataType::Time32(unit) => {
291            json!({"name": "time", "bitWidth": 32, "unit": match unit {
292                TimeUnit::Second => "SECOND",
293                TimeUnit::Millisecond => "MILLISECOND",
294                TimeUnit::Microsecond => "MICROSECOND",
295                TimeUnit::Nanosecond => "NANOSECOND",
296            }})
297        }
298        DataType::Time64(unit) => {
299            json!({"name": "time", "bitWidth": 64, "unit": match unit {
300                TimeUnit::Second => "SECOND",
301                TimeUnit::Millisecond => "MILLISECOND",
302                TimeUnit::Microsecond => "MICROSECOND",
303                TimeUnit::Nanosecond => "NANOSECOND",
304            }})
305        }
306        DataType::Date32 => {
307            json!({"name": "date", "unit": "DAY"})
308        }
309        DataType::Date64 => {
310            json!({"name": "date", "unit": "MILLISECOND"})
311        }
312        DataType::Timestamp(unit, None) => {
313            json!({"name": "timestamp", "unit": match unit {
314                TimeUnit::Second => "SECOND",
315                TimeUnit::Millisecond => "MILLISECOND",
316                TimeUnit::Microsecond => "MICROSECOND",
317                TimeUnit::Nanosecond => "NANOSECOND",
318            }})
319        }
320        DataType::Timestamp(unit, Some(tz)) => {
321            json!({"name": "timestamp", "unit": match unit {
322                    TimeUnit::Second => "SECOND",
323                    TimeUnit::Millisecond => "MILLISECOND",
324                    TimeUnit::Microsecond => "MICROSECOND",
325                    TimeUnit::Nanosecond => "NANOSECOND",
326                }, "timezone": tz})
327        }
328        DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
329            IntervalUnit::YearMonth => "YEAR_MONTH",
330            IntervalUnit::DayTime => "DAY_TIME",
331            IntervalUnit::MonthDayNano => "MONTH_DAY_NANO",
332        }}),
333        DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
334            TimeUnit::Second => "SECOND",
335            TimeUnit::Millisecond => "MILLISECOND",
336            TimeUnit::Microsecond => "MICROSECOND",
337            TimeUnit::Nanosecond => "NANOSECOND",
338        }}),
339        DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
340        DataType::Decimal32(precision, scale) => {
341            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 32})
342        }
343        DataType::Decimal64(precision, scale) => {
344            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 64})
345        }
346        DataType::Decimal128(precision, scale) => {
347            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
348        }
349        DataType::Decimal256(precision, scale) => {
350            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
351        }
352        DataType::Map(_, keys_sorted) => {
353            json!({"name": "map", "keysSorted": keys_sorted})
354        }
355        DataType::RunEndEncoded(_, _) => todo!(),
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use serde_json::Value;
363
364    #[test]
365    fn parse_utf8_from_json() {
366        let json = "{\"name\":\"utf8\"}";
367        let value: Value = serde_json::from_str(json).unwrap();
368        let dt = data_type_from_json(&value).unwrap();
369        assert_eq!(DataType::Utf8, dt);
370    }
371
372    #[test]
373    fn parse_int32_from_json() {
374        let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}";
375        let value: Value = serde_json::from_str(json).unwrap();
376        let dt = data_type_from_json(&value).unwrap();
377        assert_eq!(DataType::Int32, dt);
378    }
379}