arrow_csv/reader/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CSV Reader
19//!
20//! # Basic Usage
21//!
22//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are
23//! loaded in batches and are then converted from row-based data to columnar data.
24//!
25//! Example:
26//!
27//! ```
28//! # use arrow_schema::*;
29//! # use arrow_csv::{Reader, ReaderBuilder};
30//! # use std::fs::File;
31//! # use std::sync::Arc;
32//!
33//! let schema = Schema::new(vec![
34//!     Field::new("city", DataType::Utf8, false),
35//!     Field::new("lat", DataType::Float64, false),
36//!     Field::new("lng", DataType::Float64, false),
37//! ]);
38//!
39//! let file = File::open("test/data/uk_cities.csv").unwrap();
40//!
41//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
42//! let batch = csv.next().unwrap().unwrap();
43//! ```
44//!
45//! # Async Usage
46//!
47//! The lower-level [`Decoder`] can be integrated with various forms of async data streams,
48//! and is designed to be agnostic to the various different kinds of async IO primitives found
49//! within the Rust ecosystem.
50//!
51//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes`
52//!
53//! ```
54//! # use std::task::{Poll, ready};
55//! # use bytes::{Buf, Bytes};
56//! # use arrow_schema::ArrowError;
57//! # use futures::stream::{Stream, StreamExt};
58//! # use arrow_array::RecordBatch;
59//! # use arrow_csv::reader::Decoder;
60//! #
61//! fn decode_stream<S: Stream<Item = Bytes> + Unpin>(
62//!     mut decoder: Decoder,
63//!     mut input: S,
64//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
65//!     let mut buffered = Bytes::new();
66//!     futures::stream::poll_fn(move |cx| {
67//!         loop {
68//!             if buffered.is_empty() {
69//!                 if let Some(b) = ready!(input.poll_next_unpin(cx)) {
70//!                     buffered = b;
71//!                 }
72//!                 // Note: don't break on `None` as the decoder needs
73//!                 // to be called with an empty array to delimit the
74//!                 // final record
75//!             }
76//!             let decoded = match decoder.decode(buffered.as_ref()) {
77//!                 Ok(0) => break,
78//!                 Ok(decoded) => decoded,
79//!                 Err(e) => return Poll::Ready(Some(Err(e))),
80//!             };
81//!             buffered.advance(decoded);
82//!         }
83//!
84//!         Poll::Ready(decoder.flush().transpose())
85//!     })
86//! }
87//!
88//! ```
89//!
90//! In a similar vein, it can also be used with tokio-based IO primitives
91//!
92//! ```
93//! # use std::pin::Pin;
94//! # use std::task::{Poll, ready};
95//! # use futures::Stream;
96//! # use tokio::io::AsyncBufRead;
97//! # use arrow_array::RecordBatch;
98//! # use arrow_csv::reader::Decoder;
99//! # use arrow_schema::ArrowError;
100//! fn decode_stream<R: AsyncBufRead + Unpin>(
101//!     mut decoder: Decoder,
102//!     mut reader: R,
103//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
104//!     futures::stream::poll_fn(move |cx| {
105//!         loop {
106//!             let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) {
107//!                 Ok(b) => b,
108//!                 Err(e) => return Poll::Ready(Some(Err(e.into()))),
109//!             };
110//!             let decoded = match decoder.decode(b) {
111//!                 // Note: the decoder needs to be called with an empty
112//!                 // array to delimit the final record
113//!                 Ok(0) => break,
114//!                 Ok(decoded) => decoded,
115//!                 Err(e) => return Poll::Ready(Some(Err(e))),
116//!             };
117//!             Pin::new(&mut reader).consume(decoded);
118//!         }
119//!
120//!         Poll::Ready(decoder.flush().transpose())
121//!     })
122//! }
123//! ```
124//!
125
126mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use lazy_static::lazy_static;
136use regex::{Regex, RegexSet};
137use std::fmt::{self, Debug};
138use std::fs::File;
139use std::io::{BufRead, BufReader as StdBufReader, Read};
140use std::sync::Arc;
141
142use crate::map_csv_error;
143use crate::reader::records::{RecordDecoder, StringRecords};
144use arrow_array::timezone::Tz;
145
146lazy_static! {
147    /// Order should match [`InferredDataType`]
148    static ref REGEX_SET: RegexSet = RegexSet::new([
149        r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
150        r"^-?(\d+)$", //INTEGER
151        r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", //DECIMAL
152        r"^\d{4}-\d\d-\d\d$", //DATE32
153        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second)
154        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond)
155        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond)
156        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond)
157    ]).unwrap();
158}
159
160/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
161#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165    /// Returns true if the value should be considered as `NULL` according to
166    /// the provided regular expression.
167    #[inline]
168    fn is_null(&self, s: &str) -> bool {
169        match &self.0 {
170            Some(r) => r.is_match(s),
171            None => s.is_empty(),
172        }
173    }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178    /// Packed booleans indicating type
179    ///
180    /// 0 - Boolean
181    /// 1 - Integer
182    /// 2 - Float64
183    /// 3 - Date32
184    /// 4 - Timestamp(Second)
185    /// 5 - Timestamp(Millisecond)
186    /// 6 - Timestamp(Microsecond)
187    /// 7 - Timestamp(Nanosecond)
188    /// 8 - Utf8
189    packed: u16,
190}
191
192impl InferredDataType {
193    /// Returns the inferred data type
194    fn get(&self) -> DataType {
195        match self.packed {
196            0 => DataType::Null,
197            1 => DataType::Boolean,
198            2 => DataType::Int64,
199            4 | 6 => DataType::Float64, // Promote Int64 to Float64
200            b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201                // Promote to highest precision temporal type
202                8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203                9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204                10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205                11 => DataType::Timestamp(TimeUnit::Second, None),
206                12 => DataType::Date32,
207                _ => unreachable!(),
208            },
209            _ => DataType::Utf8,
210        }
211    }
212
213    /// Updates the [`InferredDataType`] with the given string
214    fn update(&mut self, string: &str) {
215        self.packed |= if string.starts_with('"') {
216            1 << 8 // Utf8
217        } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218            if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219                // if overflow i64, fallback to utf8
220                1 << 8
221            } else {
222                1 << m
223            }
224        } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225            1 << 2 // Float64
226        } else {
227            1 << 8 // Utf8
228        }
229    }
230}
231
232/// The format specification for the CSV file
233#[derive(Debug, Clone, Default)]
234pub struct Format {
235    header: bool,
236    delimiter: Option<u8>,
237    escape: Option<u8>,
238    quote: Option<u8>,
239    terminator: Option<u8>,
240    comment: Option<u8>,
241    null_regex: NullRegex,
242    truncated_rows: bool,
243}
244
245impl Format {
246    /// Specify whether the CSV file has a header, defaults to `false`
247    ///
248    /// When `true`, the first row of the CSV file is treated as a header row
249    pub fn with_header(mut self, has_header: bool) -> Self {
250        self.header = has_header;
251        self
252    }
253
254    /// Specify a custom delimiter character, defaults to comma `','`
255    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256        self.delimiter = Some(delimiter);
257        self
258    }
259
260    /// Specify an escape character, defaults to `None`
261    pub fn with_escape(mut self, escape: u8) -> Self {
262        self.escape = Some(escape);
263        self
264    }
265
266    /// Specify a custom quote character, defaults to double quote `'"'`
267    pub fn with_quote(mut self, quote: u8) -> Self {
268        self.quote = Some(quote);
269        self
270    }
271
272    /// Specify a custom terminator character, defaults to CRLF
273    pub fn with_terminator(mut self, terminator: u8) -> Self {
274        self.terminator = Some(terminator);
275        self
276    }
277
278    /// Specify a comment character, defaults to `None`
279    ///
280    /// Lines starting with this character will be ignored
281    pub fn with_comment(mut self, comment: u8) -> Self {
282        self.comment = Some(comment);
283        self
284    }
285
286    /// Provide a regex to match null values, defaults to `^$`
287    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288        self.null_regex = NullRegex(Some(null_regex));
289        self
290    }
291
292    /// Whether to allow truncated rows when parsing.
293    ///
294    /// By default this is set to `false` and will error if the CSV rows have different lengths.
295    /// When set to true then it will allow records with less than the expected number of columns
296    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
297    /// will still return an error.
298    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299        self.truncated_rows = allow;
300        self
301    }
302
303    /// Infer schema of CSV records from the provided `reader`
304    ///
305    /// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
306    /// records are read to infer the schema
307    ///
308    /// Returns inferred schema and number of records read
309    pub fn infer_schema<R: Read>(
310        &self,
311        reader: R,
312        max_records: Option<usize>,
313    ) -> Result<(Schema, usize), ArrowError> {
314        let mut csv_reader = self.build_reader(reader);
315
316        // get or create header names
317        // when has_header is false, creates default column names with column_ prefix
318        let headers: Vec<String> = if self.header {
319            let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320            headers.iter().map(|s| s.to_string()).collect()
321        } else {
322            let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323            (0..*first_record_count)
324                .map(|i| format!("column_{}", i + 1))
325                .collect()
326        };
327
328        let header_length = headers.len();
329        // keep track of inferred field types
330        let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332        let mut records_count = 0;
333
334        let mut record = StringRecord::new();
335        let max_records = max_records.unwrap_or(usize::MAX);
336        while records_count < max_records {
337            if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338                break;
339            }
340            records_count += 1;
341
342            // Note since we may be looking at a sample of the data, we make the safe assumption that
343            // they could be nullable
344            for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345                if let Some(string) = record.get(i) {
346                    if !self.null_regex.is_null(string) {
347                        column_type.update(string)
348                    }
349                }
350            }
351        }
352
353        // build schema from inference results
354        let fields: Fields = column_types
355            .iter()
356            .zip(&headers)
357            .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358            .collect();
359
360        Ok((Schema::new(fields), records_count))
361    }
362
363    /// Build a [`csv::Reader`] for this [`Format`]
364    fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365        let mut builder = csv::ReaderBuilder::new();
366        builder.has_headers(self.header);
367        builder.flexible(self.truncated_rows);
368
369        if let Some(c) = self.delimiter {
370            builder.delimiter(c);
371        }
372        builder.escape(self.escape);
373        if let Some(c) = self.quote {
374            builder.quote(c);
375        }
376        if let Some(t) = self.terminator {
377            builder.terminator(csv::Terminator::Any(t));
378        }
379        if let Some(comment) = self.comment {
380            builder.comment(Some(comment));
381        }
382        builder.from_reader(reader)
383    }
384
385    /// Build a [`csv_core::Reader`] for this [`Format`]
386    fn build_parser(&self) -> csv_core::Reader {
387        let mut builder = csv_core::ReaderBuilder::new();
388        builder.escape(self.escape);
389        builder.comment(self.comment);
390
391        if let Some(c) = self.delimiter {
392            builder.delimiter(c);
393        }
394        if let Some(c) = self.quote {
395            builder.quote(c);
396        }
397        if let Some(t) = self.terminator {
398            builder.terminator(csv_core::Terminator::Any(t));
399        }
400        builder.build()
401    }
402}
403
404/// Infer schema from a list of CSV files by reading through first n records
405/// with `max_read_records` controlling the maximum number of records to read.
406///
407/// Files will be read in the given order until n records have been reached.
408///
409/// If `max_read_records` is not set, all files will be read fully to infer the schema.
410pub fn infer_schema_from_files(
411    files: &[String],
412    delimiter: u8,
413    max_read_records: Option<usize>,
414    has_header: bool,
415) -> Result<Schema, ArrowError> {
416    let mut schemas = vec![];
417    let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418    let format = Format {
419        delimiter: Some(delimiter),
420        header: has_header,
421        ..Default::default()
422    };
423
424    for fname in files.iter() {
425        let f = File::open(fname)?;
426        let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427        if records_read == 0 {
428            continue;
429        }
430        schemas.push(schema.clone());
431        records_to_read -= records_read;
432        if records_to_read == 0 {
433            break;
434        }
435    }
436
437    Schema::try_merge(schemas)
438}
439
440// optional bounds of the reader, of the form (min line, max line).
441type Bounds = Option<(usize, usize)>;
442
443/// CSV file reader using [`std::io::BufReader`]
444pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446/// CSV file reader
447pub struct BufReader<R> {
448    /// File reader
449    reader: R,
450
451    /// The decoder
452    decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457    R: BufRead,
458{
459    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460        f.debug_struct("Reader")
461            .field("decoder", &self.decoder)
462            .finish()
463    }
464}
465
466impl<R: Read> Reader<R> {
467    /// Returns the schema of the reader, useful for getting the schema without reading
468    /// record batches
469    pub fn schema(&self) -> SchemaRef {
470        match &self.decoder.projection {
471            Some(projection) => {
472                let fields = self.decoder.schema.fields();
473                let projected = projection.iter().map(|i| fields[*i].clone());
474                Arc::new(Schema::new(projected.collect::<Fields>()))
475            }
476            None => self.decoder.schema.clone(),
477        }
478    }
479}
480
481impl<R: BufRead> BufReader<R> {
482    fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483        loop {
484            let buf = self.reader.fill_buf()?;
485            let decoded = self.decoder.decode(buf)?;
486            self.reader.consume(decoded);
487            // Yield if decoded no bytes or the decoder is full
488            //
489            // The capacity check avoids looping around and potentially
490            // blocking reading data in fill_buf that isn't needed
491            // to flush the next batch
492            if decoded == 0 || self.decoder.capacity() == 0 {
493                break;
494            }
495        }
496
497        self.decoder.flush()
498    }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502    type Item = Result<RecordBatch, ArrowError>;
503
504    fn next(&mut self) -> Option<Self::Item> {
505        self.read().transpose()
506    }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510    fn schema(&self) -> SchemaRef {
511        self.decoder.schema.clone()
512    }
513}
514
515/// A push-based interface for decoding CSV data from an arbitrary byte stream
516///
517/// See [`Reader`] for a higher-level interface for interface with [`Read`]
518///
519/// The push-based interface facilitates integration with sources that yield arbitrarily
520/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from
521/// object storage
522///
523/// ```
524/// # use std::io::BufRead;
525/// # use arrow_array::RecordBatch;
526/// # use arrow_csv::ReaderBuilder;
527/// # use arrow_schema::{ArrowError, SchemaRef};
528/// #
529/// fn read_from_csv<R: BufRead>(
530///     mut reader: R,
531///     schema: SchemaRef,
532///     batch_size: usize,
533/// ) -> Result<impl Iterator<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
534///     let mut decoder = ReaderBuilder::new(schema)
535///         .with_batch_size(batch_size)
536///         .build_decoder();
537///
538///     let mut next = move || {
539///         loop {
540///             let buf = reader.fill_buf()?;
541///             let decoded = decoder.decode(buf)?;
542///             if decoded == 0 {
543///                 break;
544///             }
545///
546///             // Consume the number of bytes read
547///             reader.consume(decoded);
548///         }
549///         decoder.flush()
550///     };
551///     Ok(std::iter::from_fn(move || next().transpose()))
552/// }
553/// ```
554#[derive(Debug)]
555pub struct Decoder {
556    /// Explicit schema for the CSV file
557    schema: SchemaRef,
558
559    /// Optional projection for which columns to load (zero-based column indices)
560    projection: Option<Vec<usize>>,
561
562    /// Number of records per batch
563    batch_size: usize,
564
565    /// Rows to skip
566    to_skip: usize,
567
568    /// Current line number
569    line_number: usize,
570
571    /// End line number
572    end: usize,
573
574    /// A decoder for [`StringRecords`]
575    record_decoder: RecordDecoder,
576
577    /// Check if the string matches this pattern for `NULL`.
578    null_regex: NullRegex,
579}
580
581impl Decoder {
582    /// Decode records from `buf` returning the number of bytes read
583    ///
584    /// This method returns once `batch_size` objects have been parsed since the
585    /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes
586    /// should be included in the next call to [`Self::decode`]
587    ///
588    /// There is no requirement that `buf` contains a whole number of records, facilitating
589    /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or
590    /// network sources such as object storage
591    pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592        if self.to_skip != 0 {
593            // Skip in units of `to_read` to avoid over-allocating buffers
594            let to_skip = self.to_skip.min(self.batch_size);
595            let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596            self.to_skip -= skipped;
597            self.record_decoder.clear();
598            return Ok(bytes);
599        }
600
601        let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602        let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603        Ok(bytes)
604    }
605
606    /// Flushes the currently buffered data to a [`RecordBatch`]
607    ///
608    /// This should only be called after [`Self::decode`] has returned `Ok(0)`,
609    /// otherwise may return an error if part way through decoding a record
610    ///
611    /// Returns `Ok(None)` if no buffered data
612    pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613        if self.record_decoder.is_empty() {
614            return Ok(None);
615        }
616
617        let rows = self.record_decoder.flush()?;
618        let batch = parse(
619            &rows,
620            self.schema.fields(),
621            Some(self.schema.metadata.clone()),
622            self.projection.as_ref(),
623            self.line_number,
624            &self.null_regex,
625        )?;
626        self.line_number += rows.len();
627        Ok(Some(batch))
628    }
629
630    /// Returns the number of records that can be read before requiring a call to [`Self::flush`]
631    pub fn capacity(&self) -> usize {
632        self.batch_size - self.record_decoder.len()
633    }
634}
635
636/// Parses a slice of [`StringRecords`] into a [RecordBatch]
637fn parse(
638    rows: &StringRecords<'_>,
639    fields: &Fields,
640    metadata: Option<std::collections::HashMap<String, String>>,
641    projection: Option<&Vec<usize>>,
642    line_number: usize,
643    null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645    let projection: Vec<usize> = match projection {
646        Some(v) => v.clone(),
647        None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648    };
649
650    let arrays: Result<Vec<ArrayRef>, _> = projection
651        .iter()
652        .map(|i| {
653            let i = *i;
654            let field = &fields[i];
655            match field.data_type() {
656                DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657                DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
658                    line_number,
659                    rows,
660                    i,
661                    *precision,
662                    *scale,
663                    null_regex,
664                ),
665                DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
666                    line_number,
667                    rows,
668                    i,
669                    *precision,
670                    *scale,
671                    null_regex,
672                ),
673                DataType::Int8 => {
674                    build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
675                }
676                DataType::Int16 => {
677                    build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
678                }
679                DataType::Int32 => {
680                    build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
681                }
682                DataType::Int64 => {
683                    build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
684                }
685                DataType::UInt8 => {
686                    build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
687                }
688                DataType::UInt16 => {
689                    build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
690                }
691                DataType::UInt32 => {
692                    build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
693                }
694                DataType::UInt64 => {
695                    build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
696                }
697                DataType::Float32 => {
698                    build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
699                }
700                DataType::Float64 => {
701                    build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
702                }
703                DataType::Date32 => {
704                    build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
705                }
706                DataType::Date64 => {
707                    build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
708                }
709                DataType::Time32(TimeUnit::Second) => {
710                    build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
711                }
712                DataType::Time32(TimeUnit::Millisecond) => {
713                    build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
714                }
715                DataType::Time64(TimeUnit::Microsecond) => {
716                    build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
717                }
718                DataType::Time64(TimeUnit::Nanosecond) => {
719                    build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
720                }
721                DataType::Timestamp(TimeUnit::Second, tz) => {
722                    build_timestamp_array::<TimestampSecondType>(
723                        line_number,
724                        rows,
725                        i,
726                        tz.as_deref(),
727                        null_regex,
728                    )
729                }
730                DataType::Timestamp(TimeUnit::Millisecond, tz) => {
731                    build_timestamp_array::<TimestampMillisecondType>(
732                        line_number,
733                        rows,
734                        i,
735                        tz.as_deref(),
736                        null_regex,
737                    )
738                }
739                DataType::Timestamp(TimeUnit::Microsecond, tz) => {
740                    build_timestamp_array::<TimestampMicrosecondType>(
741                        line_number,
742                        rows,
743                        i,
744                        tz.as_deref(),
745                        null_regex,
746                    )
747                }
748                DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
749                    build_timestamp_array::<TimestampNanosecondType>(
750                        line_number,
751                        rows,
752                        i,
753                        tz.as_deref(),
754                        null_regex,
755                    )
756                }
757                DataType::Null => Ok(Arc::new({
758                    let mut builder = NullBuilder::new();
759                    builder.append_nulls(rows.len());
760                    builder.finish()
761                }) as ArrayRef),
762                DataType::Utf8 => Ok(Arc::new(
763                    rows.iter()
764                        .map(|row| {
765                            let s = row.get(i);
766                            (!null_regex.is_null(s)).then_some(s)
767                        })
768                        .collect::<StringArray>(),
769                ) as ArrayRef),
770                DataType::Utf8View => Ok(Arc::new(
771                    rows.iter()
772                        .map(|row| {
773                            let s = row.get(i);
774                            (!null_regex.is_null(s)).then_some(s)
775                        })
776                        .collect::<StringViewArray>(),
777                ) as ArrayRef),
778                DataType::Dictionary(key_type, value_type)
779                    if value_type.as_ref() == &DataType::Utf8 =>
780                {
781                    match key_type.as_ref() {
782                        DataType::Int8 => Ok(Arc::new(
783                            rows.iter()
784                                .map(|row| {
785                                    let s = row.get(i);
786                                    (!null_regex.is_null(s)).then_some(s)
787                                })
788                                .collect::<DictionaryArray<Int8Type>>(),
789                        ) as ArrayRef),
790                        DataType::Int16 => Ok(Arc::new(
791                            rows.iter()
792                                .map(|row| {
793                                    let s = row.get(i);
794                                    (!null_regex.is_null(s)).then_some(s)
795                                })
796                                .collect::<DictionaryArray<Int16Type>>(),
797                        ) as ArrayRef),
798                        DataType::Int32 => Ok(Arc::new(
799                            rows.iter()
800                                .map(|row| {
801                                    let s = row.get(i);
802                                    (!null_regex.is_null(s)).then_some(s)
803                                })
804                                .collect::<DictionaryArray<Int32Type>>(),
805                        ) as ArrayRef),
806                        DataType::Int64 => Ok(Arc::new(
807                            rows.iter()
808                                .map(|row| {
809                                    let s = row.get(i);
810                                    (!null_regex.is_null(s)).then_some(s)
811                                })
812                                .collect::<DictionaryArray<Int64Type>>(),
813                        ) as ArrayRef),
814                        DataType::UInt8 => Ok(Arc::new(
815                            rows.iter()
816                                .map(|row| {
817                                    let s = row.get(i);
818                                    (!null_regex.is_null(s)).then_some(s)
819                                })
820                                .collect::<DictionaryArray<UInt8Type>>(),
821                        ) as ArrayRef),
822                        DataType::UInt16 => Ok(Arc::new(
823                            rows.iter()
824                                .map(|row| {
825                                    let s = row.get(i);
826                                    (!null_regex.is_null(s)).then_some(s)
827                                })
828                                .collect::<DictionaryArray<UInt16Type>>(),
829                        ) as ArrayRef),
830                        DataType::UInt32 => Ok(Arc::new(
831                            rows.iter()
832                                .map(|row| {
833                                    let s = row.get(i);
834                                    (!null_regex.is_null(s)).then_some(s)
835                                })
836                                .collect::<DictionaryArray<UInt32Type>>(),
837                        ) as ArrayRef),
838                        DataType::UInt64 => Ok(Arc::new(
839                            rows.iter()
840                                .map(|row| {
841                                    let s = row.get(i);
842                                    (!null_regex.is_null(s)).then_some(s)
843                                })
844                                .collect::<DictionaryArray<UInt64Type>>(),
845                        ) as ArrayRef),
846                        _ => Err(ArrowError::ParseError(format!(
847                            "Unsupported dictionary key type {key_type:?}"
848                        ))),
849                    }
850                }
851                other => Err(ArrowError::ParseError(format!(
852                    "Unsupported data type {other:?}"
853                ))),
854            }
855        })
856        .collect();
857
858    let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
859
860    let projected_schema = Arc::new(match metadata {
861        None => Schema::new(projected_fields),
862        Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
863    });
864
865    arrays.and_then(|arr| {
866        RecordBatch::try_new_with_options(
867            projected_schema,
868            arr,
869            &RecordBatchOptions::new()
870                .with_match_field_names(true)
871                .with_row_count(Some(rows.len())),
872        )
873    })
874}
875
876fn parse_bool(string: &str) -> Option<bool> {
877    if string.eq_ignore_ascii_case("false") {
878        Some(false)
879    } else if string.eq_ignore_ascii_case("true") {
880        Some(true)
881    } else {
882        None
883    }
884}
885
886// parse the column string to an Arrow Array
887fn build_decimal_array<T: DecimalType>(
888    _line_number: usize,
889    rows: &StringRecords<'_>,
890    col_idx: usize,
891    precision: u8,
892    scale: i8,
893    null_regex: &NullRegex,
894) -> Result<ArrayRef, ArrowError> {
895    let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
896    for row in rows.iter() {
897        let s = row.get(col_idx);
898        if null_regex.is_null(s) {
899            // append null
900            decimal_builder.append_null();
901        } else {
902            let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
903            match decimal_value {
904                Ok(v) => {
905                    decimal_builder.append_value(v);
906                }
907                Err(e) => {
908                    return Err(e);
909                }
910            }
911        }
912    }
913    Ok(Arc::new(
914        decimal_builder
915            .finish()
916            .with_precision_and_scale(precision, scale)?,
917    ))
918}
919
920// parses a specific column (col_idx) into an Arrow Array.
921fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
922    line_number: usize,
923    rows: &StringRecords<'_>,
924    col_idx: usize,
925    null_regex: &NullRegex,
926) -> Result<ArrayRef, ArrowError> {
927    rows.iter()
928        .enumerate()
929        .map(|(row_index, row)| {
930            let s = row.get(col_idx);
931            if null_regex.is_null(s) {
932                return Ok(None);
933            }
934
935            match T::parse(s) {
936                Some(e) => Ok(Some(e)),
937                None => Err(ArrowError::ParseError(format!(
938                    // TODO: we should surface the underlying error here.
939                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
940                    s,
941                    T::DATA_TYPE,
942                    col_idx,
943                    line_number + row_index,
944                    row
945                ))),
946            }
947        })
948        .collect::<Result<PrimitiveArray<T>, ArrowError>>()
949        .map(|e| Arc::new(e) as ArrayRef)
950}
951
952fn build_timestamp_array<T: ArrowTimestampType>(
953    line_number: usize,
954    rows: &StringRecords<'_>,
955    col_idx: usize,
956    timezone: Option<&str>,
957    null_regex: &NullRegex,
958) -> Result<ArrayRef, ArrowError> {
959    Ok(Arc::new(match timezone {
960        Some(timezone) => {
961            let tz: Tz = timezone.parse()?;
962            build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
963                .with_timezone(timezone)
964        }
965        None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
966    }))
967}
968
969fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
970    line_number: usize,
971    rows: &StringRecords<'_>,
972    col_idx: usize,
973    timezone: &Tz,
974    null_regex: &NullRegex,
975) -> Result<PrimitiveArray<T>, ArrowError> {
976    rows.iter()
977        .enumerate()
978        .map(|(row_index, row)| {
979            let s = row.get(col_idx);
980            if null_regex.is_null(s) {
981                return Ok(None);
982            }
983
984            let date = string_to_datetime(timezone, s)
985                .and_then(|date| match T::UNIT {
986                    TimeUnit::Second => Ok(date.timestamp()),
987                    TimeUnit::Millisecond => Ok(date.timestamp_millis()),
988                    TimeUnit::Microsecond => Ok(date.timestamp_micros()),
989                    TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
990                        ArrowError::ParseError(format!(
991                            "{} would overflow 64-bit signed nanoseconds",
992                            date.to_rfc3339(),
993                        ))
994                    }),
995                })
996                .map_err(|e| {
997                    ArrowError::ParseError(format!(
998                        "Error parsing column {col_idx} at line {}: {}",
999                        line_number + row_index,
1000                        e
1001                    ))
1002                })?;
1003            Ok(Some(date))
1004        })
1005        .collect()
1006}
1007
1008// parses a specific column (col_idx) into an Arrow Array.
1009fn build_boolean_array(
1010    line_number: usize,
1011    rows: &StringRecords<'_>,
1012    col_idx: usize,
1013    null_regex: &NullRegex,
1014) -> Result<ArrayRef, ArrowError> {
1015    rows.iter()
1016        .enumerate()
1017        .map(|(row_index, row)| {
1018            let s = row.get(col_idx);
1019            if null_regex.is_null(s) {
1020                return Ok(None);
1021            }
1022            let parsed = parse_bool(s);
1023            match parsed {
1024                Some(e) => Ok(Some(e)),
1025                None => Err(ArrowError::ParseError(format!(
1026                    // TODO: we should surface the underlying error here.
1027                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1028                    s,
1029                    "Boolean",
1030                    col_idx,
1031                    line_number + row_index,
1032                    row
1033                ))),
1034            }
1035        })
1036        .collect::<Result<BooleanArray, _>>()
1037        .map(|e| Arc::new(e) as ArrayRef)
1038}
1039
1040/// CSV file reader builder
1041#[derive(Debug)]
1042pub struct ReaderBuilder {
1043    /// Schema of the CSV file
1044    schema: SchemaRef,
1045    /// Format of the CSV file
1046    format: Format,
1047    /// Batch size (number of records to load each time)
1048    ///
1049    /// The default batch size when using the `ReaderBuilder` is 1024 records
1050    batch_size: usize,
1051    /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF.
1052    bounds: Bounds,
1053    /// Optional projection for which columns to load (zero-based column indices)
1054    projection: Option<Vec<usize>>,
1055}
1056
1057impl ReaderBuilder {
1058    /// Create a new builder for configuring CSV parsing options.
1059    ///
1060    /// To convert a builder into a reader, call `ReaderBuilder::build`
1061    ///
1062    /// # Example
1063    ///
1064    /// ```
1065    /// # use arrow_csv::{Reader, ReaderBuilder};
1066    /// # use std::fs::File;
1067    /// # use std::io::Seek;
1068    /// # use std::sync::Arc;
1069    /// # use arrow_csv::reader::Format;
1070    /// #
1071    /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1072    /// // Infer the schema with the first 100 records
1073    /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap();
1074    /// file.rewind().unwrap();
1075    ///
1076    /// // create a builder
1077    /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1078    /// ```
1079    pub fn new(schema: SchemaRef) -> ReaderBuilder {
1080        Self {
1081            schema,
1082            format: Format::default(),
1083            batch_size: 1024,
1084            bounds: None,
1085            projection: None,
1086        }
1087    }
1088
1089    /// Set whether the CSV file has a header
1090    pub fn with_header(mut self, has_header: bool) -> Self {
1091        self.format.header = has_header;
1092        self
1093    }
1094
1095    /// Overrides the [Format] of this [ReaderBuilder]
1096    pub fn with_format(mut self, format: Format) -> Self {
1097        self.format = format;
1098        self
1099    }
1100
1101    /// Set the CSV file's column delimiter as a byte character
1102    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1103        self.format.delimiter = Some(delimiter);
1104        self
1105    }
1106
1107    /// Set the given character as the CSV file's escape character
1108    pub fn with_escape(mut self, escape: u8) -> Self {
1109        self.format.escape = Some(escape);
1110        self
1111    }
1112
1113    /// Set the given character as the CSV file's quote character, by default it is double quote
1114    pub fn with_quote(mut self, quote: u8) -> Self {
1115        self.format.quote = Some(quote);
1116        self
1117    }
1118
1119    /// Provide a custom terminator character, defaults to CRLF
1120    pub fn with_terminator(mut self, terminator: u8) -> Self {
1121        self.format.terminator = Some(terminator);
1122        self
1123    }
1124
1125    /// Provide a comment character, lines starting with this character will be ignored
1126    pub fn with_comment(mut self, comment: u8) -> Self {
1127        self.format.comment = Some(comment);
1128        self
1129    }
1130
1131    /// Provide a regex to match null values, defaults to `^$`
1132    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1133        self.format.null_regex = NullRegex(Some(null_regex));
1134        self
1135    }
1136
1137    /// Set the batch size (number of records to load at one time)
1138    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1139        self.batch_size = batch_size;
1140        self
1141    }
1142
1143    /// Set the bounds over which to scan the reader.
1144    /// `start` and `end` are line numbers.
1145    pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1146        self.bounds = Some((start, end));
1147        self
1148    }
1149
1150    /// Set the reader's column projection
1151    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1152        self.projection = Some(projection);
1153        self
1154    }
1155
1156    /// Whether to allow truncated rows when parsing.
1157    ///
1158    /// By default this is set to `false` and will error if the CSV rows have different lengths.
1159    /// When set to true then it will allow records with less than the expected number of columns
1160    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
1161    /// will still return an error.
1162    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1163        self.format.truncated_rows = allow;
1164        self
1165    }
1166
1167    /// Create a new `Reader` from a non-buffered reader
1168    ///
1169    /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
1170    /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`]
1171    pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1172        self.build_buffered(StdBufReader::new(reader))
1173    }
1174
1175    /// Create a new `BufReader` from a buffered reader
1176    pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1177        Ok(BufReader {
1178            reader,
1179            decoder: self.build_decoder(),
1180        })
1181    }
1182
1183    /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
1184    pub fn build_decoder(self) -> Decoder {
1185        let delimiter = self.format.build_parser();
1186        let record_decoder = RecordDecoder::new(
1187            delimiter,
1188            self.schema.fields().len(),
1189            self.format.truncated_rows,
1190        );
1191
1192        let header = self.format.header as usize;
1193
1194        let (start, end) = match self.bounds {
1195            Some((start, end)) => (start + header, end + header),
1196            None => (header, usize::MAX),
1197        };
1198
1199        Decoder {
1200            schema: self.schema,
1201            to_skip: start,
1202            record_decoder,
1203            line_number: start,
1204            end,
1205            projection: self.projection,
1206            batch_size: self.batch_size,
1207            null_regex: self.format.null_regex,
1208        }
1209    }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214    use super::*;
1215
1216    use std::io::{Cursor, Seek, SeekFrom, Write};
1217    use tempfile::NamedTempFile;
1218
1219    use arrow_array::cast::AsArray;
1220
1221    #[test]
1222    fn test_csv() {
1223        let schema = Arc::new(Schema::new(vec![
1224            Field::new("city", DataType::Utf8, false),
1225            Field::new("lat", DataType::Float64, false),
1226            Field::new("lng", DataType::Float64, false),
1227        ]));
1228
1229        let file = File::open("test/data/uk_cities.csv").unwrap();
1230        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1231        assert_eq!(schema, csv.schema());
1232        let batch = csv.next().unwrap().unwrap();
1233        assert_eq!(37, batch.num_rows());
1234        assert_eq!(3, batch.num_columns());
1235
1236        // access data from a primitive array
1237        let lat = batch.column(1).as_primitive::<Float64Type>();
1238        assert_eq!(57.653484, lat.value(0));
1239
1240        // access data from a string array (ListArray<u8>)
1241        let city = batch.column(0).as_string::<i32>();
1242
1243        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1244    }
1245
1246    #[test]
1247    fn test_csv_schema_metadata() {
1248        let mut metadata = std::collections::HashMap::new();
1249        metadata.insert("foo".to_owned(), "bar".to_owned());
1250        let schema = Arc::new(Schema::new_with_metadata(
1251            vec![
1252                Field::new("city", DataType::Utf8, false),
1253                Field::new("lat", DataType::Float64, false),
1254                Field::new("lng", DataType::Float64, false),
1255            ],
1256            metadata.clone(),
1257        ));
1258
1259        let file = File::open("test/data/uk_cities.csv").unwrap();
1260
1261        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1262        assert_eq!(schema, csv.schema());
1263        let batch = csv.next().unwrap().unwrap();
1264        assert_eq!(37, batch.num_rows());
1265        assert_eq!(3, batch.num_columns());
1266
1267        assert_eq!(&metadata, batch.schema().metadata());
1268    }
1269
1270    #[test]
1271    fn test_csv_reader_with_decimal() {
1272        let schema = Arc::new(Schema::new(vec![
1273            Field::new("city", DataType::Utf8, false),
1274            Field::new("lat", DataType::Decimal128(38, 6), false),
1275            Field::new("lng", DataType::Decimal256(76, 6), false),
1276        ]));
1277
1278        let file = File::open("test/data/decimal_test.csv").unwrap();
1279
1280        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1281        let batch = csv.next().unwrap().unwrap();
1282        // access data from a primitive array
1283        let lat = batch
1284            .column(1)
1285            .as_any()
1286            .downcast_ref::<Decimal128Array>()
1287            .unwrap();
1288
1289        assert_eq!("57.653484", lat.value_as_string(0));
1290        assert_eq!("53.002666", lat.value_as_string(1));
1291        assert_eq!("52.412811", lat.value_as_string(2));
1292        assert_eq!("51.481583", lat.value_as_string(3));
1293        assert_eq!("12.123456", lat.value_as_string(4));
1294        assert_eq!("50.760000", lat.value_as_string(5));
1295        assert_eq!("0.123000", lat.value_as_string(6));
1296        assert_eq!("123.000000", lat.value_as_string(7));
1297        assert_eq!("123.000000", lat.value_as_string(8));
1298        assert_eq!("-50.760000", lat.value_as_string(9));
1299
1300        let lng = batch
1301            .column(2)
1302            .as_any()
1303            .downcast_ref::<Decimal256Array>()
1304            .unwrap();
1305
1306        assert_eq!("-3.335724", lng.value_as_string(0));
1307        assert_eq!("-2.179404", lng.value_as_string(1));
1308        assert_eq!("-1.778197", lng.value_as_string(2));
1309        assert_eq!("-3.179090", lng.value_as_string(3));
1310        assert_eq!("-3.179090", lng.value_as_string(4));
1311        assert_eq!("0.290472", lng.value_as_string(5));
1312        assert_eq!("0.290472", lng.value_as_string(6));
1313        assert_eq!("0.290472", lng.value_as_string(7));
1314        assert_eq!("0.290472", lng.value_as_string(8));
1315        assert_eq!("0.290472", lng.value_as_string(9));
1316    }
1317
1318    #[test]
1319    fn test_csv_from_buf_reader() {
1320        let schema = Schema::new(vec![
1321            Field::new("city", DataType::Utf8, false),
1322            Field::new("lat", DataType::Float64, false),
1323            Field::new("lng", DataType::Float64, false),
1324        ]);
1325
1326        let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1327        let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1328        let both_files = file_with_headers
1329            .chain(Cursor::new("\n".to_string()))
1330            .chain(file_without_headers);
1331        let mut csv = ReaderBuilder::new(Arc::new(schema))
1332            .with_header(true)
1333            .build(both_files)
1334            .unwrap();
1335        let batch = csv.next().unwrap().unwrap();
1336        assert_eq!(74, batch.num_rows());
1337        assert_eq!(3, batch.num_columns());
1338    }
1339
1340    #[test]
1341    fn test_csv_with_schema_inference() {
1342        let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1343
1344        let (schema, _) = Format::default()
1345            .with_header(true)
1346            .infer_schema(&mut file, None)
1347            .unwrap();
1348
1349        file.rewind().unwrap();
1350        let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1351
1352        let mut csv = builder.build(file).unwrap();
1353        let expected_schema = Schema::new(vec![
1354            Field::new("city", DataType::Utf8, true),
1355            Field::new("lat", DataType::Float64, true),
1356            Field::new("lng", DataType::Float64, true),
1357        ]);
1358        assert_eq!(Arc::new(expected_schema), csv.schema());
1359        let batch = csv.next().unwrap().unwrap();
1360        assert_eq!(37, batch.num_rows());
1361        assert_eq!(3, batch.num_columns());
1362
1363        // access data from a primitive array
1364        let lat = batch
1365            .column(1)
1366            .as_any()
1367            .downcast_ref::<Float64Array>()
1368            .unwrap();
1369        assert_eq!(57.653484, lat.value(0));
1370
1371        // access data from a string array (ListArray<u8>)
1372        let city = batch
1373            .column(0)
1374            .as_any()
1375            .downcast_ref::<StringArray>()
1376            .unwrap();
1377
1378        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1379    }
1380
1381    #[test]
1382    fn test_csv_with_schema_inference_no_headers() {
1383        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1384
1385        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1386        file.rewind().unwrap();
1387
1388        let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1389
1390        // csv field names should be 'column_{number}'
1391        let schema = csv.schema();
1392        assert_eq!("column_1", schema.field(0).name());
1393        assert_eq!("column_2", schema.field(1).name());
1394        assert_eq!("column_3", schema.field(2).name());
1395        let batch = csv.next().unwrap().unwrap();
1396        let batch_schema = batch.schema();
1397
1398        assert_eq!(schema, batch_schema);
1399        assert_eq!(37, batch.num_rows());
1400        assert_eq!(3, batch.num_columns());
1401
1402        // access data from a primitive array
1403        let lat = batch
1404            .column(1)
1405            .as_any()
1406            .downcast_ref::<Float64Array>()
1407            .unwrap();
1408        assert_eq!(57.653484, lat.value(0));
1409
1410        // access data from a string array (ListArray<u8>)
1411        let city = batch
1412            .column(0)
1413            .as_any()
1414            .downcast_ref::<StringArray>()
1415            .unwrap();
1416
1417        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1418    }
1419
1420    #[test]
1421    fn test_csv_builder_with_bounds() {
1422        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1423
1424        // Set the bounds to the lines 0, 1 and 2.
1425        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1426        file.rewind().unwrap();
1427        let mut csv = ReaderBuilder::new(Arc::new(schema))
1428            .with_bounds(0, 2)
1429            .build(file)
1430            .unwrap();
1431        let batch = csv.next().unwrap().unwrap();
1432
1433        // access data from a string array (ListArray<u8>)
1434        let city = batch
1435            .column(0)
1436            .as_any()
1437            .downcast_ref::<StringArray>()
1438            .unwrap();
1439
1440        // The value on line 0 is within the bounds
1441        assert_eq!("Elgin, Scotland, the UK", city.value(0));
1442
1443        // The value on line 13 is outside of the bounds. Therefore
1444        // the call to .value() will panic.
1445        let result = std::panic::catch_unwind(|| city.value(13));
1446        assert!(result.is_err());
1447    }
1448
1449    #[test]
1450    fn test_csv_with_projection() {
1451        let schema = Arc::new(Schema::new(vec![
1452            Field::new("city", DataType::Utf8, false),
1453            Field::new("lat", DataType::Float64, false),
1454            Field::new("lng", DataType::Float64, false),
1455        ]));
1456
1457        let file = File::open("test/data/uk_cities.csv").unwrap();
1458
1459        let mut csv = ReaderBuilder::new(schema)
1460            .with_projection(vec![0, 1])
1461            .build(file)
1462            .unwrap();
1463
1464        let projected_schema = Arc::new(Schema::new(vec![
1465            Field::new("city", DataType::Utf8, false),
1466            Field::new("lat", DataType::Float64, false),
1467        ]));
1468        assert_eq!(projected_schema, csv.schema());
1469        let batch = csv.next().unwrap().unwrap();
1470        assert_eq!(projected_schema, batch.schema());
1471        assert_eq!(37, batch.num_rows());
1472        assert_eq!(2, batch.num_columns());
1473    }
1474
1475    #[test]
1476    fn test_csv_with_dictionary() {
1477        let schema = Arc::new(Schema::new(vec![
1478            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1479            Field::new("lat", DataType::Float64, false),
1480            Field::new("lng", DataType::Float64, false),
1481        ]));
1482
1483        let file = File::open("test/data/uk_cities.csv").unwrap();
1484
1485        let mut csv = ReaderBuilder::new(schema)
1486            .with_projection(vec![0, 1])
1487            .build(file)
1488            .unwrap();
1489
1490        let projected_schema = Arc::new(Schema::new(vec![
1491            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1492            Field::new("lat", DataType::Float64, false),
1493        ]));
1494        assert_eq!(projected_schema, csv.schema());
1495        let batch = csv.next().unwrap().unwrap();
1496        assert_eq!(projected_schema, batch.schema());
1497        assert_eq!(37, batch.num_rows());
1498        assert_eq!(2, batch.num_columns());
1499
1500        let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1501        let strings = strings.as_string::<i32>();
1502
1503        assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1504        assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1505        assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1506    }
1507
1508    #[test]
1509    fn test_csv_with_nullable_dictionary() {
1510        let offset_type = vec![
1511            DataType::Int8,
1512            DataType::Int16,
1513            DataType::Int32,
1514            DataType::Int64,
1515            DataType::UInt8,
1516            DataType::UInt16,
1517            DataType::UInt32,
1518            DataType::UInt64,
1519        ];
1520        for data_type in offset_type {
1521            let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1522            let dictionary_type =
1523                DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1524            let schema = Arc::new(Schema::new(vec![
1525                Field::new("id", DataType::Utf8, false),
1526                Field::new("name", dictionary_type.clone(), true),
1527            ]));
1528
1529            let mut csv = ReaderBuilder::new(schema)
1530                .build(file.try_clone().unwrap())
1531                .unwrap();
1532
1533            let batch = csv.next().unwrap().unwrap();
1534            assert_eq!(3, batch.num_rows());
1535            assert_eq!(2, batch.num_columns());
1536
1537            let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1538            assert!(!names.is_null(2));
1539            assert!(names.is_null(1));
1540        }
1541    }
1542    #[test]
1543    fn test_nulls() {
1544        let schema = Arc::new(Schema::new(vec![
1545            Field::new("c_int", DataType::UInt64, false),
1546            Field::new("c_float", DataType::Float32, true),
1547            Field::new("c_string", DataType::Utf8, true),
1548            Field::new("c_bool", DataType::Boolean, false),
1549        ]));
1550
1551        let file = File::open("test/data/null_test.csv").unwrap();
1552
1553        let mut csv = ReaderBuilder::new(schema)
1554            .with_header(true)
1555            .build(file)
1556            .unwrap();
1557
1558        let batch = csv.next().unwrap().unwrap();
1559
1560        assert!(!batch.column(1).is_null(0));
1561        assert!(!batch.column(1).is_null(1));
1562        assert!(batch.column(1).is_null(2));
1563        assert!(!batch.column(1).is_null(3));
1564        assert!(!batch.column(1).is_null(4));
1565    }
1566
1567    #[test]
1568    fn test_init_nulls() {
1569        let schema = Arc::new(Schema::new(vec![
1570            Field::new("c_int", DataType::UInt64, true),
1571            Field::new("c_float", DataType::Float32, true),
1572            Field::new("c_string", DataType::Utf8, true),
1573            Field::new("c_bool", DataType::Boolean, true),
1574            Field::new("c_null", DataType::Null, true),
1575        ]));
1576        let file = File::open("test/data/init_null_test.csv").unwrap();
1577
1578        let mut csv = ReaderBuilder::new(schema)
1579            .with_header(true)
1580            .build(file)
1581            .unwrap();
1582
1583        let batch = csv.next().unwrap().unwrap();
1584
1585        assert!(batch.column(1).is_null(0));
1586        assert!(!batch.column(1).is_null(1));
1587        assert!(batch.column(1).is_null(2));
1588        assert!(!batch.column(1).is_null(3));
1589        assert!(!batch.column(1).is_null(4));
1590    }
1591
1592    #[test]
1593    fn test_init_nulls_with_inference() {
1594        let format = Format::default().with_header(true).with_delimiter(b',');
1595
1596        let mut file = File::open("test/data/init_null_test.csv").unwrap();
1597        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1598        file.rewind().unwrap();
1599
1600        let expected_schema = Schema::new(vec![
1601            Field::new("c_int", DataType::Int64, true),
1602            Field::new("c_float", DataType::Float64, true),
1603            Field::new("c_string", DataType::Utf8, true),
1604            Field::new("c_bool", DataType::Boolean, true),
1605            Field::new("c_null", DataType::Null, true),
1606        ]);
1607        assert_eq!(schema, expected_schema);
1608
1609        let mut csv = ReaderBuilder::new(Arc::new(schema))
1610            .with_format(format)
1611            .build(file)
1612            .unwrap();
1613
1614        let batch = csv.next().unwrap().unwrap();
1615
1616        assert!(batch.column(1).is_null(0));
1617        assert!(!batch.column(1).is_null(1));
1618        assert!(batch.column(1).is_null(2));
1619        assert!(!batch.column(1).is_null(3));
1620        assert!(!batch.column(1).is_null(4));
1621    }
1622
1623    #[test]
1624    fn test_custom_nulls() {
1625        let schema = Arc::new(Schema::new(vec![
1626            Field::new("c_int", DataType::UInt64, true),
1627            Field::new("c_float", DataType::Float32, true),
1628            Field::new("c_string", DataType::Utf8, true),
1629            Field::new("c_bool", DataType::Boolean, true),
1630        ]));
1631
1632        let file = File::open("test/data/custom_null_test.csv").unwrap();
1633
1634        let null_regex = Regex::new("^nil$").unwrap();
1635
1636        let mut csv = ReaderBuilder::new(schema)
1637            .with_header(true)
1638            .with_null_regex(null_regex)
1639            .build(file)
1640            .unwrap();
1641
1642        let batch = csv.next().unwrap().unwrap();
1643
1644        // "nil"s should be NULL
1645        assert!(batch.column(0).is_null(1));
1646        assert!(batch.column(1).is_null(2));
1647        assert!(batch.column(3).is_null(4));
1648        assert!(batch.column(2).is_null(3));
1649        assert!(!batch.column(2).is_null(4));
1650    }
1651
1652    #[test]
1653    fn test_nulls_with_inference() {
1654        let mut file = File::open("test/data/various_types.csv").unwrap();
1655        let format = Format::default().with_header(true).with_delimiter(b'|');
1656
1657        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1658        file.rewind().unwrap();
1659
1660        let builder = ReaderBuilder::new(Arc::new(schema))
1661            .with_format(format)
1662            .with_batch_size(512)
1663            .with_projection(vec![0, 1, 2, 3, 4, 5]);
1664
1665        let mut csv = builder.build(file).unwrap();
1666        let batch = csv.next().unwrap().unwrap();
1667
1668        assert_eq!(10, batch.num_rows());
1669        assert_eq!(6, batch.num_columns());
1670
1671        let schema = batch.schema();
1672
1673        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1674        assert_eq!(&DataType::Float64, schema.field(1).data_type());
1675        assert_eq!(&DataType::Float64, schema.field(2).data_type());
1676        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1677        assert_eq!(&DataType::Date32, schema.field(4).data_type());
1678        assert_eq!(
1679            &DataType::Timestamp(TimeUnit::Second, None),
1680            schema.field(5).data_type()
1681        );
1682
1683        let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1684        assert_eq!(
1685            names,
1686            vec![
1687                "c_int",
1688                "c_float",
1689                "c_string",
1690                "c_bool",
1691                "c_date",
1692                "c_datetime"
1693            ]
1694        );
1695
1696        assert!(schema.field(0).is_nullable());
1697        assert!(schema.field(1).is_nullable());
1698        assert!(schema.field(2).is_nullable());
1699        assert!(schema.field(3).is_nullable());
1700        assert!(schema.field(4).is_nullable());
1701        assert!(schema.field(5).is_nullable());
1702
1703        assert!(!batch.column(1).is_null(0));
1704        assert!(!batch.column(1).is_null(1));
1705        assert!(batch.column(1).is_null(2));
1706        assert!(!batch.column(1).is_null(3));
1707        assert!(!batch.column(1).is_null(4));
1708    }
1709
1710    #[test]
1711    fn test_custom_nulls_with_inference() {
1712        let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1713
1714        let null_regex = Regex::new("^nil$").unwrap();
1715
1716        let format = Format::default()
1717            .with_header(true)
1718            .with_null_regex(null_regex);
1719
1720        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1721        file.rewind().unwrap();
1722
1723        let expected_schema = Schema::new(vec![
1724            Field::new("c_int", DataType::Int64, true),
1725            Field::new("c_float", DataType::Float64, true),
1726            Field::new("c_string", DataType::Utf8, true),
1727            Field::new("c_bool", DataType::Boolean, true),
1728        ]);
1729
1730        assert_eq!(schema, expected_schema);
1731
1732        let builder = ReaderBuilder::new(Arc::new(schema))
1733            .with_format(format)
1734            .with_batch_size(512)
1735            .with_projection(vec![0, 1, 2, 3]);
1736
1737        let mut csv = builder.build(file).unwrap();
1738        let batch = csv.next().unwrap().unwrap();
1739
1740        assert_eq!(5, batch.num_rows());
1741        assert_eq!(4, batch.num_columns());
1742
1743        assert_eq!(batch.schema().as_ref(), &expected_schema);
1744    }
1745
1746    #[test]
1747    fn test_scientific_notation_with_inference() {
1748        let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1749        let format = Format::default().with_header(false).with_delimiter(b',');
1750
1751        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1752        file.rewind().unwrap();
1753
1754        let builder = ReaderBuilder::new(Arc::new(schema))
1755            .with_format(format)
1756            .with_batch_size(512)
1757            .with_projection(vec![0, 1]);
1758
1759        let mut csv = builder.build(file).unwrap();
1760        let batch = csv.next().unwrap().unwrap();
1761
1762        let schema = batch.schema();
1763
1764        assert_eq!(&DataType::Float64, schema.field(0).data_type());
1765    }
1766
1767    fn invalid_csv_helper(file_name: &str) -> String {
1768        let file = File::open(file_name).unwrap();
1769        let schema = Schema::new(vec![
1770            Field::new("c_int", DataType::UInt64, false),
1771            Field::new("c_float", DataType::Float32, false),
1772            Field::new("c_string", DataType::Utf8, false),
1773            Field::new("c_bool", DataType::Boolean, false),
1774        ]);
1775
1776        let builder = ReaderBuilder::new(Arc::new(schema))
1777            .with_header(true)
1778            .with_delimiter(b'|')
1779            .with_batch_size(512)
1780            .with_projection(vec![0, 1, 2, 3]);
1781
1782        let mut csv = builder.build(file).unwrap();
1783
1784        csv.next().unwrap().unwrap_err().to_string()
1785    }
1786
1787    #[test]
1788    fn test_parse_invalid_csv_float() {
1789        let file_name = "test/data/various_invalid_types/invalid_float.csv";
1790
1791        let error = invalid_csv_helper(file_name);
1792        assert_eq!("Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", error);
1793    }
1794
1795    #[test]
1796    fn test_parse_invalid_csv_int() {
1797        let file_name = "test/data/various_invalid_types/invalid_int.csv";
1798
1799        let error = invalid_csv_helper(file_name);
1800        assert_eq!("Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", error);
1801    }
1802
1803    #[test]
1804    fn test_parse_invalid_csv_bool() {
1805        let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1806
1807        let error = invalid_csv_helper(file_name);
1808        assert_eq!("Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", error);
1809    }
1810
1811    /// Infer the data type of a record
1812    fn infer_field_schema(string: &str) -> DataType {
1813        let mut v = InferredDataType::default();
1814        v.update(string);
1815        v.get()
1816    }
1817
1818    #[test]
1819    fn test_infer_field_schema() {
1820        assert_eq!(infer_field_schema("A"), DataType::Utf8);
1821        assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1822        assert_eq!(infer_field_schema("10"), DataType::Int64);
1823        assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1824        assert_eq!(infer_field_schema(".2"), DataType::Float64);
1825        assert_eq!(infer_field_schema("2."), DataType::Float64);
1826        assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1827        assert_eq!(infer_field_schema("nan"), DataType::Float64);
1828        assert_eq!(infer_field_schema("inf"), DataType::Float64);
1829        assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1830        assert_eq!(infer_field_schema("true"), DataType::Boolean);
1831        assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1832        assert_eq!(infer_field_schema("false"), DataType::Boolean);
1833        assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1834        assert_eq!(
1835            infer_field_schema("2020-11-08T14:20:01"),
1836            DataType::Timestamp(TimeUnit::Second, None)
1837        );
1838        assert_eq!(
1839            infer_field_schema("2020-11-08 14:20:01"),
1840            DataType::Timestamp(TimeUnit::Second, None)
1841        );
1842        assert_eq!(
1843            infer_field_schema("2020-11-08 14:20:01"),
1844            DataType::Timestamp(TimeUnit::Second, None)
1845        );
1846        assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1847        assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1848        assert_eq!(
1849            infer_field_schema("2021-12-19 13:12:30.921"),
1850            DataType::Timestamp(TimeUnit::Millisecond, None)
1851        );
1852        assert_eq!(
1853            infer_field_schema("2021-12-19T13:12:30.123456789"),
1854            DataType::Timestamp(TimeUnit::Nanosecond, None)
1855        );
1856        assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1857        assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1858    }
1859
1860    #[test]
1861    fn parse_date32() {
1862        assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1863        assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1864        assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1865    }
1866
1867    #[test]
1868    fn parse_time() {
1869        assert_eq!(
1870            Time64NanosecondType::parse("12:10:01.123456789 AM"),
1871            Some(601_123_456_789)
1872        );
1873        assert_eq!(
1874            Time64MicrosecondType::parse("12:10:01.123456 am"),
1875            Some(601_123_456)
1876        );
1877        assert_eq!(
1878            Time32MillisecondType::parse("2:10:01.12 PM"),
1879            Some(51_001_120)
1880        );
1881        assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1882    }
1883
1884    #[test]
1885    fn parse_date64() {
1886        assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1887        assert_eq!(
1888            Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1889            1542129070000
1890        );
1891        assert_eq!(
1892            Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1893            1542129070011
1894        );
1895        assert_eq!(
1896            Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1897            -2203932304000
1898        );
1899        assert_eq!(
1900            Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1901            -2203932304000
1902        );
1903        assert_eq!(
1904            Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1905            -2203932304000 - (30 * 60 * 1000)
1906        );
1907    }
1908
1909    fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1910        timezone: Option<Arc<str>>,
1911        expected: &[i64],
1912    ) {
1913        let csv = [
1914            "1970-01-01T00:00:00",
1915            "1970-01-01T00:00:00Z",
1916            "1970-01-01T00:00:00+02:00",
1917        ]
1918        .join("\n");
1919        let schema = Arc::new(Schema::new(vec![Field::new(
1920            "field",
1921            DataType::Timestamp(T::UNIT, timezone.clone()),
1922            true,
1923        )]));
1924
1925        let mut decoder = ReaderBuilder::new(schema).build_decoder();
1926
1927        let decoded = decoder.decode(csv.as_bytes()).unwrap();
1928        assert_eq!(decoded, csv.len());
1929        decoder.decode(&[]).unwrap();
1930
1931        let batch = decoder.flush().unwrap().unwrap();
1932        assert_eq!(batch.num_columns(), 1);
1933        assert_eq!(batch.num_rows(), 3);
1934        let col = batch.column(0).as_primitive::<T>();
1935        assert_eq!(col.values(), expected);
1936        assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
1937    }
1938
1939    #[test]
1940    fn test_parse_timestamp() {
1941        test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
1942        test_parse_timestamp_impl::<TimestampNanosecondType>(
1943            Some("+00:00".into()),
1944            &[0, 0, -7_200_000_000_000],
1945        );
1946        test_parse_timestamp_impl::<TimestampNanosecondType>(
1947            Some("-05:00".into()),
1948            &[18_000_000_000_000, 0, -7_200_000_000_000],
1949        );
1950        test_parse_timestamp_impl::<TimestampMicrosecondType>(
1951            Some("-03".into()),
1952            &[10_800_000_000, 0, -7_200_000_000],
1953        );
1954        test_parse_timestamp_impl::<TimestampMillisecondType>(
1955            Some("-03".into()),
1956            &[10_800_000, 0, -7_200_000],
1957        );
1958        test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
1959    }
1960
1961    #[test]
1962    fn test_infer_schema_from_multiple_files() {
1963        let mut csv1 = NamedTempFile::new().unwrap();
1964        let mut csv2 = NamedTempFile::new().unwrap();
1965        let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped
1966        let mut csv4 = NamedTempFile::new().unwrap();
1967        writeln!(csv1, "c1,c2,c3").unwrap();
1968        writeln!(csv1, "1,\"foo\",0.5").unwrap();
1969        writeln!(csv1, "3,\"bar\",1").unwrap();
1970        writeln!(csv1, "3,\"bar\",2e-06").unwrap();
1971        // reading csv2 will set c2 to optional
1972        writeln!(csv2, "c1,c2,c3,c4").unwrap();
1973        writeln!(csv2, "10,,3.14,true").unwrap();
1974        // reading csv4 will set c3 to optional
1975        writeln!(csv4, "c1,c2,c3").unwrap();
1976        writeln!(csv4, "10,\"foo\",").unwrap();
1977
1978        let schema = infer_schema_from_files(
1979            &[
1980                csv3.path().to_str().unwrap().to_string(),
1981                csv1.path().to_str().unwrap().to_string(),
1982                csv2.path().to_str().unwrap().to_string(),
1983                csv4.path().to_str().unwrap().to_string(),
1984            ],
1985            b',',
1986            Some(4), // only csv1 and csv2 should be read
1987            true,
1988        )
1989        .unwrap();
1990
1991        assert_eq!(schema.fields().len(), 4);
1992        assert!(schema.field(0).is_nullable());
1993        assert!(schema.field(1).is_nullable());
1994        assert!(schema.field(2).is_nullable());
1995        assert!(schema.field(3).is_nullable());
1996
1997        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1998        assert_eq!(&DataType::Utf8, schema.field(1).data_type());
1999        assert_eq!(&DataType::Float64, schema.field(2).data_type());
2000        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2001    }
2002
2003    #[test]
2004    fn test_bounded() {
2005        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2006        let data = [
2007            vec!["0"],
2008            vec!["1"],
2009            vec!["2"],
2010            vec!["3"],
2011            vec!["4"],
2012            vec!["5"],
2013            vec!["6"],
2014        ];
2015
2016        let data = data
2017            .iter()
2018            .map(|x| x.join(","))
2019            .collect::<Vec<_>>()
2020            .join("\n");
2021        let data = data.as_bytes();
2022
2023        let reader = std::io::Cursor::new(data);
2024
2025        let mut csv = ReaderBuilder::new(Arc::new(schema))
2026            .with_batch_size(2)
2027            .with_projection(vec![0])
2028            .with_bounds(2, 6)
2029            .build_buffered(reader)
2030            .unwrap();
2031
2032        let batch = csv.next().unwrap().unwrap();
2033        let a = batch.column(0);
2034        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2035        assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2036
2037        let batch = csv.next().unwrap().unwrap();
2038        let a = batch.column(0);
2039        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2040        assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2041
2042        assert!(csv.next().is_none());
2043    }
2044
2045    #[test]
2046    fn test_empty_projection() {
2047        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2048        let data = [vec!["0"], vec!["1"]];
2049
2050        let data = data
2051            .iter()
2052            .map(|x| x.join(","))
2053            .collect::<Vec<_>>()
2054            .join("\n");
2055
2056        let mut csv = ReaderBuilder::new(Arc::new(schema))
2057            .with_batch_size(2)
2058            .with_projection(vec![])
2059            .build_buffered(Cursor::new(data.as_bytes()))
2060            .unwrap();
2061
2062        let batch = csv.next().unwrap().unwrap();
2063        assert_eq!(batch.columns().len(), 0);
2064        assert_eq!(batch.num_rows(), 2);
2065
2066        assert!(csv.next().is_none());
2067    }
2068
2069    #[test]
2070    fn test_parsing_bool() {
2071        // Encode the expected behavior of boolean parsing
2072        assert_eq!(Some(true), parse_bool("true"));
2073        assert_eq!(Some(true), parse_bool("tRUe"));
2074        assert_eq!(Some(true), parse_bool("True"));
2075        assert_eq!(Some(true), parse_bool("TRUE"));
2076        assert_eq!(None, parse_bool("t"));
2077        assert_eq!(None, parse_bool("T"));
2078        assert_eq!(None, parse_bool(""));
2079
2080        assert_eq!(Some(false), parse_bool("false"));
2081        assert_eq!(Some(false), parse_bool("fALse"));
2082        assert_eq!(Some(false), parse_bool("False"));
2083        assert_eq!(Some(false), parse_bool("FALSE"));
2084        assert_eq!(None, parse_bool("f"));
2085        assert_eq!(None, parse_bool("F"));
2086        assert_eq!(None, parse_bool(""));
2087    }
2088
2089    #[test]
2090    fn test_parsing_float() {
2091        assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2092        assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2093        assert_eq!(Some(12.0), Float64Type::parse("12"));
2094        assert_eq!(Some(0.0), Float64Type::parse("0"));
2095        assert_eq!(Some(2.0), Float64Type::parse("2."));
2096        assert_eq!(Some(0.2), Float64Type::parse(".2"));
2097        assert!(Float64Type::parse("nan").unwrap().is_nan());
2098        assert!(Float64Type::parse("NaN").unwrap().is_nan());
2099        assert!(Float64Type::parse("inf").unwrap().is_infinite());
2100        assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2101        assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2102        assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2103        assert_eq!(None, Float64Type::parse(""));
2104        assert_eq!(None, Float64Type::parse("dd"));
2105        assert_eq!(None, Float64Type::parse("12.34.56"));
2106    }
2107
2108    #[test]
2109    fn test_non_std_quote() {
2110        let schema = Schema::new(vec![
2111            Field::new("text1", DataType::Utf8, false),
2112            Field::new("text2", DataType::Utf8, false),
2113        ]);
2114        let builder = ReaderBuilder::new(Arc::new(schema))
2115            .with_header(false)
2116            .with_quote(b'~'); // default is ", change to ~
2117
2118        let mut csv_text = Vec::new();
2119        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2120        for index in 0..10 {
2121            let text1 = format!("id{index:}");
2122            let text2 = format!("value{index:}");
2123            csv_writer
2124                .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2125                .unwrap();
2126        }
2127        let mut csv_reader = std::io::Cursor::new(&csv_text);
2128        let mut reader = builder.build(&mut csv_reader).unwrap();
2129        let batch = reader.next().unwrap().unwrap();
2130        let col0 = batch.column(0);
2131        assert_eq!(col0.len(), 10);
2132        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2133        assert_eq!(col0_arr.value(0), "id0");
2134        let col1 = batch.column(1);
2135        assert_eq!(col1.len(), 10);
2136        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2137        assert_eq!(col1_arr.value(5), "value5");
2138    }
2139
2140    #[test]
2141    fn test_non_std_escape() {
2142        let schema = Schema::new(vec![
2143            Field::new("text1", DataType::Utf8, false),
2144            Field::new("text2", DataType::Utf8, false),
2145        ]);
2146        let builder = ReaderBuilder::new(Arc::new(schema))
2147            .with_header(false)
2148            .with_escape(b'\\'); // default is None, change to \
2149
2150        let mut csv_text = Vec::new();
2151        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2152        for index in 0..10 {
2153            let text1 = format!("id{index:}");
2154            let text2 = format!("value\\\"{index:}");
2155            csv_writer
2156                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2157                .unwrap();
2158        }
2159        let mut csv_reader = std::io::Cursor::new(&csv_text);
2160        let mut reader = builder.build(&mut csv_reader).unwrap();
2161        let batch = reader.next().unwrap().unwrap();
2162        let col0 = batch.column(0);
2163        assert_eq!(col0.len(), 10);
2164        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2165        assert_eq!(col0_arr.value(0), "id0");
2166        let col1 = batch.column(1);
2167        assert_eq!(col1.len(), 10);
2168        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2169        assert_eq!(col1_arr.value(5), "value\"5");
2170    }
2171
2172    #[test]
2173    fn test_non_std_terminator() {
2174        let schema = Schema::new(vec![
2175            Field::new("text1", DataType::Utf8, false),
2176            Field::new("text2", DataType::Utf8, false),
2177        ]);
2178        let builder = ReaderBuilder::new(Arc::new(schema))
2179            .with_header(false)
2180            .with_terminator(b'\n'); // default is CRLF, change to LF
2181
2182        let mut csv_text = Vec::new();
2183        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2184        for index in 0..10 {
2185            let text1 = format!("id{index:}");
2186            let text2 = format!("value{index:}");
2187            csv_writer
2188                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2189                .unwrap();
2190        }
2191        let mut csv_reader = std::io::Cursor::new(&csv_text);
2192        let mut reader = builder.build(&mut csv_reader).unwrap();
2193        let batch = reader.next().unwrap().unwrap();
2194        let col0 = batch.column(0);
2195        assert_eq!(col0.len(), 10);
2196        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2197        assert_eq!(col0_arr.value(0), "id0");
2198        let col1 = batch.column(1);
2199        assert_eq!(col1.len(), 10);
2200        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2201        assert_eq!(col1_arr.value(5), "value5");
2202    }
2203
2204    #[test]
2205    fn test_header_bounds() {
2206        let csv = "a,b\na,b\na,b\na,b\na,b\n";
2207        let tests = [
2208            (None, false, 5),
2209            (None, true, 4),
2210            (Some((0, 4)), false, 4),
2211            (Some((1, 4)), false, 3),
2212            (Some((0, 4)), true, 4),
2213            (Some((1, 4)), true, 3),
2214        ];
2215        let schema = Arc::new(Schema::new(vec![
2216            Field::new("a", DataType::Utf8, false),
2217            Field::new("a", DataType::Utf8, false),
2218        ]));
2219
2220        for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2221            let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2222            if let Some((start, end)) = bounds {
2223                reader = reader.with_bounds(start, end);
2224            }
2225            let b = reader
2226                .build_buffered(Cursor::new(csv.as_bytes()))
2227                .unwrap()
2228                .next()
2229                .unwrap()
2230                .unwrap();
2231            assert_eq!(b.num_rows(), expected, "{idx}");
2232        }
2233    }
2234
2235    #[test]
2236    fn test_null_boolean() {
2237        let csv = "true,false\nFalse,True\n,True\nFalse,";
2238        let schema = Arc::new(Schema::new(vec![
2239            Field::new("a", DataType::Boolean, true),
2240            Field::new("a", DataType::Boolean, true),
2241        ]));
2242
2243        let b = ReaderBuilder::new(schema)
2244            .build_buffered(Cursor::new(csv.as_bytes()))
2245            .unwrap()
2246            .next()
2247            .unwrap()
2248            .unwrap();
2249
2250        assert_eq!(b.num_rows(), 4);
2251        assert_eq!(b.num_columns(), 2);
2252
2253        let c = b.column(0).as_boolean();
2254        assert_eq!(c.null_count(), 1);
2255        assert!(c.value(0));
2256        assert!(!c.value(1));
2257        assert!(c.is_null(2));
2258        assert!(!c.value(3));
2259
2260        let c = b.column(1).as_boolean();
2261        assert_eq!(c.null_count(), 1);
2262        assert!(!c.value(0));
2263        assert!(c.value(1));
2264        assert!(c.value(2));
2265        assert!(c.is_null(3));
2266    }
2267
2268    #[test]
2269    fn test_truncated_rows() {
2270        let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2271        let schema = Arc::new(Schema::new(vec![
2272            Field::new("a", DataType::Int32, true),
2273            Field::new("b", DataType::Int32, true),
2274            Field::new("c", DataType::Int32, true),
2275        ]));
2276
2277        let reader = ReaderBuilder::new(schema.clone())
2278            .with_header(true)
2279            .with_truncated_rows(true)
2280            .build(Cursor::new(data))
2281            .unwrap();
2282
2283        let batches = reader.collect::<Result<Vec<_>, _>>();
2284        assert!(batches.is_ok());
2285        let batch = batches.unwrap().into_iter().next().unwrap();
2286        // Empty rows are skipped by the underlying csv parser
2287        assert_eq!(batch.num_rows(), 3);
2288
2289        let reader = ReaderBuilder::new(schema.clone())
2290            .with_header(true)
2291            .with_truncated_rows(false)
2292            .build(Cursor::new(data))
2293            .unwrap();
2294
2295        let batches = reader.collect::<Result<Vec<_>, _>>();
2296        assert!(match batches {
2297            Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2298            _ => false,
2299        });
2300    }
2301
2302    #[test]
2303    fn test_truncated_rows_csv() {
2304        let file = File::open("test/data/truncated_rows.csv").unwrap();
2305        let schema = Arc::new(Schema::new(vec![
2306            Field::new("Name", DataType::Utf8, true),
2307            Field::new("Age", DataType::UInt32, true),
2308            Field::new("Occupation", DataType::Utf8, true),
2309            Field::new("DOB", DataType::Date32, true),
2310        ]));
2311        let reader = ReaderBuilder::new(schema.clone())
2312            .with_header(true)
2313            .with_batch_size(24)
2314            .with_truncated_rows(true);
2315        let csv = reader.build(file).unwrap();
2316        let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2317
2318        assert_eq!(batches.len(), 1);
2319        let batch = &batches[0];
2320        assert_eq!(batch.num_rows(), 6);
2321        assert_eq!(batch.num_columns(), 4);
2322        let name = batch
2323            .column(0)
2324            .as_any()
2325            .downcast_ref::<StringArray>()
2326            .unwrap();
2327        let age = batch
2328            .column(1)
2329            .as_any()
2330            .downcast_ref::<UInt32Array>()
2331            .unwrap();
2332        let occupation = batch
2333            .column(2)
2334            .as_any()
2335            .downcast_ref::<StringArray>()
2336            .unwrap();
2337        let dob = batch
2338            .column(3)
2339            .as_any()
2340            .downcast_ref::<Date32Array>()
2341            .unwrap();
2342
2343        assert_eq!(name.value(0), "A1");
2344        assert_eq!(name.value(1), "B2");
2345        assert!(name.is_null(2));
2346        assert_eq!(name.value(3), "C3");
2347        assert_eq!(name.value(4), "D4");
2348        assert_eq!(name.value(5), "E5");
2349
2350        assert_eq!(age.value(0), 34);
2351        assert_eq!(age.value(1), 29);
2352        assert!(age.is_null(2));
2353        assert_eq!(age.value(3), 45);
2354        assert!(age.is_null(4));
2355        assert_eq!(age.value(5), 31);
2356
2357        assert_eq!(occupation.value(0), "Engineer");
2358        assert_eq!(occupation.value(1), "Doctor");
2359        assert!(occupation.is_null(2));
2360        assert_eq!(occupation.value(3), "Artist");
2361        assert!(occupation.is_null(4));
2362        assert!(occupation.is_null(5));
2363
2364        assert_eq!(dob.value(0), 5675);
2365        assert!(dob.is_null(1));
2366        assert!(dob.is_null(2));
2367        assert_eq!(dob.value(3), -1858);
2368        assert!(dob.is_null(4));
2369        assert!(dob.is_null(5));
2370    }
2371
2372    #[test]
2373    fn test_truncated_rows_not_nullable_error() {
2374        let data = "a,b,c\n1,2,3\n4,5";
2375        let schema = Arc::new(Schema::new(vec![
2376            Field::new("a", DataType::Int32, false),
2377            Field::new("b", DataType::Int32, false),
2378            Field::new("c", DataType::Int32, false),
2379        ]));
2380
2381        let reader = ReaderBuilder::new(schema.clone())
2382            .with_header(true)
2383            .with_truncated_rows(true)
2384            .build(Cursor::new(data))
2385            .unwrap();
2386
2387        let batches = reader.collect::<Result<Vec<_>, _>>();
2388        assert!(match batches {
2389            Err(ArrowError::InvalidArgumentError(e)) =>
2390                e.to_string().contains("contains null values"),
2391            _ => false,
2392        });
2393    }
2394
2395    #[test]
2396    fn test_buffered() {
2397        let tests = [
2398            ("test/data/uk_cities.csv", false, 37),
2399            ("test/data/various_types.csv", true, 10),
2400            ("test/data/decimal_test.csv", false, 10),
2401        ];
2402
2403        for (path, has_header, expected_rows) in tests {
2404            let (schema, _) = Format::default()
2405                .infer_schema(File::open(path).unwrap(), None)
2406                .unwrap();
2407            let schema = Arc::new(schema);
2408
2409            for batch_size in [1, 4] {
2410                for capacity in [1, 3, 7, 100] {
2411                    let reader = ReaderBuilder::new(schema.clone())
2412                        .with_batch_size(batch_size)
2413                        .with_header(has_header)
2414                        .build(File::open(path).unwrap())
2415                        .unwrap();
2416
2417                    let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2418
2419                    assert_eq!(
2420                        expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2421                        expected_rows
2422                    );
2423
2424                    let buffered =
2425                        std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2426
2427                    let reader = ReaderBuilder::new(schema.clone())
2428                        .with_batch_size(batch_size)
2429                        .with_header(has_header)
2430                        .build_buffered(buffered)
2431                        .unwrap();
2432
2433                    let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2434                    assert_eq!(expected, actual)
2435                }
2436            }
2437        }
2438    }
2439
2440    fn err_test(csv: &[u8], expected: &str) {
2441        fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2442            let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2443            let b = ReaderBuilder::new(schema)
2444                .with_batch_size(2)
2445                .build_buffered(buffer)
2446                .unwrap();
2447            let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2448            assert_eq!(err, expected)
2449        }
2450
2451        let schema_utf8 = Arc::new(Schema::new(vec![
2452            Field::new("text1", DataType::Utf8, true),
2453            Field::new("text2", DataType::Utf8, true),
2454        ]));
2455        err_test_with_schema(csv, expected, schema_utf8);
2456
2457        let schema_utf8view = Arc::new(Schema::new(vec![
2458            Field::new("text1", DataType::Utf8View, true),
2459            Field::new("text2", DataType::Utf8View, true),
2460        ]));
2461        err_test_with_schema(csv, expected, schema_utf8view);
2462    }
2463
2464    #[test]
2465    fn test_invalid_utf8() {
2466        err_test(
2467            b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2468            "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2469        );
2470
2471        err_test(
2472            b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2473            "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2474        );
2475
2476        err_test(
2477            b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2478            "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2479        );
2480
2481        err_test(
2482            b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2483            "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2484        );
2485    }
2486
2487    struct InstrumentedRead<R> {
2488        r: R,
2489        fill_count: usize,
2490        fill_sizes: Vec<usize>,
2491    }
2492
2493    impl<R> InstrumentedRead<R> {
2494        fn new(r: R) -> Self {
2495            Self {
2496                r,
2497                fill_count: 0,
2498                fill_sizes: vec![],
2499            }
2500        }
2501    }
2502
2503    impl<R: Seek> Seek for InstrumentedRead<R> {
2504        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2505            self.r.seek(pos)
2506        }
2507    }
2508
2509    impl<R: BufRead> Read for InstrumentedRead<R> {
2510        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2511            self.r.read(buf)
2512        }
2513    }
2514
2515    impl<R: BufRead> BufRead for InstrumentedRead<R> {
2516        fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2517            self.fill_count += 1;
2518            let buf = self.r.fill_buf()?;
2519            self.fill_sizes.push(buf.len());
2520            Ok(buf)
2521        }
2522
2523        fn consume(&mut self, amt: usize) {
2524            self.r.consume(amt)
2525        }
2526    }
2527
2528    #[test]
2529    fn test_io() {
2530        let schema = Arc::new(Schema::new(vec![
2531            Field::new("a", DataType::Utf8, false),
2532            Field::new("b", DataType::Utf8, false),
2533        ]));
2534        let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2535        let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2536        let reader = ReaderBuilder::new(schema)
2537            .with_batch_size(3)
2538            .build_buffered(&mut read)
2539            .unwrap();
2540
2541        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2542        assert_eq!(batches.len(), 2);
2543        assert_eq!(batches[0].num_rows(), 3);
2544        assert_eq!(batches[1].num_rows(), 1);
2545
2546        // Expect 4 calls to fill_buf
2547        // 1. Read first 3 rows
2548        // 2. Read final row
2549        // 3. Delimit and flush final row
2550        // 4. Iterator finished
2551        assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2552        assert_eq!(read.fill_count, 4);
2553    }
2554
2555    #[test]
2556    fn test_inference() {
2557        let cases: &[(&[&str], DataType)] = &[
2558            (&[], DataType::Null),
2559            (&["false", "12"], DataType::Utf8),
2560            (&["12", "cupcakes"], DataType::Utf8),
2561            (&["12", "12.4"], DataType::Float64),
2562            (&["14050", "24332"], DataType::Int64),
2563            (&["14050.0", "true"], DataType::Utf8),
2564            (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2565            (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2566            (
2567                &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2568                DataType::Timestamp(TimeUnit::Second, None),
2569            ),
2570            (&["2020-03-19", "2020-03-20"], DataType::Date32),
2571            (
2572                &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2573                DataType::Timestamp(TimeUnit::Second, None),
2574            ),
2575            (
2576                &[
2577                    "2020-03-19",
2578                    "2020-03-19 02:00:00",
2579                    "2020-03-19 00:00:00.000",
2580                ],
2581                DataType::Timestamp(TimeUnit::Millisecond, None),
2582            ),
2583            (
2584                &[
2585                    "2020-03-19",
2586                    "2020-03-19 02:00:00",
2587                    "2020-03-19 00:00:00.000000",
2588                ],
2589                DataType::Timestamp(TimeUnit::Microsecond, None),
2590            ),
2591            (
2592                &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2593                DataType::Timestamp(TimeUnit::Second, None),
2594            ),
2595            (
2596                &[
2597                    "2020-03-19",
2598                    "2020-03-19 02:00:00+02:00",
2599                    "2020-03-19 02:00:00Z",
2600                    "2020-03-19 02:00:00.12Z",
2601                ],
2602                DataType::Timestamp(TimeUnit::Millisecond, None),
2603            ),
2604            (
2605                &[
2606                    "2020-03-19",
2607                    "2020-03-19 02:00:00.000000000",
2608                    "2020-03-19 00:00:00.000000",
2609                ],
2610                DataType::Timestamp(TimeUnit::Nanosecond, None),
2611            ),
2612        ];
2613
2614        for (values, expected) in cases {
2615            let mut t = InferredDataType::default();
2616            for v in *values {
2617                t.update(v)
2618            }
2619            assert_eq!(&t.get(), expected, "{values:?}")
2620        }
2621    }
2622
2623    #[test]
2624    fn test_record_length_mismatch() {
2625        let csv = "\
2626        a,b,c\n\
2627        1,2,3\n\
2628        4,5\n\
2629        6,7,8";
2630        let mut read = Cursor::new(csv.as_bytes());
2631        let result = Format::default()
2632            .with_header(true)
2633            .infer_schema(&mut read, None);
2634        assert!(result.is_err());
2635        // Include line number in the error message to help locate and fix the issue
2636        assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3");
2637    }
2638
2639    #[test]
2640    fn test_comment() {
2641        let schema = Schema::new(vec![
2642            Field::new("a", DataType::Int8, false),
2643            Field::new("b", DataType::Int8, false),
2644        ]);
2645
2646        let csv = "# comment1 \n1,2\n#comment2\n11,22";
2647        let mut read = Cursor::new(csv.as_bytes());
2648        let reader = ReaderBuilder::new(Arc::new(schema))
2649            .with_comment(b'#')
2650            .build(&mut read)
2651            .unwrap();
2652
2653        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2654        assert_eq!(batches.len(), 1);
2655        let b = batches.first().unwrap();
2656        assert_eq!(b.num_columns(), 2);
2657        assert_eq!(
2658            b.column(0)
2659                .as_any()
2660                .downcast_ref::<Int8Array>()
2661                .unwrap()
2662                .values(),
2663            &vec![1, 11]
2664        );
2665        assert_eq!(
2666            b.column(1)
2667                .as_any()
2668                .downcast_ref::<Int8Array>()
2669                .unwrap()
2670                .values(),
2671            &vec![2, 22]
2672        );
2673    }
2674
2675    #[test]
2676    fn test_parse_string_view_single_column() {
2677        let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2678        let schema = Arc::new(Schema::new(vec![Field::new(
2679            "c1",
2680            DataType::Utf8View,
2681            true,
2682        )]));
2683
2684        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2685
2686        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2687        assert_eq!(decoded, csv.len());
2688        decoder.decode(&[]).unwrap();
2689
2690        let batch = decoder.flush().unwrap().unwrap();
2691        assert_eq!(batch.num_columns(), 1);
2692        assert_eq!(batch.num_rows(), 3);
2693        let col = batch.column(0).as_string_view();
2694        assert_eq!(col.data_type(), &DataType::Utf8View);
2695        assert_eq!(col.value(0), "foo");
2696        assert_eq!(col.value(1), "something_cannot_be_inlined");
2697        assert_eq!(col.value(2), "foobar");
2698    }
2699
2700    #[test]
2701    fn test_parse_string_view_multi_column() {
2702        let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2703        let schema = Arc::new(Schema::new(vec![
2704            Field::new("c1", DataType::Utf8View, true),
2705            Field::new("c2", DataType::Utf8View, true),
2706        ]));
2707
2708        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2709
2710        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2711        assert_eq!(decoded, csv.len());
2712        decoder.decode(&[]).unwrap();
2713
2714        let batch = decoder.flush().unwrap().unwrap();
2715        assert_eq!(batch.num_columns(), 2);
2716        assert_eq!(batch.num_rows(), 3);
2717        let c1 = batch.column(0).as_string_view();
2718        let c2 = batch.column(1).as_string_view();
2719        assert_eq!(c1.data_type(), &DataType::Utf8View);
2720        assert_eq!(c2.data_type(), &DataType::Utf8View);
2721
2722        assert!(!c1.is_null(0));
2723        assert!(c1.is_null(1));
2724        assert!(!c1.is_null(2));
2725        assert_eq!(c1.value(0), "foo");
2726        assert_eq!(c1.value(2), "foobarfoobar");
2727
2728        assert!(c2.is_null(0));
2729        assert!(!c2.is_null(1));
2730        assert!(!c2.is_null(2));
2731        assert_eq!(c2.value(1), "something_cannot_be_inlined");
2732        assert_eq!(c2.value(2), "bar");
2733    }
2734}