1use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode};
19use arrow::error::{ArrowError, Result};
20use std::sync::Arc;
21
22pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
24 use serde_json::Value;
25 let default_field = Arc::new(Field::new("", DataType::Boolean, true));
26 match *json {
27 Value::Object(ref map) => match map.get("name") {
28 Some(s) if s == "null" => Ok(DataType::Null),
29 Some(s) if s == "bool" => Ok(DataType::Boolean),
30 Some(s) if s == "binary" => Ok(DataType::Binary),
31 Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
32 Some(s) if s == "utf8" => Ok(DataType::Utf8),
33 Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
34 Some(s) if s == "fixedsizebinary" => {
35 if let Some(Value::Number(size)) = map.get("byteWidth") {
37 Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
38 } else {
39 Err(ArrowError::ParseError(
40 "Expecting a byteWidth for fixedsizebinary".to_string(),
41 ))
42 }
43 }
44 Some(s) if s == "decimal" => {
45 let precision = match map.get("precision") {
47 Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()),
48 None => Err(ArrowError::ParseError(
49 "Expecting a precision for decimal".to_string(),
50 )),
51 }?;
52 let scale = match map.get("scale") {
53 Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()),
54 _ => Err(ArrowError::ParseError(
55 "Expecting a scale for decimal".to_string(),
56 )),
57 }?;
58 let bit_width: usize = match map.get("bitWidth") {
59 Some(b) => b.as_u64().unwrap() as usize,
60 _ => 128, };
62
63 match bit_width {
64 32 => Ok(DataType::Decimal32(precision, scale)),
65 64 => Ok(DataType::Decimal64(precision, scale)),
66 128 => Ok(DataType::Decimal128(precision, scale)),
67 256 => Ok(DataType::Decimal256(precision, scale)),
68 _ => Err(ArrowError::ParseError(
69 "Decimal bit_width invalid".to_string(),
70 )),
71 }
72 }
73 Some(s) if s == "floatingpoint" => match map.get("precision") {
74 Some(p) if p == "HALF" => Ok(DataType::Float16),
75 Some(p) if p == "SINGLE" => Ok(DataType::Float32),
76 Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
77 _ => Err(ArrowError::ParseError(
78 "floatingpoint precision missing or invalid".to_string(),
79 )),
80 },
81 Some(s) if s == "timestamp" => {
82 let unit = match map.get("unit") {
83 Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
84 Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
85 Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
86 Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
87 _ => Err(ArrowError::ParseError(
88 "timestamp unit missing or invalid".to_string(),
89 )),
90 };
91 let tz = match map.get("timezone") {
92 None => Ok(None),
93 Some(Value::String(tz)) => Ok(Some(tz.as_str().into())),
94 _ => Err(ArrowError::ParseError(
95 "timezone must be a string".to_string(),
96 )),
97 };
98 Ok(DataType::Timestamp(unit?, tz?))
99 }
100 Some(s) if s == "date" => match map.get("unit") {
101 Some(p) if p == "DAY" => Ok(DataType::Date32),
102 Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
103 _ => Err(ArrowError::ParseError(
104 "date unit missing or invalid".to_string(),
105 )),
106 },
107 Some(s) if s == "time" => {
108 let unit = match map.get("unit") {
109 Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
110 Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
111 Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
112 Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
113 _ => Err(ArrowError::ParseError(
114 "time unit missing or invalid".to_string(),
115 )),
116 };
117 match map.get("bitWidth") {
118 Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
119 Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
120 _ => Err(ArrowError::ParseError(
121 "time bitWidth missing or invalid".to_string(),
122 )),
123 }
124 }
125 Some(s) if s == "duration" => match map.get("unit") {
126 Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
127 Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)),
128 Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)),
129 Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)),
130 _ => Err(ArrowError::ParseError(
131 "time unit missing or invalid".to_string(),
132 )),
133 },
134 Some(s) if s == "interval" => match map.get("unit") {
135 Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)),
136 Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)),
137 Some(p) if p == "MONTH_DAY_NANO" => {
138 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
139 }
140 _ => Err(ArrowError::ParseError(
141 "interval unit missing or invalid".to_string(),
142 )),
143 },
144 Some(s) if s == "int" => match map.get("isSigned") {
145 Some(&Value::Bool(true)) => match map.get("bitWidth") {
146 Some(Value::Number(n)) => match n.as_u64() {
147 Some(8) => Ok(DataType::Int8),
148 Some(16) => Ok(DataType::Int16),
149 Some(32) => Ok(DataType::Int32),
150 Some(64) => Ok(DataType::Int64),
151 _ => Err(ArrowError::ParseError(
152 "int bitWidth missing or invalid".to_string(),
153 )),
154 },
155 _ => Err(ArrowError::ParseError(
156 "int bitWidth missing or invalid".to_string(),
157 )),
158 },
159 Some(&Value::Bool(false)) => match map.get("bitWidth") {
160 Some(Value::Number(n)) => match n.as_u64() {
161 Some(8) => Ok(DataType::UInt8),
162 Some(16) => Ok(DataType::UInt16),
163 Some(32) => Ok(DataType::UInt32),
164 Some(64) => Ok(DataType::UInt64),
165 _ => Err(ArrowError::ParseError(
166 "int bitWidth missing or invalid".to_string(),
167 )),
168 },
169 _ => Err(ArrowError::ParseError(
170 "int bitWidth missing or invalid".to_string(),
171 )),
172 },
173 _ => Err(ArrowError::ParseError(
174 "int signed missing or invalid".to_string(),
175 )),
176 },
177 Some(s) if s == "list" => {
178 Ok(DataType::List(default_field))
180 }
181 Some(s) if s == "largelist" => {
182 Ok(DataType::LargeList(default_field))
184 }
185 Some(s) if s == "fixedsizelist" => {
186 if let Some(Value::Number(size)) = map.get("listSize") {
188 Ok(DataType::FixedSizeList(
189 default_field,
190 size.as_i64().unwrap() as i32,
191 ))
192 } else {
193 Err(ArrowError::ParseError(
194 "Expecting a listSize for fixedsizelist".to_string(),
195 ))
196 }
197 }
198 Some(s) if s == "struct" => {
199 Ok(DataType::Struct(Fields::empty()))
201 }
202 Some(s) if s == "map" => {
203 if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
204 Ok(DataType::Map(default_field, *keys_sorted))
206 } else {
207 Err(ArrowError::ParseError(
208 "Expecting a keysSorted for map".to_string(),
209 ))
210 }
211 }
212 Some(s) if s == "union" => {
213 if let Some(Value::String(mode)) = map.get("mode") {
214 let union_mode = if mode == "SPARSE" {
215 UnionMode::Sparse
216 } else if mode == "DENSE" {
217 UnionMode::Dense
218 } else {
219 return Err(ArrowError::ParseError(format!(
220 "Unknown union mode {mode:?} for union"
221 )));
222 };
223 if let Some(values) = map.get("typeIds") {
224 let values = values.as_array().unwrap();
225 let fields = values
226 .iter()
227 .map(|t| (t.as_i64().unwrap() as i8, default_field.clone()))
228 .collect();
229
230 Ok(DataType::Union(fields, union_mode))
231 } else {
232 Err(ArrowError::ParseError(
233 "Expecting a typeIds for union ".to_string(),
234 ))
235 }
236 } else {
237 Err(ArrowError::ParseError(
238 "Expecting a mode for union".to_string(),
239 ))
240 }
241 }
242 Some(other) => Err(ArrowError::ParseError(format!(
243 "invalid or unsupported type name: {other} in {json:?}"
244 ))),
245 None => Err(ArrowError::ParseError("type name missing".to_string())),
246 },
247 _ => Err(ArrowError::ParseError(
248 "invalid json value type".to_string(),
249 )),
250 }
251}
252
253pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
255 use serde_json::json;
256 match data_type {
257 DataType::Null => json!({"name": "null"}),
258 DataType::Boolean => json!({"name": "bool"}),
259 DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
260 DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
261 DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
262 DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
263 DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
264 DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
265 DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
266 DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
267 DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
268 DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
269 DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
270 DataType::Utf8 => json!({"name": "utf8"}),
271 DataType::LargeUtf8 => json!({"name": "largeutf8"}),
272 DataType::Binary => json!({"name": "binary"}),
273 DataType::LargeBinary => json!({"name": "largebinary"}),
274 DataType::BinaryView | DataType::Utf8View => {
275 unimplemented!("BinaryView/Utf8View not implemented")
276 }
277 DataType::FixedSizeBinary(byte_width) => {
278 json!({"name": "fixedsizebinary", "byteWidth": byte_width})
279 }
280 DataType::Struct(_) => json!({"name": "struct"}),
281 DataType::Union(_, _) => json!({"name": "union"}),
282 DataType::List(_) => json!({ "name": "list"}),
283 DataType::LargeList(_) => json!({ "name": "largelist"}),
284 DataType::ListView(_) | DataType::LargeListView(_) => {
285 unimplemented!("ListView/LargeListView not implemented")
286 }
287 DataType::FixedSizeList(_, length) => {
288 json!({"name":"fixedsizelist", "listSize": length})
289 }
290 DataType::Time32(unit) => {
291 json!({"name": "time", "bitWidth": 32, "unit": match unit {
292 TimeUnit::Second => "SECOND",
293 TimeUnit::Millisecond => "MILLISECOND",
294 TimeUnit::Microsecond => "MICROSECOND",
295 TimeUnit::Nanosecond => "NANOSECOND",
296 }})
297 }
298 DataType::Time64(unit) => {
299 json!({"name": "time", "bitWidth": 64, "unit": match unit {
300 TimeUnit::Second => "SECOND",
301 TimeUnit::Millisecond => "MILLISECOND",
302 TimeUnit::Microsecond => "MICROSECOND",
303 TimeUnit::Nanosecond => "NANOSECOND",
304 }})
305 }
306 DataType::Date32 => {
307 json!({"name": "date", "unit": "DAY"})
308 }
309 DataType::Date64 => {
310 json!({"name": "date", "unit": "MILLISECOND"})
311 }
312 DataType::Timestamp(unit, None) => {
313 json!({"name": "timestamp", "unit": match unit {
314 TimeUnit::Second => "SECOND",
315 TimeUnit::Millisecond => "MILLISECOND",
316 TimeUnit::Microsecond => "MICROSECOND",
317 TimeUnit::Nanosecond => "NANOSECOND",
318 }})
319 }
320 DataType::Timestamp(unit, Some(tz)) => {
321 json!({"name": "timestamp", "unit": match unit {
322 TimeUnit::Second => "SECOND",
323 TimeUnit::Millisecond => "MILLISECOND",
324 TimeUnit::Microsecond => "MICROSECOND",
325 TimeUnit::Nanosecond => "NANOSECOND",
326 }, "timezone": tz})
327 }
328 DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
329 IntervalUnit::YearMonth => "YEAR_MONTH",
330 IntervalUnit::DayTime => "DAY_TIME",
331 IntervalUnit::MonthDayNano => "MONTH_DAY_NANO",
332 }}),
333 DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
334 TimeUnit::Second => "SECOND",
335 TimeUnit::Millisecond => "MILLISECOND",
336 TimeUnit::Microsecond => "MICROSECOND",
337 TimeUnit::Nanosecond => "NANOSECOND",
338 }}),
339 DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
340 DataType::Decimal32(precision, scale) => {
341 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 32})
342 }
343 DataType::Decimal64(precision, scale) => {
344 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 64})
345 }
346 DataType::Decimal128(precision, scale) => {
347 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
348 }
349 DataType::Decimal256(precision, scale) => {
350 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
351 }
352 DataType::Map(_, keys_sorted) => {
353 json!({"name": "map", "keysSorted": keys_sorted})
354 }
355 DataType::RunEndEncoded(_, _) => todo!(),
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use serde_json::Value;
363
364 #[test]
365 fn parse_utf8_from_json() {
366 let json = "{\"name\":\"utf8\"}";
367 let value: Value = serde_json::from_str(json).unwrap();
368 let dt = data_type_from_json(&value).unwrap();
369 assert_eq!(DataType::Utf8, dt);
370 }
371
372 #[test]
373 fn parse_int32_from_json() {
374 let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}";
375 let value: Value = serde_json::from_str(json).unwrap();
376 let dt = data_type_from_json(&value).unwrap();
377 assert_eq!(DataType::Int32, dt);
378 }
379}