arrow_csv/reader/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CSV Reader
19//!
20//! # Basic Usage
21//!
22//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are
23//! loaded in batches and are then converted from row-based data to columnar data.
24//!
25//! Example:
26//!
27//! ```
28//! # use arrow_schema::*;
29//! # use arrow_csv::{Reader, ReaderBuilder};
30//! # use std::fs::File;
31//! # use std::sync::Arc;
32//!
33//! let schema = Schema::new(vec![
34//!     Field::new("city", DataType::Utf8, false),
35//!     Field::new("lat", DataType::Float64, false),
36//!     Field::new("lng", DataType::Float64, false),
37//! ]);
38//!
39//! let file = File::open("test/data/uk_cities.csv").unwrap();
40//!
41//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
42//! let batch = csv.next().unwrap().unwrap();
43//! ```
44//!
45//! # Async Usage
46//!
47//! The lower-level [`Decoder`] can be integrated with various forms of async data streams,
48//! and is designed to be agnostic to the various different kinds of async IO primitives found
49//! within the Rust ecosystem.
50//!
51//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes`
52//!
53//! ```
54//! # use std::task::{Poll, ready};
55//! # use bytes::{Buf, Bytes};
56//! # use arrow_schema::ArrowError;
57//! # use futures::stream::{Stream, StreamExt};
58//! # use arrow_array::RecordBatch;
59//! # use arrow_csv::reader::Decoder;
60//! #
61//! fn decode_stream<S: Stream<Item = Bytes> + Unpin>(
62//!     mut decoder: Decoder,
63//!     mut input: S,
64//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
65//!     let mut buffered = Bytes::new();
66//!     futures::stream::poll_fn(move |cx| {
67//!         loop {
68//!             if buffered.is_empty() {
69//!                 if let Some(b) = ready!(input.poll_next_unpin(cx)) {
70//!                     buffered = b;
71//!                 }
72//!                 // Note: don't break on `None` as the decoder needs
73//!                 // to be called with an empty array to delimit the
74//!                 // final record
75//!             }
76//!             let decoded = match decoder.decode(buffered.as_ref()) {
77//!                 Ok(0) => break,
78//!                 Ok(decoded) => decoded,
79//!                 Err(e) => return Poll::Ready(Some(Err(e))),
80//!             };
81//!             buffered.advance(decoded);
82//!         }
83//!
84//!         Poll::Ready(decoder.flush().transpose())
85//!     })
86//! }
87//!
88//! ```
89//!
90//! In a similar vein, it can also be used with tokio-based IO primitives
91//!
92//! ```
93//! # use std::pin::Pin;
94//! # use std::task::{Poll, ready};
95//! # use futures::Stream;
96//! # use tokio::io::AsyncBufRead;
97//! # use arrow_array::RecordBatch;
98//! # use arrow_csv::reader::Decoder;
99//! # use arrow_schema::ArrowError;
100//! fn decode_stream<R: AsyncBufRead + Unpin>(
101//!     mut decoder: Decoder,
102//!     mut reader: R,
103//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
104//!     futures::stream::poll_fn(move |cx| {
105//!         loop {
106//!             let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) {
107//!                 Ok(b) => b,
108//!                 Err(e) => return Poll::Ready(Some(Err(e.into()))),
109//!             };
110//!             let decoded = match decoder.decode(b) {
111//!                 // Note: the decoder needs to be called with an empty
112//!                 // array to delimit the final record
113//!                 Ok(0) => break,
114//!                 Ok(decoded) => decoded,
115//!                 Err(e) => return Poll::Ready(Some(Err(e))),
116//!             };
117//!             Pin::new(&mut reader).consume(decoded);
118//!         }
119//!
120//!         Poll::Ready(decoder.flush().transpose())
121//!     })
122//! }
123//! ```
124//!
125
126mod records;
127
128use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
129use arrow_array::types::*;
130use arrow_array::*;
131use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser};
132use arrow_schema::*;
133use chrono::{TimeZone, Utc};
134use csv::StringRecord;
135use regex::{Regex, RegexSet};
136use std::fmt::{self, Debug};
137use std::fs::File;
138use std::io::{BufRead, BufReader as StdBufReader, Read};
139use std::sync::{Arc, LazyLock};
140
141use crate::map_csv_error;
142use crate::reader::records::{RecordDecoder, StringRecords};
143use arrow_array::timezone::Tz;
144
145/// Order should match [`InferredDataType`]
146static REGEX_SET: LazyLock<RegexSet> = LazyLock::new(|| {
147    RegexSet::new([
148        r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
149        r"^-?(\d+)$",                   //INTEGER
150        r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", //DECIMAL
151        r"^\d{4}-\d\d-\d\d$",           //DATE32
152        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second)
153        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond)
154        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond)
155        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond)
156    ])
157    .unwrap()
158});
159
160/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
161#[derive(Debug, Clone, Default)]
162struct NullRegex(Option<Regex>);
163
164impl NullRegex {
165    /// Returns true if the value should be considered as `NULL` according to
166    /// the provided regular expression.
167    #[inline]
168    fn is_null(&self, s: &str) -> bool {
169        match &self.0 {
170            Some(r) => r.is_match(s),
171            None => s.is_empty(),
172        }
173    }
174}
175
176#[derive(Default, Copy, Clone)]
177struct InferredDataType {
178    /// Packed booleans indicating type
179    ///
180    /// 0 - Boolean
181    /// 1 - Integer
182    /// 2 - Float64
183    /// 3 - Date32
184    /// 4 - Timestamp(Second)
185    /// 5 - Timestamp(Millisecond)
186    /// 6 - Timestamp(Microsecond)
187    /// 7 - Timestamp(Nanosecond)
188    /// 8 - Utf8
189    packed: u16,
190}
191
192impl InferredDataType {
193    /// Returns the inferred data type
194    fn get(&self) -> DataType {
195        match self.packed {
196            0 => DataType::Null,
197            1 => DataType::Boolean,
198            2 => DataType::Int64,
199            4 | 6 => DataType::Float64, // Promote Int64 to Float64
200            b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
201                // Promote to highest precision temporal type
202                8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
203                9 => DataType::Timestamp(TimeUnit::Microsecond, None),
204                10 => DataType::Timestamp(TimeUnit::Millisecond, None),
205                11 => DataType::Timestamp(TimeUnit::Second, None),
206                12 => DataType::Date32,
207                _ => unreachable!(),
208            },
209            _ => DataType::Utf8,
210        }
211    }
212
213    /// Updates the [`InferredDataType`] with the given string
214    fn update(&mut self, string: &str) {
215        self.packed |= if string.starts_with('"') {
216            1 << 8 // Utf8
217        } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
218            if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
219                // if overflow i64, fallback to utf8
220                1 << 8
221            } else {
222                1 << m
223            }
224        } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
225            1 << 2 // Float64
226        } else {
227            1 << 8 // Utf8
228        }
229    }
230}
231
232/// The format specification for the CSV file
233#[derive(Debug, Clone, Default)]
234pub struct Format {
235    header: bool,
236    delimiter: Option<u8>,
237    escape: Option<u8>,
238    quote: Option<u8>,
239    terminator: Option<u8>,
240    comment: Option<u8>,
241    null_regex: NullRegex,
242    truncated_rows: bool,
243}
244
245impl Format {
246    /// Specify whether the CSV file has a header, defaults to `false`
247    ///
248    /// When `true`, the first row of the CSV file is treated as a header row
249    pub fn with_header(mut self, has_header: bool) -> Self {
250        self.header = has_header;
251        self
252    }
253
254    /// Specify a custom delimiter character, defaults to comma `','`
255    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
256        self.delimiter = Some(delimiter);
257        self
258    }
259
260    /// Specify an escape character, defaults to `None`
261    pub fn with_escape(mut self, escape: u8) -> Self {
262        self.escape = Some(escape);
263        self
264    }
265
266    /// Specify a custom quote character, defaults to double quote `'"'`
267    pub fn with_quote(mut self, quote: u8) -> Self {
268        self.quote = Some(quote);
269        self
270    }
271
272    /// Specify a custom terminator character, defaults to CRLF
273    pub fn with_terminator(mut self, terminator: u8) -> Self {
274        self.terminator = Some(terminator);
275        self
276    }
277
278    /// Specify a comment character, defaults to `None`
279    ///
280    /// Lines starting with this character will be ignored
281    pub fn with_comment(mut self, comment: u8) -> Self {
282        self.comment = Some(comment);
283        self
284    }
285
286    /// Provide a regex to match null values, defaults to `^$`
287    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
288        self.null_regex = NullRegex(Some(null_regex));
289        self
290    }
291
292    /// Whether to allow truncated rows when parsing.
293    ///
294    /// By default this is set to `false` and will error if the CSV rows have different lengths.
295    /// When set to true then it will allow records with less than the expected number of columns
296    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
297    /// will still return an error.
298    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
299        self.truncated_rows = allow;
300        self
301    }
302
303    /// Infer schema of CSV records from the provided `reader`
304    ///
305    /// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
306    /// records are read to infer the schema
307    ///
308    /// Returns inferred schema and number of records read
309    pub fn infer_schema<R: Read>(
310        &self,
311        reader: R,
312        max_records: Option<usize>,
313    ) -> Result<(Schema, usize), ArrowError> {
314        let mut csv_reader = self.build_reader(reader);
315
316        // get or create header names
317        // when has_header is false, creates default column names with column_ prefix
318        let headers: Vec<String> = if self.header {
319            let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
320            headers.iter().map(|s| s.to_string()).collect()
321        } else {
322            let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
323            (0..*first_record_count)
324                .map(|i| format!("column_{}", i + 1))
325                .collect()
326        };
327
328        let header_length = headers.len();
329        // keep track of inferred field types
330        let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
331
332        let mut records_count = 0;
333
334        let mut record = StringRecord::new();
335        let max_records = max_records.unwrap_or(usize::MAX);
336        while records_count < max_records {
337            if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
338                break;
339            }
340            records_count += 1;
341
342            // Note since we may be looking at a sample of the data, we make the safe assumption that
343            // they could be nullable
344            for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
345                if let Some(string) = record.get(i) {
346                    if !self.null_regex.is_null(string) {
347                        column_type.update(string)
348                    }
349                }
350            }
351        }
352
353        // build schema from inference results
354        let fields: Fields = column_types
355            .iter()
356            .zip(&headers)
357            .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
358            .collect();
359
360        Ok((Schema::new(fields), records_count))
361    }
362
363    /// Build a [`csv::Reader`] for this [`Format`]
364    fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
365        let mut builder = csv::ReaderBuilder::new();
366        builder.has_headers(self.header);
367        builder.flexible(self.truncated_rows);
368
369        if let Some(c) = self.delimiter {
370            builder.delimiter(c);
371        }
372        builder.escape(self.escape);
373        if let Some(c) = self.quote {
374            builder.quote(c);
375        }
376        if let Some(t) = self.terminator {
377            builder.terminator(csv::Terminator::Any(t));
378        }
379        if let Some(comment) = self.comment {
380            builder.comment(Some(comment));
381        }
382        builder.from_reader(reader)
383    }
384
385    /// Build a [`csv_core::Reader`] for this [`Format`]
386    fn build_parser(&self) -> csv_core::Reader {
387        let mut builder = csv_core::ReaderBuilder::new();
388        builder.escape(self.escape);
389        builder.comment(self.comment);
390
391        if let Some(c) = self.delimiter {
392            builder.delimiter(c);
393        }
394        if let Some(c) = self.quote {
395            builder.quote(c);
396        }
397        if let Some(t) = self.terminator {
398            builder.terminator(csv_core::Terminator::Any(t));
399        }
400        builder.build()
401    }
402}
403
404/// Infer schema from a list of CSV files by reading through first n records
405/// with `max_read_records` controlling the maximum number of records to read.
406///
407/// Files will be read in the given order until n records have been reached.
408///
409/// If `max_read_records` is not set, all files will be read fully to infer the schema.
410pub fn infer_schema_from_files(
411    files: &[String],
412    delimiter: u8,
413    max_read_records: Option<usize>,
414    has_header: bool,
415) -> Result<Schema, ArrowError> {
416    let mut schemas = vec![];
417    let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
418    let format = Format {
419        delimiter: Some(delimiter),
420        header: has_header,
421        ..Default::default()
422    };
423
424    for fname in files.iter() {
425        let f = File::open(fname)?;
426        let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
427        if records_read == 0 {
428            continue;
429        }
430        schemas.push(schema.clone());
431        records_to_read -= records_read;
432        if records_to_read == 0 {
433            break;
434        }
435    }
436
437    Schema::try_merge(schemas)
438}
439
440// optional bounds of the reader, of the form (min line, max line).
441type Bounds = Option<(usize, usize)>;
442
443/// CSV file reader using [`std::io::BufReader`]
444pub type Reader<R> = BufReader<StdBufReader<R>>;
445
446/// CSV file reader
447pub struct BufReader<R> {
448    /// File reader
449    reader: R,
450
451    /// The decoder
452    decoder: Decoder,
453}
454
455impl<R> fmt::Debug for BufReader<R>
456where
457    R: BufRead,
458{
459    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460        f.debug_struct("Reader")
461            .field("decoder", &self.decoder)
462            .finish()
463    }
464}
465
466impl<R: Read> Reader<R> {
467    /// Returns the schema of the reader, useful for getting the schema without reading
468    /// record batches
469    pub fn schema(&self) -> SchemaRef {
470        match &self.decoder.projection {
471            Some(projection) => {
472                let fields = self.decoder.schema.fields();
473                let projected = projection.iter().map(|i| fields[*i].clone());
474                Arc::new(Schema::new(projected.collect::<Fields>()))
475            }
476            None => self.decoder.schema.clone(),
477        }
478    }
479}
480
481impl<R: BufRead> BufReader<R> {
482    fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
483        loop {
484            let buf = self.reader.fill_buf()?;
485            let decoded = self.decoder.decode(buf)?;
486            self.reader.consume(decoded);
487            // Yield if decoded no bytes or the decoder is full
488            //
489            // The capacity check avoids looping around and potentially
490            // blocking reading data in fill_buf that isn't needed
491            // to flush the next batch
492            if decoded == 0 || self.decoder.capacity() == 0 {
493                break;
494            }
495        }
496
497        self.decoder.flush()
498    }
499}
500
501impl<R: BufRead> Iterator for BufReader<R> {
502    type Item = Result<RecordBatch, ArrowError>;
503
504    fn next(&mut self) -> Option<Self::Item> {
505        self.read().transpose()
506    }
507}
508
509impl<R: BufRead> RecordBatchReader for BufReader<R> {
510    fn schema(&self) -> SchemaRef {
511        self.decoder.schema.clone()
512    }
513}
514
515/// A push-based interface for decoding CSV data from an arbitrary byte stream
516///
517/// See [`Reader`] for a higher-level interface for interface with [`Read`]
518///
519/// The push-based interface facilitates integration with sources that yield arbitrarily
520/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from
521/// object storage
522///
523/// ```
524/// # use std::io::BufRead;
525/// # use arrow_array::RecordBatch;
526/// # use arrow_csv::ReaderBuilder;
527/// # use arrow_schema::{ArrowError, SchemaRef};
528/// #
529/// fn read_from_csv<R: BufRead>(
530///     mut reader: R,
531///     schema: SchemaRef,
532///     batch_size: usize,
533/// ) -> Result<impl Iterator<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
534///     let mut decoder = ReaderBuilder::new(schema)
535///         .with_batch_size(batch_size)
536///         .build_decoder();
537///
538///     let mut next = move || {
539///         loop {
540///             let buf = reader.fill_buf()?;
541///             let decoded = decoder.decode(buf)?;
542///             if decoded == 0 {
543///                 break;
544///             }
545///
546///             // Consume the number of bytes read
547///             reader.consume(decoded);
548///         }
549///         decoder.flush()
550///     };
551///     Ok(std::iter::from_fn(move || next().transpose()))
552/// }
553/// ```
554#[derive(Debug)]
555pub struct Decoder {
556    /// Explicit schema for the CSV file
557    schema: SchemaRef,
558
559    /// Optional projection for which columns to load (zero-based column indices)
560    projection: Option<Vec<usize>>,
561
562    /// Number of records per batch
563    batch_size: usize,
564
565    /// Rows to skip
566    to_skip: usize,
567
568    /// Current line number
569    line_number: usize,
570
571    /// End line number
572    end: usize,
573
574    /// A decoder for [`StringRecords`]
575    record_decoder: RecordDecoder,
576
577    /// Check if the string matches this pattern for `NULL`.
578    null_regex: NullRegex,
579}
580
581impl Decoder {
582    /// Decode records from `buf` returning the number of bytes read
583    ///
584    /// This method returns once `batch_size` objects have been parsed since the
585    /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes
586    /// should be included in the next call to [`Self::decode`]
587    ///
588    /// There is no requirement that `buf` contains a whole number of records, facilitating
589    /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or
590    /// network sources such as object storage
591    pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
592        if self.to_skip != 0 {
593            // Skip in units of `to_read` to avoid over-allocating buffers
594            let to_skip = self.to_skip.min(self.batch_size);
595            let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
596            self.to_skip -= skipped;
597            self.record_decoder.clear();
598            return Ok(bytes);
599        }
600
601        let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
602        let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
603        Ok(bytes)
604    }
605
606    /// Flushes the currently buffered data to a [`RecordBatch`]
607    ///
608    /// This should only be called after [`Self::decode`] has returned `Ok(0)`,
609    /// otherwise may return an error if part way through decoding a record
610    ///
611    /// Returns `Ok(None)` if no buffered data
612    pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
613        if self.record_decoder.is_empty() {
614            return Ok(None);
615        }
616
617        let rows = self.record_decoder.flush()?;
618        let batch = parse(
619            &rows,
620            self.schema.fields(),
621            Some(self.schema.metadata.clone()),
622            self.projection.as_ref(),
623            self.line_number,
624            &self.null_regex,
625        )?;
626        self.line_number += rows.len();
627        Ok(Some(batch))
628    }
629
630    /// Returns the number of records that can be read before requiring a call to [`Self::flush`]
631    pub fn capacity(&self) -> usize {
632        self.batch_size - self.record_decoder.len()
633    }
634}
635
636/// Parses a slice of [`StringRecords`] into a [RecordBatch]
637fn parse(
638    rows: &StringRecords<'_>,
639    fields: &Fields,
640    metadata: Option<std::collections::HashMap<String, String>>,
641    projection: Option<&Vec<usize>>,
642    line_number: usize,
643    null_regex: &NullRegex,
644) -> Result<RecordBatch, ArrowError> {
645    let projection: Vec<usize> = match projection {
646        Some(v) => v.clone(),
647        None => fields.iter().enumerate().map(|(i, _)| i).collect(),
648    };
649
650    let arrays: Result<Vec<ArrayRef>, _> = projection
651        .iter()
652        .map(|i| {
653            let i = *i;
654            let field = &fields[i];
655            match field.data_type() {
656                DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
657                DataType::Decimal32(precision, scale) => build_decimal_array::<Decimal32Type>(
658                    line_number,
659                    rows,
660                    i,
661                    *precision,
662                    *scale,
663                    null_regex,
664                ),
665                DataType::Decimal64(precision, scale) => build_decimal_array::<Decimal64Type>(
666                    line_number,
667                    rows,
668                    i,
669                    *precision,
670                    *scale,
671                    null_regex,
672                ),
673                DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
674                    line_number,
675                    rows,
676                    i,
677                    *precision,
678                    *scale,
679                    null_regex,
680                ),
681                DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
682                    line_number,
683                    rows,
684                    i,
685                    *precision,
686                    *scale,
687                    null_regex,
688                ),
689                DataType::Int8 => {
690                    build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
691                }
692                DataType::Int16 => {
693                    build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
694                }
695                DataType::Int32 => {
696                    build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
697                }
698                DataType::Int64 => {
699                    build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
700                }
701                DataType::UInt8 => {
702                    build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
703                }
704                DataType::UInt16 => {
705                    build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
706                }
707                DataType::UInt32 => {
708                    build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
709                }
710                DataType::UInt64 => {
711                    build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
712                }
713                DataType::Float32 => {
714                    build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
715                }
716                DataType::Float64 => {
717                    build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
718                }
719                DataType::Date32 => {
720                    build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
721                }
722                DataType::Date64 => {
723                    build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
724                }
725                DataType::Time32(TimeUnit::Second) => {
726                    build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
727                }
728                DataType::Time32(TimeUnit::Millisecond) => {
729                    build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
730                }
731                DataType::Time64(TimeUnit::Microsecond) => {
732                    build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
733                }
734                DataType::Time64(TimeUnit::Nanosecond) => {
735                    build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
736                }
737                DataType::Timestamp(TimeUnit::Second, tz) => {
738                    build_timestamp_array::<TimestampSecondType>(
739                        line_number,
740                        rows,
741                        i,
742                        tz.as_deref(),
743                        null_regex,
744                    )
745                }
746                DataType::Timestamp(TimeUnit::Millisecond, tz) => {
747                    build_timestamp_array::<TimestampMillisecondType>(
748                        line_number,
749                        rows,
750                        i,
751                        tz.as_deref(),
752                        null_regex,
753                    )
754                }
755                DataType::Timestamp(TimeUnit::Microsecond, tz) => {
756                    build_timestamp_array::<TimestampMicrosecondType>(
757                        line_number,
758                        rows,
759                        i,
760                        tz.as_deref(),
761                        null_regex,
762                    )
763                }
764                DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
765                    build_timestamp_array::<TimestampNanosecondType>(
766                        line_number,
767                        rows,
768                        i,
769                        tz.as_deref(),
770                        null_regex,
771                    )
772                }
773                DataType::Null => Ok(Arc::new({
774                    let mut builder = NullBuilder::new();
775                    builder.append_nulls(rows.len());
776                    builder.finish()
777                }) as ArrayRef),
778                DataType::Utf8 => Ok(Arc::new(
779                    rows.iter()
780                        .map(|row| {
781                            let s = row.get(i);
782                            (!null_regex.is_null(s)).then_some(s)
783                        })
784                        .collect::<StringArray>(),
785                ) as ArrayRef),
786                DataType::Utf8View => Ok(Arc::new(
787                    rows.iter()
788                        .map(|row| {
789                            let s = row.get(i);
790                            (!null_regex.is_null(s)).then_some(s)
791                        })
792                        .collect::<StringViewArray>(),
793                ) as ArrayRef),
794                DataType::Dictionary(key_type, value_type)
795                    if value_type.as_ref() == &DataType::Utf8 =>
796                {
797                    match key_type.as_ref() {
798                        DataType::Int8 => Ok(Arc::new(
799                            rows.iter()
800                                .map(|row| {
801                                    let s = row.get(i);
802                                    (!null_regex.is_null(s)).then_some(s)
803                                })
804                                .collect::<DictionaryArray<Int8Type>>(),
805                        ) as ArrayRef),
806                        DataType::Int16 => Ok(Arc::new(
807                            rows.iter()
808                                .map(|row| {
809                                    let s = row.get(i);
810                                    (!null_regex.is_null(s)).then_some(s)
811                                })
812                                .collect::<DictionaryArray<Int16Type>>(),
813                        ) as ArrayRef),
814                        DataType::Int32 => Ok(Arc::new(
815                            rows.iter()
816                                .map(|row| {
817                                    let s = row.get(i);
818                                    (!null_regex.is_null(s)).then_some(s)
819                                })
820                                .collect::<DictionaryArray<Int32Type>>(),
821                        ) as ArrayRef),
822                        DataType::Int64 => Ok(Arc::new(
823                            rows.iter()
824                                .map(|row| {
825                                    let s = row.get(i);
826                                    (!null_regex.is_null(s)).then_some(s)
827                                })
828                                .collect::<DictionaryArray<Int64Type>>(),
829                        ) as ArrayRef),
830                        DataType::UInt8 => Ok(Arc::new(
831                            rows.iter()
832                                .map(|row| {
833                                    let s = row.get(i);
834                                    (!null_regex.is_null(s)).then_some(s)
835                                })
836                                .collect::<DictionaryArray<UInt8Type>>(),
837                        ) as ArrayRef),
838                        DataType::UInt16 => Ok(Arc::new(
839                            rows.iter()
840                                .map(|row| {
841                                    let s = row.get(i);
842                                    (!null_regex.is_null(s)).then_some(s)
843                                })
844                                .collect::<DictionaryArray<UInt16Type>>(),
845                        ) as ArrayRef),
846                        DataType::UInt32 => Ok(Arc::new(
847                            rows.iter()
848                                .map(|row| {
849                                    let s = row.get(i);
850                                    (!null_regex.is_null(s)).then_some(s)
851                                })
852                                .collect::<DictionaryArray<UInt32Type>>(),
853                        ) as ArrayRef),
854                        DataType::UInt64 => Ok(Arc::new(
855                            rows.iter()
856                                .map(|row| {
857                                    let s = row.get(i);
858                                    (!null_regex.is_null(s)).then_some(s)
859                                })
860                                .collect::<DictionaryArray<UInt64Type>>(),
861                        ) as ArrayRef),
862                        _ => Err(ArrowError::ParseError(format!(
863                            "Unsupported dictionary key type {key_type:?}"
864                        ))),
865                    }
866                }
867                other => Err(ArrowError::ParseError(format!(
868                    "Unsupported data type {other:?}"
869                ))),
870            }
871        })
872        .collect();
873
874    let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
875
876    let projected_schema = Arc::new(match metadata {
877        None => Schema::new(projected_fields),
878        Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
879    });
880
881    arrays.and_then(|arr| {
882        RecordBatch::try_new_with_options(
883            projected_schema,
884            arr,
885            &RecordBatchOptions::new()
886                .with_match_field_names(true)
887                .with_row_count(Some(rows.len())),
888        )
889    })
890}
891
892fn parse_bool(string: &str) -> Option<bool> {
893    if string.eq_ignore_ascii_case("false") {
894        Some(false)
895    } else if string.eq_ignore_ascii_case("true") {
896        Some(true)
897    } else {
898        None
899    }
900}
901
902// parse the column string to an Arrow Array
903fn build_decimal_array<T: DecimalType>(
904    _line_number: usize,
905    rows: &StringRecords<'_>,
906    col_idx: usize,
907    precision: u8,
908    scale: i8,
909    null_regex: &NullRegex,
910) -> Result<ArrayRef, ArrowError> {
911    let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
912    for row in rows.iter() {
913        let s = row.get(col_idx);
914        if null_regex.is_null(s) {
915            // append null
916            decimal_builder.append_null();
917        } else {
918            let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
919            match decimal_value {
920                Ok(v) => {
921                    decimal_builder.append_value(v);
922                }
923                Err(e) => {
924                    return Err(e);
925                }
926            }
927        }
928    }
929    Ok(Arc::new(
930        decimal_builder
931            .finish()
932            .with_precision_and_scale(precision, scale)?,
933    ))
934}
935
936// parses a specific column (col_idx) into an Arrow Array.
937fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
938    line_number: usize,
939    rows: &StringRecords<'_>,
940    col_idx: usize,
941    null_regex: &NullRegex,
942) -> Result<ArrayRef, ArrowError> {
943    rows.iter()
944        .enumerate()
945        .map(|(row_index, row)| {
946            let s = row.get(col_idx);
947            if null_regex.is_null(s) {
948                return Ok(None);
949            }
950
951            match T::parse(s) {
952                Some(e) => Ok(Some(e)),
953                None => Err(ArrowError::ParseError(format!(
954                    // TODO: we should surface the underlying error here.
955                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
956                    s,
957                    T::DATA_TYPE,
958                    col_idx,
959                    line_number + row_index,
960                    row
961                ))),
962            }
963        })
964        .collect::<Result<PrimitiveArray<T>, ArrowError>>()
965        .map(|e| Arc::new(e) as ArrayRef)
966}
967
968fn build_timestamp_array<T: ArrowTimestampType>(
969    line_number: usize,
970    rows: &StringRecords<'_>,
971    col_idx: usize,
972    timezone: Option<&str>,
973    null_regex: &NullRegex,
974) -> Result<ArrayRef, ArrowError> {
975    Ok(Arc::new(match timezone {
976        Some(timezone) => {
977            let tz: Tz = timezone.parse()?;
978            build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
979                .with_timezone(timezone)
980        }
981        None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
982    }))
983}
984
985fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
986    line_number: usize,
987    rows: &StringRecords<'_>,
988    col_idx: usize,
989    timezone: &Tz,
990    null_regex: &NullRegex,
991) -> Result<PrimitiveArray<T>, ArrowError> {
992    rows.iter()
993        .enumerate()
994        .map(|(row_index, row)| {
995            let s = row.get(col_idx);
996            if null_regex.is_null(s) {
997                return Ok(None);
998            }
999
1000            let date = string_to_datetime(timezone, s)
1001                .and_then(|date| match T::UNIT {
1002                    TimeUnit::Second => Ok(date.timestamp()),
1003                    TimeUnit::Millisecond => Ok(date.timestamp_millis()),
1004                    TimeUnit::Microsecond => Ok(date.timestamp_micros()),
1005                    TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
1006                        ArrowError::ParseError(format!(
1007                            "{} would overflow 64-bit signed nanoseconds",
1008                            date.to_rfc3339(),
1009                        ))
1010                    }),
1011                })
1012                .map_err(|e| {
1013                    ArrowError::ParseError(format!(
1014                        "Error parsing column {col_idx} at line {}: {}",
1015                        line_number + row_index,
1016                        e
1017                    ))
1018                })?;
1019            Ok(Some(date))
1020        })
1021        .collect()
1022}
1023
1024// parses a specific column (col_idx) into an Arrow Array.
1025fn build_boolean_array(
1026    line_number: usize,
1027    rows: &StringRecords<'_>,
1028    col_idx: usize,
1029    null_regex: &NullRegex,
1030) -> Result<ArrayRef, ArrowError> {
1031    rows.iter()
1032        .enumerate()
1033        .map(|(row_index, row)| {
1034            let s = row.get(col_idx);
1035            if null_regex.is_null(s) {
1036                return Ok(None);
1037            }
1038            let parsed = parse_bool(s);
1039            match parsed {
1040                Some(e) => Ok(Some(e)),
1041                None => Err(ArrowError::ParseError(format!(
1042                    // TODO: we should surface the underlying error here.
1043                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1044                    s,
1045                    "Boolean",
1046                    col_idx,
1047                    line_number + row_index,
1048                    row
1049                ))),
1050            }
1051        })
1052        .collect::<Result<BooleanArray, _>>()
1053        .map(|e| Arc::new(e) as ArrayRef)
1054}
1055
1056/// CSV file reader builder
1057#[derive(Debug)]
1058pub struct ReaderBuilder {
1059    /// Schema of the CSV file
1060    schema: SchemaRef,
1061    /// Format of the CSV file
1062    format: Format,
1063    /// Batch size (number of records to load each time)
1064    ///
1065    /// The default batch size when using the `ReaderBuilder` is 1024 records
1066    batch_size: usize,
1067    /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF.
1068    bounds: Bounds,
1069    /// Optional projection for which columns to load (zero-based column indices)
1070    projection: Option<Vec<usize>>,
1071}
1072
1073impl ReaderBuilder {
1074    /// Create a new builder for configuring CSV parsing options.
1075    ///
1076    /// To convert a builder into a reader, call `ReaderBuilder::build`
1077    ///
1078    /// # Example
1079    ///
1080    /// ```
1081    /// # use arrow_csv::{Reader, ReaderBuilder};
1082    /// # use std::fs::File;
1083    /// # use std::io::Seek;
1084    /// # use std::sync::Arc;
1085    /// # use arrow_csv::reader::Format;
1086    /// #
1087    /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1088    /// // Infer the schema with the first 100 records
1089    /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap();
1090    /// file.rewind().unwrap();
1091    ///
1092    /// // create a builder
1093    /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1094    /// ```
1095    pub fn new(schema: SchemaRef) -> ReaderBuilder {
1096        Self {
1097            schema,
1098            format: Format::default(),
1099            batch_size: 1024,
1100            bounds: None,
1101            projection: None,
1102        }
1103    }
1104
1105    /// Set whether the CSV file has a header
1106    pub fn with_header(mut self, has_header: bool) -> Self {
1107        self.format.header = has_header;
1108        self
1109    }
1110
1111    /// Overrides the [Format] of this [ReaderBuilder]
1112    pub fn with_format(mut self, format: Format) -> Self {
1113        self.format = format;
1114        self
1115    }
1116
1117    /// Set the CSV file's column delimiter as a byte character
1118    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1119        self.format.delimiter = Some(delimiter);
1120        self
1121    }
1122
1123    /// Set the given character as the CSV file's escape character
1124    pub fn with_escape(mut self, escape: u8) -> Self {
1125        self.format.escape = Some(escape);
1126        self
1127    }
1128
1129    /// Set the given character as the CSV file's quote character, by default it is double quote
1130    pub fn with_quote(mut self, quote: u8) -> Self {
1131        self.format.quote = Some(quote);
1132        self
1133    }
1134
1135    /// Provide a custom terminator character, defaults to CRLF
1136    pub fn with_terminator(mut self, terminator: u8) -> Self {
1137        self.format.terminator = Some(terminator);
1138        self
1139    }
1140
1141    /// Provide a comment character, lines starting with this character will be ignored
1142    pub fn with_comment(mut self, comment: u8) -> Self {
1143        self.format.comment = Some(comment);
1144        self
1145    }
1146
1147    /// Provide a regex to match null values, defaults to `^$`
1148    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1149        self.format.null_regex = NullRegex(Some(null_regex));
1150        self
1151    }
1152
1153    /// Set the batch size (number of records to load at one time)
1154    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1155        self.batch_size = batch_size;
1156        self
1157    }
1158
1159    /// Set the bounds over which to scan the reader.
1160    /// `start` and `end` are line numbers.
1161    pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1162        self.bounds = Some((start, end));
1163        self
1164    }
1165
1166    /// Set the reader's column projection
1167    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1168        self.projection = Some(projection);
1169        self
1170    }
1171
1172    /// Whether to allow truncated rows when parsing.
1173    ///
1174    /// By default this is set to `false` and will error if the CSV rows have different lengths.
1175    /// When set to true then it will allow records with less than the expected number of columns
1176    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
1177    /// will still return an error.
1178    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1179        self.format.truncated_rows = allow;
1180        self
1181    }
1182
1183    /// Create a new `Reader` from a non-buffered reader
1184    ///
1185    /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
1186    /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`]
1187    pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1188        self.build_buffered(StdBufReader::new(reader))
1189    }
1190
1191    /// Create a new `BufReader` from a buffered reader
1192    pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1193        Ok(BufReader {
1194            reader,
1195            decoder: self.build_decoder(),
1196        })
1197    }
1198
1199    /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
1200    pub fn build_decoder(self) -> Decoder {
1201        let delimiter = self.format.build_parser();
1202        let record_decoder = RecordDecoder::new(
1203            delimiter,
1204            self.schema.fields().len(),
1205            self.format.truncated_rows,
1206        );
1207
1208        let header = self.format.header as usize;
1209
1210        let (start, end) = match self.bounds {
1211            Some((start, end)) => (start + header, end + header),
1212            None => (header, usize::MAX),
1213        };
1214
1215        Decoder {
1216            schema: self.schema,
1217            to_skip: start,
1218            record_decoder,
1219            line_number: start,
1220            end,
1221            projection: self.projection,
1222            batch_size: self.batch_size,
1223            null_regex: self.format.null_regex,
1224        }
1225    }
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230    use super::*;
1231
1232    use std::io::{Cursor, Seek, SeekFrom, Write};
1233    use tempfile::NamedTempFile;
1234
1235    use arrow_array::cast::AsArray;
1236
1237    #[test]
1238    fn test_csv() {
1239        let schema = Arc::new(Schema::new(vec![
1240            Field::new("city", DataType::Utf8, false),
1241            Field::new("lat", DataType::Float64, false),
1242            Field::new("lng", DataType::Float64, false),
1243        ]));
1244
1245        let file = File::open("test/data/uk_cities.csv").unwrap();
1246        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1247        assert_eq!(schema, csv.schema());
1248        let batch = csv.next().unwrap().unwrap();
1249        assert_eq!(37, batch.num_rows());
1250        assert_eq!(3, batch.num_columns());
1251
1252        // access data from a primitive array
1253        let lat = batch.column(1).as_primitive::<Float64Type>();
1254        assert_eq!(57.653484, lat.value(0));
1255
1256        // access data from a string array (ListArray<u8>)
1257        let city = batch.column(0).as_string::<i32>();
1258
1259        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1260    }
1261
1262    #[test]
1263    fn test_csv_schema_metadata() {
1264        let mut metadata = std::collections::HashMap::new();
1265        metadata.insert("foo".to_owned(), "bar".to_owned());
1266        let schema = Arc::new(Schema::new_with_metadata(
1267            vec![
1268                Field::new("city", DataType::Utf8, false),
1269                Field::new("lat", DataType::Float64, false),
1270                Field::new("lng", DataType::Float64, false),
1271            ],
1272            metadata.clone(),
1273        ));
1274
1275        let file = File::open("test/data/uk_cities.csv").unwrap();
1276
1277        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1278        assert_eq!(schema, csv.schema());
1279        let batch = csv.next().unwrap().unwrap();
1280        assert_eq!(37, batch.num_rows());
1281        assert_eq!(3, batch.num_columns());
1282
1283        assert_eq!(&metadata, batch.schema().metadata());
1284    }
1285
1286    #[test]
1287    fn test_csv_reader_with_decimal() {
1288        let schema = Arc::new(Schema::new(vec![
1289            Field::new("city", DataType::Utf8, false),
1290            Field::new("lat", DataType::Decimal128(38, 6), false),
1291            Field::new("lng", DataType::Decimal256(76, 6), false),
1292        ]));
1293
1294        let file = File::open("test/data/decimal_test.csv").unwrap();
1295
1296        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1297        let batch = csv.next().unwrap().unwrap();
1298        // access data from a primitive array
1299        let lat = batch
1300            .column(1)
1301            .as_any()
1302            .downcast_ref::<Decimal128Array>()
1303            .unwrap();
1304
1305        assert_eq!("57.653484", lat.value_as_string(0));
1306        assert_eq!("53.002666", lat.value_as_string(1));
1307        assert_eq!("52.412811", lat.value_as_string(2));
1308        assert_eq!("51.481583", lat.value_as_string(3));
1309        assert_eq!("12.123456", lat.value_as_string(4));
1310        assert_eq!("50.760000", lat.value_as_string(5));
1311        assert_eq!("0.123000", lat.value_as_string(6));
1312        assert_eq!("123.000000", lat.value_as_string(7));
1313        assert_eq!("123.000000", lat.value_as_string(8));
1314        assert_eq!("-50.760000", lat.value_as_string(9));
1315
1316        let lng = batch
1317            .column(2)
1318            .as_any()
1319            .downcast_ref::<Decimal256Array>()
1320            .unwrap();
1321
1322        assert_eq!("-3.335724", lng.value_as_string(0));
1323        assert_eq!("-2.179404", lng.value_as_string(1));
1324        assert_eq!("-1.778197", lng.value_as_string(2));
1325        assert_eq!("-3.179090", lng.value_as_string(3));
1326        assert_eq!("-3.179090", lng.value_as_string(4));
1327        assert_eq!("0.290472", lng.value_as_string(5));
1328        assert_eq!("0.290472", lng.value_as_string(6));
1329        assert_eq!("0.290472", lng.value_as_string(7));
1330        assert_eq!("0.290472", lng.value_as_string(8));
1331        assert_eq!("0.290472", lng.value_as_string(9));
1332    }
1333
1334    #[test]
1335    fn test_csv_reader_with_decimal_3264() {
1336        let schema = Arc::new(Schema::new(vec![
1337            Field::new("city", DataType::Utf8, false),
1338            Field::new("lat", DataType::Decimal32(9, 6), false),
1339            Field::new("lng", DataType::Decimal64(16, 6), false),
1340        ]));
1341
1342        let file = File::open("test/data/decimal_test.csv").unwrap();
1343
1344        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1345        let batch = csv.next().unwrap().unwrap();
1346        // access data from a primitive array
1347        let lat = batch
1348            .column(1)
1349            .as_any()
1350            .downcast_ref::<Decimal32Array>()
1351            .unwrap();
1352
1353        assert_eq!("57.653484", lat.value_as_string(0));
1354        assert_eq!("53.002666", lat.value_as_string(1));
1355        assert_eq!("52.412811", lat.value_as_string(2));
1356        assert_eq!("51.481583", lat.value_as_string(3));
1357        assert_eq!("12.123456", lat.value_as_string(4));
1358        assert_eq!("50.760000", lat.value_as_string(5));
1359        assert_eq!("0.123000", lat.value_as_string(6));
1360        assert_eq!("123.000000", lat.value_as_string(7));
1361        assert_eq!("123.000000", lat.value_as_string(8));
1362        assert_eq!("-50.760000", lat.value_as_string(9));
1363
1364        let lng = batch
1365            .column(2)
1366            .as_any()
1367            .downcast_ref::<Decimal64Array>()
1368            .unwrap();
1369
1370        assert_eq!("-3.335724", lng.value_as_string(0));
1371        assert_eq!("-2.179404", lng.value_as_string(1));
1372        assert_eq!("-1.778197", lng.value_as_string(2));
1373        assert_eq!("-3.179090", lng.value_as_string(3));
1374        assert_eq!("-3.179090", lng.value_as_string(4));
1375        assert_eq!("0.290472", lng.value_as_string(5));
1376        assert_eq!("0.290472", lng.value_as_string(6));
1377        assert_eq!("0.290472", lng.value_as_string(7));
1378        assert_eq!("0.290472", lng.value_as_string(8));
1379        assert_eq!("0.290472", lng.value_as_string(9));
1380    }
1381
1382    #[test]
1383    fn test_csv_from_buf_reader() {
1384        let schema = Schema::new(vec![
1385            Field::new("city", DataType::Utf8, false),
1386            Field::new("lat", DataType::Float64, false),
1387            Field::new("lng", DataType::Float64, false),
1388        ]);
1389
1390        let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1391        let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1392        let both_files = file_with_headers
1393            .chain(Cursor::new("\n".to_string()))
1394            .chain(file_without_headers);
1395        let mut csv = ReaderBuilder::new(Arc::new(schema))
1396            .with_header(true)
1397            .build(both_files)
1398            .unwrap();
1399        let batch = csv.next().unwrap().unwrap();
1400        assert_eq!(74, batch.num_rows());
1401        assert_eq!(3, batch.num_columns());
1402    }
1403
1404    #[test]
1405    fn test_csv_with_schema_inference() {
1406        let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1407
1408        let (schema, _) = Format::default()
1409            .with_header(true)
1410            .infer_schema(&mut file, None)
1411            .unwrap();
1412
1413        file.rewind().unwrap();
1414        let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1415
1416        let mut csv = builder.build(file).unwrap();
1417        let expected_schema = Schema::new(vec![
1418            Field::new("city", DataType::Utf8, true),
1419            Field::new("lat", DataType::Float64, true),
1420            Field::new("lng", DataType::Float64, true),
1421        ]);
1422        assert_eq!(Arc::new(expected_schema), csv.schema());
1423        let batch = csv.next().unwrap().unwrap();
1424        assert_eq!(37, batch.num_rows());
1425        assert_eq!(3, batch.num_columns());
1426
1427        // access data from a primitive array
1428        let lat = batch
1429            .column(1)
1430            .as_any()
1431            .downcast_ref::<Float64Array>()
1432            .unwrap();
1433        assert_eq!(57.653484, lat.value(0));
1434
1435        // access data from a string array (ListArray<u8>)
1436        let city = batch
1437            .column(0)
1438            .as_any()
1439            .downcast_ref::<StringArray>()
1440            .unwrap();
1441
1442        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1443    }
1444
1445    #[test]
1446    fn test_csv_with_schema_inference_no_headers() {
1447        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1448
1449        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1450        file.rewind().unwrap();
1451
1452        let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1453
1454        // csv field names should be 'column_{number}'
1455        let schema = csv.schema();
1456        assert_eq!("column_1", schema.field(0).name());
1457        assert_eq!("column_2", schema.field(1).name());
1458        assert_eq!("column_3", schema.field(2).name());
1459        let batch = csv.next().unwrap().unwrap();
1460        let batch_schema = batch.schema();
1461
1462        assert_eq!(schema, batch_schema);
1463        assert_eq!(37, batch.num_rows());
1464        assert_eq!(3, batch.num_columns());
1465
1466        // access data from a primitive array
1467        let lat = batch
1468            .column(1)
1469            .as_any()
1470            .downcast_ref::<Float64Array>()
1471            .unwrap();
1472        assert_eq!(57.653484, lat.value(0));
1473
1474        // access data from a string array (ListArray<u8>)
1475        let city = batch
1476            .column(0)
1477            .as_any()
1478            .downcast_ref::<StringArray>()
1479            .unwrap();
1480
1481        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1482    }
1483
1484    #[test]
1485    fn test_csv_builder_with_bounds() {
1486        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1487
1488        // Set the bounds to the lines 0, 1 and 2.
1489        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1490        file.rewind().unwrap();
1491        let mut csv = ReaderBuilder::new(Arc::new(schema))
1492            .with_bounds(0, 2)
1493            .build(file)
1494            .unwrap();
1495        let batch = csv.next().unwrap().unwrap();
1496
1497        // access data from a string array (ListArray<u8>)
1498        let city = batch
1499            .column(0)
1500            .as_any()
1501            .downcast_ref::<StringArray>()
1502            .unwrap();
1503
1504        // The value on line 0 is within the bounds
1505        assert_eq!("Elgin, Scotland, the UK", city.value(0));
1506
1507        // The value on line 13 is outside of the bounds. Therefore
1508        // the call to .value() will panic.
1509        let result = std::panic::catch_unwind(|| city.value(13));
1510        assert!(result.is_err());
1511    }
1512
1513    #[test]
1514    fn test_csv_with_projection() {
1515        let schema = Arc::new(Schema::new(vec![
1516            Field::new("city", DataType::Utf8, false),
1517            Field::new("lat", DataType::Float64, false),
1518            Field::new("lng", DataType::Float64, false),
1519        ]));
1520
1521        let file = File::open("test/data/uk_cities.csv").unwrap();
1522
1523        let mut csv = ReaderBuilder::new(schema)
1524            .with_projection(vec![0, 1])
1525            .build(file)
1526            .unwrap();
1527
1528        let projected_schema = Arc::new(Schema::new(vec![
1529            Field::new("city", DataType::Utf8, false),
1530            Field::new("lat", DataType::Float64, false),
1531        ]));
1532        assert_eq!(projected_schema, csv.schema());
1533        let batch = csv.next().unwrap().unwrap();
1534        assert_eq!(projected_schema, batch.schema());
1535        assert_eq!(37, batch.num_rows());
1536        assert_eq!(2, batch.num_columns());
1537    }
1538
1539    #[test]
1540    fn test_csv_with_dictionary() {
1541        let schema = Arc::new(Schema::new(vec![
1542            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1543            Field::new("lat", DataType::Float64, false),
1544            Field::new("lng", DataType::Float64, false),
1545        ]));
1546
1547        let file = File::open("test/data/uk_cities.csv").unwrap();
1548
1549        let mut csv = ReaderBuilder::new(schema)
1550            .with_projection(vec![0, 1])
1551            .build(file)
1552            .unwrap();
1553
1554        let projected_schema = Arc::new(Schema::new(vec![
1555            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1556            Field::new("lat", DataType::Float64, false),
1557        ]));
1558        assert_eq!(projected_schema, csv.schema());
1559        let batch = csv.next().unwrap().unwrap();
1560        assert_eq!(projected_schema, batch.schema());
1561        assert_eq!(37, batch.num_rows());
1562        assert_eq!(2, batch.num_columns());
1563
1564        let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1565        let strings = strings.as_string::<i32>();
1566
1567        assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1568        assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1569        assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1570    }
1571
1572    #[test]
1573    fn test_csv_with_nullable_dictionary() {
1574        let offset_type = vec![
1575            DataType::Int8,
1576            DataType::Int16,
1577            DataType::Int32,
1578            DataType::Int64,
1579            DataType::UInt8,
1580            DataType::UInt16,
1581            DataType::UInt32,
1582            DataType::UInt64,
1583        ];
1584        for data_type in offset_type {
1585            let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1586            let dictionary_type =
1587                DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1588            let schema = Arc::new(Schema::new(vec![
1589                Field::new("id", DataType::Utf8, false),
1590                Field::new("name", dictionary_type.clone(), true),
1591            ]));
1592
1593            let mut csv = ReaderBuilder::new(schema)
1594                .build(file.try_clone().unwrap())
1595                .unwrap();
1596
1597            let batch = csv.next().unwrap().unwrap();
1598            assert_eq!(3, batch.num_rows());
1599            assert_eq!(2, batch.num_columns());
1600
1601            let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1602            assert!(!names.is_null(2));
1603            assert!(names.is_null(1));
1604        }
1605    }
1606    #[test]
1607    fn test_nulls() {
1608        let schema = Arc::new(Schema::new(vec![
1609            Field::new("c_int", DataType::UInt64, false),
1610            Field::new("c_float", DataType::Float32, true),
1611            Field::new("c_string", DataType::Utf8, true),
1612            Field::new("c_bool", DataType::Boolean, false),
1613        ]));
1614
1615        let file = File::open("test/data/null_test.csv").unwrap();
1616
1617        let mut csv = ReaderBuilder::new(schema)
1618            .with_header(true)
1619            .build(file)
1620            .unwrap();
1621
1622        let batch = csv.next().unwrap().unwrap();
1623
1624        assert!(!batch.column(1).is_null(0));
1625        assert!(!batch.column(1).is_null(1));
1626        assert!(batch.column(1).is_null(2));
1627        assert!(!batch.column(1).is_null(3));
1628        assert!(!batch.column(1).is_null(4));
1629    }
1630
1631    #[test]
1632    fn test_init_nulls() {
1633        let schema = Arc::new(Schema::new(vec![
1634            Field::new("c_int", DataType::UInt64, true),
1635            Field::new("c_float", DataType::Float32, true),
1636            Field::new("c_string", DataType::Utf8, true),
1637            Field::new("c_bool", DataType::Boolean, true),
1638            Field::new("c_null", DataType::Null, true),
1639        ]));
1640        let file = File::open("test/data/init_null_test.csv").unwrap();
1641
1642        let mut csv = ReaderBuilder::new(schema)
1643            .with_header(true)
1644            .build(file)
1645            .unwrap();
1646
1647        let batch = csv.next().unwrap().unwrap();
1648
1649        assert!(batch.column(1).is_null(0));
1650        assert!(!batch.column(1).is_null(1));
1651        assert!(batch.column(1).is_null(2));
1652        assert!(!batch.column(1).is_null(3));
1653        assert!(!batch.column(1).is_null(4));
1654    }
1655
1656    #[test]
1657    fn test_init_nulls_with_inference() {
1658        let format = Format::default().with_header(true).with_delimiter(b',');
1659
1660        let mut file = File::open("test/data/init_null_test.csv").unwrap();
1661        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1662        file.rewind().unwrap();
1663
1664        let expected_schema = Schema::new(vec![
1665            Field::new("c_int", DataType::Int64, true),
1666            Field::new("c_float", DataType::Float64, true),
1667            Field::new("c_string", DataType::Utf8, true),
1668            Field::new("c_bool", DataType::Boolean, true),
1669            Field::new("c_null", DataType::Null, true),
1670        ]);
1671        assert_eq!(schema, expected_schema);
1672
1673        let mut csv = ReaderBuilder::new(Arc::new(schema))
1674            .with_format(format)
1675            .build(file)
1676            .unwrap();
1677
1678        let batch = csv.next().unwrap().unwrap();
1679
1680        assert!(batch.column(1).is_null(0));
1681        assert!(!batch.column(1).is_null(1));
1682        assert!(batch.column(1).is_null(2));
1683        assert!(!batch.column(1).is_null(3));
1684        assert!(!batch.column(1).is_null(4));
1685    }
1686
1687    #[test]
1688    fn test_custom_nulls() {
1689        let schema = Arc::new(Schema::new(vec![
1690            Field::new("c_int", DataType::UInt64, true),
1691            Field::new("c_float", DataType::Float32, true),
1692            Field::new("c_string", DataType::Utf8, true),
1693            Field::new("c_bool", DataType::Boolean, true),
1694        ]));
1695
1696        let file = File::open("test/data/custom_null_test.csv").unwrap();
1697
1698        let null_regex = Regex::new("^nil$").unwrap();
1699
1700        let mut csv = ReaderBuilder::new(schema)
1701            .with_header(true)
1702            .with_null_regex(null_regex)
1703            .build(file)
1704            .unwrap();
1705
1706        let batch = csv.next().unwrap().unwrap();
1707
1708        // "nil"s should be NULL
1709        assert!(batch.column(0).is_null(1));
1710        assert!(batch.column(1).is_null(2));
1711        assert!(batch.column(3).is_null(4));
1712        assert!(batch.column(2).is_null(3));
1713        assert!(!batch.column(2).is_null(4));
1714    }
1715
1716    #[test]
1717    fn test_nulls_with_inference() {
1718        let mut file = File::open("test/data/various_types.csv").unwrap();
1719        let format = Format::default().with_header(true).with_delimiter(b'|');
1720
1721        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1722        file.rewind().unwrap();
1723
1724        let builder = ReaderBuilder::new(Arc::new(schema))
1725            .with_format(format)
1726            .with_batch_size(512)
1727            .with_projection(vec![0, 1, 2, 3, 4, 5]);
1728
1729        let mut csv = builder.build(file).unwrap();
1730        let batch = csv.next().unwrap().unwrap();
1731
1732        assert_eq!(10, batch.num_rows());
1733        assert_eq!(6, batch.num_columns());
1734
1735        let schema = batch.schema();
1736
1737        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1738        assert_eq!(&DataType::Float64, schema.field(1).data_type());
1739        assert_eq!(&DataType::Float64, schema.field(2).data_type());
1740        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1741        assert_eq!(&DataType::Date32, schema.field(4).data_type());
1742        assert_eq!(
1743            &DataType::Timestamp(TimeUnit::Second, None),
1744            schema.field(5).data_type()
1745        );
1746
1747        let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1748        assert_eq!(
1749            names,
1750            vec![
1751                "c_int",
1752                "c_float",
1753                "c_string",
1754                "c_bool",
1755                "c_date",
1756                "c_datetime"
1757            ]
1758        );
1759
1760        assert!(schema.field(0).is_nullable());
1761        assert!(schema.field(1).is_nullable());
1762        assert!(schema.field(2).is_nullable());
1763        assert!(schema.field(3).is_nullable());
1764        assert!(schema.field(4).is_nullable());
1765        assert!(schema.field(5).is_nullable());
1766
1767        assert!(!batch.column(1).is_null(0));
1768        assert!(!batch.column(1).is_null(1));
1769        assert!(batch.column(1).is_null(2));
1770        assert!(!batch.column(1).is_null(3));
1771        assert!(!batch.column(1).is_null(4));
1772    }
1773
1774    #[test]
1775    fn test_custom_nulls_with_inference() {
1776        let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1777
1778        let null_regex = Regex::new("^nil$").unwrap();
1779
1780        let format = Format::default()
1781            .with_header(true)
1782            .with_null_regex(null_regex);
1783
1784        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1785        file.rewind().unwrap();
1786
1787        let expected_schema = Schema::new(vec![
1788            Field::new("c_int", DataType::Int64, true),
1789            Field::new("c_float", DataType::Float64, true),
1790            Field::new("c_string", DataType::Utf8, true),
1791            Field::new("c_bool", DataType::Boolean, true),
1792        ]);
1793
1794        assert_eq!(schema, expected_schema);
1795
1796        let builder = ReaderBuilder::new(Arc::new(schema))
1797            .with_format(format)
1798            .with_batch_size(512)
1799            .with_projection(vec![0, 1, 2, 3]);
1800
1801        let mut csv = builder.build(file).unwrap();
1802        let batch = csv.next().unwrap().unwrap();
1803
1804        assert_eq!(5, batch.num_rows());
1805        assert_eq!(4, batch.num_columns());
1806
1807        assert_eq!(batch.schema().as_ref(), &expected_schema);
1808    }
1809
1810    #[test]
1811    fn test_scientific_notation_with_inference() {
1812        let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1813        let format = Format::default().with_header(false).with_delimiter(b',');
1814
1815        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1816        file.rewind().unwrap();
1817
1818        let builder = ReaderBuilder::new(Arc::new(schema))
1819            .with_format(format)
1820            .with_batch_size(512)
1821            .with_projection(vec![0, 1]);
1822
1823        let mut csv = builder.build(file).unwrap();
1824        let batch = csv.next().unwrap().unwrap();
1825
1826        let schema = batch.schema();
1827
1828        assert_eq!(&DataType::Float64, schema.field(0).data_type());
1829    }
1830
1831    fn invalid_csv_helper(file_name: &str) -> String {
1832        let file = File::open(file_name).unwrap();
1833        let schema = Schema::new(vec![
1834            Field::new("c_int", DataType::UInt64, false),
1835            Field::new("c_float", DataType::Float32, false),
1836            Field::new("c_string", DataType::Utf8, false),
1837            Field::new("c_bool", DataType::Boolean, false),
1838        ]);
1839
1840        let builder = ReaderBuilder::new(Arc::new(schema))
1841            .with_header(true)
1842            .with_delimiter(b'|')
1843            .with_batch_size(512)
1844            .with_projection(vec![0, 1, 2, 3]);
1845
1846        let mut csv = builder.build(file).unwrap();
1847
1848        csv.next().unwrap().unwrap_err().to_string()
1849    }
1850
1851    #[test]
1852    fn test_parse_invalid_csv_float() {
1853        let file_name = "test/data/various_invalid_types/invalid_float.csv";
1854
1855        let error = invalid_csv_helper(file_name);
1856        assert_eq!("Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", error);
1857    }
1858
1859    #[test]
1860    fn test_parse_invalid_csv_int() {
1861        let file_name = "test/data/various_invalid_types/invalid_int.csv";
1862
1863        let error = invalid_csv_helper(file_name);
1864        assert_eq!("Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", error);
1865    }
1866
1867    #[test]
1868    fn test_parse_invalid_csv_bool() {
1869        let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1870
1871        let error = invalid_csv_helper(file_name);
1872        assert_eq!("Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", error);
1873    }
1874
1875    /// Infer the data type of a record
1876    fn infer_field_schema(string: &str) -> DataType {
1877        let mut v = InferredDataType::default();
1878        v.update(string);
1879        v.get()
1880    }
1881
1882    #[test]
1883    fn test_infer_field_schema() {
1884        assert_eq!(infer_field_schema("A"), DataType::Utf8);
1885        assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1886        assert_eq!(infer_field_schema("10"), DataType::Int64);
1887        assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1888        assert_eq!(infer_field_schema(".2"), DataType::Float64);
1889        assert_eq!(infer_field_schema("2."), DataType::Float64);
1890        assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1891        assert_eq!(infer_field_schema("nan"), DataType::Float64);
1892        assert_eq!(infer_field_schema("inf"), DataType::Float64);
1893        assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1894        assert_eq!(infer_field_schema("true"), DataType::Boolean);
1895        assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1896        assert_eq!(infer_field_schema("false"), DataType::Boolean);
1897        assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1898        assert_eq!(
1899            infer_field_schema("2020-11-08T14:20:01"),
1900            DataType::Timestamp(TimeUnit::Second, None)
1901        );
1902        assert_eq!(
1903            infer_field_schema("2020-11-08 14:20:01"),
1904            DataType::Timestamp(TimeUnit::Second, None)
1905        );
1906        assert_eq!(
1907            infer_field_schema("2020-11-08 14:20:01"),
1908            DataType::Timestamp(TimeUnit::Second, None)
1909        );
1910        assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1911        assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1912        assert_eq!(
1913            infer_field_schema("2021-12-19 13:12:30.921"),
1914            DataType::Timestamp(TimeUnit::Millisecond, None)
1915        );
1916        assert_eq!(
1917            infer_field_schema("2021-12-19T13:12:30.123456789"),
1918            DataType::Timestamp(TimeUnit::Nanosecond, None)
1919        );
1920        assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1921        assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1922    }
1923
1924    #[test]
1925    fn parse_date32() {
1926        assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1927        assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1928        assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1929    }
1930
1931    #[test]
1932    fn parse_time() {
1933        assert_eq!(
1934            Time64NanosecondType::parse("12:10:01.123456789 AM"),
1935            Some(601_123_456_789)
1936        );
1937        assert_eq!(
1938            Time64MicrosecondType::parse("12:10:01.123456 am"),
1939            Some(601_123_456)
1940        );
1941        assert_eq!(
1942            Time32MillisecondType::parse("2:10:01.12 PM"),
1943            Some(51_001_120)
1944        );
1945        assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
1946    }
1947
1948    #[test]
1949    fn parse_date64() {
1950        assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
1951        assert_eq!(
1952            Date64Type::parse("2018-11-13T17:11:10").unwrap(),
1953            1542129070000
1954        );
1955        assert_eq!(
1956            Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
1957            1542129070011
1958        );
1959        assert_eq!(
1960            Date64Type::parse("1900-02-28T12:34:56").unwrap(),
1961            -2203932304000
1962        );
1963        assert_eq!(
1964            Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
1965            -2203932304000
1966        );
1967        assert_eq!(
1968            Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
1969            -2203932304000 - (30 * 60 * 1000)
1970        );
1971    }
1972
1973    fn test_parse_timestamp_impl<T: ArrowTimestampType>(
1974        timezone: Option<Arc<str>>,
1975        expected: &[i64],
1976    ) {
1977        let csv = [
1978            "1970-01-01T00:00:00",
1979            "1970-01-01T00:00:00Z",
1980            "1970-01-01T00:00:00+02:00",
1981        ]
1982        .join("\n");
1983        let schema = Arc::new(Schema::new(vec![Field::new(
1984            "field",
1985            DataType::Timestamp(T::UNIT, timezone.clone()),
1986            true,
1987        )]));
1988
1989        let mut decoder = ReaderBuilder::new(schema).build_decoder();
1990
1991        let decoded = decoder.decode(csv.as_bytes()).unwrap();
1992        assert_eq!(decoded, csv.len());
1993        decoder.decode(&[]).unwrap();
1994
1995        let batch = decoder.flush().unwrap().unwrap();
1996        assert_eq!(batch.num_columns(), 1);
1997        assert_eq!(batch.num_rows(), 3);
1998        let col = batch.column(0).as_primitive::<T>();
1999        assert_eq!(col.values(), expected);
2000        assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
2001    }
2002
2003    #[test]
2004    fn test_parse_timestamp() {
2005        test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
2006        test_parse_timestamp_impl::<TimestampNanosecondType>(
2007            Some("+00:00".into()),
2008            &[0, 0, -7_200_000_000_000],
2009        );
2010        test_parse_timestamp_impl::<TimestampNanosecondType>(
2011            Some("-05:00".into()),
2012            &[18_000_000_000_000, 0, -7_200_000_000_000],
2013        );
2014        test_parse_timestamp_impl::<TimestampMicrosecondType>(
2015            Some("-03".into()),
2016            &[10_800_000_000, 0, -7_200_000_000],
2017        );
2018        test_parse_timestamp_impl::<TimestampMillisecondType>(
2019            Some("-03".into()),
2020            &[10_800_000, 0, -7_200_000],
2021        );
2022        test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
2023    }
2024
2025    #[test]
2026    fn test_infer_schema_from_multiple_files() {
2027        let mut csv1 = NamedTempFile::new().unwrap();
2028        let mut csv2 = NamedTempFile::new().unwrap();
2029        let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped
2030        let mut csv4 = NamedTempFile::new().unwrap();
2031        writeln!(csv1, "c1,c2,c3").unwrap();
2032        writeln!(csv1, "1,\"foo\",0.5").unwrap();
2033        writeln!(csv1, "3,\"bar\",1").unwrap();
2034        writeln!(csv1, "3,\"bar\",2e-06").unwrap();
2035        // reading csv2 will set c2 to optional
2036        writeln!(csv2, "c1,c2,c3,c4").unwrap();
2037        writeln!(csv2, "10,,3.14,true").unwrap();
2038        // reading csv4 will set c3 to optional
2039        writeln!(csv4, "c1,c2,c3").unwrap();
2040        writeln!(csv4, "10,\"foo\",").unwrap();
2041
2042        let schema = infer_schema_from_files(
2043            &[
2044                csv3.path().to_str().unwrap().to_string(),
2045                csv1.path().to_str().unwrap().to_string(),
2046                csv2.path().to_str().unwrap().to_string(),
2047                csv4.path().to_str().unwrap().to_string(),
2048            ],
2049            b',',
2050            Some(4), // only csv1 and csv2 should be read
2051            true,
2052        )
2053        .unwrap();
2054
2055        assert_eq!(schema.fields().len(), 4);
2056        assert!(schema.field(0).is_nullable());
2057        assert!(schema.field(1).is_nullable());
2058        assert!(schema.field(2).is_nullable());
2059        assert!(schema.field(3).is_nullable());
2060
2061        assert_eq!(&DataType::Int64, schema.field(0).data_type());
2062        assert_eq!(&DataType::Utf8, schema.field(1).data_type());
2063        assert_eq!(&DataType::Float64, schema.field(2).data_type());
2064        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2065    }
2066
2067    #[test]
2068    fn test_bounded() {
2069        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2070        let data = [
2071            vec!["0"],
2072            vec!["1"],
2073            vec!["2"],
2074            vec!["3"],
2075            vec!["4"],
2076            vec!["5"],
2077            vec!["6"],
2078        ];
2079
2080        let data = data
2081            .iter()
2082            .map(|x| x.join(","))
2083            .collect::<Vec<_>>()
2084            .join("\n");
2085        let data = data.as_bytes();
2086
2087        let reader = std::io::Cursor::new(data);
2088
2089        let mut csv = ReaderBuilder::new(Arc::new(schema))
2090            .with_batch_size(2)
2091            .with_projection(vec![0])
2092            .with_bounds(2, 6)
2093            .build_buffered(reader)
2094            .unwrap();
2095
2096        let batch = csv.next().unwrap().unwrap();
2097        let a = batch.column(0);
2098        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2099        assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2100
2101        let batch = csv.next().unwrap().unwrap();
2102        let a = batch.column(0);
2103        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2104        assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2105
2106        assert!(csv.next().is_none());
2107    }
2108
2109    #[test]
2110    fn test_empty_projection() {
2111        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2112        let data = [vec!["0"], vec!["1"]];
2113
2114        let data = data
2115            .iter()
2116            .map(|x| x.join(","))
2117            .collect::<Vec<_>>()
2118            .join("\n");
2119
2120        let mut csv = ReaderBuilder::new(Arc::new(schema))
2121            .with_batch_size(2)
2122            .with_projection(vec![])
2123            .build_buffered(Cursor::new(data.as_bytes()))
2124            .unwrap();
2125
2126        let batch = csv.next().unwrap().unwrap();
2127        assert_eq!(batch.columns().len(), 0);
2128        assert_eq!(batch.num_rows(), 2);
2129
2130        assert!(csv.next().is_none());
2131    }
2132
2133    #[test]
2134    fn test_parsing_bool() {
2135        // Encode the expected behavior of boolean parsing
2136        assert_eq!(Some(true), parse_bool("true"));
2137        assert_eq!(Some(true), parse_bool("tRUe"));
2138        assert_eq!(Some(true), parse_bool("True"));
2139        assert_eq!(Some(true), parse_bool("TRUE"));
2140        assert_eq!(None, parse_bool("t"));
2141        assert_eq!(None, parse_bool("T"));
2142        assert_eq!(None, parse_bool(""));
2143
2144        assert_eq!(Some(false), parse_bool("false"));
2145        assert_eq!(Some(false), parse_bool("fALse"));
2146        assert_eq!(Some(false), parse_bool("False"));
2147        assert_eq!(Some(false), parse_bool("FALSE"));
2148        assert_eq!(None, parse_bool("f"));
2149        assert_eq!(None, parse_bool("F"));
2150        assert_eq!(None, parse_bool(""));
2151    }
2152
2153    #[test]
2154    fn test_parsing_float() {
2155        assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2156        assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2157        assert_eq!(Some(12.0), Float64Type::parse("12"));
2158        assert_eq!(Some(0.0), Float64Type::parse("0"));
2159        assert_eq!(Some(2.0), Float64Type::parse("2."));
2160        assert_eq!(Some(0.2), Float64Type::parse(".2"));
2161        assert!(Float64Type::parse("nan").unwrap().is_nan());
2162        assert!(Float64Type::parse("NaN").unwrap().is_nan());
2163        assert!(Float64Type::parse("inf").unwrap().is_infinite());
2164        assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2165        assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2166        assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2167        assert_eq!(None, Float64Type::parse(""));
2168        assert_eq!(None, Float64Type::parse("dd"));
2169        assert_eq!(None, Float64Type::parse("12.34.56"));
2170    }
2171
2172    #[test]
2173    fn test_non_std_quote() {
2174        let schema = Schema::new(vec![
2175            Field::new("text1", DataType::Utf8, false),
2176            Field::new("text2", DataType::Utf8, false),
2177        ]);
2178        let builder = ReaderBuilder::new(Arc::new(schema))
2179            .with_header(false)
2180            .with_quote(b'~'); // default is ", change to ~
2181
2182        let mut csv_text = Vec::new();
2183        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2184        for index in 0..10 {
2185            let text1 = format!("id{index:}");
2186            let text2 = format!("value{index:}");
2187            csv_writer
2188                .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2189                .unwrap();
2190        }
2191        let mut csv_reader = std::io::Cursor::new(&csv_text);
2192        let mut reader = builder.build(&mut csv_reader).unwrap();
2193        let batch = reader.next().unwrap().unwrap();
2194        let col0 = batch.column(0);
2195        assert_eq!(col0.len(), 10);
2196        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2197        assert_eq!(col0_arr.value(0), "id0");
2198        let col1 = batch.column(1);
2199        assert_eq!(col1.len(), 10);
2200        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2201        assert_eq!(col1_arr.value(5), "value5");
2202    }
2203
2204    #[test]
2205    fn test_non_std_escape() {
2206        let schema = Schema::new(vec![
2207            Field::new("text1", DataType::Utf8, false),
2208            Field::new("text2", DataType::Utf8, false),
2209        ]);
2210        let builder = ReaderBuilder::new(Arc::new(schema))
2211            .with_header(false)
2212            .with_escape(b'\\'); // default is None, change to \
2213
2214        let mut csv_text = Vec::new();
2215        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2216        for index in 0..10 {
2217            let text1 = format!("id{index:}");
2218            let text2 = format!("value\\\"{index:}");
2219            csv_writer
2220                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2221                .unwrap();
2222        }
2223        let mut csv_reader = std::io::Cursor::new(&csv_text);
2224        let mut reader = builder.build(&mut csv_reader).unwrap();
2225        let batch = reader.next().unwrap().unwrap();
2226        let col0 = batch.column(0);
2227        assert_eq!(col0.len(), 10);
2228        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2229        assert_eq!(col0_arr.value(0), "id0");
2230        let col1 = batch.column(1);
2231        assert_eq!(col1.len(), 10);
2232        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2233        assert_eq!(col1_arr.value(5), "value\"5");
2234    }
2235
2236    #[test]
2237    fn test_non_std_terminator() {
2238        let schema = Schema::new(vec![
2239            Field::new("text1", DataType::Utf8, false),
2240            Field::new("text2", DataType::Utf8, false),
2241        ]);
2242        let builder = ReaderBuilder::new(Arc::new(schema))
2243            .with_header(false)
2244            .with_terminator(b'\n'); // default is CRLF, change to LF
2245
2246        let mut csv_text = Vec::new();
2247        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2248        for index in 0..10 {
2249            let text1 = format!("id{index:}");
2250            let text2 = format!("value{index:}");
2251            csv_writer
2252                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2253                .unwrap();
2254        }
2255        let mut csv_reader = std::io::Cursor::new(&csv_text);
2256        let mut reader = builder.build(&mut csv_reader).unwrap();
2257        let batch = reader.next().unwrap().unwrap();
2258        let col0 = batch.column(0);
2259        assert_eq!(col0.len(), 10);
2260        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2261        assert_eq!(col0_arr.value(0), "id0");
2262        let col1 = batch.column(1);
2263        assert_eq!(col1.len(), 10);
2264        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2265        assert_eq!(col1_arr.value(5), "value5");
2266    }
2267
2268    #[test]
2269    fn test_header_bounds() {
2270        let csv = "a,b\na,b\na,b\na,b\na,b\n";
2271        let tests = [
2272            (None, false, 5),
2273            (None, true, 4),
2274            (Some((0, 4)), false, 4),
2275            (Some((1, 4)), false, 3),
2276            (Some((0, 4)), true, 4),
2277            (Some((1, 4)), true, 3),
2278        ];
2279        let schema = Arc::new(Schema::new(vec![
2280            Field::new("a", DataType::Utf8, false),
2281            Field::new("a", DataType::Utf8, false),
2282        ]));
2283
2284        for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2285            let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2286            if let Some((start, end)) = bounds {
2287                reader = reader.with_bounds(start, end);
2288            }
2289            let b = reader
2290                .build_buffered(Cursor::new(csv.as_bytes()))
2291                .unwrap()
2292                .next()
2293                .unwrap()
2294                .unwrap();
2295            assert_eq!(b.num_rows(), expected, "{idx}");
2296        }
2297    }
2298
2299    #[test]
2300    fn test_null_boolean() {
2301        let csv = "true,false\nFalse,True\n,True\nFalse,";
2302        let schema = Arc::new(Schema::new(vec![
2303            Field::new("a", DataType::Boolean, true),
2304            Field::new("a", DataType::Boolean, true),
2305        ]));
2306
2307        let b = ReaderBuilder::new(schema)
2308            .build_buffered(Cursor::new(csv.as_bytes()))
2309            .unwrap()
2310            .next()
2311            .unwrap()
2312            .unwrap();
2313
2314        assert_eq!(b.num_rows(), 4);
2315        assert_eq!(b.num_columns(), 2);
2316
2317        let c = b.column(0).as_boolean();
2318        assert_eq!(c.null_count(), 1);
2319        assert!(c.value(0));
2320        assert!(!c.value(1));
2321        assert!(c.is_null(2));
2322        assert!(!c.value(3));
2323
2324        let c = b.column(1).as_boolean();
2325        assert_eq!(c.null_count(), 1);
2326        assert!(!c.value(0));
2327        assert!(c.value(1));
2328        assert!(c.value(2));
2329        assert!(c.is_null(3));
2330    }
2331
2332    #[test]
2333    fn test_truncated_rows() {
2334        let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2335        let schema = Arc::new(Schema::new(vec![
2336            Field::new("a", DataType::Int32, true),
2337            Field::new("b", DataType::Int32, true),
2338            Field::new("c", DataType::Int32, true),
2339        ]));
2340
2341        let reader = ReaderBuilder::new(schema.clone())
2342            .with_header(true)
2343            .with_truncated_rows(true)
2344            .build(Cursor::new(data))
2345            .unwrap();
2346
2347        let batches = reader.collect::<Result<Vec<_>, _>>();
2348        assert!(batches.is_ok());
2349        let batch = batches.unwrap().into_iter().next().unwrap();
2350        // Empty rows are skipped by the underlying csv parser
2351        assert_eq!(batch.num_rows(), 3);
2352
2353        let reader = ReaderBuilder::new(schema.clone())
2354            .with_header(true)
2355            .with_truncated_rows(false)
2356            .build(Cursor::new(data))
2357            .unwrap();
2358
2359        let batches = reader.collect::<Result<Vec<_>, _>>();
2360        assert!(match batches {
2361            Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2362            _ => false,
2363        });
2364    }
2365
2366    #[test]
2367    fn test_truncated_rows_csv() {
2368        let file = File::open("test/data/truncated_rows.csv").unwrap();
2369        let schema = Arc::new(Schema::new(vec![
2370            Field::new("Name", DataType::Utf8, true),
2371            Field::new("Age", DataType::UInt32, true),
2372            Field::new("Occupation", DataType::Utf8, true),
2373            Field::new("DOB", DataType::Date32, true),
2374        ]));
2375        let reader = ReaderBuilder::new(schema.clone())
2376            .with_header(true)
2377            .with_batch_size(24)
2378            .with_truncated_rows(true);
2379        let csv = reader.build(file).unwrap();
2380        let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2381
2382        assert_eq!(batches.len(), 1);
2383        let batch = &batches[0];
2384        assert_eq!(batch.num_rows(), 6);
2385        assert_eq!(batch.num_columns(), 4);
2386        let name = batch
2387            .column(0)
2388            .as_any()
2389            .downcast_ref::<StringArray>()
2390            .unwrap();
2391        let age = batch
2392            .column(1)
2393            .as_any()
2394            .downcast_ref::<UInt32Array>()
2395            .unwrap();
2396        let occupation = batch
2397            .column(2)
2398            .as_any()
2399            .downcast_ref::<StringArray>()
2400            .unwrap();
2401        let dob = batch
2402            .column(3)
2403            .as_any()
2404            .downcast_ref::<Date32Array>()
2405            .unwrap();
2406
2407        assert_eq!(name.value(0), "A1");
2408        assert_eq!(name.value(1), "B2");
2409        assert!(name.is_null(2));
2410        assert_eq!(name.value(3), "C3");
2411        assert_eq!(name.value(4), "D4");
2412        assert_eq!(name.value(5), "E5");
2413
2414        assert_eq!(age.value(0), 34);
2415        assert_eq!(age.value(1), 29);
2416        assert!(age.is_null(2));
2417        assert_eq!(age.value(3), 45);
2418        assert!(age.is_null(4));
2419        assert_eq!(age.value(5), 31);
2420
2421        assert_eq!(occupation.value(0), "Engineer");
2422        assert_eq!(occupation.value(1), "Doctor");
2423        assert!(occupation.is_null(2));
2424        assert_eq!(occupation.value(3), "Artist");
2425        assert!(occupation.is_null(4));
2426        assert!(occupation.is_null(5));
2427
2428        assert_eq!(dob.value(0), 5675);
2429        assert!(dob.is_null(1));
2430        assert!(dob.is_null(2));
2431        assert_eq!(dob.value(3), -1858);
2432        assert!(dob.is_null(4));
2433        assert!(dob.is_null(5));
2434    }
2435
2436    #[test]
2437    fn test_truncated_rows_not_nullable_error() {
2438        let data = "a,b,c\n1,2,3\n4,5";
2439        let schema = Arc::new(Schema::new(vec![
2440            Field::new("a", DataType::Int32, false),
2441            Field::new("b", DataType::Int32, false),
2442            Field::new("c", DataType::Int32, false),
2443        ]));
2444
2445        let reader = ReaderBuilder::new(schema.clone())
2446            .with_header(true)
2447            .with_truncated_rows(true)
2448            .build(Cursor::new(data))
2449            .unwrap();
2450
2451        let batches = reader.collect::<Result<Vec<_>, _>>();
2452        assert!(match batches {
2453            Err(ArrowError::InvalidArgumentError(e)) =>
2454                e.to_string().contains("contains null values"),
2455            _ => false,
2456        });
2457    }
2458
2459    #[test]
2460    fn test_buffered() {
2461        let tests = [
2462            ("test/data/uk_cities.csv", false, 37),
2463            ("test/data/various_types.csv", true, 10),
2464            ("test/data/decimal_test.csv", false, 10),
2465        ];
2466
2467        for (path, has_header, expected_rows) in tests {
2468            let (schema, _) = Format::default()
2469                .infer_schema(File::open(path).unwrap(), None)
2470                .unwrap();
2471            let schema = Arc::new(schema);
2472
2473            for batch_size in [1, 4] {
2474                for capacity in [1, 3, 7, 100] {
2475                    let reader = ReaderBuilder::new(schema.clone())
2476                        .with_batch_size(batch_size)
2477                        .with_header(has_header)
2478                        .build(File::open(path).unwrap())
2479                        .unwrap();
2480
2481                    let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2482
2483                    assert_eq!(
2484                        expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2485                        expected_rows
2486                    );
2487
2488                    let buffered =
2489                        std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2490
2491                    let reader = ReaderBuilder::new(schema.clone())
2492                        .with_batch_size(batch_size)
2493                        .with_header(has_header)
2494                        .build_buffered(buffered)
2495                        .unwrap();
2496
2497                    let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2498                    assert_eq!(expected, actual)
2499                }
2500            }
2501        }
2502    }
2503
2504    fn err_test(csv: &[u8], expected: &str) {
2505        fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2506            let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2507            let b = ReaderBuilder::new(schema)
2508                .with_batch_size(2)
2509                .build_buffered(buffer)
2510                .unwrap();
2511            let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2512            assert_eq!(err, expected)
2513        }
2514
2515        let schema_utf8 = Arc::new(Schema::new(vec![
2516            Field::new("text1", DataType::Utf8, true),
2517            Field::new("text2", DataType::Utf8, true),
2518        ]));
2519        err_test_with_schema(csv, expected, schema_utf8);
2520
2521        let schema_utf8view = Arc::new(Schema::new(vec![
2522            Field::new("text1", DataType::Utf8View, true),
2523            Field::new("text2", DataType::Utf8View, true),
2524        ]));
2525        err_test_with_schema(csv, expected, schema_utf8view);
2526    }
2527
2528    #[test]
2529    fn test_invalid_utf8() {
2530        err_test(
2531            b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2532            "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2533        );
2534
2535        err_test(
2536            b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2537            "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2538        );
2539
2540        err_test(
2541            b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2542            "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2543        );
2544
2545        err_test(
2546            b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2547            "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2548        );
2549    }
2550
2551    struct InstrumentedRead<R> {
2552        r: R,
2553        fill_count: usize,
2554        fill_sizes: Vec<usize>,
2555    }
2556
2557    impl<R> InstrumentedRead<R> {
2558        fn new(r: R) -> Self {
2559            Self {
2560                r,
2561                fill_count: 0,
2562                fill_sizes: vec![],
2563            }
2564        }
2565    }
2566
2567    impl<R: Seek> Seek for InstrumentedRead<R> {
2568        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2569            self.r.seek(pos)
2570        }
2571    }
2572
2573    impl<R: BufRead> Read for InstrumentedRead<R> {
2574        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2575            self.r.read(buf)
2576        }
2577    }
2578
2579    impl<R: BufRead> BufRead for InstrumentedRead<R> {
2580        fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2581            self.fill_count += 1;
2582            let buf = self.r.fill_buf()?;
2583            self.fill_sizes.push(buf.len());
2584            Ok(buf)
2585        }
2586
2587        fn consume(&mut self, amt: usize) {
2588            self.r.consume(amt)
2589        }
2590    }
2591
2592    #[test]
2593    fn test_io() {
2594        let schema = Arc::new(Schema::new(vec![
2595            Field::new("a", DataType::Utf8, false),
2596            Field::new("b", DataType::Utf8, false),
2597        ]));
2598        let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2599        let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2600        let reader = ReaderBuilder::new(schema)
2601            .with_batch_size(3)
2602            .build_buffered(&mut read)
2603            .unwrap();
2604
2605        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2606        assert_eq!(batches.len(), 2);
2607        assert_eq!(batches[0].num_rows(), 3);
2608        assert_eq!(batches[1].num_rows(), 1);
2609
2610        // Expect 4 calls to fill_buf
2611        // 1. Read first 3 rows
2612        // 2. Read final row
2613        // 3. Delimit and flush final row
2614        // 4. Iterator finished
2615        assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2616        assert_eq!(read.fill_count, 4);
2617    }
2618
2619    #[test]
2620    fn test_inference() {
2621        let cases: &[(&[&str], DataType)] = &[
2622            (&[], DataType::Null),
2623            (&["false", "12"], DataType::Utf8),
2624            (&["12", "cupcakes"], DataType::Utf8),
2625            (&["12", "12.4"], DataType::Float64),
2626            (&["14050", "24332"], DataType::Int64),
2627            (&["14050.0", "true"], DataType::Utf8),
2628            (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2629            (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2630            (
2631                &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2632                DataType::Timestamp(TimeUnit::Second, None),
2633            ),
2634            (&["2020-03-19", "2020-03-20"], DataType::Date32),
2635            (
2636                &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2637                DataType::Timestamp(TimeUnit::Second, None),
2638            ),
2639            (
2640                &[
2641                    "2020-03-19",
2642                    "2020-03-19 02:00:00",
2643                    "2020-03-19 00:00:00.000",
2644                ],
2645                DataType::Timestamp(TimeUnit::Millisecond, None),
2646            ),
2647            (
2648                &[
2649                    "2020-03-19",
2650                    "2020-03-19 02:00:00",
2651                    "2020-03-19 00:00:00.000000",
2652                ],
2653                DataType::Timestamp(TimeUnit::Microsecond, None),
2654            ),
2655            (
2656                &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2657                DataType::Timestamp(TimeUnit::Second, None),
2658            ),
2659            (
2660                &[
2661                    "2020-03-19",
2662                    "2020-03-19 02:00:00+02:00",
2663                    "2020-03-19 02:00:00Z",
2664                    "2020-03-19 02:00:00.12Z",
2665                ],
2666                DataType::Timestamp(TimeUnit::Millisecond, None),
2667            ),
2668            (
2669                &[
2670                    "2020-03-19",
2671                    "2020-03-19 02:00:00.000000000",
2672                    "2020-03-19 00:00:00.000000",
2673                ],
2674                DataType::Timestamp(TimeUnit::Nanosecond, None),
2675            ),
2676        ];
2677
2678        for (values, expected) in cases {
2679            let mut t = InferredDataType::default();
2680            for v in *values {
2681                t.update(v)
2682            }
2683            assert_eq!(&t.get(), expected, "{values:?}")
2684        }
2685    }
2686
2687    #[test]
2688    fn test_record_length_mismatch() {
2689        let csv = "\
2690        a,b,c\n\
2691        1,2,3\n\
2692        4,5\n\
2693        6,7,8";
2694        let mut read = Cursor::new(csv.as_bytes());
2695        let result = Format::default()
2696            .with_header(true)
2697            .infer_schema(&mut read, None);
2698        assert!(result.is_err());
2699        // Include line number in the error message to help locate and fix the issue
2700        assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3");
2701    }
2702
2703    #[test]
2704    fn test_comment() {
2705        let schema = Schema::new(vec![
2706            Field::new("a", DataType::Int8, false),
2707            Field::new("b", DataType::Int8, false),
2708        ]);
2709
2710        let csv = "# comment1 \n1,2\n#comment2\n11,22";
2711        let mut read = Cursor::new(csv.as_bytes());
2712        let reader = ReaderBuilder::new(Arc::new(schema))
2713            .with_comment(b'#')
2714            .build(&mut read)
2715            .unwrap();
2716
2717        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2718        assert_eq!(batches.len(), 1);
2719        let b = batches.first().unwrap();
2720        assert_eq!(b.num_columns(), 2);
2721        assert_eq!(
2722            b.column(0)
2723                .as_any()
2724                .downcast_ref::<Int8Array>()
2725                .unwrap()
2726                .values(),
2727            &vec![1, 11]
2728        );
2729        assert_eq!(
2730            b.column(1)
2731                .as_any()
2732                .downcast_ref::<Int8Array>()
2733                .unwrap()
2734                .values(),
2735            &vec![2, 22]
2736        );
2737    }
2738
2739    #[test]
2740    fn test_parse_string_view_single_column() {
2741        let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2742        let schema = Arc::new(Schema::new(vec![Field::new(
2743            "c1",
2744            DataType::Utf8View,
2745            true,
2746        )]));
2747
2748        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2749
2750        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2751        assert_eq!(decoded, csv.len());
2752        decoder.decode(&[]).unwrap();
2753
2754        let batch = decoder.flush().unwrap().unwrap();
2755        assert_eq!(batch.num_columns(), 1);
2756        assert_eq!(batch.num_rows(), 3);
2757        let col = batch.column(0).as_string_view();
2758        assert_eq!(col.data_type(), &DataType::Utf8View);
2759        assert_eq!(col.value(0), "foo");
2760        assert_eq!(col.value(1), "something_cannot_be_inlined");
2761        assert_eq!(col.value(2), "foobar");
2762    }
2763
2764    #[test]
2765    fn test_parse_string_view_multi_column() {
2766        let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2767        let schema = Arc::new(Schema::new(vec![
2768            Field::new("c1", DataType::Utf8View, true),
2769            Field::new("c2", DataType::Utf8View, true),
2770        ]));
2771
2772        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2773
2774        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2775        assert_eq!(decoded, csv.len());
2776        decoder.decode(&[]).unwrap();
2777
2778        let batch = decoder.flush().unwrap().unwrap();
2779        assert_eq!(batch.num_columns(), 2);
2780        assert_eq!(batch.num_rows(), 3);
2781        let c1 = batch.column(0).as_string_view();
2782        let c2 = batch.column(1).as_string_view();
2783        assert_eq!(c1.data_type(), &DataType::Utf8View);
2784        assert_eq!(c2.data_type(), &DataType::Utf8View);
2785
2786        assert!(!c1.is_null(0));
2787        assert!(c1.is_null(1));
2788        assert!(!c1.is_null(2));
2789        assert_eq!(c1.value(0), "foo");
2790        assert_eq!(c1.value(2), "foobarfoobar");
2791
2792        assert!(c2.is_null(0));
2793        assert!(!c2.is_null(1));
2794        assert!(!c2.is_null(2));
2795        assert_eq!(c2.value(1), "something_cannot_be_inlined");
2796        assert_eq!(c2.value(2), "bar");
2797    }
2798}