1use arrow_schema::ArrowError;
19use csv_core::{ReadRecordResult, Reader};
20
21const AVERAGE_FIELD_SIZE: usize = 8;
23
24const MIN_CAPACITY: usize = 1024;
26
27#[derive(Debug)]
29pub struct RecordDecoder {
30 delimiter: Reader,
31
32 num_columns: usize,
34
35 line_number: usize,
37
38 offsets: Vec<usize>,
40
41 offsets_len: usize,
45
46 current_field: usize,
48
49 num_rows: usize,
51
52 data: Vec<u8>,
54
55 data_len: usize,
59
60 truncated_rows: bool,
65}
66
67impl RecordDecoder {
68 pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
69 Self {
70 delimiter,
71 num_columns,
72 line_number: 1,
73 offsets: vec![],
74 offsets_len: 1, current_field: 0,
76 data_len: 0,
77 data: vec![],
78 num_rows: 0,
79 truncated_rows,
80 }
81 }
82
83 pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> {
87 if to_read == 0 {
88 return Ok((0, 0));
89 }
90
91 self.offsets
93 .resize(self.offsets_len + to_read * self.num_columns, 0);
94
95 let mut input_offset = 0;
97
98 let mut read = 0;
100
101 loop {
102 let remaining_rows = to_read - read;
104 let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
105 let estimated_data = capacity.max(MIN_CAPACITY);
106 self.data.resize(self.data_len + estimated_data, 0);
107
108 loop {
110 let (result, bytes_read, bytes_written, end_positions) =
111 self.delimiter.read_record(
112 &input[input_offset..],
113 &mut self.data[self.data_len..],
114 &mut self.offsets[self.offsets_len..],
115 );
116
117 self.current_field += end_positions;
118 self.offsets_len += end_positions;
119 input_offset += bytes_read;
120 self.data_len += bytes_written;
121
122 match result {
123 ReadRecordResult::End | ReadRecordResult::InputEmpty => {
124 return Ok((read, input_offset));
126 }
127 ReadRecordResult::OutputFull => break,
129 ReadRecordResult::OutputEndsFull => {
130 return Err(ArrowError::CsvError(format!(
131 "incorrect number of fields for line {}, expected {} got more than {}",
132 self.line_number, self.num_columns, self.current_field
133 )));
134 }
135 ReadRecordResult::Record => {
136 if self.current_field != self.num_columns {
137 if self.truncated_rows && self.current_field < self.num_columns {
138 let fill_count = self.num_columns - self.current_field;
140 let fill_value = self.offsets[self.offsets_len - 1];
141 self.offsets[self.offsets_len..self.offsets_len + fill_count]
142 .fill(fill_value);
143 self.offsets_len += fill_count;
144 } else {
145 return Err(ArrowError::CsvError(format!(
146 "incorrect number of fields for line {}, expected {} got {}",
147 self.line_number, self.num_columns, self.current_field
148 )));
149 }
150 }
151 read += 1;
152 self.current_field = 0;
153 self.line_number += 1;
154 self.num_rows += 1;
155
156 if read == to_read {
157 return Ok((read, input_offset));
159 }
160
161 if input.len() == input_offset {
162 return Ok((read, input_offset));
166 }
167 }
168 }
169 }
170 }
171 }
172
173 pub fn len(&self) -> usize {
175 self.num_rows
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.num_rows == 0
181 }
182
183 pub fn clear(&mut self) {
185 self.offsets_len = 1;
187 self.data_len = 0;
188 self.num_rows = 0;
189 }
190
191 pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
193 if self.current_field != 0 {
194 return Err(ArrowError::CsvError(
195 "Cannot flush part way through record".to_string(),
196 ));
197 }
198
199 let mut row_offset: usize = 0;
202 self.offsets[1..self.offsets_len]
203 .chunks_exact_mut(self.num_columns)
204 .try_for_each(|row| -> Result<(), ArrowError> {
205 let offset = row_offset;
206 row.iter_mut().try_for_each(|x| -> Result<(), ArrowError> {
207 *x = x.checked_add(offset).ok_or_else(|| {
208 ArrowError::CsvError(
209 "CSV record offsets overflowed usize while flushing".to_string(),
210 )
211 })?;
212 row_offset = *x;
213 Ok(())
214 })
215 })?;
216
217 let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
219 let valid_up_to = e.valid_up_to();
220
221 let idx = self.offsets[..self.offsets_len]
223 .iter()
224 .rposition(|x| *x <= valid_up_to)
225 .unwrap();
226
227 let field = idx % self.num_columns + 1;
228 let line_offset = self.line_number - self.num_rows;
229 let line = line_offset + idx / self.num_columns;
230
231 ArrowError::CsvError(format!(
232 "Encountered invalid UTF-8 data for line {line} and field {field}"
233 ))
234 })?;
235
236 let offsets = &self.offsets[..self.offsets_len];
237 let num_rows = self.num_rows;
238
239 self.offsets_len = 1;
241 self.data_len = 0;
242 self.num_rows = 0;
243
244 Ok(StringRecords {
245 num_rows,
246 num_columns: self.num_columns,
247 offsets,
248 data,
249 })
250 }
251}
252
253#[derive(Debug)]
255pub struct StringRecords<'a> {
256 num_columns: usize,
257 num_rows: usize,
258 offsets: &'a [usize],
259 data: &'a str,
260}
261
262impl<'a> StringRecords<'a> {
263 fn get(&self, index: usize) -> StringRecord<'a> {
264 let field_idx = index * self.num_columns;
265 StringRecord {
266 data: self.data,
267 offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1],
268 }
269 }
270
271 pub fn len(&self) -> usize {
272 self.num_rows
273 }
274
275 pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
276 (0..self.num_rows).map(|x| self.get(x))
277 }
278}
279
280#[derive(Debug, Clone, Copy)]
282pub struct StringRecord<'a> {
283 data: &'a str,
284 offsets: &'a [usize],
285}
286
287impl<'a> StringRecord<'a> {
288 pub fn get(&self, index: usize) -> &'a str {
289 let end = self.offsets[index + 1];
290 let start = self.offsets[index];
291
292 unsafe { self.data.get_unchecked(start..end) }
295 }
296}
297
298impl std::fmt::Display for StringRecord<'_> {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 let num_fields = self.offsets.len() - 1;
301 write!(f, "[")?;
302 for i in 0..num_fields {
303 if i > 0 {
304 write!(f, ",")?;
305 }
306 write!(f, "{}", self.get(i))?;
307 }
308 write!(f, "]")?;
309 Ok(())
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use crate::reader::records::RecordDecoder;
316 use csv_core::Reader;
317 use std::io::{BufRead, BufReader, Cursor};
318
319 #[test]
320 fn test_basic() {
321 let csv = [
322 "foo,bar,baz",
323 "a,b,c",
324 "12,3,5",
325 "\"asda\"\"asas\",\"sdffsnsd\", as",
326 ]
327 .join("\n");
328
329 let mut expected = vec![
330 vec!["foo", "bar", "baz"],
331 vec!["a", "b", "c"],
332 vec!["12", "3", "5"],
333 vec!["asda\"asas", "sdffsnsd", " as"],
334 ]
335 .into_iter();
336
337 let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
338 let mut decoder = RecordDecoder::new(Reader::new(), 3, false);
339
340 loop {
341 let to_read = 3;
342 let mut read = 0;
343 loop {
344 let buf = reader.fill_buf().unwrap();
345 let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
346
347 reader.consume(bytes);
348 read += records;
349
350 if read == to_read || bytes == 0 {
351 break;
352 }
353 }
354 if read == 0 {
355 break;
356 }
357
358 let b = decoder.flush().unwrap();
359 b.iter().zip(&mut expected).for_each(|(record, expected)| {
360 let actual = (0..3)
361 .map(|field_idx| record.get(field_idx))
362 .collect::<Vec<_>>();
363 assert_eq!(actual, expected)
364 });
365 }
366 assert!(expected.next().is_none());
367 }
368
369 #[test]
370 fn test_invalid_fields() {
371 let csv = "a,b\nb,c\na\n";
372 let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
373 let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
374
375 let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";
376
377 assert_eq!(err, expected);
378
379 let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
381 let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
382 assert_eq!(skipped, 1);
383 decoder.clear();
384
385 let remaining = &csv.as_bytes()[bytes..];
386 let err = decoder.decode(remaining, 3).unwrap_err().to_string();
387 assert_eq!(err, expected);
388 }
389
390 #[test]
391 fn test_skip_insufficient_rows() {
392 let csv = "a\nv\n";
393 let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
394 let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
395 assert_eq!(read, 2);
396 assert_eq!(bytes, csv.len());
397 }
398
399 #[test]
400 fn test_truncated_rows() {
401 let csv = "a,b\nv\n,1\n,2\n,3\n";
402 let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
403 let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
404 assert_eq!(read, 5);
405 assert_eq!(bytes, csv.len());
406 }
407
408 #[test]
415 fn test_flush_offset_overflow_returns_csv_error() {
416 let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
417 decoder.offsets = vec![0, usize::MAX, 1];
418 decoder.offsets_len = 3;
419 decoder.num_rows = 2;
420 let err = decoder.flush().unwrap_err();
421 assert_eq!(
422 err.to_string(),
423 "Csv error: CSV record offsets overflowed usize while flushing"
424 );
425 }
426}