1mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use regex::{Regex, RegexSet};
136use std::fmt::{self, Debug};
137use std::fs::File;
138use std::io::{BufRead, BufReader as StdBufReader, Read};
139use std::sync::{Arc, LazyLock};
140
141use crate::map_csv_error;
142use crate::reader::records::{RecordDecoder, StringRecords};
143use arrow_array::timezone::Tz;
144
145static REGEX_SET: LazyLock<RegexSet> = LazyLock::new(|| {
147 RegexSet::new([
148 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ])
157 .unwrap()
158});
159
160#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165 #[inline]
168 fn is_null(&self, s: &str) -> bool {
169 match &self.0 {
170 Some(r) => r.is_match(s),
171 None => s.is_empty(),
172 }
173 }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178 packed: u16,
190}
191
192impl InferredDataType {
193 fn get(&self) -> DataType {
195 match self.packed {
196 0 => DataType::Null,
197 1 => DataType::Boolean,
198 2 => DataType::Int64,
199 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205 11 => DataType::Timestamp(TimeUnit::Second, None),
206 12 => DataType::Date32,
207 _ => unreachable!(),
208 },
209 _ => DataType::Utf8,
210 }
211 }
212
213 fn update(&mut self, string: &str) {
215 self.packed |= if string.starts_with('"') {
216 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219 1 << 8
221 } else {
222 1 << m
223 }
224 } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225 1 << 2 } else {
227 1 << 8 }
229 }
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct Format {
235 header: bool,
236 delimiter: Option<u8>,
237 escape: Option<u8>,
238 quote: Option<u8>,
239 terminator: Option<u8>,
240 comment: Option<u8>,
241 null_regex: NullRegex,
242 truncated_rows: bool,
243}
244
245impl Format {
246 pub fn with_header(mut self, has_header: bool) -> Self {
250 self.header = has_header;
251 self
252 }
253
254 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256 self.delimiter = Some(delimiter);
257 self
258 }
259
260 pub fn with_escape(mut self, escape: u8) -> Self {
262 self.escape = Some(escape);
263 self
264 }
265
266 pub fn with_quote(mut self, quote: u8) -> Self {
268 self.quote = Some(quote);
269 self
270 }
271
272 pub fn with_terminator(mut self, terminator: u8) -> Self {
274 self.terminator = Some(terminator);
275 self
276 }
277
278 pub fn with_comment(mut self, comment: u8) -> Self {
282 self.comment = Some(comment);
283 self
284 }
285
286 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288 self.null_regex = NullRegex(Some(null_regex));
289 self
290 }
291
292 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299 self.truncated_rows = allow;
300 self
301 }
302
303 pub fn infer_schema<R: Read>(
310 &self,
311 reader: R,
312 max_records: Option<usize>,
313 ) -> Result<(Schema, usize), ArrowError> {
314 let mut csv_reader = self.build_reader(reader);
315
316 let headers: Vec<String> = if self.header {
319 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320 headers.iter().map(|s| s.to_string()).collect()
321 } else {
322 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323 (0..*first_record_count)
324 .map(|i| format!("column_{}", i + 1))
325 .collect()
326 };
327
328 let header_length = headers.len();
329 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332 let mut records_count = 0;
333
334 let mut record = StringRecord::new();
335 let max_records = max_records.unwrap_or(usize::MAX);
336 while records_count < max_records {
337 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338 break;
339 }
340 records_count += 1;
341
342 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345 if let Some(string) = record.get(i) {
346 if !self.null_regex.is_null(string) {
347 column_type.update(string)
348 }
349 }
350 }
351 }
352
353 let fields: Fields = column_types
355 .iter()
356 .zip(&headers)
357 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358 .collect();
359
360 Ok((Schema::new(fields), records_count))
361 }
362
363 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365 let mut builder = csv::ReaderBuilder::new();
366 builder.has_headers(self.header);
367 builder.flexible(self.truncated_rows);
368
369 if let Some(c) = self.delimiter {
370 builder.delimiter(c);
371 }
372 builder.escape(self.escape);
373 if let Some(c) = self.quote {
374 builder.quote(c);
375 }
376 if let Some(t) = self.terminator {
377 builder.terminator(csv::Terminator::Any(t));
378 }
379 if let Some(comment) = self.comment {
380 builder.comment(Some(comment));
381 }
382 builder.from_reader(reader)
383 }
384
385 fn build_parser(&self) -> csv_core::Reader {
387 let mut builder = csv_core::ReaderBuilder::new();
388 builder.escape(self.escape);
389 builder.comment(self.comment);
390
391 if let Some(c) = self.delimiter {
392 builder.delimiter(c);
393 }
394 if let Some(c) = self.quote {
395 builder.quote(c);
396 }
397 if let Some(t) = self.terminator {
398 builder.terminator(csv_core::Terminator::Any(t));
399 }
400 builder.build()
401 }
402}
403
404pub fn infer_schema_from_files(
411 files: &[String],
412 delimiter: u8,
413 max_read_records: Option<usize>,
414 has_header: bool,
415) -> Result<Schema, ArrowError> {
416 let mut schemas = vec![];
417 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418 let format = Format {
419 delimiter: Some(delimiter),
420 header: has_header,
421 ..Default::default()
422 };
423
424 for fname in files.iter() {
425 let f = File::open(fname)?;
426 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427 if records_read == 0 {
428 continue;
429 }
430 schemas.push(schema.clone());
431 records_to_read -= records_read;
432 if records_to_read == 0 {
433 break;
434 }
435 }
436
437 Schema::try_merge(schemas)
438}
439
440type Bounds = Option<(usize, usize)>;
442
443pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446pub struct BufReader<R> {
448 reader: R,
450
451 decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457 R: BufRead,
458{
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 f.debug_struct("Reader")
461 .field("decoder", &self.decoder)
462 .finish()
463 }
464}
465
466impl<R: Read> Reader<R> {
467 pub fn schema(&self) -> SchemaRef {
470 match &self.decoder.projection {
471 Some(projection) => {
472 let fields = self.decoder.schema.fields();
473 let projected = projection.iter().map(|i| fields[*i].clone());
474 Arc::new(Schema::new(projected.collect::<Fields>()))
475 }
476 None => self.decoder.schema.clone(),
477 }
478 }
479}
480
481impl<R: BufRead> BufReader<R> {
482 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483 loop {
484 let buf = self.reader.fill_buf()?;
485 let decoded = self.decoder.decode(buf)?;
486 self.reader.consume(decoded);
487 if decoded == 0 || self.decoder.capacity() == 0 {
493 break;
494 }
495 }
496
497 self.decoder.flush()
498 }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502 type Item = Result<RecordBatch, ArrowError>;
503
504 fn next(&mut self) -> Option<Self::Item> {
505 self.read().transpose()
506 }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510 fn schema(&self) -> SchemaRef {
511 self.decoder.schema.clone()
512 }
513}
514
515#[derive(Debug)]
555pub struct Decoder {
556 schema: SchemaRef,
558
559 projection: Option<Vec<usize>>,
561
562 batch_size: usize,
564
565 to_skip: usize,
567
568 line_number: usize,
570
571 end: usize,
573
574 record_decoder: RecordDecoder,
576
577 null_regex: NullRegex,
579}
580
581impl Decoder {
582 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592 if self.to_skip != 0 {
593 let to_skip = self.to_skip.min(self.batch_size);
595 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596 self.to_skip -= skipped;
597 self.record_decoder.clear();
598 return Ok(bytes);
599 }
600
601 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603 Ok(bytes)
604 }
605
606 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613 if self.record_decoder.is_empty() {
614 return Ok(None);
615 }
616
617 let rows = self.record_decoder.flush()?;
618 let batch = parse(
619 &rows,
620 self.schema.fields(),
621 Some(self.schema.metadata.clone()),
622 self.projection.as_ref(),
623 self.line_number,
624 &self.null_regex,
625 )?;
626 self.line_number += rows.len();
627 Ok(Some(batch))
628 }
629
630 pub fn capacity(&self) -> usize {
632 self.batch_size - self.record_decoder.len()
633 }
634}
635
636fn parse(
638 rows: &StringRecords<'_>,
639 fields: &Fields,
640 metadata: Option<std::collections::HashMap<String, String>>,
641 projection: Option<&Vec<usize>>,
642 line_number: usize,
643 null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645 let projection: Vec<usize> = match projection {
646 Some(v) => v.clone(),
647 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648 };
649
650 let arrays: Result<Vec<ArrayRef>, _> = projection
651 .iter()
652 .map(|i| {
653 let i = *i;
654 let field = &fields[i];
655 match field.data_type() {
656 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657 DataType::Decimal32(precision, scale) => build_decimal_array::<Decimal32Type>(
658 line_number,
659 rows,
660 i,
661 *precision,
662 *scale,
663 null_regex,
664 ),
665 DataType::Decimal64(precision, scale) => build_decimal_array::<Decimal64Type>(
666 line_number,
667 rows,
668 i,
669 *precision,
670 *scale,
671 null_regex,
672 ),
673 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
674 line_number,
675 rows,
676 i,
677 *precision,
678 *scale,
679 null_regex,
680 ),
681 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
682 line_number,
683 rows,
684 i,
685 *precision,
686 *scale,
687 null_regex,
688 ),
689 DataType::Int8 => {
690 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
691 }
692 DataType::Int16 => {
693 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
694 }
695 DataType::Int32 => {
696 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
697 }
698 DataType::Int64 => {
699 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
700 }
701 DataType::UInt8 => {
702 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
703 }
704 DataType::UInt16 => {
705 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
706 }
707 DataType::UInt32 => {
708 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
709 }
710 DataType::UInt64 => {
711 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
712 }
713 DataType::Float32 => {
714 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
715 }
716 DataType::Float64 => {
717 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
718 }
719 DataType::Date32 => {
720 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
721 }
722 DataType::Date64 => {
723 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
724 }
725 DataType::Time32(TimeUnit::Second) => {
726 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
727 }
728 DataType::Time32(TimeUnit::Millisecond) => {
729 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
730 }
731 DataType::Time64(TimeUnit::Microsecond) => {
732 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
733 }
734 DataType::Time64(TimeUnit::Nanosecond) => {
735 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
736 }
737 DataType::Timestamp(TimeUnit::Second, tz) => {
738 build_timestamp_array::<TimestampSecondType>(
739 line_number,
740 rows,
741 i,
742 tz.as_deref(),
743 null_regex,
744 )
745 }
746 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
747 build_timestamp_array::<TimestampMillisecondType>(
748 line_number,
749 rows,
750 i,
751 tz.as_deref(),
752 null_regex,
753 )
754 }
755 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
756 build_timestamp_array::<TimestampMicrosecondType>(
757 line_number,
758 rows,
759 i,
760 tz.as_deref(),
761 null_regex,
762 )
763 }
764 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
765 build_timestamp_array::<TimestampNanosecondType>(
766 line_number,
767 rows,
768 i,
769 tz.as_deref(),
770 null_regex,
771 )
772 }
773 DataType::Null => Ok(Arc::new({
774 let mut builder = NullBuilder::new();
775 builder.append_nulls(rows.len());
776 builder.finish()
777 }) as ArrayRef),
778 DataType::Utf8 => Ok(Arc::new(
779 rows.iter()
780 .map(|row| {
781 let s = row.get(i);
782 (!null_regex.is_null(s)).then_some(s)
783 })
784 .collect::<StringArray>(),
785 ) as ArrayRef),
786 DataType::Utf8View => Ok(Arc::new(
787 rows.iter()
788 .map(|row| {
789 let s = row.get(i);
790 (!null_regex.is_null(s)).then_some(s)
791 })
792 .collect::<StringViewArray>(),
793 ) as ArrayRef),
794 DataType::Dictionary(key_type, value_type)
795 if value_type.as_ref() == &DataType::Utf8 =>
796 {
797 match key_type.as_ref() {
798 DataType::Int8 => Ok(Arc::new(
799 rows.iter()
800 .map(|row| {
801 let s = row.get(i);
802 (!null_regex.is_null(s)).then_some(s)
803 })
804 .collect::<DictionaryArray<Int8Type>>(),
805 ) as ArrayRef),
806 DataType::Int16 => Ok(Arc::new(
807 rows.iter()
808 .map(|row| {
809 let s = row.get(i);
810 (!null_regex.is_null(s)).then_some(s)
811 })
812 .collect::<DictionaryArray<Int16Type>>(),
813 ) as ArrayRef),
814 DataType::Int32 => Ok(Arc::new(
815 rows.iter()
816 .map(|row| {
817 let s = row.get(i);
818 (!null_regex.is_null(s)).then_some(s)
819 })
820 .collect::<DictionaryArray<Int32Type>>(),
821 ) as ArrayRef),
822 DataType::Int64 => Ok(Arc::new(
823 rows.iter()
824 .map(|row| {
825 let s = row.get(i);
826 (!null_regex.is_null(s)).then_some(s)
827 })
828 .collect::<DictionaryArray<Int64Type>>(),
829 ) as ArrayRef),
830 DataType::UInt8 => Ok(Arc::new(
831 rows.iter()
832 .map(|row| {
833 let s = row.get(i);
834 (!null_regex.is_null(s)).then_some(s)
835 })
836 .collect::<DictionaryArray<UInt8Type>>(),
837 ) as ArrayRef),
838 DataType::UInt16 => Ok(Arc::new(
839 rows.iter()
840 .map(|row| {
841 let s = row.get(i);
842 (!null_regex.is_null(s)).then_some(s)
843 })
844 .collect::<DictionaryArray<UInt16Type>>(),
845 ) as ArrayRef),
846 DataType::UInt32 => Ok(Arc::new(
847 rows.iter()
848 .map(|row| {
849 let s = row.get(i);
850 (!null_regex.is_null(s)).then_some(s)
851 })
852 .collect::<DictionaryArray<UInt32Type>>(),
853 ) as ArrayRef),
854 DataType::UInt64 => Ok(Arc::new(
855 rows.iter()
856 .map(|row| {
857 let s = row.get(i);
858 (!null_regex.is_null(s)).then_some(s)
859 })
860 .collect::<DictionaryArray<UInt64Type>>(),
861 ) as ArrayRef),
862 _ => Err(ArrowError::ParseError(format!(
863 "Unsupported dictionary key type {key_type:?}"
864 ))),
865 }
866 }
867 other => Err(ArrowError::ParseError(format!(
868 "Unsupported data type {other:?}"
869 ))),
870 }
871 })
872 .collect();
873
874 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
875
876 let projected_schema = Arc::new(match metadata {
877 None => Schema::new(projected_fields),
878 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
879 });
880
881 arrays.and_then(|arr| {
882 RecordBatch::try_new_with_options(
883 projected_schema,
884 arr,
885 &RecordBatchOptions::new()
886 .with_match_field_names(true)
887 .with_row_count(Some(rows.len())),
888 )
889 })
890}
891
892fn parse_bool(string: &str) -> Option<bool> {
893 if string.eq_ignore_ascii_case("false") {
894 Some(false)
895 } else if string.eq_ignore_ascii_case("true") {
896 Some(true)
897 } else {
898 None
899 }
900}
901
902fn build_decimal_array<T: DecimalType>(
904 _line_number: usize,
905 rows: &StringRecords<'_>,
906 col_idx: usize,
907 precision: u8,
908 scale: i8,
909 null_regex: &NullRegex,
910) -> Result<ArrayRef, ArrowError> {
911 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
912 for row in rows.iter() {
913 let s = row.get(col_idx);
914 if null_regex.is_null(s) {
915 decimal_builder.append_null();
917 } else {
918 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
919 match decimal_value {
920 Ok(v) => {
921 decimal_builder.append_value(v);
922 }
923 Err(e) => {
924 return Err(e);
925 }
926 }
927 }
928 }
929 Ok(Arc::new(
930 decimal_builder
931 .finish()
932 .with_precision_and_scale(precision, scale)?,
933 ))
934}
935
936fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
938 line_number: usize,
939 rows: &StringRecords<'_>,
940 col_idx: usize,
941 null_regex: &NullRegex,
942) -> Result<ArrayRef, ArrowError> {
943 rows.iter()
944 .enumerate()
945 .map(|(row_index, row)| {
946 let s = row.get(col_idx);
947 if null_regex.is_null(s) {
948 return Ok(None);
949 }
950
951 match T::parse(s) {
952 Some(e) => Ok(Some(e)),
953 None => Err(ArrowError::ParseError(format!(
954 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
956 s,
957 T::DATA_TYPE,
958 col_idx,
959 line_number + row_index,
960 row
961 ))),
962 }
963 })
964 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
965 .map(|e| Arc::new(e) as ArrayRef)
966}
967
968fn build_timestamp_array<T: ArrowTimestampType>(
969 line_number: usize,
970 rows: &StringRecords<'_>,
971 col_idx: usize,
972 timezone: Option<&str>,
973 null_regex: &NullRegex,
974) -> Result<ArrayRef, ArrowError> {
975 Ok(Arc::new(match timezone {
976 Some(timezone) => {
977 let tz: Tz = timezone.parse()?;
978 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
979 .with_timezone(timezone)
980 }
981 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
982 }))
983}
984
985fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
986 line_number: usize,
987 rows: &StringRecords<'_>,
988 col_idx: usize,
989 timezone: &Tz,
990 null_regex: &NullRegex,
991) -> Result<PrimitiveArray<T>, ArrowError> {
992 rows.iter()
993 .enumerate()
994 .map(|(row_index, row)| {
995 let s = row.get(col_idx);
996 if null_regex.is_null(s) {
997 return Ok(None);
998 }
999
1000 let date = string_to_datetime(timezone, s)
1001 .and_then(|date| match T::UNIT {
1002 TimeUnit::Second => Ok(date.timestamp()),
1003 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
1004 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
1005 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
1006 ArrowError::ParseError(format!(
1007 "{} would overflow 64-bit signed nanoseconds",
1008 date.to_rfc3339(),
1009 ))
1010 }),
1011 })
1012 .map_err(|e| {
1013 ArrowError::ParseError(format!(
1014 "Error parsing column {col_idx} at line {}: {}",
1015 line_number + row_index,
1016 e
1017 ))
1018 })?;
1019 Ok(Some(date))
1020 })
1021 .collect()
1022}
1023
1024fn build_boolean_array(
1026 line_number: usize,
1027 rows: &StringRecords<'_>,
1028 col_idx: usize,
1029 null_regex: &NullRegex,
1030) -> Result<ArrayRef, ArrowError> {
1031 rows.iter()
1032 .enumerate()
1033 .map(|(row_index, row)| {
1034 let s = row.get(col_idx);
1035 if null_regex.is_null(s) {
1036 return Ok(None);
1037 }
1038 let parsed = parse_bool(s);
1039 match parsed {
1040 Some(e) => Ok(Some(e)),
1041 None => Err(ArrowError::ParseError(format!(
1042 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1044 s,
1045 "Boolean",
1046 col_idx,
1047 line_number + row_index,
1048 row
1049 ))),
1050 }
1051 })
1052 .collect::<Result<BooleanArray, _>>()
1053 .map(|e| Arc::new(e) as ArrayRef)
1054}
1055
1056#[derive(Debug)]
1058pub struct ReaderBuilder {
1059 schema: SchemaRef,
1061 format: Format,
1063 batch_size: usize,
1067 bounds: Bounds,
1069 projection: Option<Vec<usize>>,
1071}
1072
1073impl ReaderBuilder {
1074 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1096 Self {
1097 schema,
1098 format: Format::default(),
1099 batch_size: 1024,
1100 bounds: None,
1101 projection: None,
1102 }
1103 }
1104
1105 pub fn with_header(mut self, has_header: bool) -> Self {
1107 self.format.header = has_header;
1108 self
1109 }
1110
1111 pub fn with_format(mut self, format: Format) -> Self {
1113 self.format = format;
1114 self
1115 }
1116
1117 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1119 self.format.delimiter = Some(delimiter);
1120 self
1121 }
1122
1123 pub fn with_escape(mut self, escape: u8) -> Self {
1125 self.format.escape = Some(escape);
1126 self
1127 }
1128
1129 pub fn with_quote(mut self, quote: u8) -> Self {
1131 self.format.quote = Some(quote);
1132 self
1133 }
1134
1135 pub fn with_terminator(mut self, terminator: u8) -> Self {
1137 self.format.terminator = Some(terminator);
1138 self
1139 }
1140
1141 pub fn with_comment(mut self, comment: u8) -> Self {
1143 self.format.comment = Some(comment);
1144 self
1145 }
1146
1147 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1149 self.format.null_regex = NullRegex(Some(null_regex));
1150 self
1151 }
1152
1153 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1155 self.batch_size = batch_size;
1156 self
1157 }
1158
1159 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1162 self.bounds = Some((start, end));
1163 self
1164 }
1165
1166 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1168 self.projection = Some(projection);
1169 self
1170 }
1171
1172 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1179 self.format.truncated_rows = allow;
1180 self
1181 }
1182
1183 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1188 self.build_buffered(StdBufReader::new(reader))
1189 }
1190
1191 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1193 Ok(BufReader {
1194 reader,
1195 decoder: self.build_decoder(),
1196 })
1197 }
1198
1199 pub fn build_decoder(self) -> Decoder {
1201 let delimiter = self.format.build_parser();
1202 let record_decoder = RecordDecoder::new(
1203 delimiter,
1204 self.schema.fields().len(),
1205 self.format.truncated_rows,
1206 );
1207
1208 let header = self.format.header as usize;
1209
1210 let (start, end) = match self.bounds {
1211 Some((start, end)) => (start + header, end + header),
1212 None => (header, usize::MAX),
1213 };
1214
1215 Decoder {
1216 schema: self.schema,
1217 to_skip: start,
1218 record_decoder,
1219 line_number: start,
1220 end,
1221 projection: self.projection,
1222 batch_size: self.batch_size,
1223 null_regex: self.format.null_regex,
1224 }
1225 }
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231
1232 use std::io::{Cursor, Seek, SeekFrom, Write};
1233 use tempfile::NamedTempFile;
1234
1235 use arrow_array::cast::AsArray;
1236
1237 #[test]
1238 fn test_csv() {
1239 let schema = Arc::new(Schema::new(vec![
1240 Field::new("city", DataType::Utf8, false),
1241 Field::new("lat", DataType::Float64, false),
1242 Field::new("lng", DataType::Float64, false),
1243 ]));
1244
1245 let file = File::open("test/data/uk_cities.csv").unwrap();
1246 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1247 assert_eq!(schema, csv.schema());
1248 let batch = csv.next().unwrap().unwrap();
1249 assert_eq!(37, batch.num_rows());
1250 assert_eq!(3, batch.num_columns());
1251
1252 let lat = batch.column(1).as_primitive::<Float64Type>();
1254 assert_eq!(57.653484, lat.value(0));
1255
1256 let city = batch.column(0).as_string::<i32>();
1258
1259 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1260 }
1261
1262 #[test]
1263 fn test_csv_schema_metadata() {
1264 let mut metadata = std::collections::HashMap::new();
1265 metadata.insert("foo".to_owned(), "bar".to_owned());
1266 let schema = Arc::new(Schema::new_with_metadata(
1267 vec![
1268 Field::new("city", DataType::Utf8, false),
1269 Field::new("lat", DataType::Float64, false),
1270 Field::new("lng", DataType::Float64, false),
1271 ],
1272 metadata.clone(),
1273 ));
1274
1275 let file = File::open("test/data/uk_cities.csv").unwrap();
1276
1277 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1278 assert_eq!(schema, csv.schema());
1279 let batch = csv.next().unwrap().unwrap();
1280 assert_eq!(37, batch.num_rows());
1281 assert_eq!(3, batch.num_columns());
1282
1283 assert_eq!(&metadata, batch.schema().metadata());
1284 }
1285
1286 #[test]
1287 fn test_csv_reader_with_decimal() {
1288 let schema = Arc::new(Schema::new(vec![
1289 Field::new("city", DataType::Utf8, false),
1290 Field::new("lat", DataType::Decimal128(38, 6), false),
1291 Field::new("lng", DataType::Decimal256(76, 6), false),
1292 ]));
1293
1294 let file = File::open("test/data/decimal_test.csv").unwrap();
1295
1296 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1297 let batch = csv.next().unwrap().unwrap();
1298 let lat = batch
1300 .column(1)
1301 .as_any()
1302 .downcast_ref::<Decimal128Array>()
1303 .unwrap();
1304
1305 assert_eq!("57.653484", lat.value_as_string(0));
1306 assert_eq!("53.002666", lat.value_as_string(1));
1307 assert_eq!("52.412811", lat.value_as_string(2));
1308 assert_eq!("51.481583", lat.value_as_string(3));
1309 assert_eq!("12.123456", lat.value_as_string(4));
1310 assert_eq!("50.760000", lat.value_as_string(5));
1311 assert_eq!("0.123000", lat.value_as_string(6));
1312 assert_eq!("123.000000", lat.value_as_string(7));
1313 assert_eq!("123.000000", lat.value_as_string(8));
1314 assert_eq!("-50.760000", lat.value_as_string(9));
1315
1316 let lng = batch
1317 .column(2)
1318 .as_any()
1319 .downcast_ref::<Decimal256Array>()
1320 .unwrap();
1321
1322 assert_eq!("-3.335724", lng.value_as_string(0));
1323 assert_eq!("-2.179404", lng.value_as_string(1));
1324 assert_eq!("-1.778197", lng.value_as_string(2));
1325 assert_eq!("-3.179090", lng.value_as_string(3));
1326 assert_eq!("-3.179090", lng.value_as_string(4));
1327 assert_eq!("0.290472", lng.value_as_string(5));
1328 assert_eq!("0.290472", lng.value_as_string(6));
1329 assert_eq!("0.290472", lng.value_as_string(7));
1330 assert_eq!("0.290472", lng.value_as_string(8));
1331 assert_eq!("0.290472", lng.value_as_string(9));
1332 }
1333
1334 #[test]
1335 fn test_csv_reader_with_decimal_3264() {
1336 let schema = Arc::new(Schema::new(vec![
1337 Field::new("city", DataType::Utf8, false),
1338 Field::new("lat", DataType::Decimal32(9, 6), false),
1339 Field::new("lng", DataType::Decimal64(16, 6), false),
1340 ]));
1341
1342 let file = File::open("test/data/decimal_test.csv").unwrap();
1343
1344 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1345 let batch = csv.next().unwrap().unwrap();
1346 let lat = batch
1348 .column(1)
1349 .as_any()
1350 .downcast_ref::<Decimal32Array>()
1351 .unwrap();
1352
1353 assert_eq!("57.653484", lat.value_as_string(0));
1354 assert_eq!("53.002666", lat.value_as_string(1));
1355 assert_eq!("52.412811", lat.value_as_string(2));
1356 assert_eq!("51.481583", lat.value_as_string(3));
1357 assert_eq!("12.123456", lat.value_as_string(4));
1358 assert_eq!("50.760000", lat.value_as_string(5));
1359 assert_eq!("0.123000", lat.value_as_string(6));
1360 assert_eq!("123.000000", lat.value_as_string(7));
1361 assert_eq!("123.000000", lat.value_as_string(8));
1362 assert_eq!("-50.760000", lat.value_as_string(9));
1363
1364 let lng = batch
1365 .column(2)
1366 .as_any()
1367 .downcast_ref::<Decimal64Array>()
1368 .unwrap();
1369
1370 assert_eq!("-3.335724", lng.value_as_string(0));
1371 assert_eq!("-2.179404", lng.value_as_string(1));
1372 assert_eq!("-1.778197", lng.value_as_string(2));
1373 assert_eq!("-3.179090", lng.value_as_string(3));
1374 assert_eq!("-3.179090", lng.value_as_string(4));
1375 assert_eq!("0.290472", lng.value_as_string(5));
1376 assert_eq!("0.290472", lng.value_as_string(6));
1377 assert_eq!("0.290472", lng.value_as_string(7));
1378 assert_eq!("0.290472", lng.value_as_string(8));
1379 assert_eq!("0.290472", lng.value_as_string(9));
1380 }
1381
1382 #[test]
1383 fn test_csv_from_buf_reader() {
1384 let schema = Schema::new(vec![
1385 Field::new("city", DataType::Utf8, false),
1386 Field::new("lat", DataType::Float64, false),
1387 Field::new("lng", DataType::Float64, false),
1388 ]);
1389
1390 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1391 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1392 let both_files = file_with_headers
1393 .chain(Cursor::new("\n".to_string()))
1394 .chain(file_without_headers);
1395 let mut csv = ReaderBuilder::new(Arc::new(schema))
1396 .with_header(true)
1397 .build(both_files)
1398 .unwrap();
1399 let batch = csv.next().unwrap().unwrap();
1400 assert_eq!(74, batch.num_rows());
1401 assert_eq!(3, batch.num_columns());
1402 }
1403
1404 #[test]
1405 fn test_csv_with_schema_inference() {
1406 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1407
1408 let (schema, _) = Format::default()
1409 .with_header(true)
1410 .infer_schema(&mut file, None)
1411 .unwrap();
1412
1413 file.rewind().unwrap();
1414 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1415
1416 let mut csv = builder.build(file).unwrap();
1417 let expected_schema = Schema::new(vec![
1418 Field::new("city", DataType::Utf8, true),
1419 Field::new("lat", DataType::Float64, true),
1420 Field::new("lng", DataType::Float64, true),
1421 ]);
1422 assert_eq!(Arc::new(expected_schema), csv.schema());
1423 let batch = csv.next().unwrap().unwrap();
1424 assert_eq!(37, batch.num_rows());
1425 assert_eq!(3, batch.num_columns());
1426
1427 let lat = batch
1429 .column(1)
1430 .as_any()
1431 .downcast_ref::<Float64Array>()
1432 .unwrap();
1433 assert_eq!(57.653484, lat.value(0));
1434
1435 let city = batch
1437 .column(0)
1438 .as_any()
1439 .downcast_ref::<StringArray>()
1440 .unwrap();
1441
1442 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1443 }
1444
1445 #[test]
1446 fn test_csv_with_schema_inference_no_headers() {
1447 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1448
1449 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1450 file.rewind().unwrap();
1451
1452 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1453
1454 let schema = csv.schema();
1456 assert_eq!("column_1", schema.field(0).name());
1457 assert_eq!("column_2", schema.field(1).name());
1458 assert_eq!("column_3", schema.field(2).name());
1459 let batch = csv.next().unwrap().unwrap();
1460 let batch_schema = batch.schema();
1461
1462 assert_eq!(schema, batch_schema);
1463 assert_eq!(37, batch.num_rows());
1464 assert_eq!(3, batch.num_columns());
1465
1466 let lat = batch
1468 .column(1)
1469 .as_any()
1470 .downcast_ref::<Float64Array>()
1471 .unwrap();
1472 assert_eq!(57.653484, lat.value(0));
1473
1474 let city = batch
1476 .column(0)
1477 .as_any()
1478 .downcast_ref::<StringArray>()
1479 .unwrap();
1480
1481 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1482 }
1483
1484 #[test]
1485 fn test_csv_builder_with_bounds() {
1486 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1487
1488 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1490 file.rewind().unwrap();
1491 let mut csv = ReaderBuilder::new(Arc::new(schema))
1492 .with_bounds(0, 2)
1493 .build(file)
1494 .unwrap();
1495 let batch = csv.next().unwrap().unwrap();
1496
1497 let city = batch
1499 .column(0)
1500 .as_any()
1501 .downcast_ref::<StringArray>()
1502 .unwrap();
1503
1504 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1506
1507 let result = std::panic::catch_unwind(|| city.value(13));
1510 assert!(result.is_err());
1511 }
1512
1513 #[test]
1514 fn test_csv_with_projection() {
1515 let schema = Arc::new(Schema::new(vec![
1516 Field::new("city", DataType::Utf8, false),
1517 Field::new("lat", DataType::Float64, false),
1518 Field::new("lng", DataType::Float64, false),
1519 ]));
1520
1521 let file = File::open("test/data/uk_cities.csv").unwrap();
1522
1523 let mut csv = ReaderBuilder::new(schema)
1524 .with_projection(vec![0, 1])
1525 .build(file)
1526 .unwrap();
1527
1528 let projected_schema = Arc::new(Schema::new(vec![
1529 Field::new("city", DataType::Utf8, false),
1530 Field::new("lat", DataType::Float64, false),
1531 ]));
1532 assert_eq!(projected_schema, csv.schema());
1533 let batch = csv.next().unwrap().unwrap();
1534 assert_eq!(projected_schema, batch.schema());
1535 assert_eq!(37, batch.num_rows());
1536 assert_eq!(2, batch.num_columns());
1537 }
1538
1539 #[test]
1540 fn test_csv_with_dictionary() {
1541 let schema = Arc::new(Schema::new(vec![
1542 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1543 Field::new("lat", DataType::Float64, false),
1544 Field::new("lng", DataType::Float64, false),
1545 ]));
1546
1547 let file = File::open("test/data/uk_cities.csv").unwrap();
1548
1549 let mut csv = ReaderBuilder::new(schema)
1550 .with_projection(vec![0, 1])
1551 .build(file)
1552 .unwrap();
1553
1554 let projected_schema = Arc::new(Schema::new(vec![
1555 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1556 Field::new("lat", DataType::Float64, false),
1557 ]));
1558 assert_eq!(projected_schema, csv.schema());
1559 let batch = csv.next().unwrap().unwrap();
1560 assert_eq!(projected_schema, batch.schema());
1561 assert_eq!(37, batch.num_rows());
1562 assert_eq!(2, batch.num_columns());
1563
1564 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1565 let strings = strings.as_string::<i32>();
1566
1567 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1568 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1569 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1570 }
1571
1572 #[test]
1573 fn test_csv_with_nullable_dictionary() {
1574 let offset_type = vec![
1575 DataType::Int8,
1576 DataType::Int16,
1577 DataType::Int32,
1578 DataType::Int64,
1579 DataType::UInt8,
1580 DataType::UInt16,
1581 DataType::UInt32,
1582 DataType::UInt64,
1583 ];
1584 for data_type in offset_type {
1585 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1586 let dictionary_type =
1587 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1588 let schema = Arc::new(Schema::new(vec![
1589 Field::new("id", DataType::Utf8, false),
1590 Field::new("name", dictionary_type.clone(), true),
1591 ]));
1592
1593 let mut csv = ReaderBuilder::new(schema)
1594 .build(file.try_clone().unwrap())
1595 .unwrap();
1596
1597 let batch = csv.next().unwrap().unwrap();
1598 assert_eq!(3, batch.num_rows());
1599 assert_eq!(2, batch.num_columns());
1600
1601 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1602 assert!(!names.is_null(2));
1603 assert!(names.is_null(1));
1604 }
1605 }
1606 #[test]
1607 fn test_nulls() {
1608 let schema = Arc::new(Schema::new(vec![
1609 Field::new("c_int", DataType::UInt64, false),
1610 Field::new("c_float", DataType::Float32, true),
1611 Field::new("c_string", DataType::Utf8, true),
1612 Field::new("c_bool", DataType::Boolean, false),
1613 ]));
1614
1615 let file = File::open("test/data/null_test.csv").unwrap();
1616
1617 let mut csv = ReaderBuilder::new(schema)
1618 .with_header(true)
1619 .build(file)
1620 .unwrap();
1621
1622 let batch = csv.next().unwrap().unwrap();
1623
1624 assert!(!batch.column(1).is_null(0));
1625 assert!(!batch.column(1).is_null(1));
1626 assert!(batch.column(1).is_null(2));
1627 assert!(!batch.column(1).is_null(3));
1628 assert!(!batch.column(1).is_null(4));
1629 }
1630
1631 #[test]
1632 fn test_init_nulls() {
1633 let schema = Arc::new(Schema::new(vec![
1634 Field::new("c_int", DataType::UInt64, true),
1635 Field::new("c_float", DataType::Float32, true),
1636 Field::new("c_string", DataType::Utf8, true),
1637 Field::new("c_bool", DataType::Boolean, true),
1638 Field::new("c_null", DataType::Null, true),
1639 ]));
1640 let file = File::open("test/data/init_null_test.csv").unwrap();
1641
1642 let mut csv = ReaderBuilder::new(schema)
1643 .with_header(true)
1644 .build(file)
1645 .unwrap();
1646
1647 let batch = csv.next().unwrap().unwrap();
1648
1649 assert!(batch.column(1).is_null(0));
1650 assert!(!batch.column(1).is_null(1));
1651 assert!(batch.column(1).is_null(2));
1652 assert!(!batch.column(1).is_null(3));
1653 assert!(!batch.column(1).is_null(4));
1654 }
1655
1656 #[test]
1657 fn test_init_nulls_with_inference() {
1658 let format = Format::default().with_header(true).with_delimiter(b',');
1659
1660 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1661 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1662 file.rewind().unwrap();
1663
1664 let expected_schema = Schema::new(vec![
1665 Field::new("c_int", DataType::Int64, true),
1666 Field::new("c_float", DataType::Float64, true),
1667 Field::new("c_string", DataType::Utf8, true),
1668 Field::new("c_bool", DataType::Boolean, true),
1669 Field::new("c_null", DataType::Null, true),
1670 ]);
1671 assert_eq!(schema, expected_schema);
1672
1673 let mut csv = ReaderBuilder::new(Arc::new(schema))
1674 .with_format(format)
1675 .build(file)
1676 .unwrap();
1677
1678 let batch = csv.next().unwrap().unwrap();
1679
1680 assert!(batch.column(1).is_null(0));
1681 assert!(!batch.column(1).is_null(1));
1682 assert!(batch.column(1).is_null(2));
1683 assert!(!batch.column(1).is_null(3));
1684 assert!(!batch.column(1).is_null(4));
1685 }
1686
1687 #[test]
1688 fn test_custom_nulls() {
1689 let schema = Arc::new(Schema::new(vec![
1690 Field::new("c_int", DataType::UInt64, true),
1691 Field::new("c_float", DataType::Float32, true),
1692 Field::new("c_string", DataType::Utf8, true),
1693 Field::new("c_bool", DataType::Boolean, true),
1694 ]));
1695
1696 let file = File::open("test/data/custom_null_test.csv").unwrap();
1697
1698 let null_regex = Regex::new("^nil$").unwrap();
1699
1700 let mut csv = ReaderBuilder::new(schema)
1701 .with_header(true)
1702 .with_null_regex(null_regex)
1703 .build(file)
1704 .unwrap();
1705
1706 let batch = csv.next().unwrap().unwrap();
1707
1708 assert!(batch.column(0).is_null(1));
1710 assert!(batch.column(1).is_null(2));
1711 assert!(batch.column(3).is_null(4));
1712 assert!(batch.column(2).is_null(3));
1713 assert!(!batch.column(2).is_null(4));
1714 }
1715
1716 #[test]
1717 fn test_nulls_with_inference() {
1718 let mut file = File::open("test/data/various_types.csv").unwrap();
1719 let format = Format::default().with_header(true).with_delimiter(b'|');
1720
1721 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1722 file.rewind().unwrap();
1723
1724 let builder = ReaderBuilder::new(Arc::new(schema))
1725 .with_format(format)
1726 .with_batch_size(512)
1727 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1728
1729 let mut csv = builder.build(file).unwrap();
1730 let batch = csv.next().unwrap().unwrap();
1731
1732 assert_eq!(10, batch.num_rows());
1733 assert_eq!(6, batch.num_columns());
1734
1735 let schema = batch.schema();
1736
1737 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1738 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1739 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1740 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1741 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1742 assert_eq!(
1743 &DataType::Timestamp(TimeUnit::Second, None),
1744 schema.field(5).data_type()
1745 );
1746
1747 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1748 assert_eq!(
1749 names,
1750 vec![
1751 "c_int",
1752 "c_float",
1753 "c_string",
1754 "c_bool",
1755 "c_date",
1756 "c_datetime"
1757 ]
1758 );
1759
1760 assert!(schema.field(0).is_nullable());
1761 assert!(schema.field(1).is_nullable());
1762 assert!(schema.field(2).is_nullable());
1763 assert!(schema.field(3).is_nullable());
1764 assert!(schema.field(4).is_nullable());
1765 assert!(schema.field(5).is_nullable());
1766
1767 assert!(!batch.column(1).is_null(0));
1768 assert!(!batch.column(1).is_null(1));
1769 assert!(batch.column(1).is_null(2));
1770 assert!(!batch.column(1).is_null(3));
1771 assert!(!batch.column(1).is_null(4));
1772 }
1773
1774 #[test]
1775 fn test_custom_nulls_with_inference() {
1776 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1777
1778 let null_regex = Regex::new("^nil$").unwrap();
1779
1780 let format = Format::default()
1781 .with_header(true)
1782 .with_null_regex(null_regex);
1783
1784 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1785 file.rewind().unwrap();
1786
1787 let expected_schema = Schema::new(vec![
1788 Field::new("c_int", DataType::Int64, true),
1789 Field::new("c_float", DataType::Float64, true),
1790 Field::new("c_string", DataType::Utf8, true),
1791 Field::new("c_bool", DataType::Boolean, true),
1792 ]);
1793
1794 assert_eq!(schema, expected_schema);
1795
1796 let builder = ReaderBuilder::new(Arc::new(schema))
1797 .with_format(format)
1798 .with_batch_size(512)
1799 .with_projection(vec![0, 1, 2, 3]);
1800
1801 let mut csv = builder.build(file).unwrap();
1802 let batch = csv.next().unwrap().unwrap();
1803
1804 assert_eq!(5, batch.num_rows());
1805 assert_eq!(4, batch.num_columns());
1806
1807 assert_eq!(batch.schema().as_ref(), &expected_schema);
1808 }
1809
1810 #[test]
1811 fn test_scientific_notation_with_inference() {
1812 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1813 let format = Format::default().with_header(false).with_delimiter(b',');
1814
1815 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1816 file.rewind().unwrap();
1817
1818 let builder = ReaderBuilder::new(Arc::new(schema))
1819 .with_format(format)
1820 .with_batch_size(512)
1821 .with_projection(vec![0, 1]);
1822
1823 let mut csv = builder.build(file).unwrap();
1824 let batch = csv.next().unwrap().unwrap();
1825
1826 let schema = batch.schema();
1827
1828 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1829 }
1830
1831 fn invalid_csv_helper(file_name: &str) -> String {
1832 let file = File::open(file_name).unwrap();
1833 let schema = Schema::new(vec![
1834 Field::new("c_int", DataType::UInt64, false),
1835 Field::new("c_float", DataType::Float32, false),
1836 Field::new("c_string", DataType::Utf8, false),
1837 Field::new("c_bool", DataType::Boolean, false),
1838 ]);
1839
1840 let builder = ReaderBuilder::new(Arc::new(schema))
1841 .with_header(true)
1842 .with_delimiter(b'|')
1843 .with_batch_size(512)
1844 .with_projection(vec![0, 1, 2, 3]);
1845
1846 let mut csv = builder.build(file).unwrap();
1847
1848 csv.next().unwrap().unwrap_err().to_string()
1849 }
1850
1851 #[test]
1852 fn test_parse_invalid_csv_float() {
1853 let file_name = "test/data/various_invalid_types/invalid_float.csv";
1854
1855 let error = invalid_csv_helper(file_name);
1856 assert_eq!("Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", error);
1857 }
1858
1859 #[test]
1860 fn test_parse_invalid_csv_int() {
1861 let file_name = "test/data/various_invalid_types/invalid_int.csv";
1862
1863 let error = invalid_csv_helper(file_name);
1864 assert_eq!("Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", error);
1865 }
1866
1867 #[test]
1868 fn test_parse_invalid_csv_bool() {
1869 let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1870
1871 let error = invalid_csv_helper(file_name);
1872 assert_eq!("Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", error);
1873 }
1874
1875 fn infer_field_schema(string: &str) -> DataType {
1877 let mut v = InferredDataType::default();
1878 v.update(string);
1879 v.get()
1880 }
1881
1882 #[test]
1883 fn test_infer_field_schema() {
1884 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1885 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1886 assert_eq!(infer_field_schema("10"), DataType::Int64);
1887 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1888 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1889 assert_eq!(infer_field_schema("2."), DataType::Float64);
1890 assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1891 assert_eq!(infer_field_schema("nan"), DataType::Float64);
1892 assert_eq!(infer_field_schema("inf"), DataType::Float64);
1893 assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1894 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1895 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1896 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1897 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1898 assert_eq!(
1899 infer_field_schema("2020-11-08T14:20:01"),
1900 DataType::Timestamp(TimeUnit::Second, None)
1901 );
1902 assert_eq!(
1903 infer_field_schema("2020-11-08 14:20:01"),
1904 DataType::Timestamp(TimeUnit::Second, None)
1905 );
1906 assert_eq!(
1907 infer_field_schema("2020-11-08 14:20:01"),
1908 DataType::Timestamp(TimeUnit::Second, None)
1909 );
1910 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1911 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1912 assert_eq!(
1913 infer_field_schema("2021-12-19 13:12:30.921"),
1914 DataType::Timestamp(TimeUnit::Millisecond, None)
1915 );
1916 assert_eq!(
1917 infer_field_schema("2021-12-19T13:12:30.123456789"),
1918 DataType::Timestamp(TimeUnit::Nanosecond, None)
1919 );
1920 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1921 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1922 }
1923
1924 #[test]
1925 fn parse_date32() {
1926 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1927 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1928 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1929 }
1930
1931 #[test]
1932 fn parse_time() {
1933 assert_eq!(
1934 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1935 Some(601_123_456_789)
1936 );
1937 assert_eq!(
1938 Time64MicrosecondType::parse("12:10:01.123456 am"),
1939 Some(601_123_456)
1940 );
1941 assert_eq!(
1942 Time32MillisecondType::parse("2:10:01.12 PM"),
1943 Some(51_001_120)
1944 );
1945 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1946 }
1947
1948 #[test]
1949 fn parse_date64() {
1950 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1951 assert_eq!(
1952 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1953 1542129070000
1954 );
1955 assert_eq!(
1956 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1957 1542129070011
1958 );
1959 assert_eq!(
1960 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1961 -2203932304000
1962 );
1963 assert_eq!(
1964 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1965 -2203932304000
1966 );
1967 assert_eq!(
1968 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1969 -2203932304000 - (30 * 60 * 1000)
1970 );
1971 }
1972
1973 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1974 timezone: Option<Arc<str>>,
1975 expected: &[i64],
1976 ) {
1977 let csv = [
1978 "1970-01-01T00:00:00",
1979 "1970-01-01T00:00:00Z",
1980 "1970-01-01T00:00:00+02:00",
1981 ]
1982 .join("\n");
1983 let schema = Arc::new(Schema::new(vec![Field::new(
1984 "field",
1985 DataType::Timestamp(T::UNIT, timezone.clone()),
1986 true,
1987 )]));
1988
1989 let mut decoder = ReaderBuilder::new(schema).build_decoder();
1990
1991 let decoded = decoder.decode(csv.as_bytes()).unwrap();
1992 assert_eq!(decoded, csv.len());
1993 decoder.decode(&[]).unwrap();
1994
1995 let batch = decoder.flush().unwrap().unwrap();
1996 assert_eq!(batch.num_columns(), 1);
1997 assert_eq!(batch.num_rows(), 3);
1998 let col = batch.column(0).as_primitive::<T>();
1999 assert_eq!(col.values(), expected);
2000 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
2001 }
2002
2003 #[test]
2004 fn test_parse_timestamp() {
2005 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
2006 test_parse_timestamp_impl::<TimestampNanosecondType>(
2007 Some("+00:00".into()),
2008 &[0, 0, -7_200_000_000_000],
2009 );
2010 test_parse_timestamp_impl::<TimestampNanosecondType>(
2011 Some("-05:00".into()),
2012 &[18_000_000_000_000, 0, -7_200_000_000_000],
2013 );
2014 test_parse_timestamp_impl::<TimestampMicrosecondType>(
2015 Some("-03".into()),
2016 &[10_800_000_000, 0, -7_200_000_000],
2017 );
2018 test_parse_timestamp_impl::<TimestampMillisecondType>(
2019 Some("-03".into()),
2020 &[10_800_000, 0, -7_200_000],
2021 );
2022 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
2023 }
2024
2025 #[test]
2026 fn test_infer_schema_from_multiple_files() {
2027 let mut csv1 = NamedTempFile::new().unwrap();
2028 let mut csv2 = NamedTempFile::new().unwrap();
2029 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
2031 writeln!(csv1, "c1,c2,c3").unwrap();
2032 writeln!(csv1, "1,\"foo\",0.5").unwrap();
2033 writeln!(csv1, "3,\"bar\",1").unwrap();
2034 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
2035 writeln!(csv2, "c1,c2,c3,c4").unwrap();
2037 writeln!(csv2, "10,,3.14,true").unwrap();
2038 writeln!(csv4, "c1,c2,c3").unwrap();
2040 writeln!(csv4, "10,\"foo\",").unwrap();
2041
2042 let schema = infer_schema_from_files(
2043 &[
2044 csv3.path().to_str().unwrap().to_string(),
2045 csv1.path().to_str().unwrap().to_string(),
2046 csv2.path().to_str().unwrap().to_string(),
2047 csv4.path().to_str().unwrap().to_string(),
2048 ],
2049 b',',
2050 Some(4), true,
2052 )
2053 .unwrap();
2054
2055 assert_eq!(schema.fields().len(), 4);
2056 assert!(schema.field(0).is_nullable());
2057 assert!(schema.field(1).is_nullable());
2058 assert!(schema.field(2).is_nullable());
2059 assert!(schema.field(3).is_nullable());
2060
2061 assert_eq!(&DataType::Int64, schema.field(0).data_type());
2062 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
2063 assert_eq!(&DataType::Float64, schema.field(2).data_type());
2064 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2065 }
2066
2067 #[test]
2068 fn test_bounded() {
2069 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2070 let data = [
2071 vec!["0"],
2072 vec!["1"],
2073 vec!["2"],
2074 vec!["3"],
2075 vec!["4"],
2076 vec!["5"],
2077 vec!["6"],
2078 ];
2079
2080 let data = data
2081 .iter()
2082 .map(|x| x.join(","))
2083 .collect::<Vec<_>>()
2084 .join("\n");
2085 let data = data.as_bytes();
2086
2087 let reader = std::io::Cursor::new(data);
2088
2089 let mut csv = ReaderBuilder::new(Arc::new(schema))
2090 .with_batch_size(2)
2091 .with_projection(vec![0])
2092 .with_bounds(2, 6)
2093 .build_buffered(reader)
2094 .unwrap();
2095
2096 let batch = csv.next().unwrap().unwrap();
2097 let a = batch.column(0);
2098 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2099 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2100
2101 let batch = csv.next().unwrap().unwrap();
2102 let a = batch.column(0);
2103 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2104 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2105
2106 assert!(csv.next().is_none());
2107 }
2108
2109 #[test]
2110 fn test_empty_projection() {
2111 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2112 let data = [vec!["0"], vec!["1"]];
2113
2114 let data = data
2115 .iter()
2116 .map(|x| x.join(","))
2117 .collect::<Vec<_>>()
2118 .join("\n");
2119
2120 let mut csv = ReaderBuilder::new(Arc::new(schema))
2121 .with_batch_size(2)
2122 .with_projection(vec![])
2123 .build_buffered(Cursor::new(data.as_bytes()))
2124 .unwrap();
2125
2126 let batch = csv.next().unwrap().unwrap();
2127 assert_eq!(batch.columns().len(), 0);
2128 assert_eq!(batch.num_rows(), 2);
2129
2130 assert!(csv.next().is_none());
2131 }
2132
2133 #[test]
2134 fn test_parsing_bool() {
2135 assert_eq!(Some(true), parse_bool("true"));
2137 assert_eq!(Some(true), parse_bool("tRUe"));
2138 assert_eq!(Some(true), parse_bool("True"));
2139 assert_eq!(Some(true), parse_bool("TRUE"));
2140 assert_eq!(None, parse_bool("t"));
2141 assert_eq!(None, parse_bool("T"));
2142 assert_eq!(None, parse_bool(""));
2143
2144 assert_eq!(Some(false), parse_bool("false"));
2145 assert_eq!(Some(false), parse_bool("fALse"));
2146 assert_eq!(Some(false), parse_bool("False"));
2147 assert_eq!(Some(false), parse_bool("FALSE"));
2148 assert_eq!(None, parse_bool("f"));
2149 assert_eq!(None, parse_bool("F"));
2150 assert_eq!(None, parse_bool(""));
2151 }
2152
2153 #[test]
2154 fn test_parsing_float() {
2155 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2156 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2157 assert_eq!(Some(12.0), Float64Type::parse("12"));
2158 assert_eq!(Some(0.0), Float64Type::parse("0"));
2159 assert_eq!(Some(2.0), Float64Type::parse("2."));
2160 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2161 assert!(Float64Type::parse("nan").unwrap().is_nan());
2162 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2163 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2164 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2165 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2166 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2167 assert_eq!(None, Float64Type::parse(""));
2168 assert_eq!(None, Float64Type::parse("dd"));
2169 assert_eq!(None, Float64Type::parse("12.34.56"));
2170 }
2171
2172 #[test]
2173 fn test_non_std_quote() {
2174 let schema = Schema::new(vec![
2175 Field::new("text1", DataType::Utf8, false),
2176 Field::new("text2", DataType::Utf8, false),
2177 ]);
2178 let builder = ReaderBuilder::new(Arc::new(schema))
2179 .with_header(false)
2180 .with_quote(b'~'); let mut csv_text = Vec::new();
2183 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2184 for index in 0..10 {
2185 let text1 = format!("id{index:}");
2186 let text2 = format!("value{index:}");
2187 csv_writer
2188 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2189 .unwrap();
2190 }
2191 let mut csv_reader = std::io::Cursor::new(&csv_text);
2192 let mut reader = builder.build(&mut csv_reader).unwrap();
2193 let batch = reader.next().unwrap().unwrap();
2194 let col0 = batch.column(0);
2195 assert_eq!(col0.len(), 10);
2196 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2197 assert_eq!(col0_arr.value(0), "id0");
2198 let col1 = batch.column(1);
2199 assert_eq!(col1.len(), 10);
2200 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2201 assert_eq!(col1_arr.value(5), "value5");
2202 }
2203
2204 #[test]
2205 fn test_non_std_escape() {
2206 let schema = Schema::new(vec![
2207 Field::new("text1", DataType::Utf8, false),
2208 Field::new("text2", DataType::Utf8, false),
2209 ]);
2210 let builder = ReaderBuilder::new(Arc::new(schema))
2211 .with_header(false)
2212 .with_escape(b'\\'); let mut csv_text = Vec::new();
2215 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2216 for index in 0..10 {
2217 let text1 = format!("id{index:}");
2218 let text2 = format!("value\\\"{index:}");
2219 csv_writer
2220 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2221 .unwrap();
2222 }
2223 let mut csv_reader = std::io::Cursor::new(&csv_text);
2224 let mut reader = builder.build(&mut csv_reader).unwrap();
2225 let batch = reader.next().unwrap().unwrap();
2226 let col0 = batch.column(0);
2227 assert_eq!(col0.len(), 10);
2228 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2229 assert_eq!(col0_arr.value(0), "id0");
2230 let col1 = batch.column(1);
2231 assert_eq!(col1.len(), 10);
2232 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2233 assert_eq!(col1_arr.value(5), "value\"5");
2234 }
2235
2236 #[test]
2237 fn test_non_std_terminator() {
2238 let schema = Schema::new(vec![
2239 Field::new("text1", DataType::Utf8, false),
2240 Field::new("text2", DataType::Utf8, false),
2241 ]);
2242 let builder = ReaderBuilder::new(Arc::new(schema))
2243 .with_header(false)
2244 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2247 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2248 for index in 0..10 {
2249 let text1 = format!("id{index:}");
2250 let text2 = format!("value{index:}");
2251 csv_writer
2252 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2253 .unwrap();
2254 }
2255 let mut csv_reader = std::io::Cursor::new(&csv_text);
2256 let mut reader = builder.build(&mut csv_reader).unwrap();
2257 let batch = reader.next().unwrap().unwrap();
2258 let col0 = batch.column(0);
2259 assert_eq!(col0.len(), 10);
2260 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2261 assert_eq!(col0_arr.value(0), "id0");
2262 let col1 = batch.column(1);
2263 assert_eq!(col1.len(), 10);
2264 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2265 assert_eq!(col1_arr.value(5), "value5");
2266 }
2267
2268 #[test]
2269 fn test_header_bounds() {
2270 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2271 let tests = [
2272 (None, false, 5),
2273 (None, true, 4),
2274 (Some((0, 4)), false, 4),
2275 (Some((1, 4)), false, 3),
2276 (Some((0, 4)), true, 4),
2277 (Some((1, 4)), true, 3),
2278 ];
2279 let schema = Arc::new(Schema::new(vec![
2280 Field::new("a", DataType::Utf8, false),
2281 Field::new("a", DataType::Utf8, false),
2282 ]));
2283
2284 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2285 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2286 if let Some((start, end)) = bounds {
2287 reader = reader.with_bounds(start, end);
2288 }
2289 let b = reader
2290 .build_buffered(Cursor::new(csv.as_bytes()))
2291 .unwrap()
2292 .next()
2293 .unwrap()
2294 .unwrap();
2295 assert_eq!(b.num_rows(), expected, "{idx}");
2296 }
2297 }
2298
2299 #[test]
2300 fn test_null_boolean() {
2301 let csv = "true,false\nFalse,True\n,True\nFalse,";
2302 let schema = Arc::new(Schema::new(vec![
2303 Field::new("a", DataType::Boolean, true),
2304 Field::new("a", DataType::Boolean, true),
2305 ]));
2306
2307 let b = ReaderBuilder::new(schema)
2308 .build_buffered(Cursor::new(csv.as_bytes()))
2309 .unwrap()
2310 .next()
2311 .unwrap()
2312 .unwrap();
2313
2314 assert_eq!(b.num_rows(), 4);
2315 assert_eq!(b.num_columns(), 2);
2316
2317 let c = b.column(0).as_boolean();
2318 assert_eq!(c.null_count(), 1);
2319 assert!(c.value(0));
2320 assert!(!c.value(1));
2321 assert!(c.is_null(2));
2322 assert!(!c.value(3));
2323
2324 let c = b.column(1).as_boolean();
2325 assert_eq!(c.null_count(), 1);
2326 assert!(!c.value(0));
2327 assert!(c.value(1));
2328 assert!(c.value(2));
2329 assert!(c.is_null(3));
2330 }
2331
2332 #[test]
2333 fn test_truncated_rows() {
2334 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2335 let schema = Arc::new(Schema::new(vec![
2336 Field::new("a", DataType::Int32, true),
2337 Field::new("b", DataType::Int32, true),
2338 Field::new("c", DataType::Int32, true),
2339 ]));
2340
2341 let reader = ReaderBuilder::new(schema.clone())
2342 .with_header(true)
2343 .with_truncated_rows(true)
2344 .build(Cursor::new(data))
2345 .unwrap();
2346
2347 let batches = reader.collect::<Result<Vec<_>, _>>();
2348 assert!(batches.is_ok());
2349 let batch = batches.unwrap().into_iter().next().unwrap();
2350 assert_eq!(batch.num_rows(), 3);
2352
2353 let reader = ReaderBuilder::new(schema.clone())
2354 .with_header(true)
2355 .with_truncated_rows(false)
2356 .build(Cursor::new(data))
2357 .unwrap();
2358
2359 let batches = reader.collect::<Result<Vec<_>, _>>();
2360 assert!(match batches {
2361 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2362 _ => false,
2363 });
2364 }
2365
2366 #[test]
2367 fn test_truncated_rows_csv() {
2368 let file = File::open("test/data/truncated_rows.csv").unwrap();
2369 let schema = Arc::new(Schema::new(vec![
2370 Field::new("Name", DataType::Utf8, true),
2371 Field::new("Age", DataType::UInt32, true),
2372 Field::new("Occupation", DataType::Utf8, true),
2373 Field::new("DOB", DataType::Date32, true),
2374 ]));
2375 let reader = ReaderBuilder::new(schema.clone())
2376 .with_header(true)
2377 .with_batch_size(24)
2378 .with_truncated_rows(true);
2379 let csv = reader.build(file).unwrap();
2380 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2381
2382 assert_eq!(batches.len(), 1);
2383 let batch = &batches[0];
2384 assert_eq!(batch.num_rows(), 6);
2385 assert_eq!(batch.num_columns(), 4);
2386 let name = batch
2387 .column(0)
2388 .as_any()
2389 .downcast_ref::<StringArray>()
2390 .unwrap();
2391 let age = batch
2392 .column(1)
2393 .as_any()
2394 .downcast_ref::<UInt32Array>()
2395 .unwrap();
2396 let occupation = batch
2397 .column(2)
2398 .as_any()
2399 .downcast_ref::<StringArray>()
2400 .unwrap();
2401 let dob = batch
2402 .column(3)
2403 .as_any()
2404 .downcast_ref::<Date32Array>()
2405 .unwrap();
2406
2407 assert_eq!(name.value(0), "A1");
2408 assert_eq!(name.value(1), "B2");
2409 assert!(name.is_null(2));
2410 assert_eq!(name.value(3), "C3");
2411 assert_eq!(name.value(4), "D4");
2412 assert_eq!(name.value(5), "E5");
2413
2414 assert_eq!(age.value(0), 34);
2415 assert_eq!(age.value(1), 29);
2416 assert!(age.is_null(2));
2417 assert_eq!(age.value(3), 45);
2418 assert!(age.is_null(4));
2419 assert_eq!(age.value(5), 31);
2420
2421 assert_eq!(occupation.value(0), "Engineer");
2422 assert_eq!(occupation.value(1), "Doctor");
2423 assert!(occupation.is_null(2));
2424 assert_eq!(occupation.value(3), "Artist");
2425 assert!(occupation.is_null(4));
2426 assert!(occupation.is_null(5));
2427
2428 assert_eq!(dob.value(0), 5675);
2429 assert!(dob.is_null(1));
2430 assert!(dob.is_null(2));
2431 assert_eq!(dob.value(3), -1858);
2432 assert!(dob.is_null(4));
2433 assert!(dob.is_null(5));
2434 }
2435
2436 #[test]
2437 fn test_truncated_rows_not_nullable_error() {
2438 let data = "a,b,c\n1,2,3\n4,5";
2439 let schema = Arc::new(Schema::new(vec![
2440 Field::new("a", DataType::Int32, false),
2441 Field::new("b", DataType::Int32, false),
2442 Field::new("c", DataType::Int32, false),
2443 ]));
2444
2445 let reader = ReaderBuilder::new(schema.clone())
2446 .with_header(true)
2447 .with_truncated_rows(true)
2448 .build(Cursor::new(data))
2449 .unwrap();
2450
2451 let batches = reader.collect::<Result<Vec<_>, _>>();
2452 assert!(match batches {
2453 Err(ArrowError::InvalidArgumentError(e)) =>
2454 e.to_string().contains("contains null values"),
2455 _ => false,
2456 });
2457 }
2458
2459 #[test]
2460 fn test_buffered() {
2461 let tests = [
2462 ("test/data/uk_cities.csv", false, 37),
2463 ("test/data/various_types.csv", true, 10),
2464 ("test/data/decimal_test.csv", false, 10),
2465 ];
2466
2467 for (path, has_header, expected_rows) in tests {
2468 let (schema, _) = Format::default()
2469 .infer_schema(File::open(path).unwrap(), None)
2470 .unwrap();
2471 let schema = Arc::new(schema);
2472
2473 for batch_size in [1, 4] {
2474 for capacity in [1, 3, 7, 100] {
2475 let reader = ReaderBuilder::new(schema.clone())
2476 .with_batch_size(batch_size)
2477 .with_header(has_header)
2478 .build(File::open(path).unwrap())
2479 .unwrap();
2480
2481 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2482
2483 assert_eq!(
2484 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2485 expected_rows
2486 );
2487
2488 let buffered =
2489 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2490
2491 let reader = ReaderBuilder::new(schema.clone())
2492 .with_batch_size(batch_size)
2493 .with_header(has_header)
2494 .build_buffered(buffered)
2495 .unwrap();
2496
2497 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2498 assert_eq!(expected, actual)
2499 }
2500 }
2501 }
2502 }
2503
2504 fn err_test(csv: &[u8], expected: &str) {
2505 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2506 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2507 let b = ReaderBuilder::new(schema)
2508 .with_batch_size(2)
2509 .build_buffered(buffer)
2510 .unwrap();
2511 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2512 assert_eq!(err, expected)
2513 }
2514
2515 let schema_utf8 = Arc::new(Schema::new(vec![
2516 Field::new("text1", DataType::Utf8, true),
2517 Field::new("text2", DataType::Utf8, true),
2518 ]));
2519 err_test_with_schema(csv, expected, schema_utf8);
2520
2521 let schema_utf8view = Arc::new(Schema::new(vec![
2522 Field::new("text1", DataType::Utf8View, true),
2523 Field::new("text2", DataType::Utf8View, true),
2524 ]));
2525 err_test_with_schema(csv, expected, schema_utf8view);
2526 }
2527
2528 #[test]
2529 fn test_invalid_utf8() {
2530 err_test(
2531 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2532 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2533 );
2534
2535 err_test(
2536 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2537 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2538 );
2539
2540 err_test(
2541 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2542 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2543 );
2544
2545 err_test(
2546 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2547 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2548 );
2549 }
2550
2551 struct InstrumentedRead<R> {
2552 r: R,
2553 fill_count: usize,
2554 fill_sizes: Vec<usize>,
2555 }
2556
2557 impl<R> InstrumentedRead<R> {
2558 fn new(r: R) -> Self {
2559 Self {
2560 r,
2561 fill_count: 0,
2562 fill_sizes: vec![],
2563 }
2564 }
2565 }
2566
2567 impl<R: Seek> Seek for InstrumentedRead<R> {
2568 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2569 self.r.seek(pos)
2570 }
2571 }
2572
2573 impl<R: BufRead> Read for InstrumentedRead<R> {
2574 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2575 self.r.read(buf)
2576 }
2577 }
2578
2579 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2580 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2581 self.fill_count += 1;
2582 let buf = self.r.fill_buf()?;
2583 self.fill_sizes.push(buf.len());
2584 Ok(buf)
2585 }
2586
2587 fn consume(&mut self, amt: usize) {
2588 self.r.consume(amt)
2589 }
2590 }
2591
2592 #[test]
2593 fn test_io() {
2594 let schema = Arc::new(Schema::new(vec![
2595 Field::new("a", DataType::Utf8, false),
2596 Field::new("b", DataType::Utf8, false),
2597 ]));
2598 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2599 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2600 let reader = ReaderBuilder::new(schema)
2601 .with_batch_size(3)
2602 .build_buffered(&mut read)
2603 .unwrap();
2604
2605 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2606 assert_eq!(batches.len(), 2);
2607 assert_eq!(batches[0].num_rows(), 3);
2608 assert_eq!(batches[1].num_rows(), 1);
2609
2610 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2616 assert_eq!(read.fill_count, 4);
2617 }
2618
2619 #[test]
2620 fn test_inference() {
2621 let cases: &[(&[&str], DataType)] = &[
2622 (&[], DataType::Null),
2623 (&["false", "12"], DataType::Utf8),
2624 (&["12", "cupcakes"], DataType::Utf8),
2625 (&["12", "12.4"], DataType::Float64),
2626 (&["14050", "24332"], DataType::Int64),
2627 (&["14050.0", "true"], DataType::Utf8),
2628 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2629 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2630 (
2631 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2632 DataType::Timestamp(TimeUnit::Second, None),
2633 ),
2634 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2635 (
2636 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2637 DataType::Timestamp(TimeUnit::Second, None),
2638 ),
2639 (
2640 &[
2641 "2020-03-19",
2642 "2020-03-19 02:00:00",
2643 "2020-03-19 00:00:00.000",
2644 ],
2645 DataType::Timestamp(TimeUnit::Millisecond, None),
2646 ),
2647 (
2648 &[
2649 "2020-03-19",
2650 "2020-03-19 02:00:00",
2651 "2020-03-19 00:00:00.000000",
2652 ],
2653 DataType::Timestamp(TimeUnit::Microsecond, None),
2654 ),
2655 (
2656 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2657 DataType::Timestamp(TimeUnit::Second, None),
2658 ),
2659 (
2660 &[
2661 "2020-03-19",
2662 "2020-03-19 02:00:00+02:00",
2663 "2020-03-19 02:00:00Z",
2664 "2020-03-19 02:00:00.12Z",
2665 ],
2666 DataType::Timestamp(TimeUnit::Millisecond, None),
2667 ),
2668 (
2669 &[
2670 "2020-03-19",
2671 "2020-03-19 02:00:00.000000000",
2672 "2020-03-19 00:00:00.000000",
2673 ],
2674 DataType::Timestamp(TimeUnit::Nanosecond, None),
2675 ),
2676 ];
2677
2678 for (values, expected) in cases {
2679 let mut t = InferredDataType::default();
2680 for v in *values {
2681 t.update(v)
2682 }
2683 assert_eq!(&t.get(), expected, "{values:?}")
2684 }
2685 }
2686
2687 #[test]
2688 fn test_record_length_mismatch() {
2689 let csv = "\
2690 a,b,c\n\
2691 1,2,3\n\
2692 4,5\n\
2693 6,7,8";
2694 let mut read = Cursor::new(csv.as_bytes());
2695 let result = Format::default()
2696 .with_header(true)
2697 .infer_schema(&mut read, None);
2698 assert!(result.is_err());
2699 assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3");
2701 }
2702
2703 #[test]
2704 fn test_comment() {
2705 let schema = Schema::new(vec![
2706 Field::new("a", DataType::Int8, false),
2707 Field::new("b", DataType::Int8, false),
2708 ]);
2709
2710 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2711 let mut read = Cursor::new(csv.as_bytes());
2712 let reader = ReaderBuilder::new(Arc::new(schema))
2713 .with_comment(b'#')
2714 .build(&mut read)
2715 .unwrap();
2716
2717 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2718 assert_eq!(batches.len(), 1);
2719 let b = batches.first().unwrap();
2720 assert_eq!(b.num_columns(), 2);
2721 assert_eq!(
2722 b.column(0)
2723 .as_any()
2724 .downcast_ref::<Int8Array>()
2725 .unwrap()
2726 .values(),
2727 &vec![1, 11]
2728 );
2729 assert_eq!(
2730 b.column(1)
2731 .as_any()
2732 .downcast_ref::<Int8Array>()
2733 .unwrap()
2734 .values(),
2735 &vec![2, 22]
2736 );
2737 }
2738
2739 #[test]
2740 fn test_parse_string_view_single_column() {
2741 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2742 let schema = Arc::new(Schema::new(vec![Field::new(
2743 "c1",
2744 DataType::Utf8View,
2745 true,
2746 )]));
2747
2748 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2749
2750 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2751 assert_eq!(decoded, csv.len());
2752 decoder.decode(&[]).unwrap();
2753
2754 let batch = decoder.flush().unwrap().unwrap();
2755 assert_eq!(batch.num_columns(), 1);
2756 assert_eq!(batch.num_rows(), 3);
2757 let col = batch.column(0).as_string_view();
2758 assert_eq!(col.data_type(), &DataType::Utf8View);
2759 assert_eq!(col.value(0), "foo");
2760 assert_eq!(col.value(1), "something_cannot_be_inlined");
2761 assert_eq!(col.value(2), "foobar");
2762 }
2763
2764 #[test]
2765 fn test_parse_string_view_multi_column() {
2766 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2767 let schema = Arc::new(Schema::new(vec![
2768 Field::new("c1", DataType::Utf8View, true),
2769 Field::new("c2", DataType::Utf8View, true),
2770 ]));
2771
2772 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2773
2774 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2775 assert_eq!(decoded, csv.len());
2776 decoder.decode(&[]).unwrap();
2777
2778 let batch = decoder.flush().unwrap().unwrap();
2779 assert_eq!(batch.num_columns(), 2);
2780 assert_eq!(batch.num_rows(), 3);
2781 let c1 = batch.column(0).as_string_view();
2782 let c2 = batch.column(1).as_string_view();
2783 assert_eq!(c1.data_type(), &DataType::Utf8View);
2784 assert_eq!(c2.data_type(), &DataType::Utf8View);
2785
2786 assert!(!c1.is_null(0));
2787 assert!(c1.is_null(1));
2788 assert!(!c1.is_null(2));
2789 assert_eq!(c1.value(0), "foo");
2790 assert_eq!(c1.value(2), "foobarfoobar");
2791
2792 assert!(c2.is_null(0));
2793 assert!(!c2.is_null(1));
2794 assert!(!c2.is_null(2));
2795 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2796 assert_eq!(c2.value(2), "bar");
2797 }
2798}