1use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode};
19use arrow::error::{ArrowError, Result};
20use std::sync::Arc;
21
22pub fn data_type_from_json(json: &serde_json::Value) -> Result<DataType> {
24 use serde_json::Value;
25 let default_field = Arc::new(Field::new("", DataType::Boolean, true));
26 match *json {
27 Value::Object(ref map) => match map.get("name") {
28 Some(s) if s == "null" => Ok(DataType::Null),
29 Some(s) if s == "bool" => Ok(DataType::Boolean),
30 Some(s) if s == "binary" => Ok(DataType::Binary),
31 Some(s) if s == "largebinary" => Ok(DataType::LargeBinary),
32 Some(s) if s == "binaryview" => Ok(DataType::BinaryView),
33 Some(s) if s == "utf8view" => Ok(DataType::Utf8View),
34 Some(s) if s == "utf8" => Ok(DataType::Utf8),
35 Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8),
36 Some(s) if s == "fixedsizebinary" => {
37 if let Some(Value::Number(size)) = map.get("byteWidth") {
39 Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32))
40 } else {
41 Err(ArrowError::ParseError(
42 "Expecting a byteWidth for fixedsizebinary".to_string(),
43 ))
44 }
45 }
46 Some(s) if s == "decimal" => {
47 let precision = match map.get("precision") {
49 Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()),
50 None => Err(ArrowError::ParseError(
51 "Expecting a precision for decimal".to_string(),
52 )),
53 }?;
54 let scale = match map.get("scale") {
55 Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()),
56 _ => Err(ArrowError::ParseError(
57 "Expecting a scale for decimal".to_string(),
58 )),
59 }?;
60 let bit_width: usize = match map.get("bitWidth") {
61 Some(b) => b.as_u64().unwrap() as usize,
62 _ => 128, };
64
65 match bit_width {
66 32 => Ok(DataType::Decimal32(precision, scale)),
67 64 => Ok(DataType::Decimal64(precision, scale)),
68 128 => Ok(DataType::Decimal128(precision, scale)),
69 256 => Ok(DataType::Decimal256(precision, scale)),
70 _ => Err(ArrowError::ParseError(
71 "Decimal bit_width invalid".to_string(),
72 )),
73 }
74 }
75 Some(s) if s == "floatingpoint" => match map.get("precision") {
76 Some(p) if p == "HALF" => Ok(DataType::Float16),
77 Some(p) if p == "SINGLE" => Ok(DataType::Float32),
78 Some(p) if p == "DOUBLE" => Ok(DataType::Float64),
79 _ => Err(ArrowError::ParseError(
80 "floatingpoint precision missing or invalid".to_string(),
81 )),
82 },
83 Some(s) if s == "timestamp" => {
84 let unit = match map.get("unit") {
85 Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
86 Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
87 Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
88 Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
89 _ => Err(ArrowError::ParseError(
90 "timestamp unit missing or invalid".to_string(),
91 )),
92 };
93 let tz = match map.get("timezone") {
94 None => Ok(None),
95 Some(Value::String(tz)) => Ok(Some(tz.as_str().into())),
96 _ => Err(ArrowError::ParseError(
97 "timezone must be a string".to_string(),
98 )),
99 };
100 Ok(DataType::Timestamp(unit?, tz?))
101 }
102 Some(s) if s == "date" => match map.get("unit") {
103 Some(p) if p == "DAY" => Ok(DataType::Date32),
104 Some(p) if p == "MILLISECOND" => Ok(DataType::Date64),
105 _ => Err(ArrowError::ParseError(
106 "date unit missing or invalid".to_string(),
107 )),
108 },
109 Some(s) if s == "time" => {
110 let unit = match map.get("unit") {
111 Some(p) if p == "SECOND" => Ok(TimeUnit::Second),
112 Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond),
113 Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond),
114 Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond),
115 _ => Err(ArrowError::ParseError(
116 "time unit missing or invalid".to_string(),
117 )),
118 };
119 match map.get("bitWidth") {
120 Some(p) if p == 32 => Ok(DataType::Time32(unit?)),
121 Some(p) if p == 64 => Ok(DataType::Time64(unit?)),
122 _ => Err(ArrowError::ParseError(
123 "time bitWidth missing or invalid".to_string(),
124 )),
125 }
126 }
127 Some(s) if s == "duration" => match map.get("unit") {
128 Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)),
129 Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)),
130 Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)),
131 Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)),
132 _ => Err(ArrowError::ParseError(
133 "time unit missing or invalid".to_string(),
134 )),
135 },
136 Some(s) if s == "interval" => match map.get("unit") {
137 Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)),
138 Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)),
139 Some(p) if p == "MONTH_DAY_NANO" => {
140 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
141 }
142 _ => Err(ArrowError::ParseError(
143 "interval unit missing or invalid".to_string(),
144 )),
145 },
146 Some(s) if s == "int" => match map.get("isSigned") {
147 Some(&Value::Bool(true)) => match map.get("bitWidth") {
148 Some(Value::Number(n)) => match n.as_u64() {
149 Some(8) => Ok(DataType::Int8),
150 Some(16) => Ok(DataType::Int16),
151 Some(32) => Ok(DataType::Int32),
152 Some(64) => Ok(DataType::Int64),
153 _ => Err(ArrowError::ParseError(
154 "int bitWidth missing or invalid".to_string(),
155 )),
156 },
157 _ => Err(ArrowError::ParseError(
158 "int bitWidth missing or invalid".to_string(),
159 )),
160 },
161 Some(&Value::Bool(false)) => match map.get("bitWidth") {
162 Some(Value::Number(n)) => match n.as_u64() {
163 Some(8) => Ok(DataType::UInt8),
164 Some(16) => Ok(DataType::UInt16),
165 Some(32) => Ok(DataType::UInt32),
166 Some(64) => Ok(DataType::UInt64),
167 _ => Err(ArrowError::ParseError(
168 "int bitWidth missing or invalid".to_string(),
169 )),
170 },
171 _ => Err(ArrowError::ParseError(
172 "int bitWidth missing or invalid".to_string(),
173 )),
174 },
175 _ => Err(ArrowError::ParseError(
176 "int signed missing or invalid".to_string(),
177 )),
178 },
179 Some(s) if s == "list" => {
180 Ok(DataType::List(default_field))
182 }
183 Some(s) if s == "largelist" => {
184 Ok(DataType::LargeList(default_field))
186 }
187 Some(s) if s == "listview" => {
188 Ok(DataType::ListView(default_field))
190 }
191 Some(s) if s == "largelistview" => {
192 Ok(DataType::LargeListView(default_field))
194 }
195 Some(s) if s == "fixedsizelist" => {
196 if let Some(Value::Number(size)) = map.get("listSize") {
198 Ok(DataType::FixedSizeList(
199 default_field,
200 size.as_i64().unwrap() as i32,
201 ))
202 } else {
203 Err(ArrowError::ParseError(
204 "Expecting a listSize for fixedsizelist".to_string(),
205 ))
206 }
207 }
208 Some(s) if s == "struct" => {
209 Ok(DataType::Struct(Fields::empty()))
211 }
212 Some(s) if s == "runendencoded" => {
213 Ok(DataType::RunEndEncoded(
215 Arc::new(Field::new("run_ends", DataType::Int32, false)),
216 default_field,
217 ))
218 }
219 Some(s) if s == "map" => {
220 if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") {
221 Ok(DataType::Map(default_field, *keys_sorted))
223 } else {
224 Err(ArrowError::ParseError(
225 "Expecting a keysSorted for map".to_string(),
226 ))
227 }
228 }
229 Some(s) if s == "union" => {
230 if let Some(Value::String(mode)) = map.get("mode") {
231 let union_mode = if mode == "SPARSE" {
232 UnionMode::Sparse
233 } else if mode == "DENSE" {
234 UnionMode::Dense
235 } else {
236 return Err(ArrowError::ParseError(format!(
237 "Unknown union mode {mode:?} for union"
238 )));
239 };
240 if let Some(values) = map.get("typeIds") {
241 let values = values.as_array().unwrap();
242 let fields = values
243 .iter()
244 .map(|t| (t.as_i64().unwrap() as i8, default_field.clone()))
245 .collect();
246
247 Ok(DataType::Union(fields, union_mode))
248 } else {
249 Err(ArrowError::ParseError(
250 "Expecting a typeIds for union ".to_string(),
251 ))
252 }
253 } else {
254 Err(ArrowError::ParseError(
255 "Expecting a mode for union".to_string(),
256 ))
257 }
258 }
259 Some(other) => Err(ArrowError::ParseError(format!(
260 "invalid or unsupported type name: {other} in {json:?}"
261 ))),
262 None => Err(ArrowError::ParseError("type name missing".to_string())),
263 },
264 _ => Err(ArrowError::ParseError(
265 "invalid json value type".to_string(),
266 )),
267 }
268}
269
270pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
272 use serde_json::json;
273 match data_type {
274 DataType::Null => json!({"name": "null"}),
275 DataType::Boolean => json!({"name": "bool"}),
276 DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}),
277 DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
278 DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
279 DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
280 DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
281 DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
282 DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
283 DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}),
284 DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}),
285 DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}),
286 DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}),
287 DataType::Utf8 => json!({"name": "utf8"}),
288 DataType::LargeUtf8 => json!({"name": "largeutf8"}),
289 DataType::Binary => json!({"name": "binary"}),
290 DataType::LargeBinary => json!({"name": "largebinary"}),
291 DataType::BinaryView => json!({"name": "binaryview"}),
292 DataType::Utf8View => json!({"name": "utf8view"}),
293 DataType::FixedSizeBinary(byte_width) => {
294 json!({"name": "fixedsizebinary", "byteWidth": byte_width})
295 }
296 DataType::Struct(_) => json!({"name": "struct"}),
297 DataType::Union(_, _) => json!({"name": "union"}),
298 DataType::List(_) => json!({ "name": "list"}),
299 DataType::LargeList(_) => json!({ "name": "largelist"}),
300 DataType::ListView(_) => json!({ "name": "listview"}),
301 DataType::LargeListView(_) => json!({ "name": "largelistview"}),
302 DataType::FixedSizeList(_, length) => {
303 json!({"name":"fixedsizelist", "listSize": length})
304 }
305 DataType::Time32(unit) => {
306 json!({"name": "time", "bitWidth": 32, "unit": match unit {
307 TimeUnit::Second => "SECOND",
308 TimeUnit::Millisecond => "MILLISECOND",
309 TimeUnit::Microsecond => "MICROSECOND",
310 TimeUnit::Nanosecond => "NANOSECOND",
311 }})
312 }
313 DataType::Time64(unit) => {
314 json!({"name": "time", "bitWidth": 64, "unit": match unit {
315 TimeUnit::Second => "SECOND",
316 TimeUnit::Millisecond => "MILLISECOND",
317 TimeUnit::Microsecond => "MICROSECOND",
318 TimeUnit::Nanosecond => "NANOSECOND",
319 }})
320 }
321 DataType::Date32 => {
322 json!({"name": "date", "unit": "DAY"})
323 }
324 DataType::Date64 => {
325 json!({"name": "date", "unit": "MILLISECOND"})
326 }
327 DataType::Timestamp(unit, None) => {
328 json!({"name": "timestamp", "unit": match unit {
329 TimeUnit::Second => "SECOND",
330 TimeUnit::Millisecond => "MILLISECOND",
331 TimeUnit::Microsecond => "MICROSECOND",
332 TimeUnit::Nanosecond => "NANOSECOND",
333 }})
334 }
335 DataType::Timestamp(unit, Some(tz)) => {
336 json!({"name": "timestamp", "unit": match unit {
337 TimeUnit::Second => "SECOND",
338 TimeUnit::Millisecond => "MILLISECOND",
339 TimeUnit::Microsecond => "MICROSECOND",
340 TimeUnit::Nanosecond => "NANOSECOND",
341 }, "timezone": tz})
342 }
343 DataType::Interval(unit) => json!({"name": "interval", "unit": match unit {
344 IntervalUnit::YearMonth => "YEAR_MONTH",
345 IntervalUnit::DayTime => "DAY_TIME",
346 IntervalUnit::MonthDayNano => "MONTH_DAY_NANO",
347 }}),
348 DataType::Duration(unit) => json!({"name": "duration", "unit": match unit {
349 TimeUnit::Second => "SECOND",
350 TimeUnit::Millisecond => "MILLISECOND",
351 TimeUnit::Microsecond => "MICROSECOND",
352 TimeUnit::Nanosecond => "NANOSECOND",
353 }}),
354 DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
355 DataType::Decimal32(precision, scale) => {
356 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 32})
357 }
358 DataType::Decimal64(precision, scale) => {
359 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 64})
360 }
361 DataType::Decimal128(precision, scale) => {
362 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
363 }
364 DataType::Decimal256(precision, scale) => {
365 json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
366 }
367 DataType::Map(_, keys_sorted) => {
368 json!({"name": "map", "keysSorted": keys_sorted})
369 }
370 DataType::RunEndEncoded(_, _) => json!({"name": "runendencoded"}),
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use serde_json::Value;
378
379 #[test]
380 fn parse_utf8_from_json() {
381 let json = "{\"name\":\"utf8\"}";
382 let value: Value = serde_json::from_str(json).unwrap();
383 let dt = data_type_from_json(&value).unwrap();
384 assert_eq!(DataType::Utf8, dt);
385 }
386
387 #[test]
388 fn parse_int32_from_json() {
389 let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}";
390 let value: Value = serde_json::from_str(json).unwrap();
391 let dt = data_type_from_json(&value).unwrap();
392 assert_eq!(DataType::Int32, dt);
393 }
394}