Skip to main content

arrow_integration_test/
datatype.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode};
19use arrow::error::{ArrowError, Result};
20use std::sync::Arc;
21
22/// Parse a data type from a JSON representation.
23pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
24    use serde_json::Value;
25    let default_field = Arc::new(Field::new("", DataType::Boolean, true));
26    match *json {
27        Value::Object(ref map) => match map.get("name") {
28            Some(s) if s == "null" => Ok(DataType::Null),
29            Some(s) if s == "bool" => Ok(DataType::Boolean),
30            Some(s) if s == "binary" => Ok(DataType::Binary),
31            Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
32            Some(s) if s == "binaryview" => Ok(DataType::BinaryView),
33            Some(s) if s == "utf8view" => Ok(DataType::Utf8View),
34            Some(s) if s == "utf8" => Ok(DataType::Utf8),
35            Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
36            Some(s) if s == "fixedsizebinary" => {
37                // return a list with any type as its child isn't defined in the map
38                if let Some(Value::Number(size)) = map.get("byteWidth") {
39                    Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
40                } else {
41                    Err(ArrowError::ParseError(
42                        "Expecting a byteWidth for fixedsizebinary".to_string(),
43                    ))
44                }
45            }
46            Some(s) if s == "decimal" => {
47                // return a list with any type as its child isn't defined in the map
48                let precision = match map.get("precision") {
49                    Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()),
50                    None => Err(ArrowError::ParseError(
51                        "Expecting a precision for decimal".to_string(),
52                    )),
53                }?;
54                let scale = match map.get("scale") {
55                    Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()),
56                    _ => Err(ArrowError::ParseError(
57                        "Expecting a scale for decimal".to_string(),
58                    )),
59                }?;
60                let bit_width: usize = match map.get("bitWidth") {
61                    Some(b) => b.as_u64().unwrap() as usize,
62                    _ => 128, // Default bit width
63                };
64
65                match bit_width {
66                    32 => Ok(DataType::Decimal32(precision, scale)),
67                    64 => Ok(DataType::Decimal64(precision, scale)),
68                    128 => Ok(DataType::Decimal128(precision, scale)),
69                    256 => Ok(DataType::Decimal256(precision, scale)),
70                    _ => Err(ArrowError::ParseError(
71                        "Decimal bit_width invalid".to_string(),
72                    )),
73                }
74            }
75            Some(s) if s == "floatingpoint" => match map.get("precision") {
76                Some(p) if p == "HALF" => Ok(DataType::Float16),
77                Some(p) if p == "SINGLE" => Ok(DataType::Float32),
78                Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
79                _ => Err(ArrowError::ParseError(
80                    "floatingpoint precision missing or invalid".to_string(),
81                )),
82            },
83            Some(s) if s == "timestamp" => {
84                let unit = match map.get("unit") {
85                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
86                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
87                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
88                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
89                    _ => Err(ArrowError::ParseError(
90                        "timestamp unit missing or invalid".to_string(),
91                    )),
92                };
93                let tz = match map.get("timezone") {
94                    None => Ok(None),
95                    Some(Value::String(tz)) => Ok(Some(tz.as_str().into())),
96                    _ => Err(ArrowError::ParseError(
97                        "timezone must be a string".to_string(),
98                    )),
99                };
100                Ok(DataType::Timestamp(unit?, tz?))
101            }
102            Some(s) if s == "date" => match map.get("unit") {
103                Some(p) if p == "DAY" => Ok(DataType::Date32),
104                Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
105                _ => Err(ArrowError::ParseError(
106                    "date unit missing or invalid".to_string(),
107                )),
108            },
109            Some(s) if s == "time" => {
110                let unit = match map.get("unit") {
111                    Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
112                    Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
113                    Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
114                    Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
115                    _ => Err(ArrowError::ParseError(
116                        "time unit missing or invalid".to_string(),
117                    )),
118                };
119                match map.get("bitWidth") {
120                    Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
121                    Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
122                    _ => Err(ArrowError::ParseError(
123                        "time bitWidth missing or invalid".to_string(),
124                    )),
125                }
126            }
127            Some(s) if s == "duration" => match map.get("unit") {
128                Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
129                Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)),
130                Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)),
131                Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)),
132                _ => Err(ArrowError::ParseError(
133                    "time unit missing or invalid".to_string(),
134                )),
135            },
136            Some(s) if s == "interval" => match map.get("unit") {
137                Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)),
138                Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)),
139                Some(p) if p == "MONTH_DAY_NANO" => {
140                    Ok(DataType::Interval(IntervalUnit::MonthDayNano))
141                }
142                _ => Err(ArrowError::ParseError(
143                    "interval unit missing or invalid".to_string(),
144                )),
145            },
146            Some(s) if s == "int" => match map.get("isSigned") {
147                Some(&Value::Bool(true)) => match map.get("bitWidth") {
148                    Some(Value::Number(n)) => match n.as_u64() {
149                        Some(8) => Ok(DataType::Int8),
150                        Some(16) => Ok(DataType::Int16),
151                        Some(32) => Ok(DataType::Int32),
152                        Some(64) => Ok(DataType::Int64),
153                        _ => Err(ArrowError::ParseError(
154                            "int bitWidth missing or invalid".to_string(),
155                        )),
156                    },
157                    _ => Err(ArrowError::ParseError(
158                        "int bitWidth missing or invalid".to_string(),
159                    )),
160                },
161                Some(&Value::Bool(false)) => match map.get("bitWidth") {
162                    Some(Value::Number(n)) => match n.as_u64() {
163                        Some(8) => Ok(DataType::UInt8),
164                        Some(16) => Ok(DataType::UInt16),
165                        Some(32) => Ok(DataType::UInt32),
166                        Some(64) => Ok(DataType::UInt64),
167                        _ => Err(ArrowError::ParseError(
168                            "int bitWidth missing or invalid".to_string(),
169                        )),
170                    },
171                    _ => Err(ArrowError::ParseError(
172                        "int bitWidth missing or invalid".to_string(),
173                    )),
174                },
175                _ => Err(ArrowError::ParseError(
176                    "int signed missing or invalid".to_string(),
177                )),
178            },
179            Some(s) if s == "list" => {
180                // return a list with any type as its child isn't defined in the map
181                Ok(DataType::List(default_field))
182            }
183            Some(s) if s == "largelist" => {
184                // return a largelist with any type as its child isn't defined in the map
185                Ok(DataType::LargeList(default_field))
186            }
187            Some(s) if s == "listview" => {
188                // return a listview with any type as its child isn't defined in the map
189                Ok(DataType::ListView(default_field))
190            }
191            Some(s) if s == "largelistview" => {
192                // return a large listview with any type as its child isn't defined in the map
193                Ok(DataType::LargeListView(default_field))
194            }
195            Some(s) if s == "fixedsizelist" => {
196                // return a list with any type as its child isn't defined in the map
197                if let Some(Value::Number(size)) = map.get("listSize") {
198                    Ok(DataType::FixedSizeList(
199                        default_field,
200                        size.as_i64().unwrap() as i32,
201                    ))
202                } else {
203                    Err(ArrowError::ParseError(
204                        "Expecting a listSize for fixedsizelist".to_string(),
205                    ))
206                }
207            }
208            Some(s) if s == "struct" => {
209                // return an empty `struct` type as its children aren't defined in the map
210                Ok(DataType::Struct(Fields::empty()))
211            }
212            Some(s) if s == "runendencoded" => {
213                // return a run end encoded with placeholder types as children aren't defined in the map
214                Ok(DataType::RunEndEncoded(
215                    Arc::new(Field::new("run_ends", DataType::Int32, false)),
216                    default_field,
217                ))
218            }
219            Some(s) if s == "map" => {
220                if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
221                    // Return a map with an empty type as its children aren't defined in the map
222                    Ok(DataType::Map(default_field, *keys_sorted))
223                } else {
224                    Err(ArrowError::ParseError(
225                        "Expecting a keysSorted for map".to_string(),
226                    ))
227                }
228            }
229            Some(s) if s == "union" => {
230                if let Some(Value::String(mode)) = map.get("mode") {
231                    let union_mode = if mode == "SPARSE" {
232                        UnionMode::Sparse
233                    } else if mode == "DENSE" {
234                        UnionMode::Dense
235                    } else {
236                        return Err(ArrowError::ParseError(format!(
237                            "Unknown union mode {mode:?} for union"
238                        )));
239                    };
240                    if let Some(values) = map.get("typeIds") {
241                        let values = values.as_array().unwrap();
242                        let fields = values
243                            .iter()
244                            .map(|t| (t.as_i64().unwrap() as i8, default_field.clone()))
245                            .collect();
246
247                        Ok(DataType::Union(fields, union_mode))
248                    } else {
249                        Err(ArrowError::ParseError(
250                            "Expecting a typeIds for union ".to_string(),
251                        ))
252                    }
253                } else {
254                    Err(ArrowError::ParseError(
255                        "Expecting a mode for union".to_string(),
256                    ))
257                }
258            }
259            Some(other) => Err(ArrowError::ParseError(format!(
260                "invalid or unsupported type name: {other} in {json:?}"
261            ))),
262            None => Err(ArrowError::ParseError("type name missing".to_string())),
263        },
264        _ => Err(ArrowError::ParseError(
265            "invalid json value type".to_string(),
266        )),
267    }
268}
269
270/// Generate a JSON representation of the data type.
271pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
272    use serde_json::json;
273    match data_type {
274        DataType::Null => json!({"name": "null"}),
275        DataType::Boolean => json!({"name": "bool"}),
276        DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
277        DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
278        DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
279        DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
280        DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
281        DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
282        DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
283        DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
284        DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
285        DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
286        DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
287        DataType::Utf8 => json!({"name": "utf8"}),
288        DataType::LargeUtf8 => json!({"name": "largeutf8"}),
289        DataType::Binary => json!({"name": "binary"}),
290        DataType::LargeBinary => json!({"name": "largebinary"}),
291        DataType::BinaryView => json!({"name": "binaryview"}),
292        DataType::Utf8View => json!({"name": "utf8view"}),
293        DataType::FixedSizeBinary(byte_width) => {
294            json!({"name": "fixedsizebinary", "byteWidth": byte_width})
295        }
296        DataType::Struct(_) => json!({"name": "struct"}),
297        DataType::Union(_, _) => json!({"name": "union"}),
298        DataType::List(_) => json!({ "name": "list"}),
299        DataType::LargeList(_) => json!({ "name": "largelist"}),
300        DataType::ListView(_) => json!({ "name": "listview"}),
301        DataType::LargeListView(_) => json!({ "name": "largelistview"}),
302        DataType::FixedSizeList(_, length) => {
303            json!({"name":"fixedsizelist", "listSize": length})
304        }
305        DataType::Time32(unit) => {
306            json!({"name": "time", "bitWidth": 32, "unit": match unit {
307                TimeUnit::Second => "SECOND",
308                TimeUnit::Millisecond => "MILLISECOND",
309                TimeUnit::Microsecond => "MICROSECOND",
310                TimeUnit::Nanosecond => "NANOSECOND",
311            }})
312        }
313        DataType::Time64(unit) => {
314            json!({"name": "time", "bitWidth": 64, "unit": match unit {
315                TimeUnit::Second => "SECOND",
316                TimeUnit::Millisecond => "MILLISECOND",
317                TimeUnit::Microsecond => "MICROSECOND",
318                TimeUnit::Nanosecond => "NANOSECOND",
319            }})
320        }
321        DataType::Date32 => {
322            json!({"name": "date", "unit": "DAY"})
323        }
324        DataType::Date64 => {
325            json!({"name": "date", "unit": "MILLISECOND"})
326        }
327        DataType::Timestamp(unit, None) => {
328            json!({"name": "timestamp", "unit": match unit {
329                TimeUnit::Second => "SECOND",
330                TimeUnit::Millisecond => "MILLISECOND",
331                TimeUnit::Microsecond => "MICROSECOND",
332                TimeUnit::Nanosecond => "NANOSECOND",
333            }})
334        }
335        DataType::Timestamp(unit, Some(tz)) => {
336            json!({"name": "timestamp", "unit": match unit {
337                    TimeUnit::Second => "SECOND",
338                    TimeUnit::Millisecond => "MILLISECOND",
339                    TimeUnit::Microsecond => "MICROSECOND",
340                    TimeUnit::Nanosecond => "NANOSECOND",
341                }, "timezone": tz})
342        }
343        DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
344            IntervalUnit::YearMonth => "YEAR_MONTH",
345            IntervalUnit::DayTime => "DAY_TIME",
346            IntervalUnit::MonthDayNano => "MONTH_DAY_NANO",
347        }}),
348        DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
349            TimeUnit::Second => "SECOND",
350            TimeUnit::Millisecond => "MILLISECOND",
351            TimeUnit::Microsecond => "MICROSECOND",
352            TimeUnit::Nanosecond => "NANOSECOND",
353        }}),
354        DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
355        DataType::Decimal32(precision, scale) => {
356            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 32})
357        }
358        DataType::Decimal64(precision, scale) => {
359            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 64})
360        }
361        DataType::Decimal128(precision, scale) => {
362            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
363        }
364        DataType::Decimal256(precision, scale) => {
365            json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
366        }
367        DataType::Map(_, keys_sorted) => {
368            json!({"name": "map", "keysSorted": keys_sorted})
369        }
370        DataType::RunEndEncoded(_, _) => json!({"name": "runendencoded"}),
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use serde_json::Value;
378
379    #[test]
380    fn parse_utf8_from_json() {
381        let json = "{\"name\":\"utf8\"}";
382        let value: Value = serde_json::from_str(json).unwrap();
383        let dt = data_type_from_json(&value).unwrap();
384        assert_eq!(DataType::Utf8, dt);
385    }
386
387    #[test]
388    fn parse_int32_from_json() {
389        let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}";
390        let value: Value = serde_json::from_str(json).unwrap();
391        let dt = data_type_from_json(&value).unwrap();
392        assert_eq!(DataType::Int32, dt);
393    }
394}