1mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use lazy_static::lazy_static;
136use regex::{Regex, RegexSet};
137use std::fmt::{self, Debug};
138use std::fs::File;
139use std::io::{BufRead, BufReader as StdBufReader, Read};
140use std::sync::Arc;
141
142use crate::map_csv_error;
143use crate::reader::records::{RecordDecoder, StringRecords};
144use arrow_array::timezone::Tz;
145
146lazy_static! {
147 static ref REGEX_SET: RegexSet = RegexSet::new([
149 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ]).unwrap();
158}
159
160#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165 #[inline]
168 fn is_null(&self, s: &str) -> bool {
169 match &self.0 {
170 Some(r) => r.is_match(s),
171 None => s.is_empty(),
172 }
173 }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178 packed: u16,
190}
191
192impl InferredDataType {
193 fn get(&self) -> DataType {
195 match self.packed {
196 0 => DataType::Null,
197 1 => DataType::Boolean,
198 2 => DataType::Int64,
199 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205 11 => DataType::Timestamp(TimeUnit::Second, None),
206 12 => DataType::Date32,
207 _ => unreachable!(),
208 },
209 _ => DataType::Utf8,
210 }
211 }
212
213 fn update(&mut self, string: &str) {
215 self.packed |= if string.starts_with('"') {
216 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219 1 << 8
221 } else {
222 1 << m
223 }
224 } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225 1 << 2 } else {
227 1 << 8 }
229 }
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct Format {
235 header: bool,
236 delimiter: Option<u8>,
237 escape: Option<u8>,
238 quote: Option<u8>,
239 terminator: Option<u8>,
240 comment: Option<u8>,
241 null_regex: NullRegex,
242 truncated_rows: bool,
243}
244
245impl Format {
246 pub fn with_header(mut self, has_header: bool) -> Self {
250 self.header = has_header;
251 self
252 }
253
254 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256 self.delimiter = Some(delimiter);
257 self
258 }
259
260 pub fn with_escape(mut self, escape: u8) -> Self {
262 self.escape = Some(escape);
263 self
264 }
265
266 pub fn with_quote(mut self, quote: u8) -> Self {
268 self.quote = Some(quote);
269 self
270 }
271
272 pub fn with_terminator(mut self, terminator: u8) -> Self {
274 self.terminator = Some(terminator);
275 self
276 }
277
278 pub fn with_comment(mut self, comment: u8) -> Self {
282 self.comment = Some(comment);
283 self
284 }
285
286 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288 self.null_regex = NullRegex(Some(null_regex));
289 self
290 }
291
292 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299 self.truncated_rows = allow;
300 self
301 }
302
303 pub fn infer_schema<R: Read>(
310 &self,
311 reader: R,
312 max_records: Option<usize>,
313 ) -> Result<(Schema, usize), ArrowError> {
314 let mut csv_reader = self.build_reader(reader);
315
316 let headers: Vec<String> = if self.header {
319 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320 headers.iter().map(|s| s.to_string()).collect()
321 } else {
322 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323 (0..*first_record_count)
324 .map(|i| format!("column_{}", i + 1))
325 .collect()
326 };
327
328 let header_length = headers.len();
329 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332 let mut records_count = 0;
333
334 let mut record = StringRecord::new();
335 let max_records = max_records.unwrap_or(usize::MAX);
336 while records_count < max_records {
337 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338 break;
339 }
340 records_count += 1;
341
342 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345 if let Some(string) = record.get(i) {
346 if !self.null_regex.is_null(string) {
347 column_type.update(string)
348 }
349 }
350 }
351 }
352
353 let fields: Fields = column_types
355 .iter()
356 .zip(&headers)
357 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358 .collect();
359
360 Ok((Schema::new(fields), records_count))
361 }
362
363 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365 let mut builder = csv::ReaderBuilder::new();
366 builder.has_headers(self.header);
367 builder.flexible(self.truncated_rows);
368
369 if let Some(c) = self.delimiter {
370 builder.delimiter(c);
371 }
372 builder.escape(self.escape);
373 if let Some(c) = self.quote {
374 builder.quote(c);
375 }
376 if let Some(t) = self.terminator {
377 builder.terminator(csv::Terminator::Any(t));
378 }
379 if let Some(comment) = self.comment {
380 builder.comment(Some(comment));
381 }
382 builder.from_reader(reader)
383 }
384
385 fn build_parser(&self) -> csv_core::Reader {
387 let mut builder = csv_core::ReaderBuilder::new();
388 builder.escape(self.escape);
389 builder.comment(self.comment);
390
391 if let Some(c) = self.delimiter {
392 builder.delimiter(c);
393 }
394 if let Some(c) = self.quote {
395 builder.quote(c);
396 }
397 if let Some(t) = self.terminator {
398 builder.terminator(csv_core::Terminator::Any(t));
399 }
400 builder.build()
401 }
402}
403
404pub fn infer_schema_from_files(
411 files: &[String],
412 delimiter: u8,
413 max_read_records: Option<usize>,
414 has_header: bool,
415) -> Result<Schema, ArrowError> {
416 let mut schemas = vec![];
417 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418 let format = Format {
419 delimiter: Some(delimiter),
420 header: has_header,
421 ..Default::default()
422 };
423
424 for fname in files.iter() {
425 let f = File::open(fname)?;
426 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427 if records_read == 0 {
428 continue;
429 }
430 schemas.push(schema.clone());
431 records_to_read -= records_read;
432 if records_to_read == 0 {
433 break;
434 }
435 }
436
437 Schema::try_merge(schemas)
438}
439
440type Bounds = Option<(usize, usize)>;
442
443pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446pub struct BufReader<R> {
448 reader: R,
450
451 decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457 R: BufRead,
458{
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 f.debug_struct("Reader")
461 .field("decoder", &self.decoder)
462 .finish()
463 }
464}
465
466impl<R: Read> Reader<R> {
467 pub fn schema(&self) -> SchemaRef {
470 match &self.decoder.projection {
471 Some(projection) => {
472 let fields = self.decoder.schema.fields();
473 let projected = projection.iter().map(|i| fields[*i].clone());
474 Arc::new(Schema::new(projected.collect::<Fields>()))
475 }
476 None => self.decoder.schema.clone(),
477 }
478 }
479}
480
481impl<R: BufRead> BufReader<R> {
482 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483 loop {
484 let buf = self.reader.fill_buf()?;
485 let decoded = self.decoder.decode(buf)?;
486 self.reader.consume(decoded);
487 if decoded == 0 || self.decoder.capacity() == 0 {
493 break;
494 }
495 }
496
497 self.decoder.flush()
498 }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502 type Item = Result<RecordBatch, ArrowError>;
503
504 fn next(&mut self) -> Option<Self::Item> {
505 self.read().transpose()
506 }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510 fn schema(&self) -> SchemaRef {
511 self.decoder.schema.clone()
512 }
513}
514
515#[derive(Debug)]
555pub struct Decoder {
556 schema: SchemaRef,
558
559 projection: Option<Vec<usize>>,
561
562 batch_size: usize,
564
565 to_skip: usize,
567
568 line_number: usize,
570
571 end: usize,
573
574 record_decoder: RecordDecoder,
576
577 null_regex: NullRegex,
579}
580
581impl Decoder {
582 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592 if self.to_skip != 0 {
593 let to_skip = self.to_skip.min(self.batch_size);
595 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596 self.to_skip -= skipped;
597 self.record_decoder.clear();
598 return Ok(bytes);
599 }
600
601 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603 Ok(bytes)
604 }
605
606 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613 if self.record_decoder.is_empty() {
614 return Ok(None);
615 }
616
617 let rows = self.record_decoder.flush()?;
618 let batch = parse(
619 &rows,
620 self.schema.fields(),
621 Some(self.schema.metadata.clone()),
622 self.projection.as_ref(),
623 self.line_number,
624 &self.null_regex,
625 )?;
626 self.line_number += rows.len();
627 Ok(Some(batch))
628 }
629
630 pub fn capacity(&self) -> usize {
632 self.batch_size - self.record_decoder.len()
633 }
634}
635
636fn parse(
638 rows: &StringRecords<'_>,
639 fields: &Fields,
640 metadata: Option<std::collections::HashMap<String, String>>,
641 projection: Option<&Vec<usize>>,
642 line_number: usize,
643 null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645 let projection: Vec<usize> = match projection {
646 Some(v) => v.clone(),
647 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648 };
649
650 let arrays: Result<Vec<ArrayRef>, _> = projection
651 .iter()
652 .map(|i| {
653 let i = *i;
654 let field = &fields[i];
655 match field.data_type() {
656 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
658 line_number,
659 rows,
660 i,
661 *precision,
662 *scale,
663 null_regex,
664 ),
665 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
666 line_number,
667 rows,
668 i,
669 *precision,
670 *scale,
671 null_regex,
672 ),
673 DataType::Int8 => {
674 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
675 }
676 DataType::Int16 => {
677 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
678 }
679 DataType::Int32 => {
680 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
681 }
682 DataType::Int64 => {
683 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
684 }
685 DataType::UInt8 => {
686 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
687 }
688 DataType::UInt16 => {
689 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
690 }
691 DataType::UInt32 => {
692 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
693 }
694 DataType::UInt64 => {
695 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
696 }
697 DataType::Float32 => {
698 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
699 }
700 DataType::Float64 => {
701 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
702 }
703 DataType::Date32 => {
704 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
705 }
706 DataType::Date64 => {
707 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
708 }
709 DataType::Time32(TimeUnit::Second) => {
710 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
711 }
712 DataType::Time32(TimeUnit::Millisecond) => {
713 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
714 }
715 DataType::Time64(TimeUnit::Microsecond) => {
716 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
717 }
718 DataType::Time64(TimeUnit::Nanosecond) => {
719 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
720 }
721 DataType::Timestamp(TimeUnit::Second, tz) => {
722 build_timestamp_array::<TimestampSecondType>(
723 line_number,
724 rows,
725 i,
726 tz.as_deref(),
727 null_regex,
728 )
729 }
730 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
731 build_timestamp_array::<TimestampMillisecondType>(
732 line_number,
733 rows,
734 i,
735 tz.as_deref(),
736 null_regex,
737 )
738 }
739 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
740 build_timestamp_array::<TimestampMicrosecondType>(
741 line_number,
742 rows,
743 i,
744 tz.as_deref(),
745 null_regex,
746 )
747 }
748 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
749 build_timestamp_array::<TimestampNanosecondType>(
750 line_number,
751 rows,
752 i,
753 tz.as_deref(),
754 null_regex,
755 )
756 }
757 DataType::Null => Ok(Arc::new({
758 let mut builder = NullBuilder::new();
759 builder.append_nulls(rows.len());
760 builder.finish()
761 }) as ArrayRef),
762 DataType::Utf8 => Ok(Arc::new(
763 rows.iter()
764 .map(|row| {
765 let s = row.get(i);
766 (!null_regex.is_null(s)).then_some(s)
767 })
768 .collect::<StringArray>(),
769 ) as ArrayRef),
770 DataType::Utf8View => Ok(Arc::new(
771 rows.iter()
772 .map(|row| {
773 let s = row.get(i);
774 (!null_regex.is_null(s)).then_some(s)
775 })
776 .collect::<StringViewArray>(),
777 ) as ArrayRef),
778 DataType::Dictionary(key_type, value_type)
779 if value_type.as_ref() == &DataType::Utf8 =>
780 {
781 match key_type.as_ref() {
782 DataType::Int8 => Ok(Arc::new(
783 rows.iter()
784 .map(|row| {
785 let s = row.get(i);
786 (!null_regex.is_null(s)).then_some(s)
787 })
788 .collect::<DictionaryArray<Int8Type>>(),
789 ) as ArrayRef),
790 DataType::Int16 => Ok(Arc::new(
791 rows.iter()
792 .map(|row| {
793 let s = row.get(i);
794 (!null_regex.is_null(s)).then_some(s)
795 })
796 .collect::<DictionaryArray<Int16Type>>(),
797 ) as ArrayRef),
798 DataType::Int32 => Ok(Arc::new(
799 rows.iter()
800 .map(|row| {
801 let s = row.get(i);
802 (!null_regex.is_null(s)).then_some(s)
803 })
804 .collect::<DictionaryArray<Int32Type>>(),
805 ) as ArrayRef),
806 DataType::Int64 => Ok(Arc::new(
807 rows.iter()
808 .map(|row| {
809 let s = row.get(i);
810 (!null_regex.is_null(s)).then_some(s)
811 })
812 .collect::<DictionaryArray<Int64Type>>(),
813 ) as ArrayRef),
814 DataType::UInt8 => Ok(Arc::new(
815 rows.iter()
816 .map(|row| {
817 let s = row.get(i);
818 (!null_regex.is_null(s)).then_some(s)
819 })
820 .collect::<DictionaryArray<UInt8Type>>(),
821 ) as ArrayRef),
822 DataType::UInt16 => Ok(Arc::new(
823 rows.iter()
824 .map(|row| {
825 let s = row.get(i);
826 (!null_regex.is_null(s)).then_some(s)
827 })
828 .collect::<DictionaryArray<UInt16Type>>(),
829 ) as ArrayRef),
830 DataType::UInt32 => Ok(Arc::new(
831 rows.iter()
832 .map(|row| {
833 let s = row.get(i);
834 (!null_regex.is_null(s)).then_some(s)
835 })
836 .collect::<DictionaryArray<UInt32Type>>(),
837 ) as ArrayRef),
838 DataType::UInt64 => Ok(Arc::new(
839 rows.iter()
840 .map(|row| {
841 let s = row.get(i);
842 (!null_regex.is_null(s)).then_some(s)
843 })
844 .collect::<DictionaryArray<UInt64Type>>(),
845 ) as ArrayRef),
846 _ => Err(ArrowError::ParseError(format!(
847 "Unsupported dictionary key type {key_type:?}"
848 ))),
849 }
850 }
851 other => Err(ArrowError::ParseError(format!(
852 "Unsupported data type {other:?}"
853 ))),
854 }
855 })
856 .collect();
857
858 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
859
860 let projected_schema = Arc::new(match metadata {
861 None => Schema::new(projected_fields),
862 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
863 });
864
865 arrays.and_then(|arr| {
866 RecordBatch::try_new_with_options(
867 projected_schema,
868 arr,
869 &RecordBatchOptions::new()
870 .with_match_field_names(true)
871 .with_row_count(Some(rows.len())),
872 )
873 })
874}
875
876fn parse_bool(string: &str) -> Option<bool> {
877 if string.eq_ignore_ascii_case("false") {
878 Some(false)
879 } else if string.eq_ignore_ascii_case("true") {
880 Some(true)
881 } else {
882 None
883 }
884}
885
886fn build_decimal_array<T: DecimalType>(
888 _line_number: usize,
889 rows: &StringRecords<'_>,
890 col_idx: usize,
891 precision: u8,
892 scale: i8,
893 null_regex: &NullRegex,
894) -> Result<ArrayRef, ArrowError> {
895 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
896 for row in rows.iter() {
897 let s = row.get(col_idx);
898 if null_regex.is_null(s) {
899 decimal_builder.append_null();
901 } else {
902 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
903 match decimal_value {
904 Ok(v) => {
905 decimal_builder.append_value(v);
906 }
907 Err(e) => {
908 return Err(e);
909 }
910 }
911 }
912 }
913 Ok(Arc::new(
914 decimal_builder
915 .finish()
916 .with_precision_and_scale(precision, scale)?,
917 ))
918}
919
920fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
922 line_number: usize,
923 rows: &StringRecords<'_>,
924 col_idx: usize,
925 null_regex: &NullRegex,
926) -> Result<ArrayRef, ArrowError> {
927 rows.iter()
928 .enumerate()
929 .map(|(row_index, row)| {
930 let s = row.get(col_idx);
931 if null_regex.is_null(s) {
932 return Ok(None);
933 }
934
935 match T::parse(s) {
936 Some(e) => Ok(Some(e)),
937 None => Err(ArrowError::ParseError(format!(
938 "Error while parsing value {} for column {} at line {}",
940 s,
941 col_idx,
942 line_number + row_index
943 ))),
944 }
945 })
946 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
947 .map(|e| Arc::new(e) as ArrayRef)
948}
949
950fn build_timestamp_array<T: ArrowTimestampType>(
951 line_number: usize,
952 rows: &StringRecords<'_>,
953 col_idx: usize,
954 timezone: Option<&str>,
955 null_regex: &NullRegex,
956) -> Result<ArrayRef, ArrowError> {
957 Ok(Arc::new(match timezone {
958 Some(timezone) => {
959 let tz: Tz = timezone.parse()?;
960 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
961 .with_timezone(timezone)
962 }
963 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
964 }))
965}
966
967fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
968 line_number: usize,
969 rows: &StringRecords<'_>,
970 col_idx: usize,
971 timezone: &Tz,
972 null_regex: &NullRegex,
973) -> Result<PrimitiveArray<T>, ArrowError> {
974 rows.iter()
975 .enumerate()
976 .map(|(row_index, row)| {
977 let s = row.get(col_idx);
978 if null_regex.is_null(s) {
979 return Ok(None);
980 }
981
982 let date = string_to_datetime(timezone, s)
983 .and_then(|date| match T::UNIT {
984 TimeUnit::Second => Ok(date.timestamp()),
985 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
986 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
987 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
988 ArrowError::ParseError(format!(
989 "{} would overflow 64-bit signed nanoseconds",
990 date.to_rfc3339(),
991 ))
992 }),
993 })
994 .map_err(|e| {
995 ArrowError::ParseError(format!(
996 "Error parsing column {col_idx} at line {}: {}",
997 line_number + row_index,
998 e
999 ))
1000 })?;
1001 Ok(Some(date))
1002 })
1003 .collect()
1004}
1005
1006fn build_boolean_array(
1008 line_number: usize,
1009 rows: &StringRecords<'_>,
1010 col_idx: usize,
1011 null_regex: &NullRegex,
1012) -> Result<ArrayRef, ArrowError> {
1013 rows.iter()
1014 .enumerate()
1015 .map(|(row_index, row)| {
1016 let s = row.get(col_idx);
1017 if null_regex.is_null(s) {
1018 return Ok(None);
1019 }
1020 let parsed = parse_bool(s);
1021 match parsed {
1022 Some(e) => Ok(Some(e)),
1023 None => Err(ArrowError::ParseError(format!(
1024 "Error while parsing value {} for column {} at line {}",
1026 s,
1027 col_idx,
1028 line_number + row_index
1029 ))),
1030 }
1031 })
1032 .collect::<Result<BooleanArray, _>>()
1033 .map(|e| Arc::new(e) as ArrayRef)
1034}
1035
1036#[derive(Debug)]
1038pub struct ReaderBuilder {
1039 schema: SchemaRef,
1041 format: Format,
1043 batch_size: usize,
1047 bounds: Bounds,
1049 projection: Option<Vec<usize>>,
1051}
1052
1053impl ReaderBuilder {
1054 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1076 Self {
1077 schema,
1078 format: Format::default(),
1079 batch_size: 1024,
1080 bounds: None,
1081 projection: None,
1082 }
1083 }
1084
1085 pub fn with_header(mut self, has_header: bool) -> Self {
1087 self.format.header = has_header;
1088 self
1089 }
1090
1091 pub fn with_format(mut self, format: Format) -> Self {
1093 self.format = format;
1094 self
1095 }
1096
1097 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1099 self.format.delimiter = Some(delimiter);
1100 self
1101 }
1102
1103 pub fn with_escape(mut self, escape: u8) -> Self {
1105 self.format.escape = Some(escape);
1106 self
1107 }
1108
1109 pub fn with_quote(mut self, quote: u8) -> Self {
1111 self.format.quote = Some(quote);
1112 self
1113 }
1114
1115 pub fn with_terminator(mut self, terminator: u8) -> Self {
1117 self.format.terminator = Some(terminator);
1118 self
1119 }
1120
1121 pub fn with_comment(mut self, comment: u8) -> Self {
1123 self.format.comment = Some(comment);
1124 self
1125 }
1126
1127 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1129 self.format.null_regex = NullRegex(Some(null_regex));
1130 self
1131 }
1132
1133 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1135 self.batch_size = batch_size;
1136 self
1137 }
1138
1139 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1142 self.bounds = Some((start, end));
1143 self
1144 }
1145
1146 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1148 self.projection = Some(projection);
1149 self
1150 }
1151
1152 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1159 self.format.truncated_rows = allow;
1160 self
1161 }
1162
1163 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1168 self.build_buffered(StdBufReader::new(reader))
1169 }
1170
1171 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1173 Ok(BufReader {
1174 reader,
1175 decoder: self.build_decoder(),
1176 })
1177 }
1178
1179 pub fn build_decoder(self) -> Decoder {
1181 let delimiter = self.format.build_parser();
1182 let record_decoder = RecordDecoder::new(
1183 delimiter,
1184 self.schema.fields().len(),
1185 self.format.truncated_rows,
1186 );
1187
1188 let header = self.format.header as usize;
1189
1190 let (start, end) = match self.bounds {
1191 Some((start, end)) => (start + header, end + header),
1192 None => (header, usize::MAX),
1193 };
1194
1195 Decoder {
1196 schema: self.schema,
1197 to_skip: start,
1198 record_decoder,
1199 line_number: start,
1200 end,
1201 projection: self.projection,
1202 batch_size: self.batch_size,
1203 null_regex: self.format.null_regex,
1204 }
1205 }
1206}
1207
1208#[cfg(test)]
1209mod tests {
1210 use super::*;
1211
1212 use std::io::{Cursor, Seek, SeekFrom, Write};
1213 use tempfile::NamedTempFile;
1214
1215 use arrow_array::cast::AsArray;
1216
1217 #[test]
1218 fn test_csv() {
1219 let schema = Arc::new(Schema::new(vec![
1220 Field::new("city", DataType::Utf8, false),
1221 Field::new("lat", DataType::Float64, false),
1222 Field::new("lng", DataType::Float64, false),
1223 ]));
1224
1225 let file = File::open("test/data/uk_cities.csv").unwrap();
1226 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1227 assert_eq!(schema, csv.schema());
1228 let batch = csv.next().unwrap().unwrap();
1229 assert_eq!(37, batch.num_rows());
1230 assert_eq!(3, batch.num_columns());
1231
1232 let lat = batch.column(1).as_primitive::<Float64Type>();
1234 assert_eq!(57.653484, lat.value(0));
1235
1236 let city = batch.column(0).as_string::<i32>();
1238
1239 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1240 }
1241
1242 #[test]
1243 fn test_csv_schema_metadata() {
1244 let mut metadata = std::collections::HashMap::new();
1245 metadata.insert("foo".to_owned(), "bar".to_owned());
1246 let schema = Arc::new(Schema::new_with_metadata(
1247 vec![
1248 Field::new("city", DataType::Utf8, false),
1249 Field::new("lat", DataType::Float64, false),
1250 Field::new("lng", DataType::Float64, false),
1251 ],
1252 metadata.clone(),
1253 ));
1254
1255 let file = File::open("test/data/uk_cities.csv").unwrap();
1256
1257 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1258 assert_eq!(schema, csv.schema());
1259 let batch = csv.next().unwrap().unwrap();
1260 assert_eq!(37, batch.num_rows());
1261 assert_eq!(3, batch.num_columns());
1262
1263 assert_eq!(&metadata, batch.schema().metadata());
1264 }
1265
1266 #[test]
1267 fn test_csv_reader_with_decimal() {
1268 let schema = Arc::new(Schema::new(vec![
1269 Field::new("city", DataType::Utf8, false),
1270 Field::new("lat", DataType::Decimal128(38, 6), false),
1271 Field::new("lng", DataType::Decimal256(76, 6), false),
1272 ]));
1273
1274 let file = File::open("test/data/decimal_test.csv").unwrap();
1275
1276 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1277 let batch = csv.next().unwrap().unwrap();
1278 let lat = batch
1280 .column(1)
1281 .as_any()
1282 .downcast_ref::<Decimal128Array>()
1283 .unwrap();
1284
1285 assert_eq!("57.653484", lat.value_as_string(0));
1286 assert_eq!("53.002666", lat.value_as_string(1));
1287 assert_eq!("52.412811", lat.value_as_string(2));
1288 assert_eq!("51.481583", lat.value_as_string(3));
1289 assert_eq!("12.123456", lat.value_as_string(4));
1290 assert_eq!("50.760000", lat.value_as_string(5));
1291 assert_eq!("0.123000", lat.value_as_string(6));
1292 assert_eq!("123.000000", lat.value_as_string(7));
1293 assert_eq!("123.000000", lat.value_as_string(8));
1294 assert_eq!("-50.760000", lat.value_as_string(9));
1295
1296 let lng = batch
1297 .column(2)
1298 .as_any()
1299 .downcast_ref::<Decimal256Array>()
1300 .unwrap();
1301
1302 assert_eq!("-3.335724", lng.value_as_string(0));
1303 assert_eq!("-2.179404", lng.value_as_string(1));
1304 assert_eq!("-1.778197", lng.value_as_string(2));
1305 assert_eq!("-3.179090", lng.value_as_string(3));
1306 assert_eq!("-3.179090", lng.value_as_string(4));
1307 assert_eq!("0.290472", lng.value_as_string(5));
1308 assert_eq!("0.290472", lng.value_as_string(6));
1309 assert_eq!("0.290472", lng.value_as_string(7));
1310 assert_eq!("0.290472", lng.value_as_string(8));
1311 assert_eq!("0.290472", lng.value_as_string(9));
1312 }
1313
1314 #[test]
1315 fn test_csv_from_buf_reader() {
1316 let schema = Schema::new(vec![
1317 Field::new("city", DataType::Utf8, false),
1318 Field::new("lat", DataType::Float64, false),
1319 Field::new("lng", DataType::Float64, false),
1320 ]);
1321
1322 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1323 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1324 let both_files = file_with_headers
1325 .chain(Cursor::new("\n".to_string()))
1326 .chain(file_without_headers);
1327 let mut csv = ReaderBuilder::new(Arc::new(schema))
1328 .with_header(true)
1329 .build(both_files)
1330 .unwrap();
1331 let batch = csv.next().unwrap().unwrap();
1332 assert_eq!(74, batch.num_rows());
1333 assert_eq!(3, batch.num_columns());
1334 }
1335
1336 #[test]
1337 fn test_csv_with_schema_inference() {
1338 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1339
1340 let (schema, _) = Format::default()
1341 .with_header(true)
1342 .infer_schema(&mut file, None)
1343 .unwrap();
1344
1345 file.rewind().unwrap();
1346 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1347
1348 let mut csv = builder.build(file).unwrap();
1349 let expected_schema = Schema::new(vec![
1350 Field::new("city", DataType::Utf8, true),
1351 Field::new("lat", DataType::Float64, true),
1352 Field::new("lng", DataType::Float64, true),
1353 ]);
1354 assert_eq!(Arc::new(expected_schema), csv.schema());
1355 let batch = csv.next().unwrap().unwrap();
1356 assert_eq!(37, batch.num_rows());
1357 assert_eq!(3, batch.num_columns());
1358
1359 let lat = batch
1361 .column(1)
1362 .as_any()
1363 .downcast_ref::<Float64Array>()
1364 .unwrap();
1365 assert_eq!(57.653484, lat.value(0));
1366
1367 let city = batch
1369 .column(0)
1370 .as_any()
1371 .downcast_ref::<StringArray>()
1372 .unwrap();
1373
1374 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1375 }
1376
1377 #[test]
1378 fn test_csv_with_schema_inference_no_headers() {
1379 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1380
1381 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1382 file.rewind().unwrap();
1383
1384 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1385
1386 let schema = csv.schema();
1388 assert_eq!("column_1", schema.field(0).name());
1389 assert_eq!("column_2", schema.field(1).name());
1390 assert_eq!("column_3", schema.field(2).name());
1391 let batch = csv.next().unwrap().unwrap();
1392 let batch_schema = batch.schema();
1393
1394 assert_eq!(schema, batch_schema);
1395 assert_eq!(37, batch.num_rows());
1396 assert_eq!(3, batch.num_columns());
1397
1398 let lat = batch
1400 .column(1)
1401 .as_any()
1402 .downcast_ref::<Float64Array>()
1403 .unwrap();
1404 assert_eq!(57.653484, lat.value(0));
1405
1406 let city = batch
1408 .column(0)
1409 .as_any()
1410 .downcast_ref::<StringArray>()
1411 .unwrap();
1412
1413 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1414 }
1415
1416 #[test]
1417 fn test_csv_builder_with_bounds() {
1418 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1419
1420 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1422 file.rewind().unwrap();
1423 let mut csv = ReaderBuilder::new(Arc::new(schema))
1424 .with_bounds(0, 2)
1425 .build(file)
1426 .unwrap();
1427 let batch = csv.next().unwrap().unwrap();
1428
1429 let city = batch
1431 .column(0)
1432 .as_any()
1433 .downcast_ref::<StringArray>()
1434 .unwrap();
1435
1436 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1438
1439 let result = std::panic::catch_unwind(|| city.value(13));
1442 assert!(result.is_err());
1443 }
1444
1445 #[test]
1446 fn test_csv_with_projection() {
1447 let schema = Arc::new(Schema::new(vec![
1448 Field::new("city", DataType::Utf8, false),
1449 Field::new("lat", DataType::Float64, false),
1450 Field::new("lng", DataType::Float64, false),
1451 ]));
1452
1453 let file = File::open("test/data/uk_cities.csv").unwrap();
1454
1455 let mut csv = ReaderBuilder::new(schema)
1456 .with_projection(vec![0, 1])
1457 .build(file)
1458 .unwrap();
1459
1460 let projected_schema = Arc::new(Schema::new(vec![
1461 Field::new("city", DataType::Utf8, false),
1462 Field::new("lat", DataType::Float64, false),
1463 ]));
1464 assert_eq!(projected_schema, csv.schema());
1465 let batch = csv.next().unwrap().unwrap();
1466 assert_eq!(projected_schema, batch.schema());
1467 assert_eq!(37, batch.num_rows());
1468 assert_eq!(2, batch.num_columns());
1469 }
1470
1471 #[test]
1472 fn test_csv_with_dictionary() {
1473 let schema = Arc::new(Schema::new(vec![
1474 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1475 Field::new("lat", DataType::Float64, false),
1476 Field::new("lng", DataType::Float64, false),
1477 ]));
1478
1479 let file = File::open("test/data/uk_cities.csv").unwrap();
1480
1481 let mut csv = ReaderBuilder::new(schema)
1482 .with_projection(vec![0, 1])
1483 .build(file)
1484 .unwrap();
1485
1486 let projected_schema = Arc::new(Schema::new(vec![
1487 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1488 Field::new("lat", DataType::Float64, false),
1489 ]));
1490 assert_eq!(projected_schema, csv.schema());
1491 let batch = csv.next().unwrap().unwrap();
1492 assert_eq!(projected_schema, batch.schema());
1493 assert_eq!(37, batch.num_rows());
1494 assert_eq!(2, batch.num_columns());
1495
1496 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1497 let strings = strings.as_string::<i32>();
1498
1499 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1500 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1501 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1502 }
1503
1504 #[test]
1505 fn test_csv_with_nullable_dictionary() {
1506 let offset_type = vec![
1507 DataType::Int8,
1508 DataType::Int16,
1509 DataType::Int32,
1510 DataType::Int64,
1511 DataType::UInt8,
1512 DataType::UInt16,
1513 DataType::UInt32,
1514 DataType::UInt64,
1515 ];
1516 for data_type in offset_type {
1517 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1518 let dictionary_type =
1519 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1520 let schema = Arc::new(Schema::new(vec![
1521 Field::new("id", DataType::Utf8, false),
1522 Field::new("name", dictionary_type.clone(), true),
1523 ]));
1524
1525 let mut csv = ReaderBuilder::new(schema)
1526 .build(file.try_clone().unwrap())
1527 .unwrap();
1528
1529 let batch = csv.next().unwrap().unwrap();
1530 assert_eq!(3, batch.num_rows());
1531 assert_eq!(2, batch.num_columns());
1532
1533 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1534 assert!(!names.is_null(2));
1535 assert!(names.is_null(1));
1536 }
1537 }
1538 #[test]
1539 fn test_nulls() {
1540 let schema = Arc::new(Schema::new(vec![
1541 Field::new("c_int", DataType::UInt64, false),
1542 Field::new("c_float", DataType::Float32, true),
1543 Field::new("c_string", DataType::Utf8, true),
1544 Field::new("c_bool", DataType::Boolean, false),
1545 ]));
1546
1547 let file = File::open("test/data/null_test.csv").unwrap();
1548
1549 let mut csv = ReaderBuilder::new(schema)
1550 .with_header(true)
1551 .build(file)
1552 .unwrap();
1553
1554 let batch = csv.next().unwrap().unwrap();
1555
1556 assert!(!batch.column(1).is_null(0));
1557 assert!(!batch.column(1).is_null(1));
1558 assert!(batch.column(1).is_null(2));
1559 assert!(!batch.column(1).is_null(3));
1560 assert!(!batch.column(1).is_null(4));
1561 }
1562
1563 #[test]
1564 fn test_init_nulls() {
1565 let schema = Arc::new(Schema::new(vec![
1566 Field::new("c_int", DataType::UInt64, true),
1567 Field::new("c_float", DataType::Float32, true),
1568 Field::new("c_string", DataType::Utf8, true),
1569 Field::new("c_bool", DataType::Boolean, true),
1570 Field::new("c_null", DataType::Null, true),
1571 ]));
1572 let file = File::open("test/data/init_null_test.csv").unwrap();
1573
1574 let mut csv = ReaderBuilder::new(schema)
1575 .with_header(true)
1576 .build(file)
1577 .unwrap();
1578
1579 let batch = csv.next().unwrap().unwrap();
1580
1581 assert!(batch.column(1).is_null(0));
1582 assert!(!batch.column(1).is_null(1));
1583 assert!(batch.column(1).is_null(2));
1584 assert!(!batch.column(1).is_null(3));
1585 assert!(!batch.column(1).is_null(4));
1586 }
1587
1588 #[test]
1589 fn test_init_nulls_with_inference() {
1590 let format = Format::default().with_header(true).with_delimiter(b',');
1591
1592 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1593 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1594 file.rewind().unwrap();
1595
1596 let expected_schema = Schema::new(vec![
1597 Field::new("c_int", DataType::Int64, true),
1598 Field::new("c_float", DataType::Float64, true),
1599 Field::new("c_string", DataType::Utf8, true),
1600 Field::new("c_bool", DataType::Boolean, true),
1601 Field::new("c_null", DataType::Null, true),
1602 ]);
1603 assert_eq!(schema, expected_schema);
1604
1605 let mut csv = ReaderBuilder::new(Arc::new(schema))
1606 .with_format(format)
1607 .build(file)
1608 .unwrap();
1609
1610 let batch = csv.next().unwrap().unwrap();
1611
1612 assert!(batch.column(1).is_null(0));
1613 assert!(!batch.column(1).is_null(1));
1614 assert!(batch.column(1).is_null(2));
1615 assert!(!batch.column(1).is_null(3));
1616 assert!(!batch.column(1).is_null(4));
1617 }
1618
1619 #[test]
1620 fn test_custom_nulls() {
1621 let schema = Arc::new(Schema::new(vec![
1622 Field::new("c_int", DataType::UInt64, true),
1623 Field::new("c_float", DataType::Float32, true),
1624 Field::new("c_string", DataType::Utf8, true),
1625 Field::new("c_bool", DataType::Boolean, true),
1626 ]));
1627
1628 let file = File::open("test/data/custom_null_test.csv").unwrap();
1629
1630 let null_regex = Regex::new("^nil$").unwrap();
1631
1632 let mut csv = ReaderBuilder::new(schema)
1633 .with_header(true)
1634 .with_null_regex(null_regex)
1635 .build(file)
1636 .unwrap();
1637
1638 let batch = csv.next().unwrap().unwrap();
1639
1640 assert!(batch.column(0).is_null(1));
1642 assert!(batch.column(1).is_null(2));
1643 assert!(batch.column(3).is_null(4));
1644 assert!(batch.column(2).is_null(3));
1645 assert!(!batch.column(2).is_null(4));
1646 }
1647
1648 #[test]
1649 fn test_nulls_with_inference() {
1650 let mut file = File::open("test/data/various_types.csv").unwrap();
1651 let format = Format::default().with_header(true).with_delimiter(b'|');
1652
1653 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1654 file.rewind().unwrap();
1655
1656 let builder = ReaderBuilder::new(Arc::new(schema))
1657 .with_format(format)
1658 .with_batch_size(512)
1659 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1660
1661 let mut csv = builder.build(file).unwrap();
1662 let batch = csv.next().unwrap().unwrap();
1663
1664 assert_eq!(10, batch.num_rows());
1665 assert_eq!(6, batch.num_columns());
1666
1667 let schema = batch.schema();
1668
1669 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1670 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1671 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1672 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1673 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1674 assert_eq!(
1675 &DataType::Timestamp(TimeUnit::Second, None),
1676 schema.field(5).data_type()
1677 );
1678
1679 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1680 assert_eq!(
1681 names,
1682 vec![
1683 "c_int",
1684 "c_float",
1685 "c_string",
1686 "c_bool",
1687 "c_date",
1688 "c_datetime"
1689 ]
1690 );
1691
1692 assert!(schema.field(0).is_nullable());
1693 assert!(schema.field(1).is_nullable());
1694 assert!(schema.field(2).is_nullable());
1695 assert!(schema.field(3).is_nullable());
1696 assert!(schema.field(4).is_nullable());
1697 assert!(schema.field(5).is_nullable());
1698
1699 assert!(!batch.column(1).is_null(0));
1700 assert!(!batch.column(1).is_null(1));
1701 assert!(batch.column(1).is_null(2));
1702 assert!(!batch.column(1).is_null(3));
1703 assert!(!batch.column(1).is_null(4));
1704 }
1705
1706 #[test]
1707 fn test_custom_nulls_with_inference() {
1708 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1709
1710 let null_regex = Regex::new("^nil$").unwrap();
1711
1712 let format = Format::default()
1713 .with_header(true)
1714 .with_null_regex(null_regex);
1715
1716 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1717 file.rewind().unwrap();
1718
1719 let expected_schema = Schema::new(vec![
1720 Field::new("c_int", DataType::Int64, true),
1721 Field::new("c_float", DataType::Float64, true),
1722 Field::new("c_string", DataType::Utf8, true),
1723 Field::new("c_bool", DataType::Boolean, true),
1724 ]);
1725
1726 assert_eq!(schema, expected_schema);
1727
1728 let builder = ReaderBuilder::new(Arc::new(schema))
1729 .with_format(format)
1730 .with_batch_size(512)
1731 .with_projection(vec![0, 1, 2, 3]);
1732
1733 let mut csv = builder.build(file).unwrap();
1734 let batch = csv.next().unwrap().unwrap();
1735
1736 assert_eq!(5, batch.num_rows());
1737 assert_eq!(4, batch.num_columns());
1738
1739 assert_eq!(batch.schema().as_ref(), &expected_schema);
1740 }
1741
1742 #[test]
1743 fn test_scientific_notation_with_inference() {
1744 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1745 let format = Format::default().with_header(false).with_delimiter(b',');
1746
1747 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1748 file.rewind().unwrap();
1749
1750 let builder = ReaderBuilder::new(Arc::new(schema))
1751 .with_format(format)
1752 .with_batch_size(512)
1753 .with_projection(vec![0, 1]);
1754
1755 let mut csv = builder.build(file).unwrap();
1756 let batch = csv.next().unwrap().unwrap();
1757
1758 let schema = batch.schema();
1759
1760 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1761 }
1762
1763 #[test]
1764 fn test_parse_invalid_csv() {
1765 let file = File::open("test/data/various_types_invalid.csv").unwrap();
1766
1767 let schema = Schema::new(vec![
1768 Field::new("c_int", DataType::UInt64, false),
1769 Field::new("c_float", DataType::Float32, false),
1770 Field::new("c_string", DataType::Utf8, false),
1771 Field::new("c_bool", DataType::Boolean, false),
1772 ]);
1773
1774 let builder = ReaderBuilder::new(Arc::new(schema))
1775 .with_header(true)
1776 .with_delimiter(b'|')
1777 .with_batch_size(512)
1778 .with_projection(vec![0, 1, 2, 3]);
1779
1780 let mut csv = builder.build(file).unwrap();
1781 match csv.next() {
1782 Some(e) => match e {
1783 Err(e) => assert_eq!(
1784 "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")",
1785 format!("{e:?}")
1786 ),
1787 Ok(_) => panic!("should have failed"),
1788 },
1789 None => panic!("should have failed"),
1790 }
1791 }
1792
1793 fn infer_field_schema(string: &str) -> DataType {
1795 let mut v = InferredDataType::default();
1796 v.update(string);
1797 v.get()
1798 }
1799
1800 #[test]
1801 fn test_infer_field_schema() {
1802 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1803 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1804 assert_eq!(infer_field_schema("10"), DataType::Int64);
1805 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1806 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1807 assert_eq!(infer_field_schema("2."), DataType::Float64);
1808 assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1809 assert_eq!(infer_field_schema("nan"), DataType::Float64);
1810 assert_eq!(infer_field_schema("inf"), DataType::Float64);
1811 assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1812 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1813 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1814 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1815 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1816 assert_eq!(
1817 infer_field_schema("2020-11-08T14:20:01"),
1818 DataType::Timestamp(TimeUnit::Second, None)
1819 );
1820 assert_eq!(
1821 infer_field_schema("2020-11-08 14:20:01"),
1822 DataType::Timestamp(TimeUnit::Second, None)
1823 );
1824 assert_eq!(
1825 infer_field_schema("2020-11-08 14:20:01"),
1826 DataType::Timestamp(TimeUnit::Second, None)
1827 );
1828 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1829 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1830 assert_eq!(
1831 infer_field_schema("2021-12-19 13:12:30.921"),
1832 DataType::Timestamp(TimeUnit::Millisecond, None)
1833 );
1834 assert_eq!(
1835 infer_field_schema("2021-12-19T13:12:30.123456789"),
1836 DataType::Timestamp(TimeUnit::Nanosecond, None)
1837 );
1838 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1839 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1840 }
1841
1842 #[test]
1843 fn parse_date32() {
1844 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1845 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1846 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1847 }
1848
1849 #[test]
1850 fn parse_time() {
1851 assert_eq!(
1852 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1853 Some(601_123_456_789)
1854 );
1855 assert_eq!(
1856 Time64MicrosecondType::parse("12:10:01.123456 am"),
1857 Some(601_123_456)
1858 );
1859 assert_eq!(
1860 Time32MillisecondType::parse("2:10:01.12 PM"),
1861 Some(51_001_120)
1862 );
1863 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1864 }
1865
1866 #[test]
1867 fn parse_date64() {
1868 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1869 assert_eq!(
1870 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1871 1542129070000
1872 );
1873 assert_eq!(
1874 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1875 1542129070011
1876 );
1877 assert_eq!(
1878 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1879 -2203932304000
1880 );
1881 assert_eq!(
1882 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1883 -2203932304000
1884 );
1885 assert_eq!(
1886 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1887 -2203932304000 - (30 * 60 * 1000)
1888 );
1889 }
1890
1891 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1892 timezone: Option<Arc<str>>,
1893 expected: &[i64],
1894 ) {
1895 let csv = [
1896 "1970-01-01T00:00:00",
1897 "1970-01-01T00:00:00Z",
1898 "1970-01-01T00:00:00+02:00",
1899 ]
1900 .join("\n");
1901 let schema = Arc::new(Schema::new(vec![Field::new(
1902 "field",
1903 DataType::Timestamp(T::UNIT, timezone.clone()),
1904 true,
1905 )]));
1906
1907 let mut decoder = ReaderBuilder::new(schema).build_decoder();
1908
1909 let decoded = decoder.decode(csv.as_bytes()).unwrap();
1910 assert_eq!(decoded, csv.len());
1911 decoder.decode(&[]).unwrap();
1912
1913 let batch = decoder.flush().unwrap().unwrap();
1914 assert_eq!(batch.num_columns(), 1);
1915 assert_eq!(batch.num_rows(), 3);
1916 let col = batch.column(0).as_primitive::<T>();
1917 assert_eq!(col.values(), expected);
1918 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
1919 }
1920
1921 #[test]
1922 fn test_parse_timestamp() {
1923 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
1924 test_parse_timestamp_impl::<TimestampNanosecondType>(
1925 Some("+00:00".into()),
1926 &[0, 0, -7_200_000_000_000],
1927 );
1928 test_parse_timestamp_impl::<TimestampNanosecondType>(
1929 Some("-05:00".into()),
1930 &[18_000_000_000_000, 0, -7_200_000_000_000],
1931 );
1932 test_parse_timestamp_impl::<TimestampMicrosecondType>(
1933 Some("-03".into()),
1934 &[10_800_000_000, 0, -7_200_000_000],
1935 );
1936 test_parse_timestamp_impl::<TimestampMillisecondType>(
1937 Some("-03".into()),
1938 &[10_800_000, 0, -7_200_000],
1939 );
1940 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
1941 }
1942
1943 #[test]
1944 fn test_infer_schema_from_multiple_files() {
1945 let mut csv1 = NamedTempFile::new().unwrap();
1946 let mut csv2 = NamedTempFile::new().unwrap();
1947 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
1949 writeln!(csv1, "c1,c2,c3").unwrap();
1950 writeln!(csv1, "1,\"foo\",0.5").unwrap();
1951 writeln!(csv1, "3,\"bar\",1").unwrap();
1952 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
1953 writeln!(csv2, "c1,c2,c3,c4").unwrap();
1955 writeln!(csv2, "10,,3.14,true").unwrap();
1956 writeln!(csv4, "c1,c2,c3").unwrap();
1958 writeln!(csv4, "10,\"foo\",").unwrap();
1959
1960 let schema = infer_schema_from_files(
1961 &[
1962 csv3.path().to_str().unwrap().to_string(),
1963 csv1.path().to_str().unwrap().to_string(),
1964 csv2.path().to_str().unwrap().to_string(),
1965 csv4.path().to_str().unwrap().to_string(),
1966 ],
1967 b',',
1968 Some(4), true,
1970 )
1971 .unwrap();
1972
1973 assert_eq!(schema.fields().len(), 4);
1974 assert!(schema.field(0).is_nullable());
1975 assert!(schema.field(1).is_nullable());
1976 assert!(schema.field(2).is_nullable());
1977 assert!(schema.field(3).is_nullable());
1978
1979 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1980 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
1981 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1982 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1983 }
1984
1985 #[test]
1986 fn test_bounded() {
1987 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
1988 let data = [
1989 vec!["0"],
1990 vec!["1"],
1991 vec!["2"],
1992 vec!["3"],
1993 vec!["4"],
1994 vec!["5"],
1995 vec!["6"],
1996 ];
1997
1998 let data = data
1999 .iter()
2000 .map(|x| x.join(","))
2001 .collect::<Vec<_>>()
2002 .join("\n");
2003 let data = data.as_bytes();
2004
2005 let reader = std::io::Cursor::new(data);
2006
2007 let mut csv = ReaderBuilder::new(Arc::new(schema))
2008 .with_batch_size(2)
2009 .with_projection(vec![0])
2010 .with_bounds(2, 6)
2011 .build_buffered(reader)
2012 .unwrap();
2013
2014 let batch = csv.next().unwrap().unwrap();
2015 let a = batch.column(0);
2016 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2017 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2018
2019 let batch = csv.next().unwrap().unwrap();
2020 let a = batch.column(0);
2021 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2022 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2023
2024 assert!(csv.next().is_none());
2025 }
2026
2027 #[test]
2028 fn test_empty_projection() {
2029 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2030 let data = [vec!["0"], vec!["1"]];
2031
2032 let data = data
2033 .iter()
2034 .map(|x| x.join(","))
2035 .collect::<Vec<_>>()
2036 .join("\n");
2037
2038 let mut csv = ReaderBuilder::new(Arc::new(schema))
2039 .with_batch_size(2)
2040 .with_projection(vec![])
2041 .build_buffered(Cursor::new(data.as_bytes()))
2042 .unwrap();
2043
2044 let batch = csv.next().unwrap().unwrap();
2045 assert_eq!(batch.columns().len(), 0);
2046 assert_eq!(batch.num_rows(), 2);
2047
2048 assert!(csv.next().is_none());
2049 }
2050
2051 #[test]
2052 fn test_parsing_bool() {
2053 assert_eq!(Some(true), parse_bool("true"));
2055 assert_eq!(Some(true), parse_bool("tRUe"));
2056 assert_eq!(Some(true), parse_bool("True"));
2057 assert_eq!(Some(true), parse_bool("TRUE"));
2058 assert_eq!(None, parse_bool("t"));
2059 assert_eq!(None, parse_bool("T"));
2060 assert_eq!(None, parse_bool(""));
2061
2062 assert_eq!(Some(false), parse_bool("false"));
2063 assert_eq!(Some(false), parse_bool("fALse"));
2064 assert_eq!(Some(false), parse_bool("False"));
2065 assert_eq!(Some(false), parse_bool("FALSE"));
2066 assert_eq!(None, parse_bool("f"));
2067 assert_eq!(None, parse_bool("F"));
2068 assert_eq!(None, parse_bool(""));
2069 }
2070
2071 #[test]
2072 fn test_parsing_float() {
2073 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2074 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2075 assert_eq!(Some(12.0), Float64Type::parse("12"));
2076 assert_eq!(Some(0.0), Float64Type::parse("0"));
2077 assert_eq!(Some(2.0), Float64Type::parse("2."));
2078 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2079 assert!(Float64Type::parse("nan").unwrap().is_nan());
2080 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2081 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2082 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2083 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2084 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2085 assert_eq!(None, Float64Type::parse(""));
2086 assert_eq!(None, Float64Type::parse("dd"));
2087 assert_eq!(None, Float64Type::parse("12.34.56"));
2088 }
2089
2090 #[test]
2091 fn test_non_std_quote() {
2092 let schema = Schema::new(vec![
2093 Field::new("text1", DataType::Utf8, false),
2094 Field::new("text2", DataType::Utf8, false),
2095 ]);
2096 let builder = ReaderBuilder::new(Arc::new(schema))
2097 .with_header(false)
2098 .with_quote(b'~'); let mut csv_text = Vec::new();
2101 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2102 for index in 0..10 {
2103 let text1 = format!("id{index:}");
2104 let text2 = format!("value{index:}");
2105 csv_writer
2106 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2107 .unwrap();
2108 }
2109 let mut csv_reader = std::io::Cursor::new(&csv_text);
2110 let mut reader = builder.build(&mut csv_reader).unwrap();
2111 let batch = reader.next().unwrap().unwrap();
2112 let col0 = batch.column(0);
2113 assert_eq!(col0.len(), 10);
2114 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2115 assert_eq!(col0_arr.value(0), "id0");
2116 let col1 = batch.column(1);
2117 assert_eq!(col1.len(), 10);
2118 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2119 assert_eq!(col1_arr.value(5), "value5");
2120 }
2121
2122 #[test]
2123 fn test_non_std_escape() {
2124 let schema = Schema::new(vec![
2125 Field::new("text1", DataType::Utf8, false),
2126 Field::new("text2", DataType::Utf8, false),
2127 ]);
2128 let builder = ReaderBuilder::new(Arc::new(schema))
2129 .with_header(false)
2130 .with_escape(b'\\'); let mut csv_text = Vec::new();
2133 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2134 for index in 0..10 {
2135 let text1 = format!("id{index:}");
2136 let text2 = format!("value\\\"{index:}");
2137 csv_writer
2138 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2139 .unwrap();
2140 }
2141 let mut csv_reader = std::io::Cursor::new(&csv_text);
2142 let mut reader = builder.build(&mut csv_reader).unwrap();
2143 let batch = reader.next().unwrap().unwrap();
2144 let col0 = batch.column(0);
2145 assert_eq!(col0.len(), 10);
2146 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2147 assert_eq!(col0_arr.value(0), "id0");
2148 let col1 = batch.column(1);
2149 assert_eq!(col1.len(), 10);
2150 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2151 assert_eq!(col1_arr.value(5), "value\"5");
2152 }
2153
2154 #[test]
2155 fn test_non_std_terminator() {
2156 let schema = Schema::new(vec![
2157 Field::new("text1", DataType::Utf8, false),
2158 Field::new("text2", DataType::Utf8, false),
2159 ]);
2160 let builder = ReaderBuilder::new(Arc::new(schema))
2161 .with_header(false)
2162 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2165 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2166 for index in 0..10 {
2167 let text1 = format!("id{index:}");
2168 let text2 = format!("value{index:}");
2169 csv_writer
2170 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2171 .unwrap();
2172 }
2173 let mut csv_reader = std::io::Cursor::new(&csv_text);
2174 let mut reader = builder.build(&mut csv_reader).unwrap();
2175 let batch = reader.next().unwrap().unwrap();
2176 let col0 = batch.column(0);
2177 assert_eq!(col0.len(), 10);
2178 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2179 assert_eq!(col0_arr.value(0), "id0");
2180 let col1 = batch.column(1);
2181 assert_eq!(col1.len(), 10);
2182 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2183 assert_eq!(col1_arr.value(5), "value5");
2184 }
2185
2186 #[test]
2187 fn test_header_bounds() {
2188 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2189 let tests = [
2190 (None, false, 5),
2191 (None, true, 4),
2192 (Some((0, 4)), false, 4),
2193 (Some((1, 4)), false, 3),
2194 (Some((0, 4)), true, 4),
2195 (Some((1, 4)), true, 3),
2196 ];
2197 let schema = Arc::new(Schema::new(vec![
2198 Field::new("a", DataType::Utf8, false),
2199 Field::new("a", DataType::Utf8, false),
2200 ]));
2201
2202 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2203 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2204 if let Some((start, end)) = bounds {
2205 reader = reader.with_bounds(start, end);
2206 }
2207 let b = reader
2208 .build_buffered(Cursor::new(csv.as_bytes()))
2209 .unwrap()
2210 .next()
2211 .unwrap()
2212 .unwrap();
2213 assert_eq!(b.num_rows(), expected, "{idx}");
2214 }
2215 }
2216
2217 #[test]
2218 fn test_null_boolean() {
2219 let csv = "true,false\nFalse,True\n,True\nFalse,";
2220 let schema = Arc::new(Schema::new(vec![
2221 Field::new("a", DataType::Boolean, true),
2222 Field::new("a", DataType::Boolean, true),
2223 ]));
2224
2225 let b = ReaderBuilder::new(schema)
2226 .build_buffered(Cursor::new(csv.as_bytes()))
2227 .unwrap()
2228 .next()
2229 .unwrap()
2230 .unwrap();
2231
2232 assert_eq!(b.num_rows(), 4);
2233 assert_eq!(b.num_columns(), 2);
2234
2235 let c = b.column(0).as_boolean();
2236 assert_eq!(c.null_count(), 1);
2237 assert!(c.value(0));
2238 assert!(!c.value(1));
2239 assert!(c.is_null(2));
2240 assert!(!c.value(3));
2241
2242 let c = b.column(1).as_boolean();
2243 assert_eq!(c.null_count(), 1);
2244 assert!(!c.value(0));
2245 assert!(c.value(1));
2246 assert!(c.value(2));
2247 assert!(c.is_null(3));
2248 }
2249
2250 #[test]
2251 fn test_truncated_rows() {
2252 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2253 let schema = Arc::new(Schema::new(vec![
2254 Field::new("a", DataType::Int32, true),
2255 Field::new("b", DataType::Int32, true),
2256 Field::new("c", DataType::Int32, true),
2257 ]));
2258
2259 let reader = ReaderBuilder::new(schema.clone())
2260 .with_header(true)
2261 .with_truncated_rows(true)
2262 .build(Cursor::new(data))
2263 .unwrap();
2264
2265 let batches = reader.collect::<Result<Vec<_>, _>>();
2266 assert!(batches.is_ok());
2267 let batch = batches.unwrap().into_iter().next().unwrap();
2268 assert_eq!(batch.num_rows(), 3);
2270
2271 let reader = ReaderBuilder::new(schema.clone())
2272 .with_header(true)
2273 .with_truncated_rows(false)
2274 .build(Cursor::new(data))
2275 .unwrap();
2276
2277 let batches = reader.collect::<Result<Vec<_>, _>>();
2278 assert!(match batches {
2279 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2280 _ => false,
2281 });
2282 }
2283
2284 #[test]
2285 fn test_truncated_rows_csv() {
2286 let file = File::open("test/data/truncated_rows.csv").unwrap();
2287 let schema = Arc::new(Schema::new(vec![
2288 Field::new("Name", DataType::Utf8, true),
2289 Field::new("Age", DataType::UInt32, true),
2290 Field::new("Occupation", DataType::Utf8, true),
2291 Field::new("DOB", DataType::Date32, true),
2292 ]));
2293 let reader = ReaderBuilder::new(schema.clone())
2294 .with_header(true)
2295 .with_batch_size(24)
2296 .with_truncated_rows(true);
2297 let csv = reader.build(file).unwrap();
2298 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2299
2300 assert_eq!(batches.len(), 1);
2301 let batch = &batches[0];
2302 assert_eq!(batch.num_rows(), 6);
2303 assert_eq!(batch.num_columns(), 4);
2304 let name = batch
2305 .column(0)
2306 .as_any()
2307 .downcast_ref::<StringArray>()
2308 .unwrap();
2309 let age = batch
2310 .column(1)
2311 .as_any()
2312 .downcast_ref::<UInt32Array>()
2313 .unwrap();
2314 let occupation = batch
2315 .column(2)
2316 .as_any()
2317 .downcast_ref::<StringArray>()
2318 .unwrap();
2319 let dob = batch
2320 .column(3)
2321 .as_any()
2322 .downcast_ref::<Date32Array>()
2323 .unwrap();
2324
2325 assert_eq!(name.value(0), "A1");
2326 assert_eq!(name.value(1), "B2");
2327 assert!(name.is_null(2));
2328 assert_eq!(name.value(3), "C3");
2329 assert_eq!(name.value(4), "D4");
2330 assert_eq!(name.value(5), "E5");
2331
2332 assert_eq!(age.value(0), 34);
2333 assert_eq!(age.value(1), 29);
2334 assert!(age.is_null(2));
2335 assert_eq!(age.value(3), 45);
2336 assert!(age.is_null(4));
2337 assert_eq!(age.value(5), 31);
2338
2339 assert_eq!(occupation.value(0), "Engineer");
2340 assert_eq!(occupation.value(1), "Doctor");
2341 assert!(occupation.is_null(2));
2342 assert_eq!(occupation.value(3), "Artist");
2343 assert!(occupation.is_null(4));
2344 assert!(occupation.is_null(5));
2345
2346 assert_eq!(dob.value(0), 5675);
2347 assert!(dob.is_null(1));
2348 assert!(dob.is_null(2));
2349 assert_eq!(dob.value(3), -1858);
2350 assert!(dob.is_null(4));
2351 assert!(dob.is_null(5));
2352 }
2353
2354 #[test]
2355 fn test_truncated_rows_not_nullable_error() {
2356 let data = "a,b,c\n1,2,3\n4,5";
2357 let schema = Arc::new(Schema::new(vec![
2358 Field::new("a", DataType::Int32, false),
2359 Field::new("b", DataType::Int32, false),
2360 Field::new("c", DataType::Int32, false),
2361 ]));
2362
2363 let reader = ReaderBuilder::new(schema.clone())
2364 .with_header(true)
2365 .with_truncated_rows(true)
2366 .build(Cursor::new(data))
2367 .unwrap();
2368
2369 let batches = reader.collect::<Result<Vec<_>, _>>();
2370 assert!(match batches {
2371 Err(ArrowError::InvalidArgumentError(e)) =>
2372 e.to_string().contains("contains null values"),
2373 _ => false,
2374 });
2375 }
2376
2377 #[test]
2378 fn test_buffered() {
2379 let tests = [
2380 ("test/data/uk_cities.csv", false, 37),
2381 ("test/data/various_types.csv", true, 10),
2382 ("test/data/decimal_test.csv", false, 10),
2383 ];
2384
2385 for (path, has_header, expected_rows) in tests {
2386 let (schema, _) = Format::default()
2387 .infer_schema(File::open(path).unwrap(), None)
2388 .unwrap();
2389 let schema = Arc::new(schema);
2390
2391 for batch_size in [1, 4] {
2392 for capacity in [1, 3, 7, 100] {
2393 let reader = ReaderBuilder::new(schema.clone())
2394 .with_batch_size(batch_size)
2395 .with_header(has_header)
2396 .build(File::open(path).unwrap())
2397 .unwrap();
2398
2399 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2400
2401 assert_eq!(
2402 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2403 expected_rows
2404 );
2405
2406 let buffered =
2407 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2408
2409 let reader = ReaderBuilder::new(schema.clone())
2410 .with_batch_size(batch_size)
2411 .with_header(has_header)
2412 .build_buffered(buffered)
2413 .unwrap();
2414
2415 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2416 assert_eq!(expected, actual)
2417 }
2418 }
2419 }
2420 }
2421
2422 fn err_test(csv: &[u8], expected: &str) {
2423 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2424 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2425 let b = ReaderBuilder::new(schema)
2426 .with_batch_size(2)
2427 .build_buffered(buffer)
2428 .unwrap();
2429 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2430 assert_eq!(err, expected)
2431 }
2432
2433 let schema_utf8 = Arc::new(Schema::new(vec![
2434 Field::new("text1", DataType::Utf8, true),
2435 Field::new("text2", DataType::Utf8, true),
2436 ]));
2437 err_test_with_schema(csv, expected, schema_utf8);
2438
2439 let schema_utf8view = Arc::new(Schema::new(vec![
2440 Field::new("text1", DataType::Utf8View, true),
2441 Field::new("text2", DataType::Utf8View, true),
2442 ]));
2443 err_test_with_schema(csv, expected, schema_utf8view);
2444 }
2445
2446 #[test]
2447 fn test_invalid_utf8() {
2448 err_test(
2449 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2450 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2451 );
2452
2453 err_test(
2454 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2455 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2456 );
2457
2458 err_test(
2459 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2460 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2461 );
2462
2463 err_test(
2464 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2465 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2466 );
2467 }
2468
2469 struct InstrumentedRead<R> {
2470 r: R,
2471 fill_count: usize,
2472 fill_sizes: Vec<usize>,
2473 }
2474
2475 impl<R> InstrumentedRead<R> {
2476 fn new(r: R) -> Self {
2477 Self {
2478 r,
2479 fill_count: 0,
2480 fill_sizes: vec![],
2481 }
2482 }
2483 }
2484
2485 impl<R: Seek> Seek for InstrumentedRead<R> {
2486 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2487 self.r.seek(pos)
2488 }
2489 }
2490
2491 impl<R: BufRead> Read for InstrumentedRead<R> {
2492 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2493 self.r.read(buf)
2494 }
2495 }
2496
2497 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2498 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2499 self.fill_count += 1;
2500 let buf = self.r.fill_buf()?;
2501 self.fill_sizes.push(buf.len());
2502 Ok(buf)
2503 }
2504
2505 fn consume(&mut self, amt: usize) {
2506 self.r.consume(amt)
2507 }
2508 }
2509
2510 #[test]
2511 fn test_io() {
2512 let schema = Arc::new(Schema::new(vec![
2513 Field::new("a", DataType::Utf8, false),
2514 Field::new("b", DataType::Utf8, false),
2515 ]));
2516 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2517 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2518 let reader = ReaderBuilder::new(schema)
2519 .with_batch_size(3)
2520 .build_buffered(&mut read)
2521 .unwrap();
2522
2523 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2524 assert_eq!(batches.len(), 2);
2525 assert_eq!(batches[0].num_rows(), 3);
2526 assert_eq!(batches[1].num_rows(), 1);
2527
2528 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2534 assert_eq!(read.fill_count, 4);
2535 }
2536
2537 #[test]
2538 fn test_inference() {
2539 let cases: &[(&[&str], DataType)] = &[
2540 (&[], DataType::Null),
2541 (&["false", "12"], DataType::Utf8),
2542 (&["12", "cupcakes"], DataType::Utf8),
2543 (&["12", "12.4"], DataType::Float64),
2544 (&["14050", "24332"], DataType::Int64),
2545 (&["14050.0", "true"], DataType::Utf8),
2546 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2547 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2548 (
2549 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2550 DataType::Timestamp(TimeUnit::Second, None),
2551 ),
2552 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2553 (
2554 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2555 DataType::Timestamp(TimeUnit::Second, None),
2556 ),
2557 (
2558 &[
2559 "2020-03-19",
2560 "2020-03-19 02:00:00",
2561 "2020-03-19 00:00:00.000",
2562 ],
2563 DataType::Timestamp(TimeUnit::Millisecond, None),
2564 ),
2565 (
2566 &[
2567 "2020-03-19",
2568 "2020-03-19 02:00:00",
2569 "2020-03-19 00:00:00.000000",
2570 ],
2571 DataType::Timestamp(TimeUnit::Microsecond, None),
2572 ),
2573 (
2574 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2575 DataType::Timestamp(TimeUnit::Second, None),
2576 ),
2577 (
2578 &[
2579 "2020-03-19",
2580 "2020-03-19 02:00:00+02:00",
2581 "2020-03-19 02:00:00Z",
2582 "2020-03-19 02:00:00.12Z",
2583 ],
2584 DataType::Timestamp(TimeUnit::Millisecond, None),
2585 ),
2586 (
2587 &[
2588 "2020-03-19",
2589 "2020-03-19 02:00:00.000000000",
2590 "2020-03-19 00:00:00.000000",
2591 ],
2592 DataType::Timestamp(TimeUnit::Nanosecond, None),
2593 ),
2594 ];
2595
2596 for (values, expected) in cases {
2597 let mut t = InferredDataType::default();
2598 for v in *values {
2599 t.update(v)
2600 }
2601 assert_eq!(&t.get(), expected, "{values:?}")
2602 }
2603 }
2604
2605 #[test]
2606 fn test_record_length_mismatch() {
2607 let csv = "\
2608 a,b,c\n\
2609 1,2,3\n\
2610 4,5\n\
2611 6,7,8";
2612 let mut read = Cursor::new(csv.as_bytes());
2613 let result = Format::default()
2614 .with_header(true)
2615 .infer_schema(&mut read, None);
2616 assert!(result.is_err());
2617 assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3");
2619 }
2620
2621 #[test]
2622 fn test_comment() {
2623 let schema = Schema::new(vec![
2624 Field::new("a", DataType::Int8, false),
2625 Field::new("b", DataType::Int8, false),
2626 ]);
2627
2628 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2629 let mut read = Cursor::new(csv.as_bytes());
2630 let reader = ReaderBuilder::new(Arc::new(schema))
2631 .with_comment(b'#')
2632 .build(&mut read)
2633 .unwrap();
2634
2635 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2636 assert_eq!(batches.len(), 1);
2637 let b = batches.first().unwrap();
2638 assert_eq!(b.num_columns(), 2);
2639 assert_eq!(
2640 b.column(0)
2641 .as_any()
2642 .downcast_ref::<Int8Array>()
2643 .unwrap()
2644 .values(),
2645 &vec![1, 11]
2646 );
2647 assert_eq!(
2648 b.column(1)
2649 .as_any()
2650 .downcast_ref::<Int8Array>()
2651 .unwrap()
2652 .values(),
2653 &vec![2, 22]
2654 );
2655 }
2656
2657 #[test]
2658 fn test_parse_string_view_single_column() {
2659 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2660 let schema = Arc::new(Schema::new(vec![Field::new(
2661 "c1",
2662 DataType::Utf8View,
2663 true,
2664 )]));
2665
2666 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2667
2668 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2669 assert_eq!(decoded, csv.len());
2670 decoder.decode(&[]).unwrap();
2671
2672 let batch = decoder.flush().unwrap().unwrap();
2673 assert_eq!(batch.num_columns(), 1);
2674 assert_eq!(batch.num_rows(), 3);
2675 let col = batch.column(0).as_string_view();
2676 assert_eq!(col.data_type(), &DataType::Utf8View);
2677 assert_eq!(col.value(0), "foo");
2678 assert_eq!(col.value(1), "something_cannot_be_inlined");
2679 assert_eq!(col.value(2), "foobar");
2680 }
2681
2682 #[test]
2683 fn test_parse_string_view_multi_column() {
2684 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2685 let schema = Arc::new(Schema::new(vec![
2686 Field::new("c1", DataType::Utf8View, true),
2687 Field::new("c2", DataType::Utf8View, true),
2688 ]));
2689
2690 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2691
2692 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2693 assert_eq!(decoded, csv.len());
2694 decoder.decode(&[]).unwrap();
2695
2696 let batch = decoder.flush().unwrap().unwrap();
2697 assert_eq!(batch.num_columns(), 2);
2698 assert_eq!(batch.num_rows(), 3);
2699 let c1 = batch.column(0).as_string_view();
2700 let c2 = batch.column(1).as_string_view();
2701 assert_eq!(c1.data_type(), &DataType::Utf8View);
2702 assert_eq!(c2.data_type(), &DataType::Utf8View);
2703
2704 assert!(!c1.is_null(0));
2705 assert!(c1.is_null(1));
2706 assert!(!c1.is_null(2));
2707 assert_eq!(c1.value(0), "foo");
2708 assert_eq!(c1.value(2), "foobarfoobar");
2709
2710 assert!(c2.is_null(0));
2711 assert!(!c2.is_null(1));
2712 assert!(!c2.is_null(2));
2713 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2714 assert_eq!(c2.value(2), "bar");
2715 }
2716}