1mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use lazy_static::lazy_static;
136use regex::{Regex, RegexSet};
137use std::fmt::{self, Debug};
138use std::fs::File;
139use std::io::{BufRead, BufReader as StdBufReader, Read};
140use std::sync::Arc;
141
142use crate::map_csv_error;
143use crate::reader::records::{RecordDecoder, StringRecords};
144use arrow_array::timezone::Tz;
145
146lazy_static! {
147 static ref REGEX_SET: RegexSet = RegexSet::new([
149 r"(?i)^(true)$|^(false)$(?-i)", r"^-?(\d+)$", r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", r"^\d{4}-\d\d-\d\d$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", ]).unwrap();
158}
159
160#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165 #[inline]
168 fn is_null(&self, s: &str) -> bool {
169 match &self.0 {
170 Some(r) => r.is_match(s),
171 None => s.is_empty(),
172 }
173 }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178 packed: u16,
190}
191
192impl InferredDataType {
193 fn get(&self) -> DataType {
195 match self.packed {
196 0 => DataType::Null,
197 1 => DataType::Boolean,
198 2 => DataType::Int64,
199 4 | 6 => DataType::Float64, b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201 8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203 9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204 10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205 11 => DataType::Timestamp(TimeUnit::Second, None),
206 12 => DataType::Date32,
207 _ => unreachable!(),
208 },
209 _ => DataType::Utf8,
210 }
211 }
212
213 fn update(&mut self, string: &str) {
215 self.packed |= if string.starts_with('"') {
216 1 << 8 } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218 if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219 1 << 8
221 } else {
222 1 << m
223 }
224 } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225 1 << 2 } else {
227 1 << 8 }
229 }
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct Format {
235 header: bool,
236 delimiter: Option<u8>,
237 escape: Option<u8>,
238 quote: Option<u8>,
239 terminator: Option<u8>,
240 comment: Option<u8>,
241 null_regex: NullRegex,
242 truncated_rows: bool,
243}
244
245impl Format {
246 pub fn with_header(mut self, has_header: bool) -> Self {
250 self.header = has_header;
251 self
252 }
253
254 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256 self.delimiter = Some(delimiter);
257 self
258 }
259
260 pub fn with_escape(mut self, escape: u8) -> Self {
262 self.escape = Some(escape);
263 self
264 }
265
266 pub fn with_quote(mut self, quote: u8) -> Self {
268 self.quote = Some(quote);
269 self
270 }
271
272 pub fn with_terminator(mut self, terminator: u8) -> Self {
274 self.terminator = Some(terminator);
275 self
276 }
277
278 pub fn with_comment(mut self, comment: u8) -> Self {
282 self.comment = Some(comment);
283 self
284 }
285
286 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288 self.null_regex = NullRegex(Some(null_regex));
289 self
290 }
291
292 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299 self.truncated_rows = allow;
300 self
301 }
302
303 pub fn infer_schema<R: Read>(
310 &self,
311 reader: R,
312 max_records: Option<usize>,
313 ) -> Result<(Schema, usize), ArrowError> {
314 let mut csv_reader = self.build_reader(reader);
315
316 let headers: Vec<String> = if self.header {
319 let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320 headers.iter().map(|s| s.to_string()).collect()
321 } else {
322 let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323 (0..*first_record_count)
324 .map(|i| format!("column_{}", i + 1))
325 .collect()
326 };
327
328 let header_length = headers.len();
329 let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332 let mut records_count = 0;
333
334 let mut record = StringRecord::new();
335 let max_records = max_records.unwrap_or(usize::MAX);
336 while records_count < max_records {
337 if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338 break;
339 }
340 records_count += 1;
341
342 for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345 if let Some(string) = record.get(i) {
346 if !self.null_regex.is_null(string) {
347 column_type.update(string)
348 }
349 }
350 }
351 }
352
353 let fields: Fields = column_types
355 .iter()
356 .zip(&headers)
357 .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358 .collect();
359
360 Ok((Schema::new(fields), records_count))
361 }
362
363 fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365 let mut builder = csv::ReaderBuilder::new();
366 builder.has_headers(self.header);
367 builder.flexible(self.truncated_rows);
368
369 if let Some(c) = self.delimiter {
370 builder.delimiter(c);
371 }
372 builder.escape(self.escape);
373 if let Some(c) = self.quote {
374 builder.quote(c);
375 }
376 if let Some(t) = self.terminator {
377 builder.terminator(csv::Terminator::Any(t));
378 }
379 if let Some(comment) = self.comment {
380 builder.comment(Some(comment));
381 }
382 builder.from_reader(reader)
383 }
384
385 fn build_parser(&self) -> csv_core::Reader {
387 let mut builder = csv_core::ReaderBuilder::new();
388 builder.escape(self.escape);
389 builder.comment(self.comment);
390
391 if let Some(c) = self.delimiter {
392 builder.delimiter(c);
393 }
394 if let Some(c) = self.quote {
395 builder.quote(c);
396 }
397 if let Some(t) = self.terminator {
398 builder.terminator(csv_core::Terminator::Any(t));
399 }
400 builder.build()
401 }
402}
403
404pub fn infer_schema_from_files(
411 files: &[String],
412 delimiter: u8,
413 max_read_records: Option<usize>,
414 has_header: bool,
415) -> Result<Schema, ArrowError> {
416 let mut schemas = vec![];
417 let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418 let format = Format {
419 delimiter: Some(delimiter),
420 header: has_header,
421 ..Default::default()
422 };
423
424 for fname in files.iter() {
425 let f = File::open(fname)?;
426 let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427 if records_read == 0 {
428 continue;
429 }
430 schemas.push(schema.clone());
431 records_to_read -= records_read;
432 if records_to_read == 0 {
433 break;
434 }
435 }
436
437 Schema::try_merge(schemas)
438}
439
440type Bounds = Option<(usize, usize)>;
442
443pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446pub struct BufReader<R> {
448 reader: R,
450
451 decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457 R: BufRead,
458{
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 f.debug_struct("Reader")
461 .field("decoder", &self.decoder)
462 .finish()
463 }
464}
465
466impl<R: Read> Reader<R> {
467 pub fn schema(&self) -> SchemaRef {
470 match &self.decoder.projection {
471 Some(projection) => {
472 let fields = self.decoder.schema.fields();
473 let projected = projection.iter().map(|i| fields[*i].clone());
474 Arc::new(Schema::new(projected.collect::<Fields>()))
475 }
476 None => self.decoder.schema.clone(),
477 }
478 }
479}
480
481impl<R: BufRead> BufReader<R> {
482 fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483 loop {
484 let buf = self.reader.fill_buf()?;
485 let decoded = self.decoder.decode(buf)?;
486 self.reader.consume(decoded);
487 if decoded == 0 || self.decoder.capacity() == 0 {
493 break;
494 }
495 }
496
497 self.decoder.flush()
498 }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502 type Item = Result<RecordBatch, ArrowError>;
503
504 fn next(&mut self) -> Option<Self::Item> {
505 self.read().transpose()
506 }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510 fn schema(&self) -> SchemaRef {
511 self.decoder.schema.clone()
512 }
513}
514
515#[derive(Debug)]
555pub struct Decoder {
556 schema: SchemaRef,
558
559 projection: Option<Vec<usize>>,
561
562 batch_size: usize,
564
565 to_skip: usize,
567
568 line_number: usize,
570
571 end: usize,
573
574 record_decoder: RecordDecoder,
576
577 null_regex: NullRegex,
579}
580
581impl Decoder {
582 pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592 if self.to_skip != 0 {
593 let to_skip = self.to_skip.min(self.batch_size);
595 let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596 self.to_skip -= skipped;
597 self.record_decoder.clear();
598 return Ok(bytes);
599 }
600
601 let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602 let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603 Ok(bytes)
604 }
605
606 pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613 if self.record_decoder.is_empty() {
614 return Ok(None);
615 }
616
617 let rows = self.record_decoder.flush()?;
618 let batch = parse(
619 &rows,
620 self.schema.fields(),
621 Some(self.schema.metadata.clone()),
622 self.projection.as_ref(),
623 self.line_number,
624 &self.null_regex,
625 )?;
626 self.line_number += rows.len();
627 Ok(Some(batch))
628 }
629
630 pub fn capacity(&self) -> usize {
632 self.batch_size - self.record_decoder.len()
633 }
634}
635
636fn parse(
638 rows: &StringRecords<'_>,
639 fields: &Fields,
640 metadata: Option<std::collections::HashMap<String, String>>,
641 projection: Option<&Vec<usize>>,
642 line_number: usize,
643 null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645 let projection: Vec<usize> = match projection {
646 Some(v) => v.clone(),
647 None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648 };
649
650 let arrays: Result<Vec<ArrayRef>, _> = projection
651 .iter()
652 .map(|i| {
653 let i = *i;
654 let field = &fields[i];
655 match field.data_type() {
656 DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657 DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
658 line_number,
659 rows,
660 i,
661 *precision,
662 *scale,
663 null_regex,
664 ),
665 DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
666 line_number,
667 rows,
668 i,
669 *precision,
670 *scale,
671 null_regex,
672 ),
673 DataType::Int8 => {
674 build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
675 }
676 DataType::Int16 => {
677 build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
678 }
679 DataType::Int32 => {
680 build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
681 }
682 DataType::Int64 => {
683 build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
684 }
685 DataType::UInt8 => {
686 build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
687 }
688 DataType::UInt16 => {
689 build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
690 }
691 DataType::UInt32 => {
692 build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
693 }
694 DataType::UInt64 => {
695 build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
696 }
697 DataType::Float32 => {
698 build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
699 }
700 DataType::Float64 => {
701 build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
702 }
703 DataType::Date32 => {
704 build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
705 }
706 DataType::Date64 => {
707 build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
708 }
709 DataType::Time32(TimeUnit::Second) => {
710 build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
711 }
712 DataType::Time32(TimeUnit::Millisecond) => {
713 build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
714 }
715 DataType::Time64(TimeUnit::Microsecond) => {
716 build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
717 }
718 DataType::Time64(TimeUnit::Nanosecond) => {
719 build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
720 }
721 DataType::Timestamp(TimeUnit::Second, tz) => {
722 build_timestamp_array::<TimestampSecondType>(
723 line_number,
724 rows,
725 i,
726 tz.as_deref(),
727 null_regex,
728 )
729 }
730 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
731 build_timestamp_array::<TimestampMillisecondType>(
732 line_number,
733 rows,
734 i,
735 tz.as_deref(),
736 null_regex,
737 )
738 }
739 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
740 build_timestamp_array::<TimestampMicrosecondType>(
741 line_number,
742 rows,
743 i,
744 tz.as_deref(),
745 null_regex,
746 )
747 }
748 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
749 build_timestamp_array::<TimestampNanosecondType>(
750 line_number,
751 rows,
752 i,
753 tz.as_deref(),
754 null_regex,
755 )
756 }
757 DataType::Null => Ok(Arc::new({
758 let mut builder = NullBuilder::new();
759 builder.append_nulls(rows.len());
760 builder.finish()
761 }) as ArrayRef),
762 DataType::Utf8 => Ok(Arc::new(
763 rows.iter()
764 .map(|row| {
765 let s = row.get(i);
766 (!null_regex.is_null(s)).then_some(s)
767 })
768 .collect::<StringArray>(),
769 ) as ArrayRef),
770 DataType::Utf8View => Ok(Arc::new(
771 rows.iter()
772 .map(|row| {
773 let s = row.get(i);
774 (!null_regex.is_null(s)).then_some(s)
775 })
776 .collect::<StringViewArray>(),
777 ) as ArrayRef),
778 DataType::Dictionary(key_type, value_type)
779 if value_type.as_ref() == &DataType::Utf8 =>
780 {
781 match key_type.as_ref() {
782 DataType::Int8 => Ok(Arc::new(
783 rows.iter()
784 .map(|row| {
785 let s = row.get(i);
786 (!null_regex.is_null(s)).then_some(s)
787 })
788 .collect::<DictionaryArray<Int8Type>>(),
789 ) as ArrayRef),
790 DataType::Int16 => Ok(Arc::new(
791 rows.iter()
792 .map(|row| {
793 let s = row.get(i);
794 (!null_regex.is_null(s)).then_some(s)
795 })
796 .collect::<DictionaryArray<Int16Type>>(),
797 ) as ArrayRef),
798 DataType::Int32 => Ok(Arc::new(
799 rows.iter()
800 .map(|row| {
801 let s = row.get(i);
802 (!null_regex.is_null(s)).then_some(s)
803 })
804 .collect::<DictionaryArray<Int32Type>>(),
805 ) as ArrayRef),
806 DataType::Int64 => Ok(Arc::new(
807 rows.iter()
808 .map(|row| {
809 let s = row.get(i);
810 (!null_regex.is_null(s)).then_some(s)
811 })
812 .collect::<DictionaryArray<Int64Type>>(),
813 ) as ArrayRef),
814 DataType::UInt8 => Ok(Arc::new(
815 rows.iter()
816 .map(|row| {
817 let s = row.get(i);
818 (!null_regex.is_null(s)).then_some(s)
819 })
820 .collect::<DictionaryArray<UInt8Type>>(),
821 ) as ArrayRef),
822 DataType::UInt16 => Ok(Arc::new(
823 rows.iter()
824 .map(|row| {
825 let s = row.get(i);
826 (!null_regex.is_null(s)).then_some(s)
827 })
828 .collect::<DictionaryArray<UInt16Type>>(),
829 ) as ArrayRef),
830 DataType::UInt32 => Ok(Arc::new(
831 rows.iter()
832 .map(|row| {
833 let s = row.get(i);
834 (!null_regex.is_null(s)).then_some(s)
835 })
836 .collect::<DictionaryArray<UInt32Type>>(),
837 ) as ArrayRef),
838 DataType::UInt64 => Ok(Arc::new(
839 rows.iter()
840 .map(|row| {
841 let s = row.get(i);
842 (!null_regex.is_null(s)).then_some(s)
843 })
844 .collect::<DictionaryArray<UInt64Type>>(),
845 ) as ArrayRef),
846 _ => Err(ArrowError::ParseError(format!(
847 "Unsupported dictionary key type {key_type:?}"
848 ))),
849 }
850 }
851 other => Err(ArrowError::ParseError(format!(
852 "Unsupported data type {other:?}"
853 ))),
854 }
855 })
856 .collect();
857
858 let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
859
860 let projected_schema = Arc::new(match metadata {
861 None => Schema::new(projected_fields),
862 Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
863 });
864
865 arrays.and_then(|arr| {
866 RecordBatch::try_new_with_options(
867 projected_schema,
868 arr,
869 &RecordBatchOptions::new()
870 .with_match_field_names(true)
871 .with_row_count(Some(rows.len())),
872 )
873 })
874}
875
876fn parse_bool(string: &str) -> Option<bool> {
877 if string.eq_ignore_ascii_case("false") {
878 Some(false)
879 } else if string.eq_ignore_ascii_case("true") {
880 Some(true)
881 } else {
882 None
883 }
884}
885
886fn build_decimal_array<T: DecimalType>(
888 _line_number: usize,
889 rows: &StringRecords<'_>,
890 col_idx: usize,
891 precision: u8,
892 scale: i8,
893 null_regex: &NullRegex,
894) -> Result<ArrayRef, ArrowError> {
895 let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
896 for row in rows.iter() {
897 let s = row.get(col_idx);
898 if null_regex.is_null(s) {
899 decimal_builder.append_null();
901 } else {
902 let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
903 match decimal_value {
904 Ok(v) => {
905 decimal_builder.append_value(v);
906 }
907 Err(e) => {
908 return Err(e);
909 }
910 }
911 }
912 }
913 Ok(Arc::new(
914 decimal_builder
915 .finish()
916 .with_precision_and_scale(precision, scale)?,
917 ))
918}
919
920fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
922 line_number: usize,
923 rows: &StringRecords<'_>,
924 col_idx: usize,
925 null_regex: &NullRegex,
926) -> Result<ArrayRef, ArrowError> {
927 rows.iter()
928 .enumerate()
929 .map(|(row_index, row)| {
930 let s = row.get(col_idx);
931 if null_regex.is_null(s) {
932 return Ok(None);
933 }
934
935 match T::parse(s) {
936 Some(e) => Ok(Some(e)),
937 None => Err(ArrowError::ParseError(format!(
938 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
940 s,
941 T::DATA_TYPE,
942 col_idx,
943 line_number + row_index,
944 row
945 ))),
946 }
947 })
948 .collect::<Result<PrimitiveArray<T>, ArrowError>>()
949 .map(|e| Arc::new(e) as ArrayRef)
950}
951
952fn build_timestamp_array<T: ArrowTimestampType>(
953 line_number: usize,
954 rows: &StringRecords<'_>,
955 col_idx: usize,
956 timezone: Option<&str>,
957 null_regex: &NullRegex,
958) -> Result<ArrayRef, ArrowError> {
959 Ok(Arc::new(match timezone {
960 Some(timezone) => {
961 let tz: Tz = timezone.parse()?;
962 build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
963 .with_timezone(timezone)
964 }
965 None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
966 }))
967}
968
969fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
970 line_number: usize,
971 rows: &StringRecords<'_>,
972 col_idx: usize,
973 timezone: &Tz,
974 null_regex: &NullRegex,
975) -> Result<PrimitiveArray<T>, ArrowError> {
976 rows.iter()
977 .enumerate()
978 .map(|(row_index, row)| {
979 let s = row.get(col_idx);
980 if null_regex.is_null(s) {
981 return Ok(None);
982 }
983
984 let date = string_to_datetime(timezone, s)
985 .and_then(|date| match T::UNIT {
986 TimeUnit::Second => Ok(date.timestamp()),
987 TimeUnit::Millisecond => Ok(date.timestamp_millis()),
988 TimeUnit::Microsecond => Ok(date.timestamp_micros()),
989 TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
990 ArrowError::ParseError(format!(
991 "{} would overflow 64-bit signed nanoseconds",
992 date.to_rfc3339(),
993 ))
994 }),
995 })
996 .map_err(|e| {
997 ArrowError::ParseError(format!(
998 "Error parsing column {col_idx} at line {}: {}",
999 line_number + row_index,
1000 e
1001 ))
1002 })?;
1003 Ok(Some(date))
1004 })
1005 .collect()
1006}
1007
1008fn build_boolean_array(
1010 line_number: usize,
1011 rows: &StringRecords<'_>,
1012 col_idx: usize,
1013 null_regex: &NullRegex,
1014) -> Result<ArrayRef, ArrowError> {
1015 rows.iter()
1016 .enumerate()
1017 .map(|(row_index, row)| {
1018 let s = row.get(col_idx);
1019 if null_regex.is_null(s) {
1020 return Ok(None);
1021 }
1022 let parsed = parse_bool(s);
1023 match parsed {
1024 Some(e) => Ok(Some(e)),
1025 None => Err(ArrowError::ParseError(format!(
1026 "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1028 s,
1029 "Boolean",
1030 col_idx,
1031 line_number + row_index,
1032 row
1033 ))),
1034 }
1035 })
1036 .collect::<Result<BooleanArray, _>>()
1037 .map(|e| Arc::new(e) as ArrayRef)
1038}
1039
1040#[derive(Debug)]
1042pub struct ReaderBuilder {
1043 schema: SchemaRef,
1045 format: Format,
1047 batch_size: usize,
1051 bounds: Bounds,
1053 projection: Option<Vec<usize>>,
1055}
1056
1057impl ReaderBuilder {
1058 pub fn new(schema: SchemaRef) -> ReaderBuilder {
1080 Self {
1081 schema,
1082 format: Format::default(),
1083 batch_size: 1024,
1084 bounds: None,
1085 projection: None,
1086 }
1087 }
1088
1089 pub fn with_header(mut self, has_header: bool) -> Self {
1091 self.format.header = has_header;
1092 self
1093 }
1094
1095 pub fn with_format(mut self, format: Format) -> Self {
1097 self.format = format;
1098 self
1099 }
1100
1101 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1103 self.format.delimiter = Some(delimiter);
1104 self
1105 }
1106
1107 pub fn with_escape(mut self, escape: u8) -> Self {
1109 self.format.escape = Some(escape);
1110 self
1111 }
1112
1113 pub fn with_quote(mut self, quote: u8) -> Self {
1115 self.format.quote = Some(quote);
1116 self
1117 }
1118
1119 pub fn with_terminator(mut self, terminator: u8) -> Self {
1121 self.format.terminator = Some(terminator);
1122 self
1123 }
1124
1125 pub fn with_comment(mut self, comment: u8) -> Self {
1127 self.format.comment = Some(comment);
1128 self
1129 }
1130
1131 pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1133 self.format.null_regex = NullRegex(Some(null_regex));
1134 self
1135 }
1136
1137 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1139 self.batch_size = batch_size;
1140 self
1141 }
1142
1143 pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1146 self.bounds = Some((start, end));
1147 self
1148 }
1149
1150 pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1152 self.projection = Some(projection);
1153 self
1154 }
1155
1156 pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1163 self.format.truncated_rows = allow;
1164 self
1165 }
1166
1167 pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1172 self.build_buffered(StdBufReader::new(reader))
1173 }
1174
1175 pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1177 Ok(BufReader {
1178 reader,
1179 decoder: self.build_decoder(),
1180 })
1181 }
1182
1183 pub fn build_decoder(self) -> Decoder {
1185 let delimiter = self.format.build_parser();
1186 let record_decoder = RecordDecoder::new(
1187 delimiter,
1188 self.schema.fields().len(),
1189 self.format.truncated_rows,
1190 );
1191
1192 let header = self.format.header as usize;
1193
1194 let (start, end) = match self.bounds {
1195 Some((start, end)) => (start + header, end + header),
1196 None => (header, usize::MAX),
1197 };
1198
1199 Decoder {
1200 schema: self.schema,
1201 to_skip: start,
1202 record_decoder,
1203 line_number: start,
1204 end,
1205 projection: self.projection,
1206 batch_size: self.batch_size,
1207 null_regex: self.format.null_regex,
1208 }
1209 }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214 use super::*;
1215
1216 use std::io::{Cursor, Seek, SeekFrom, Write};
1217 use tempfile::NamedTempFile;
1218
1219 use arrow_array::cast::AsArray;
1220
1221 #[test]
1222 fn test_csv() {
1223 let schema = Arc::new(Schema::new(vec![
1224 Field::new("city", DataType::Utf8, false),
1225 Field::new("lat", DataType::Float64, false),
1226 Field::new("lng", DataType::Float64, false),
1227 ]));
1228
1229 let file = File::open("test/data/uk_cities.csv").unwrap();
1230 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1231 assert_eq!(schema, csv.schema());
1232 let batch = csv.next().unwrap().unwrap();
1233 assert_eq!(37, batch.num_rows());
1234 assert_eq!(3, batch.num_columns());
1235
1236 let lat = batch.column(1).as_primitive::<Float64Type>();
1238 assert_eq!(57.653484, lat.value(0));
1239
1240 let city = batch.column(0).as_string::<i32>();
1242
1243 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1244 }
1245
1246 #[test]
1247 fn test_csv_schema_metadata() {
1248 let mut metadata = std::collections::HashMap::new();
1249 metadata.insert("foo".to_owned(), "bar".to_owned());
1250 let schema = Arc::new(Schema::new_with_metadata(
1251 vec![
1252 Field::new("city", DataType::Utf8, false),
1253 Field::new("lat", DataType::Float64, false),
1254 Field::new("lng", DataType::Float64, false),
1255 ],
1256 metadata.clone(),
1257 ));
1258
1259 let file = File::open("test/data/uk_cities.csv").unwrap();
1260
1261 let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1262 assert_eq!(schema, csv.schema());
1263 let batch = csv.next().unwrap().unwrap();
1264 assert_eq!(37, batch.num_rows());
1265 assert_eq!(3, batch.num_columns());
1266
1267 assert_eq!(&metadata, batch.schema().metadata());
1268 }
1269
1270 #[test]
1271 fn test_csv_reader_with_decimal() {
1272 let schema = Arc::new(Schema::new(vec![
1273 Field::new("city", DataType::Utf8, false),
1274 Field::new("lat", DataType::Decimal128(38, 6), false),
1275 Field::new("lng", DataType::Decimal256(76, 6), false),
1276 ]));
1277
1278 let file = File::open("test/data/decimal_test.csv").unwrap();
1279
1280 let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1281 let batch = csv.next().unwrap().unwrap();
1282 let lat = batch
1284 .column(1)
1285 .as_any()
1286 .downcast_ref::<Decimal128Array>()
1287 .unwrap();
1288
1289 assert_eq!("57.653484", lat.value_as_string(0));
1290 assert_eq!("53.002666", lat.value_as_string(1));
1291 assert_eq!("52.412811", lat.value_as_string(2));
1292 assert_eq!("51.481583", lat.value_as_string(3));
1293 assert_eq!("12.123456", lat.value_as_string(4));
1294 assert_eq!("50.760000", lat.value_as_string(5));
1295 assert_eq!("0.123000", lat.value_as_string(6));
1296 assert_eq!("123.000000", lat.value_as_string(7));
1297 assert_eq!("123.000000", lat.value_as_string(8));
1298 assert_eq!("-50.760000", lat.value_as_string(9));
1299
1300 let lng = batch
1301 .column(2)
1302 .as_any()
1303 .downcast_ref::<Decimal256Array>()
1304 .unwrap();
1305
1306 assert_eq!("-3.335724", lng.value_as_string(0));
1307 assert_eq!("-2.179404", lng.value_as_string(1));
1308 assert_eq!("-1.778197", lng.value_as_string(2));
1309 assert_eq!("-3.179090", lng.value_as_string(3));
1310 assert_eq!("-3.179090", lng.value_as_string(4));
1311 assert_eq!("0.290472", lng.value_as_string(5));
1312 assert_eq!("0.290472", lng.value_as_string(6));
1313 assert_eq!("0.290472", lng.value_as_string(7));
1314 assert_eq!("0.290472", lng.value_as_string(8));
1315 assert_eq!("0.290472", lng.value_as_string(9));
1316 }
1317
1318 #[test]
1319 fn test_csv_from_buf_reader() {
1320 let schema = Schema::new(vec![
1321 Field::new("city", DataType::Utf8, false),
1322 Field::new("lat", DataType::Float64, false),
1323 Field::new("lng", DataType::Float64, false),
1324 ]);
1325
1326 let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1327 let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1328 let both_files = file_with_headers
1329 .chain(Cursor::new("\n".to_string()))
1330 .chain(file_without_headers);
1331 let mut csv = ReaderBuilder::new(Arc::new(schema))
1332 .with_header(true)
1333 .build(both_files)
1334 .unwrap();
1335 let batch = csv.next().unwrap().unwrap();
1336 assert_eq!(74, batch.num_rows());
1337 assert_eq!(3, batch.num_columns());
1338 }
1339
1340 #[test]
1341 fn test_csv_with_schema_inference() {
1342 let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1343
1344 let (schema, _) = Format::default()
1345 .with_header(true)
1346 .infer_schema(&mut file, None)
1347 .unwrap();
1348
1349 file.rewind().unwrap();
1350 let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1351
1352 let mut csv = builder.build(file).unwrap();
1353 let expected_schema = Schema::new(vec![
1354 Field::new("city", DataType::Utf8, true),
1355 Field::new("lat", DataType::Float64, true),
1356 Field::new("lng", DataType::Float64, true),
1357 ]);
1358 assert_eq!(Arc::new(expected_schema), csv.schema());
1359 let batch = csv.next().unwrap().unwrap();
1360 assert_eq!(37, batch.num_rows());
1361 assert_eq!(3, batch.num_columns());
1362
1363 let lat = batch
1365 .column(1)
1366 .as_any()
1367 .downcast_ref::<Float64Array>()
1368 .unwrap();
1369 assert_eq!(57.653484, lat.value(0));
1370
1371 let city = batch
1373 .column(0)
1374 .as_any()
1375 .downcast_ref::<StringArray>()
1376 .unwrap();
1377
1378 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1379 }
1380
1381 #[test]
1382 fn test_csv_with_schema_inference_no_headers() {
1383 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1384
1385 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1386 file.rewind().unwrap();
1387
1388 let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1389
1390 let schema = csv.schema();
1392 assert_eq!("column_1", schema.field(0).name());
1393 assert_eq!("column_2", schema.field(1).name());
1394 assert_eq!("column_3", schema.field(2).name());
1395 let batch = csv.next().unwrap().unwrap();
1396 let batch_schema = batch.schema();
1397
1398 assert_eq!(schema, batch_schema);
1399 assert_eq!(37, batch.num_rows());
1400 assert_eq!(3, batch.num_columns());
1401
1402 let lat = batch
1404 .column(1)
1405 .as_any()
1406 .downcast_ref::<Float64Array>()
1407 .unwrap();
1408 assert_eq!(57.653484, lat.value(0));
1409
1410 let city = batch
1412 .column(0)
1413 .as_any()
1414 .downcast_ref::<StringArray>()
1415 .unwrap();
1416
1417 assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1418 }
1419
1420 #[test]
1421 fn test_csv_builder_with_bounds() {
1422 let mut file = File::open("test/data/uk_cities.csv").unwrap();
1423
1424 let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1426 file.rewind().unwrap();
1427 let mut csv = ReaderBuilder::new(Arc::new(schema))
1428 .with_bounds(0, 2)
1429 .build(file)
1430 .unwrap();
1431 let batch = csv.next().unwrap().unwrap();
1432
1433 let city = batch
1435 .column(0)
1436 .as_any()
1437 .downcast_ref::<StringArray>()
1438 .unwrap();
1439
1440 assert_eq!("Elgin, Scotland, the UK", city.value(0));
1442
1443 let result = std::panic::catch_unwind(|| city.value(13));
1446 assert!(result.is_err());
1447 }
1448
1449 #[test]
1450 fn test_csv_with_projection() {
1451 let schema = Arc::new(Schema::new(vec![
1452 Field::new("city", DataType::Utf8, false),
1453 Field::new("lat", DataType::Float64, false),
1454 Field::new("lng", DataType::Float64, false),
1455 ]));
1456
1457 let file = File::open("test/data/uk_cities.csv").unwrap();
1458
1459 let mut csv = ReaderBuilder::new(schema)
1460 .with_projection(vec![0, 1])
1461 .build(file)
1462 .unwrap();
1463
1464 let projected_schema = Arc::new(Schema::new(vec![
1465 Field::new("city", DataType::Utf8, false),
1466 Field::new("lat", DataType::Float64, false),
1467 ]));
1468 assert_eq!(projected_schema, csv.schema());
1469 let batch = csv.next().unwrap().unwrap();
1470 assert_eq!(projected_schema, batch.schema());
1471 assert_eq!(37, batch.num_rows());
1472 assert_eq!(2, batch.num_columns());
1473 }
1474
1475 #[test]
1476 fn test_csv_with_dictionary() {
1477 let schema = Arc::new(Schema::new(vec![
1478 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1479 Field::new("lat", DataType::Float64, false),
1480 Field::new("lng", DataType::Float64, false),
1481 ]));
1482
1483 let file = File::open("test/data/uk_cities.csv").unwrap();
1484
1485 let mut csv = ReaderBuilder::new(schema)
1486 .with_projection(vec![0, 1])
1487 .build(file)
1488 .unwrap();
1489
1490 let projected_schema = Arc::new(Schema::new(vec![
1491 Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1492 Field::new("lat", DataType::Float64, false),
1493 ]));
1494 assert_eq!(projected_schema, csv.schema());
1495 let batch = csv.next().unwrap().unwrap();
1496 assert_eq!(projected_schema, batch.schema());
1497 assert_eq!(37, batch.num_rows());
1498 assert_eq!(2, batch.num_columns());
1499
1500 let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1501 let strings = strings.as_string::<i32>();
1502
1503 assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1504 assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1505 assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1506 }
1507
1508 #[test]
1509 fn test_csv_with_nullable_dictionary() {
1510 let offset_type = vec![
1511 DataType::Int8,
1512 DataType::Int16,
1513 DataType::Int32,
1514 DataType::Int64,
1515 DataType::UInt8,
1516 DataType::UInt16,
1517 DataType::UInt32,
1518 DataType::UInt64,
1519 ];
1520 for data_type in offset_type {
1521 let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1522 let dictionary_type =
1523 DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1524 let schema = Arc::new(Schema::new(vec![
1525 Field::new("id", DataType::Utf8, false),
1526 Field::new("name", dictionary_type.clone(), true),
1527 ]));
1528
1529 let mut csv = ReaderBuilder::new(schema)
1530 .build(file.try_clone().unwrap())
1531 .unwrap();
1532
1533 let batch = csv.next().unwrap().unwrap();
1534 assert_eq!(3, batch.num_rows());
1535 assert_eq!(2, batch.num_columns());
1536
1537 let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1538 assert!(!names.is_null(2));
1539 assert!(names.is_null(1));
1540 }
1541 }
1542 #[test]
1543 fn test_nulls() {
1544 let schema = Arc::new(Schema::new(vec![
1545 Field::new("c_int", DataType::UInt64, false),
1546 Field::new("c_float", DataType::Float32, true),
1547 Field::new("c_string", DataType::Utf8, true),
1548 Field::new("c_bool", DataType::Boolean, false),
1549 ]));
1550
1551 let file = File::open("test/data/null_test.csv").unwrap();
1552
1553 let mut csv = ReaderBuilder::new(schema)
1554 .with_header(true)
1555 .build(file)
1556 .unwrap();
1557
1558 let batch = csv.next().unwrap().unwrap();
1559
1560 assert!(!batch.column(1).is_null(0));
1561 assert!(!batch.column(1).is_null(1));
1562 assert!(batch.column(1).is_null(2));
1563 assert!(!batch.column(1).is_null(3));
1564 assert!(!batch.column(1).is_null(4));
1565 }
1566
1567 #[test]
1568 fn test_init_nulls() {
1569 let schema = Arc::new(Schema::new(vec![
1570 Field::new("c_int", DataType::UInt64, true),
1571 Field::new("c_float", DataType::Float32, true),
1572 Field::new("c_string", DataType::Utf8, true),
1573 Field::new("c_bool", DataType::Boolean, true),
1574 Field::new("c_null", DataType::Null, true),
1575 ]));
1576 let file = File::open("test/data/init_null_test.csv").unwrap();
1577
1578 let mut csv = ReaderBuilder::new(schema)
1579 .with_header(true)
1580 .build(file)
1581 .unwrap();
1582
1583 let batch = csv.next().unwrap().unwrap();
1584
1585 assert!(batch.column(1).is_null(0));
1586 assert!(!batch.column(1).is_null(1));
1587 assert!(batch.column(1).is_null(2));
1588 assert!(!batch.column(1).is_null(3));
1589 assert!(!batch.column(1).is_null(4));
1590 }
1591
1592 #[test]
1593 fn test_init_nulls_with_inference() {
1594 let format = Format::default().with_header(true).with_delimiter(b',');
1595
1596 let mut file = File::open("test/data/init_null_test.csv").unwrap();
1597 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1598 file.rewind().unwrap();
1599
1600 let expected_schema = Schema::new(vec![
1601 Field::new("c_int", DataType::Int64, true),
1602 Field::new("c_float", DataType::Float64, true),
1603 Field::new("c_string", DataType::Utf8, true),
1604 Field::new("c_bool", DataType::Boolean, true),
1605 Field::new("c_null", DataType::Null, true),
1606 ]);
1607 assert_eq!(schema, expected_schema);
1608
1609 let mut csv = ReaderBuilder::new(Arc::new(schema))
1610 .with_format(format)
1611 .build(file)
1612 .unwrap();
1613
1614 let batch = csv.next().unwrap().unwrap();
1615
1616 assert!(batch.column(1).is_null(0));
1617 assert!(!batch.column(1).is_null(1));
1618 assert!(batch.column(1).is_null(2));
1619 assert!(!batch.column(1).is_null(3));
1620 assert!(!batch.column(1).is_null(4));
1621 }
1622
1623 #[test]
1624 fn test_custom_nulls() {
1625 let schema = Arc::new(Schema::new(vec![
1626 Field::new("c_int", DataType::UInt64, true),
1627 Field::new("c_float", DataType::Float32, true),
1628 Field::new("c_string", DataType::Utf8, true),
1629 Field::new("c_bool", DataType::Boolean, true),
1630 ]));
1631
1632 let file = File::open("test/data/custom_null_test.csv").unwrap();
1633
1634 let null_regex = Regex::new("^nil$").unwrap();
1635
1636 let mut csv = ReaderBuilder::new(schema)
1637 .with_header(true)
1638 .with_null_regex(null_regex)
1639 .build(file)
1640 .unwrap();
1641
1642 let batch = csv.next().unwrap().unwrap();
1643
1644 assert!(batch.column(0).is_null(1));
1646 assert!(batch.column(1).is_null(2));
1647 assert!(batch.column(3).is_null(4));
1648 assert!(batch.column(2).is_null(3));
1649 assert!(!batch.column(2).is_null(4));
1650 }
1651
1652 #[test]
1653 fn test_nulls_with_inference() {
1654 let mut file = File::open("test/data/various_types.csv").unwrap();
1655 let format = Format::default().with_header(true).with_delimiter(b'|');
1656
1657 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1658 file.rewind().unwrap();
1659
1660 let builder = ReaderBuilder::new(Arc::new(schema))
1661 .with_format(format)
1662 .with_batch_size(512)
1663 .with_projection(vec![0, 1, 2, 3, 4, 5]);
1664
1665 let mut csv = builder.build(file).unwrap();
1666 let batch = csv.next().unwrap().unwrap();
1667
1668 assert_eq!(10, batch.num_rows());
1669 assert_eq!(6, batch.num_columns());
1670
1671 let schema = batch.schema();
1672
1673 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1674 assert_eq!(&DataType::Float64, schema.field(1).data_type());
1675 assert_eq!(&DataType::Float64, schema.field(2).data_type());
1676 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1677 assert_eq!(&DataType::Date32, schema.field(4).data_type());
1678 assert_eq!(
1679 &DataType::Timestamp(TimeUnit::Second, None),
1680 schema.field(5).data_type()
1681 );
1682
1683 let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1684 assert_eq!(
1685 names,
1686 vec![
1687 "c_int",
1688 "c_float",
1689 "c_string",
1690 "c_bool",
1691 "c_date",
1692 "c_datetime"
1693 ]
1694 );
1695
1696 assert!(schema.field(0).is_nullable());
1697 assert!(schema.field(1).is_nullable());
1698 assert!(schema.field(2).is_nullable());
1699 assert!(schema.field(3).is_nullable());
1700 assert!(schema.field(4).is_nullable());
1701 assert!(schema.field(5).is_nullable());
1702
1703 assert!(!batch.column(1).is_null(0));
1704 assert!(!batch.column(1).is_null(1));
1705 assert!(batch.column(1).is_null(2));
1706 assert!(!batch.column(1).is_null(3));
1707 assert!(!batch.column(1).is_null(4));
1708 }
1709
1710 #[test]
1711 fn test_custom_nulls_with_inference() {
1712 let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1713
1714 let null_regex = Regex::new("^nil$").unwrap();
1715
1716 let format = Format::default()
1717 .with_header(true)
1718 .with_null_regex(null_regex);
1719
1720 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1721 file.rewind().unwrap();
1722
1723 let expected_schema = Schema::new(vec![
1724 Field::new("c_int", DataType::Int64, true),
1725 Field::new("c_float", DataType::Float64, true),
1726 Field::new("c_string", DataType::Utf8, true),
1727 Field::new("c_bool", DataType::Boolean, true),
1728 ]);
1729
1730 assert_eq!(schema, expected_schema);
1731
1732 let builder = ReaderBuilder::new(Arc::new(schema))
1733 .with_format(format)
1734 .with_batch_size(512)
1735 .with_projection(vec![0, 1, 2, 3]);
1736
1737 let mut csv = builder.build(file).unwrap();
1738 let batch = csv.next().unwrap().unwrap();
1739
1740 assert_eq!(5, batch.num_rows());
1741 assert_eq!(4, batch.num_columns());
1742
1743 assert_eq!(batch.schema().as_ref(), &expected_schema);
1744 }
1745
1746 #[test]
1747 fn test_scientific_notation_with_inference() {
1748 let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1749 let format = Format::default().with_header(false).with_delimiter(b',');
1750
1751 let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1752 file.rewind().unwrap();
1753
1754 let builder = ReaderBuilder::new(Arc::new(schema))
1755 .with_format(format)
1756 .with_batch_size(512)
1757 .with_projection(vec![0, 1]);
1758
1759 let mut csv = builder.build(file).unwrap();
1760 let batch = csv.next().unwrap().unwrap();
1761
1762 let schema = batch.schema();
1763
1764 assert_eq!(&DataType::Float64, schema.field(0).data_type());
1765 }
1766
1767 fn invalid_csv_helper(file_name: &str) -> String {
1768 let file = File::open(file_name).unwrap();
1769 let schema = Schema::new(vec![
1770 Field::new("c_int", DataType::UInt64, false),
1771 Field::new("c_float", DataType::Float32, false),
1772 Field::new("c_string", DataType::Utf8, false),
1773 Field::new("c_bool", DataType::Boolean, false),
1774 ]);
1775
1776 let builder = ReaderBuilder::new(Arc::new(schema))
1777 .with_header(true)
1778 .with_delimiter(b'|')
1779 .with_batch_size(512)
1780 .with_projection(vec![0, 1, 2, 3]);
1781
1782 let mut csv = builder.build(file).unwrap();
1783
1784 csv.next().unwrap().unwrap_err().to_string()
1785 }
1786
1787 #[test]
1788 fn test_parse_invalid_csv_float() {
1789 let file_name = "test/data/various_invalid_types/invalid_float.csv";
1790
1791 let error = invalid_csv_helper(file_name);
1792 assert_eq!("Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", error);
1793 }
1794
1795 #[test]
1796 fn test_parse_invalid_csv_int() {
1797 let file_name = "test/data/various_invalid_types/invalid_int.csv";
1798
1799 let error = invalid_csv_helper(file_name);
1800 assert_eq!("Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", error);
1801 }
1802
1803 #[test]
1804 fn test_parse_invalid_csv_bool() {
1805 let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1806
1807 let error = invalid_csv_helper(file_name);
1808 assert_eq!("Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", error);
1809 }
1810
1811 fn infer_field_schema(string: &str) -> DataType {
1813 let mut v = InferredDataType::default();
1814 v.update(string);
1815 v.get()
1816 }
1817
1818 #[test]
1819 fn test_infer_field_schema() {
1820 assert_eq!(infer_field_schema("A"), DataType::Utf8);
1821 assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1822 assert_eq!(infer_field_schema("10"), DataType::Int64);
1823 assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1824 assert_eq!(infer_field_schema(".2"), DataType::Float64);
1825 assert_eq!(infer_field_schema("2."), DataType::Float64);
1826 assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1827 assert_eq!(infer_field_schema("nan"), DataType::Float64);
1828 assert_eq!(infer_field_schema("inf"), DataType::Float64);
1829 assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1830 assert_eq!(infer_field_schema("true"), DataType::Boolean);
1831 assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1832 assert_eq!(infer_field_schema("false"), DataType::Boolean);
1833 assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1834 assert_eq!(
1835 infer_field_schema("2020-11-08T14:20:01"),
1836 DataType::Timestamp(TimeUnit::Second, None)
1837 );
1838 assert_eq!(
1839 infer_field_schema("2020-11-08 14:20:01"),
1840 DataType::Timestamp(TimeUnit::Second, None)
1841 );
1842 assert_eq!(
1843 infer_field_schema("2020-11-08 14:20:01"),
1844 DataType::Timestamp(TimeUnit::Second, None)
1845 );
1846 assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1847 assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1848 assert_eq!(
1849 infer_field_schema("2021-12-19 13:12:30.921"),
1850 DataType::Timestamp(TimeUnit::Millisecond, None)
1851 );
1852 assert_eq!(
1853 infer_field_schema("2021-12-19T13:12:30.123456789"),
1854 DataType::Timestamp(TimeUnit::Nanosecond, None)
1855 );
1856 assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1857 assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1858 }
1859
1860 #[test]
1861 fn parse_date32() {
1862 assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1863 assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1864 assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1865 }
1866
1867 #[test]
1868 fn parse_time() {
1869 assert_eq!(
1870 Time64NanosecondType::parse("12:10:01.123456789 AM"),
1871 Some(601_123_456_789)
1872 );
1873 assert_eq!(
1874 Time64MicrosecondType::parse("12:10:01.123456 am"),
1875 Some(601_123_456)
1876 );
1877 assert_eq!(
1878 Time32MillisecondType::parse("2:10:01.12 PM"),
1879 Some(51_001_120)
1880 );
1881 assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1882 }
1883
1884 #[test]
1885 fn parse_date64() {
1886 assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1887 assert_eq!(
1888 Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1889 1542129070000
1890 );
1891 assert_eq!(
1892 Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1893 1542129070011
1894 );
1895 assert_eq!(
1896 Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1897 -2203932304000
1898 );
1899 assert_eq!(
1900 Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1901 -2203932304000
1902 );
1903 assert_eq!(
1904 Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1905 -2203932304000 - (30 * 60 * 1000)
1906 );
1907 }
1908
1909 fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1910 timezone: Option<Arc<str>>,
1911 expected: &[i64],
1912 ) {
1913 let csv = [
1914 "1970-01-01T00:00:00",
1915 "1970-01-01T00:00:00Z",
1916 "1970-01-01T00:00:00+02:00",
1917 ]
1918 .join("\n");
1919 let schema = Arc::new(Schema::new(vec![Field::new(
1920 "field",
1921 DataType::Timestamp(T::UNIT, timezone.clone()),
1922 true,
1923 )]));
1924
1925 let mut decoder = ReaderBuilder::new(schema).build_decoder();
1926
1927 let decoded = decoder.decode(csv.as_bytes()).unwrap();
1928 assert_eq!(decoded, csv.len());
1929 decoder.decode(&[]).unwrap();
1930
1931 let batch = decoder.flush().unwrap().unwrap();
1932 assert_eq!(batch.num_columns(), 1);
1933 assert_eq!(batch.num_rows(), 3);
1934 let col = batch.column(0).as_primitive::<T>();
1935 assert_eq!(col.values(), expected);
1936 assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
1937 }
1938
1939 #[test]
1940 fn test_parse_timestamp() {
1941 test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
1942 test_parse_timestamp_impl::<TimestampNanosecondType>(
1943 Some("+00:00".into()),
1944 &[0, 0, -7_200_000_000_000],
1945 );
1946 test_parse_timestamp_impl::<TimestampNanosecondType>(
1947 Some("-05:00".into()),
1948 &[18_000_000_000_000, 0, -7_200_000_000_000],
1949 );
1950 test_parse_timestamp_impl::<TimestampMicrosecondType>(
1951 Some("-03".into()),
1952 &[10_800_000_000, 0, -7_200_000_000],
1953 );
1954 test_parse_timestamp_impl::<TimestampMillisecondType>(
1955 Some("-03".into()),
1956 &[10_800_000, 0, -7_200_000],
1957 );
1958 test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
1959 }
1960
1961 #[test]
1962 fn test_infer_schema_from_multiple_files() {
1963 let mut csv1 = NamedTempFile::new().unwrap();
1964 let mut csv2 = NamedTempFile::new().unwrap();
1965 let csv3 = NamedTempFile::new().unwrap(); let mut csv4 = NamedTempFile::new().unwrap();
1967 writeln!(csv1, "c1,c2,c3").unwrap();
1968 writeln!(csv1, "1,\"foo\",0.5").unwrap();
1969 writeln!(csv1, "3,\"bar\",1").unwrap();
1970 writeln!(csv1, "3,\"bar\",2e-06").unwrap();
1971 writeln!(csv2, "c1,c2,c3,c4").unwrap();
1973 writeln!(csv2, "10,,3.14,true").unwrap();
1974 writeln!(csv4, "c1,c2,c3").unwrap();
1976 writeln!(csv4, "10,\"foo\",").unwrap();
1977
1978 let schema = infer_schema_from_files(
1979 &[
1980 csv3.path().to_str().unwrap().to_string(),
1981 csv1.path().to_str().unwrap().to_string(),
1982 csv2.path().to_str().unwrap().to_string(),
1983 csv4.path().to_str().unwrap().to_string(),
1984 ],
1985 b',',
1986 Some(4), true,
1988 )
1989 .unwrap();
1990
1991 assert_eq!(schema.fields().len(), 4);
1992 assert!(schema.field(0).is_nullable());
1993 assert!(schema.field(1).is_nullable());
1994 assert!(schema.field(2).is_nullable());
1995 assert!(schema.field(3).is_nullable());
1996
1997 assert_eq!(&DataType::Int64, schema.field(0).data_type());
1998 assert_eq!(&DataType::Utf8, schema.field(1).data_type());
1999 assert_eq!(&DataType::Float64, schema.field(2).data_type());
2000 assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2001 }
2002
2003 #[test]
2004 fn test_bounded() {
2005 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2006 let data = [
2007 vec!["0"],
2008 vec!["1"],
2009 vec!["2"],
2010 vec!["3"],
2011 vec!["4"],
2012 vec!["5"],
2013 vec!["6"],
2014 ];
2015
2016 let data = data
2017 .iter()
2018 .map(|x| x.join(","))
2019 .collect::<Vec<_>>()
2020 .join("\n");
2021 let data = data.as_bytes();
2022
2023 let reader = std::io::Cursor::new(data);
2024
2025 let mut csv = ReaderBuilder::new(Arc::new(schema))
2026 .with_batch_size(2)
2027 .with_projection(vec![0])
2028 .with_bounds(2, 6)
2029 .build_buffered(reader)
2030 .unwrap();
2031
2032 let batch = csv.next().unwrap().unwrap();
2033 let a = batch.column(0);
2034 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2035 assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2036
2037 let batch = csv.next().unwrap().unwrap();
2038 let a = batch.column(0);
2039 let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2040 assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2041
2042 assert!(csv.next().is_none());
2043 }
2044
2045 #[test]
2046 fn test_empty_projection() {
2047 let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2048 let data = [vec!["0"], vec!["1"]];
2049
2050 let data = data
2051 .iter()
2052 .map(|x| x.join(","))
2053 .collect::<Vec<_>>()
2054 .join("\n");
2055
2056 let mut csv = ReaderBuilder::new(Arc::new(schema))
2057 .with_batch_size(2)
2058 .with_projection(vec![])
2059 .build_buffered(Cursor::new(data.as_bytes()))
2060 .unwrap();
2061
2062 let batch = csv.next().unwrap().unwrap();
2063 assert_eq!(batch.columns().len(), 0);
2064 assert_eq!(batch.num_rows(), 2);
2065
2066 assert!(csv.next().is_none());
2067 }
2068
2069 #[test]
2070 fn test_parsing_bool() {
2071 assert_eq!(Some(true), parse_bool("true"));
2073 assert_eq!(Some(true), parse_bool("tRUe"));
2074 assert_eq!(Some(true), parse_bool("True"));
2075 assert_eq!(Some(true), parse_bool("TRUE"));
2076 assert_eq!(None, parse_bool("t"));
2077 assert_eq!(None, parse_bool("T"));
2078 assert_eq!(None, parse_bool(""));
2079
2080 assert_eq!(Some(false), parse_bool("false"));
2081 assert_eq!(Some(false), parse_bool("fALse"));
2082 assert_eq!(Some(false), parse_bool("False"));
2083 assert_eq!(Some(false), parse_bool("FALSE"));
2084 assert_eq!(None, parse_bool("f"));
2085 assert_eq!(None, parse_bool("F"));
2086 assert_eq!(None, parse_bool(""));
2087 }
2088
2089 #[test]
2090 fn test_parsing_float() {
2091 assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2092 assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2093 assert_eq!(Some(12.0), Float64Type::parse("12"));
2094 assert_eq!(Some(0.0), Float64Type::parse("0"));
2095 assert_eq!(Some(2.0), Float64Type::parse("2."));
2096 assert_eq!(Some(0.2), Float64Type::parse(".2"));
2097 assert!(Float64Type::parse("nan").unwrap().is_nan());
2098 assert!(Float64Type::parse("NaN").unwrap().is_nan());
2099 assert!(Float64Type::parse("inf").unwrap().is_infinite());
2100 assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2101 assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2102 assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2103 assert_eq!(None, Float64Type::parse(""));
2104 assert_eq!(None, Float64Type::parse("dd"));
2105 assert_eq!(None, Float64Type::parse("12.34.56"));
2106 }
2107
2108 #[test]
2109 fn test_non_std_quote() {
2110 let schema = Schema::new(vec![
2111 Field::new("text1", DataType::Utf8, false),
2112 Field::new("text2", DataType::Utf8, false),
2113 ]);
2114 let builder = ReaderBuilder::new(Arc::new(schema))
2115 .with_header(false)
2116 .with_quote(b'~'); let mut csv_text = Vec::new();
2119 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2120 for index in 0..10 {
2121 let text1 = format!("id{index:}");
2122 let text2 = format!("value{index:}");
2123 csv_writer
2124 .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2125 .unwrap();
2126 }
2127 let mut csv_reader = std::io::Cursor::new(&csv_text);
2128 let mut reader = builder.build(&mut csv_reader).unwrap();
2129 let batch = reader.next().unwrap().unwrap();
2130 let col0 = batch.column(0);
2131 assert_eq!(col0.len(), 10);
2132 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2133 assert_eq!(col0_arr.value(0), "id0");
2134 let col1 = batch.column(1);
2135 assert_eq!(col1.len(), 10);
2136 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2137 assert_eq!(col1_arr.value(5), "value5");
2138 }
2139
2140 #[test]
2141 fn test_non_std_escape() {
2142 let schema = Schema::new(vec![
2143 Field::new("text1", DataType::Utf8, false),
2144 Field::new("text2", DataType::Utf8, false),
2145 ]);
2146 let builder = ReaderBuilder::new(Arc::new(schema))
2147 .with_header(false)
2148 .with_escape(b'\\'); let mut csv_text = Vec::new();
2151 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2152 for index in 0..10 {
2153 let text1 = format!("id{index:}");
2154 let text2 = format!("value\\\"{index:}");
2155 csv_writer
2156 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2157 .unwrap();
2158 }
2159 let mut csv_reader = std::io::Cursor::new(&csv_text);
2160 let mut reader = builder.build(&mut csv_reader).unwrap();
2161 let batch = reader.next().unwrap().unwrap();
2162 let col0 = batch.column(0);
2163 assert_eq!(col0.len(), 10);
2164 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2165 assert_eq!(col0_arr.value(0), "id0");
2166 let col1 = batch.column(1);
2167 assert_eq!(col1.len(), 10);
2168 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2169 assert_eq!(col1_arr.value(5), "value\"5");
2170 }
2171
2172 #[test]
2173 fn test_non_std_terminator() {
2174 let schema = Schema::new(vec![
2175 Field::new("text1", DataType::Utf8, false),
2176 Field::new("text2", DataType::Utf8, false),
2177 ]);
2178 let builder = ReaderBuilder::new(Arc::new(schema))
2179 .with_header(false)
2180 .with_terminator(b'\n'); let mut csv_text = Vec::new();
2183 let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2184 for index in 0..10 {
2185 let text1 = format!("id{index:}");
2186 let text2 = format!("value{index:}");
2187 csv_writer
2188 .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2189 .unwrap();
2190 }
2191 let mut csv_reader = std::io::Cursor::new(&csv_text);
2192 let mut reader = builder.build(&mut csv_reader).unwrap();
2193 let batch = reader.next().unwrap().unwrap();
2194 let col0 = batch.column(0);
2195 assert_eq!(col0.len(), 10);
2196 let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2197 assert_eq!(col0_arr.value(0), "id0");
2198 let col1 = batch.column(1);
2199 assert_eq!(col1.len(), 10);
2200 let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2201 assert_eq!(col1_arr.value(5), "value5");
2202 }
2203
2204 #[test]
2205 fn test_header_bounds() {
2206 let csv = "a,b\na,b\na,b\na,b\na,b\n";
2207 let tests = [
2208 (None, false, 5),
2209 (None, true, 4),
2210 (Some((0, 4)), false, 4),
2211 (Some((1, 4)), false, 3),
2212 (Some((0, 4)), true, 4),
2213 (Some((1, 4)), true, 3),
2214 ];
2215 let schema = Arc::new(Schema::new(vec![
2216 Field::new("a", DataType::Utf8, false),
2217 Field::new("a", DataType::Utf8, false),
2218 ]));
2219
2220 for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2221 let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2222 if let Some((start, end)) = bounds {
2223 reader = reader.with_bounds(start, end);
2224 }
2225 let b = reader
2226 .build_buffered(Cursor::new(csv.as_bytes()))
2227 .unwrap()
2228 .next()
2229 .unwrap()
2230 .unwrap();
2231 assert_eq!(b.num_rows(), expected, "{idx}");
2232 }
2233 }
2234
2235 #[test]
2236 fn test_null_boolean() {
2237 let csv = "true,false\nFalse,True\n,True\nFalse,";
2238 let schema = Arc::new(Schema::new(vec![
2239 Field::new("a", DataType::Boolean, true),
2240 Field::new("a", DataType::Boolean, true),
2241 ]));
2242
2243 let b = ReaderBuilder::new(schema)
2244 .build_buffered(Cursor::new(csv.as_bytes()))
2245 .unwrap()
2246 .next()
2247 .unwrap()
2248 .unwrap();
2249
2250 assert_eq!(b.num_rows(), 4);
2251 assert_eq!(b.num_columns(), 2);
2252
2253 let c = b.column(0).as_boolean();
2254 assert_eq!(c.null_count(), 1);
2255 assert!(c.value(0));
2256 assert!(!c.value(1));
2257 assert!(c.is_null(2));
2258 assert!(!c.value(3));
2259
2260 let c = b.column(1).as_boolean();
2261 assert_eq!(c.null_count(), 1);
2262 assert!(!c.value(0));
2263 assert!(c.value(1));
2264 assert!(c.value(2));
2265 assert!(c.is_null(3));
2266 }
2267
2268 #[test]
2269 fn test_truncated_rows() {
2270 let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2271 let schema = Arc::new(Schema::new(vec![
2272 Field::new("a", DataType::Int32, true),
2273 Field::new("b", DataType::Int32, true),
2274 Field::new("c", DataType::Int32, true),
2275 ]));
2276
2277 let reader = ReaderBuilder::new(schema.clone())
2278 .with_header(true)
2279 .with_truncated_rows(true)
2280 .build(Cursor::new(data))
2281 .unwrap();
2282
2283 let batches = reader.collect::<Result<Vec<_>, _>>();
2284 assert!(batches.is_ok());
2285 let batch = batches.unwrap().into_iter().next().unwrap();
2286 assert_eq!(batch.num_rows(), 3);
2288
2289 let reader = ReaderBuilder::new(schema.clone())
2290 .with_header(true)
2291 .with_truncated_rows(false)
2292 .build(Cursor::new(data))
2293 .unwrap();
2294
2295 let batches = reader.collect::<Result<Vec<_>, _>>();
2296 assert!(match batches {
2297 Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2298 _ => false,
2299 });
2300 }
2301
2302 #[test]
2303 fn test_truncated_rows_csv() {
2304 let file = File::open("test/data/truncated_rows.csv").unwrap();
2305 let schema = Arc::new(Schema::new(vec![
2306 Field::new("Name", DataType::Utf8, true),
2307 Field::new("Age", DataType::UInt32, true),
2308 Field::new("Occupation", DataType::Utf8, true),
2309 Field::new("DOB", DataType::Date32, true),
2310 ]));
2311 let reader = ReaderBuilder::new(schema.clone())
2312 .with_header(true)
2313 .with_batch_size(24)
2314 .with_truncated_rows(true);
2315 let csv = reader.build(file).unwrap();
2316 let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2317
2318 assert_eq!(batches.len(), 1);
2319 let batch = &batches[0];
2320 assert_eq!(batch.num_rows(), 6);
2321 assert_eq!(batch.num_columns(), 4);
2322 let name = batch
2323 .column(0)
2324 .as_any()
2325 .downcast_ref::<StringArray>()
2326 .unwrap();
2327 let age = batch
2328 .column(1)
2329 .as_any()
2330 .downcast_ref::<UInt32Array>()
2331 .unwrap();
2332 let occupation = batch
2333 .column(2)
2334 .as_any()
2335 .downcast_ref::<StringArray>()
2336 .unwrap();
2337 let dob = batch
2338 .column(3)
2339 .as_any()
2340 .downcast_ref::<Date32Array>()
2341 .unwrap();
2342
2343 assert_eq!(name.value(0), "A1");
2344 assert_eq!(name.value(1), "B2");
2345 assert!(name.is_null(2));
2346 assert_eq!(name.value(3), "C3");
2347 assert_eq!(name.value(4), "D4");
2348 assert_eq!(name.value(5), "E5");
2349
2350 assert_eq!(age.value(0), 34);
2351 assert_eq!(age.value(1), 29);
2352 assert!(age.is_null(2));
2353 assert_eq!(age.value(3), 45);
2354 assert!(age.is_null(4));
2355 assert_eq!(age.value(5), 31);
2356
2357 assert_eq!(occupation.value(0), "Engineer");
2358 assert_eq!(occupation.value(1), "Doctor");
2359 assert!(occupation.is_null(2));
2360 assert_eq!(occupation.value(3), "Artist");
2361 assert!(occupation.is_null(4));
2362 assert!(occupation.is_null(5));
2363
2364 assert_eq!(dob.value(0), 5675);
2365 assert!(dob.is_null(1));
2366 assert!(dob.is_null(2));
2367 assert_eq!(dob.value(3), -1858);
2368 assert!(dob.is_null(4));
2369 assert!(dob.is_null(5));
2370 }
2371
2372 #[test]
2373 fn test_truncated_rows_not_nullable_error() {
2374 let data = "a,b,c\n1,2,3\n4,5";
2375 let schema = Arc::new(Schema::new(vec![
2376 Field::new("a", DataType::Int32, false),
2377 Field::new("b", DataType::Int32, false),
2378 Field::new("c", DataType::Int32, false),
2379 ]));
2380
2381 let reader = ReaderBuilder::new(schema.clone())
2382 .with_header(true)
2383 .with_truncated_rows(true)
2384 .build(Cursor::new(data))
2385 .unwrap();
2386
2387 let batches = reader.collect::<Result<Vec<_>, _>>();
2388 assert!(match batches {
2389 Err(ArrowError::InvalidArgumentError(e)) =>
2390 e.to_string().contains("contains null values"),
2391 _ => false,
2392 });
2393 }
2394
2395 #[test]
2396 fn test_buffered() {
2397 let tests = [
2398 ("test/data/uk_cities.csv", false, 37),
2399 ("test/data/various_types.csv", true, 10),
2400 ("test/data/decimal_test.csv", false, 10),
2401 ];
2402
2403 for (path, has_header, expected_rows) in tests {
2404 let (schema, _) = Format::default()
2405 .infer_schema(File::open(path).unwrap(), None)
2406 .unwrap();
2407 let schema = Arc::new(schema);
2408
2409 for batch_size in [1, 4] {
2410 for capacity in [1, 3, 7, 100] {
2411 let reader = ReaderBuilder::new(schema.clone())
2412 .with_batch_size(batch_size)
2413 .with_header(has_header)
2414 .build(File::open(path).unwrap())
2415 .unwrap();
2416
2417 let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2418
2419 assert_eq!(
2420 expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2421 expected_rows
2422 );
2423
2424 let buffered =
2425 std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2426
2427 let reader = ReaderBuilder::new(schema.clone())
2428 .with_batch_size(batch_size)
2429 .with_header(has_header)
2430 .build_buffered(buffered)
2431 .unwrap();
2432
2433 let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2434 assert_eq!(expected, actual)
2435 }
2436 }
2437 }
2438 }
2439
2440 fn err_test(csv: &[u8], expected: &str) {
2441 fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2442 let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2443 let b = ReaderBuilder::new(schema)
2444 .with_batch_size(2)
2445 .build_buffered(buffer)
2446 .unwrap();
2447 let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2448 assert_eq!(err, expected)
2449 }
2450
2451 let schema_utf8 = Arc::new(Schema::new(vec![
2452 Field::new("text1", DataType::Utf8, true),
2453 Field::new("text2", DataType::Utf8, true),
2454 ]));
2455 err_test_with_schema(csv, expected, schema_utf8);
2456
2457 let schema_utf8view = Arc::new(Schema::new(vec![
2458 Field::new("text1", DataType::Utf8View, true),
2459 Field::new("text2", DataType::Utf8View, true),
2460 ]));
2461 err_test_with_schema(csv, expected, schema_utf8view);
2462 }
2463
2464 #[test]
2465 fn test_invalid_utf8() {
2466 err_test(
2467 b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2468 "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2469 );
2470
2471 err_test(
2472 b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2473 "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2474 );
2475
2476 err_test(
2477 b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2478 "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2479 );
2480
2481 err_test(
2482 b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2483 "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2484 );
2485 }
2486
2487 struct InstrumentedRead<R> {
2488 r: R,
2489 fill_count: usize,
2490 fill_sizes: Vec<usize>,
2491 }
2492
2493 impl<R> InstrumentedRead<R> {
2494 fn new(r: R) -> Self {
2495 Self {
2496 r,
2497 fill_count: 0,
2498 fill_sizes: vec![],
2499 }
2500 }
2501 }
2502
2503 impl<R: Seek> Seek for InstrumentedRead<R> {
2504 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2505 self.r.seek(pos)
2506 }
2507 }
2508
2509 impl<R: BufRead> Read for InstrumentedRead<R> {
2510 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2511 self.r.read(buf)
2512 }
2513 }
2514
2515 impl<R: BufRead> BufRead for InstrumentedRead<R> {
2516 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2517 self.fill_count += 1;
2518 let buf = self.r.fill_buf()?;
2519 self.fill_sizes.push(buf.len());
2520 Ok(buf)
2521 }
2522
2523 fn consume(&mut self, amt: usize) {
2524 self.r.consume(amt)
2525 }
2526 }
2527
2528 #[test]
2529 fn test_io() {
2530 let schema = Arc::new(Schema::new(vec![
2531 Field::new("a", DataType::Utf8, false),
2532 Field::new("b", DataType::Utf8, false),
2533 ]));
2534 let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2535 let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2536 let reader = ReaderBuilder::new(schema)
2537 .with_batch_size(3)
2538 .build_buffered(&mut read)
2539 .unwrap();
2540
2541 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2542 assert_eq!(batches.len(), 2);
2543 assert_eq!(batches[0].num_rows(), 3);
2544 assert_eq!(batches[1].num_rows(), 1);
2545
2546 assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2552 assert_eq!(read.fill_count, 4);
2553 }
2554
2555 #[test]
2556 fn test_inference() {
2557 let cases: &[(&[&str], DataType)] = &[
2558 (&[], DataType::Null),
2559 (&["false", "12"], DataType::Utf8),
2560 (&["12", "cupcakes"], DataType::Utf8),
2561 (&["12", "12.4"], DataType::Float64),
2562 (&["14050", "24332"], DataType::Int64),
2563 (&["14050.0", "true"], DataType::Utf8),
2564 (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2565 (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2566 (
2567 &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2568 DataType::Timestamp(TimeUnit::Second, None),
2569 ),
2570 (&["2020-03-19", "2020-03-20"], DataType::Date32),
2571 (
2572 &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2573 DataType::Timestamp(TimeUnit::Second, None),
2574 ),
2575 (
2576 &[
2577 "2020-03-19",
2578 "2020-03-19 02:00:00",
2579 "2020-03-19 00:00:00.000",
2580 ],
2581 DataType::Timestamp(TimeUnit::Millisecond, None),
2582 ),
2583 (
2584 &[
2585 "2020-03-19",
2586 "2020-03-19 02:00:00",
2587 "2020-03-19 00:00:00.000000",
2588 ],
2589 DataType::Timestamp(TimeUnit::Microsecond, None),
2590 ),
2591 (
2592 &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2593 DataType::Timestamp(TimeUnit::Second, None),
2594 ),
2595 (
2596 &[
2597 "2020-03-19",
2598 "2020-03-19 02:00:00+02:00",
2599 "2020-03-19 02:00:00Z",
2600 "2020-03-19 02:00:00.12Z",
2601 ],
2602 DataType::Timestamp(TimeUnit::Millisecond, None),
2603 ),
2604 (
2605 &[
2606 "2020-03-19",
2607 "2020-03-19 02:00:00.000000000",
2608 "2020-03-19 00:00:00.000000",
2609 ],
2610 DataType::Timestamp(TimeUnit::Nanosecond, None),
2611 ),
2612 ];
2613
2614 for (values, expected) in cases {
2615 let mut t = InferredDataType::default();
2616 for v in *values {
2617 t.update(v)
2618 }
2619 assert_eq!(&t.get(), expected, "{values:?}")
2620 }
2621 }
2622
2623 #[test]
2624 fn test_record_length_mismatch() {
2625 let csv = "\
2626 a,b,c\n\
2627 1,2,3\n\
2628 4,5\n\
2629 6,7,8";
2630 let mut read = Cursor::new(csv.as_bytes());
2631 let result = Format::default()
2632 .with_header(true)
2633 .infer_schema(&mut read, None);
2634 assert!(result.is_err());
2635 assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3");
2637 }
2638
2639 #[test]
2640 fn test_comment() {
2641 let schema = Schema::new(vec![
2642 Field::new("a", DataType::Int8, false),
2643 Field::new("b", DataType::Int8, false),
2644 ]);
2645
2646 let csv = "# comment1 \n1,2\n#comment2\n11,22";
2647 let mut read = Cursor::new(csv.as_bytes());
2648 let reader = ReaderBuilder::new(Arc::new(schema))
2649 .with_comment(b'#')
2650 .build(&mut read)
2651 .unwrap();
2652
2653 let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2654 assert_eq!(batches.len(), 1);
2655 let b = batches.first().unwrap();
2656 assert_eq!(b.num_columns(), 2);
2657 assert_eq!(
2658 b.column(0)
2659 .as_any()
2660 .downcast_ref::<Int8Array>()
2661 .unwrap()
2662 .values(),
2663 &vec![1, 11]
2664 );
2665 assert_eq!(
2666 b.column(1)
2667 .as_any()
2668 .downcast_ref::<Int8Array>()
2669 .unwrap()
2670 .values(),
2671 &vec![2, 22]
2672 );
2673 }
2674
2675 #[test]
2676 fn test_parse_string_view_single_column() {
2677 let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2678 let schema = Arc::new(Schema::new(vec![Field::new(
2679 "c1",
2680 DataType::Utf8View,
2681 true,
2682 )]));
2683
2684 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2685
2686 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2687 assert_eq!(decoded, csv.len());
2688 decoder.decode(&[]).unwrap();
2689
2690 let batch = decoder.flush().unwrap().unwrap();
2691 assert_eq!(batch.num_columns(), 1);
2692 assert_eq!(batch.num_rows(), 3);
2693 let col = batch.column(0).as_string_view();
2694 assert_eq!(col.data_type(), &DataType::Utf8View);
2695 assert_eq!(col.value(0), "foo");
2696 assert_eq!(col.value(1), "something_cannot_be_inlined");
2697 assert_eq!(col.value(2), "foobar");
2698 }
2699
2700 #[test]
2701 fn test_parse_string_view_multi_column() {
2702 let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2703 let schema = Arc::new(Schema::new(vec![
2704 Field::new("c1", DataType::Utf8View, true),
2705 Field::new("c2", DataType::Utf8View, true),
2706 ]));
2707
2708 let mut decoder = ReaderBuilder::new(schema).build_decoder();
2709
2710 let decoded = decoder.decode(csv.as_bytes()).unwrap();
2711 assert_eq!(decoded, csv.len());
2712 decoder.decode(&[]).unwrap();
2713
2714 let batch = decoder.flush().unwrap().unwrap();
2715 assert_eq!(batch.num_columns(), 2);
2716 assert_eq!(batch.num_rows(), 3);
2717 let c1 = batch.column(0).as_string_view();
2718 let c2 = batch.column(1).as_string_view();
2719 assert_eq!(c1.data_type(), &DataType::Utf8View);
2720 assert_eq!(c2.data_type(), &DataType::Utf8View);
2721
2722 assert!(!c1.is_null(0));
2723 assert!(c1.is_null(1));
2724 assert!(!c1.is_null(2));
2725 assert_eq!(c1.value(0), "foo");
2726 assert_eq!(c1.value(2), "foobarfoobar");
2727
2728 assert!(c2.is_null(0));
2729 assert!(!c2.is_null(1));
2730 assert!(!c2.is_null(2));
2731 assert_eq!(c2.value(1), "something_cannot_be_inlined");
2732 assert_eq!(c2.value(2), "bar");
2733 }
2734}