Skip to main content

arrow_csv/reader/
records.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_schema::ArrowError;
19use csv_core::{ReadRecordResult, Reader};
20
21/// The estimated length of a field in bytes
22const AVERAGE_FIELD_SIZE: usize = 8;
23
24/// The minimum amount of data in a single read
25const MIN_CAPACITY: usize = 1024;
26
27/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`]
28#[derive(Debug)]
29pub struct RecordDecoder {
30    delimiter: Reader,
31
32    /// The expected number of fields per row
33    num_columns: usize,
34
35    /// The current line number
36    line_number: usize,
37
38    /// Offsets delimiting field start positions
39    offsets: Vec<usize>,
40
41    /// The current offset into `self.offsets`
42    ///
43    /// We track this independently of Vec to avoid re-zeroing memory
44    offsets_len: usize,
45
46    /// The number of fields read for the current record
47    current_field: usize,
48
49    /// The number of rows buffered
50    num_rows: usize,
51
52    /// Decoded field data
53    data: Vec<u8>,
54
55    /// Offsets into data
56    ///
57    /// We track this independently of Vec to avoid re-zeroing memory
58    data_len: usize,
59
60    /// Whether rows with less than expected columns are considered valid
61    ///
62    /// Default value is false
63    /// When enabled fills in missing columns with null
64    truncated_rows: bool,
65}
66
67impl RecordDecoder {
68    pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
69        Self {
70            delimiter,
71            num_columns,
72            line_number: 1,
73            offsets: vec![],
74            offsets_len: 1, // The first offset is always 0
75            current_field: 0,
76            data_len: 0,
77            data: vec![],
78            num_rows: 0,
79            truncated_rows,
80        }
81    }
82
83    /// Decodes records from `input` returning the number of records and bytes read
84    ///
85    /// Note: this expects to be called with an empty `input` to signal EOF
86    pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> {
87        if to_read == 0 {
88            return Ok((0, 0));
89        }
90
91        // Reserve sufficient capacity in offsets
92        self.offsets
93            .resize(self.offsets_len + to_read * self.num_columns, 0);
94
95        // The current offset into `input`
96        let mut input_offset = 0;
97
98        // The number of rows decoded in this pass
99        let mut read = 0;
100
101        loop {
102            // Reserve necessary space in output data based on best estimate
103            let remaining_rows = to_read - read;
104            let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
105            let estimated_data = capacity.max(MIN_CAPACITY);
106            self.data.resize(self.data_len + estimated_data, 0);
107
108            // Try to read a record
109            loop {
110                let (result, bytes_read, bytes_written, end_positions) =
111                    self.delimiter.read_record(
112                        &input[input_offset..],
113                        &mut self.data[self.data_len..],
114                        &mut self.offsets[self.offsets_len..],
115                    );
116
117                self.current_field += end_positions;
118                self.offsets_len += end_positions;
119                input_offset += bytes_read;
120                self.data_len += bytes_written;
121
122                match result {
123                    ReadRecordResult::End | ReadRecordResult::InputEmpty => {
124                        // Reached end of input
125                        return Ok((read, input_offset));
126                    }
127                    // Need to allocate more capacity
128                    ReadRecordResult::OutputFull => break,
129                    ReadRecordResult::OutputEndsFull => {
130                        return Err(ArrowError::CsvError(format!(
131                            "incorrect number of fields for line {}, expected {} got more than {}",
132                            self.line_number, self.num_columns, self.current_field
133                        )));
134                    }
135                    ReadRecordResult::Record => {
136                        if self.current_field != self.num_columns {
137                            if self.truncated_rows && self.current_field < self.num_columns {
138                                // If the number of fields is less than expected, pad with nulls
139                                let fill_count = self.num_columns - self.current_field;
140                                let fill_value = self.offsets[self.offsets_len - 1];
141                                self.offsets[self.offsets_len..self.offsets_len + fill_count]
142                                    .fill(fill_value);
143                                self.offsets_len += fill_count;
144                            } else {
145                                return Err(ArrowError::CsvError(format!(
146                                    "incorrect number of fields for line {}, expected {} got {}",
147                                    self.line_number, self.num_columns, self.current_field
148                                )));
149                            }
150                        }
151                        read += 1;
152                        self.current_field = 0;
153                        self.line_number += 1;
154                        self.num_rows += 1;
155
156                        if read == to_read {
157                            // Read sufficient rows
158                            return Ok((read, input_offset));
159                        }
160
161                        if input.len() == input_offset {
162                            // Input exhausted, need to read more
163                            // Without this read_record will interpret the empty input
164                            // byte array as indicating the end of the file
165                            return Ok((read, input_offset));
166                        }
167                    }
168                }
169            }
170        }
171    }
172
173    /// Returns the current number of buffered records
174    pub fn len(&self) -> usize {
175        self.num_rows
176    }
177
178    /// Returns true if the decoder is empty
179    pub fn is_empty(&self) -> bool {
180        self.num_rows == 0
181    }
182
183    /// Clears the current contents of the decoder
184    pub fn clear(&mut self) {
185        // This does not reset current_field to allow clearing part way through a record
186        self.offsets_len = 1;
187        self.data_len = 0;
188        self.num_rows = 0;
189    }
190
191    /// Flushes the current contents of the reader
192    pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
193        if self.current_field != 0 {
194            return Err(ArrowError::CsvError(
195                "Cannot flush part way through record".to_string(),
196            ));
197        }
198
199        // csv_core::Reader writes end offsets relative to the start of the row
200        // Therefore scan through and offset these based on the cumulative row offsets
201        let mut row_offset: usize = 0;
202        self.offsets[1..self.offsets_len]
203            .chunks_exact_mut(self.num_columns)
204            .try_for_each(|row| -> Result<(), ArrowError> {
205                let offset = row_offset;
206                row.iter_mut().try_for_each(|x| -> Result<(), ArrowError> {
207                    *x = x.checked_add(offset).ok_or_else(|| {
208                        ArrowError::CsvError(
209                            "CSV record offsets overflowed usize while flushing".to_string(),
210                        )
211                    })?;
212                    row_offset = *x;
213                    Ok(())
214                })
215            })?;
216
217        // Need to truncate data t1o the actual amount of data read
218        let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
219            let valid_up_to = e.valid_up_to();
220
221            // We can't use binary search because of empty fields
222            let idx = self.offsets[..self.offsets_len]
223                .iter()
224                .rposition(|x| *x <= valid_up_to)
225                .unwrap();
226
227            let field = idx % self.num_columns + 1;
228            let line_offset = self.line_number - self.num_rows;
229            let line = line_offset + idx / self.num_columns;
230
231            ArrowError::CsvError(format!(
232                "Encountered invalid UTF-8 data for line {line} and field {field}"
233            ))
234        })?;
235
236        let offsets = &self.offsets[..self.offsets_len];
237        let num_rows = self.num_rows;
238
239        // Reset state
240        self.offsets_len = 1;
241        self.data_len = 0;
242        self.num_rows = 0;
243
244        Ok(StringRecords {
245            num_rows,
246            num_columns: self.num_columns,
247            offsets,
248            data,
249        })
250    }
251}
252
253/// A collection of parsed, UTF-8 CSV records
254#[derive(Debug)]
255pub struct StringRecords<'a> {
256    num_columns: usize,
257    num_rows: usize,
258    offsets: &'a [usize],
259    data: &'a str,
260}
261
262impl<'a> StringRecords<'a> {
263    fn get(&self, index: usize) -> StringRecord<'a> {
264        let field_idx = index * self.num_columns;
265        StringRecord {
266            data: self.data,
267            offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1],
268        }
269    }
270
271    pub fn len(&self) -> usize {
272        self.num_rows
273    }
274
275    pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
276        (0..self.num_rows).map(|x| self.get(x))
277    }
278}
279
280/// A single parsed, UTF-8 CSV record
281#[derive(Debug, Clone, Copy)]
282pub struct StringRecord<'a> {
283    data: &'a str,
284    offsets: &'a [usize],
285}
286
287impl<'a> StringRecord<'a> {
288    pub fn get(&self, index: usize) -> &'a str {
289        let end = self.offsets[index + 1];
290        let start = self.offsets[index];
291
292        // SAFETY:
293        // Parsing produces offsets at valid byte boundaries
294        unsafe { self.data.get_unchecked(start..end) }
295    }
296}
297
298impl std::fmt::Display for StringRecord<'_> {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        let num_fields = self.offsets.len() - 1;
301        write!(f, "[")?;
302        for i in 0..num_fields {
303            if i > 0 {
304                write!(f, ",")?;
305            }
306            write!(f, "{}", self.get(i))?;
307        }
308        write!(f, "]")?;
309        Ok(())
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use crate::reader::records::RecordDecoder;
316    use csv_core::Reader;
317    use std::io::{BufRead, BufReader, Cursor};
318
319    #[test]
320    fn test_basic() {
321        let csv = [
322            "foo,bar,baz",
323            "a,b,c",
324            "12,3,5",
325            "\"asda\"\"asas\",\"sdffsnsd\", as",
326        ]
327        .join("\n");
328
329        let mut expected = vec![
330            vec!["foo", "bar", "baz"],
331            vec!["a", "b", "c"],
332            vec!["12", "3", "5"],
333            vec!["asda\"asas", "sdffsnsd", " as"],
334        ]
335        .into_iter();
336
337        let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
338        let mut decoder = RecordDecoder::new(Reader::new(), 3, false);
339
340        loop {
341            let to_read = 3;
342            let mut read = 0;
343            loop {
344                let buf = reader.fill_buf().unwrap();
345                let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
346
347                reader.consume(bytes);
348                read += records;
349
350                if read == to_read || bytes == 0 {
351                    break;
352                }
353            }
354            if read == 0 {
355                break;
356            }
357
358            let b = decoder.flush().unwrap();
359            b.iter().zip(&mut expected).for_each(|(record, expected)| {
360                let actual = (0..3)
361                    .map(|field_idx| record.get(field_idx))
362                    .collect::<Vec<_>>();
363                assert_eq!(actual, expected)
364            });
365        }
366        assert!(expected.next().is_none());
367    }
368
369    #[test]
370    fn test_invalid_fields() {
371        let csv = "a,b\nb,c\na\n";
372        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
373        let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
374
375        let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";
376
377        assert_eq!(err, expected);
378
379        // Test with initial skip
380        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
381        let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
382        assert_eq!(skipped, 1);
383        decoder.clear();
384
385        let remaining = &csv.as_bytes()[bytes..];
386        let err = decoder.decode(remaining, 3).unwrap_err().to_string();
387        assert_eq!(err, expected);
388    }
389
390    #[test]
391    fn test_skip_insufficient_rows() {
392        let csv = "a\nv\n";
393        let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
394        let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
395        assert_eq!(read, 2);
396        assert_eq!(bytes, csv.len());
397    }
398
399    #[test]
400    fn test_truncated_rows() {
401        let csv = "a,b\nv\n,1\n,2\n,3\n";
402        let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
403        let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
404        assert_eq!(read, 5);
405        assert_eq!(bytes, csv.len());
406    }
407
408    /// Regression test for an overflow path found by the `arrow-csv`
409    /// cargo-fuzz harness being prototyped for #5332. Stages the
410    /// `RecordDecoder` state directly so that rebasing the second row's
411    /// end offset overflows `usize`. With the previous `*x += offset` this
412    /// panicked with `attempt to add with overflow`; the patched code
413    /// surfaces the condition as `ArrowError::CsvError`.
414    #[test]
415    fn test_flush_offset_overflow_returns_csv_error() {
416        let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
417        decoder.offsets = vec![0, usize::MAX, 1];
418        decoder.offsets_len = 3;
419        decoder.num_rows = 2;
420        let err = decoder.flush().unwrap_err();
421        assert_eq!(
422            err.to_string(),
423            "Csv error: CSV record offsets overflowed usize while flushing"
424        );
425    }
426}