1use arrow_schema::{ArrowError, DataType, Field, Fields, Schema};
19use indexmap::map::IndexMap as HashMap;
20use indexmap::set::IndexSet as HashSet;
21use serde_json::Value;
22use std::borrow::Borrow;
23use std::io::{BufRead, Seek};
24use std::sync::Arc;
25
26#[derive(Debug, Clone)]
27enum InferredType {
28 Scalar(HashSet<DataType>),
29 Array(Box<InferredType>),
30 Object(HashMap<String, InferredType>),
31 Any,
32}
33
34impl InferredType {
35 fn merge(&mut self, other: InferredType) -> Result<(), ArrowError> {
36 match (self, other) {
37 (InferredType::Array(s), InferredType::Array(o)) => {
38 s.merge(*o)?;
39 }
40 (InferredType::Scalar(self_hs), InferredType::Scalar(other_hs)) => {
41 other_hs.into_iter().for_each(|v| {
42 self_hs.insert(v);
43 });
44 }
45 (InferredType::Object(self_map), InferredType::Object(other_map)) => {
46 for (k, v) in other_map {
47 self_map.entry(k).or_insert(InferredType::Any).merge(v)?;
48 }
49 }
50 (s @ InferredType::Any, v) => {
51 *s = v;
52 }
53 (_, InferredType::Any) => {}
54 (InferredType::Array(self_inner_type), other_scalar @ InferredType::Scalar(_)) => {
56 self_inner_type.merge(other_scalar)?;
57 }
58 (s @ InferredType::Scalar(_), InferredType::Array(mut other_inner_type)) => {
59 other_inner_type.merge(s.clone())?;
60 *s = InferredType::Array(other_inner_type);
61 }
62 (s, o) => {
64 return Err(ArrowError::JsonError(format!(
65 "Incompatible type found during schema inference: {s:?} v.s. {o:?}",
66 )));
67 }
68 }
69
70 Ok(())
71 }
72
73 fn is_none_or_any(ty: Option<&Self>) -> bool {
74 matches!(ty, Some(Self::Any) | None)
75 }
76}
77
78fn list_type_of(ty: DataType) -> DataType {
80 DataType::List(Arc::new(Field::new_list_field(ty, true)))
81}
82
83fn coerce_data_type(dt: Vec<&DataType>) -> DataType {
89 let mut dt_iter = dt.into_iter().cloned();
90 let dt_init = dt_iter.next().unwrap_or(DataType::Utf8);
91
92 dt_iter.fold(dt_init, |l, r| match (l, r) {
93 (DataType::Null, o) | (o, DataType::Null) => o,
94 (DataType::Boolean, DataType::Boolean) => DataType::Boolean,
95 (DataType::Int64, DataType::Int64) => DataType::Int64,
96 (DataType::Float64, DataType::Float64)
97 | (DataType::Float64, DataType::Int64)
98 | (DataType::Int64, DataType::Float64) => DataType::Float64,
99 (DataType::List(l), DataType::List(r)) => {
100 list_type_of(coerce_data_type(vec![l.data_type(), r.data_type()]))
101 }
102 (DataType::List(e), not_list) | (not_list, DataType::List(e)) => {
104 list_type_of(coerce_data_type(vec![e.data_type(), ¬_list]))
105 }
106 _ => DataType::Utf8,
107 })
108}
109
110fn generate_datatype(t: &InferredType) -> Result<DataType, ArrowError> {
111 Ok(match t {
112 InferredType::Scalar(hs) => coerce_data_type(hs.iter().collect()),
113 InferredType::Object(spec) => DataType::Struct(generate_fields(spec)?),
114 InferredType::Array(ele_type) => list_type_of(generate_datatype(ele_type)?),
115 InferredType::Any => DataType::Null,
116 })
117}
118
119fn generate_fields(spec: &HashMap<String, InferredType>) -> Result<Fields, ArrowError> {
120 spec.iter()
121 .map(|(k, types)| Ok(Field::new(k, generate_datatype(types)?, true)))
122 .collect()
123}
124
125fn generate_schema(spec: HashMap<String, InferredType>) -> Result<Schema, ArrowError> {
127 Ok(Schema::new(generate_fields(&spec)?))
128}
129
130#[derive(Debug)]
147pub struct ValueIter<R: BufRead> {
148 reader: R,
149 max_read_records: Option<usize>,
150 record_count: usize,
151 line_buf: String,
153}
154
155impl<R: BufRead> ValueIter<R> {
156 pub fn new(reader: R, max_read_records: Option<usize>) -> Self {
158 Self {
159 reader,
160 max_read_records,
161 record_count: 0,
162 line_buf: String::new(),
163 }
164 }
165}
166
167impl<R: BufRead> Iterator for ValueIter<R> {
168 type Item = Result<Value, ArrowError>;
169
170 fn next(&mut self) -> Option<Self::Item> {
171 if let Some(max) = self.max_read_records {
172 if self.record_count >= max {
173 return None;
174 }
175 }
176
177 loop {
178 self.line_buf.truncate(0);
179 match self.reader.read_line(&mut self.line_buf) {
180 Ok(0) => {
181 return None;
183 }
184 Err(e) => {
185 return Some(Err(ArrowError::JsonError(format!(
186 "Failed to read JSON record: {e}"
187 ))));
188 }
189 _ => {
190 let trimmed_s = self.line_buf.trim();
191 if trimmed_s.is_empty() {
192 continue;
194 }
195
196 self.record_count += 1;
197 return Some(
198 serde_json::from_str(trimmed_s)
199 .map_err(|e| ArrowError::JsonError(format!("Not valid JSON: {e}"))),
200 );
201 }
202 }
203 }
204 }
205}
206
207pub fn infer_json_schema_from_seekable<R: BufRead + Seek>(
232 mut reader: R,
233 max_read_records: Option<usize>,
234) -> Result<(Schema, usize), ArrowError> {
235 let schema = infer_json_schema(&mut reader, max_read_records);
236 reader.rewind()?;
238
239 schema
240}
241
242pub fn infer_json_schema<R: BufRead>(
280 reader: R,
281 max_read_records: Option<usize>,
282) -> Result<(Schema, usize), ArrowError> {
283 let mut values = ValueIter::new(reader, max_read_records);
284 let schema = infer_json_schema_from_iterator(&mut values)?;
285 Ok((schema, values.record_count))
286}
287
288fn set_object_scalar_field_type(
289 field_types: &mut HashMap<String, InferredType>,
290 key: &str,
291 ftype: DataType,
292) -> Result<(), ArrowError> {
293 if InferredType::is_none_or_any(field_types.get(key)) {
294 field_types.insert(key.to_string(), InferredType::Scalar(HashSet::new()));
295 }
296
297 match field_types.get_mut(key).unwrap() {
298 InferredType::Scalar(hs) => {
299 hs.insert(ftype);
300 Ok(())
301 }
302 scalar_array @ InferredType::Array(_) => {
305 let mut hs = HashSet::new();
306 hs.insert(ftype);
307 scalar_array.merge(InferredType::Scalar(hs))?;
308 Ok(())
309 }
310 t => Err(ArrowError::JsonError(format!(
311 "Expected scalar or scalar array JSON type, found: {t:?}",
312 ))),
313 }
314}
315
316fn infer_scalar_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
317 let mut hs = HashSet::new();
318
319 for v in array {
320 match v {
321 Value::Null => {}
322 Value::Number(n) => {
323 if n.is_i64() {
324 hs.insert(DataType::Int64);
325 } else {
326 hs.insert(DataType::Float64);
327 }
328 }
329 Value::Bool(_) => {
330 hs.insert(DataType::Boolean);
331 }
332 Value::String(_) => {
333 hs.insert(DataType::Utf8);
334 }
335 Value::Array(_) | Value::Object(_) => {
336 return Err(ArrowError::JsonError(format!(
337 "Expected scalar value for scalar array, got: {v:?}"
338 )));
339 }
340 }
341 }
342
343 Ok(InferredType::Scalar(hs))
344}
345
346fn infer_nested_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
347 let mut inner_ele_type = InferredType::Any;
348
349 for v in array {
350 match v {
351 Value::Array(inner_array) => {
352 inner_ele_type.merge(infer_array_element_type(inner_array)?)?;
353 }
354 x => {
355 return Err(ArrowError::JsonError(format!(
356 "Got non array element in nested array: {x:?}"
357 )));
358 }
359 }
360 }
361
362 Ok(InferredType::Array(Box::new(inner_ele_type)))
363}
364
365fn infer_struct_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
366 let mut field_types = HashMap::new();
367
368 for v in array {
369 match v {
370 Value::Object(map) => {
371 collect_field_types_from_object(&mut field_types, map)?;
372 }
373 _ => {
374 return Err(ArrowError::JsonError(format!(
375 "Expected struct value for struct array, got: {v:?}"
376 )));
377 }
378 }
379 }
380
381 Ok(InferredType::Object(field_types))
382}
383
384fn infer_array_element_type(array: &[Value]) -> Result<InferredType, ArrowError> {
385 match array.iter().take(1).next() {
386 None => Ok(InferredType::Any), Some(a) => match a {
388 Value::Array(_) => infer_nested_array_type(array),
389 Value::Object(_) => infer_struct_array_type(array),
390 _ => infer_scalar_array_type(array),
391 },
392 }
393}
394
395fn collect_field_types_from_object(
396 field_types: &mut HashMap<String, InferredType>,
397 map: &serde_json::map::Map<String, Value>,
398) -> Result<(), ArrowError> {
399 for (k, v) in map {
400 match v {
401 Value::Array(array) => {
402 let ele_type = infer_array_element_type(array)?;
403
404 if InferredType::is_none_or_any(field_types.get(k)) {
405 match ele_type {
406 InferredType::Scalar(_) => {
407 field_types.insert(
408 k.to_string(),
409 InferredType::Array(Box::new(InferredType::Scalar(HashSet::new()))),
410 );
411 }
412 InferredType::Object(_) => {
413 field_types.insert(
414 k.to_string(),
415 InferredType::Array(Box::new(InferredType::Object(HashMap::new()))),
416 );
417 }
418 InferredType::Any | InferredType::Array(_) => {
419 field_types.insert(
422 k.to_string(),
423 InferredType::Array(Box::new(InferredType::Any)),
424 );
425 }
426 }
427 }
428
429 match field_types.get_mut(k).unwrap() {
430 InferredType::Array(inner_type) => {
431 inner_type.merge(ele_type)?;
432 }
433 field_type @ InferredType::Scalar(_) => {
436 field_type.merge(ele_type)?;
437 *field_type = InferredType::Array(Box::new(field_type.clone()));
438 }
439 t => {
440 return Err(ArrowError::JsonError(format!(
441 "Expected array json type, found: {t:?}",
442 )));
443 }
444 }
445 }
446 Value::Bool(_) => {
447 set_object_scalar_field_type(field_types, k, DataType::Boolean)?;
448 }
449 Value::Null => {
450 if !field_types.contains_key(k) {
453 field_types.insert(k.to_string(), InferredType::Any);
454 }
455 }
456 Value::Number(n) => {
457 if n.is_i64() {
458 set_object_scalar_field_type(field_types, k, DataType::Int64)?;
459 } else {
460 set_object_scalar_field_type(field_types, k, DataType::Float64)?;
461 }
462 }
463 Value::String(_) => {
464 set_object_scalar_field_type(field_types, k, DataType::Utf8)?;
465 }
466 Value::Object(inner_map) => {
467 if let InferredType::Any = field_types.get(k).unwrap_or(&InferredType::Any) {
468 field_types.insert(k.to_string(), InferredType::Object(HashMap::new()));
469 }
470 match field_types.get_mut(k).unwrap() {
471 InferredType::Object(inner_field_types) => {
472 collect_field_types_from_object(inner_field_types, inner_map)?;
473 }
474 t => {
475 return Err(ArrowError::JsonError(format!(
476 "Expected object json type, found: {t:?}",
477 )));
478 }
479 }
480 }
481 }
482 }
483
484 Ok(())
485}
486
487pub fn infer_json_schema_from_iterator<I, V>(value_iter: I) -> Result<Schema, ArrowError>
501where
502 I: Iterator<Item = Result<V, ArrowError>>,
503 V: Borrow<Value>,
504{
505 let mut field_types: HashMap<String, InferredType> = HashMap::new();
506
507 for record in value_iter {
508 match record?.borrow() {
509 Value::Object(map) => {
510 collect_field_types_from_object(&mut field_types, map)?;
511 }
512 value => {
513 return Err(ArrowError::JsonError(format!(
514 "Expected JSON record to be an object, found {value:?}"
515 )));
516 }
517 };
518 }
519
520 generate_schema(field_types)
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use flate2::read::GzDecoder;
527 use std::fs::File;
528 use std::io::{BufReader, Cursor};
529
530 #[test]
531 fn test_json_infer_schema() {
532 let schema = Schema::new(vec![
533 Field::new("a", DataType::Int64, true),
534 Field::new("b", list_type_of(DataType::Float64), true),
535 Field::new("c", list_type_of(DataType::Boolean), true),
536 Field::new("d", list_type_of(DataType::Utf8), true),
537 ]);
538
539 let mut reader = BufReader::new(File::open("test/data/mixed_arrays.json").unwrap());
540 let (inferred_schema, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
541
542 assert_eq!(inferred_schema, schema);
543 assert_eq!(n_rows, 4);
544
545 let file = File::open("test/data/mixed_arrays.json.gz").unwrap();
546 let mut reader = BufReader::new(GzDecoder::new(&file));
547 let (inferred_schema, n_rows) = infer_json_schema(&mut reader, None).unwrap();
548
549 assert_eq!(inferred_schema, schema);
550 assert_eq!(n_rows, 4);
551 }
552
553 #[test]
554 fn test_row_limit() {
555 let mut reader = BufReader::new(File::open("test/data/basic.json").unwrap());
556
557 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
558 assert_eq!(n_rows, 12);
559
560 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, Some(5)).unwrap();
561 assert_eq!(n_rows, 5);
562 }
563
564 #[test]
565 fn test_json_infer_schema_nested_structs() {
566 let schema = Schema::new(vec![
567 Field::new(
568 "c1",
569 DataType::Struct(Fields::from(vec![
570 Field::new("a", DataType::Boolean, true),
571 Field::new(
572 "b",
573 DataType::Struct(vec![Field::new("c", DataType::Utf8, true)].into()),
574 true,
575 ),
576 ])),
577 true,
578 ),
579 Field::new("c2", DataType::Int64, true),
580 Field::new("c3", DataType::Utf8, true),
581 ]);
582
583 let inferred_schema = infer_json_schema_from_iterator(
584 vec![
585 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c2": 1})),
586 Ok(serde_json::json!({"c1": {"a": false, "b": null}, "c2": 0})),
587 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c3": "ok"})),
588 ]
589 .into_iter(),
590 )
591 .unwrap();
592
593 assert_eq!(inferred_schema, schema);
594 }
595
596 #[test]
597 fn test_json_infer_schema_struct_in_list() {
598 let schema = Schema::new(vec![
599 Field::new(
600 "c1",
601 list_type_of(DataType::Struct(Fields::from(vec![
602 Field::new("a", DataType::Utf8, true),
603 Field::new("b", DataType::Int64, true),
604 Field::new("c", DataType::Boolean, true),
605 ]))),
606 true,
607 ),
608 Field::new("c2", DataType::Float64, true),
609 Field::new(
610 "c3",
611 list_type_of(DataType::Null),
613 true,
614 ),
615 ]);
616
617 let inferred_schema = infer_json_schema_from_iterator(
618 vec![
619 Ok(serde_json::json!({
620 "c1": [{"a": "foo", "b": 100}], "c2": 1, "c3": [],
621 })),
622 Ok(serde_json::json!({
623 "c1": [{"a": "bar", "b": 2}, {"a": "foo", "c": true}], "c2": 0, "c3": [],
624 })),
625 Ok(serde_json::json!({"c1": [], "c2": 0.5, "c3": []})),
626 ]
627 .into_iter(),
628 )
629 .unwrap();
630
631 assert_eq!(inferred_schema, schema);
632 }
633
634 #[test]
635 fn test_json_infer_schema_nested_list() {
636 let schema = Schema::new(vec![
637 Field::new("c1", list_type_of(list_type_of(DataType::Utf8)), true),
638 Field::new("c2", DataType::Float64, true),
639 ]);
640
641 let inferred_schema = infer_json_schema_from_iterator(
642 vec![
643 Ok(serde_json::json!({
644 "c1": [],
645 "c2": 12,
646 })),
647 Ok(serde_json::json!({
648 "c1": [["a", "b"], ["c"]],
649 })),
650 Ok(serde_json::json!({
651 "c1": [["foo"]],
652 "c2": 0.11,
653 })),
654 ]
655 .into_iter(),
656 )
657 .unwrap();
658
659 assert_eq!(inferred_schema, schema);
660 }
661
662 #[test]
663 fn test_infer_json_schema_bigger_than_i64_max() {
664 let bigger_than_i64_max = (i64::MAX as i128) + 1;
665 let smaller_than_i64_min = (i64::MIN as i128) - 1;
666 let json = format!(
667 "{{ \"bigger_than_i64_max\": {bigger_than_i64_max}, \"smaller_than_i64_min\": {smaller_than_i64_min} }}",
668 );
669 let mut buf_reader = BufReader::new(json.as_bytes());
670 let (inferred_schema, _) = infer_json_schema(&mut buf_reader, Some(1)).unwrap();
671 let fields = inferred_schema.fields();
672
673 let (_, big_field) = fields.find("bigger_than_i64_max").unwrap();
674 assert_eq!(big_field.data_type(), &DataType::Float64);
675 let (_, small_field) = fields.find("smaller_than_i64_min").unwrap();
676 assert_eq!(small_field.data_type(), &DataType::Float64);
677 }
678
679 #[test]
680 fn test_coercion_scalar_and_list() {
681 assert_eq!(
682 list_type_of(DataType::Float64),
683 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Float64)])
684 );
685 assert_eq!(
686 list_type_of(DataType::Float64),
687 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Int64)])
688 );
689 assert_eq!(
690 list_type_of(DataType::Int64),
691 coerce_data_type(vec![&DataType::Int64, &list_type_of(DataType::Int64)])
692 );
693 assert_eq!(
695 list_type_of(DataType::Utf8),
696 coerce_data_type(vec![&DataType::Boolean, &list_type_of(DataType::Float64)])
697 );
698 }
699
700 #[test]
701 fn test_invalid_json_infer_schema() {
702 let re = infer_json_schema_from_seekable(Cursor::new(b"}"), None);
703 assert_eq!(
704 re.err().unwrap().to_string(),
705 "Json error: Not valid JSON: expected value at line 1 column 1",
706 );
707 }
708
709 #[test]
710 fn test_null_field_inferred_as_null() {
711 let data = r#"
712 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":null}
713 {"in":null, "ni":2, "ns":"3", "sn":null, "n":null, "an":null, "na": [], "nas":["8"]}
714 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":[]}
715 "#;
716 let (inferred_schema, _) =
717 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
718 let schema = Schema::new(vec![
719 Field::new("an", list_type_of(DataType::Null), true),
720 Field::new("in", DataType::Int64, true),
721 Field::new("n", DataType::Null, true),
722 Field::new("na", list_type_of(DataType::Null), true),
723 Field::new("nas", list_type_of(DataType::Utf8), true),
724 Field::new("ni", DataType::Int64, true),
725 Field::new("ns", DataType::Utf8, true),
726 Field::new("sn", DataType::Utf8, true),
727 ]);
728 assert_eq!(inferred_schema, schema);
729 }
730
731 #[test]
732 fn test_infer_from_null_then_object() {
733 let data = r#"
734 {"obj":null}
735 {"obj":{"foo":1}}
736 "#;
737 let (inferred_schema, _) =
738 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
739 let schema = Schema::new(vec![Field::new(
740 "obj",
741 DataType::Struct(
742 [Field::new("foo", DataType::Int64, true)]
743 .into_iter()
744 .collect(),
745 ),
746 true,
747 )]);
748 assert_eq!(inferred_schema, schema);
749 }
750}