1use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
29use crate::{MessageHeader, CONTINUATION_MARKER};
30
31#[derive(Debug, Default)]
35pub struct StreamDecoder {
36 schema: Option<SchemaRef>,
38 dictionaries: HashMap<i64, ArrayRef>,
40 state: DecoderState,
42 buf: MutableBuffer,
44 require_alignment: bool,
46 skip_validation: UnsafeFlag,
52}
53
54#[derive(Debug)]
55enum DecoderState {
56 Header {
58 buf: [u8; 4],
60 read: u8,
62 continuation: bool,
64 },
65 Message {
67 size: u32,
69 },
70 Body {
72 message: MessageBuffer,
74 },
75 Finished,
77}
78
79impl Default for DecoderState {
80 fn default() -> Self {
81 Self::Header {
82 buf: [0; 4],
83 read: 0,
84 continuation: false,
85 }
86 }
87}
88
89impl StreamDecoder {
90 pub fn new() -> Self {
92 Self::default()
93 }
94
95 pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
108 self.require_alignment = require_alignment;
109 self
110 }
111
112 pub fn schema(&self) -> Option<SchemaRef> {
114 self.schema.as_ref().map(|schema| schema.clone())
115 }
116
117 pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
147 while !buffer.is_empty() {
148 match &mut self.state {
149 DecoderState::Header {
150 buf,
151 read,
152 continuation,
153 } => {
154 let offset_buf = &mut buf[*read as usize..];
155 let to_read = buffer.len().min(offset_buf.len());
156 offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
157 *read += to_read as u8;
158 buffer.advance(to_read);
159 if *read == 4 {
160 if !*continuation && buf == &CONTINUATION_MARKER {
161 *continuation = true;
162 *read = 0;
163 continue;
164 }
165 let size = u32::from_le_bytes(*buf);
166
167 if size == 0 {
168 self.state = DecoderState::Finished;
169 continue;
170 }
171 self.state = DecoderState::Message { size };
172 }
173 }
174 DecoderState::Message { size } => {
175 let len = *size as usize;
176 if self.buf.is_empty() && buffer.len() > len {
177 let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
178 self.state = DecoderState::Body { message };
179 buffer.advance(len);
180 continue;
181 }
182
183 let to_read = buffer.len().min(len - self.buf.len());
184 self.buf.extend_from_slice(&buffer[..to_read]);
185 buffer.advance(to_read);
186 if self.buf.len() == len {
187 let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
188 self.state = DecoderState::Body { message };
189 }
190 }
191 DecoderState::Body { message } => {
192 let message = message.as_ref();
193 let body_length = message.bodyLength() as usize;
194
195 let body = if self.buf.is_empty() && buffer.len() >= body_length {
196 let body = buffer.slice_with_length(0, body_length);
197 buffer.advance(body_length);
198 body
199 } else {
200 let to_read = buffer.len().min(body_length - self.buf.len());
201 self.buf.extend_from_slice(&buffer[..to_read]);
202 buffer.advance(to_read);
203
204 if self.buf.len() != body_length {
205 continue;
206 }
207 std::mem::take(&mut self.buf).into()
208 };
209
210 let version = message.version();
211 match message.header_type() {
212 MessageHeader::Schema => {
213 if self.schema.is_some() {
214 return Err(ArrowError::IpcError(
215 "Not expecting a schema when messages are read".to_string(),
216 ));
217 }
218
219 let ipc_schema = message.header_as_schema().unwrap();
220 let schema = crate::convert::fb_to_schema(ipc_schema);
221 self.state = DecoderState::default();
222 self.schema = Some(Arc::new(schema));
223 }
224 MessageHeader::RecordBatch => {
225 let batch = message.header_as_record_batch().unwrap();
226 let schema = self.schema.clone().ok_or_else(|| {
227 ArrowError::IpcError("Missing schema".to_string())
228 })?;
229 let batch = RecordBatchDecoder::try_new(
230 &body,
231 batch,
232 schema,
233 &self.dictionaries,
234 &version,
235 )?
236 .with_require_alignment(self.require_alignment)
237 .read_record_batch()?;
238 self.state = DecoderState::default();
239 return Ok(Some(batch));
240 }
241 MessageHeader::DictionaryBatch => {
242 let dictionary = message.header_as_dictionary_batch().unwrap();
243 let schema = self.schema.as_deref().ok_or_else(|| {
244 ArrowError::IpcError("Missing schema".to_string())
245 })?;
246 read_dictionary_impl(
247 &body,
248 dictionary,
249 schema,
250 &mut self.dictionaries,
251 &version,
252 self.require_alignment,
253 self.skip_validation.clone(),
254 )?;
255 self.state = DecoderState::default();
256 }
257 MessageHeader::NONE => {
258 self.state = DecoderState::default();
259 }
260 t => {
261 return Err(ArrowError::IpcError(format!(
262 "Message type unsupported by StreamDecoder: {t:?}"
263 )))
264 }
265 }
266 }
267 DecoderState::Finished => {
268 return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
269 }
270 }
271 }
272 Ok(None)
273 }
274
275 pub fn finish(&mut self) -> Result<(), ArrowError> {
279 match self.state {
280 DecoderState::Finished
281 | DecoderState::Header {
282 read: 0,
283 continuation: false,
284 ..
285 } => Ok(()),
286 _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::writer::{IpcWriteOptions, StreamWriter};
295 use arrow_array::{
296 types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
297 };
298 use arrow_schema::{DataType, Field, Schema};
299
300 #[test]
303 fn test_eos() {
304 let schema = Arc::new(Schema::new(vec![
305 Field::new("int32", DataType::Int32, false),
306 Field::new("int64", DataType::Int64, false),
307 ]));
308
309 let input = RecordBatch::try_new(
310 schema.clone(),
311 vec![
312 Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
313 Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
314 ],
315 )
316 .unwrap();
317
318 let mut buf = Vec::with_capacity(1024);
319 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
320 s.write(&input).unwrap();
321 s.finish().unwrap();
322 drop(s);
323
324 let buffer = Buffer::from_vec(buf);
325
326 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
327 let mut decoder = StreamDecoder::new();
328 let output = decoder.decode(&mut b).unwrap().unwrap();
329 assert_eq!(output, input);
330 assert_eq!(b.len(), 7); assert!(decoder.decode(&mut b).unwrap().is_none());
332
333 let err = decoder.finish().unwrap_err().to_string();
334 assert_eq!(err, "Ipc error: Unexpected End of Stream");
335 }
336
337 #[test]
338 fn test_schema() {
339 let schema = Arc::new(Schema::new(vec![
340 Field::new("int32", DataType::Int32, false),
341 Field::new("int64", DataType::Int64, false),
342 ]));
343
344 let mut buf = Vec::with_capacity(1024);
345 let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
346 s.finish().unwrap();
347 drop(s);
348
349 let buffer = Buffer::from_vec(buf);
350
351 let mut b = buffer.slice_with_length(0, buffer.len() - 1);
352 let mut decoder = StreamDecoder::new();
353 let output = decoder.decode(&mut b).unwrap();
354 assert!(output.is_none());
355 let decoded_schema = decoder.schema().unwrap();
356 assert_eq!(schema, decoded_schema);
357
358 let err = decoder.finish().unwrap_err().to_string();
359 assert_eq!(err, "Ipc error: Unexpected End of Stream");
360 }
361
362 #[test]
363 fn test_read_ree_dict_record_batches_from_buffer() {
364 let schema = Schema::new(vec![Field::new(
365 "test1",
366 DataType::RunEndEncoded(
367 Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
368 #[allow(deprecated)]
369 Arc::new(Field::new_dict(
370 "values".to_string(),
371 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
372 true,
373 0,
374 false,
375 )),
376 ),
377 true,
378 )]);
379 let batch = RecordBatch::try_new(
380 schema.clone().into(),
381 vec![Arc::new(
382 RunArray::try_new(
383 &Int32Array::from(vec![1, 2, 3]),
384 &vec![Some("a"), None, Some("a")]
385 .into_iter()
386 .collect::<DictionaryArray<Int32Type>>(),
387 )
388 .expect("Failed to create RunArray"),
389 )],
390 )
391 .expect("Failed to create RecordBatch");
392
393 let mut buffer = vec![];
394 {
395 let mut writer = StreamWriter::try_new_with_options(
396 &mut buffer,
397 &schema,
398 #[allow(deprecated)]
399 IpcWriteOptions::default().with_preserve_dict_id(false),
400 )
401 .expect("Failed to create StreamWriter");
402 writer.write(&batch).expect("Failed to write RecordBatch");
403 writer.finish().expect("Failed to finish StreamWriter");
404 }
405
406 let mut decoder = StreamDecoder::new();
407 let buf = &mut Buffer::from(buffer.as_slice());
408 while let Some(batch) = decoder
409 .decode(buf)
410 .map_err(|e| {
411 ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
412 })
413 .expect("Failed to decode record batch")
414 {
415 assert_eq!(batch, batch);
416 }
417
418 decoder.finish().expect("Failed to finish decoder");
419 }
420}