1use arrow_schema::ArrowError;
19use csv_core::{ReadRecordResult, Reader};
20
21const AVERAGE_FIELD_SIZE: usize = 8;
23
24const MIN_CAPACITY: usize = 1024;
26
27#[derive(Debug)]
29pub struct RecordDecoder {
30 delimiter: Reader,
31
32 num_columns: usize,
34
35 line_number: usize,
37
38 offsets: Vec<usize>,
40
41 offsets_len: usize,
45
46 current_field: usize,
48
49 num_rows: usize,
51
52 data: Vec<u8>,
54
55 data_len: usize,
59
60 truncated_rows: bool,
65}
66
67impl RecordDecoder {
68 pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
69 Self {
70 delimiter,
71 num_columns,
72 line_number: 1,
73 offsets: vec![],
74 offsets_len: 1, current_field: 0,
76 data_len: 0,
77 data: vec![],
78 num_rows: 0,
79 truncated_rows,
80 }
81 }
82
83 pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> {
87 if to_read == 0 {
88 return Ok((0, 0));
89 }
90
91 self.offsets
93 .resize(self.offsets_len + to_read * self.num_columns, 0);
94
95 let mut input_offset = 0;
97
98 let mut read = 0;
100
101 loop {
102 let remaining_rows = to_read - read;
104 let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
105 let estimated_data = capacity.max(MIN_CAPACITY);
106 self.data.resize(self.data_len + estimated_data, 0);
107
108 loop {
110 let (result, bytes_read, bytes_written, end_positions) =
111 self.delimiter.read_record(
112 &input[input_offset..],
113 &mut self.data[self.data_len..],
114 &mut self.offsets[self.offsets_len..],
115 );
116
117 self.current_field += end_positions;
118 self.offsets_len += end_positions;
119 input_offset += bytes_read;
120 self.data_len += bytes_written;
121
122 match result {
123 ReadRecordResult::End | ReadRecordResult::InputEmpty => {
124 return Ok((read, input_offset));
126 }
127 ReadRecordResult::OutputFull => break,
129 ReadRecordResult::OutputEndsFull => {
130 return Err(ArrowError::CsvError(format!(
131 "incorrect number of fields for line {}, expected {} got more than {}",
132 self.line_number, self.num_columns, self.current_field
133 )));
134 }
135 ReadRecordResult::Record => {
136 if self.current_field != self.num_columns {
137 if self.truncated_rows && self.current_field < self.num_columns {
138 let fill_count = self.num_columns - self.current_field;
140 let fill_value = self.offsets[self.offsets_len - 1];
141 self.offsets[self.offsets_len..self.offsets_len + fill_count]
142 .fill(fill_value);
143 self.offsets_len += fill_count;
144 } else {
145 return Err(ArrowError::CsvError(format!(
146 "incorrect number of fields for line {}, expected {} got {}",
147 self.line_number, self.num_columns, self.current_field
148 )));
149 }
150 }
151 read += 1;
152 self.current_field = 0;
153 self.line_number += 1;
154 self.num_rows += 1;
155
156 if read == to_read {
157 return Ok((read, input_offset));
159 }
160
161 if input.len() == input_offset {
162 return Ok((read, input_offset));
166 }
167 }
168 }
169 }
170 }
171 }
172
173 pub fn len(&self) -> usize {
175 self.num_rows
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.num_rows == 0
181 }
182
183 pub fn clear(&mut self) {
185 self.offsets_len = 1;
187 self.data_len = 0;
188 self.num_rows = 0;
189 }
190
191 pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
193 if self.current_field != 0 {
194 return Err(ArrowError::CsvError(
195 "Cannot flush part way through record".to_string(),
196 ));
197 }
198
199 let mut row_offset = 0;
202 self.offsets[1..self.offsets_len]
203 .chunks_exact_mut(self.num_columns)
204 .for_each(|row| {
205 let offset = row_offset;
206 row.iter_mut().for_each(|x| {
207 *x += offset;
208 row_offset = *x;
209 });
210 });
211
212 let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
214 let valid_up_to = e.valid_up_to();
215
216 let idx = self.offsets[..self.offsets_len]
218 .iter()
219 .rposition(|x| *x <= valid_up_to)
220 .unwrap();
221
222 let field = idx % self.num_columns + 1;
223 let line_offset = self.line_number - self.num_rows;
224 let line = line_offset + idx / self.num_columns;
225
226 ArrowError::CsvError(format!(
227 "Encountered invalid UTF-8 data for line {line} and field {field}"
228 ))
229 })?;
230
231 let offsets = &self.offsets[..self.offsets_len];
232 let num_rows = self.num_rows;
233
234 self.offsets_len = 1;
236 self.data_len = 0;
237 self.num_rows = 0;
238
239 Ok(StringRecords {
240 num_rows,
241 num_columns: self.num_columns,
242 offsets,
243 data,
244 })
245 }
246}
247
248#[derive(Debug)]
250pub struct StringRecords<'a> {
251 num_columns: usize,
252 num_rows: usize,
253 offsets: &'a [usize],
254 data: &'a str,
255}
256
257impl<'a> StringRecords<'a> {
258 fn get(&self, index: usize) -> StringRecord<'a> {
259 let field_idx = index * self.num_columns;
260 StringRecord {
261 data: self.data,
262 offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1],
263 }
264 }
265
266 pub fn len(&self) -> usize {
267 self.num_rows
268 }
269
270 pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
271 (0..self.num_rows).map(|x| self.get(x))
272 }
273}
274
275#[derive(Debug, Clone, Copy)]
277pub struct StringRecord<'a> {
278 data: &'a str,
279 offsets: &'a [usize],
280}
281
282impl<'a> StringRecord<'a> {
283 pub fn get(&self, index: usize) -> &'a str {
284 let end = self.offsets[index + 1];
285 let start = self.offsets[index];
286
287 unsafe { self.data.get_unchecked(start..end) }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use crate::reader::records::RecordDecoder;
296 use csv_core::Reader;
297 use std::io::{BufRead, BufReader, Cursor};
298
299 #[test]
300 fn test_basic() {
301 let csv = [
302 "foo,bar,baz",
303 "a,b,c",
304 "12,3,5",
305 "\"asda\"\"asas\",\"sdffsnsd\", as",
306 ]
307 .join("\n");
308
309 let mut expected = vec![
310 vec!["foo", "bar", "baz"],
311 vec!["a", "b", "c"],
312 vec!["12", "3", "5"],
313 vec!["asda\"asas", "sdffsnsd", " as"],
314 ]
315 .into_iter();
316
317 let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
318 let mut decoder = RecordDecoder::new(Reader::new(), 3, false);
319
320 loop {
321 let to_read = 3;
322 let mut read = 0;
323 loop {
324 let buf = reader.fill_buf().unwrap();
325 let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
326
327 reader.consume(bytes);
328 read += records;
329
330 if read == to_read || bytes == 0 {
331 break;
332 }
333 }
334 if read == 0 {
335 break;
336 }
337
338 let b = decoder.flush().unwrap();
339 b.iter().zip(&mut expected).for_each(|(record, expected)| {
340 let actual = (0..3)
341 .map(|field_idx| record.get(field_idx))
342 .collect::<Vec<_>>();
343 assert_eq!(actual, expected)
344 });
345 }
346 assert!(expected.next().is_none());
347 }
348
349 #[test]
350 fn test_invalid_fields() {
351 let csv = "a,b\nb,c\na\n";
352 let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
353 let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
354
355 let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";
356
357 assert_eq!(err, expected);
358
359 let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
361 let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
362 assert_eq!(skipped, 1);
363 decoder.clear();
364
365 let remaining = &csv.as_bytes()[bytes..];
366 let err = decoder.decode(remaining, 3).unwrap_err().to_string();
367 assert_eq!(err, expected);
368 }
369
370 #[test]
371 fn test_skip_insufficient_rows() {
372 let csv = "a\nv\n";
373 let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
374 let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
375 assert_eq!(read, 2);
376 assert_eq!(bytes, csv.len());
377 }
378
379 #[test]
380 fn test_truncated_rows() {
381 let csv = "a,b\nv\n,1\n,2\n,3\n";
382 let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
383 let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
384 assert_eq!(read, 5);
385 assert_eq!(bytes, csv.len());
386 }
387}