1mod records;
167
168use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
169use arrow_array::types::*;
170use arrow_array::*;
171use arrow_cast::parse::{Parser, parse_decimal, string_to_datetime};
172use arrow_schema::*;
173use chrono::{TimeZone, Utc};
174use csv::StringRecord;
175use regex::{Regex, RegexSet};
176use std::fmt::{self, Debug};
177use std::fs::File;
178use std::io::{BufRead, BufReader as StdBufReader, Read};
179use std::sync::{Arc, LazyLock};
180
181use crate::map_csv_error;
182use crate::reader::records::{RecordDecoder, StringRecords};
183use arrow_array::timezone::Tz;
184
185static REGEX_SET: LazyLock<RegexSet> = LazyLock::new(|| {
187 RegexSet::new([
188 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ])
197 .unwrap()
198});
199
200#[derive(Debug, Clone, Default)]
202struct NullRegex(Option<Regex>);
203
204impl NullRegex {
205 #[inline]
208 fn is_null(&self, s: &str) -> bool {
209 match &self.0 {
210 Some(r) => r.is_match(s),
211 None => s.is_empty(),
212 }
213 }
214}
215
216#[derive(Default, Copy, Clone)]
217struct InferredDataType {
218 packed: u16,
230}
231
232impl InferredDataType {
233 fn get(&self) -> DataType {
235 match self.packed {
236 0 => DataType::Null,
237 1 => DataType::Boolean,
238 2 => DataType::Int64,
239 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
241 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
243 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
244 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
245 11 => DataType::Timestamp(TimeUnit::Second, None),
246 12 => DataType::Date32,
247 _ => unreachable!(),
248 },
249 _ => DataType::Utf8,
250 }
251 }
252
253 fn update(&mut self, string: &str) {
255 self.packed |= if string.starts_with('"') {
256 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
258 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
259 1 << 8
261 } else {
262 1 << m
263 }
264 } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
265 1 << 2 } else {
267 1 << 8 }
269 }
270}
271
272#[derive(Debug, Clone, Default)]
274pub struct Format {
275 header: bool,
276 delimiter: Option<u8>,
277 escape: Option<u8>,
278 quote: Option<u8>,
279 terminator: Option<u8>,
280 comment: Option<u8>,
281 null_regex: NullRegex,
282 truncated_rows: bool,
283}
284
285impl Format {
286 pub fn with_header(mut self, has_header: bool) -> Self {
290 self.header = has_header;
291 self
292 }
293
294 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
296 self.delimiter = Some(delimiter);
297 self
298 }
299
300 pub fn with_escape(mut self, escape: u8) -> Self {
302 self.escape = Some(escape);
303 self
304 }
305
306 pub fn with_quote(mut self, quote: u8) -> Self {
308 self.quote = Some(quote);
309 self
310 }
311
312 pub fn with_terminator(mut self, terminator: u8) -> Self {
314 self.terminator = Some(terminator);
315 self
316 }
317
318 pub fn with_comment(mut self, comment: u8) -> Self {
322 self.comment = Some(comment);
323 self
324 }
325
326 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
328 self.null_regex = NullRegex(Some(null_regex));
329 self
330 }
331
332 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
339 self.truncated_rows = allow;
340 self
341 }
342
343 pub fn infer_schema<R: Read>(
350 &self,
351 reader: R,
352 max_records: Option<usize>,
353 ) -> Result<(Schema, usize), ArrowError> {
354 let mut csv_reader = self.build_reader(reader);
355
356 let headers: Vec<String> = if self.header {
359 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
360 headers.iter().map(|s| s.to_string()).collect()
361 } else {
362 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
363 (0..*first_record_count)
364 .map(|i| format!("column_{}", i + 1))
365 .collect()
366 };
367
368 let header_length = headers.len();
369 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
371
372 let mut records_count = 0;
373
374 let mut record = StringRecord::new();
375 let max_records = max_records.unwrap_or(usize::MAX);
376 while records_count < max_records {
377 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
378 break;
379 }
380 records_count += 1;
381
382 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
385 if let Some(string) = record.get(i) {
386 if !self.null_regex.is_null(string) {
387 column_type.update(string)
388 }
389 }
390 }
391 }
392
393 let fields: Fields = column_types
395 .iter()
396 .zip(&headers)
397 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
398 .collect();
399
400 Ok((Schema::new(fields), records_count))
401 }
402
403 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
405 let mut builder = csv::ReaderBuilder::new();
406 builder.has_headers(self.header);
407 builder.flexible(self.truncated_rows);
408
409 if let Some(c) = self.delimiter {
410 builder.delimiter(c);
411 }
412 builder.escape(self.escape);
413 if let Some(c) = self.quote {
414 builder.quote(c);
415 }
416 if let Some(t) = self.terminator {
417 builder.terminator(csv::Terminator::Any(t));
418 }
419 if let Some(comment) = self.comment {
420 builder.comment(Some(comment));
421 }
422 builder.from_reader(reader)
423 }
424
425 fn build_parser(&self) -> csv_core::Reader {
427 let mut builder = csv_core::ReaderBuilder::new();
428 builder.escape(self.escape);
429 builder.comment(self.comment);
430
431 if let Some(c) = self.delimiter {
432 builder.delimiter(c);
433 }
434 if let Some(c) = self.quote {
435 builder.quote(c);
436 }
437 if let Some(t) = self.terminator {
438 builder.terminator(csv_core::Terminator::Any(t));
439 }
440 builder.build()
441 }
442}
443
444pub fn infer_schema_from_files(
451 files: &[String],
452 delimiter: u8,
453 max_read_records: Option<usize>,
454 has_header: bool,
455) -> Result<Schema, ArrowError> {
456 let mut schemas = vec![];
457 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
458 let format = Format {
459 delimiter: Some(delimiter),
460 header: has_header,
461 ..Default::default()
462 };
463
464 for fname in files.iter() {
465 let f = File::open(fname)?;
466 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
467 if records_read == 0 {
468 continue;
469 }
470 schemas.push(schema.clone());
471 records_to_read -= records_read;
472 if records_to_read == 0 {
473 break;
474 }
475 }
476
477 Schema::try_merge(schemas)
478}
479
480type Bounds = Option<(usize, usize)>;
482
483pub type Reader<R> = BufReader<StdBufReader<R>>;
488
489pub struct BufReader<R> {
494 reader: R,
496 decoder: Decoder,
498}
499
500impl<R> fmt::Debug for BufReader<R>
501where
502 R: BufRead,
503{
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 f.debug_struct("Reader")
506 .field("decoder", &self.decoder)
507 .finish()
508 }
509}
510
511impl<R: Read> Reader<R> {
512 pub fn schema(&self) -> SchemaRef {
515 match &self.decoder.projection {
516 Some(projection) => {
517 let fields = self.decoder.schema.fields();
518 let projected = projection.iter().map(|i| fields[*i].clone());
519 Arc::new(Schema::new(projected.collect::<Fields>()))
520 }
521 None => self.decoder.schema.clone(),
522 }
523 }
524}
525
526impl<R: BufRead> BufReader<R> {
527 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
528 loop {
529 let buf = self.reader.fill_buf()?;
530 let decoded = self.decoder.decode(buf)?;
531 self.reader.consume(decoded);
532 if decoded == 0 || self.decoder.capacity() == 0 {
538 break;
539 }
540 }
541
542 self.decoder.flush()
543 }
544}
545
546impl<R: BufRead> Iterator for BufReader<R> {
547 type Item = Result<RecordBatch, ArrowError>;
548
549 fn next(&mut self) -> Option<Self::Item> {
550 self.read().transpose()
551 }
552}
553
554impl<R: BufRead> RecordBatchReader for BufReader<R> {
555 fn schema(&self) -> SchemaRef {
556 self.decoder.schema.clone()
557 }
558}
559
560#[derive(Debug)]
600pub struct Decoder {
601 schema: SchemaRef,
603
604 projection: Option<Vec<usize>>,
606
607 batch_size: usize,
609
610 to_skip: usize,
612
613 line_number: usize,
615
616 end: usize,
618
619 record_decoder: RecordDecoder,
621
622 null_regex: NullRegex,
624}
625
626impl Decoder {
627 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
637 if self.to_skip != 0 {
638 let to_skip = self.to_skip.min(self.batch_size);
640 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
641 self.to_skip -= skipped;
642 self.record_decoder.clear();
643 return Ok(bytes);
644 }
645
646 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
647 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
648 Ok(bytes)
649 }
650
651 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
658 if self.record_decoder.is_empty() {
659 return Ok(None);
660 }
661
662 let rows = self.record_decoder.flush()?;
663 let batch = parse(
664 &rows,
665 self.schema.fields(),
666 Some(self.schema.metadata.clone()),
667 self.projection.as_ref(),
668 self.line_number,
669 &self.null_regex,
670 )?;
671 self.line_number += rows.len();
672 Ok(Some(batch))
673 }
674
675 pub fn capacity(&self) -> usize {
677 self.batch_size - self.record_decoder.len()
678 }
679}
680
681fn parse(
683 rows: &StringRecords<'_>,
684 fields: &Fields,
685 metadata: Option<std::collections::HashMap<String, String>>,
686 projection: Option<&Vec<usize>>,
687 line_number: usize,
688 null_regex: &NullRegex,
689) -> Result<RecordBatch, ArrowError> {
690 let projection: Vec<usize> = match projection {
691 Some(v) => v.clone(),
692 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
693 };
694
695 let arrays: Result<Vec<ArrayRef>, _> = projection
696 .iter()
697 .map(|i| {
698 let i = *i;
699 let field = &fields[i];
700 match field.data_type() {
701 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
702 DataType::Decimal32(precision, scale) => build_decimal_array::<Decimal32Type>(
703 line_number,
704 rows,
705 i,
706 *precision,
707 *scale,
708 null_regex,
709 ),
710 DataType::Decimal64(precision, scale) => build_decimal_array::<Decimal64Type>(
711 line_number,
712 rows,
713 i,
714 *precision,
715 *scale,
716 null_regex,
717 ),
718 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
719 line_number,
720 rows,
721 i,
722 *precision,
723 *scale,
724 null_regex,
725 ),
726 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
727 line_number,
728 rows,
729 i,
730 *precision,
731 *scale,
732 null_regex,
733 ),
734 DataType::Int8 => {
735 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
736 }
737 DataType::Int16 => {
738 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
739 }
740 DataType::Int32 => {
741 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
742 }
743 DataType::Int64 => {
744 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
745 }
746 DataType::UInt8 => {
747 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
748 }
749 DataType::UInt16 => {
750 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
751 }
752 DataType::UInt32 => {
753 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
754 }
755 DataType::UInt64 => {
756 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
757 }
758 DataType::Float32 => {
759 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
760 }
761 DataType::Float64 => {
762 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
763 }
764 DataType::Date32 => {
765 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
766 }
767 DataType::Date64 => {
768 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
769 }
770 DataType::Time32(TimeUnit::Second) => {
771 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
772 }
773 DataType::Time32(TimeUnit::Millisecond) => {
774 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
775 }
776 DataType::Time64(TimeUnit::Microsecond) => {
777 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
778 }
779 DataType::Time64(TimeUnit::Nanosecond) => {
780 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
781 }
782 DataType::Timestamp(TimeUnit::Second, tz) => {
783 build_timestamp_array::<TimestampSecondType>(
784 line_number,
785 rows,
786 i,
787 tz.as_deref(),
788 null_regex,
789 )
790 }
791 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
792 build_timestamp_array::<TimestampMillisecondType>(
793 line_number,
794 rows,
795 i,
796 tz.as_deref(),
797 null_regex,
798 )
799 }
800 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
801 build_timestamp_array::<TimestampMicrosecondType>(
802 line_number,
803 rows,
804 i,
805 tz.as_deref(),
806 null_regex,
807 )
808 }
809 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
810 build_timestamp_array::<TimestampNanosecondType>(
811 line_number,
812 rows,
813 i,
814 tz.as_deref(),
815 null_regex,
816 )
817 }
818 DataType::Null => Ok(Arc::new({
819 let mut builder = NullBuilder::new();
820 builder.append_nulls(rows.len());
821 builder.finish()
822 }) as ArrayRef),
823 DataType::Utf8 => Ok(Arc::new(
824 rows.iter()
825 .map(|row| {
826 let s = row.get(i);
827 (!null_regex.is_null(s)).then_some(s)
828 })
829 .collect::<StringArray>(),
830 ) as ArrayRef),
831 DataType::Utf8View => Ok(Arc::new(
832 rows.iter()
833 .map(|row| {
834 let s = row.get(i);
835 (!null_regex.is_null(s)).then_some(s)
836 })
837 .collect::<StringViewArray>(),
838 ) as ArrayRef),
839 DataType::Dictionary(key_type, value_type)
840 if value_type.as_ref() == &DataType::Utf8 =>
841 {
842 match key_type.as_ref() {
843 DataType::Int8 => Ok(Arc::new(
844 rows.iter()
845 .map(|row| {
846 let s = row.get(i);
847 (!null_regex.is_null(s)).then_some(s)
848 })
849 .collect::<DictionaryArray<Int8Type>>(),
850 ) as ArrayRef),
851 DataType::Int16 => Ok(Arc::new(
852 rows.iter()
853 .map(|row| {
854 let s = row.get(i);
855 (!null_regex.is_null(s)).then_some(s)
856 })
857 .collect::<DictionaryArray<Int16Type>>(),
858 ) as ArrayRef),
859 DataType::Int32 => Ok(Arc::new(
860 rows.iter()
861 .map(|row| {
862 let s = row.get(i);
863 (!null_regex.is_null(s)).then_some(s)
864 })
865 .collect::<DictionaryArray<Int32Type>>(),
866 ) as ArrayRef),
867 DataType::Int64 => Ok(Arc::new(
868 rows.iter()
869 .map(|row| {
870 let s = row.get(i);
871 (!null_regex.is_null(s)).then_some(s)
872 })
873 .collect::<DictionaryArray<Int64Type>>(),
874 ) as ArrayRef),
875 DataType::UInt8 => Ok(Arc::new(
876 rows.iter()
877 .map(|row| {
878 let s = row.get(i);
879 (!null_regex.is_null(s)).then_some(s)
880 })
881 .collect::<DictionaryArray<UInt8Type>>(),
882 ) as ArrayRef),
883 DataType::UInt16 => Ok(Arc::new(
884 rows.iter()
885 .map(|row| {
886 let s = row.get(i);
887 (!null_regex.is_null(s)).then_some(s)
888 })
889 .collect::<DictionaryArray<UInt16Type>>(),
890 ) as ArrayRef),
891 DataType::UInt32 => Ok(Arc::new(
892 rows.iter()
893 .map(|row| {
894 let s = row.get(i);
895 (!null_regex.is_null(s)).then_some(s)
896 })
897 .collect::<DictionaryArray<UInt32Type>>(),
898 ) as ArrayRef),
899 DataType::UInt64 => Ok(Arc::new(
900 rows.iter()
901 .map(|row| {
902 let s = row.get(i);
903 (!null_regex.is_null(s)).then_some(s)
904 })
905 .collect::<DictionaryArray<UInt64Type>>(),
906 ) as ArrayRef),
907 _ => Err(ArrowError::ParseError(format!(
908 "Unsupported dictionary key type {key_type}"
909 ))),
910 }
911 }
912 other => Err(ArrowError::ParseError(format!(
913 "Unsupported data type {other:?}"
914 ))),
915 }
916 })
917 .collect();
918
919 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
920
921 let projected_schema = Arc::new(match metadata {
922 None => Schema::new(projected_fields),
923 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
924 });
925
926 arrays.and_then(|arr| {
927 RecordBatch::try_new_with_options(
928 projected_schema,
929 arr,
930 &RecordBatchOptions::new()
931 .with_match_field_names(true)
932 .with_row_count(Some(rows.len())),
933 )
934 })
935}
936
937fn parse_bool(string: &str) -> Option<bool> {
938 if string.eq_ignore_ascii_case("false") {
939 Some(false)
940 } else if string.eq_ignore_ascii_case("true") {
941 Some(true)
942 } else {
943 None
944 }
945}
946
947fn build_decimal_array<T: DecimalType>(
949 _line_number: usize,
950 rows: &StringRecords<'_>,
951 col_idx: usize,
952 precision: u8,
953 scale: i8,
954 null_regex: &NullRegex,
955) -> Result<ArrayRef, ArrowError> {
956 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
957 for row in rows.iter() {
958 let s = row.get(col_idx);
959 if null_regex.is_null(s) {
960 decimal_builder.append_null();
962 } else {
963 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
964 match decimal_value {
965 Ok(v) => {
966 decimal_builder.append_value(v);
967 }
968 Err(e) => {
969 return Err(e);
970 }
971 }
972 }
973 }
974 Ok(Arc::new(
975 decimal_builder
976 .finish()
977 .with_precision_and_scale(precision, scale)?,
978 ))
979}
980
981fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
983 line_number: usize,
984 rows: &StringRecords<'_>,
985 col_idx: usize,
986 null_regex: &NullRegex,
987) -> Result<ArrayRef, ArrowError> {
988 rows.iter()
989 .enumerate()
990 .map(|(row_index, row)| {
991 let s = row.get(col_idx);
992 if null_regex.is_null(s) {
993 return Ok(None);
994 }
995
996 match T::parse(s) {
997 Some(e) => Ok(Some(e)),
998 None => Err(ArrowError::ParseError(format!(
999 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1001 s,
1002 T::DATA_TYPE,
1003 col_idx,
1004 line_number + row_index,
1005 row
1006 ))),
1007 }
1008 })
1009 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
1010 .map(|e| Arc::new(e) as ArrayRef)
1011}
1012
1013fn build_timestamp_array<T: ArrowTimestampType>(
1014 line_number: usize,
1015 rows: &StringRecords<'_>,
1016 col_idx: usize,
1017 timezone: Option<&str>,
1018 null_regex: &NullRegex,
1019) -> Result<ArrayRef, ArrowError> {
1020 Ok(Arc::new(match timezone {
1021 Some(timezone) => {
1022 let tz: Tz = timezone.parse()?;
1023 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
1024 .with_timezone(timezone)
1025 }
1026 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
1027 }))
1028}
1029
1030fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
1031 line_number: usize,
1032 rows: &StringRecords<'_>,
1033 col_idx: usize,
1034 timezone: &Tz,
1035 null_regex: &NullRegex,
1036) -> Result<PrimitiveArray<T>, ArrowError> {
1037 rows.iter()
1038 .enumerate()
1039 .map(|(row_index, row)| {
1040 let s = row.get(col_idx);
1041 if null_regex.is_null(s) {
1042 return Ok(None);
1043 }
1044
1045 let date = string_to_datetime(timezone, s)
1046 .and_then(|date| match T::UNIT {
1047 TimeUnit::Second => Ok(date.timestamp()),
1048 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
1049 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
1050 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
1051 ArrowError::ParseError(format!(
1052 "{} would overflow 64-bit signed nanoseconds",
1053 date.to_rfc3339(),
1054 ))
1055 }),
1056 })
1057 .map_err(|e| {
1058 ArrowError::ParseError(format!(
1059 "Error parsing column {col_idx} at line {}: {}",
1060 line_number + row_index,
1061 e
1062 ))
1063 })?;
1064 Ok(Some(date))
1065 })
1066 .collect()
1067}
1068
1069fn build_boolean_array(
1071 line_number: usize,
1072 rows: &StringRecords<'_>,
1073 col_idx: usize,
1074 null_regex: &NullRegex,
1075) -> Result<ArrayRef, ArrowError> {
1076 rows.iter()
1077 .enumerate()
1078 .map(|(row_index, row)| {
1079 let s = row.get(col_idx);
1080 if null_regex.is_null(s) {
1081 return Ok(None);
1082 }
1083 let parsed = parse_bool(s);
1084 match parsed {
1085 Some(e) => Ok(Some(e)),
1086 None => Err(ArrowError::ParseError(format!(
1087 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1089 s,
1090 "Boolean",
1091 col_idx,
1092 line_number + row_index,
1093 row
1094 ))),
1095 }
1096 })
1097 .collect::<Result<BooleanArray, _>>()
1098 .map(|e| Arc::new(e) as ArrayRef)
1099}
1100
1101#[derive(Debug)]
1103pub struct ReaderBuilder {
1104 schema: SchemaRef,
1106 format: Format,
1108 batch_size: usize,
1112 bounds: Bounds,
1114 projection: Option<Vec<usize>>,
1116}
1117
1118impl ReaderBuilder {
1119 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1142 Self {
1143 schema,
1144 format: Format::default(),
1145 batch_size: 1024,
1146 bounds: None,
1147 projection: None,
1148 }
1149 }
1150
1151 pub fn with_header(mut self, has_header: bool) -> Self {
1153 self.format.header = has_header;
1154 self
1155 }
1156
1157 pub fn with_format(mut self, format: Format) -> Self {
1159 self.format = format;
1160 self
1161 }
1162
1163 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1165 self.format.delimiter = Some(delimiter);
1166 self
1167 }
1168
1169 pub fn with_escape(mut self, escape: u8) -> Self {
1171 self.format.escape = Some(escape);
1172 self
1173 }
1174
1175 pub fn with_quote(mut self, quote: u8) -> Self {
1177 self.format.quote = Some(quote);
1178 self
1179 }
1180
1181 pub fn with_terminator(mut self, terminator: u8) -> Self {
1183 self.format.terminator = Some(terminator);
1184 self
1185 }
1186
1187 pub fn with_comment(mut self, comment: u8) -> Self {
1189 self.format.comment = Some(comment);
1190 self
1191 }
1192
1193 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1195 self.format.null_regex = NullRegex(Some(null_regex));
1196 self
1197 }
1198
1199 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1201 self.batch_size = batch_size;
1202 self
1203 }
1204
1205 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1208 self.bounds = Some((start, end));
1209 self
1210 }
1211
1212 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1214 self.projection = Some(projection);
1215 self
1216 }
1217
1218 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1225 self.format.truncated_rows = allow;
1226 self
1227 }
1228
1229 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1234 self.build_buffered(StdBufReader::new(reader))
1235 }
1236
1237 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1239 Ok(BufReader {
1240 reader,
1241 decoder: self.build_decoder(),
1242 })
1243 }
1244
1245 pub fn build_decoder(self) -> Decoder {
1247 let delimiter = self.format.build_parser();
1248 let record_decoder = RecordDecoder::new(
1249 delimiter,
1250 self.schema.fields().len(),
1251 self.format.truncated_rows,
1252 );
1253
1254 let header = self.format.header as usize;
1255
1256 let (start, end) = match self.bounds {
1257 Some((start, end)) => (start + header, end + header),
1258 None => (header, usize::MAX),
1259 };
1260
1261 Decoder {
1262 schema: self.schema,
1263 to_skip: start,
1264 record_decoder,
1265 line_number: start,
1266 end,
1267 projection: self.projection,
1268 batch_size: self.batch_size,
1269 null_regex: self.format.null_regex,
1270 }
1271 }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276 use super::*;
1277
1278 use std::io::{Cursor, Seek, SeekFrom, Write};
1279 use tempfile::NamedTempFile;
1280
1281 use arrow_array::cast::AsArray;
1282
1283 #[test]
1284 fn test_csv() {
1285 let schema = Arc::new(Schema::new(vec![
1286 Field::new("city", DataType::Utf8, false),
1287 Field::new("lat", DataType::Float64, false),
1288 Field::new("lng", DataType::Float64, false),
1289 ]));
1290
1291 let file = File::open("test/data/uk_cities.csv").unwrap();
1292 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1293 assert_eq!(schema, csv.schema());
1294 let batch = csv.next().unwrap().unwrap();
1295 assert_eq!(37, batch.num_rows());
1296 assert_eq!(3, batch.num_columns());
1297
1298 let lat = batch.column(1).as_primitive::<Float64Type>();
1300 assert_eq!(57.653484, lat.value(0));
1301
1302 let city = batch.column(0).as_string::<i32>();
1304
1305 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1306 }
1307
1308 #[test]
1309 fn test_csv_schema_metadata() {
1310 let mut metadata = std::collections::HashMap::new();
1311 metadata.insert("foo".to_owned(), "bar".to_owned());
1312 let schema = Arc::new(Schema::new_with_metadata(
1313 vec![
1314 Field::new("city", DataType::Utf8, false),
1315 Field::new("lat", DataType::Float64, false),
1316 Field::new("lng", DataType::Float64, false),
1317 ],
1318 metadata.clone(),
1319 ));
1320
1321 let file = File::open("test/data/uk_cities.csv").unwrap();
1322
1323 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1324 assert_eq!(schema, csv.schema());
1325 let batch = csv.next().unwrap().unwrap();
1326 assert_eq!(37, batch.num_rows());
1327 assert_eq!(3, batch.num_columns());
1328
1329 assert_eq!(&metadata, batch.schema().metadata());
1330 }
1331
1332 #[test]
1333 fn test_csv_reader_with_decimal() {
1334 let schema = Arc::new(Schema::new(vec![
1335 Field::new("city", DataType::Utf8, false),
1336 Field::new("lat", DataType::Decimal128(38, 6), false),
1337 Field::new("lng", DataType::Decimal256(76, 6), false),
1338 ]));
1339
1340 let file = File::open("test/data/decimal_test.csv").unwrap();
1341
1342 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1343 let batch = csv.next().unwrap().unwrap();
1344 let lat = batch
1346 .column(1)
1347 .as_any()
1348 .downcast_ref::<Decimal128Array>()
1349 .unwrap();
1350
1351 assert_eq!("57.653484", lat.value_as_string(0));
1352 assert_eq!("53.002666", lat.value_as_string(1));
1353 assert_eq!("52.412811", lat.value_as_string(2));
1354 assert_eq!("51.481583", lat.value_as_string(3));
1355 assert_eq!("12.123456", lat.value_as_string(4));
1356 assert_eq!("50.760000", lat.value_as_string(5));
1357 assert_eq!("0.123000", lat.value_as_string(6));
1358 assert_eq!("123.000000", lat.value_as_string(7));
1359 assert_eq!("123.000000", lat.value_as_string(8));
1360 assert_eq!("-50.760000", lat.value_as_string(9));
1361
1362 let lng = batch
1363 .column(2)
1364 .as_any()
1365 .downcast_ref::<Decimal256Array>()
1366 .unwrap();
1367
1368 assert_eq!("-3.335724", lng.value_as_string(0));
1369 assert_eq!("-2.179404", lng.value_as_string(1));
1370 assert_eq!("-1.778197", lng.value_as_string(2));
1371 assert_eq!("-3.179090", lng.value_as_string(3));
1372 assert_eq!("-3.179090", lng.value_as_string(4));
1373 assert_eq!("0.290472", lng.value_as_string(5));
1374 assert_eq!("0.290472", lng.value_as_string(6));
1375 assert_eq!("0.290472", lng.value_as_string(7));
1376 assert_eq!("0.290472", lng.value_as_string(8));
1377 assert_eq!("0.290472", lng.value_as_string(9));
1378 }
1379
1380 #[test]
1381 fn test_csv_reader_with_decimal_3264() {
1382 let schema = Arc::new(Schema::new(vec![
1383 Field::new("city", DataType::Utf8, false),
1384 Field::new("lat", DataType::Decimal32(9, 6), false),
1385 Field::new("lng", DataType::Decimal64(16, 6), false),
1386 ]));
1387
1388 let file = File::open("test/data/decimal_test.csv").unwrap();
1389
1390 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1391 let batch = csv.next().unwrap().unwrap();
1392 let lat = batch
1394 .column(1)
1395 .as_any()
1396 .downcast_ref::<Decimal32Array>()
1397 .unwrap();
1398
1399 assert_eq!("57.653484", lat.value_as_string(0));
1400 assert_eq!("53.002666", lat.value_as_string(1));
1401 assert_eq!("52.412811", lat.value_as_string(2));
1402 assert_eq!("51.481583", lat.value_as_string(3));
1403 assert_eq!("12.123456", lat.value_as_string(4));
1404 assert_eq!("50.760000", lat.value_as_string(5));
1405 assert_eq!("0.123000", lat.value_as_string(6));
1406 assert_eq!("123.000000", lat.value_as_string(7));
1407 assert_eq!("123.000000", lat.value_as_string(8));
1408 assert_eq!("-50.760000", lat.value_as_string(9));
1409
1410 let lng = batch
1411 .column(2)
1412 .as_any()
1413 .downcast_ref::<Decimal64Array>()
1414 .unwrap();
1415
1416 assert_eq!("-3.335724", lng.value_as_string(0));
1417 assert_eq!("-2.179404", lng.value_as_string(1));
1418 assert_eq!("-1.778197", lng.value_as_string(2));
1419 assert_eq!("-3.179090", lng.value_as_string(3));
1420 assert_eq!("-3.179090", lng.value_as_string(4));
1421 assert_eq!("0.290472", lng.value_as_string(5));
1422 assert_eq!("0.290472", lng.value_as_string(6));
1423 assert_eq!("0.290472", lng.value_as_string(7));
1424 assert_eq!("0.290472", lng.value_as_string(8));
1425 assert_eq!("0.290472", lng.value_as_string(9));
1426 }
1427
1428 #[test]
1429 fn test_csv_from_buf_reader() {
1430 let schema = Schema::new(vec![
1431 Field::new("city", DataType::Utf8, false),
1432 Field::new("lat", DataType::Float64, false),
1433 Field::new("lng", DataType::Float64, false),
1434 ]);
1435
1436 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1437 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1438 let both_files = file_with_headers
1439 .chain(Cursor::new("\n".to_string()))
1440 .chain(file_without_headers);
1441 let mut csv = ReaderBuilder::new(Arc::new(schema))
1442 .with_header(true)
1443 .build(both_files)
1444 .unwrap();
1445 let batch = csv.next().unwrap().unwrap();
1446 assert_eq!(74, batch.num_rows());
1447 assert_eq!(3, batch.num_columns());
1448 }
1449
1450 #[test]
1451 fn test_csv_with_schema_inference() {
1452 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1453
1454 let (schema, _) = Format::default()
1455 .with_header(true)
1456 .infer_schema(&mut file, None)
1457 .unwrap();
1458
1459 file.rewind().unwrap();
1460 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1461
1462 let mut csv = builder.build(file).unwrap();
1463 let expected_schema = Schema::new(vec![
1464 Field::new("city", DataType::Utf8, true),
1465 Field::new("lat", DataType::Float64, true),
1466 Field::new("lng", DataType::Float64, true),
1467 ]);
1468 assert_eq!(Arc::new(expected_schema), csv.schema());
1469 let batch = csv.next().unwrap().unwrap();
1470 assert_eq!(37, batch.num_rows());
1471 assert_eq!(3, batch.num_columns());
1472
1473 let lat = batch
1475 .column(1)
1476 .as_any()
1477 .downcast_ref::<Float64Array>()
1478 .unwrap();
1479 assert_eq!(57.653484, lat.value(0));
1480
1481 let city = batch
1483 .column(0)
1484 .as_any()
1485 .downcast_ref::<StringArray>()
1486 .unwrap();
1487
1488 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1489 }
1490
1491 #[test]
1492 fn test_csv_with_schema_inference_no_headers() {
1493 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1494
1495 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1496 file.rewind().unwrap();
1497
1498 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1499
1500 let schema = csv.schema();
1502 assert_eq!("column_1", schema.field(0).name());
1503 assert_eq!("column_2", schema.field(1).name());
1504 assert_eq!("column_3", schema.field(2).name());
1505 let batch = csv.next().unwrap().unwrap();
1506 let batch_schema = batch.schema();
1507
1508 assert_eq!(schema, batch_schema);
1509 assert_eq!(37, batch.num_rows());
1510 assert_eq!(3, batch.num_columns());
1511
1512 let lat = batch
1514 .column(1)
1515 .as_any()
1516 .downcast_ref::<Float64Array>()
1517 .unwrap();
1518 assert_eq!(57.653484, lat.value(0));
1519
1520 let city = batch
1522 .column(0)
1523 .as_any()
1524 .downcast_ref::<StringArray>()
1525 .unwrap();
1526
1527 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1528 }
1529
1530 #[test]
1531 fn test_csv_builder_with_bounds() {
1532 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1533
1534 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1536 file.rewind().unwrap();
1537 let mut csv = ReaderBuilder::new(Arc::new(schema))
1538 .with_bounds(0, 2)
1539 .build(file)
1540 .unwrap();
1541 let batch = csv.next().unwrap().unwrap();
1542
1543 let city = batch
1545 .column(0)
1546 .as_any()
1547 .downcast_ref::<StringArray>()
1548 .unwrap();
1549
1550 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1552
1553 let result = std::panic::catch_unwind(|| city.value(13));
1556 assert!(result.is_err());
1557 }
1558
1559 #[test]
1560 fn test_csv_with_projection() {
1561 let schema = Arc::new(Schema::new(vec![
1562 Field::new("city", DataType::Utf8, false),
1563 Field::new("lat", DataType::Float64, false),
1564 Field::new("lng", DataType::Float64, false),
1565 ]));
1566
1567 let file = File::open("test/data/uk_cities.csv").unwrap();
1568
1569 let mut csv = ReaderBuilder::new(schema)
1570 .with_projection(vec![0, 1])
1571 .build(file)
1572 .unwrap();
1573
1574 let projected_schema = Arc::new(Schema::new(vec![
1575 Field::new("city", DataType::Utf8, false),
1576 Field::new("lat", DataType::Float64, false),
1577 ]));
1578 assert_eq!(projected_schema, csv.schema());
1579 let batch = csv.next().unwrap().unwrap();
1580 assert_eq!(projected_schema, batch.schema());
1581 assert_eq!(37, batch.num_rows());
1582 assert_eq!(2, batch.num_columns());
1583 }
1584
1585 #[test]
1586 fn test_csv_with_dictionary() {
1587 let schema = Arc::new(Schema::new(vec![
1588 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1589 Field::new("lat", DataType::Float64, false),
1590 Field::new("lng", DataType::Float64, false),
1591 ]));
1592
1593 let file = File::open("test/data/uk_cities.csv").unwrap();
1594
1595 let mut csv = ReaderBuilder::new(schema)
1596 .with_projection(vec![0, 1])
1597 .build(file)
1598 .unwrap();
1599
1600 let projected_schema = Arc::new(Schema::new(vec![
1601 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1602 Field::new("lat", DataType::Float64, false),
1603 ]));
1604 assert_eq!(projected_schema, csv.schema());
1605 let batch = csv.next().unwrap().unwrap();
1606 assert_eq!(projected_schema, batch.schema());
1607 assert_eq!(37, batch.num_rows());
1608 assert_eq!(2, batch.num_columns());
1609
1610 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1611 let strings = strings.as_string::<i32>();
1612
1613 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1614 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1615 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1616 }
1617
1618 #[test]
1619 fn test_csv_with_nullable_dictionary() {
1620 let offset_type = vec![
1621 DataType::Int8,
1622 DataType::Int16,
1623 DataType::Int32,
1624 DataType::Int64,
1625 DataType::UInt8,
1626 DataType::UInt16,
1627 DataType::UInt32,
1628 DataType::UInt64,
1629 ];
1630 for data_type in offset_type {
1631 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1632 let dictionary_type =
1633 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1634 let schema = Arc::new(Schema::new(vec![
1635 Field::new("id", DataType::Utf8, false),
1636 Field::new("name", dictionary_type.clone(), true),
1637 ]));
1638
1639 let mut csv = ReaderBuilder::new(schema)
1640 .build(file.try_clone().unwrap())
1641 .unwrap();
1642
1643 let batch = csv.next().unwrap().unwrap();
1644 assert_eq!(3, batch.num_rows());
1645 assert_eq!(2, batch.num_columns());
1646
1647 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1648 assert!(!names.is_null(2));
1649 assert!(names.is_null(1));
1650 }
1651 }
1652 #[test]
1653 fn test_nulls() {
1654 let schema = Arc::new(Schema::new(vec![
1655 Field::new("c_int", DataType::UInt64, false),
1656 Field::new("c_float", DataType::Float32, true),
1657 Field::new("c_string", DataType::Utf8, true),
1658 Field::new("c_bool", DataType::Boolean, false),
1659 ]));
1660
1661 let file = File::open("test/data/null_test.csv").unwrap();
1662
1663 let mut csv = ReaderBuilder::new(schema)
1664 .with_header(true)
1665 .build(file)
1666 .unwrap();
1667
1668 let batch = csv.next().unwrap().unwrap();
1669
1670 assert!(!batch.column(1).is_null(0));
1671 assert!(!batch.column(1).is_null(1));
1672 assert!(batch.column(1).is_null(2));
1673 assert!(!batch.column(1).is_null(3));
1674 assert!(!batch.column(1).is_null(4));
1675 }
1676
1677 #[test]
1678 fn test_init_nulls() {
1679 let schema = Arc::new(Schema::new(vec![
1680 Field::new("c_int", DataType::UInt64, true),
1681 Field::new("c_float", DataType::Float32, true),
1682 Field::new("c_string", DataType::Utf8, true),
1683 Field::new("c_bool", DataType::Boolean, true),
1684 Field::new("c_null", DataType::Null, true),
1685 ]));
1686 let file = File::open("test/data/init_null_test.csv").unwrap();
1687
1688 let mut csv = ReaderBuilder::new(schema)
1689 .with_header(true)
1690 .build(file)
1691 .unwrap();
1692
1693 let batch = csv.next().unwrap().unwrap();
1694
1695 assert!(batch.column(1).is_null(0));
1696 assert!(!batch.column(1).is_null(1));
1697 assert!(batch.column(1).is_null(2));
1698 assert!(!batch.column(1).is_null(3));
1699 assert!(!batch.column(1).is_null(4));
1700 }
1701
1702 #[test]
1703 fn test_init_nulls_with_inference() {
1704 let format = Format::default().with_header(true).with_delimiter(b',');
1705
1706 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1707 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1708 file.rewind().unwrap();
1709
1710 let expected_schema = Schema::new(vec![
1711 Field::new("c_int", DataType::Int64, true),
1712 Field::new("c_float", DataType::Float64, true),
1713 Field::new("c_string", DataType::Utf8, true),
1714 Field::new("c_bool", DataType::Boolean, true),
1715 Field::new("c_null", DataType::Null, true),
1716 ]);
1717 assert_eq!(schema, expected_schema);
1718
1719 let mut csv = ReaderBuilder::new(Arc::new(schema))
1720 .with_format(format)
1721 .build(file)
1722 .unwrap();
1723
1724 let batch = csv.next().unwrap().unwrap();
1725
1726 assert!(batch.column(1).is_null(0));
1727 assert!(!batch.column(1).is_null(1));
1728 assert!(batch.column(1).is_null(2));
1729 assert!(!batch.column(1).is_null(3));
1730 assert!(!batch.column(1).is_null(4));
1731 }
1732
1733 #[test]
1734 fn test_custom_nulls() {
1735 let schema = Arc::new(Schema::new(vec![
1736 Field::new("c_int", DataType::UInt64, true),
1737 Field::new("c_float", DataType::Float32, true),
1738 Field::new("c_string", DataType::Utf8, true),
1739 Field::new("c_bool", DataType::Boolean, true),
1740 ]));
1741
1742 let file = File::open("test/data/custom_null_test.csv").unwrap();
1743
1744 let null_regex = Regex::new("^nil$").unwrap();
1745
1746 let mut csv = ReaderBuilder::new(schema)
1747 .with_header(true)
1748 .with_null_regex(null_regex)
1749 .build(file)
1750 .unwrap();
1751
1752 let batch = csv.next().unwrap().unwrap();
1753
1754 assert!(batch.column(0).is_null(1));
1756 assert!(batch.column(1).is_null(2));
1757 assert!(batch.column(3).is_null(4));
1758 assert!(batch.column(2).is_null(3));
1759 assert!(!batch.column(2).is_null(4));
1760 }
1761
1762 #[test]
1763 fn test_nulls_with_inference() {
1764 let mut file = File::open("test/data/various_types.csv").unwrap();
1765 let format = Format::default().with_header(true).with_delimiter(b'|');
1766
1767 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1768 file.rewind().unwrap();
1769
1770 let builder = ReaderBuilder::new(Arc::new(schema))
1771 .with_format(format)
1772 .with_batch_size(512)
1773 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1774
1775 let mut csv = builder.build(file).unwrap();
1776 let batch = csv.next().unwrap().unwrap();
1777
1778 assert_eq!(10, batch.num_rows());
1779 assert_eq!(6, batch.num_columns());
1780
1781 let schema = batch.schema();
1782
1783 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1784 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1785 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1786 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1787 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1788 assert_eq!(
1789 &DataType::Timestamp(TimeUnit::Second, None),
1790 schema.field(5).data_type()
1791 );
1792
1793 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1794 assert_eq!(
1795 names,
1796 vec![
1797 "c_int",
1798 "c_float",
1799 "c_string",
1800 "c_bool",
1801 "c_date",
1802 "c_datetime"
1803 ]
1804 );
1805
1806 assert!(schema.field(0).is_nullable());
1807 assert!(schema.field(1).is_nullable());
1808 assert!(schema.field(2).is_nullable());
1809 assert!(schema.field(3).is_nullable());
1810 assert!(schema.field(4).is_nullable());
1811 assert!(schema.field(5).is_nullable());
1812
1813 assert!(!batch.column(1).is_null(0));
1814 assert!(!batch.column(1).is_null(1));
1815 assert!(batch.column(1).is_null(2));
1816 assert!(!batch.column(1).is_null(3));
1817 assert!(!batch.column(1).is_null(4));
1818 }
1819
1820 #[test]
1821 fn test_custom_nulls_with_inference() {
1822 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1823
1824 let null_regex = Regex::new("^nil$").unwrap();
1825
1826 let format = Format::default()
1827 .with_header(true)
1828 .with_null_regex(null_regex);
1829
1830 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1831 file.rewind().unwrap();
1832
1833 let expected_schema = Schema::new(vec![
1834 Field::new("c_int", DataType::Int64, true),
1835 Field::new("c_float", DataType::Float64, true),
1836 Field::new("c_string", DataType::Utf8, true),
1837 Field::new("c_bool", DataType::Boolean, true),
1838 ]);
1839
1840 assert_eq!(schema, expected_schema);
1841
1842 let builder = ReaderBuilder::new(Arc::new(schema))
1843 .with_format(format)
1844 .with_batch_size(512)
1845 .with_projection(vec![0, 1, 2, 3]);
1846
1847 let mut csv = builder.build(file).unwrap();
1848 let batch = csv.next().unwrap().unwrap();
1849
1850 assert_eq!(5, batch.num_rows());
1851 assert_eq!(4, batch.num_columns());
1852
1853 assert_eq!(batch.schema().as_ref(), &expected_schema);
1854 }
1855
1856 #[test]
1857 fn test_scientific_notation_with_inference() {
1858 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1859 let format = Format::default().with_header(false).with_delimiter(b',');
1860
1861 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1862 file.rewind().unwrap();
1863
1864 let builder = ReaderBuilder::new(Arc::new(schema))
1865 .with_format(format)
1866 .with_batch_size(512)
1867 .with_projection(vec![0, 1]);
1868
1869 let mut csv = builder.build(file).unwrap();
1870 let batch = csv.next().unwrap().unwrap();
1871
1872 let schema = batch.schema();
1873
1874 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1875 }
1876
1877 fn invalid_csv_helper(file_name: &str) -> String {
1878 let file = File::open(file_name).unwrap();
1879 let schema = Schema::new(vec![
1880 Field::new("c_int", DataType::UInt64, false),
1881 Field::new("c_float", DataType::Float32, false),
1882 Field::new("c_string", DataType::Utf8, false),
1883 Field::new("c_bool", DataType::Boolean, false),
1884 ]);
1885
1886 let builder = ReaderBuilder::new(Arc::new(schema))
1887 .with_header(true)
1888 .with_delimiter(b'|')
1889 .with_batch_size(512)
1890 .with_projection(vec![0, 1, 2, 3]);
1891
1892 let mut csv = builder.build(file).unwrap();
1893
1894 csv.next().unwrap().unwrap_err().to_string()
1895 }
1896
1897 #[test]
1898 fn test_parse_invalid_csv_float() {
1899 let file_name = "test/data/various_invalid_types/invalid_float.csv";
1900
1901 let error = invalid_csv_helper(file_name);
1902 assert_eq!(
1903 "Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'",
1904 error
1905 );
1906 }
1907
1908 #[test]
1909 fn test_parse_invalid_csv_int() {
1910 let file_name = "test/data/various_invalid_types/invalid_int.csv";
1911
1912 let error = invalid_csv_helper(file_name);
1913 assert_eq!(
1914 "Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'",
1915 error
1916 );
1917 }
1918
1919 #[test]
1920 fn test_parse_invalid_csv_bool() {
1921 let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1922
1923 let error = invalid_csv_helper(file_name);
1924 assert_eq!(
1925 "Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'",
1926 error
1927 );
1928 }
1929
1930 fn infer_field_schema(string: &str) -> DataType {
1932 let mut v = InferredDataType::default();
1933 v.update(string);
1934 v.get()
1935 }
1936
1937 #[test]
1938 fn test_infer_field_schema() {
1939 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1940 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1941 assert_eq!(infer_field_schema("10"), DataType::Int64);
1942 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1943 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1944 assert_eq!(infer_field_schema("2."), DataType::Float64);
1945 assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1946 assert_eq!(infer_field_schema("nan"), DataType::Float64);
1947 assert_eq!(infer_field_schema("inf"), DataType::Float64);
1948 assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1949 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1950 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1951 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1952 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1953 assert_eq!(
1954 infer_field_schema("2020-11-08T14:20:01"),
1955 DataType::Timestamp(TimeUnit::Second, None)
1956 );
1957 assert_eq!(
1958 infer_field_schema("2020-11-08 14:20:01"),
1959 DataType::Timestamp(TimeUnit::Second, None)
1960 );
1961 assert_eq!(
1962 infer_field_schema("2020-11-08 14:20:01"),
1963 DataType::Timestamp(TimeUnit::Second, None)
1964 );
1965 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1966 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1967 assert_eq!(
1968 infer_field_schema("2021-12-19 13:12:30.921"),
1969 DataType::Timestamp(TimeUnit::Millisecond, None)
1970 );
1971 assert_eq!(
1972 infer_field_schema("2021-12-19T13:12:30.123456789"),
1973 DataType::Timestamp(TimeUnit::Nanosecond, None)
1974 );
1975 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1976 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1977 }
1978
1979 #[test]
1980 fn parse_date32() {
1981 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1982 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1983 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1984 }
1985
1986 #[test]
1987 fn parse_time() {
1988 assert_eq!(
1989 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1990 Some(601_123_456_789)
1991 );
1992 assert_eq!(
1993 Time64MicrosecondType::parse("12:10:01.123456 am"),
1994 Some(601_123_456)
1995 );
1996 assert_eq!(
1997 Time32MillisecondType::parse("2:10:01.12 PM"),
1998 Some(51_001_120)
1999 );
2000 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
2001 }
2002
2003 #[test]
2004 fn parse_date64() {
2005 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
2006 assert_eq!(
2007 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
2008 1542129070000
2009 );
2010 assert_eq!(
2011 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
2012 1542129070011
2013 );
2014 assert_eq!(
2015 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
2016 -2203932304000
2017 );
2018 assert_eq!(
2019 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
2020 -2203932304000
2021 );
2022 assert_eq!(
2023 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
2024 -2203932304000 - (30 * 60 * 1000)
2025 );
2026 }
2027
2028 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
2029 timezone: Option<Arc<str>>,
2030 expected: &[i64],
2031 ) {
2032 let csv = [
2033 "1970-01-01T00:00:00",
2034 "1970-01-01T00:00:00Z",
2035 "1970-01-01T00:00:00+02:00",
2036 ]
2037 .join("\n");
2038 let schema = Arc::new(Schema::new(vec![Field::new(
2039 "field",
2040 DataType::Timestamp(T::UNIT, timezone.clone()),
2041 true,
2042 )]));
2043
2044 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2045
2046 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2047 assert_eq!(decoded, csv.len());
2048 decoder.decode(&[]).unwrap();
2049
2050 let batch = decoder.flush().unwrap().unwrap();
2051 assert_eq!(batch.num_columns(), 1);
2052 assert_eq!(batch.num_rows(), 3);
2053 let col = batch.column(0).as_primitive::<T>();
2054 assert_eq!(col.values(), expected);
2055 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
2056 }
2057
2058 #[test]
2059 fn test_parse_timestamp() {
2060 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
2061 test_parse_timestamp_impl::<TimestampNanosecondType>(
2062 Some("+00:00".into()),
2063 &[0, 0, -7_200_000_000_000],
2064 );
2065 test_parse_timestamp_impl::<TimestampNanosecondType>(
2066 Some("-05:00".into()),
2067 &[18_000_000_000_000, 0, -7_200_000_000_000],
2068 );
2069 test_parse_timestamp_impl::<TimestampMicrosecondType>(
2070 Some("-03".into()),
2071 &[10_800_000_000, 0, -7_200_000_000],
2072 );
2073 test_parse_timestamp_impl::<TimestampMillisecondType>(
2074 Some("-03".into()),
2075 &[10_800_000, 0, -7_200_000],
2076 );
2077 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
2078 }
2079
2080 #[test]
2081 fn test_infer_schema_from_multiple_files() {
2082 let mut csv1 = NamedTempFile::new().unwrap();
2083 let mut csv2 = NamedTempFile::new().unwrap();
2084 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
2086 writeln!(csv1, "c1,c2,c3").unwrap();
2087 writeln!(csv1, "1,\"foo\",0.5").unwrap();
2088 writeln!(csv1, "3,\"bar\",1").unwrap();
2089 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
2090 writeln!(csv2, "c1,c2,c3,c4").unwrap();
2092 writeln!(csv2, "10,,3.14,true").unwrap();
2093 writeln!(csv4, "c1,c2,c3").unwrap();
2095 writeln!(csv4, "10,\"foo\",").unwrap();
2096
2097 let schema = infer_schema_from_files(
2098 &[
2099 csv3.path().to_str().unwrap().to_string(),
2100 csv1.path().to_str().unwrap().to_string(),
2101 csv2.path().to_str().unwrap().to_string(),
2102 csv4.path().to_str().unwrap().to_string(),
2103 ],
2104 b',',
2105 Some(4), true,
2107 )
2108 .unwrap();
2109
2110 assert_eq!(schema.fields().len(), 4);
2111 assert!(schema.field(0).is_nullable());
2112 assert!(schema.field(1).is_nullable());
2113 assert!(schema.field(2).is_nullable());
2114 assert!(schema.field(3).is_nullable());
2115
2116 assert_eq!(&DataType::Int64, schema.field(0).data_type());
2117 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
2118 assert_eq!(&DataType::Float64, schema.field(2).data_type());
2119 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2120 }
2121
2122 #[test]
2123 fn test_bounded() {
2124 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2125 let data = [
2126 vec!["0"],
2127 vec!["1"],
2128 vec!["2"],
2129 vec!["3"],
2130 vec!["4"],
2131 vec!["5"],
2132 vec!["6"],
2133 ];
2134
2135 let data = data
2136 .iter()
2137 .map(|x| x.join(","))
2138 .collect::<Vec<_>>()
2139 .join("\n");
2140 let data = data.as_bytes();
2141
2142 let reader = std::io::Cursor::new(data);
2143
2144 let mut csv = ReaderBuilder::new(Arc::new(schema))
2145 .with_batch_size(2)
2146 .with_projection(vec![0])
2147 .with_bounds(2, 6)
2148 .build_buffered(reader)
2149 .unwrap();
2150
2151 let batch = csv.next().unwrap().unwrap();
2152 let a = batch.column(0);
2153 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2154 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2155
2156 let batch = csv.next().unwrap().unwrap();
2157 let a = batch.column(0);
2158 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2159 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2160
2161 assert!(csv.next().is_none());
2162 }
2163
2164 #[test]
2165 fn test_empty_projection() {
2166 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2167 let data = [vec!["0"], vec!["1"]];
2168
2169 let data = data
2170 .iter()
2171 .map(|x| x.join(","))
2172 .collect::<Vec<_>>()
2173 .join("\n");
2174
2175 let mut csv = ReaderBuilder::new(Arc::new(schema))
2176 .with_batch_size(2)
2177 .with_projection(vec![])
2178 .build_buffered(Cursor::new(data.as_bytes()))
2179 .unwrap();
2180
2181 let batch = csv.next().unwrap().unwrap();
2182 assert_eq!(batch.columns().len(), 0);
2183 assert_eq!(batch.num_rows(), 2);
2184
2185 assert!(csv.next().is_none());
2186 }
2187
2188 #[test]
2189 fn test_parsing_bool() {
2190 assert_eq!(Some(true), parse_bool("true"));
2192 assert_eq!(Some(true), parse_bool("tRUe"));
2193 assert_eq!(Some(true), parse_bool("True"));
2194 assert_eq!(Some(true), parse_bool("TRUE"));
2195 assert_eq!(None, parse_bool("t"));
2196 assert_eq!(None, parse_bool("T"));
2197 assert_eq!(None, parse_bool(""));
2198
2199 assert_eq!(Some(false), parse_bool("false"));
2200 assert_eq!(Some(false), parse_bool("fALse"));
2201 assert_eq!(Some(false), parse_bool("False"));
2202 assert_eq!(Some(false), parse_bool("FALSE"));
2203 assert_eq!(None, parse_bool("f"));
2204 assert_eq!(None, parse_bool("F"));
2205 assert_eq!(None, parse_bool(""));
2206 }
2207
2208 #[test]
2209 fn test_parsing_float() {
2210 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2211 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2212 assert_eq!(Some(12.0), Float64Type::parse("12"));
2213 assert_eq!(Some(0.0), Float64Type::parse("0"));
2214 assert_eq!(Some(2.0), Float64Type::parse("2."));
2215 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2216 assert!(Float64Type::parse("nan").unwrap().is_nan());
2217 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2218 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2219 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2220 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2221 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2222 assert_eq!(None, Float64Type::parse(""));
2223 assert_eq!(None, Float64Type::parse("dd"));
2224 assert_eq!(None, Float64Type::parse("12.34.56"));
2225 }
2226
2227 #[test]
2228 fn test_non_std_quote() {
2229 let schema = Schema::new(vec![
2230 Field::new("text1", DataType::Utf8, false),
2231 Field::new("text2", DataType::Utf8, false),
2232 ]);
2233 let builder = ReaderBuilder::new(Arc::new(schema))
2234 .with_header(false)
2235 .with_quote(b'~'); let mut csv_text = Vec::new();
2238 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2239 for index in 0..10 {
2240 let text1 = format!("id{index:}");
2241 let text2 = format!("value{index:}");
2242 csv_writer
2243 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2244 .unwrap();
2245 }
2246 let mut csv_reader = std::io::Cursor::new(&csv_text);
2247 let mut reader = builder.build(&mut csv_reader).unwrap();
2248 let batch = reader.next().unwrap().unwrap();
2249 let col0 = batch.column(0);
2250 assert_eq!(col0.len(), 10);
2251 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2252 assert_eq!(col0_arr.value(0), "id0");
2253 let col1 = batch.column(1);
2254 assert_eq!(col1.len(), 10);
2255 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2256 assert_eq!(col1_arr.value(5), "value5");
2257 }
2258
2259 #[test]
2260 fn test_non_std_escape() {
2261 let schema = Schema::new(vec![
2262 Field::new("text1", DataType::Utf8, false),
2263 Field::new("text2", DataType::Utf8, false),
2264 ]);
2265 let builder = ReaderBuilder::new(Arc::new(schema))
2266 .with_header(false)
2267 .with_escape(b'\\'); let mut csv_text = Vec::new();
2270 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2271 for index in 0..10 {
2272 let text1 = format!("id{index:}");
2273 let text2 = format!("value\\\"{index:}");
2274 csv_writer
2275 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2276 .unwrap();
2277 }
2278 let mut csv_reader = std::io::Cursor::new(&csv_text);
2279 let mut reader = builder.build(&mut csv_reader).unwrap();
2280 let batch = reader.next().unwrap().unwrap();
2281 let col0 = batch.column(0);
2282 assert_eq!(col0.len(), 10);
2283 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2284 assert_eq!(col0_arr.value(0), "id0");
2285 let col1 = batch.column(1);
2286 assert_eq!(col1.len(), 10);
2287 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2288 assert_eq!(col1_arr.value(5), "value\"5");
2289 }
2290
2291 #[test]
2292 fn test_non_std_terminator() {
2293 let schema = Schema::new(vec![
2294 Field::new("text1", DataType::Utf8, false),
2295 Field::new("text2", DataType::Utf8, false),
2296 ]);
2297 let builder = ReaderBuilder::new(Arc::new(schema))
2298 .with_header(false)
2299 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2302 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2303 for index in 0..10 {
2304 let text1 = format!("id{index:}");
2305 let text2 = format!("value{index:}");
2306 csv_writer
2307 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2308 .unwrap();
2309 }
2310 let mut csv_reader = std::io::Cursor::new(&csv_text);
2311 let mut reader = builder.build(&mut csv_reader).unwrap();
2312 let batch = reader.next().unwrap().unwrap();
2313 let col0 = batch.column(0);
2314 assert_eq!(col0.len(), 10);
2315 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2316 assert_eq!(col0_arr.value(0), "id0");
2317 let col1 = batch.column(1);
2318 assert_eq!(col1.len(), 10);
2319 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2320 assert_eq!(col1_arr.value(5), "value5");
2321 }
2322
2323 #[test]
2324 fn test_header_bounds() {
2325 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2326 let tests = [
2327 (None, false, 5),
2328 (None, true, 4),
2329 (Some((0, 4)), false, 4),
2330 (Some((1, 4)), false, 3),
2331 (Some((0, 4)), true, 4),
2332 (Some((1, 4)), true, 3),
2333 ];
2334 let schema = Arc::new(Schema::new(vec![
2335 Field::new("a", DataType::Utf8, false),
2336 Field::new("a", DataType::Utf8, false),
2337 ]));
2338
2339 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2340 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2341 if let Some((start, end)) = bounds {
2342 reader = reader.with_bounds(start, end);
2343 }
2344 let b = reader
2345 .build_buffered(Cursor::new(csv.as_bytes()))
2346 .unwrap()
2347 .next()
2348 .unwrap()
2349 .unwrap();
2350 assert_eq!(b.num_rows(), expected, "{idx}");
2351 }
2352 }
2353
2354 #[test]
2355 fn test_null_boolean() {
2356 let csv = "true,false\nFalse,True\n,True\nFalse,";
2357 let schema = Arc::new(Schema::new(vec![
2358 Field::new("a", DataType::Boolean, true),
2359 Field::new("a", DataType::Boolean, true),
2360 ]));
2361
2362 let b = ReaderBuilder::new(schema)
2363 .build_buffered(Cursor::new(csv.as_bytes()))
2364 .unwrap()
2365 .next()
2366 .unwrap()
2367 .unwrap();
2368
2369 assert_eq!(b.num_rows(), 4);
2370 assert_eq!(b.num_columns(), 2);
2371
2372 let c = b.column(0).as_boolean();
2373 assert_eq!(c.null_count(), 1);
2374 assert!(c.value(0));
2375 assert!(!c.value(1));
2376 assert!(c.is_null(2));
2377 assert!(!c.value(3));
2378
2379 let c = b.column(1).as_boolean();
2380 assert_eq!(c.null_count(), 1);
2381 assert!(!c.value(0));
2382 assert!(c.value(1));
2383 assert!(c.value(2));
2384 assert!(c.is_null(3));
2385 }
2386
2387 #[test]
2388 fn test_truncated_rows() {
2389 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2390 let schema = Arc::new(Schema::new(vec![
2391 Field::new("a", DataType::Int32, true),
2392 Field::new("b", DataType::Int32, true),
2393 Field::new("c", DataType::Int32, true),
2394 ]));
2395
2396 let reader = ReaderBuilder::new(schema.clone())
2397 .with_header(true)
2398 .with_truncated_rows(true)
2399 .build(Cursor::new(data))
2400 .unwrap();
2401
2402 let batches = reader.collect::<Result<Vec<_>, _>>();
2403 assert!(batches.is_ok());
2404 let batch = batches.unwrap().into_iter().next().unwrap();
2405 assert_eq!(batch.num_rows(), 3);
2407
2408 let reader = ReaderBuilder::new(schema.clone())
2409 .with_header(true)
2410 .with_truncated_rows(false)
2411 .build(Cursor::new(data))
2412 .unwrap();
2413
2414 let batches = reader.collect::<Result<Vec<_>, _>>();
2415 assert!(match batches {
2416 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2417 _ => false,
2418 });
2419 }
2420
2421 #[test]
2422 fn test_truncated_rows_csv() {
2423 let file = File::open("test/data/truncated_rows.csv").unwrap();
2424 let schema = Arc::new(Schema::new(vec![
2425 Field::new("Name", DataType::Utf8, true),
2426 Field::new("Age", DataType::UInt32, true),
2427 Field::new("Occupation", DataType::Utf8, true),
2428 Field::new("DOB", DataType::Date32, true),
2429 ]));
2430 let reader = ReaderBuilder::new(schema.clone())
2431 .with_header(true)
2432 .with_batch_size(24)
2433 .with_truncated_rows(true);
2434 let csv = reader.build(file).unwrap();
2435 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2436
2437 assert_eq!(batches.len(), 1);
2438 let batch = &batches[0];
2439 assert_eq!(batch.num_rows(), 6);
2440 assert_eq!(batch.num_columns(), 4);
2441 let name = batch
2442 .column(0)
2443 .as_any()
2444 .downcast_ref::<StringArray>()
2445 .unwrap();
2446 let age = batch
2447 .column(1)
2448 .as_any()
2449 .downcast_ref::<UInt32Array>()
2450 .unwrap();
2451 let occupation = batch
2452 .column(2)
2453 .as_any()
2454 .downcast_ref::<StringArray>()
2455 .unwrap();
2456 let dob = batch
2457 .column(3)
2458 .as_any()
2459 .downcast_ref::<Date32Array>()
2460 .unwrap();
2461
2462 assert_eq!(name.value(0), "A1");
2463 assert_eq!(name.value(1), "B2");
2464 assert!(name.is_null(2));
2465 assert_eq!(name.value(3), "C3");
2466 assert_eq!(name.value(4), "D4");
2467 assert_eq!(name.value(5), "E5");
2468
2469 assert_eq!(age.value(0), 34);
2470 assert_eq!(age.value(1), 29);
2471 assert!(age.is_null(2));
2472 assert_eq!(age.value(3), 45);
2473 assert!(age.is_null(4));
2474 assert_eq!(age.value(5), 31);
2475
2476 assert_eq!(occupation.value(0), "Engineer");
2477 assert_eq!(occupation.value(1), "Doctor");
2478 assert!(occupation.is_null(2));
2479 assert_eq!(occupation.value(3), "Artist");
2480 assert!(occupation.is_null(4));
2481 assert!(occupation.is_null(5));
2482
2483 assert_eq!(dob.value(0), 5675);
2484 assert!(dob.is_null(1));
2485 assert!(dob.is_null(2));
2486 assert_eq!(dob.value(3), -1858);
2487 assert!(dob.is_null(4));
2488 assert!(dob.is_null(5));
2489 }
2490
2491 #[test]
2492 fn test_truncated_rows_not_nullable_error() {
2493 let data = "a,b,c\n1,2,3\n4,5";
2494 let schema = Arc::new(Schema::new(vec![
2495 Field::new("a", DataType::Int32, false),
2496 Field::new("b", DataType::Int32, false),
2497 Field::new("c", DataType::Int32, false),
2498 ]));
2499
2500 let reader = ReaderBuilder::new(schema.clone())
2501 .with_header(true)
2502 .with_truncated_rows(true)
2503 .build(Cursor::new(data))
2504 .unwrap();
2505
2506 let batches = reader.collect::<Result<Vec<_>, _>>();
2507 assert!(match batches {
2508 Err(ArrowError::InvalidArgumentError(e)) =>
2509 e.to_string().contains("contains null values"),
2510 _ => false,
2511 });
2512 }
2513
2514 #[test]
2515 fn test_buffered() {
2516 let tests = [
2517 ("test/data/uk_cities.csv", false, 37),
2518 ("test/data/various_types.csv", true, 10),
2519 ("test/data/decimal_test.csv", false, 10),
2520 ];
2521
2522 for (path, has_header, expected_rows) in tests {
2523 let (schema, _) = Format::default()
2524 .infer_schema(File::open(path).unwrap(), None)
2525 .unwrap();
2526 let schema = Arc::new(schema);
2527
2528 for batch_size in [1, 4] {
2529 for capacity in [1, 3, 7, 100] {
2530 let reader = ReaderBuilder::new(schema.clone())
2531 .with_batch_size(batch_size)
2532 .with_header(has_header)
2533 .build(File::open(path).unwrap())
2534 .unwrap();
2535
2536 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2537
2538 assert_eq!(
2539 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2540 expected_rows
2541 );
2542
2543 let buffered =
2544 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2545
2546 let reader = ReaderBuilder::new(schema.clone())
2547 .with_batch_size(batch_size)
2548 .with_header(has_header)
2549 .build_buffered(buffered)
2550 .unwrap();
2551
2552 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2553 assert_eq!(expected, actual)
2554 }
2555 }
2556 }
2557 }
2558
2559 fn err_test(csv: &[u8], expected: &str) {
2560 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2561 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2562 let b = ReaderBuilder::new(schema)
2563 .with_batch_size(2)
2564 .build_buffered(buffer)
2565 .unwrap();
2566 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2567 assert_eq!(err, expected)
2568 }
2569
2570 let schema_utf8 = Arc::new(Schema::new(vec![
2571 Field::new("text1", DataType::Utf8, true),
2572 Field::new("text2", DataType::Utf8, true),
2573 ]));
2574 err_test_with_schema(csv, expected, schema_utf8);
2575
2576 let schema_utf8view = Arc::new(Schema::new(vec![
2577 Field::new("text1", DataType::Utf8View, true),
2578 Field::new("text2", DataType::Utf8View, true),
2579 ]));
2580 err_test_with_schema(csv, expected, schema_utf8view);
2581 }
2582
2583 #[test]
2584 fn test_invalid_utf8() {
2585 err_test(
2586 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2587 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2588 );
2589
2590 err_test(
2591 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2592 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2593 );
2594
2595 err_test(
2596 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2597 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2598 );
2599
2600 err_test(
2601 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2602 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2603 );
2604 }
2605
2606 struct InstrumentedRead<R> {
2607 r: R,
2608 fill_count: usize,
2609 fill_sizes: Vec<usize>,
2610 }
2611
2612 impl<R> InstrumentedRead<R> {
2613 fn new(r: R) -> Self {
2614 Self {
2615 r,
2616 fill_count: 0,
2617 fill_sizes: vec![],
2618 }
2619 }
2620 }
2621
2622 impl<R: Seek> Seek for InstrumentedRead<R> {
2623 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2624 self.r.seek(pos)
2625 }
2626 }
2627
2628 impl<R: BufRead> Read for InstrumentedRead<R> {
2629 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2630 self.r.read(buf)
2631 }
2632 }
2633
2634 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2635 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2636 self.fill_count += 1;
2637 let buf = self.r.fill_buf()?;
2638 self.fill_sizes.push(buf.len());
2639 Ok(buf)
2640 }
2641
2642 fn consume(&mut self, amt: usize) {
2643 self.r.consume(amt)
2644 }
2645 }
2646
2647 #[test]
2648 fn test_io() {
2649 let schema = Arc::new(Schema::new(vec![
2650 Field::new("a", DataType::Utf8, false),
2651 Field::new("b", DataType::Utf8, false),
2652 ]));
2653 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2654 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2655 let reader = ReaderBuilder::new(schema)
2656 .with_batch_size(3)
2657 .build_buffered(&mut read)
2658 .unwrap();
2659
2660 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2661 assert_eq!(batches.len(), 2);
2662 assert_eq!(batches[0].num_rows(), 3);
2663 assert_eq!(batches[1].num_rows(), 1);
2664
2665 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2671 assert_eq!(read.fill_count, 4);
2672 }
2673
2674 #[test]
2675 fn test_inference() {
2676 let cases: &[(&[&str], DataType)] = &[
2677 (&[], DataType::Null),
2678 (&["false", "12"], DataType::Utf8),
2679 (&["12", "cupcakes"], DataType::Utf8),
2680 (&["12", "12.4"], DataType::Float64),
2681 (&["14050", "24332"], DataType::Int64),
2682 (&["14050.0", "true"], DataType::Utf8),
2683 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2684 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2685 (
2686 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2687 DataType::Timestamp(TimeUnit::Second, None),
2688 ),
2689 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2690 (
2691 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2692 DataType::Timestamp(TimeUnit::Second, None),
2693 ),
2694 (
2695 &[
2696 "2020-03-19",
2697 "2020-03-19 02:00:00",
2698 "2020-03-19 00:00:00.000",
2699 ],
2700 DataType::Timestamp(TimeUnit::Millisecond, None),
2701 ),
2702 (
2703 &[
2704 "2020-03-19",
2705 "2020-03-19 02:00:00",
2706 "2020-03-19 00:00:00.000000",
2707 ],
2708 DataType::Timestamp(TimeUnit::Microsecond, None),
2709 ),
2710 (
2711 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2712 DataType::Timestamp(TimeUnit::Second, None),
2713 ),
2714 (
2715 &[
2716 "2020-03-19",
2717 "2020-03-19 02:00:00+02:00",
2718 "2020-03-19 02:00:00Z",
2719 "2020-03-19 02:00:00.12Z",
2720 ],
2721 DataType::Timestamp(TimeUnit::Millisecond, None),
2722 ),
2723 (
2724 &[
2725 "2020-03-19",
2726 "2020-03-19 02:00:00.000000000",
2727 "2020-03-19 00:00:00.000000",
2728 ],
2729 DataType::Timestamp(TimeUnit::Nanosecond, None),
2730 ),
2731 ];
2732
2733 for (values, expected) in cases {
2734 let mut t = InferredDataType::default();
2735 for v in *values {
2736 t.update(v)
2737 }
2738 assert_eq!(&t.get(), expected, "{values:?}")
2739 }
2740 }
2741
2742 #[test]
2743 fn test_record_length_mismatch() {
2744 let csv = "\
2745 a,b,c\n\
2746 1,2,3\n\
2747 4,5\n\
2748 6,7,8";
2749 let mut read = Cursor::new(csv.as_bytes());
2750 let result = Format::default()
2751 .with_header(true)
2752 .infer_schema(&mut read, None);
2753 assert!(result.is_err());
2754 assert_eq!(
2756 result.err().unwrap().to_string(),
2757 "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3"
2758 );
2759 }
2760
2761 #[test]
2762 fn test_comment() {
2763 let schema = Schema::new(vec![
2764 Field::new("a", DataType::Int8, false),
2765 Field::new("b", DataType::Int8, false),
2766 ]);
2767
2768 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2769 let mut read = Cursor::new(csv.as_bytes());
2770 let reader = ReaderBuilder::new(Arc::new(schema))
2771 .with_comment(b'#')
2772 .build(&mut read)
2773 .unwrap();
2774
2775 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2776 assert_eq!(batches.len(), 1);
2777 let b = batches.first().unwrap();
2778 assert_eq!(b.num_columns(), 2);
2779 assert_eq!(
2780 b.column(0)
2781 .as_any()
2782 .downcast_ref::<Int8Array>()
2783 .unwrap()
2784 .values(),
2785 &vec![1, 11]
2786 );
2787 assert_eq!(
2788 b.column(1)
2789 .as_any()
2790 .downcast_ref::<Int8Array>()
2791 .unwrap()
2792 .values(),
2793 &vec![2, 22]
2794 );
2795 }
2796
2797 #[test]
2798 fn test_parse_string_view_single_column() {
2799 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2800 let schema = Arc::new(Schema::new(vec![Field::new(
2801 "c1",
2802 DataType::Utf8View,
2803 true,
2804 )]));
2805
2806 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2807
2808 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2809 assert_eq!(decoded, csv.len());
2810 decoder.decode(&[]).unwrap();
2811
2812 let batch = decoder.flush().unwrap().unwrap();
2813 assert_eq!(batch.num_columns(), 1);
2814 assert_eq!(batch.num_rows(), 3);
2815 let col = batch.column(0).as_string_view();
2816 assert_eq!(col.data_type(), &DataType::Utf8View);
2817 assert_eq!(col.value(0), "foo");
2818 assert_eq!(col.value(1), "something_cannot_be_inlined");
2819 assert_eq!(col.value(2), "foobar");
2820 }
2821
2822 #[test]
2823 fn test_parse_string_view_multi_column() {
2824 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2825 let schema = Arc::new(Schema::new(vec![
2826 Field::new("c1", DataType::Utf8View, true),
2827 Field::new("c2", DataType::Utf8View, true),
2828 ]));
2829
2830 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2831
2832 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2833 assert_eq!(decoded, csv.len());
2834 decoder.decode(&[]).unwrap();
2835
2836 let batch = decoder.flush().unwrap().unwrap();
2837 assert_eq!(batch.num_columns(), 2);
2838 assert_eq!(batch.num_rows(), 3);
2839 let c1 = batch.column(0).as_string_view();
2840 let c2 = batch.column(1).as_string_view();
2841 assert_eq!(c1.data_type(), &DataType::Utf8View);
2842 assert_eq!(c2.data_type(), &DataType::Utf8View);
2843
2844 assert!(!c1.is_null(0));
2845 assert!(c1.is_null(1));
2846 assert!(!c1.is_null(2));
2847 assert_eq!(c1.value(0), "foo");
2848 assert_eq!(c1.value(2), "foobarfoobar");
2849
2850 assert!(c2.is_null(0));
2851 assert!(!c2.is_null(1));
2852 assert!(!c2.is_null(2));
2853 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2854 assert_eq!(c2.value(2), "bar");
2855 }
2856}