arrow_ipc/reader/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::sync::Arc;
21
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_buffer::{Buffer, MutableBuffer};
24use arrow_data::UnsafeFlag;
25use arrow_schema::{ArrowError, SchemaRef};
26
27use crate::convert::MessageBuffer;
28use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
29use crate::{MessageHeader, CONTINUATION_MARKER};
30
31/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
32///
33/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface
34#[derive(Debug, Default)]
35pub struct StreamDecoder {
36    /// The schema of this decoder, if read
37    schema: Option<SchemaRef>,
38    /// Lookup table for dictionaries by ID
39    dictionaries: HashMap<i64, ArrayRef>,
40    /// The decoder state
41    state: DecoderState,
42    /// A scratch buffer when a read is split across multiple `Buffer`
43    buf: MutableBuffer,
44    /// Whether or not array data in input buffers are required to be aligned
45    require_alignment: bool,
46    /// Should validation be skipped when reading data? Defaults to false.
47    ///
48    /// See [`FileDecoder::with_skip_validation`] for details.
49    ///
50    /// [`FileDecoder::with_skip_validation`]: crate::reader::FileDecoder::with_skip_validation
51    skip_validation: UnsafeFlag,
52}
53
54#[derive(Debug)]
55enum DecoderState {
56    /// Decoding the message header
57    Header {
58        /// Temporary buffer
59        buf: [u8; 4],
60        /// Number of bytes read into buf
61        read: u8,
62        /// If we have read a continuation token
63        continuation: bool,
64    },
65    /// Decoding the message flatbuffer
66    Message {
67        /// The size of the message flatbuffer
68        size: u32,
69    },
70    /// Decoding the message body
71    Body {
72        /// The message flatbuffer
73        message: MessageBuffer,
74    },
75    /// Reached the end of the stream
76    Finished,
77}
78
79impl Default for DecoderState {
80    fn default() -> Self {
81        Self::Header {
82            buf: [0; 4],
83            read: 0,
84            continuation: false,
85        }
86    }
87}
88
89impl StreamDecoder {
90    /// Create a new [`StreamDecoder`]
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Specifies whether or not array data in input buffers is required to be properly aligned.
96    ///
97    /// If `require_alignment` is true, this decoder will return an error if any array data in the
98    /// input `buf` is not properly aligned.
99    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
100    /// [`arrow_data::ArrayData`].
101    ///
102    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
103    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
104    /// properly aligned. (Properly aligned array data will remain zero-copy.)
105    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
106    /// [`arrow_data::ArrayData`].
107    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
108        self.require_alignment = require_alignment;
109        self
110    }
111
112    /// Return the schema if decoded, else None.
113    pub fn schema(&self) -> Option<SchemaRef> {
114        self.schema.as_ref().map(|schema| schema.clone())
115    }
116
117    /// Try to read the next [`RecordBatch`] from the provided [`Buffer`]
118    ///
119    /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes.
120    ///
121    /// The push-based interface facilitates integration with sources that yield arbitrarily
122    /// delimited bytes ranges, such as a chunked byte stream received from object storage
123    ///
124    /// ```
125    /// # use arrow_array::RecordBatch;
126    /// # use arrow_buffer::Buffer;
127    /// # use arrow_ipc::reader::StreamDecoder;
128    /// # use arrow_schema::ArrowError;
129    /// #
130    /// fn print_stream<I>(src: impl Iterator<Item = Buffer>) -> Result<(), ArrowError> {
131    ///     let mut decoder = StreamDecoder::new();
132    ///     for mut x in src {
133    ///         while !x.is_empty() {
134    ///             if let Some(x) = decoder.decode(&mut x)? {
135    ///                 println!("{x:?}");
136    ///             }
137    ///             if let Some(schema) = decoder.schema() {
138    ///                 println!("Schema: {schema:?}");
139    ///             }
140    ///         }
141    ///     }
142    ///     decoder.finish().unwrap();
143    ///     Ok(())
144    /// }
145    /// ```
146    pub fn decode(&mut self, buffer: &mut Buffer) -> Result<Option<RecordBatch>, ArrowError> {
147        while !buffer.is_empty() {
148            match &mut self.state {
149                DecoderState::Header {
150                    buf,
151                    read,
152                    continuation,
153                } => {
154                    let offset_buf = &mut buf[*read as usize..];
155                    let to_read = buffer.len().min(offset_buf.len());
156                    offset_buf[..to_read].copy_from_slice(&buffer[..to_read]);
157                    *read += to_read as u8;
158                    buffer.advance(to_read);
159                    if *read == 4 {
160                        if !*continuation && buf == &CONTINUATION_MARKER {
161                            *continuation = true;
162                            *read = 0;
163                            continue;
164                        }
165                        let size = u32::from_le_bytes(*buf);
166
167                        if size == 0 {
168                            self.state = DecoderState::Finished;
169                            continue;
170                        }
171                        self.state = DecoderState::Message { size };
172                    }
173                }
174                DecoderState::Message { size } => {
175                    let len = *size as usize;
176                    if self.buf.is_empty() && buffer.len() > len {
177                        let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?;
178                        self.state = DecoderState::Body { message };
179                        buffer.advance(len);
180                        continue;
181                    }
182
183                    let to_read = buffer.len().min(len - self.buf.len());
184                    self.buf.extend_from_slice(&buffer[..to_read]);
185                    buffer.advance(to_read);
186                    if self.buf.len() == len {
187                        let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?;
188                        self.state = DecoderState::Body { message };
189                    }
190                }
191                DecoderState::Body { message } => {
192                    let message = message.as_ref();
193                    let body_length = message.bodyLength() as usize;
194
195                    let body = if self.buf.is_empty() && buffer.len() >= body_length {
196                        let body = buffer.slice_with_length(0, body_length);
197                        buffer.advance(body_length);
198                        body
199                    } else {
200                        let to_read = buffer.len().min(body_length - self.buf.len());
201                        self.buf.extend_from_slice(&buffer[..to_read]);
202                        buffer.advance(to_read);
203
204                        if self.buf.len() != body_length {
205                            continue;
206                        }
207                        std::mem::take(&mut self.buf).into()
208                    };
209
210                    let version = message.version();
211                    match message.header_type() {
212                        MessageHeader::Schema => {
213                            if self.schema.is_some() {
214                                return Err(ArrowError::IpcError(
215                                    "Not expecting a schema when messages are read".to_string(),
216                                ));
217                            }
218
219                            let ipc_schema = message.header_as_schema().unwrap();
220                            let schema = crate::convert::fb_to_schema(ipc_schema);
221                            self.state = DecoderState::default();
222                            self.schema = Some(Arc::new(schema));
223                        }
224                        MessageHeader::RecordBatch => {
225                            let batch = message.header_as_record_batch().unwrap();
226                            let schema = self.schema.clone().ok_or_else(|| {
227                                ArrowError::IpcError("Missing schema".to_string())
228                            })?;
229                            let batch = RecordBatchDecoder::try_new(
230                                &body,
231                                batch,
232                                schema,
233                                &self.dictionaries,
234                                &version,
235                            )?
236                            .with_require_alignment(self.require_alignment)
237                            .read_record_batch()?;
238                            self.state = DecoderState::default();
239                            return Ok(Some(batch));
240                        }
241                        MessageHeader::DictionaryBatch => {
242                            let dictionary = message.header_as_dictionary_batch().unwrap();
243                            let schema = self.schema.as_deref().ok_or_else(|| {
244                                ArrowError::IpcError("Missing schema".to_string())
245                            })?;
246                            read_dictionary_impl(
247                                &body,
248                                dictionary,
249                                schema,
250                                &mut self.dictionaries,
251                                &version,
252                                self.require_alignment,
253                                self.skip_validation.clone(),
254                            )?;
255                            self.state = DecoderState::default();
256                        }
257                        MessageHeader::NONE => {
258                            self.state = DecoderState::default();
259                        }
260                        t => {
261                            return Err(ArrowError::IpcError(format!(
262                                "Message type unsupported by StreamDecoder: {t:?}"
263                            )))
264                        }
265                    }
266                }
267                DecoderState::Finished => {
268                    return Err(ArrowError::IpcError("Unexpected EOS".to_string()))
269                }
270            }
271        }
272        Ok(None)
273    }
274
275    /// Signal the end of stream
276    ///
277    /// Returns an error if any partial data remains in the stream
278    pub fn finish(&mut self) -> Result<(), ArrowError> {
279        match self.state {
280            DecoderState::Finished
281            | DecoderState::Header {
282                read: 0,
283                continuation: false,
284                ..
285            } => Ok(()),
286            _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())),
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::writer::{IpcWriteOptions, StreamWriter};
295    use arrow_array::{
296        types::Int32Type, DictionaryArray, Int32Array, Int64Array, RecordBatch, RunArray,
297    };
298    use arrow_schema::{DataType, Field, Schema};
299
300    // Further tests in arrow-integration-testing/tests/ipc_reader.rs
301
302    #[test]
303    fn test_eos() {
304        let schema = Arc::new(Schema::new(vec![
305            Field::new("int32", DataType::Int32, false),
306            Field::new("int64", DataType::Int64, false),
307        ]));
308
309        let input = RecordBatch::try_new(
310            schema.clone(),
311            vec![
312                Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
313                Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
314            ],
315        )
316        .unwrap();
317
318        let mut buf = Vec::with_capacity(1024);
319        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
320        s.write(&input).unwrap();
321        s.finish().unwrap();
322        drop(s);
323
324        let buffer = Buffer::from_vec(buf);
325
326        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
327        let mut decoder = StreamDecoder::new();
328        let output = decoder.decode(&mut b).unwrap().unwrap();
329        assert_eq!(output, input);
330        assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte
331        assert!(decoder.decode(&mut b).unwrap().is_none());
332
333        let err = decoder.finish().unwrap_err().to_string();
334        assert_eq!(err, "Ipc error: Unexpected End of Stream");
335    }
336
337    #[test]
338    fn test_schema() {
339        let schema = Arc::new(Schema::new(vec![
340            Field::new("int32", DataType::Int32, false),
341            Field::new("int64", DataType::Int64, false),
342        ]));
343
344        let mut buf = Vec::with_capacity(1024);
345        let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap();
346        s.finish().unwrap();
347        drop(s);
348
349        let buffer = Buffer::from_vec(buf);
350
351        let mut b = buffer.slice_with_length(0, buffer.len() - 1);
352        let mut decoder = StreamDecoder::new();
353        let output = decoder.decode(&mut b).unwrap();
354        assert!(output.is_none());
355        let decoded_schema = decoder.schema().unwrap();
356        assert_eq!(schema, decoded_schema);
357
358        let err = decoder.finish().unwrap_err().to_string();
359        assert_eq!(err, "Ipc error: Unexpected End of Stream");
360    }
361
362    #[test]
363    fn test_read_ree_dict_record_batches_from_buffer() {
364        let schema = Schema::new(vec![Field::new(
365            "test1",
366            DataType::RunEndEncoded(
367                Arc::new(Field::new("run_ends".to_string(), DataType::Int32, false)),
368                #[allow(deprecated)]
369                Arc::new(Field::new_dict(
370                    "values".to_string(),
371                    DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
372                    true,
373                    0,
374                    false,
375                )),
376            ),
377            true,
378        )]);
379        let batch = RecordBatch::try_new(
380            schema.clone().into(),
381            vec![Arc::new(
382                RunArray::try_new(
383                    &Int32Array::from(vec![1, 2, 3]),
384                    &vec![Some("a"), None, Some("a")]
385                        .into_iter()
386                        .collect::<DictionaryArray<Int32Type>>(),
387                )
388                .expect("Failed to create RunArray"),
389            )],
390        )
391        .expect("Failed to create RecordBatch");
392
393        let mut buffer = vec![];
394        {
395            let mut writer = StreamWriter::try_new_with_options(
396                &mut buffer,
397                &schema,
398                #[allow(deprecated)]
399                IpcWriteOptions::default().with_preserve_dict_id(false),
400            )
401            .expect("Failed to create StreamWriter");
402            writer.write(&batch).expect("Failed to write RecordBatch");
403            writer.finish().expect("Failed to finish StreamWriter");
404        }
405
406        let mut decoder = StreamDecoder::new();
407        let buf = &mut Buffer::from(buffer.as_slice());
408        while let Some(batch) = decoder
409            .decode(buf)
410            .map_err(|e| {
411                ArrowError::ExternalError(format!("Failed to decode record batch: {}", e).into())
412            })
413            .expect("Failed to decode record batch")
414        {
415            assert_eq!(batch, batch);
416        }
417
418        decoder.finish().expect("Failed to finish decoder");
419    }
420}