arrow_csv/reader/
records.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_schema::ArrowError;
19use csv_core::{ReadRecordResult, Reader};
20
21/// The estimated length of a field in bytes
22const AVERAGE_FIELD_SIZE: usize = 8;
23
24/// The minimum amount of data in a single read
25const MIN_CAPACITY: usize = 1024;
26
27/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`]
28#[derive(Debug)]
29pub struct RecordDecoder {
30    delimiter: Reader,
31
32    /// The expected number of fields per row
33    num_columns: usize,
34
35    /// The current line number
36    line_number: usize,
37
38    /// Offsets delimiting field start positions
39    offsets: Vec<usize>,
40
41    /// The current offset into `self.offsets`
42    ///
43    /// We track this independently of Vec to avoid re-zeroing memory
44    offsets_len: usize,
45
46    /// The number of fields read for the current record
47    current_field: usize,
48
49    /// The number of rows buffered
50    num_rows: usize,
51
52    /// Decoded field data
53    data: Vec<u8>,
54
55    /// Offsets into data
56    ///
57    /// We track this independently of Vec to avoid re-zeroing memory
58    data_len: usize,
59
60    /// Whether rows with less than expected columns are considered valid
61    ///
62    /// Default value is false
63    /// When enabled fills in missing columns with null
64    truncated_rows: bool,
65}
66
67impl RecordDecoder {
68    pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
69        Self {
70            delimiter,
71            num_columns,
72            line_number: 1,
73            offsets: vec![],
74            offsets_len: 1, // The first offset is always 0
75            current_field: 0,
76            data_len: 0,
77            data: vec![],
78            num_rows: 0,
79            truncated_rows,
80        }
81    }
82
83    /// Decodes records from `input` returning the number of records and bytes read
84    ///
85    /// Note: this expects to be called with an empty `input` to signal EOF
86    pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> {
87        if to_read == 0 {
88            return Ok((0, 0));
89        }
90
91        // Reserve sufficient capacity in offsets
92        self.offsets
93            .resize(self.offsets_len + to_read * self.num_columns, 0);
94
95        // The current offset into `input`
96        let mut input_offset = 0;
97
98        // The number of rows decoded in this pass
99        let mut read = 0;
100
101        loop {
102            // Reserve necessary space in output data based on best estimate
103            let remaining_rows = to_read - read;
104            let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
105            let estimated_data = capacity.max(MIN_CAPACITY);
106            self.data.resize(self.data_len + estimated_data, 0);
107
108            // Try to read a record
109            loop {
110                let (result, bytes_read, bytes_written, end_positions) =
111                    self.delimiter.read_record(
112                        &input[input_offset..],
113                        &mut self.data[self.data_len..],
114                        &mut self.offsets[self.offsets_len..],
115                    );
116
117                self.current_field += end_positions;
118                self.offsets_len += end_positions;
119                input_offset += bytes_read;
120                self.data_len += bytes_written;
121
122                match result {
123                    ReadRecordResult::End | ReadRecordResult::InputEmpty => {
124                        // Reached end of input
125                        return Ok((read, input_offset));
126                    }
127                    // Need to allocate more capacity
128                    ReadRecordResult::OutputFull => break,
129                    ReadRecordResult::OutputEndsFull => {
130                        return Err(ArrowError::CsvError(format!(
131                            "incorrect number of fields for line {}, expected {} got more than {}",
132                            self.line_number, self.num_columns, self.current_field
133                        )));
134                    }
135                    ReadRecordResult::Record => {
136                        if self.current_field != self.num_columns {
137                            if self.truncated_rows && self.current_field < self.num_columns {
138                                // If the number of fields is less than expected, pad with nulls
139                                let fill_count = self.num_columns - self.current_field;
140                                let fill_value = self.offsets[self.offsets_len - 1];
141                                self.offsets[self.offsets_len..self.offsets_len + fill_count]
142                                    .fill(fill_value);
143                                self.offsets_len += fill_count;
144                            } else {
145                                return Err(ArrowError::CsvError(format!(
146                                    "incorrect number of fields for line {}, expected {} got {}",
147                                    self.line_number, self.num_columns, self.current_field
148                                )));
149                            }
150                        }
151                        read += 1;
152                        self.current_field = 0;
153                        self.line_number += 1;
154                        self.num_rows += 1;
155
156                        if read == to_read {
157                            // Read sufficient rows
158                            return Ok((read, input_offset));
159                        }
160
161                        if input.len() == input_offset {
162                            // Input exhausted, need to read more
163                            // Without this read_record will interpret the empty input
164                            // byte array as indicating the end of the file
165                            return Ok((read, input_offset));
166                        }
167                    }
168                }
169            }
170        }
171    }
172
173    /// Returns the current number of buffered records
174    pub fn len(&self) -> usize {
175        self.num_rows
176    }
177
178    /// Returns true if the decoder is empty
179    pub fn is_empty(&self) -> bool {
180        self.num_rows == 0
181    }
182
183    /// Clears the current contents of the decoder
184    pub fn clear(&mut self) {
185        // This does not reset current_field to allow clearing part way through a record
186        self.offsets_len = 1;
187        self.data_len = 0;
188        self.num_rows = 0;
189    }
190
191    /// Flushes the current contents of the reader
192    pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
193        if self.current_field != 0 {
194            return Err(ArrowError::CsvError(
195                "Cannot flush part way through record".to_string(),
196            ));
197        }
198
199        // csv_core::Reader writes end offsets relative to the start of the row
200        // Therefore scan through and offset these based on the cumulative row offsets
201        let mut row_offset = 0;
202        self.offsets[1..self.offsets_len]
203            .chunks_exact_mut(self.num_columns)
204            .for_each(|row| {
205                let offset = row_offset;
206                row.iter_mut().for_each(|x| {
207                    *x += offset;
208                    row_offset = *x;
209                });
210            });
211
212        // Need to truncate data t1o the actual amount of data read
213        let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
214            let valid_up_to = e.valid_up_to();
215
216            // We can't use binary search because of empty fields
217            let idx = self.offsets[..self.offsets_len]
218                .iter()
219                .rposition(|x| *x <= valid_up_to)
220                .unwrap();
221
222            let field = idx % self.num_columns + 1;
223            let line_offset = self.line_number - self.num_rows;
224            let line = line_offset + idx / self.num_columns;
225
226            ArrowError::CsvError(format!(
227                "Encountered invalid UTF-8 data for line {line} and field {field}"
228            ))
229        })?;
230
231        let offsets = &self.offsets[..self.offsets_len];
232        let num_rows = self.num_rows;
233
234        // Reset state
235        self.offsets_len = 1;
236        self.data_len = 0;
237        self.num_rows = 0;
238
239        Ok(StringRecords {
240            num_rows,
241            num_columns: self.num_columns,
242            offsets,
243            data,
244        })
245    }
246}
247
248/// A collection of parsed, UTF-8 CSV records
249#[derive(Debug)]
250pub struct StringRecords<'a> {
251    num_columns: usize,
252    num_rows: usize,
253    offsets: &'a [usize],
254    data: &'a str,
255}
256
257impl<'a> StringRecords<'a> {
258    fn get(&self, index: usize) -> StringRecord<'a> {
259        let field_idx = index * self.num_columns;
260        StringRecord {
261            data: self.data,
262            offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1],
263        }
264    }
265
266    pub fn len(&self) -> usize {
267        self.num_rows
268    }
269
270    pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
271        (0..self.num_rows).map(|x| self.get(x))
272    }
273}
274
275/// A single parsed, UTF-8 CSV record
276#[derive(Debug, Clone, Copy)]
277pub struct StringRecord<'a> {
278    data: &'a str,
279    offsets: &'a [usize],
280}
281
282impl<'a> StringRecord<'a> {
283    pub fn get(&self, index: usize) -> &'a str {
284        let end = self.offsets[index + 1];
285        let start = self.offsets[index];
286
287        // SAFETY:
288        // Parsing produces offsets at valid byte boundaries
289        unsafe { self.data.get_unchecked(start..end) }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use crate::reader::records::RecordDecoder;
296    use csv_core::Reader;
297    use std::io::{BufRead, BufReader, Cursor};
298
299    #[test]
300    fn test_basic() {
301        let csv = [
302            "foo,bar,baz",
303            "a,b,c",
304            "12,3,5",
305            "\"asda\"\"asas\",\"sdffsnsd\", as",
306        ]
307        .join("\n");
308
309        let mut expected = vec![
310            vec!["foo", "bar", "baz"],
311            vec!["a", "b", "c"],
312            vec!["12", "3", "5"],
313            vec!["asda\"asas", "sdffsnsd", " as"],
314        ]
315        .into_iter();
316
317        let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
318        let mut decoder = RecordDecoder::new(Reader::new(), 3, false);
319
320        loop {
321            let to_read = 3;
322            let mut read = 0;
323            loop {
324                let buf = reader.fill_buf().unwrap();
325                let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
326
327                reader.consume(bytes);
328                read += records;
329
330                if read == to_read || bytes == 0 {
331                    break;
332                }
333            }
334            if read == 0 {
335                break;
336            }
337
338            let b = decoder.flush().unwrap();
339            b.iter().zip(&mut expected).for_each(|(record, expected)| {
340                let actual = (0..3)
341                    .map(|field_idx| record.get(field_idx))
342                    .collect::<Vec<_>>();
343                assert_eq!(actual, expected)
344            });
345        }
346        assert!(expected.next().is_none());
347    }
348
349    #[test]
350    fn test_invalid_fields() {
351        let csv = "a,b\nb,c\na\n";
352        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
353        let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
354
355        let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";
356
357        assert_eq!(err, expected);
358
359        // Test with initial skip
360        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
361        let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
362        assert_eq!(skipped, 1);
363        decoder.clear();
364
365        let remaining = &csv.as_bytes()[bytes..];
366        let err = decoder.decode(remaining, 3).unwrap_err().to_string();
367        assert_eq!(err, expected);
368    }
369
370    #[test]
371    fn test_skip_insufficient_rows() {
372        let csv = "a\nv\n";
373        let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
374        let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
375        assert_eq!(read, 2);
376        assert_eq!(bytes, csv.len());
377    }
378
379    #[test]
380    fn test_truncated_rows() {
381        let csv = "a,b\nv\n,1\n,2\n,3\n";
382        let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
383        let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
384        assert_eq!(read, 5);
385        assert_eq!(bytes, csv.len());
386    }
387}