1mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{Parser, parse_decimal, string_to_datetime};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use regex::{Regex, RegexSet};
136use std::fmt::{self, Debug};
137use std::fs::File;
138use std::io::{BufRead, BufReader as StdBufReader, Read};
139use std::sync::{Arc, LazyLock};
140
141use crate::map_csv_error;
142use crate::reader::records::{RecordDecoder, StringRecords};
143use arrow_array::timezone::Tz;
144
145static REGEX_SET: LazyLock<RegexSet> = LazyLock::new(|| {
147 RegexSet::new([
148 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ])
157 .unwrap()
158});
159
160#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165 #[inline]
168 fn is_null(&self, s: &str) -> bool {
169 match &self.0 {
170 Some(r) => r.is_match(s),
171 None => s.is_empty(),
172 }
173 }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178 packed: u16,
190}
191
192impl InferredDataType {
193 fn get(&self) -> DataType {
195 match self.packed {
196 0 => DataType::Null,
197 1 => DataType::Boolean,
198 2 => DataType::Int64,
199 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205 11 => DataType::Timestamp(TimeUnit::Second, None),
206 12 => DataType::Date32,
207 _ => unreachable!(),
208 },
209 _ => DataType::Utf8,
210 }
211 }
212
213 fn update(&mut self, string: &str) {
215 self.packed |= if string.starts_with('"') {
216 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219 1 << 8
221 } else {
222 1 << m
223 }
224 } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225 1 << 2 } else {
227 1 << 8 }
229 }
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct Format {
235 header: bool,
236 delimiter: Option<u8>,
237 escape: Option<u8>,
238 quote: Option<u8>,
239 terminator: Option<u8>,
240 comment: Option<u8>,
241 null_regex: NullRegex,
242 truncated_rows: bool,
243}
244
245impl Format {
246 pub fn with_header(mut self, has_header: bool) -> Self {
250 self.header = has_header;
251 self
252 }
253
254 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256 self.delimiter = Some(delimiter);
257 self
258 }
259
260 pub fn with_escape(mut self, escape: u8) -> Self {
262 self.escape = Some(escape);
263 self
264 }
265
266 pub fn with_quote(mut self, quote: u8) -> Self {
268 self.quote = Some(quote);
269 self
270 }
271
272 pub fn with_terminator(mut self, terminator: u8) -> Self {
274 self.terminator = Some(terminator);
275 self
276 }
277
278 pub fn with_comment(mut self, comment: u8) -> Self {
282 self.comment = Some(comment);
283 self
284 }
285
286 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288 self.null_regex = NullRegex(Some(null_regex));
289 self
290 }
291
292 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299 self.truncated_rows = allow;
300 self
301 }
302
303 pub fn infer_schema<R: Read>(
310 &self,
311 reader: R,
312 max_records: Option<usize>,
313 ) -> Result<(Schema, usize), ArrowError> {
314 let mut csv_reader = self.build_reader(reader);
315
316 let headers: Vec<String> = if self.header {
319 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320 headers.iter().map(|s| s.to_string()).collect()
321 } else {
322 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323 (0..*first_record_count)
324 .map(|i| format!("column_{}", i + 1))
325 .collect()
326 };
327
328 let header_length = headers.len();
329 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332 let mut records_count = 0;
333
334 let mut record = StringRecord::new();
335 let max_records = max_records.unwrap_or(usize::MAX);
336 while records_count < max_records {
337 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338 break;
339 }
340 records_count += 1;
341
342 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345 if let Some(string) = record.get(i) {
346 if !self.null_regex.is_null(string) {
347 column_type.update(string)
348 }
349 }
350 }
351 }
352
353 let fields: Fields = column_types
355 .iter()
356 .zip(&headers)
357 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358 .collect();
359
360 Ok((Schema::new(fields), records_count))
361 }
362
363 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365 let mut builder = csv::ReaderBuilder::new();
366 builder.has_headers(self.header);
367 builder.flexible(self.truncated_rows);
368
369 if let Some(c) = self.delimiter {
370 builder.delimiter(c);
371 }
372 builder.escape(self.escape);
373 if let Some(c) = self.quote {
374 builder.quote(c);
375 }
376 if let Some(t) = self.terminator {
377 builder.terminator(csv::Terminator::Any(t));
378 }
379 if let Some(comment) = self.comment {
380 builder.comment(Some(comment));
381 }
382 builder.from_reader(reader)
383 }
384
385 fn build_parser(&self) -> csv_core::Reader {
387 let mut builder = csv_core::ReaderBuilder::new();
388 builder.escape(self.escape);
389 builder.comment(self.comment);
390
391 if let Some(c) = self.delimiter {
392 builder.delimiter(c);
393 }
394 if let Some(c) = self.quote {
395 builder.quote(c);
396 }
397 if let Some(t) = self.terminator {
398 builder.terminator(csv_core::Terminator::Any(t));
399 }
400 builder.build()
401 }
402}
403
404pub fn infer_schema_from_files(
411 files: &[String],
412 delimiter: u8,
413 max_read_records: Option<usize>,
414 has_header: bool,
415) -> Result<Schema, ArrowError> {
416 let mut schemas = vec![];
417 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418 let format = Format {
419 delimiter: Some(delimiter),
420 header: has_header,
421 ..Default::default()
422 };
423
424 for fname in files.iter() {
425 let f = File::open(fname)?;
426 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427 if records_read == 0 {
428 continue;
429 }
430 schemas.push(schema.clone());
431 records_to_read -= records_read;
432 if records_to_read == 0 {
433 break;
434 }
435 }
436
437 Schema::try_merge(schemas)
438}
439
440type Bounds = Option<(usize, usize)>;
442
443pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446pub struct BufReader<R> {
448 reader: R,
450
451 decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457 R: BufRead,
458{
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 f.debug_struct("Reader")
461 .field("decoder", &self.decoder)
462 .finish()
463 }
464}
465
466impl<R: Read> Reader<R> {
467 pub fn schema(&self) -> SchemaRef {
470 match &self.decoder.projection {
471 Some(projection) => {
472 let fields = self.decoder.schema.fields();
473 let projected = projection.iter().map(|i| fields[*i].clone());
474 Arc::new(Schema::new(projected.collect::<Fields>()))
475 }
476 None => self.decoder.schema.clone(),
477 }
478 }
479}
480
481impl<R: BufRead> BufReader<R> {
482 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483 loop {
484 let buf = self.reader.fill_buf()?;
485 let decoded = self.decoder.decode(buf)?;
486 self.reader.consume(decoded);
487 if decoded == 0 || self.decoder.capacity() == 0 {
493 break;
494 }
495 }
496
497 self.decoder.flush()
498 }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502 type Item = Result<RecordBatch, ArrowError>;
503
504 fn next(&mut self) -> Option<Self::Item> {
505 self.read().transpose()
506 }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510 fn schema(&self) -> SchemaRef {
511 self.decoder.schema.clone()
512 }
513}
514
515#[derive(Debug)]
555pub struct Decoder {
556 schema: SchemaRef,
558
559 projection: Option<Vec<usize>>,
561
562 batch_size: usize,
564
565 to_skip: usize,
567
568 line_number: usize,
570
571 end: usize,
573
574 record_decoder: RecordDecoder,
576
577 null_regex: NullRegex,
579}
580
581impl Decoder {
582 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592 if self.to_skip != 0 {
593 let to_skip = self.to_skip.min(self.batch_size);
595 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596 self.to_skip -= skipped;
597 self.record_decoder.clear();
598 return Ok(bytes);
599 }
600
601 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603 Ok(bytes)
604 }
605
606 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613 if self.record_decoder.is_empty() {
614 return Ok(None);
615 }
616
617 let rows = self.record_decoder.flush()?;
618 let batch = parse(
619 &rows,
620 self.schema.fields(),
621 Some(self.schema.metadata.clone()),
622 self.projection.as_ref(),
623 self.line_number,
624 &self.null_regex,
625 )?;
626 self.line_number += rows.len();
627 Ok(Some(batch))
628 }
629
630 pub fn capacity(&self) -> usize {
632 self.batch_size - self.record_decoder.len()
633 }
634}
635
636fn parse(
638 rows: &StringRecords<'_>,
639 fields: &Fields,
640 metadata: Option<std::collections::HashMap<String, String>>,
641 projection: Option<&Vec<usize>>,
642 line_number: usize,
643 null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645 let projection: Vec<usize> = match projection {
646 Some(v) => v.clone(),
647 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648 };
649
650 let arrays: Result<Vec<ArrayRef>, _> = projection
651 .iter()
652 .map(|i| {
653 let i = *i;
654 let field = &fields[i];
655 match field.data_type() {
656 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657 DataType::Decimal32(precision, scale) => build_decimal_array::<Decimal32Type>(
658 line_number,
659 rows,
660 i,
661 *precision,
662 *scale,
663 null_regex,
664 ),
665 DataType::Decimal64(precision, scale) => build_decimal_array::<Decimal64Type>(
666 line_number,
667 rows,
668 i,
669 *precision,
670 *scale,
671 null_regex,
672 ),
673 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
674 line_number,
675 rows,
676 i,
677 *precision,
678 *scale,
679 null_regex,
680 ),
681 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
682 line_number,
683 rows,
684 i,
685 *precision,
686 *scale,
687 null_regex,
688 ),
689 DataType::Int8 => {
690 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
691 }
692 DataType::Int16 => {
693 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
694 }
695 DataType::Int32 => {
696 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
697 }
698 DataType::Int64 => {
699 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
700 }
701 DataType::UInt8 => {
702 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
703 }
704 DataType::UInt16 => {
705 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
706 }
707 DataType::UInt32 => {
708 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
709 }
710 DataType::UInt64 => {
711 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
712 }
713 DataType::Float32 => {
714 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
715 }
716 DataType::Float64 => {
717 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
718 }
719 DataType::Date32 => {
720 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
721 }
722 DataType::Date64 => {
723 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
724 }
725 DataType::Time32(TimeUnit::Second) => {
726 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
727 }
728 DataType::Time32(TimeUnit::Millisecond) => {
729 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
730 }
731 DataType::Time64(TimeUnit::Microsecond) => {
732 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
733 }
734 DataType::Time64(TimeUnit::Nanosecond) => {
735 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
736 }
737 DataType::Timestamp(TimeUnit::Second, tz) => {
738 build_timestamp_array::<TimestampSecondType>(
739 line_number,
740 rows,
741 i,
742 tz.as_deref(),
743 null_regex,
744 )
745 }
746 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
747 build_timestamp_array::<TimestampMillisecondType>(
748 line_number,
749 rows,
750 i,
751 tz.as_deref(),
752 null_regex,
753 )
754 }
755 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
756 build_timestamp_array::<TimestampMicrosecondType>(
757 line_number,
758 rows,
759 i,
760 tz.as_deref(),
761 null_regex,
762 )
763 }
764 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
765 build_timestamp_array::<TimestampNanosecondType>(
766 line_number,
767 rows,
768 i,
769 tz.as_deref(),
770 null_regex,
771 )
772 }
773 DataType::Null => Ok(Arc::new({
774 let mut builder = NullBuilder::new();
775 builder.append_nulls(rows.len());
776 builder.finish()
777 }) as ArrayRef),
778 DataType::Utf8 => Ok(Arc::new(
779 rows.iter()
780 .map(|row| {
781 let s = row.get(i);
782 (!null_regex.is_null(s)).then_some(s)
783 })
784 .collect::<StringArray>(),
785 ) as ArrayRef),
786 DataType::Utf8View => Ok(Arc::new(
787 rows.iter()
788 .map(|row| {
789 let s = row.get(i);
790 (!null_regex.is_null(s)).then_some(s)
791 })
792 .collect::<StringViewArray>(),
793 ) as ArrayRef),
794 DataType::Dictionary(key_type, value_type)
795 if value_type.as_ref() == &DataType::Utf8 =>
796 {
797 match key_type.as_ref() {
798 DataType::Int8 => Ok(Arc::new(
799 rows.iter()
800 .map(|row| {
801 let s = row.get(i);
802 (!null_regex.is_null(s)).then_some(s)
803 })
804 .collect::<DictionaryArray<Int8Type>>(),
805 ) as ArrayRef),
806 DataType::Int16 => Ok(Arc::new(
807 rows.iter()
808 .map(|row| {
809 let s = row.get(i);
810 (!null_regex.is_null(s)).then_some(s)
811 })
812 .collect::<DictionaryArray<Int16Type>>(),
813 ) as ArrayRef),
814 DataType::Int32 => Ok(Arc::new(
815 rows.iter()
816 .map(|row| {
817 let s = row.get(i);
818 (!null_regex.is_null(s)).then_some(s)
819 })
820 .collect::<DictionaryArray<Int32Type>>(),
821 ) as ArrayRef),
822 DataType::Int64 => Ok(Arc::new(
823 rows.iter()
824 .map(|row| {
825 let s = row.get(i);
826 (!null_regex.is_null(s)).then_some(s)
827 })
828 .collect::<DictionaryArray<Int64Type>>(),
829 ) as ArrayRef),
830 DataType::UInt8 => Ok(Arc::new(
831 rows.iter()
832 .map(|row| {
833 let s = row.get(i);
834 (!null_regex.is_null(s)).then_some(s)
835 })
836 .collect::<DictionaryArray<UInt8Type>>(),
837 ) as ArrayRef),
838 DataType::UInt16 => Ok(Arc::new(
839 rows.iter()
840 .map(|row| {
841 let s = row.get(i);
842 (!null_regex.is_null(s)).then_some(s)
843 })
844 .collect::<DictionaryArray<UInt16Type>>(),
845 ) as ArrayRef),
846 DataType::UInt32 => Ok(Arc::new(
847 rows.iter()
848 .map(|row| {
849 let s = row.get(i);
850 (!null_regex.is_null(s)).then_some(s)
851 })
852 .collect::<DictionaryArray<UInt32Type>>(),
853 ) as ArrayRef),
854 DataType::UInt64 => Ok(Arc::new(
855 rows.iter()
856 .map(|row| {
857 let s = row.get(i);
858 (!null_regex.is_null(s)).then_some(s)
859 })
860 .collect::<DictionaryArray<UInt64Type>>(),
861 ) as ArrayRef),
862 _ => Err(ArrowError::ParseError(format!(
863 "Unsupported dictionary key type {key_type}"
864 ))),
865 }
866 }
867 other => Err(ArrowError::ParseError(format!(
868 "Unsupported data type {other:?}"
869 ))),
870 }
871 })
872 .collect();
873
874 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
875
876 let projected_schema = Arc::new(match metadata {
877 None => Schema::new(projected_fields),
878 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
879 });
880
881 arrays.and_then(|arr| {
882 RecordBatch::try_new_with_options(
883 projected_schema,
884 arr,
885 &RecordBatchOptions::new()
886 .with_match_field_names(true)
887 .with_row_count(Some(rows.len())),
888 )
889 })
890}
891
892fn parse_bool(string: &str) -> Option<bool> {
893 if string.eq_ignore_ascii_case("false") {
894 Some(false)
895 } else if string.eq_ignore_ascii_case("true") {
896 Some(true)
897 } else {
898 None
899 }
900}
901
902fn build_decimal_array<T: DecimalType>(
904 _line_number: usize,
905 rows: &StringRecords<'_>,
906 col_idx: usize,
907 precision: u8,
908 scale: i8,
909 null_regex: &NullRegex,
910) -> Result<ArrayRef, ArrowError> {
911 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
912 for row in rows.iter() {
913 let s = row.get(col_idx);
914 if null_regex.is_null(s) {
915 decimal_builder.append_null();
917 } else {
918 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
919 match decimal_value {
920 Ok(v) => {
921 decimal_builder.append_value(v);
922 }
923 Err(e) => {
924 return Err(e);
925 }
926 }
927 }
928 }
929 Ok(Arc::new(
930 decimal_builder
931 .finish()
932 .with_precision_and_scale(precision, scale)?,
933 ))
934}
935
936fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
938 line_number: usize,
939 rows: &StringRecords<'_>,
940 col_idx: usize,
941 null_regex: &NullRegex,
942) -> Result<ArrayRef, ArrowError> {
943 rows.iter()
944 .enumerate()
945 .map(|(row_index, row)| {
946 let s = row.get(col_idx);
947 if null_regex.is_null(s) {
948 return Ok(None);
949 }
950
951 match T::parse(s) {
952 Some(e) => Ok(Some(e)),
953 None => Err(ArrowError::ParseError(format!(
954 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
956 s,
957 T::DATA_TYPE,
958 col_idx,
959 line_number + row_index,
960 row
961 ))),
962 }
963 })
964 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
965 .map(|e| Arc::new(e) as ArrayRef)
966}
967
968fn build_timestamp_array<T: ArrowTimestampType>(
969 line_number: usize,
970 rows: &StringRecords<'_>,
971 col_idx: usize,
972 timezone: Option<&str>,
973 null_regex: &NullRegex,
974) -> Result<ArrayRef, ArrowError> {
975 Ok(Arc::new(match timezone {
976 Some(timezone) => {
977 let tz: Tz = timezone.parse()?;
978 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
979 .with_timezone(timezone)
980 }
981 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
982 }))
983}
984
985fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
986 line_number: usize,
987 rows: &StringRecords<'_>,
988 col_idx: usize,
989 timezone: &Tz,
990 null_regex: &NullRegex,
991) -> Result<PrimitiveArray<T>, ArrowError> {
992 rows.iter()
993 .enumerate()
994 .map(|(row_index, row)| {
995 let s = row.get(col_idx);
996 if null_regex.is_null(s) {
997 return Ok(None);
998 }
999
1000 let date = string_to_datetime(timezone, s)
1001 .and_then(|date| match T::UNIT {
1002 TimeUnit::Second => Ok(date.timestamp()),
1003 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
1004 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
1005 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
1006 ArrowError::ParseError(format!(
1007 "{} would overflow 64-bit signed nanoseconds",
1008 date.to_rfc3339(),
1009 ))
1010 }),
1011 })
1012 .map_err(|e| {
1013 ArrowError::ParseError(format!(
1014 "Error parsing column {col_idx} at line {}: {}",
1015 line_number + row_index,
1016 e
1017 ))
1018 })?;
1019 Ok(Some(date))
1020 })
1021 .collect()
1022}
1023
1024fn build_boolean_array(
1026 line_number: usize,
1027 rows: &StringRecords<'_>,
1028 col_idx: usize,
1029 null_regex: &NullRegex,
1030) -> Result<ArrayRef, ArrowError> {
1031 rows.iter()
1032 .enumerate()
1033 .map(|(row_index, row)| {
1034 let s = row.get(col_idx);
1035 if null_regex.is_null(s) {
1036 return Ok(None);
1037 }
1038 let parsed = parse_bool(s);
1039 match parsed {
1040 Some(e) => Ok(Some(e)),
1041 None => Err(ArrowError::ParseError(format!(
1042 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1044 s,
1045 "Boolean",
1046 col_idx,
1047 line_number + row_index,
1048 row
1049 ))),
1050 }
1051 })
1052 .collect::<Result<BooleanArray, _>>()
1053 .map(|e| Arc::new(e) as ArrayRef)
1054}
1055
1056#[derive(Debug)]
1058pub struct ReaderBuilder {
1059 schema: SchemaRef,
1061 format: Format,
1063 batch_size: usize,
1067 bounds: Bounds,
1069 projection: Option<Vec<usize>>,
1071}
1072
1073impl ReaderBuilder {
1074 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1096 Self {
1097 schema,
1098 format: Format::default(),
1099 batch_size: 1024,
1100 bounds: None,
1101 projection: None,
1102 }
1103 }
1104
1105 pub fn with_header(mut self, has_header: bool) -> Self {
1107 self.format.header = has_header;
1108 self
1109 }
1110
1111 pub fn with_format(mut self, format: Format) -> Self {
1113 self.format = format;
1114 self
1115 }
1116
1117 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1119 self.format.delimiter = Some(delimiter);
1120 self
1121 }
1122
1123 pub fn with_escape(mut self, escape: u8) -> Self {
1125 self.format.escape = Some(escape);
1126 self
1127 }
1128
1129 pub fn with_quote(mut self, quote: u8) -> Self {
1131 self.format.quote = Some(quote);
1132 self
1133 }
1134
1135 pub fn with_terminator(mut self, terminator: u8) -> Self {
1137 self.format.terminator = Some(terminator);
1138 self
1139 }
1140
1141 pub fn with_comment(mut self, comment: u8) -> Self {
1143 self.format.comment = Some(comment);
1144 self
1145 }
1146
1147 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1149 self.format.null_regex = NullRegex(Some(null_regex));
1150 self
1151 }
1152
1153 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1155 self.batch_size = batch_size;
1156 self
1157 }
1158
1159 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1162 self.bounds = Some((start, end));
1163 self
1164 }
1165
1166 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1168 self.projection = Some(projection);
1169 self
1170 }
1171
1172 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1179 self.format.truncated_rows = allow;
1180 self
1181 }
1182
1183 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1188 self.build_buffered(StdBufReader::new(reader))
1189 }
1190
1191 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1193 Ok(BufReader {
1194 reader,
1195 decoder: self.build_decoder(),
1196 })
1197 }
1198
1199 pub fn build_decoder(self) -> Decoder {
1201 let delimiter = self.format.build_parser();
1202 let record_decoder = RecordDecoder::new(
1203 delimiter,
1204 self.schema.fields().len(),
1205 self.format.truncated_rows,
1206 );
1207
1208 let header = self.format.header as usize;
1209
1210 let (start, end) = match self.bounds {
1211 Some((start, end)) => (start + header, end + header),
1212 None => (header, usize::MAX),
1213 };
1214
1215 Decoder {
1216 schema: self.schema,
1217 to_skip: start,
1218 record_decoder,
1219 line_number: start,
1220 end,
1221 projection: self.projection,
1222 batch_size: self.batch_size,
1223 null_regex: self.format.null_regex,
1224 }
1225 }
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231
1232 use std::io::{Cursor, Seek, SeekFrom, Write};
1233 use tempfile::NamedTempFile;
1234
1235 use arrow_array::cast::AsArray;
1236
1237 #[test]
1238 fn test_csv() {
1239 let schema = Arc::new(Schema::new(vec![
1240 Field::new("city", DataType::Utf8, false),
1241 Field::new("lat", DataType::Float64, false),
1242 Field::new("lng", DataType::Float64, false),
1243 ]));
1244
1245 let file = File::open("test/data/uk_cities.csv").unwrap();
1246 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1247 assert_eq!(schema, csv.schema());
1248 let batch = csv.next().unwrap().unwrap();
1249 assert_eq!(37, batch.num_rows());
1250 assert_eq!(3, batch.num_columns());
1251
1252 let lat = batch.column(1).as_primitive::<Float64Type>();
1254 assert_eq!(57.653484, lat.value(0));
1255
1256 let city = batch.column(0).as_string::<i32>();
1258
1259 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1260 }
1261
1262 #[test]
1263 fn test_csv_schema_metadata() {
1264 let mut metadata = std::collections::HashMap::new();
1265 metadata.insert("foo".to_owned(), "bar".to_owned());
1266 let schema = Arc::new(Schema::new_with_metadata(
1267 vec![
1268 Field::new("city", DataType::Utf8, false),
1269 Field::new("lat", DataType::Float64, false),
1270 Field::new("lng", DataType::Float64, false),
1271 ],
1272 metadata.clone(),
1273 ));
1274
1275 let file = File::open("test/data/uk_cities.csv").unwrap();
1276
1277 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1278 assert_eq!(schema, csv.schema());
1279 let batch = csv.next().unwrap().unwrap();
1280 assert_eq!(37, batch.num_rows());
1281 assert_eq!(3, batch.num_columns());
1282
1283 assert_eq!(&metadata, batch.schema().metadata());
1284 }
1285
1286 #[test]
1287 fn test_csv_reader_with_decimal() {
1288 let schema = Arc::new(Schema::new(vec![
1289 Field::new("city", DataType::Utf8, false),
1290 Field::new("lat", DataType::Decimal128(38, 6), false),
1291 Field::new("lng", DataType::Decimal256(76, 6), false),
1292 ]));
1293
1294 let file = File::open("test/data/decimal_test.csv").unwrap();
1295
1296 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1297 let batch = csv.next().unwrap().unwrap();
1298 let lat = batch
1300 .column(1)
1301 .as_any()
1302 .downcast_ref::<Decimal128Array>()
1303 .unwrap();
1304
1305 assert_eq!("57.653484", lat.value_as_string(0));
1306 assert_eq!("53.002666", lat.value_as_string(1));
1307 assert_eq!("52.412811", lat.value_as_string(2));
1308 assert_eq!("51.481583", lat.value_as_string(3));
1309 assert_eq!("12.123456", lat.value_as_string(4));
1310 assert_eq!("50.760000", lat.value_as_string(5));
1311 assert_eq!("0.123000", lat.value_as_string(6));
1312 assert_eq!("123.000000", lat.value_as_string(7));
1313 assert_eq!("123.000000", lat.value_as_string(8));
1314 assert_eq!("-50.760000", lat.value_as_string(9));
1315
1316 let lng = batch
1317 .column(2)
1318 .as_any()
1319 .downcast_ref::<Decimal256Array>()
1320 .unwrap();
1321
1322 assert_eq!("-3.335724", lng.value_as_string(0));
1323 assert_eq!("-2.179404", lng.value_as_string(1));
1324 assert_eq!("-1.778197", lng.value_as_string(2));
1325 assert_eq!("-3.179090", lng.value_as_string(3));
1326 assert_eq!("-3.179090", lng.value_as_string(4));
1327 assert_eq!("0.290472", lng.value_as_string(5));
1328 assert_eq!("0.290472", lng.value_as_string(6));
1329 assert_eq!("0.290472", lng.value_as_string(7));
1330 assert_eq!("0.290472", lng.value_as_string(8));
1331 assert_eq!("0.290472", lng.value_as_string(9));
1332 }
1333
1334 #[test]
1335 fn test_csv_reader_with_decimal_3264() {
1336 let schema = Arc::new(Schema::new(vec![
1337 Field::new("city", DataType::Utf8, false),
1338 Field::new("lat", DataType::Decimal32(9, 6), false),
1339 Field::new("lng", DataType::Decimal64(16, 6), false),
1340 ]));
1341
1342 let file = File::open("test/data/decimal_test.csv").unwrap();
1343
1344 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1345 let batch = csv.next().unwrap().unwrap();
1346 let lat = batch
1348 .column(1)
1349 .as_any()
1350 .downcast_ref::<Decimal32Array>()
1351 .unwrap();
1352
1353 assert_eq!("57.653484", lat.value_as_string(0));
1354 assert_eq!("53.002666", lat.value_as_string(1));
1355 assert_eq!("52.412811", lat.value_as_string(2));
1356 assert_eq!("51.481583", lat.value_as_string(3));
1357 assert_eq!("12.123456", lat.value_as_string(4));
1358 assert_eq!("50.760000", lat.value_as_string(5));
1359 assert_eq!("0.123000", lat.value_as_string(6));
1360 assert_eq!("123.000000", lat.value_as_string(7));
1361 assert_eq!("123.000000", lat.value_as_string(8));
1362 assert_eq!("-50.760000", lat.value_as_string(9));
1363
1364 let lng = batch
1365 .column(2)
1366 .as_any()
1367 .downcast_ref::<Decimal64Array>()
1368 .unwrap();
1369
1370 assert_eq!("-3.335724", lng.value_as_string(0));
1371 assert_eq!("-2.179404", lng.value_as_string(1));
1372 assert_eq!("-1.778197", lng.value_as_string(2));
1373 assert_eq!("-3.179090", lng.value_as_string(3));
1374 assert_eq!("-3.179090", lng.value_as_string(4));
1375 assert_eq!("0.290472", lng.value_as_string(5));
1376 assert_eq!("0.290472", lng.value_as_string(6));
1377 assert_eq!("0.290472", lng.value_as_string(7));
1378 assert_eq!("0.290472", lng.value_as_string(8));
1379 assert_eq!("0.290472", lng.value_as_string(9));
1380 }
1381
1382 #[test]
1383 fn test_csv_from_buf_reader() {
1384 let schema = Schema::new(vec![
1385 Field::new("city", DataType::Utf8, false),
1386 Field::new("lat", DataType::Float64, false),
1387 Field::new("lng", DataType::Float64, false),
1388 ]);
1389
1390 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1391 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1392 let both_files = file_with_headers
1393 .chain(Cursor::new("\n".to_string()))
1394 .chain(file_without_headers);
1395 let mut csv = ReaderBuilder::new(Arc::new(schema))
1396 .with_header(true)
1397 .build(both_files)
1398 .unwrap();
1399 let batch = csv.next().unwrap().unwrap();
1400 assert_eq!(74, batch.num_rows());
1401 assert_eq!(3, batch.num_columns());
1402 }
1403
1404 #[test]
1405 fn test_csv_with_schema_inference() {
1406 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1407
1408 let (schema, _) = Format::default()
1409 .with_header(true)
1410 .infer_schema(&mut file, None)
1411 .unwrap();
1412
1413 file.rewind().unwrap();
1414 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1415
1416 let mut csv = builder.build(file).unwrap();
1417 let expected_schema = Schema::new(vec![
1418 Field::new("city", DataType::Utf8, true),
1419 Field::new("lat", DataType::Float64, true),
1420 Field::new("lng", DataType::Float64, true),
1421 ]);
1422 assert_eq!(Arc::new(expected_schema), csv.schema());
1423 let batch = csv.next().unwrap().unwrap();
1424 assert_eq!(37, batch.num_rows());
1425 assert_eq!(3, batch.num_columns());
1426
1427 let lat = batch
1429 .column(1)
1430 .as_any()
1431 .downcast_ref::<Float64Array>()
1432 .unwrap();
1433 assert_eq!(57.653484, lat.value(0));
1434
1435 let city = batch
1437 .column(0)
1438 .as_any()
1439 .downcast_ref::<StringArray>()
1440 .unwrap();
1441
1442 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1443 }
1444
1445 #[test]
1446 fn test_csv_with_schema_inference_no_headers() {
1447 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1448
1449 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1450 file.rewind().unwrap();
1451
1452 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1453
1454 let schema = csv.schema();
1456 assert_eq!("column_1", schema.field(0).name());
1457 assert_eq!("column_2", schema.field(1).name());
1458 assert_eq!("column_3", schema.field(2).name());
1459 let batch = csv.next().unwrap().unwrap();
1460 let batch_schema = batch.schema();
1461
1462 assert_eq!(schema, batch_schema);
1463 assert_eq!(37, batch.num_rows());
1464 assert_eq!(3, batch.num_columns());
1465
1466 let lat = batch
1468 .column(1)
1469 .as_any()
1470 .downcast_ref::<Float64Array>()
1471 .unwrap();
1472 assert_eq!(57.653484, lat.value(0));
1473
1474 let city = batch
1476 .column(0)
1477 .as_any()
1478 .downcast_ref::<StringArray>()
1479 .unwrap();
1480
1481 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1482 }
1483
1484 #[test]
1485 fn test_csv_builder_with_bounds() {
1486 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1487
1488 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1490 file.rewind().unwrap();
1491 let mut csv = ReaderBuilder::new(Arc::new(schema))
1492 .with_bounds(0, 2)
1493 .build(file)
1494 .unwrap();
1495 let batch = csv.next().unwrap().unwrap();
1496
1497 let city = batch
1499 .column(0)
1500 .as_any()
1501 .downcast_ref::<StringArray>()
1502 .unwrap();
1503
1504 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1506
1507 let result = std::panic::catch_unwind(|| city.value(13));
1510 assert!(result.is_err());
1511 }
1512
1513 #[test]
1514 fn test_csv_with_projection() {
1515 let schema = Arc::new(Schema::new(vec![
1516 Field::new("city", DataType::Utf8, false),
1517 Field::new("lat", DataType::Float64, false),
1518 Field::new("lng", DataType::Float64, false),
1519 ]));
1520
1521 let file = File::open("test/data/uk_cities.csv").unwrap();
1522
1523 let mut csv = ReaderBuilder::new(schema)
1524 .with_projection(vec![0, 1])
1525 .build(file)
1526 .unwrap();
1527
1528 let projected_schema = Arc::new(Schema::new(vec![
1529 Field::new("city", DataType::Utf8, false),
1530 Field::new("lat", DataType::Float64, false),
1531 ]));
1532 assert_eq!(projected_schema, csv.schema());
1533 let batch = csv.next().unwrap().unwrap();
1534 assert_eq!(projected_schema, batch.schema());
1535 assert_eq!(37, batch.num_rows());
1536 assert_eq!(2, batch.num_columns());
1537 }
1538
1539 #[test]
1540 fn test_csv_with_dictionary() {
1541 let schema = Arc::new(Schema::new(vec![
1542 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1543 Field::new("lat", DataType::Float64, false),
1544 Field::new("lng", DataType::Float64, false),
1545 ]));
1546
1547 let file = File::open("test/data/uk_cities.csv").unwrap();
1548
1549 let mut csv = ReaderBuilder::new(schema)
1550 .with_projection(vec![0, 1])
1551 .build(file)
1552 .unwrap();
1553
1554 let projected_schema = Arc::new(Schema::new(vec![
1555 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1556 Field::new("lat", DataType::Float64, false),
1557 ]));
1558 assert_eq!(projected_schema, csv.schema());
1559 let batch = csv.next().unwrap().unwrap();
1560 assert_eq!(projected_schema, batch.schema());
1561 assert_eq!(37, batch.num_rows());
1562 assert_eq!(2, batch.num_columns());
1563
1564 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1565 let strings = strings.as_string::<i32>();
1566
1567 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1568 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1569 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1570 }
1571
1572 #[test]
1573 fn test_csv_with_nullable_dictionary() {
1574 let offset_type = vec![
1575 DataType::Int8,
1576 DataType::Int16,
1577 DataType::Int32,
1578 DataType::Int64,
1579 DataType::UInt8,
1580 DataType::UInt16,
1581 DataType::UInt32,
1582 DataType::UInt64,
1583 ];
1584 for data_type in offset_type {
1585 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1586 let dictionary_type =
1587 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1588 let schema = Arc::new(Schema::new(vec![
1589 Field::new("id", DataType::Utf8, false),
1590 Field::new("name", dictionary_type.clone(), true),
1591 ]));
1592
1593 let mut csv = ReaderBuilder::new(schema)
1594 .build(file.try_clone().unwrap())
1595 .unwrap();
1596
1597 let batch = csv.next().unwrap().unwrap();
1598 assert_eq!(3, batch.num_rows());
1599 assert_eq!(2, batch.num_columns());
1600
1601 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1602 assert!(!names.is_null(2));
1603 assert!(names.is_null(1));
1604 }
1605 }
1606 #[test]
1607 fn test_nulls() {
1608 let schema = Arc::new(Schema::new(vec![
1609 Field::new("c_int", DataType::UInt64, false),
1610 Field::new("c_float", DataType::Float32, true),
1611 Field::new("c_string", DataType::Utf8, true),
1612 Field::new("c_bool", DataType::Boolean, false),
1613 ]));
1614
1615 let file = File::open("test/data/null_test.csv").unwrap();
1616
1617 let mut csv = ReaderBuilder::new(schema)
1618 .with_header(true)
1619 .build(file)
1620 .unwrap();
1621
1622 let batch = csv.next().unwrap().unwrap();
1623
1624 assert!(!batch.column(1).is_null(0));
1625 assert!(!batch.column(1).is_null(1));
1626 assert!(batch.column(1).is_null(2));
1627 assert!(!batch.column(1).is_null(3));
1628 assert!(!batch.column(1).is_null(4));
1629 }
1630
1631 #[test]
1632 fn test_init_nulls() {
1633 let schema = Arc::new(Schema::new(vec![
1634 Field::new("c_int", DataType::UInt64, true),
1635 Field::new("c_float", DataType::Float32, true),
1636 Field::new("c_string", DataType::Utf8, true),
1637 Field::new("c_bool", DataType::Boolean, true),
1638 Field::new("c_null", DataType::Null, true),
1639 ]));
1640 let file = File::open("test/data/init_null_test.csv").unwrap();
1641
1642 let mut csv = ReaderBuilder::new(schema)
1643 .with_header(true)
1644 .build(file)
1645 .unwrap();
1646
1647 let batch = csv.next().unwrap().unwrap();
1648
1649 assert!(batch.column(1).is_null(0));
1650 assert!(!batch.column(1).is_null(1));
1651 assert!(batch.column(1).is_null(2));
1652 assert!(!batch.column(1).is_null(3));
1653 assert!(!batch.column(1).is_null(4));
1654 }
1655
1656 #[test]
1657 fn test_init_nulls_with_inference() {
1658 let format = Format::default().with_header(true).with_delimiter(b',');
1659
1660 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1661 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1662 file.rewind().unwrap();
1663
1664 let expected_schema = Schema::new(vec![
1665 Field::new("c_int", DataType::Int64, true),
1666 Field::new("c_float", DataType::Float64, true),
1667 Field::new("c_string", DataType::Utf8, true),
1668 Field::new("c_bool", DataType::Boolean, true),
1669 Field::new("c_null", DataType::Null, true),
1670 ]);
1671 assert_eq!(schema, expected_schema);
1672
1673 let mut csv = ReaderBuilder::new(Arc::new(schema))
1674 .with_format(format)
1675 .build(file)
1676 .unwrap();
1677
1678 let batch = csv.next().unwrap().unwrap();
1679
1680 assert!(batch.column(1).is_null(0));
1681 assert!(!batch.column(1).is_null(1));
1682 assert!(batch.column(1).is_null(2));
1683 assert!(!batch.column(1).is_null(3));
1684 assert!(!batch.column(1).is_null(4));
1685 }
1686
1687 #[test]
1688 fn test_custom_nulls() {
1689 let schema = Arc::new(Schema::new(vec![
1690 Field::new("c_int", DataType::UInt64, true),
1691 Field::new("c_float", DataType::Float32, true),
1692 Field::new("c_string", DataType::Utf8, true),
1693 Field::new("c_bool", DataType::Boolean, true),
1694 ]));
1695
1696 let file = File::open("test/data/custom_null_test.csv").unwrap();
1697
1698 let null_regex = Regex::new("^nil$").unwrap();
1699
1700 let mut csv = ReaderBuilder::new(schema)
1701 .with_header(true)
1702 .with_null_regex(null_regex)
1703 .build(file)
1704 .unwrap();
1705
1706 let batch = csv.next().unwrap().unwrap();
1707
1708 assert!(batch.column(0).is_null(1));
1710 assert!(batch.column(1).is_null(2));
1711 assert!(batch.column(3).is_null(4));
1712 assert!(batch.column(2).is_null(3));
1713 assert!(!batch.column(2).is_null(4));
1714 }
1715
1716 #[test]
1717 fn test_nulls_with_inference() {
1718 let mut file = File::open("test/data/various_types.csv").unwrap();
1719 let format = Format::default().with_header(true).with_delimiter(b'|');
1720
1721 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1722 file.rewind().unwrap();
1723
1724 let builder = ReaderBuilder::new(Arc::new(schema))
1725 .with_format(format)
1726 .with_batch_size(512)
1727 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1728
1729 let mut csv = builder.build(file).unwrap();
1730 let batch = csv.next().unwrap().unwrap();
1731
1732 assert_eq!(10, batch.num_rows());
1733 assert_eq!(6, batch.num_columns());
1734
1735 let schema = batch.schema();
1736
1737 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1738 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1739 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1740 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1741 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1742 assert_eq!(
1743 &DataType::Timestamp(TimeUnit::Second, None),
1744 schema.field(5).data_type()
1745 );
1746
1747 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1748 assert_eq!(
1749 names,
1750 vec![
1751 "c_int",
1752 "c_float",
1753 "c_string",
1754 "c_bool",
1755 "c_date",
1756 "c_datetime"
1757 ]
1758 );
1759
1760 assert!(schema.field(0).is_nullable());
1761 assert!(schema.field(1).is_nullable());
1762 assert!(schema.field(2).is_nullable());
1763 assert!(schema.field(3).is_nullable());
1764 assert!(schema.field(4).is_nullable());
1765 assert!(schema.field(5).is_nullable());
1766
1767 assert!(!batch.column(1).is_null(0));
1768 assert!(!batch.column(1).is_null(1));
1769 assert!(batch.column(1).is_null(2));
1770 assert!(!batch.column(1).is_null(3));
1771 assert!(!batch.column(1).is_null(4));
1772 }
1773
1774 #[test]
1775 fn test_custom_nulls_with_inference() {
1776 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1777
1778 let null_regex = Regex::new("^nil$").unwrap();
1779
1780 let format = Format::default()
1781 .with_header(true)
1782 .with_null_regex(null_regex);
1783
1784 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1785 file.rewind().unwrap();
1786
1787 let expected_schema = Schema::new(vec![
1788 Field::new("c_int", DataType::Int64, true),
1789 Field::new("c_float", DataType::Float64, true),
1790 Field::new("c_string", DataType::Utf8, true),
1791 Field::new("c_bool", DataType::Boolean, true),
1792 ]);
1793
1794 assert_eq!(schema, expected_schema);
1795
1796 let builder = ReaderBuilder::new(Arc::new(schema))
1797 .with_format(format)
1798 .with_batch_size(512)
1799 .with_projection(vec![0, 1, 2, 3]);
1800
1801 let mut csv = builder.build(file).unwrap();
1802 let batch = csv.next().unwrap().unwrap();
1803
1804 assert_eq!(5, batch.num_rows());
1805 assert_eq!(4, batch.num_columns());
1806
1807 assert_eq!(batch.schema().as_ref(), &expected_schema);
1808 }
1809
1810 #[test]
1811 fn test_scientific_notation_with_inference() {
1812 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1813 let format = Format::default().with_header(false).with_delimiter(b',');
1814
1815 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1816 file.rewind().unwrap();
1817
1818 let builder = ReaderBuilder::new(Arc::new(schema))
1819 .with_format(format)
1820 .with_batch_size(512)
1821 .with_projection(vec![0, 1]);
1822
1823 let mut csv = builder.build(file).unwrap();
1824 let batch = csv.next().unwrap().unwrap();
1825
1826 let schema = batch.schema();
1827
1828 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1829 }
1830
1831 fn invalid_csv_helper(file_name: &str) -> String {
1832 let file = File::open(file_name).unwrap();
1833 let schema = Schema::new(vec![
1834 Field::new("c_int", DataType::UInt64, false),
1835 Field::new("c_float", DataType::Float32, false),
1836 Field::new("c_string", DataType::Utf8, false),
1837 Field::new("c_bool", DataType::Boolean, false),
1838 ]);
1839
1840 let builder = ReaderBuilder::new(Arc::new(schema))
1841 .with_header(true)
1842 .with_delimiter(b'|')
1843 .with_batch_size(512)
1844 .with_projection(vec![0, 1, 2, 3]);
1845
1846 let mut csv = builder.build(file).unwrap();
1847
1848 csv.next().unwrap().unwrap_err().to_string()
1849 }
1850
1851 #[test]
1852 fn test_parse_invalid_csv_float() {
1853 let file_name = "test/data/various_invalid_types/invalid_float.csv";
1854
1855 let error = invalid_csv_helper(file_name);
1856 assert_eq!(
1857 "Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'",
1858 error
1859 );
1860 }
1861
1862 #[test]
1863 fn test_parse_invalid_csv_int() {
1864 let file_name = "test/data/various_invalid_types/invalid_int.csv";
1865
1866 let error = invalid_csv_helper(file_name);
1867 assert_eq!(
1868 "Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'",
1869 error
1870 );
1871 }
1872
1873 #[test]
1874 fn test_parse_invalid_csv_bool() {
1875 let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1876
1877 let error = invalid_csv_helper(file_name);
1878 assert_eq!(
1879 "Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'",
1880 error
1881 );
1882 }
1883
1884 fn infer_field_schema(string: &str) -> DataType {
1886 let mut v = InferredDataType::default();
1887 v.update(string);
1888 v.get()
1889 }
1890
1891 #[test]
1892 fn test_infer_field_schema() {
1893 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1894 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1895 assert_eq!(infer_field_schema("10"), DataType::Int64);
1896 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1897 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1898 assert_eq!(infer_field_schema("2."), DataType::Float64);
1899 assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1900 assert_eq!(infer_field_schema("nan"), DataType::Float64);
1901 assert_eq!(infer_field_schema("inf"), DataType::Float64);
1902 assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1903 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1904 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1905 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1906 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1907 assert_eq!(
1908 infer_field_schema("2020-11-08T14:20:01"),
1909 DataType::Timestamp(TimeUnit::Second, None)
1910 );
1911 assert_eq!(
1912 infer_field_schema("2020-11-08 14:20:01"),
1913 DataType::Timestamp(TimeUnit::Second, None)
1914 );
1915 assert_eq!(
1916 infer_field_schema("2020-11-08 14:20:01"),
1917 DataType::Timestamp(TimeUnit::Second, None)
1918 );
1919 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1920 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1921 assert_eq!(
1922 infer_field_schema("2021-12-19 13:12:30.921"),
1923 DataType::Timestamp(TimeUnit::Millisecond, None)
1924 );
1925 assert_eq!(
1926 infer_field_schema("2021-12-19T13:12:30.123456789"),
1927 DataType::Timestamp(TimeUnit::Nanosecond, None)
1928 );
1929 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1930 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1931 }
1932
1933 #[test]
1934 fn parse_date32() {
1935 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1936 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1937 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1938 }
1939
1940 #[test]
1941 fn parse_time() {
1942 assert_eq!(
1943 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1944 Some(601_123_456_789)
1945 );
1946 assert_eq!(
1947 Time64MicrosecondType::parse("12:10:01.123456 am"),
1948 Some(601_123_456)
1949 );
1950 assert_eq!(
1951 Time32MillisecondType::parse("2:10:01.12 PM"),
1952 Some(51_001_120)
1953 );
1954 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1955 }
1956
1957 #[test]
1958 fn parse_date64() {
1959 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1960 assert_eq!(
1961 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1962 1542129070000
1963 );
1964 assert_eq!(
1965 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1966 1542129070011
1967 );
1968 assert_eq!(
1969 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1970 -2203932304000
1971 );
1972 assert_eq!(
1973 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1974 -2203932304000
1975 );
1976 assert_eq!(
1977 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1978 -2203932304000 - (30 * 60 * 1000)
1979 );
1980 }
1981
1982 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1983 timezone: Option<Arc<str>>,
1984 expected: &[i64],
1985 ) {
1986 let csv = [
1987 "1970-01-01T00:00:00",
1988 "1970-01-01T00:00:00Z",
1989 "1970-01-01T00:00:00+02:00",
1990 ]
1991 .join("\n");
1992 let schema = Arc::new(Schema::new(vec![Field::new(
1993 "field",
1994 DataType::Timestamp(T::UNIT, timezone.clone()),
1995 true,
1996 )]));
1997
1998 let mut decoder = ReaderBuilder::new(schema).build_decoder();
1999
2000 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2001 assert_eq!(decoded, csv.len());
2002 decoder.decode(&[]).unwrap();
2003
2004 let batch = decoder.flush().unwrap().unwrap();
2005 assert_eq!(batch.num_columns(), 1);
2006 assert_eq!(batch.num_rows(), 3);
2007 let col = batch.column(0).as_primitive::<T>();
2008 assert_eq!(col.values(), expected);
2009 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
2010 }
2011
2012 #[test]
2013 fn test_parse_timestamp() {
2014 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
2015 test_parse_timestamp_impl::<TimestampNanosecondType>(
2016 Some("+00:00".into()),
2017 &[0, 0, -7_200_000_000_000],
2018 );
2019 test_parse_timestamp_impl::<TimestampNanosecondType>(
2020 Some("-05:00".into()),
2021 &[18_000_000_000_000, 0, -7_200_000_000_000],
2022 );
2023 test_parse_timestamp_impl::<TimestampMicrosecondType>(
2024 Some("-03".into()),
2025 &[10_800_000_000, 0, -7_200_000_000],
2026 );
2027 test_parse_timestamp_impl::<TimestampMillisecondType>(
2028 Some("-03".into()),
2029 &[10_800_000, 0, -7_200_000],
2030 );
2031 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
2032 }
2033
2034 #[test]
2035 fn test_infer_schema_from_multiple_files() {
2036 let mut csv1 = NamedTempFile::new().unwrap();
2037 let mut csv2 = NamedTempFile::new().unwrap();
2038 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
2040 writeln!(csv1, "c1,c2,c3").unwrap();
2041 writeln!(csv1, "1,\"foo\",0.5").unwrap();
2042 writeln!(csv1, "3,\"bar\",1").unwrap();
2043 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
2044 writeln!(csv2, "c1,c2,c3,c4").unwrap();
2046 writeln!(csv2, "10,,3.14,true").unwrap();
2047 writeln!(csv4, "c1,c2,c3").unwrap();
2049 writeln!(csv4, "10,\"foo\",").unwrap();
2050
2051 let schema = infer_schema_from_files(
2052 &[
2053 csv3.path().to_str().unwrap().to_string(),
2054 csv1.path().to_str().unwrap().to_string(),
2055 csv2.path().to_str().unwrap().to_string(),
2056 csv4.path().to_str().unwrap().to_string(),
2057 ],
2058 b',',
2059 Some(4), true,
2061 )
2062 .unwrap();
2063
2064 assert_eq!(schema.fields().len(), 4);
2065 assert!(schema.field(0).is_nullable());
2066 assert!(schema.field(1).is_nullable());
2067 assert!(schema.field(2).is_nullable());
2068 assert!(schema.field(3).is_nullable());
2069
2070 assert_eq!(&DataType::Int64, schema.field(0).data_type());
2071 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
2072 assert_eq!(&DataType::Float64, schema.field(2).data_type());
2073 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2074 }
2075
2076 #[test]
2077 fn test_bounded() {
2078 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2079 let data = [
2080 vec!["0"],
2081 vec!["1"],
2082 vec!["2"],
2083 vec!["3"],
2084 vec!["4"],
2085 vec!["5"],
2086 vec!["6"],
2087 ];
2088
2089 let data = data
2090 .iter()
2091 .map(|x| x.join(","))
2092 .collect::<Vec<_>>()
2093 .join("\n");
2094 let data = data.as_bytes();
2095
2096 let reader = std::io::Cursor::new(data);
2097
2098 let mut csv = ReaderBuilder::new(Arc::new(schema))
2099 .with_batch_size(2)
2100 .with_projection(vec![0])
2101 .with_bounds(2, 6)
2102 .build_buffered(reader)
2103 .unwrap();
2104
2105 let batch = csv.next().unwrap().unwrap();
2106 let a = batch.column(0);
2107 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2108 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2109
2110 let batch = csv.next().unwrap().unwrap();
2111 let a = batch.column(0);
2112 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2113 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2114
2115 assert!(csv.next().is_none());
2116 }
2117
2118 #[test]
2119 fn test_empty_projection() {
2120 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2121 let data = [vec!["0"], vec!["1"]];
2122
2123 let data = data
2124 .iter()
2125 .map(|x| x.join(","))
2126 .collect::<Vec<_>>()
2127 .join("\n");
2128
2129 let mut csv = ReaderBuilder::new(Arc::new(schema))
2130 .with_batch_size(2)
2131 .with_projection(vec![])
2132 .build_buffered(Cursor::new(data.as_bytes()))
2133 .unwrap();
2134
2135 let batch = csv.next().unwrap().unwrap();
2136 assert_eq!(batch.columns().len(), 0);
2137 assert_eq!(batch.num_rows(), 2);
2138
2139 assert!(csv.next().is_none());
2140 }
2141
2142 #[test]
2143 fn test_parsing_bool() {
2144 assert_eq!(Some(true), parse_bool("true"));
2146 assert_eq!(Some(true), parse_bool("tRUe"));
2147 assert_eq!(Some(true), parse_bool("True"));
2148 assert_eq!(Some(true), parse_bool("TRUE"));
2149 assert_eq!(None, parse_bool("t"));
2150 assert_eq!(None, parse_bool("T"));
2151 assert_eq!(None, parse_bool(""));
2152
2153 assert_eq!(Some(false), parse_bool("false"));
2154 assert_eq!(Some(false), parse_bool("fALse"));
2155 assert_eq!(Some(false), parse_bool("False"));
2156 assert_eq!(Some(false), parse_bool("FALSE"));
2157 assert_eq!(None, parse_bool("f"));
2158 assert_eq!(None, parse_bool("F"));
2159 assert_eq!(None, parse_bool(""));
2160 }
2161
2162 #[test]
2163 fn test_parsing_float() {
2164 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2165 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2166 assert_eq!(Some(12.0), Float64Type::parse("12"));
2167 assert_eq!(Some(0.0), Float64Type::parse("0"));
2168 assert_eq!(Some(2.0), Float64Type::parse("2."));
2169 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2170 assert!(Float64Type::parse("nan").unwrap().is_nan());
2171 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2172 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2173 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2174 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2175 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2176 assert_eq!(None, Float64Type::parse(""));
2177 assert_eq!(None, Float64Type::parse("dd"));
2178 assert_eq!(None, Float64Type::parse("12.34.56"));
2179 }
2180
2181 #[test]
2182 fn test_non_std_quote() {
2183 let schema = Schema::new(vec![
2184 Field::new("text1", DataType::Utf8, false),
2185 Field::new("text2", DataType::Utf8, false),
2186 ]);
2187 let builder = ReaderBuilder::new(Arc::new(schema))
2188 .with_header(false)
2189 .with_quote(b'~'); let mut csv_text = Vec::new();
2192 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2193 for index in 0..10 {
2194 let text1 = format!("id{index:}");
2195 let text2 = format!("value{index:}");
2196 csv_writer
2197 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2198 .unwrap();
2199 }
2200 let mut csv_reader = std::io::Cursor::new(&csv_text);
2201 let mut reader = builder.build(&mut csv_reader).unwrap();
2202 let batch = reader.next().unwrap().unwrap();
2203 let col0 = batch.column(0);
2204 assert_eq!(col0.len(), 10);
2205 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2206 assert_eq!(col0_arr.value(0), "id0");
2207 let col1 = batch.column(1);
2208 assert_eq!(col1.len(), 10);
2209 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2210 assert_eq!(col1_arr.value(5), "value5");
2211 }
2212
2213 #[test]
2214 fn test_non_std_escape() {
2215 let schema = Schema::new(vec![
2216 Field::new("text1", DataType::Utf8, false),
2217 Field::new("text2", DataType::Utf8, false),
2218 ]);
2219 let builder = ReaderBuilder::new(Arc::new(schema))
2220 .with_header(false)
2221 .with_escape(b'\\'); let mut csv_text = Vec::new();
2224 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2225 for index in 0..10 {
2226 let text1 = format!("id{index:}");
2227 let text2 = format!("value\\\"{index:}");
2228 csv_writer
2229 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2230 .unwrap();
2231 }
2232 let mut csv_reader = std::io::Cursor::new(&csv_text);
2233 let mut reader = builder.build(&mut csv_reader).unwrap();
2234 let batch = reader.next().unwrap().unwrap();
2235 let col0 = batch.column(0);
2236 assert_eq!(col0.len(), 10);
2237 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2238 assert_eq!(col0_arr.value(0), "id0");
2239 let col1 = batch.column(1);
2240 assert_eq!(col1.len(), 10);
2241 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2242 assert_eq!(col1_arr.value(5), "value\"5");
2243 }
2244
2245 #[test]
2246 fn test_non_std_terminator() {
2247 let schema = Schema::new(vec![
2248 Field::new("text1", DataType::Utf8, false),
2249 Field::new("text2", DataType::Utf8, false),
2250 ]);
2251 let builder = ReaderBuilder::new(Arc::new(schema))
2252 .with_header(false)
2253 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2256 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2257 for index in 0..10 {
2258 let text1 = format!("id{index:}");
2259 let text2 = format!("value{index:}");
2260 csv_writer
2261 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2262 .unwrap();
2263 }
2264 let mut csv_reader = std::io::Cursor::new(&csv_text);
2265 let mut reader = builder.build(&mut csv_reader).unwrap();
2266 let batch = reader.next().unwrap().unwrap();
2267 let col0 = batch.column(0);
2268 assert_eq!(col0.len(), 10);
2269 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2270 assert_eq!(col0_arr.value(0), "id0");
2271 let col1 = batch.column(1);
2272 assert_eq!(col1.len(), 10);
2273 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2274 assert_eq!(col1_arr.value(5), "value5");
2275 }
2276
2277 #[test]
2278 fn test_header_bounds() {
2279 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2280 let tests = [
2281 (None, false, 5),
2282 (None, true, 4),
2283 (Some((0, 4)), false, 4),
2284 (Some((1, 4)), false, 3),
2285 (Some((0, 4)), true, 4),
2286 (Some((1, 4)), true, 3),
2287 ];
2288 let schema = Arc::new(Schema::new(vec![
2289 Field::new("a", DataType::Utf8, false),
2290 Field::new("a", DataType::Utf8, false),
2291 ]));
2292
2293 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2294 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2295 if let Some((start, end)) = bounds {
2296 reader = reader.with_bounds(start, end);
2297 }
2298 let b = reader
2299 .build_buffered(Cursor::new(csv.as_bytes()))
2300 .unwrap()
2301 .next()
2302 .unwrap()
2303 .unwrap();
2304 assert_eq!(b.num_rows(), expected, "{idx}");
2305 }
2306 }
2307
2308 #[test]
2309 fn test_null_boolean() {
2310 let csv = "true,false\nFalse,True\n,True\nFalse,";
2311 let schema = Arc::new(Schema::new(vec![
2312 Field::new("a", DataType::Boolean, true),
2313 Field::new("a", DataType::Boolean, true),
2314 ]));
2315
2316 let b = ReaderBuilder::new(schema)
2317 .build_buffered(Cursor::new(csv.as_bytes()))
2318 .unwrap()
2319 .next()
2320 .unwrap()
2321 .unwrap();
2322
2323 assert_eq!(b.num_rows(), 4);
2324 assert_eq!(b.num_columns(), 2);
2325
2326 let c = b.column(0).as_boolean();
2327 assert_eq!(c.null_count(), 1);
2328 assert!(c.value(0));
2329 assert!(!c.value(1));
2330 assert!(c.is_null(2));
2331 assert!(!c.value(3));
2332
2333 let c = b.column(1).as_boolean();
2334 assert_eq!(c.null_count(), 1);
2335 assert!(!c.value(0));
2336 assert!(c.value(1));
2337 assert!(c.value(2));
2338 assert!(c.is_null(3));
2339 }
2340
2341 #[test]
2342 fn test_truncated_rows() {
2343 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2344 let schema = Arc::new(Schema::new(vec![
2345 Field::new("a", DataType::Int32, true),
2346 Field::new("b", DataType::Int32, true),
2347 Field::new("c", DataType::Int32, true),
2348 ]));
2349
2350 let reader = ReaderBuilder::new(schema.clone())
2351 .with_header(true)
2352 .with_truncated_rows(true)
2353 .build(Cursor::new(data))
2354 .unwrap();
2355
2356 let batches = reader.collect::<Result<Vec<_>, _>>();
2357 assert!(batches.is_ok());
2358 let batch = batches.unwrap().into_iter().next().unwrap();
2359 assert_eq!(batch.num_rows(), 3);
2361
2362 let reader = ReaderBuilder::new(schema.clone())
2363 .with_header(true)
2364 .with_truncated_rows(false)
2365 .build(Cursor::new(data))
2366 .unwrap();
2367
2368 let batches = reader.collect::<Result<Vec<_>, _>>();
2369 assert!(match batches {
2370 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2371 _ => false,
2372 });
2373 }
2374
2375 #[test]
2376 fn test_truncated_rows_csv() {
2377 let file = File::open("test/data/truncated_rows.csv").unwrap();
2378 let schema = Arc::new(Schema::new(vec![
2379 Field::new("Name", DataType::Utf8, true),
2380 Field::new("Age", DataType::UInt32, true),
2381 Field::new("Occupation", DataType::Utf8, true),
2382 Field::new("DOB", DataType::Date32, true),
2383 ]));
2384 let reader = ReaderBuilder::new(schema.clone())
2385 .with_header(true)
2386 .with_batch_size(24)
2387 .with_truncated_rows(true);
2388 let csv = reader.build(file).unwrap();
2389 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2390
2391 assert_eq!(batches.len(), 1);
2392 let batch = &batches[0];
2393 assert_eq!(batch.num_rows(), 6);
2394 assert_eq!(batch.num_columns(), 4);
2395 let name = batch
2396 .column(0)
2397 .as_any()
2398 .downcast_ref::<StringArray>()
2399 .unwrap();
2400 let age = batch
2401 .column(1)
2402 .as_any()
2403 .downcast_ref::<UInt32Array>()
2404 .unwrap();
2405 let occupation = batch
2406 .column(2)
2407 .as_any()
2408 .downcast_ref::<StringArray>()
2409 .unwrap();
2410 let dob = batch
2411 .column(3)
2412 .as_any()
2413 .downcast_ref::<Date32Array>()
2414 .unwrap();
2415
2416 assert_eq!(name.value(0), "A1");
2417 assert_eq!(name.value(1), "B2");
2418 assert!(name.is_null(2));
2419 assert_eq!(name.value(3), "C3");
2420 assert_eq!(name.value(4), "D4");
2421 assert_eq!(name.value(5), "E5");
2422
2423 assert_eq!(age.value(0), 34);
2424 assert_eq!(age.value(1), 29);
2425 assert!(age.is_null(2));
2426 assert_eq!(age.value(3), 45);
2427 assert!(age.is_null(4));
2428 assert_eq!(age.value(5), 31);
2429
2430 assert_eq!(occupation.value(0), "Engineer");
2431 assert_eq!(occupation.value(1), "Doctor");
2432 assert!(occupation.is_null(2));
2433 assert_eq!(occupation.value(3), "Artist");
2434 assert!(occupation.is_null(4));
2435 assert!(occupation.is_null(5));
2436
2437 assert_eq!(dob.value(0), 5675);
2438 assert!(dob.is_null(1));
2439 assert!(dob.is_null(2));
2440 assert_eq!(dob.value(3), -1858);
2441 assert!(dob.is_null(4));
2442 assert!(dob.is_null(5));
2443 }
2444
2445 #[test]
2446 fn test_truncated_rows_not_nullable_error() {
2447 let data = "a,b,c\n1,2,3\n4,5";
2448 let schema = Arc::new(Schema::new(vec![
2449 Field::new("a", DataType::Int32, false),
2450 Field::new("b", DataType::Int32, false),
2451 Field::new("c", DataType::Int32, false),
2452 ]));
2453
2454 let reader = ReaderBuilder::new(schema.clone())
2455 .with_header(true)
2456 .with_truncated_rows(true)
2457 .build(Cursor::new(data))
2458 .unwrap();
2459
2460 let batches = reader.collect::<Result<Vec<_>, _>>();
2461 assert!(match batches {
2462 Err(ArrowError::InvalidArgumentError(e)) =>
2463 e.to_string().contains("contains null values"),
2464 _ => false,
2465 });
2466 }
2467
2468 #[test]
2469 fn test_buffered() {
2470 let tests = [
2471 ("test/data/uk_cities.csv", false, 37),
2472 ("test/data/various_types.csv", true, 10),
2473 ("test/data/decimal_test.csv", false, 10),
2474 ];
2475
2476 for (path, has_header, expected_rows) in tests {
2477 let (schema, _) = Format::default()
2478 .infer_schema(File::open(path).unwrap(), None)
2479 .unwrap();
2480 let schema = Arc::new(schema);
2481
2482 for batch_size in [1, 4] {
2483 for capacity in [1, 3, 7, 100] {
2484 let reader = ReaderBuilder::new(schema.clone())
2485 .with_batch_size(batch_size)
2486 .with_header(has_header)
2487 .build(File::open(path).unwrap())
2488 .unwrap();
2489
2490 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2491
2492 assert_eq!(
2493 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2494 expected_rows
2495 );
2496
2497 let buffered =
2498 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2499
2500 let reader = ReaderBuilder::new(schema.clone())
2501 .with_batch_size(batch_size)
2502 .with_header(has_header)
2503 .build_buffered(buffered)
2504 .unwrap();
2505
2506 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2507 assert_eq!(expected, actual)
2508 }
2509 }
2510 }
2511 }
2512
2513 fn err_test(csv: &[u8], expected: &str) {
2514 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2515 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2516 let b = ReaderBuilder::new(schema)
2517 .with_batch_size(2)
2518 .build_buffered(buffer)
2519 .unwrap();
2520 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2521 assert_eq!(err, expected)
2522 }
2523
2524 let schema_utf8 = Arc::new(Schema::new(vec![
2525 Field::new("text1", DataType::Utf8, true),
2526 Field::new("text2", DataType::Utf8, true),
2527 ]));
2528 err_test_with_schema(csv, expected, schema_utf8);
2529
2530 let schema_utf8view = Arc::new(Schema::new(vec![
2531 Field::new("text1", DataType::Utf8View, true),
2532 Field::new("text2", DataType::Utf8View, true),
2533 ]));
2534 err_test_with_schema(csv, expected, schema_utf8view);
2535 }
2536
2537 #[test]
2538 fn test_invalid_utf8() {
2539 err_test(
2540 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2541 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2542 );
2543
2544 err_test(
2545 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2546 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2547 );
2548
2549 err_test(
2550 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2551 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2552 );
2553
2554 err_test(
2555 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2556 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2557 );
2558 }
2559
2560 struct InstrumentedRead<R> {
2561 r: R,
2562 fill_count: usize,
2563 fill_sizes: Vec<usize>,
2564 }
2565
2566 impl<R> InstrumentedRead<R> {
2567 fn new(r: R) -> Self {
2568 Self {
2569 r,
2570 fill_count: 0,
2571 fill_sizes: vec![],
2572 }
2573 }
2574 }
2575
2576 impl<R: Seek> Seek for InstrumentedRead<R> {
2577 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2578 self.r.seek(pos)
2579 }
2580 }
2581
2582 impl<R: BufRead> Read for InstrumentedRead<R> {
2583 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2584 self.r.read(buf)
2585 }
2586 }
2587
2588 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2589 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2590 self.fill_count += 1;
2591 let buf = self.r.fill_buf()?;
2592 self.fill_sizes.push(buf.len());
2593 Ok(buf)
2594 }
2595
2596 fn consume(&mut self, amt: usize) {
2597 self.r.consume(amt)
2598 }
2599 }
2600
2601 #[test]
2602 fn test_io() {
2603 let schema = Arc::new(Schema::new(vec![
2604 Field::new("a", DataType::Utf8, false),
2605 Field::new("b", DataType::Utf8, false),
2606 ]));
2607 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2608 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2609 let reader = ReaderBuilder::new(schema)
2610 .with_batch_size(3)
2611 .build_buffered(&mut read)
2612 .unwrap();
2613
2614 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2615 assert_eq!(batches.len(), 2);
2616 assert_eq!(batches[0].num_rows(), 3);
2617 assert_eq!(batches[1].num_rows(), 1);
2618
2619 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2625 assert_eq!(read.fill_count, 4);
2626 }
2627
2628 #[test]
2629 fn test_inference() {
2630 let cases: &[(&[&str], DataType)] = &[
2631 (&[], DataType::Null),
2632 (&["false", "12"], DataType::Utf8),
2633 (&["12", "cupcakes"], DataType::Utf8),
2634 (&["12", "12.4"], DataType::Float64),
2635 (&["14050", "24332"], DataType::Int64),
2636 (&["14050.0", "true"], DataType::Utf8),
2637 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2638 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2639 (
2640 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2641 DataType::Timestamp(TimeUnit::Second, None),
2642 ),
2643 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2644 (
2645 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2646 DataType::Timestamp(TimeUnit::Second, None),
2647 ),
2648 (
2649 &[
2650 "2020-03-19",
2651 "2020-03-19 02:00:00",
2652 "2020-03-19 00:00:00.000",
2653 ],
2654 DataType::Timestamp(TimeUnit::Millisecond, None),
2655 ),
2656 (
2657 &[
2658 "2020-03-19",
2659 "2020-03-19 02:00:00",
2660 "2020-03-19 00:00:00.000000",
2661 ],
2662 DataType::Timestamp(TimeUnit::Microsecond, None),
2663 ),
2664 (
2665 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2666 DataType::Timestamp(TimeUnit::Second, None),
2667 ),
2668 (
2669 &[
2670 "2020-03-19",
2671 "2020-03-19 02:00:00+02:00",
2672 "2020-03-19 02:00:00Z",
2673 "2020-03-19 02:00:00.12Z",
2674 ],
2675 DataType::Timestamp(TimeUnit::Millisecond, None),
2676 ),
2677 (
2678 &[
2679 "2020-03-19",
2680 "2020-03-19 02:00:00.000000000",
2681 "2020-03-19 00:00:00.000000",
2682 ],
2683 DataType::Timestamp(TimeUnit::Nanosecond, None),
2684 ),
2685 ];
2686
2687 for (values, expected) in cases {
2688 let mut t = InferredDataType::default();
2689 for v in *values {
2690 t.update(v)
2691 }
2692 assert_eq!(&t.get(), expected, "{values:?}")
2693 }
2694 }
2695
2696 #[test]
2697 fn test_record_length_mismatch() {
2698 let csv = "\
2699 a,b,c\n\
2700 1,2,3\n\
2701 4,5\n\
2702 6,7,8";
2703 let mut read = Cursor::new(csv.as_bytes());
2704 let result = Format::default()
2705 .with_header(true)
2706 .infer_schema(&mut read, None);
2707 assert!(result.is_err());
2708 assert_eq!(
2710 result.err().unwrap().to_string(),
2711 "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3"
2712 );
2713 }
2714
2715 #[test]
2716 fn test_comment() {
2717 let schema = Schema::new(vec![
2718 Field::new("a", DataType::Int8, false),
2719 Field::new("b", DataType::Int8, false),
2720 ]);
2721
2722 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2723 let mut read = Cursor::new(csv.as_bytes());
2724 let reader = ReaderBuilder::new(Arc::new(schema))
2725 .with_comment(b'#')
2726 .build(&mut read)
2727 .unwrap();
2728
2729 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2730 assert_eq!(batches.len(), 1);
2731 let b = batches.first().unwrap();
2732 assert_eq!(b.num_columns(), 2);
2733 assert_eq!(
2734 b.column(0)
2735 .as_any()
2736 .downcast_ref::<Int8Array>()
2737 .unwrap()
2738 .values(),
2739 &vec![1, 11]
2740 );
2741 assert_eq!(
2742 b.column(1)
2743 .as_any()
2744 .downcast_ref::<Int8Array>()
2745 .unwrap()
2746 .values(),
2747 &vec![2, 22]
2748 );
2749 }
2750
2751 #[test]
2752 fn test_parse_string_view_single_column() {
2753 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2754 let schema = Arc::new(Schema::new(vec![Field::new(
2755 "c1",
2756 DataType::Utf8View,
2757 true,
2758 )]));
2759
2760 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2761
2762 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2763 assert_eq!(decoded, csv.len());
2764 decoder.decode(&[]).unwrap();
2765
2766 let batch = decoder.flush().unwrap().unwrap();
2767 assert_eq!(batch.num_columns(), 1);
2768 assert_eq!(batch.num_rows(), 3);
2769 let col = batch.column(0).as_string_view();
2770 assert_eq!(col.data_type(), &DataType::Utf8View);
2771 assert_eq!(col.value(0), "foo");
2772 assert_eq!(col.value(1), "something_cannot_be_inlined");
2773 assert_eq!(col.value(2), "foobar");
2774 }
2775
2776 #[test]
2777 fn test_parse_string_view_multi_column() {
2778 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2779 let schema = Arc::new(Schema::new(vec![
2780 Field::new("c1", DataType::Utf8View, true),
2781 Field::new("c2", DataType::Utf8View, true),
2782 ]));
2783
2784 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2785
2786 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2787 assert_eq!(decoded, csv.len());
2788 decoder.decode(&[]).unwrap();
2789
2790 let batch = decoder.flush().unwrap().unwrap();
2791 assert_eq!(batch.num_columns(), 2);
2792 assert_eq!(batch.num_rows(), 3);
2793 let c1 = batch.column(0).as_string_view();
2794 let c2 = batch.column(1).as_string_view();
2795 assert_eq!(c1.data_type(), &DataType::Utf8View);
2796 assert_eq!(c2.data_type(), &DataType::Utf8View);
2797
2798 assert!(!c1.is_null(0));
2799 assert!(c1.is_null(1));
2800 assert!(!c1.is_null(2));
2801 assert_eq!(c1.value(0), "foo");
2802 assert_eq!(c1.value(2), "foobarfoobar");
2803
2804 assert!(c2.is_null(0));
2805 assert!(!c2.is_null(1));
2806 assert!(!c2.is_null(2));
2807 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2808 assert_eq!(c2.value(2), "bar");
2809 }
2810}