arrow_csv/reader/
records.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_schema::ArrowError;
19use csv_core::{ReadRecordResult, Reader};
20
21/// The estimated length of a field in bytes
22const AVERAGE_FIELD_SIZE: usize = 8;
23
24/// The minimum amount of data in a single read
25const MIN_CAPACITY: usize = 1024;
26
27/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`]
28#[derive(Debug)]
29pub struct RecordDecoder {
30    delimiter: Reader,
31
32    /// The expected number of fields per row
33    num_columns: usize,
34
35    /// The current line number
36    line_number: usize,
37
38    /// Offsets delimiting field start positions
39    offsets: Vec<usize>,
40
41    /// The current offset into `self.offsets`
42    ///
43    /// We track this independently of Vec to avoid re-zeroing memory
44    offsets_len: usize,
45
46    /// The number of fields read for the current record
47    current_field: usize,
48
49    /// The number of rows buffered
50    num_rows: usize,
51
52    /// Decoded field data
53    data: Vec<u8>,
54
55    /// Offsets into data
56    ///
57    /// We track this independently of Vec to avoid re-zeroing memory
58    data_len: usize,
59
60    /// Whether rows with less than expected columns are considered valid
61    ///
62    /// Default value is false
63    /// When enabled fills in missing columns with null
64    truncated_rows: bool,
65}
66
67impl RecordDecoder {
68    pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
69        Self {
70            delimiter,
71            num_columns,
72            line_number: 1,
73            offsets: vec![],
74            offsets_len: 1, // The first offset is always 0
75            current_field: 0,
76            data_len: 0,
77            data: vec![],
78            num_rows: 0,
79            truncated_rows,
80        }
81    }
82
83    /// Decodes records from `input` returning the number of records and bytes read
84    ///
85    /// Note: this expects to be called with an empty `input` to signal EOF
86    pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> {
87        if to_read == 0 {
88            return Ok((0, 0));
89        }
90
91        // Reserve sufficient capacity in offsets
92        self.offsets
93            .resize(self.offsets_len + to_read * self.num_columns, 0);
94
95        // The current offset into `input`
96        let mut input_offset = 0;
97
98        // The number of rows decoded in this pass
99        let mut read = 0;
100
101        loop {
102            // Reserve necessary space in output data based on best estimate
103            let remaining_rows = to_read - read;
104            let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
105            let estimated_data = capacity.max(MIN_CAPACITY);
106            self.data.resize(self.data_len + estimated_data, 0);
107
108            // Try to read a record
109            loop {
110                let (result, bytes_read, bytes_written, end_positions) =
111                    self.delimiter.read_record(
112                        &input[input_offset..],
113                        &mut self.data[self.data_len..],
114                        &mut self.offsets[self.offsets_len..],
115                    );
116
117                self.current_field += end_positions;
118                self.offsets_len += end_positions;
119                input_offset += bytes_read;
120                self.data_len += bytes_written;
121
122                match result {
123                    ReadRecordResult::End | ReadRecordResult::InputEmpty => {
124                        // Reached end of input
125                        return Ok((read, input_offset));
126                    }
127                    // Need to allocate more capacity
128                    ReadRecordResult::OutputFull => break,
129                    ReadRecordResult::OutputEndsFull => {
130                        return Err(ArrowError::CsvError(format!(
131                            "incorrect number of fields for line {}, expected {} got more than {}",
132                            self.line_number, self.num_columns, self.current_field
133                        )));
134                    }
135                    ReadRecordResult::Record => {
136                        if self.current_field != self.num_columns {
137                            if self.truncated_rows && self.current_field < self.num_columns {
138                                // If the number of fields is less than expected, pad with nulls
139                                let fill_count = self.num_columns - self.current_field;
140                                let fill_value = self.offsets[self.offsets_len - 1];
141                                self.offsets[self.offsets_len..self.offsets_len + fill_count]
142                                    .fill(fill_value);
143                                self.offsets_len += fill_count;
144                            } else {
145                                return Err(ArrowError::CsvError(format!(
146                                    "incorrect number of fields for line {}, expected {} got {}",
147                                    self.line_number, self.num_columns, self.current_field
148                                )));
149                            }
150                        }
151                        read += 1;
152                        self.current_field = 0;
153                        self.line_number += 1;
154                        self.num_rows += 1;
155
156                        if read == to_read {
157                            // Read sufficient rows
158                            return Ok((read, input_offset));
159                        }
160
161                        if input.len() == input_offset {
162                            // Input exhausted, need to read more
163                            // Without this read_record will interpret the empty input
164                            // byte array as indicating the end of the file
165                            return Ok((read, input_offset));
166                        }
167                    }
168                }
169            }
170        }
171    }
172
173    /// Returns the current number of buffered records
174    pub fn len(&self) -> usize {
175        self.num_rows
176    }
177
178    /// Returns true if the decoder is empty
179    pub fn is_empty(&self) -> bool {
180        self.num_rows == 0
181    }
182
183    /// Clears the current contents of the decoder
184    pub fn clear(&mut self) {
185        // This does not reset current_field to allow clearing part way through a record
186        self.offsets_len = 1;
187        self.data_len = 0;
188        self.num_rows = 0;
189    }
190
191    /// Flushes the current contents of the reader
192    pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
193        if self.current_field != 0 {
194            return Err(ArrowError::CsvError(
195                "Cannot flush part way through record".to_string(),
196            ));
197        }
198
199        // csv_core::Reader writes end offsets relative to the start of the row
200        // Therefore scan through and offset these based on the cumulative row offsets
201        let mut row_offset = 0;
202        self.offsets[1..self.offsets_len]
203            .chunks_exact_mut(self.num_columns)
204            .for_each(|row| {
205                let offset = row_offset;
206                row.iter_mut().for_each(|x| {
207                    *x += offset;
208                    row_offset = *x;
209                });
210            });
211
212        // Need to truncate data t1o the actual amount of data read
213        let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
214            let valid_up_to = e.valid_up_to();
215
216            // We can't use binary search because of empty fields
217            let idx = self.offsets[..self.offsets_len]
218                .iter()
219                .rposition(|x| *x <= valid_up_to)
220                .unwrap();
221
222            let field = idx % self.num_columns + 1;
223            let line_offset = self.line_number - self.num_rows;
224            let line = line_offset + idx / self.num_columns;
225
226            ArrowError::CsvError(format!(
227                "Encountered invalid UTF-8 data for line {line} and field {field}"
228            ))
229        })?;
230
231        let offsets = &self.offsets[..self.offsets_len];
232        let num_rows = self.num_rows;
233
234        // Reset state
235        self.offsets_len = 1;
236        self.data_len = 0;
237        self.num_rows = 0;
238
239        Ok(StringRecords {
240            num_rows,
241            num_columns: self.num_columns,
242            offsets,
243            data,
244        })
245    }
246}
247
248/// A collection of parsed, UTF-8 CSV records
249#[derive(Debug)]
250pub struct StringRecords<'a> {
251    num_columns: usize,
252    num_rows: usize,
253    offsets: &'a [usize],
254    data: &'a str,
255}
256
257impl<'a> StringRecords<'a> {
258    fn get(&self, index: usize) -> StringRecord<'a> {
259        let field_idx = index * self.num_columns;
260        StringRecord {
261            data: self.data,
262            offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1],
263        }
264    }
265
266    pub fn len(&self) -> usize {
267        self.num_rows
268    }
269
270    pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
271        (0..self.num_rows).map(|x| self.get(x))
272    }
273}
274
275/// A single parsed, UTF-8 CSV record
276#[derive(Debug, Clone, Copy)]
277pub struct StringRecord<'a> {
278    data: &'a str,
279    offsets: &'a [usize],
280}
281
282impl<'a> StringRecord<'a> {
283    pub fn get(&self, index: usize) -> &'a str {
284        let end = self.offsets[index + 1];
285        let start = self.offsets[index];
286
287        // SAFETY:
288        // Parsing produces offsets at valid byte boundaries
289        unsafe { self.data.get_unchecked(start..end) }
290    }
291}
292
293impl std::fmt::Display for StringRecord<'_> {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        let num_fields = self.offsets.len() - 1;
296        write!(f, "[")?;
297        for i in 0..num_fields {
298            if i > 0 {
299                write!(f, ",")?;
300            }
301            write!(f, "{}", self.get(i))?;
302        }
303        write!(f, "]")?;
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use crate::reader::records::RecordDecoder;
311    use csv_core::Reader;
312    use std::io::{BufRead, BufReader, Cursor};
313
314    #[test]
315    fn test_basic() {
316        let csv = [
317            "foo,bar,baz",
318            "a,b,c",
319            "12,3,5",
320            "\"asda\"\"asas\",\"sdffsnsd\", as",
321        ]
322        .join("\n");
323
324        let mut expected = vec![
325            vec!["foo", "bar", "baz"],
326            vec!["a", "b", "c"],
327            vec!["12", "3", "5"],
328            vec!["asda\"asas", "sdffsnsd", " as"],
329        ]
330        .into_iter();
331
332        let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
333        let mut decoder = RecordDecoder::new(Reader::new(), 3, false);
334
335        loop {
336            let to_read = 3;
337            let mut read = 0;
338            loop {
339                let buf = reader.fill_buf().unwrap();
340                let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
341
342                reader.consume(bytes);
343                read += records;
344
345                if read == to_read || bytes == 0 {
346                    break;
347                }
348            }
349            if read == 0 {
350                break;
351            }
352
353            let b = decoder.flush().unwrap();
354            b.iter().zip(&mut expected).for_each(|(record, expected)| {
355                let actual = (0..3)
356                    .map(|field_idx| record.get(field_idx))
357                    .collect::<Vec<_>>();
358                assert_eq!(actual, expected)
359            });
360        }
361        assert!(expected.next().is_none());
362    }
363
364    #[test]
365    fn test_invalid_fields() {
366        let csv = "a,b\nb,c\na\n";
367        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
368        let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
369
370        let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";
371
372        assert_eq!(err, expected);
373
374        // Test with initial skip
375        let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
376        let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
377        assert_eq!(skipped, 1);
378        decoder.clear();
379
380        let remaining = &csv.as_bytes()[bytes..];
381        let err = decoder.decode(remaining, 3).unwrap_err().to_string();
382        assert_eq!(err, expected);
383    }
384
385    #[test]
386    fn test_skip_insufficient_rows() {
387        let csv = "a\nv\n";
388        let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
389        let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
390        assert_eq!(read, 2);
391        assert_eq!(bytes, csv.len());
392    }
393
394    #[test]
395    fn test_truncated_rows() {
396        let csv = "a,b\nv\n,1\n,2\n,3\n";
397        let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
398        let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
399        assert_eq!(read, 5);
400        assert_eq!(bytes, csv.len());
401    }
402}