arrow_csv/reader/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CSV Reading: [`Reader`] and [`ReaderBuilder`]
19//!
20//! # Basic Usage
21//!
22//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are
23//! loaded in batches and are then converted from row-based data to columnar data.
24//!
25//! Example:
26//!
27//! ```
28//! # use arrow_schema::*;
29//! # use arrow_csv::{Reader, ReaderBuilder};
30//! # use std::fs::File;
31//! # use std::sync::Arc;
32//!
33//! let schema = Schema::new(vec![
34//!     Field::new("city", DataType::Utf8, false),
35//!     Field::new("lat", DataType::Float64, false),
36//!     Field::new("lng", DataType::Float64, false),
37//! ]);
38//!
39//! let file = File::open("test/data/uk_cities.csv").unwrap();
40//!
41//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
42//! let batch = csv.next().unwrap().unwrap();
43//! ```
44//!
45//! # Example: Numeric calculations on CSV
46//! This code finds the maximum value in column 0 of a CSV file containing
47//! ```csv
48//! c1,c2,c3,c4
49//! 1,1.1,"hong kong",true
50//! 3,323.12,"XiAn",false
51//! 10,131323.12,"cheng du",false
52//! ```
53//!
54//! ```
55//! # use arrow_array::cast::AsArray;
56//! # use arrow_array::types::Int16Type;
57//! # use arrow_csv::ReaderBuilder;
58//! # use arrow_schema::{DataType, Field, Schema};
59//! # use std::fs::File;
60//! # use std::sync::Arc;
61//! // Open the example file
62//! let file = File::open("test/data/example.csv").unwrap();
63//! let csv_schema = Schema::new(vec![
64//!     Field::new("c1", DataType::Int16, true),
65//!     Field::new("c2", DataType::Float32, true),
66//!     Field::new("c3", DataType::Utf8, true),
67//!     Field::new("c4", DataType::Boolean, true),
68//! ]);
69//! let mut reader = ReaderBuilder::new(Arc::new(csv_schema))
70//!     .with_header(true)
71//!     .build(file)
72//!     .unwrap();
73//! // find the maximum value in column 0 across all batches
74//! let mut max_c0 = 0;
75//! while let Some(r) = reader.next() {
76//!   let r = r.unwrap(); // handle error
77//!   // get the max value in column(0) for this batch
78//!   let col = r.column(0).as_primitive::<Int16Type>();
79//!   let batch_max = col.iter().max().flatten().unwrap_or_default();
80//!   max_c0 = max_c0.max(batch_max);
81//! }
82//! assert_eq!(max_c0, 10);
83//!```
84//!
85//! # Async Usage
86//!
87//! The lower-level [`Decoder`] can be integrated with various forms of async data streams,
88//! and is designed to be agnostic to the various different kinds of async IO primitives found
89//! within the Rust ecosystem.
90//!
91//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes`
92//!
93//! ```
94//! # use std::task::{Poll, ready};
95//! # use bytes::{Buf, Bytes};
96//! # use arrow_schema::ArrowError;
97//! # use futures::stream::{Stream, StreamExt};
98//! # use arrow_array::RecordBatch;
99//! # use arrow_csv::reader::Decoder;
100//! #
101//! fn decode_stream<S: Stream<Item = Bytes> + Unpin>(
102//!     mut decoder: Decoder,
103//!     mut input: S,
104//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
105//!     let mut buffered = Bytes::new();
106//!     futures::stream::poll_fn(move |cx| {
107//!         loop {
108//!             if buffered.is_empty() {
109//!                 if let Some(b) = ready!(input.poll_next_unpin(cx)) {
110//!                     buffered = b;
111//!                 }
112//!                 // Note: don't break on `None` as the decoder needs
113//!                 // to be called with an empty array to delimit the
114//!                 // final record
115//!             }
116//!             let decoded = match decoder.decode(buffered.as_ref()) {
117//!                 Ok(0) => break,
118//!                 Ok(decoded) => decoded,
119//!                 Err(e) => return Poll::Ready(Some(Err(e))),
120//!             };
121//!             buffered.advance(decoded);
122//!         }
123//!
124//!         Poll::Ready(decoder.flush().transpose())
125//!     })
126//! }
127//!
128//! ```
129//!
130//! In a similar vein, it can also be used with tokio-based IO primitives
131//!
132//! ```
133//! # use std::pin::Pin;
134//! # use std::task::{Poll, ready};
135//! # use futures::Stream;
136//! # use tokio::io::AsyncBufRead;
137//! # use arrow_array::RecordBatch;
138//! # use arrow_csv::reader::Decoder;
139//! # use arrow_schema::ArrowError;
140//! fn decode_stream<R: AsyncBufRead + Unpin>(
141//!     mut decoder: Decoder,
142//!     mut reader: R,
143//! ) -> impl Stream<Item = Result<RecordBatch, ArrowError>> {
144//!     futures::stream::poll_fn(move |cx| {
145//!         loop {
146//!             let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) {
147//!                 Ok(b) => b,
148//!                 Err(e) => return Poll::Ready(Some(Err(e.into()))),
149//!             };
150//!             let decoded = match decoder.decode(b) {
151//!                 // Note: the decoder needs to be called with an empty
152//!                 // array to delimit the final record
153//!                 Ok(0) => break,
154//!                 Ok(decoded) => decoded,
155//!                 Err(e) => return Poll::Ready(Some(Err(e))),
156//!             };
157//!             Pin::new(&mut reader).consume(decoded);
158//!         }
159//!
160//!         Poll::Ready(decoder.flush().transpose())
161//!     })
162//! }
163//! ```
164//!
165
166mod records;
167
168use arrow_array::builder::{NullBuilder, PrimitiveBuilder};
169use arrow_array::types::*;
170use arrow_array::*;
171use arrow_cast::parse::{Parser, parse_decimal, string_to_datetime};
172use arrow_schema::*;
173use chrono::{TimeZone, Utc};
174use csv::StringRecord;
175use regex::{Regex, RegexSet};
176use std::fmt::{self, Debug};
177use std::fs::File;
178use std::io::{BufRead, BufReader as StdBufReader, Read};
179use std::sync::{Arc, LazyLock};
180
181use crate::map_csv_error;
182use crate::reader::records::{RecordDecoder, StringRecords};
183use arrow_array::timezone::Tz;
184
185/// Order should match [`InferredDataType`]
186static REGEX_SET: LazyLock<RegexSet> = LazyLock::new(|| {
187    RegexSet::new([
188        r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
189        r"^-?(\d+)$",                   //INTEGER
190        r"^-?((\d*\.\d+|\d+\.\d*)([eE][-+]?\d+)?|\d+([eE][-+]?\d+))$", //DECIMAL
191        r"^\d{4}-\d\d-\d\d$",           //DATE32
192        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second)
193        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond)
194        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond)
195        r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond)
196    ])
197    .unwrap()
198});
199
200/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
201#[derive(Debug, Clone, Default)]
202struct NullRegex(Option<Regex>);
203
204impl NullRegex {
205    /// Returns true if the value should be considered as `NULL` according to
206    /// the provided regular expression.
207    #[inline]
208    fn is_null(&self, s: &str) -> bool {
209        match &self.0 {
210            Some(r) => r.is_match(s),
211            None => s.is_empty(),
212        }
213    }
214}
215
216#[derive(Default, Copy, Clone)]
217struct InferredDataType {
218    /// Packed booleans indicating type
219    ///
220    /// 0 - Boolean
221    /// 1 - Integer
222    /// 2 - Float64
223    /// 3 - Date32
224    /// 4 - Timestamp(Second)
225    /// 5 - Timestamp(Millisecond)
226    /// 6 - Timestamp(Microsecond)
227    /// 7 - Timestamp(Nanosecond)
228    /// 8 - Utf8
229    packed: u16,
230}
231
232impl InferredDataType {
233    /// Returns the inferred data type
234    fn get(&self) -> DataType {
235        match self.packed {
236            0 => DataType::Null,
237            1 => DataType::Boolean,
238            2 => DataType::Int64,
239            4 | 6 => DataType::Float64, // Promote Int64 to Float64
240            b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
241                // Promote to highest precision temporal type
242                8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
243                9 => DataType::Timestamp(TimeUnit::Microsecond, None),
244                10 => DataType::Timestamp(TimeUnit::Millisecond, None),
245                11 => DataType::Timestamp(TimeUnit::Second, None),
246                12 => DataType::Date32,
247                _ => unreachable!(),
248            },
249            _ => DataType::Utf8,
250        }
251    }
252
253    /// Updates the [`InferredDataType`] with the given string
254    fn update(&mut self, string: &str) {
255        self.packed |= if string.starts_with('"') {
256            1 << 8 // Utf8
257        } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
258            if m == 1 && string.len() >= 19 && string.parse::<i64>().is_err() {
259                // if overflow i64, fallback to utf8
260                1 << 8
261            } else {
262                1 << m
263            }
264        } else if string == "NaN" || string == "nan" || string == "inf" || string == "-inf" {
265            1 << 2 // Float64
266        } else {
267            1 << 8 // Utf8
268        }
269    }
270}
271
272/// The format specification for the CSV file
273#[derive(Debug, Clone, Default)]
274pub struct Format {
275    header: bool,
276    delimiter: Option<u8>,
277    escape: Option<u8>,
278    quote: Option<u8>,
279    terminator: Option<u8>,
280    comment: Option<u8>,
281    null_regex: NullRegex,
282    truncated_rows: bool,
283}
284
285impl Format {
286    /// Specify whether the CSV file has a header, defaults to `false`
287    ///
288    /// When `true`, the first row of the CSV file is treated as a header row
289    pub fn with_header(mut self, has_header: bool) -> Self {
290        self.header = has_header;
291        self
292    }
293
294    /// Specify a custom delimiter character, defaults to comma `','`
295    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
296        self.delimiter = Some(delimiter);
297        self
298    }
299
300    /// Specify an escape character, defaults to `None`
301    pub fn with_escape(mut self, escape: u8) -> Self {
302        self.escape = Some(escape);
303        self
304    }
305
306    /// Specify a custom quote character, defaults to double quote `'"'`
307    pub fn with_quote(mut self, quote: u8) -> Self {
308        self.quote = Some(quote);
309        self
310    }
311
312    /// Specify a custom terminator character, defaults to CRLF
313    pub fn with_terminator(mut self, terminator: u8) -> Self {
314        self.terminator = Some(terminator);
315        self
316    }
317
318    /// Specify a comment character, defaults to `None`
319    ///
320    /// Lines starting with this character will be ignored
321    pub fn with_comment(mut self, comment: u8) -> Self {
322        self.comment = Some(comment);
323        self
324    }
325
326    /// Provide a regex to match null values, defaults to `^$`
327    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
328        self.null_regex = NullRegex(Some(null_regex));
329        self
330    }
331
332    /// Whether to allow truncated rows when parsing.
333    ///
334    /// By default this is set to `false` and will error if the CSV rows have different lengths.
335    /// When set to true then it will allow records with less than the expected number of columns
336    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
337    /// will still return an error.
338    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
339        self.truncated_rows = allow;
340        self
341    }
342
343    /// Infer schema of CSV records from the provided `reader`
344    ///
345    /// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
346    /// records are read to infer the schema
347    ///
348    /// Returns inferred schema and number of records read
349    pub fn infer_schema<R: Read>(
350        &self,
351        reader: R,
352        max_records: Option<usize>,
353    ) -> Result<(Schema, usize), ArrowError> {
354        let mut csv_reader = self.build_reader(reader);
355
356        // get or create header names
357        // when has_header is false, creates default column names with column_ prefix
358        let headers: Vec<String> = if self.header {
359            let headers = &csv_reader.headers().map_err(map_csv_error)?.clone();
360            headers.iter().map(|s| s.to_string()).collect()
361        } else {
362            let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len();
363            (0..*first_record_count)
364                .map(|i| format!("column_{}", i + 1))
365                .collect()
366        };
367
368        let header_length = headers.len();
369        // keep track of inferred field types
370        let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];
371
372        let mut records_count = 0;
373
374        let mut record = StringRecord::new();
375        let max_records = max_records.unwrap_or(usize::MAX);
376        while records_count < max_records {
377            if !csv_reader.read_record(&mut record).map_err(map_csv_error)? {
378                break;
379            }
380            records_count += 1;
381
382            // Note since we may be looking at a sample of the data, we make the safe assumption that
383            // they could be nullable
384            for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
385                if let Some(string) = record.get(i) {
386                    if !self.null_regex.is_null(string) {
387                        column_type.update(string)
388                    }
389                }
390            }
391        }
392
393        // build schema from inference results
394        let fields: Fields = column_types
395            .iter()
396            .zip(&headers)
397            .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
398            .collect();
399
400        Ok((Schema::new(fields), records_count))
401    }
402
403    /// Build a [`csv::Reader`] for this [`Format`]
404    fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
405        let mut builder = csv::ReaderBuilder::new();
406        builder.has_headers(self.header);
407        builder.flexible(self.truncated_rows);
408
409        if let Some(c) = self.delimiter {
410            builder.delimiter(c);
411        }
412        builder.escape(self.escape);
413        if let Some(c) = self.quote {
414            builder.quote(c);
415        }
416        if let Some(t) = self.terminator {
417            builder.terminator(csv::Terminator::Any(t));
418        }
419        if let Some(comment) = self.comment {
420            builder.comment(Some(comment));
421        }
422        builder.from_reader(reader)
423    }
424
425    /// Build a [`csv_core::Reader`] for this [`Format`]
426    fn build_parser(&self) -> csv_core::Reader {
427        let mut builder = csv_core::ReaderBuilder::new();
428        builder.escape(self.escape);
429        builder.comment(self.comment);
430
431        if let Some(c) = self.delimiter {
432            builder.delimiter(c);
433        }
434        if let Some(c) = self.quote {
435            builder.quote(c);
436        }
437        if let Some(t) = self.terminator {
438            builder.terminator(csv_core::Terminator::Any(t));
439        }
440        builder.build()
441    }
442}
443
444/// Infer schema from a list of CSV files by reading through first n records
445/// with `max_read_records` controlling the maximum number of records to read.
446///
447/// Files will be read in the given order until n records have been reached.
448///
449/// If `max_read_records` is not set, all files will be read fully to infer the schema.
450pub fn infer_schema_from_files(
451    files: &[String],
452    delimiter: u8,
453    max_read_records: Option<usize>,
454    has_header: bool,
455) -> Result<Schema, ArrowError> {
456    let mut schemas = vec![];
457    let mut records_to_read = max_read_records.unwrap_or(usize::MAX);
458    let format = Format {
459        delimiter: Some(delimiter),
460        header: has_header,
461        ..Default::default()
462    };
463
464    for fname in files.iter() {
465        let f = File::open(fname)?;
466        let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?;
467        if records_read == 0 {
468            continue;
469        }
470        schemas.push(schema.clone());
471        records_to_read -= records_read;
472        if records_to_read == 0 {
473            break;
474        }
475    }
476
477    Schema::try_merge(schemas)
478}
479
480// optional bounds of the reader, of the form (min line, max line).
481type Bounds = Option<(usize, usize)>;
482
483/// CSV file reader using [`std::io::BufReader`]
484///
485/// See [`ReaderBuilder`] to construct a CSV reader with options and  the
486/// [module-level documentation](crate::reader) for more details and examples
487pub type Reader<R> = BufReader<StdBufReader<R>>;
488
489/// CSV file reader implementation. See [`Reader`] for usage
490///
491/// Despite having the same name as [`std::io::BufReader`, this structure does
492/// not buffer reads itself
493pub struct BufReader<R> {
494    /// File reader
495    reader: R,
496    /// The decoder
497    decoder: Decoder,
498}
499
500impl<R> fmt::Debug for BufReader<R>
501where
502    R: BufRead,
503{
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505        f.debug_struct("Reader")
506            .field("decoder", &self.decoder)
507            .finish()
508    }
509}
510
511impl<R: Read> Reader<R> {
512    /// Returns the schema of the reader, useful for getting the schema without reading
513    /// record batches
514    pub fn schema(&self) -> SchemaRef {
515        match &self.decoder.projection {
516            Some(projection) => {
517                let fields = self.decoder.schema.fields();
518                let projected = projection.iter().map(|i| fields[*i].clone());
519                Arc::new(Schema::new(projected.collect::<Fields>()))
520            }
521            None => self.decoder.schema.clone(),
522        }
523    }
524}
525
526impl<R: BufRead> BufReader<R> {
527    fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
528        loop {
529            let buf = self.reader.fill_buf()?;
530            let decoded = self.decoder.decode(buf)?;
531            self.reader.consume(decoded);
532            // Yield if decoded no bytes or the decoder is full
533            //
534            // The capacity check avoids looping around and potentially
535            // blocking reading data in fill_buf that isn't needed
536            // to flush the next batch
537            if decoded == 0 || self.decoder.capacity() == 0 {
538                break;
539            }
540        }
541
542        self.decoder.flush()
543    }
544}
545
546impl<R: BufRead> Iterator for BufReader<R> {
547    type Item = Result<RecordBatch, ArrowError>;
548
549    fn next(&mut self) -> Option<Self::Item> {
550        self.read().transpose()
551    }
552}
553
554impl<R: BufRead> RecordBatchReader for BufReader<R> {
555    fn schema(&self) -> SchemaRef {
556        self.decoder.schema.clone()
557    }
558}
559
560/// A push-based interface for decoding CSV data from an arbitrary byte stream
561///
562/// See [`Reader`] for a higher-level interface for interface with [`Read`]
563///
564/// The push-based interface facilitates integration with sources that yield arbitrarily
565/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from
566/// object storage
567///
568/// ```
569/// # use std::io::BufRead;
570/// # use arrow_array::RecordBatch;
571/// # use arrow_csv::ReaderBuilder;
572/// # use arrow_schema::{ArrowError, SchemaRef};
573/// #
574/// fn read_from_csv<R: BufRead>(
575///     mut reader: R,
576///     schema: SchemaRef,
577///     batch_size: usize,
578/// ) -> Result<impl Iterator<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
579///     let mut decoder = ReaderBuilder::new(schema)
580///         .with_batch_size(batch_size)
581///         .build_decoder();
582///
583///     let mut next = move || {
584///         loop {
585///             let buf = reader.fill_buf()?;
586///             let decoded = decoder.decode(buf)?;
587///             if decoded == 0 {
588///                 break;
589///             }
590///
591///             // Consume the number of bytes read
592///             reader.consume(decoded);
593///         }
594///         decoder.flush()
595///     };
596///     Ok(std::iter::from_fn(move || next().transpose()))
597/// }
598/// ```
599#[derive(Debug)]
600pub struct Decoder {
601    /// Explicit schema for the CSV file
602    schema: SchemaRef,
603
604    /// Optional projection for which columns to load (zero-based column indices)
605    projection: Option<Vec<usize>>,
606
607    /// Number of records per batch
608    batch_size: usize,
609
610    /// Rows to skip
611    to_skip: usize,
612
613    /// Current line number
614    line_number: usize,
615
616    /// End line number
617    end: usize,
618
619    /// A decoder for [`StringRecords`]
620    record_decoder: RecordDecoder,
621
622    /// Check if the string matches this pattern for `NULL`.
623    null_regex: NullRegex,
624}
625
626impl Decoder {
627    /// Decode records from `buf` returning the number of bytes read
628    ///
629    /// This method returns once `batch_size` objects have been parsed since the
630    /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes
631    /// should be included in the next call to [`Self::decode`]
632    ///
633    /// There is no requirement that `buf` contains a whole number of records, facilitating
634    /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or
635    /// network sources such as object storage
636    pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
637        if self.to_skip != 0 {
638            // Skip in units of `to_read` to avoid over-allocating buffers
639            let to_skip = self.to_skip.min(self.batch_size);
640            let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
641            self.to_skip -= skipped;
642            self.record_decoder.clear();
643            return Ok(bytes);
644        }
645
646        let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
647        let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
648        Ok(bytes)
649    }
650
651    /// Flushes the currently buffered data to a [`RecordBatch`]
652    ///
653    /// This should only be called after [`Self::decode`] has returned `Ok(0)`,
654    /// otherwise may return an error if part way through decoding a record
655    ///
656    /// Returns `Ok(None)` if no buffered data
657    pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
658        if self.record_decoder.is_empty() {
659            return Ok(None);
660        }
661
662        let rows = self.record_decoder.flush()?;
663        let batch = parse(
664            &rows,
665            self.schema.fields(),
666            Some(self.schema.metadata.clone()),
667            self.projection.as_ref(),
668            self.line_number,
669            &self.null_regex,
670        )?;
671        self.line_number += rows.len();
672        Ok(Some(batch))
673    }
674
675    /// Returns the number of records that can be read before requiring a call to [`Self::flush`]
676    pub fn capacity(&self) -> usize {
677        self.batch_size - self.record_decoder.len()
678    }
679}
680
681/// Parses a slice of [`StringRecords`] into a [RecordBatch]
682fn parse(
683    rows: &StringRecords<'_>,
684    fields: &Fields,
685    metadata: Option<std::collections::HashMap<String, String>>,
686    projection: Option<&Vec<usize>>,
687    line_number: usize,
688    null_regex: &NullRegex,
689) -> Result<RecordBatch, ArrowError> {
690    let projection: Vec<usize> = match projection {
691        Some(v) => v.clone(),
692        None => fields.iter().enumerate().map(|(i, _)| i).collect(),
693    };
694
695    let arrays: Result<Vec<ArrayRef>, _> = projection
696        .iter()
697        .map(|i| {
698            let i = *i;
699            let field = &fields[i];
700            match field.data_type() {
701                DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex),
702                DataType::Decimal32(precision, scale) => build_decimal_array::<Decimal32Type>(
703                    line_number,
704                    rows,
705                    i,
706                    *precision,
707                    *scale,
708                    null_regex,
709                ),
710                DataType::Decimal64(precision, scale) => build_decimal_array::<Decimal64Type>(
711                    line_number,
712                    rows,
713                    i,
714                    *precision,
715                    *scale,
716                    null_regex,
717                ),
718                DataType::Decimal128(precision, scale) => build_decimal_array::<Decimal128Type>(
719                    line_number,
720                    rows,
721                    i,
722                    *precision,
723                    *scale,
724                    null_regex,
725                ),
726                DataType::Decimal256(precision, scale) => build_decimal_array::<Decimal256Type>(
727                    line_number,
728                    rows,
729                    i,
730                    *precision,
731                    *scale,
732                    null_regex,
733                ),
734                DataType::Int8 => {
735                    build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
736                }
737                DataType::Int16 => {
738                    build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
739                }
740                DataType::Int32 => {
741                    build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
742                }
743                DataType::Int64 => {
744                    build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
745                }
746                DataType::UInt8 => {
747                    build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
748                }
749                DataType::UInt16 => {
750                    build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
751                }
752                DataType::UInt32 => {
753                    build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
754                }
755                DataType::UInt64 => {
756                    build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
757                }
758                DataType::Float32 => {
759                    build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
760                }
761                DataType::Float64 => {
762                    build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
763                }
764                DataType::Date32 => {
765                    build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
766                }
767                DataType::Date64 => {
768                    build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
769                }
770                DataType::Time32(TimeUnit::Second) => {
771                    build_primitive_array::<Time32SecondType>(line_number, rows, i, null_regex)
772                }
773                DataType::Time32(TimeUnit::Millisecond) => {
774                    build_primitive_array::<Time32MillisecondType>(line_number, rows, i, null_regex)
775                }
776                DataType::Time64(TimeUnit::Microsecond) => {
777                    build_primitive_array::<Time64MicrosecondType>(line_number, rows, i, null_regex)
778                }
779                DataType::Time64(TimeUnit::Nanosecond) => {
780                    build_primitive_array::<Time64NanosecondType>(line_number, rows, i, null_regex)
781                }
782                DataType::Timestamp(TimeUnit::Second, tz) => {
783                    build_timestamp_array::<TimestampSecondType>(
784                        line_number,
785                        rows,
786                        i,
787                        tz.as_deref(),
788                        null_regex,
789                    )
790                }
791                DataType::Timestamp(TimeUnit::Millisecond, tz) => {
792                    build_timestamp_array::<TimestampMillisecondType>(
793                        line_number,
794                        rows,
795                        i,
796                        tz.as_deref(),
797                        null_regex,
798                    )
799                }
800                DataType::Timestamp(TimeUnit::Microsecond, tz) => {
801                    build_timestamp_array::<TimestampMicrosecondType>(
802                        line_number,
803                        rows,
804                        i,
805                        tz.as_deref(),
806                        null_regex,
807                    )
808                }
809                DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
810                    build_timestamp_array::<TimestampNanosecondType>(
811                        line_number,
812                        rows,
813                        i,
814                        tz.as_deref(),
815                        null_regex,
816                    )
817                }
818                DataType::Null => Ok(Arc::new({
819                    let mut builder = NullBuilder::new();
820                    builder.append_nulls(rows.len());
821                    builder.finish()
822                }) as ArrayRef),
823                DataType::Utf8 => Ok(Arc::new(
824                    rows.iter()
825                        .map(|row| {
826                            let s = row.get(i);
827                            (!null_regex.is_null(s)).then_some(s)
828                        })
829                        .collect::<StringArray>(),
830                ) as ArrayRef),
831                DataType::Utf8View => Ok(Arc::new(
832                    rows.iter()
833                        .map(|row| {
834                            let s = row.get(i);
835                            (!null_regex.is_null(s)).then_some(s)
836                        })
837                        .collect::<StringViewArray>(),
838                ) as ArrayRef),
839                DataType::Dictionary(key_type, value_type)
840                    if value_type.as_ref() == &DataType::Utf8 =>
841                {
842                    match key_type.as_ref() {
843                        DataType::Int8 => Ok(Arc::new(
844                            rows.iter()
845                                .map(|row| {
846                                    let s = row.get(i);
847                                    (!null_regex.is_null(s)).then_some(s)
848                                })
849                                .collect::<DictionaryArray<Int8Type>>(),
850                        ) as ArrayRef),
851                        DataType::Int16 => Ok(Arc::new(
852                            rows.iter()
853                                .map(|row| {
854                                    let s = row.get(i);
855                                    (!null_regex.is_null(s)).then_some(s)
856                                })
857                                .collect::<DictionaryArray<Int16Type>>(),
858                        ) as ArrayRef),
859                        DataType::Int32 => Ok(Arc::new(
860                            rows.iter()
861                                .map(|row| {
862                                    let s = row.get(i);
863                                    (!null_regex.is_null(s)).then_some(s)
864                                })
865                                .collect::<DictionaryArray<Int32Type>>(),
866                        ) as ArrayRef),
867                        DataType::Int64 => Ok(Arc::new(
868                            rows.iter()
869                                .map(|row| {
870                                    let s = row.get(i);
871                                    (!null_regex.is_null(s)).then_some(s)
872                                })
873                                .collect::<DictionaryArray<Int64Type>>(),
874                        ) as ArrayRef),
875                        DataType::UInt8 => Ok(Arc::new(
876                            rows.iter()
877                                .map(|row| {
878                                    let s = row.get(i);
879                                    (!null_regex.is_null(s)).then_some(s)
880                                })
881                                .collect::<DictionaryArray<UInt8Type>>(),
882                        ) as ArrayRef),
883                        DataType::UInt16 => Ok(Arc::new(
884                            rows.iter()
885                                .map(|row| {
886                                    let s = row.get(i);
887                                    (!null_regex.is_null(s)).then_some(s)
888                                })
889                                .collect::<DictionaryArray<UInt16Type>>(),
890                        ) as ArrayRef),
891                        DataType::UInt32 => Ok(Arc::new(
892                            rows.iter()
893                                .map(|row| {
894                                    let s = row.get(i);
895                                    (!null_regex.is_null(s)).then_some(s)
896                                })
897                                .collect::<DictionaryArray<UInt32Type>>(),
898                        ) as ArrayRef),
899                        DataType::UInt64 => Ok(Arc::new(
900                            rows.iter()
901                                .map(|row| {
902                                    let s = row.get(i);
903                                    (!null_regex.is_null(s)).then_some(s)
904                                })
905                                .collect::<DictionaryArray<UInt64Type>>(),
906                        ) as ArrayRef),
907                        _ => Err(ArrowError::ParseError(format!(
908                            "Unsupported dictionary key type {key_type}"
909                        ))),
910                    }
911                }
912                other => Err(ArrowError::ParseError(format!(
913                    "Unsupported data type {other:?}"
914                ))),
915            }
916        })
917        .collect();
918
919    let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect();
920
921    let projected_schema = Arc::new(match metadata {
922        None => Schema::new(projected_fields),
923        Some(metadata) => Schema::new_with_metadata(projected_fields, metadata),
924    });
925
926    arrays.and_then(|arr| {
927        RecordBatch::try_new_with_options(
928            projected_schema,
929            arr,
930            &RecordBatchOptions::new()
931                .with_match_field_names(true)
932                .with_row_count(Some(rows.len())),
933        )
934    })
935}
936
937fn parse_bool(string: &str) -> Option<bool> {
938    if string.eq_ignore_ascii_case("false") {
939        Some(false)
940    } else if string.eq_ignore_ascii_case("true") {
941        Some(true)
942    } else {
943        None
944    }
945}
946
947// parse the column string to an Arrow Array
948fn build_decimal_array<T: DecimalType>(
949    _line_number: usize,
950    rows: &StringRecords<'_>,
951    col_idx: usize,
952    precision: u8,
953    scale: i8,
954    null_regex: &NullRegex,
955) -> Result<ArrayRef, ArrowError> {
956    let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
957    for row in rows.iter() {
958        let s = row.get(col_idx);
959        if null_regex.is_null(s) {
960            // append null
961            decimal_builder.append_null();
962        } else {
963            let decimal_value: Result<T::Native, _> = parse_decimal::<T>(s, precision, scale);
964            match decimal_value {
965                Ok(v) => {
966                    decimal_builder.append_value(v);
967                }
968                Err(e) => {
969                    return Err(e);
970                }
971            }
972        }
973    }
974    Ok(Arc::new(
975        decimal_builder
976            .finish()
977            .with_precision_and_scale(precision, scale)?,
978    ))
979}
980
981// parses a specific column (col_idx) into an Arrow Array.
982fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
983    line_number: usize,
984    rows: &StringRecords<'_>,
985    col_idx: usize,
986    null_regex: &NullRegex,
987) -> Result<ArrayRef, ArrowError> {
988    rows.iter()
989        .enumerate()
990        .map(|(row_index, row)| {
991            let s = row.get(col_idx);
992            if null_regex.is_null(s) {
993                return Ok(None);
994            }
995
996            match T::parse(s) {
997                Some(e) => Ok(Some(e)),
998                None => Err(ArrowError::ParseError(format!(
999                    // TODO: we should surface the underlying error here.
1000                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1001                    s,
1002                    T::DATA_TYPE,
1003                    col_idx,
1004                    line_number + row_index,
1005                    row
1006                ))),
1007            }
1008        })
1009        .collect::<Result<PrimitiveArray<T>, ArrowError>>()
1010        .map(|e| Arc::new(e) as ArrayRef)
1011}
1012
1013fn build_timestamp_array<T: ArrowTimestampType>(
1014    line_number: usize,
1015    rows: &StringRecords<'_>,
1016    col_idx: usize,
1017    timezone: Option<&str>,
1018    null_regex: &NullRegex,
1019) -> Result<ArrayRef, ArrowError> {
1020    Ok(Arc::new(match timezone {
1021        Some(timezone) => {
1022            let tz: Tz = timezone.parse()?;
1023            build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, null_regex)?
1024                .with_timezone(timezone)
1025        }
1026        None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, null_regex)?,
1027    }))
1028}
1029
1030fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
1031    line_number: usize,
1032    rows: &StringRecords<'_>,
1033    col_idx: usize,
1034    timezone: &Tz,
1035    null_regex: &NullRegex,
1036) -> Result<PrimitiveArray<T>, ArrowError> {
1037    rows.iter()
1038        .enumerate()
1039        .map(|(row_index, row)| {
1040            let s = row.get(col_idx);
1041            if null_regex.is_null(s) {
1042                return Ok(None);
1043            }
1044
1045            let date = string_to_datetime(timezone, s)
1046                .and_then(|date| match T::UNIT {
1047                    TimeUnit::Second => Ok(date.timestamp()),
1048                    TimeUnit::Millisecond => Ok(date.timestamp_millis()),
1049                    TimeUnit::Microsecond => Ok(date.timestamp_micros()),
1050                    TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| {
1051                        ArrowError::ParseError(format!(
1052                            "{} would overflow 64-bit signed nanoseconds",
1053                            date.to_rfc3339(),
1054                        ))
1055                    }),
1056                })
1057                .map_err(|e| {
1058                    ArrowError::ParseError(format!(
1059                        "Error parsing column {col_idx} at line {}: {}",
1060                        line_number + row_index,
1061                        e
1062                    ))
1063                })?;
1064            Ok(Some(date))
1065        })
1066        .collect()
1067}
1068
1069// parses a specific column (col_idx) into an Arrow Array.
1070fn build_boolean_array(
1071    line_number: usize,
1072    rows: &StringRecords<'_>,
1073    col_idx: usize,
1074    null_regex: &NullRegex,
1075) -> Result<ArrayRef, ArrowError> {
1076    rows.iter()
1077        .enumerate()
1078        .map(|(row_index, row)| {
1079            let s = row.get(col_idx);
1080            if null_regex.is_null(s) {
1081                return Ok(None);
1082            }
1083            let parsed = parse_bool(s);
1084            match parsed {
1085                Some(e) => Ok(Some(e)),
1086                None => Err(ArrowError::ParseError(format!(
1087                    // TODO: we should surface the underlying error here.
1088                    "Error while parsing value '{}' as type '{}' for column {} at line {}. Row data: '{}'",
1089                    s,
1090                    "Boolean",
1091                    col_idx,
1092                    line_number + row_index,
1093                    row
1094                ))),
1095            }
1096        })
1097        .collect::<Result<BooleanArray, _>>()
1098        .map(|e| Arc::new(e) as ArrayRef)
1099}
1100
1101/// Builder for CSV [`Reader`]s
1102#[derive(Debug)]
1103pub struct ReaderBuilder {
1104    /// Schema of the CSV file
1105    schema: SchemaRef,
1106    /// Format of the CSV file
1107    format: Format,
1108    /// Batch size (number of records to load each time)
1109    ///
1110    /// The default batch size when using the `ReaderBuilder` is 1024 records
1111    batch_size: usize,
1112    /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF.
1113    bounds: Bounds,
1114    /// Optional projection for which columns to load (zero-based column indices)
1115    projection: Option<Vec<usize>>,
1116}
1117
1118impl ReaderBuilder {
1119    /// Create a new builder for configuring [`Reader`] CSV parsing options.
1120    ///
1121    /// To convert a builder into a reader, call [`ReaderBuilder::build`]. See
1122    /// the [module-level documentation](crate::reader) for more details and examples.
1123    ///
1124    /// # Example
1125    ///
1126    /// ```
1127    /// # use arrow_csv::{Reader, ReaderBuilder};
1128    /// # use std::fs::File;
1129    /// # use std::io::Seek;
1130    /// # use std::sync::Arc;
1131    /// # use arrow_csv::reader::Format;
1132    /// #
1133    /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1134    /// // Infer the schema with the first 100 records
1135    /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap();
1136    /// file.rewind().unwrap();
1137    ///
1138    /// // create a builder
1139    /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1140    /// ```
1141    pub fn new(schema: SchemaRef) -> ReaderBuilder {
1142        Self {
1143            schema,
1144            format: Format::default(),
1145            batch_size: 1024,
1146            bounds: None,
1147            projection: None,
1148        }
1149    }
1150
1151    /// Set whether the CSV file has a header
1152    pub fn with_header(mut self, has_header: bool) -> Self {
1153        self.format.header = has_header;
1154        self
1155    }
1156
1157    /// Overrides the [Format] of this [ReaderBuilder]
1158    pub fn with_format(mut self, format: Format) -> Self {
1159        self.format = format;
1160        self
1161    }
1162
1163    /// Set the CSV file's column delimiter as a byte character
1164    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
1165        self.format.delimiter = Some(delimiter);
1166        self
1167    }
1168
1169    /// Set the given character as the CSV file's escape character
1170    pub fn with_escape(mut self, escape: u8) -> Self {
1171        self.format.escape = Some(escape);
1172        self
1173    }
1174
1175    /// Set the given character as the CSV file's quote character, by default it is double quote
1176    pub fn with_quote(mut self, quote: u8) -> Self {
1177        self.format.quote = Some(quote);
1178        self
1179    }
1180
1181    /// Provide a custom terminator character, defaults to CRLF
1182    pub fn with_terminator(mut self, terminator: u8) -> Self {
1183        self.format.terminator = Some(terminator);
1184        self
1185    }
1186
1187    /// Provide a comment character, lines starting with this character will be ignored
1188    pub fn with_comment(mut self, comment: u8) -> Self {
1189        self.format.comment = Some(comment);
1190        self
1191    }
1192
1193    /// Provide a regex to match null values, defaults to `^$`
1194    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
1195        self.format.null_regex = NullRegex(Some(null_regex));
1196        self
1197    }
1198
1199    /// Set the batch size (number of records to load at one time)
1200    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
1201        self.batch_size = batch_size;
1202        self
1203    }
1204
1205    /// Set the bounds over which to scan the reader.
1206    /// `start` and `end` are line numbers.
1207    pub fn with_bounds(mut self, start: usize, end: usize) -> Self {
1208        self.bounds = Some((start, end));
1209        self
1210    }
1211
1212    /// Set the reader's column projection
1213    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
1214        self.projection = Some(projection);
1215        self
1216    }
1217
1218    /// Whether to allow truncated rows when parsing.
1219    ///
1220    /// By default this is set to `false` and will error if the CSV rows have different lengths.
1221    /// When set to true then it will allow records with less than the expected number of columns
1222    /// and fill the missing columns with nulls. If the record's schema is not nullable, then it
1223    /// will still return an error.
1224    pub fn with_truncated_rows(mut self, allow: bool) -> Self {
1225        self.format.truncated_rows = allow;
1226        self
1227    }
1228
1229    /// Create a new `Reader` from a non-buffered reader
1230    ///
1231    /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
1232    /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`]
1233    pub fn build<R: Read>(self, reader: R) -> Result<Reader<R>, ArrowError> {
1234        self.build_buffered(StdBufReader::new(reader))
1235    }
1236
1237    /// Create a new `BufReader` from a buffered reader
1238    pub fn build_buffered<R: BufRead>(self, reader: R) -> Result<BufReader<R>, ArrowError> {
1239        Ok(BufReader {
1240            reader,
1241            decoder: self.build_decoder(),
1242        })
1243    }
1244
1245    /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
1246    pub fn build_decoder(self) -> Decoder {
1247        let delimiter = self.format.build_parser();
1248        let record_decoder = RecordDecoder::new(
1249            delimiter,
1250            self.schema.fields().len(),
1251            self.format.truncated_rows,
1252        );
1253
1254        let header = self.format.header as usize;
1255
1256        let (start, end) = match self.bounds {
1257            Some((start, end)) => (start + header, end + header),
1258            None => (header, usize::MAX),
1259        };
1260
1261        Decoder {
1262            schema: self.schema,
1263            to_skip: start,
1264            record_decoder,
1265            line_number: start,
1266            end,
1267            projection: self.projection,
1268            batch_size: self.batch_size,
1269            null_regex: self.format.null_regex,
1270        }
1271    }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276    use super::*;
1277
1278    use std::io::{Cursor, Seek, SeekFrom, Write};
1279    use tempfile::NamedTempFile;
1280
1281    use arrow_array::cast::AsArray;
1282
1283    #[test]
1284    fn test_csv() {
1285        let schema = Arc::new(Schema::new(vec![
1286            Field::new("city", DataType::Utf8, false),
1287            Field::new("lat", DataType::Float64, false),
1288            Field::new("lng", DataType::Float64, false),
1289        ]));
1290
1291        let file = File::open("test/data/uk_cities.csv").unwrap();
1292        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1293        assert_eq!(schema, csv.schema());
1294        let batch = csv.next().unwrap().unwrap();
1295        assert_eq!(37, batch.num_rows());
1296        assert_eq!(3, batch.num_columns());
1297
1298        // access data from a primitive array
1299        let lat = batch.column(1).as_primitive::<Float64Type>();
1300        assert_eq!(57.653484, lat.value(0));
1301
1302        // access data from a string array (ListArray<u8>)
1303        let city = batch.column(0).as_string::<i32>();
1304
1305        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1306    }
1307
1308    #[test]
1309    fn test_csv_schema_metadata() {
1310        let mut metadata = std::collections::HashMap::new();
1311        metadata.insert("foo".to_owned(), "bar".to_owned());
1312        let schema = Arc::new(Schema::new_with_metadata(
1313            vec![
1314                Field::new("city", DataType::Utf8, false),
1315                Field::new("lat", DataType::Float64, false),
1316                Field::new("lng", DataType::Float64, false),
1317            ],
1318            metadata.clone(),
1319        ));
1320
1321        let file = File::open("test/data/uk_cities.csv").unwrap();
1322
1323        let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap();
1324        assert_eq!(schema, csv.schema());
1325        let batch = csv.next().unwrap().unwrap();
1326        assert_eq!(37, batch.num_rows());
1327        assert_eq!(3, batch.num_columns());
1328
1329        assert_eq!(&metadata, batch.schema().metadata());
1330    }
1331
1332    #[test]
1333    fn test_csv_reader_with_decimal() {
1334        let schema = Arc::new(Schema::new(vec![
1335            Field::new("city", DataType::Utf8, false),
1336            Field::new("lat", DataType::Decimal128(38, 6), false),
1337            Field::new("lng", DataType::Decimal256(76, 6), false),
1338        ]));
1339
1340        let file = File::open("test/data/decimal_test.csv").unwrap();
1341
1342        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1343        let batch = csv.next().unwrap().unwrap();
1344        // access data from a primitive array
1345        let lat = batch
1346            .column(1)
1347            .as_any()
1348            .downcast_ref::<Decimal128Array>()
1349            .unwrap();
1350
1351        assert_eq!("57.653484", lat.value_as_string(0));
1352        assert_eq!("53.002666", lat.value_as_string(1));
1353        assert_eq!("52.412811", lat.value_as_string(2));
1354        assert_eq!("51.481583", lat.value_as_string(3));
1355        assert_eq!("12.123456", lat.value_as_string(4));
1356        assert_eq!("50.760000", lat.value_as_string(5));
1357        assert_eq!("0.123000", lat.value_as_string(6));
1358        assert_eq!("123.000000", lat.value_as_string(7));
1359        assert_eq!("123.000000", lat.value_as_string(8));
1360        assert_eq!("-50.760000", lat.value_as_string(9));
1361
1362        let lng = batch
1363            .column(2)
1364            .as_any()
1365            .downcast_ref::<Decimal256Array>()
1366            .unwrap();
1367
1368        assert_eq!("-3.335724", lng.value_as_string(0));
1369        assert_eq!("-2.179404", lng.value_as_string(1));
1370        assert_eq!("-1.778197", lng.value_as_string(2));
1371        assert_eq!("-3.179090", lng.value_as_string(3));
1372        assert_eq!("-3.179090", lng.value_as_string(4));
1373        assert_eq!("0.290472", lng.value_as_string(5));
1374        assert_eq!("0.290472", lng.value_as_string(6));
1375        assert_eq!("0.290472", lng.value_as_string(7));
1376        assert_eq!("0.290472", lng.value_as_string(8));
1377        assert_eq!("0.290472", lng.value_as_string(9));
1378    }
1379
1380    #[test]
1381    fn test_csv_reader_with_decimal_3264() {
1382        let schema = Arc::new(Schema::new(vec![
1383            Field::new("city", DataType::Utf8, false),
1384            Field::new("lat", DataType::Decimal32(9, 6), false),
1385            Field::new("lng", DataType::Decimal64(16, 6), false),
1386        ]));
1387
1388        let file = File::open("test/data/decimal_test.csv").unwrap();
1389
1390        let mut csv = ReaderBuilder::new(schema).build(file).unwrap();
1391        let batch = csv.next().unwrap().unwrap();
1392        // access data from a primitive array
1393        let lat = batch
1394            .column(1)
1395            .as_any()
1396            .downcast_ref::<Decimal32Array>()
1397            .unwrap();
1398
1399        assert_eq!("57.653484", lat.value_as_string(0));
1400        assert_eq!("53.002666", lat.value_as_string(1));
1401        assert_eq!("52.412811", lat.value_as_string(2));
1402        assert_eq!("51.481583", lat.value_as_string(3));
1403        assert_eq!("12.123456", lat.value_as_string(4));
1404        assert_eq!("50.760000", lat.value_as_string(5));
1405        assert_eq!("0.123000", lat.value_as_string(6));
1406        assert_eq!("123.000000", lat.value_as_string(7));
1407        assert_eq!("123.000000", lat.value_as_string(8));
1408        assert_eq!("-50.760000", lat.value_as_string(9));
1409
1410        let lng = batch
1411            .column(2)
1412            .as_any()
1413            .downcast_ref::<Decimal64Array>()
1414            .unwrap();
1415
1416        assert_eq!("-3.335724", lng.value_as_string(0));
1417        assert_eq!("-2.179404", lng.value_as_string(1));
1418        assert_eq!("-1.778197", lng.value_as_string(2));
1419        assert_eq!("-3.179090", lng.value_as_string(3));
1420        assert_eq!("-3.179090", lng.value_as_string(4));
1421        assert_eq!("0.290472", lng.value_as_string(5));
1422        assert_eq!("0.290472", lng.value_as_string(6));
1423        assert_eq!("0.290472", lng.value_as_string(7));
1424        assert_eq!("0.290472", lng.value_as_string(8));
1425        assert_eq!("0.290472", lng.value_as_string(9));
1426    }
1427
1428    #[test]
1429    fn test_csv_from_buf_reader() {
1430        let schema = Schema::new(vec![
1431            Field::new("city", DataType::Utf8, false),
1432            Field::new("lat", DataType::Float64, false),
1433            Field::new("lng", DataType::Float64, false),
1434        ]);
1435
1436        let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1437        let file_without_headers = File::open("test/data/uk_cities.csv").unwrap();
1438        let both_files = file_with_headers
1439            .chain(Cursor::new("\n".to_string()))
1440            .chain(file_without_headers);
1441        let mut csv = ReaderBuilder::new(Arc::new(schema))
1442            .with_header(true)
1443            .build(both_files)
1444            .unwrap();
1445        let batch = csv.next().unwrap().unwrap();
1446        assert_eq!(74, batch.num_rows());
1447        assert_eq!(3, batch.num_columns());
1448    }
1449
1450    #[test]
1451    fn test_csv_with_schema_inference() {
1452        let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap();
1453
1454        let (schema, _) = Format::default()
1455            .with_header(true)
1456            .infer_schema(&mut file, None)
1457            .unwrap();
1458
1459        file.rewind().unwrap();
1460        let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true);
1461
1462        let mut csv = builder.build(file).unwrap();
1463        let expected_schema = Schema::new(vec![
1464            Field::new("city", DataType::Utf8, true),
1465            Field::new("lat", DataType::Float64, true),
1466            Field::new("lng", DataType::Float64, true),
1467        ]);
1468        assert_eq!(Arc::new(expected_schema), csv.schema());
1469        let batch = csv.next().unwrap().unwrap();
1470        assert_eq!(37, batch.num_rows());
1471        assert_eq!(3, batch.num_columns());
1472
1473        // access data from a primitive array
1474        let lat = batch
1475            .column(1)
1476            .as_any()
1477            .downcast_ref::<Float64Array>()
1478            .unwrap();
1479        assert_eq!(57.653484, lat.value(0));
1480
1481        // access data from a string array (ListArray<u8>)
1482        let city = batch
1483            .column(0)
1484            .as_any()
1485            .downcast_ref::<StringArray>()
1486            .unwrap();
1487
1488        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1489    }
1490
1491    #[test]
1492    fn test_csv_with_schema_inference_no_headers() {
1493        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1494
1495        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1496        file.rewind().unwrap();
1497
1498        let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap();
1499
1500        // csv field names should be 'column_{number}'
1501        let schema = csv.schema();
1502        assert_eq!("column_1", schema.field(0).name());
1503        assert_eq!("column_2", schema.field(1).name());
1504        assert_eq!("column_3", schema.field(2).name());
1505        let batch = csv.next().unwrap().unwrap();
1506        let batch_schema = batch.schema();
1507
1508        assert_eq!(schema, batch_schema);
1509        assert_eq!(37, batch.num_rows());
1510        assert_eq!(3, batch.num_columns());
1511
1512        // access data from a primitive array
1513        let lat = batch
1514            .column(1)
1515            .as_any()
1516            .downcast_ref::<Float64Array>()
1517            .unwrap();
1518        assert_eq!(57.653484, lat.value(0));
1519
1520        // access data from a string array (ListArray<u8>)
1521        let city = batch
1522            .column(0)
1523            .as_any()
1524            .downcast_ref::<StringArray>()
1525            .unwrap();
1526
1527        assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
1528    }
1529
1530    #[test]
1531    fn test_csv_builder_with_bounds() {
1532        let mut file = File::open("test/data/uk_cities.csv").unwrap();
1533
1534        // Set the bounds to the lines 0, 1 and 2.
1535        let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap();
1536        file.rewind().unwrap();
1537        let mut csv = ReaderBuilder::new(Arc::new(schema))
1538            .with_bounds(0, 2)
1539            .build(file)
1540            .unwrap();
1541        let batch = csv.next().unwrap().unwrap();
1542
1543        // access data from a string array (ListArray<u8>)
1544        let city = batch
1545            .column(0)
1546            .as_any()
1547            .downcast_ref::<StringArray>()
1548            .unwrap();
1549
1550        // The value on line 0 is within the bounds
1551        assert_eq!("Elgin, Scotland, the UK", city.value(0));
1552
1553        // The value on line 13 is outside of the bounds. Therefore
1554        // the call to .value() will panic.
1555        let result = std::panic::catch_unwind(|| city.value(13));
1556        assert!(result.is_err());
1557    }
1558
1559    #[test]
1560    fn test_csv_with_projection() {
1561        let schema = Arc::new(Schema::new(vec![
1562            Field::new("city", DataType::Utf8, false),
1563            Field::new("lat", DataType::Float64, false),
1564            Field::new("lng", DataType::Float64, false),
1565        ]));
1566
1567        let file = File::open("test/data/uk_cities.csv").unwrap();
1568
1569        let mut csv = ReaderBuilder::new(schema)
1570            .with_projection(vec![0, 1])
1571            .build(file)
1572            .unwrap();
1573
1574        let projected_schema = Arc::new(Schema::new(vec![
1575            Field::new("city", DataType::Utf8, false),
1576            Field::new("lat", DataType::Float64, false),
1577        ]));
1578        assert_eq!(projected_schema, csv.schema());
1579        let batch = csv.next().unwrap().unwrap();
1580        assert_eq!(projected_schema, batch.schema());
1581        assert_eq!(37, batch.num_rows());
1582        assert_eq!(2, batch.num_columns());
1583    }
1584
1585    #[test]
1586    fn test_csv_with_dictionary() {
1587        let schema = Arc::new(Schema::new(vec![
1588            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1589            Field::new("lat", DataType::Float64, false),
1590            Field::new("lng", DataType::Float64, false),
1591        ]));
1592
1593        let file = File::open("test/data/uk_cities.csv").unwrap();
1594
1595        let mut csv = ReaderBuilder::new(schema)
1596            .with_projection(vec![0, 1])
1597            .build(file)
1598            .unwrap();
1599
1600        let projected_schema = Arc::new(Schema::new(vec![
1601            Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false),
1602            Field::new("lat", DataType::Float64, false),
1603        ]));
1604        assert_eq!(projected_schema, csv.schema());
1605        let batch = csv.next().unwrap().unwrap();
1606        assert_eq!(projected_schema, batch.schema());
1607        assert_eq!(37, batch.num_rows());
1608        assert_eq!(2, batch.num_columns());
1609
1610        let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap();
1611        let strings = strings.as_string::<i32>();
1612
1613        assert_eq!(strings.value(0), "Elgin, Scotland, the UK");
1614        assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK");
1615        assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
1616    }
1617
1618    #[test]
1619    fn test_csv_with_nullable_dictionary() {
1620        let offset_type = vec![
1621            DataType::Int8,
1622            DataType::Int16,
1623            DataType::Int32,
1624            DataType::Int64,
1625            DataType::UInt8,
1626            DataType::UInt16,
1627            DataType::UInt32,
1628            DataType::UInt64,
1629        ];
1630        for data_type in offset_type {
1631            let file = File::open("test/data/dictionary_nullable_test.csv").unwrap();
1632            let dictionary_type =
1633                DataType::Dictionary(Box::new(data_type), Box::new(DataType::Utf8));
1634            let schema = Arc::new(Schema::new(vec![
1635                Field::new("id", DataType::Utf8, false),
1636                Field::new("name", dictionary_type.clone(), true),
1637            ]));
1638
1639            let mut csv = ReaderBuilder::new(schema)
1640                .build(file.try_clone().unwrap())
1641                .unwrap();
1642
1643            let batch = csv.next().unwrap().unwrap();
1644            assert_eq!(3, batch.num_rows());
1645            assert_eq!(2, batch.num_columns());
1646
1647            let names = arrow_cast::cast(batch.column(1), &dictionary_type).unwrap();
1648            assert!(!names.is_null(2));
1649            assert!(names.is_null(1));
1650        }
1651    }
1652    #[test]
1653    fn test_nulls() {
1654        let schema = Arc::new(Schema::new(vec![
1655            Field::new("c_int", DataType::UInt64, false),
1656            Field::new("c_float", DataType::Float32, true),
1657            Field::new("c_string", DataType::Utf8, true),
1658            Field::new("c_bool", DataType::Boolean, false),
1659        ]));
1660
1661        let file = File::open("test/data/null_test.csv").unwrap();
1662
1663        let mut csv = ReaderBuilder::new(schema)
1664            .with_header(true)
1665            .build(file)
1666            .unwrap();
1667
1668        let batch = csv.next().unwrap().unwrap();
1669
1670        assert!(!batch.column(1).is_null(0));
1671        assert!(!batch.column(1).is_null(1));
1672        assert!(batch.column(1).is_null(2));
1673        assert!(!batch.column(1).is_null(3));
1674        assert!(!batch.column(1).is_null(4));
1675    }
1676
1677    #[test]
1678    fn test_init_nulls() {
1679        let schema = Arc::new(Schema::new(vec![
1680            Field::new("c_int", DataType::UInt64, true),
1681            Field::new("c_float", DataType::Float32, true),
1682            Field::new("c_string", DataType::Utf8, true),
1683            Field::new("c_bool", DataType::Boolean, true),
1684            Field::new("c_null", DataType::Null, true),
1685        ]));
1686        let file = File::open("test/data/init_null_test.csv").unwrap();
1687
1688        let mut csv = ReaderBuilder::new(schema)
1689            .with_header(true)
1690            .build(file)
1691            .unwrap();
1692
1693        let batch = csv.next().unwrap().unwrap();
1694
1695        assert!(batch.column(1).is_null(0));
1696        assert!(!batch.column(1).is_null(1));
1697        assert!(batch.column(1).is_null(2));
1698        assert!(!batch.column(1).is_null(3));
1699        assert!(!batch.column(1).is_null(4));
1700    }
1701
1702    #[test]
1703    fn test_init_nulls_with_inference() {
1704        let format = Format::default().with_header(true).with_delimiter(b',');
1705
1706        let mut file = File::open("test/data/init_null_test.csv").unwrap();
1707        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1708        file.rewind().unwrap();
1709
1710        let expected_schema = Schema::new(vec![
1711            Field::new("c_int", DataType::Int64, true),
1712            Field::new("c_float", DataType::Float64, true),
1713            Field::new("c_string", DataType::Utf8, true),
1714            Field::new("c_bool", DataType::Boolean, true),
1715            Field::new("c_null", DataType::Null, true),
1716        ]);
1717        assert_eq!(schema, expected_schema);
1718
1719        let mut csv = ReaderBuilder::new(Arc::new(schema))
1720            .with_format(format)
1721            .build(file)
1722            .unwrap();
1723
1724        let batch = csv.next().unwrap().unwrap();
1725
1726        assert!(batch.column(1).is_null(0));
1727        assert!(!batch.column(1).is_null(1));
1728        assert!(batch.column(1).is_null(2));
1729        assert!(!batch.column(1).is_null(3));
1730        assert!(!batch.column(1).is_null(4));
1731    }
1732
1733    #[test]
1734    fn test_custom_nulls() {
1735        let schema = Arc::new(Schema::new(vec![
1736            Field::new("c_int", DataType::UInt64, true),
1737            Field::new("c_float", DataType::Float32, true),
1738            Field::new("c_string", DataType::Utf8, true),
1739            Field::new("c_bool", DataType::Boolean, true),
1740        ]));
1741
1742        let file = File::open("test/data/custom_null_test.csv").unwrap();
1743
1744        let null_regex = Regex::new("^nil$").unwrap();
1745
1746        let mut csv = ReaderBuilder::new(schema)
1747            .with_header(true)
1748            .with_null_regex(null_regex)
1749            .build(file)
1750            .unwrap();
1751
1752        let batch = csv.next().unwrap().unwrap();
1753
1754        // "nil"s should be NULL
1755        assert!(batch.column(0).is_null(1));
1756        assert!(batch.column(1).is_null(2));
1757        assert!(batch.column(3).is_null(4));
1758        assert!(batch.column(2).is_null(3));
1759        assert!(!batch.column(2).is_null(4));
1760    }
1761
1762    #[test]
1763    fn test_nulls_with_inference() {
1764        let mut file = File::open("test/data/various_types.csv").unwrap();
1765        let format = Format::default().with_header(true).with_delimiter(b'|');
1766
1767        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1768        file.rewind().unwrap();
1769
1770        let builder = ReaderBuilder::new(Arc::new(schema))
1771            .with_format(format)
1772            .with_batch_size(512)
1773            .with_projection(vec![0, 1, 2, 3, 4, 5]);
1774
1775        let mut csv = builder.build(file).unwrap();
1776        let batch = csv.next().unwrap().unwrap();
1777
1778        assert_eq!(10, batch.num_rows());
1779        assert_eq!(6, batch.num_columns());
1780
1781        let schema = batch.schema();
1782
1783        assert_eq!(&DataType::Int64, schema.field(0).data_type());
1784        assert_eq!(&DataType::Float64, schema.field(1).data_type());
1785        assert_eq!(&DataType::Float64, schema.field(2).data_type());
1786        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
1787        assert_eq!(&DataType::Date32, schema.field(4).data_type());
1788        assert_eq!(
1789            &DataType::Timestamp(TimeUnit::Second, None),
1790            schema.field(5).data_type()
1791        );
1792
1793        let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect();
1794        assert_eq!(
1795            names,
1796            vec![
1797                "c_int",
1798                "c_float",
1799                "c_string",
1800                "c_bool",
1801                "c_date",
1802                "c_datetime"
1803            ]
1804        );
1805
1806        assert!(schema.field(0).is_nullable());
1807        assert!(schema.field(1).is_nullable());
1808        assert!(schema.field(2).is_nullable());
1809        assert!(schema.field(3).is_nullable());
1810        assert!(schema.field(4).is_nullable());
1811        assert!(schema.field(5).is_nullable());
1812
1813        assert!(!batch.column(1).is_null(0));
1814        assert!(!batch.column(1).is_null(1));
1815        assert!(batch.column(1).is_null(2));
1816        assert!(!batch.column(1).is_null(3));
1817        assert!(!batch.column(1).is_null(4));
1818    }
1819
1820    #[test]
1821    fn test_custom_nulls_with_inference() {
1822        let mut file = File::open("test/data/custom_null_test.csv").unwrap();
1823
1824        let null_regex = Regex::new("^nil$").unwrap();
1825
1826        let format = Format::default()
1827            .with_header(true)
1828            .with_null_regex(null_regex);
1829
1830        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1831        file.rewind().unwrap();
1832
1833        let expected_schema = Schema::new(vec![
1834            Field::new("c_int", DataType::Int64, true),
1835            Field::new("c_float", DataType::Float64, true),
1836            Field::new("c_string", DataType::Utf8, true),
1837            Field::new("c_bool", DataType::Boolean, true),
1838        ]);
1839
1840        assert_eq!(schema, expected_schema);
1841
1842        let builder = ReaderBuilder::new(Arc::new(schema))
1843            .with_format(format)
1844            .with_batch_size(512)
1845            .with_projection(vec![0, 1, 2, 3]);
1846
1847        let mut csv = builder.build(file).unwrap();
1848        let batch = csv.next().unwrap().unwrap();
1849
1850        assert_eq!(5, batch.num_rows());
1851        assert_eq!(4, batch.num_columns());
1852
1853        assert_eq!(batch.schema().as_ref(), &expected_schema);
1854    }
1855
1856    #[test]
1857    fn test_scientific_notation_with_inference() {
1858        let mut file = File::open("test/data/scientific_notation_test.csv").unwrap();
1859        let format = Format::default().with_header(false).with_delimiter(b',');
1860
1861        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
1862        file.rewind().unwrap();
1863
1864        let builder = ReaderBuilder::new(Arc::new(schema))
1865            .with_format(format)
1866            .with_batch_size(512)
1867            .with_projection(vec![0, 1]);
1868
1869        let mut csv = builder.build(file).unwrap();
1870        let batch = csv.next().unwrap().unwrap();
1871
1872        let schema = batch.schema();
1873
1874        assert_eq!(&DataType::Float64, schema.field(0).data_type());
1875    }
1876
1877    fn invalid_csv_helper(file_name: &str) -> String {
1878        let file = File::open(file_name).unwrap();
1879        let schema = Schema::new(vec![
1880            Field::new("c_int", DataType::UInt64, false),
1881            Field::new("c_float", DataType::Float32, false),
1882            Field::new("c_string", DataType::Utf8, false),
1883            Field::new("c_bool", DataType::Boolean, false),
1884        ]);
1885
1886        let builder = ReaderBuilder::new(Arc::new(schema))
1887            .with_header(true)
1888            .with_delimiter(b'|')
1889            .with_batch_size(512)
1890            .with_projection(vec![0, 1, 2, 3]);
1891
1892        let mut csv = builder.build(file).unwrap();
1893
1894        csv.next().unwrap().unwrap_err().to_string()
1895    }
1896
1897    #[test]
1898    fn test_parse_invalid_csv_float() {
1899        let file_name = "test/data/various_invalid_types/invalid_float.csv";
1900
1901        let error = invalid_csv_helper(file_name);
1902        assert_eq!(
1903            "Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'",
1904            error
1905        );
1906    }
1907
1908    #[test]
1909    fn test_parse_invalid_csv_int() {
1910        let file_name = "test/data/various_invalid_types/invalid_int.csv";
1911
1912        let error = invalid_csv_helper(file_name);
1913        assert_eq!(
1914            "Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'",
1915            error
1916        );
1917    }
1918
1919    #[test]
1920    fn test_parse_invalid_csv_bool() {
1921        let file_name = "test/data/various_invalid_types/invalid_bool.csv";
1922
1923        let error = invalid_csv_helper(file_name);
1924        assert_eq!(
1925            "Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'",
1926            error
1927        );
1928    }
1929
1930    /// Infer the data type of a record
1931    fn infer_field_schema(string: &str) -> DataType {
1932        let mut v = InferredDataType::default();
1933        v.update(string);
1934        v.get()
1935    }
1936
1937    #[test]
1938    fn test_infer_field_schema() {
1939        assert_eq!(infer_field_schema("A"), DataType::Utf8);
1940        assert_eq!(infer_field_schema("\"123\""), DataType::Utf8);
1941        assert_eq!(infer_field_schema("10"), DataType::Int64);
1942        assert_eq!(infer_field_schema("10.2"), DataType::Float64);
1943        assert_eq!(infer_field_schema(".2"), DataType::Float64);
1944        assert_eq!(infer_field_schema("2."), DataType::Float64);
1945        assert_eq!(infer_field_schema("NaN"), DataType::Float64);
1946        assert_eq!(infer_field_schema("nan"), DataType::Float64);
1947        assert_eq!(infer_field_schema("inf"), DataType::Float64);
1948        assert_eq!(infer_field_schema("-inf"), DataType::Float64);
1949        assert_eq!(infer_field_schema("true"), DataType::Boolean);
1950        assert_eq!(infer_field_schema("trUe"), DataType::Boolean);
1951        assert_eq!(infer_field_schema("false"), DataType::Boolean);
1952        assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
1953        assert_eq!(
1954            infer_field_schema("2020-11-08T14:20:01"),
1955            DataType::Timestamp(TimeUnit::Second, None)
1956        );
1957        assert_eq!(
1958            infer_field_schema("2020-11-08 14:20:01"),
1959            DataType::Timestamp(TimeUnit::Second, None)
1960        );
1961        assert_eq!(
1962            infer_field_schema("2020-11-08 14:20:01"),
1963            DataType::Timestamp(TimeUnit::Second, None)
1964        );
1965        assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
1966        assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
1967        assert_eq!(
1968            infer_field_schema("2021-12-19 13:12:30.921"),
1969            DataType::Timestamp(TimeUnit::Millisecond, None)
1970        );
1971        assert_eq!(
1972            infer_field_schema("2021-12-19T13:12:30.123456789"),
1973            DataType::Timestamp(TimeUnit::Nanosecond, None)
1974        );
1975        assert_eq!(infer_field_schema("–9223372036854775809"), DataType::Utf8);
1976        assert_eq!(infer_field_schema("9223372036854775808"), DataType::Utf8);
1977    }
1978
1979    #[test]
1980    fn parse_date32() {
1981        assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0);
1982        assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336);
1983        assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004);
1984    }
1985
1986    #[test]
1987    fn parse_time() {
1988        assert_eq!(
1989            Time64NanosecondType::parse("12:10:01.123456789 AM"),
1990            Some(601_123_456_789)
1991        );
1992        assert_eq!(
1993            Time64MicrosecondType::parse("12:10:01.123456 am"),
1994            Some(601_123_456)
1995        );
1996        assert_eq!(
1997            Time32MillisecondType::parse("2:10:01.12 PM"),
1998            Some(51_001_120)
1999        );
2000        assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001));
2001    }
2002
2003    #[test]
2004    fn parse_date64() {
2005        assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0);
2006        assert_eq!(
2007            Date64Type::parse("2018-11-13T17:11:10").unwrap(),
2008            1542129070000
2009        );
2010        assert_eq!(
2011            Date64Type::parse("2018-11-13T17:11:10.011").unwrap(),
2012            1542129070011
2013        );
2014        assert_eq!(
2015            Date64Type::parse("1900-02-28T12:34:56").unwrap(),
2016            -2203932304000
2017        );
2018        assert_eq!(
2019            Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(),
2020            -2203932304000
2021        );
2022        assert_eq!(
2023            Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(),
2024            -2203932304000 - (30 * 60 * 1000)
2025        );
2026    }
2027
2028    fn test_parse_timestamp_impl<T: ArrowTimestampType>(
2029        timezone: Option<Arc<str>>,
2030        expected: &[i64],
2031    ) {
2032        let csv = [
2033            "1970-01-01T00:00:00",
2034            "1970-01-01T00:00:00Z",
2035            "1970-01-01T00:00:00+02:00",
2036        ]
2037        .join("\n");
2038        let schema = Arc::new(Schema::new(vec![Field::new(
2039            "field",
2040            DataType::Timestamp(T::UNIT, timezone.clone()),
2041            true,
2042        )]));
2043
2044        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2045
2046        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2047        assert_eq!(decoded, csv.len());
2048        decoder.decode(&[]).unwrap();
2049
2050        let batch = decoder.flush().unwrap().unwrap();
2051        assert_eq!(batch.num_columns(), 1);
2052        assert_eq!(batch.num_rows(), 3);
2053        let col = batch.column(0).as_primitive::<T>();
2054        assert_eq!(col.values(), expected);
2055        assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone));
2056    }
2057
2058    #[test]
2059    fn test_parse_timestamp() {
2060        test_parse_timestamp_impl::<TimestampNanosecondType>(None, &[0, 0, -7_200_000_000_000]);
2061        test_parse_timestamp_impl::<TimestampNanosecondType>(
2062            Some("+00:00".into()),
2063            &[0, 0, -7_200_000_000_000],
2064        );
2065        test_parse_timestamp_impl::<TimestampNanosecondType>(
2066            Some("-05:00".into()),
2067            &[18_000_000_000_000, 0, -7_200_000_000_000],
2068        );
2069        test_parse_timestamp_impl::<TimestampMicrosecondType>(
2070            Some("-03".into()),
2071            &[10_800_000_000, 0, -7_200_000_000],
2072        );
2073        test_parse_timestamp_impl::<TimestampMillisecondType>(
2074            Some("-03".into()),
2075            &[10_800_000, 0, -7_200_000],
2076        );
2077        test_parse_timestamp_impl::<TimestampSecondType>(Some("-03".into()), &[10_800, 0, -7_200]);
2078    }
2079
2080    #[test]
2081    fn test_infer_schema_from_multiple_files() {
2082        let mut csv1 = NamedTempFile::new().unwrap();
2083        let mut csv2 = NamedTempFile::new().unwrap();
2084        let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped
2085        let mut csv4 = NamedTempFile::new().unwrap();
2086        writeln!(csv1, "c1,c2,c3").unwrap();
2087        writeln!(csv1, "1,\"foo\",0.5").unwrap();
2088        writeln!(csv1, "3,\"bar\",1").unwrap();
2089        writeln!(csv1, "3,\"bar\",2e-06").unwrap();
2090        // reading csv2 will set c2 to optional
2091        writeln!(csv2, "c1,c2,c3,c4").unwrap();
2092        writeln!(csv2, "10,,3.14,true").unwrap();
2093        // reading csv4 will set c3 to optional
2094        writeln!(csv4, "c1,c2,c3").unwrap();
2095        writeln!(csv4, "10,\"foo\",").unwrap();
2096
2097        let schema = infer_schema_from_files(
2098            &[
2099                csv3.path().to_str().unwrap().to_string(),
2100                csv1.path().to_str().unwrap().to_string(),
2101                csv2.path().to_str().unwrap().to_string(),
2102                csv4.path().to_str().unwrap().to_string(),
2103            ],
2104            b',',
2105            Some(4), // only csv1 and csv2 should be read
2106            true,
2107        )
2108        .unwrap();
2109
2110        assert_eq!(schema.fields().len(), 4);
2111        assert!(schema.field(0).is_nullable());
2112        assert!(schema.field(1).is_nullable());
2113        assert!(schema.field(2).is_nullable());
2114        assert!(schema.field(3).is_nullable());
2115
2116        assert_eq!(&DataType::Int64, schema.field(0).data_type());
2117        assert_eq!(&DataType::Utf8, schema.field(1).data_type());
2118        assert_eq!(&DataType::Float64, schema.field(2).data_type());
2119        assert_eq!(&DataType::Boolean, schema.field(3).data_type());
2120    }
2121
2122    #[test]
2123    fn test_bounded() {
2124        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2125        let data = [
2126            vec!["0"],
2127            vec!["1"],
2128            vec!["2"],
2129            vec!["3"],
2130            vec!["4"],
2131            vec!["5"],
2132            vec!["6"],
2133        ];
2134
2135        let data = data
2136            .iter()
2137            .map(|x| x.join(","))
2138            .collect::<Vec<_>>()
2139            .join("\n");
2140        let data = data.as_bytes();
2141
2142        let reader = std::io::Cursor::new(data);
2143
2144        let mut csv = ReaderBuilder::new(Arc::new(schema))
2145            .with_batch_size(2)
2146            .with_projection(vec![0])
2147            .with_bounds(2, 6)
2148            .build_buffered(reader)
2149            .unwrap();
2150
2151        let batch = csv.next().unwrap().unwrap();
2152        let a = batch.column(0);
2153        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2154        assert_eq!(a, &UInt32Array::from(vec![2, 3]));
2155
2156        let batch = csv.next().unwrap().unwrap();
2157        let a = batch.column(0);
2158        let a = a.as_any().downcast_ref::<UInt32Array>().unwrap();
2159        assert_eq!(a, &UInt32Array::from(vec![4, 5]));
2160
2161        assert!(csv.next().is_none());
2162    }
2163
2164    #[test]
2165    fn test_empty_projection() {
2166        let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]);
2167        let data = [vec!["0"], vec!["1"]];
2168
2169        let data = data
2170            .iter()
2171            .map(|x| x.join(","))
2172            .collect::<Vec<_>>()
2173            .join("\n");
2174
2175        let mut csv = ReaderBuilder::new(Arc::new(schema))
2176            .with_batch_size(2)
2177            .with_projection(vec![])
2178            .build_buffered(Cursor::new(data.as_bytes()))
2179            .unwrap();
2180
2181        let batch = csv.next().unwrap().unwrap();
2182        assert_eq!(batch.columns().len(), 0);
2183        assert_eq!(batch.num_rows(), 2);
2184
2185        assert!(csv.next().is_none());
2186    }
2187
2188    #[test]
2189    fn test_parsing_bool() {
2190        // Encode the expected behavior of boolean parsing
2191        assert_eq!(Some(true), parse_bool("true"));
2192        assert_eq!(Some(true), parse_bool("tRUe"));
2193        assert_eq!(Some(true), parse_bool("True"));
2194        assert_eq!(Some(true), parse_bool("TRUE"));
2195        assert_eq!(None, parse_bool("t"));
2196        assert_eq!(None, parse_bool("T"));
2197        assert_eq!(None, parse_bool(""));
2198
2199        assert_eq!(Some(false), parse_bool("false"));
2200        assert_eq!(Some(false), parse_bool("fALse"));
2201        assert_eq!(Some(false), parse_bool("False"));
2202        assert_eq!(Some(false), parse_bool("FALSE"));
2203        assert_eq!(None, parse_bool("f"));
2204        assert_eq!(None, parse_bool("F"));
2205        assert_eq!(None, parse_bool(""));
2206    }
2207
2208    #[test]
2209    fn test_parsing_float() {
2210        assert_eq!(Some(12.34), Float64Type::parse("12.34"));
2211        assert_eq!(Some(-12.34), Float64Type::parse("-12.34"));
2212        assert_eq!(Some(12.0), Float64Type::parse("12"));
2213        assert_eq!(Some(0.0), Float64Type::parse("0"));
2214        assert_eq!(Some(2.0), Float64Type::parse("2."));
2215        assert_eq!(Some(0.2), Float64Type::parse(".2"));
2216        assert!(Float64Type::parse("nan").unwrap().is_nan());
2217        assert!(Float64Type::parse("NaN").unwrap().is_nan());
2218        assert!(Float64Type::parse("inf").unwrap().is_infinite());
2219        assert!(Float64Type::parse("inf").unwrap().is_sign_positive());
2220        assert!(Float64Type::parse("-inf").unwrap().is_infinite());
2221        assert!(Float64Type::parse("-inf").unwrap().is_sign_negative());
2222        assert_eq!(None, Float64Type::parse(""));
2223        assert_eq!(None, Float64Type::parse("dd"));
2224        assert_eq!(None, Float64Type::parse("12.34.56"));
2225    }
2226
2227    #[test]
2228    fn test_non_std_quote() {
2229        let schema = Schema::new(vec![
2230            Field::new("text1", DataType::Utf8, false),
2231            Field::new("text2", DataType::Utf8, false),
2232        ]);
2233        let builder = ReaderBuilder::new(Arc::new(schema))
2234            .with_header(false)
2235            .with_quote(b'~'); // default is ", change to ~
2236
2237        let mut csv_text = Vec::new();
2238        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2239        for index in 0..10 {
2240            let text1 = format!("id{index:}");
2241            let text2 = format!("value{index:}");
2242            csv_writer
2243                .write_fmt(format_args!("~{text1}~,~{text2}~\r\n"))
2244                .unwrap();
2245        }
2246        let mut csv_reader = std::io::Cursor::new(&csv_text);
2247        let mut reader = builder.build(&mut csv_reader).unwrap();
2248        let batch = reader.next().unwrap().unwrap();
2249        let col0 = batch.column(0);
2250        assert_eq!(col0.len(), 10);
2251        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2252        assert_eq!(col0_arr.value(0), "id0");
2253        let col1 = batch.column(1);
2254        assert_eq!(col1.len(), 10);
2255        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2256        assert_eq!(col1_arr.value(5), "value5");
2257    }
2258
2259    #[test]
2260    fn test_non_std_escape() {
2261        let schema = Schema::new(vec![
2262            Field::new("text1", DataType::Utf8, false),
2263            Field::new("text2", DataType::Utf8, false),
2264        ]);
2265        let builder = ReaderBuilder::new(Arc::new(schema))
2266            .with_header(false)
2267            .with_escape(b'\\'); // default is None, change to \
2268
2269        let mut csv_text = Vec::new();
2270        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2271        for index in 0..10 {
2272            let text1 = format!("id{index:}");
2273            let text2 = format!("value\\\"{index:}");
2274            csv_writer
2275                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n"))
2276                .unwrap();
2277        }
2278        let mut csv_reader = std::io::Cursor::new(&csv_text);
2279        let mut reader = builder.build(&mut csv_reader).unwrap();
2280        let batch = reader.next().unwrap().unwrap();
2281        let col0 = batch.column(0);
2282        assert_eq!(col0.len(), 10);
2283        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2284        assert_eq!(col0_arr.value(0), "id0");
2285        let col1 = batch.column(1);
2286        assert_eq!(col1.len(), 10);
2287        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2288        assert_eq!(col1_arr.value(5), "value\"5");
2289    }
2290
2291    #[test]
2292    fn test_non_std_terminator() {
2293        let schema = Schema::new(vec![
2294            Field::new("text1", DataType::Utf8, false),
2295            Field::new("text2", DataType::Utf8, false),
2296        ]);
2297        let builder = ReaderBuilder::new(Arc::new(schema))
2298            .with_header(false)
2299            .with_terminator(b'\n'); // default is CRLF, change to LF
2300
2301        let mut csv_text = Vec::new();
2302        let mut csv_writer = std::io::Cursor::new(&mut csv_text);
2303        for index in 0..10 {
2304            let text1 = format!("id{index:}");
2305            let text2 = format!("value{index:}");
2306            csv_writer
2307                .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n"))
2308                .unwrap();
2309        }
2310        let mut csv_reader = std::io::Cursor::new(&csv_text);
2311        let mut reader = builder.build(&mut csv_reader).unwrap();
2312        let batch = reader.next().unwrap().unwrap();
2313        let col0 = batch.column(0);
2314        assert_eq!(col0.len(), 10);
2315        let col0_arr = col0.as_any().downcast_ref::<StringArray>().unwrap();
2316        assert_eq!(col0_arr.value(0), "id0");
2317        let col1 = batch.column(1);
2318        assert_eq!(col1.len(), 10);
2319        let col1_arr = col1.as_any().downcast_ref::<StringArray>().unwrap();
2320        assert_eq!(col1_arr.value(5), "value5");
2321    }
2322
2323    #[test]
2324    fn test_header_bounds() {
2325        let csv = "a,b\na,b\na,b\na,b\na,b\n";
2326        let tests = [
2327            (None, false, 5),
2328            (None, true, 4),
2329            (Some((0, 4)), false, 4),
2330            (Some((1, 4)), false, 3),
2331            (Some((0, 4)), true, 4),
2332            (Some((1, 4)), true, 3),
2333        ];
2334        let schema = Arc::new(Schema::new(vec![
2335            Field::new("a", DataType::Utf8, false),
2336            Field::new("a", DataType::Utf8, false),
2337        ]));
2338
2339        for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() {
2340            let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header);
2341            if let Some((start, end)) = bounds {
2342                reader = reader.with_bounds(start, end);
2343            }
2344            let b = reader
2345                .build_buffered(Cursor::new(csv.as_bytes()))
2346                .unwrap()
2347                .next()
2348                .unwrap()
2349                .unwrap();
2350            assert_eq!(b.num_rows(), expected, "{idx}");
2351        }
2352    }
2353
2354    #[test]
2355    fn test_null_boolean() {
2356        let csv = "true,false\nFalse,True\n,True\nFalse,";
2357        let schema = Arc::new(Schema::new(vec![
2358            Field::new("a", DataType::Boolean, true),
2359            Field::new("a", DataType::Boolean, true),
2360        ]));
2361
2362        let b = ReaderBuilder::new(schema)
2363            .build_buffered(Cursor::new(csv.as_bytes()))
2364            .unwrap()
2365            .next()
2366            .unwrap()
2367            .unwrap();
2368
2369        assert_eq!(b.num_rows(), 4);
2370        assert_eq!(b.num_columns(), 2);
2371
2372        let c = b.column(0).as_boolean();
2373        assert_eq!(c.null_count(), 1);
2374        assert!(c.value(0));
2375        assert!(!c.value(1));
2376        assert!(c.is_null(2));
2377        assert!(!c.value(3));
2378
2379        let c = b.column(1).as_boolean();
2380        assert_eq!(c.null_count(), 1);
2381        assert!(!c.value(0));
2382        assert!(c.value(1));
2383        assert!(c.value(2));
2384        assert!(c.is_null(3));
2385    }
2386
2387    #[test]
2388    fn test_truncated_rows() {
2389        let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
2390        let schema = Arc::new(Schema::new(vec![
2391            Field::new("a", DataType::Int32, true),
2392            Field::new("b", DataType::Int32, true),
2393            Field::new("c", DataType::Int32, true),
2394        ]));
2395
2396        let reader = ReaderBuilder::new(schema.clone())
2397            .with_header(true)
2398            .with_truncated_rows(true)
2399            .build(Cursor::new(data))
2400            .unwrap();
2401
2402        let batches = reader.collect::<Result<Vec<_>, _>>();
2403        assert!(batches.is_ok());
2404        let batch = batches.unwrap().into_iter().next().unwrap();
2405        // Empty rows are skipped by the underlying csv parser
2406        assert_eq!(batch.num_rows(), 3);
2407
2408        let reader = ReaderBuilder::new(schema.clone())
2409            .with_header(true)
2410            .with_truncated_rows(false)
2411            .build(Cursor::new(data))
2412            .unwrap();
2413
2414        let batches = reader.collect::<Result<Vec<_>, _>>();
2415        assert!(match batches {
2416            Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
2417            _ => false,
2418        });
2419    }
2420
2421    #[test]
2422    fn test_truncated_rows_csv() {
2423        let file = File::open("test/data/truncated_rows.csv").unwrap();
2424        let schema = Arc::new(Schema::new(vec![
2425            Field::new("Name", DataType::Utf8, true),
2426            Field::new("Age", DataType::UInt32, true),
2427            Field::new("Occupation", DataType::Utf8, true),
2428            Field::new("DOB", DataType::Date32, true),
2429        ]));
2430        let reader = ReaderBuilder::new(schema.clone())
2431            .with_header(true)
2432            .with_batch_size(24)
2433            .with_truncated_rows(true);
2434        let csv = reader.build(file).unwrap();
2435        let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();
2436
2437        assert_eq!(batches.len(), 1);
2438        let batch = &batches[0];
2439        assert_eq!(batch.num_rows(), 6);
2440        assert_eq!(batch.num_columns(), 4);
2441        let name = batch
2442            .column(0)
2443            .as_any()
2444            .downcast_ref::<StringArray>()
2445            .unwrap();
2446        let age = batch
2447            .column(1)
2448            .as_any()
2449            .downcast_ref::<UInt32Array>()
2450            .unwrap();
2451        let occupation = batch
2452            .column(2)
2453            .as_any()
2454            .downcast_ref::<StringArray>()
2455            .unwrap();
2456        let dob = batch
2457            .column(3)
2458            .as_any()
2459            .downcast_ref::<Date32Array>()
2460            .unwrap();
2461
2462        assert_eq!(name.value(0), "A1");
2463        assert_eq!(name.value(1), "B2");
2464        assert!(name.is_null(2));
2465        assert_eq!(name.value(3), "C3");
2466        assert_eq!(name.value(4), "D4");
2467        assert_eq!(name.value(5), "E5");
2468
2469        assert_eq!(age.value(0), 34);
2470        assert_eq!(age.value(1), 29);
2471        assert!(age.is_null(2));
2472        assert_eq!(age.value(3), 45);
2473        assert!(age.is_null(4));
2474        assert_eq!(age.value(5), 31);
2475
2476        assert_eq!(occupation.value(0), "Engineer");
2477        assert_eq!(occupation.value(1), "Doctor");
2478        assert!(occupation.is_null(2));
2479        assert_eq!(occupation.value(3), "Artist");
2480        assert!(occupation.is_null(4));
2481        assert!(occupation.is_null(5));
2482
2483        assert_eq!(dob.value(0), 5675);
2484        assert!(dob.is_null(1));
2485        assert!(dob.is_null(2));
2486        assert_eq!(dob.value(3), -1858);
2487        assert!(dob.is_null(4));
2488        assert!(dob.is_null(5));
2489    }
2490
2491    #[test]
2492    fn test_truncated_rows_not_nullable_error() {
2493        let data = "a,b,c\n1,2,3\n4,5";
2494        let schema = Arc::new(Schema::new(vec![
2495            Field::new("a", DataType::Int32, false),
2496            Field::new("b", DataType::Int32, false),
2497            Field::new("c", DataType::Int32, false),
2498        ]));
2499
2500        let reader = ReaderBuilder::new(schema.clone())
2501            .with_header(true)
2502            .with_truncated_rows(true)
2503            .build(Cursor::new(data))
2504            .unwrap();
2505
2506        let batches = reader.collect::<Result<Vec<_>, _>>();
2507        assert!(match batches {
2508            Err(ArrowError::InvalidArgumentError(e)) =>
2509                e.to_string().contains("contains null values"),
2510            _ => false,
2511        });
2512    }
2513
2514    #[test]
2515    fn test_buffered() {
2516        let tests = [
2517            ("test/data/uk_cities.csv", false, 37),
2518            ("test/data/various_types.csv", true, 10),
2519            ("test/data/decimal_test.csv", false, 10),
2520        ];
2521
2522        for (path, has_header, expected_rows) in tests {
2523            let (schema, _) = Format::default()
2524                .infer_schema(File::open(path).unwrap(), None)
2525                .unwrap();
2526            let schema = Arc::new(schema);
2527
2528            for batch_size in [1, 4] {
2529                for capacity in [1, 3, 7, 100] {
2530                    let reader = ReaderBuilder::new(schema.clone())
2531                        .with_batch_size(batch_size)
2532                        .with_header(has_header)
2533                        .build(File::open(path).unwrap())
2534                        .unwrap();
2535
2536                    let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
2537
2538                    assert_eq!(
2539                        expected.iter().map(|x| x.num_rows()).sum::<usize>(),
2540                        expected_rows
2541                    );
2542
2543                    let buffered =
2544                        std::io::BufReader::with_capacity(capacity, File::open(path).unwrap());
2545
2546                    let reader = ReaderBuilder::new(schema.clone())
2547                        .with_batch_size(batch_size)
2548                        .with_header(has_header)
2549                        .build_buffered(buffered)
2550                        .unwrap();
2551
2552                    let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
2553                    assert_eq!(expected, actual)
2554                }
2555            }
2556        }
2557    }
2558
2559    fn err_test(csv: &[u8], expected: &str) {
2560        fn err_test_with_schema(csv: &[u8], expected: &str, schema: Arc<Schema>) {
2561            let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv));
2562            let b = ReaderBuilder::new(schema)
2563                .with_batch_size(2)
2564                .build_buffered(buffer)
2565                .unwrap();
2566            let err = b.collect::<Result<Vec<_>, _>>().unwrap_err().to_string();
2567            assert_eq!(err, expected)
2568        }
2569
2570        let schema_utf8 = Arc::new(Schema::new(vec![
2571            Field::new("text1", DataType::Utf8, true),
2572            Field::new("text2", DataType::Utf8, true),
2573        ]));
2574        err_test_with_schema(csv, expected, schema_utf8);
2575
2576        let schema_utf8view = Arc::new(Schema::new(vec![
2577            Field::new("text1", DataType::Utf8View, true),
2578            Field::new("text2", DataType::Utf8View, true),
2579        ]));
2580        err_test_with_schema(csv, expected, schema_utf8view);
2581    }
2582
2583    #[test]
2584    fn test_invalid_utf8() {
2585        err_test(
2586            b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,",
2587            "Csv error: Encountered invalid UTF-8 data for line 2 and field 2",
2588        );
2589
2590        err_test(
2591            b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,",
2592            "Csv error: Encountered invalid UTF-8 data for line 3 and field 1",
2593        );
2594
2595        err_test(
2596            b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2597            "Csv error: Encountered invalid UTF-8 data for line 5 and field 2",
2598        );
2599
2600        err_test(
2601            b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF",
2602            "Csv error: Encountered invalid UTF-8 data for line 1 and field 1",
2603        );
2604    }
2605
2606    struct InstrumentedRead<R> {
2607        r: R,
2608        fill_count: usize,
2609        fill_sizes: Vec<usize>,
2610    }
2611
2612    impl<R> InstrumentedRead<R> {
2613        fn new(r: R) -> Self {
2614            Self {
2615                r,
2616                fill_count: 0,
2617                fill_sizes: vec![],
2618            }
2619        }
2620    }
2621
2622    impl<R: Seek> Seek for InstrumentedRead<R> {
2623        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
2624            self.r.seek(pos)
2625        }
2626    }
2627
2628    impl<R: BufRead> Read for InstrumentedRead<R> {
2629        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
2630            self.r.read(buf)
2631        }
2632    }
2633
2634    impl<R: BufRead> BufRead for InstrumentedRead<R> {
2635        fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
2636            self.fill_count += 1;
2637            let buf = self.r.fill_buf()?;
2638            self.fill_sizes.push(buf.len());
2639            Ok(buf)
2640        }
2641
2642        fn consume(&mut self, amt: usize) {
2643            self.r.consume(amt)
2644        }
2645    }
2646
2647    #[test]
2648    fn test_io() {
2649        let schema = Arc::new(Schema::new(vec![
2650            Field::new("a", DataType::Utf8, false),
2651            Field::new("b", DataType::Utf8, false),
2652        ]));
2653        let csv = "foo,bar\nbaz,foo\na,b\nc,d";
2654        let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes()));
2655        let reader = ReaderBuilder::new(schema)
2656            .with_batch_size(3)
2657            .build_buffered(&mut read)
2658            .unwrap();
2659
2660        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2661        assert_eq!(batches.len(), 2);
2662        assert_eq!(batches[0].num_rows(), 3);
2663        assert_eq!(batches[1].num_rows(), 1);
2664
2665        // Expect 4 calls to fill_buf
2666        // 1. Read first 3 rows
2667        // 2. Read final row
2668        // 3. Delimit and flush final row
2669        // 4. Iterator finished
2670        assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
2671        assert_eq!(read.fill_count, 4);
2672    }
2673
2674    #[test]
2675    fn test_inference() {
2676        let cases: &[(&[&str], DataType)] = &[
2677            (&[], DataType::Null),
2678            (&["false", "12"], DataType::Utf8),
2679            (&["12", "cupcakes"], DataType::Utf8),
2680            (&["12", "12.4"], DataType::Float64),
2681            (&["14050", "24332"], DataType::Int64),
2682            (&["14050.0", "true"], DataType::Utf8),
2683            (&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
2684            (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
2685            (
2686                &["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2687                DataType::Timestamp(TimeUnit::Second, None),
2688            ),
2689            (&["2020-03-19", "2020-03-20"], DataType::Date32),
2690            (
2691                &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
2692                DataType::Timestamp(TimeUnit::Second, None),
2693            ),
2694            (
2695                &[
2696                    "2020-03-19",
2697                    "2020-03-19 02:00:00",
2698                    "2020-03-19 00:00:00.000",
2699                ],
2700                DataType::Timestamp(TimeUnit::Millisecond, None),
2701            ),
2702            (
2703                &[
2704                    "2020-03-19",
2705                    "2020-03-19 02:00:00",
2706                    "2020-03-19 00:00:00.000000",
2707                ],
2708                DataType::Timestamp(TimeUnit::Microsecond, None),
2709            ),
2710            (
2711                &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"],
2712                DataType::Timestamp(TimeUnit::Second, None),
2713            ),
2714            (
2715                &[
2716                    "2020-03-19",
2717                    "2020-03-19 02:00:00+02:00",
2718                    "2020-03-19 02:00:00Z",
2719                    "2020-03-19 02:00:00.12Z",
2720                ],
2721                DataType::Timestamp(TimeUnit::Millisecond, None),
2722            ),
2723            (
2724                &[
2725                    "2020-03-19",
2726                    "2020-03-19 02:00:00.000000000",
2727                    "2020-03-19 00:00:00.000000",
2728                ],
2729                DataType::Timestamp(TimeUnit::Nanosecond, None),
2730            ),
2731        ];
2732
2733        for (values, expected) in cases {
2734            let mut t = InferredDataType::default();
2735            for v in *values {
2736                t.update(v)
2737            }
2738            assert_eq!(&t.get(), expected, "{values:?}")
2739        }
2740    }
2741
2742    #[test]
2743    fn test_record_length_mismatch() {
2744        let csv = "\
2745        a,b,c\n\
2746        1,2,3\n\
2747        4,5\n\
2748        6,7,8";
2749        let mut read = Cursor::new(csv.as_bytes());
2750        let result = Format::default()
2751            .with_header(true)
2752            .infer_schema(&mut read, None);
2753        assert!(result.is_err());
2754        // Include line number in the error message to help locate and fix the issue
2755        assert_eq!(
2756            result.err().unwrap().to_string(),
2757            "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3"
2758        );
2759    }
2760
2761    #[test]
2762    fn test_comment() {
2763        let schema = Schema::new(vec![
2764            Field::new("a", DataType::Int8, false),
2765            Field::new("b", DataType::Int8, false),
2766        ]);
2767
2768        let csv = "# comment1 \n1,2\n#comment2\n11,22";
2769        let mut read = Cursor::new(csv.as_bytes());
2770        let reader = ReaderBuilder::new(Arc::new(schema))
2771            .with_comment(b'#')
2772            .build(&mut read)
2773            .unwrap();
2774
2775        let batches = reader.collect::<Result<Vec<_>, _>>().unwrap();
2776        assert_eq!(batches.len(), 1);
2777        let b = batches.first().unwrap();
2778        assert_eq!(b.num_columns(), 2);
2779        assert_eq!(
2780            b.column(0)
2781                .as_any()
2782                .downcast_ref::<Int8Array>()
2783                .unwrap()
2784                .values(),
2785            &vec![1, 11]
2786        );
2787        assert_eq!(
2788            b.column(1)
2789                .as_any()
2790                .downcast_ref::<Int8Array>()
2791                .unwrap()
2792                .values(),
2793            &vec![2, 22]
2794        );
2795    }
2796
2797    #[test]
2798    fn test_parse_string_view_single_column() {
2799        let csv = ["foo", "something_cannot_be_inlined", "foobar"].join("\n");
2800        let schema = Arc::new(Schema::new(vec![Field::new(
2801            "c1",
2802            DataType::Utf8View,
2803            true,
2804        )]));
2805
2806        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2807
2808        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2809        assert_eq!(decoded, csv.len());
2810        decoder.decode(&[]).unwrap();
2811
2812        let batch = decoder.flush().unwrap().unwrap();
2813        assert_eq!(batch.num_columns(), 1);
2814        assert_eq!(batch.num_rows(), 3);
2815        let col = batch.column(0).as_string_view();
2816        assert_eq!(col.data_type(), &DataType::Utf8View);
2817        assert_eq!(col.value(0), "foo");
2818        assert_eq!(col.value(1), "something_cannot_be_inlined");
2819        assert_eq!(col.value(2), "foobar");
2820    }
2821
2822    #[test]
2823    fn test_parse_string_view_multi_column() {
2824        let csv = ["foo,", ",something_cannot_be_inlined", "foobarfoobar,bar"].join("\n");
2825        let schema = Arc::new(Schema::new(vec![
2826            Field::new("c1", DataType::Utf8View, true),
2827            Field::new("c2", DataType::Utf8View, true),
2828        ]));
2829
2830        let mut decoder = ReaderBuilder::new(schema).build_decoder();
2831
2832        let decoded = decoder.decode(csv.as_bytes()).unwrap();
2833        assert_eq!(decoded, csv.len());
2834        decoder.decode(&[]).unwrap();
2835
2836        let batch = decoder.flush().unwrap().unwrap();
2837        assert_eq!(batch.num_columns(), 2);
2838        assert_eq!(batch.num_rows(), 3);
2839        let c1 = batch.column(0).as_string_view();
2840        let c2 = batch.column(1).as_string_view();
2841        assert_eq!(c1.data_type(), &DataType::Utf8View);
2842        assert_eq!(c2.data_type(), &DataType::Utf8View);
2843
2844        assert!(!c1.is_null(0));
2845        assert!(c1.is_null(1));
2846        assert!(!c1.is_null(2));
2847        assert_eq!(c1.value(0), "foo");
2848        assert_eq!(c1.value(2), "foobarfoobar");
2849
2850        assert!(c2.is_null(0));
2851        assert!(!c2.is_null(1));
2852        assert!(!c2.is_null(2));
2853        assert_eq!(c2.value(1), "something_cannot_be_inlined");
2854        assert_eq!(c2.value(2), "bar");
2855    }
2856}