arrow_flight/
decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_buffer::Buffer;
21use arrow_schema::{Schema, SchemaRef};
22use bytes::Bytes;
23use futures::{ready, stream::BoxStream, Stream, StreamExt};
24use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
25use tonic::metadata::MetadataMap;
26
27use crate::error::{FlightError, Result};
28
29/// Decodes a [Stream] of [`FlightData`] back into
30/// [`RecordBatch`]es. This can be used to decode the response from an
31/// Arrow Flight server
32///
33/// # Note
34/// To access the lower level Flight messages (e.g. to access
35/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`]
36/// and use the [`FlightDataDecoder`] directly.
37///
38/// # Example:
39/// ```no_run
40/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{
41/// # use bytes::Bytes;
42/// // make a do_get request
43/// use arrow_flight::{
44///   error::Result,
45///   decode::FlightRecordBatchStream,
46///   Ticket,
47///   flight_service_client::FlightServiceClient
48/// };
49/// use tonic::transport::Channel;
50/// use futures::stream::{StreamExt, TryStreamExt};
51///
52/// let client: FlightServiceClient<Channel> = // make client..
53/// # unimplemented!();
54///
55/// let request = tonic::Request::new(
56///   Ticket { ticket: Bytes::new() }
57/// );
58///
59/// // Get a stream of FlightData;
60/// let flight_data_stream = client
61///   .do_get(request)
62///   .await?
63///   .into_inner();
64///
65/// // Decode stream of FlightData to RecordBatches
66/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
67///   // convert tonic::Status to FlightError
68///   flight_data_stream.map_err(|e| e.into())
69/// );
70///
71/// // Read back RecordBatches
72/// while let Some(batch) = record_batch_stream.next().await {
73///   match batch {
74///     Ok(batch) => { /* process batch */ },
75///     Err(e) => { /* handle error */ },
76///   };
77/// }
78///
79/// # Ok(())
80/// # }
81/// ```
82#[derive(Debug)]
83pub struct FlightRecordBatchStream {
84    /// Optional grpc header metadata.
85    headers: MetadataMap,
86
87    /// Optional grpc trailer metadata.
88    trailers: Option<LazyTrailers>,
89
90    inner: FlightDataDecoder,
91}
92
93impl FlightRecordBatchStream {
94    /// Create a new [`FlightRecordBatchStream`] from a decoded stream
95    pub fn new(inner: FlightDataDecoder) -> Self {
96        Self {
97            inner,
98            headers: MetadataMap::default(),
99            trailers: None,
100        }
101    }
102
103    /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
104    pub fn new_from_flight_data<S>(inner: S) -> Self
105    where
106        S: Stream<Item = Result<FlightData>> + Send + 'static,
107    {
108        Self {
109            inner: FlightDataDecoder::new(inner),
110            headers: MetadataMap::default(),
111            trailers: None,
112        }
113    }
114
115    /// Record response headers.
116    pub fn with_headers(self, headers: MetadataMap) -> Self {
117        Self { headers, ..self }
118    }
119
120    /// Record response trailers.
121    pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
122        Self {
123            trailers: Some(trailers),
124            ..self
125        }
126    }
127
128    /// Headers attached to this stream.
129    pub fn headers(&self) -> &MetadataMap {
130        &self.headers
131    }
132
133    /// Trailers attached to this stream.
134    ///
135    /// Note that this will return `None` until the entire stream is consumed.
136    /// Only after calling `next()` returns `None`, might any available trailers be returned.
137    pub fn trailers(&self) -> Option<MetadataMap> {
138        self.trailers.as_ref().and_then(|trailers| trailers.get())
139    }
140
141    /// Has a message defining the schema been received yet?
142    #[deprecated = "use schema().is_some() instead"]
143    pub fn got_schema(&self) -> bool {
144        self.schema().is_some()
145    }
146
147    /// Return schema for the stream, if it has been received
148    pub fn schema(&self) -> Option<&SchemaRef> {
149        self.inner.schema()
150    }
151
152    /// Consume self and return the wrapped [`FlightDataDecoder`]
153    pub fn into_inner(self) -> FlightDataDecoder {
154        self.inner
155    }
156}
157
158impl futures::Stream for FlightRecordBatchStream {
159    type Item = Result<RecordBatch>;
160
161    /// Returns the next [`RecordBatch`] available in this stream, or `None` if
162    /// there are no further results available.
163    fn poll_next(
164        mut self: Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> Poll<Option<Result<RecordBatch>>> {
167        loop {
168            let had_schema = self.schema().is_some();
169            let res = ready!(self.inner.poll_next_unpin(cx));
170            match res {
171                // Inner exhausted
172                None => {
173                    return Poll::Ready(None);
174                }
175                Some(Err(e)) => {
176                    return Poll::Ready(Some(Err(e)));
177                }
178                // translate data
179                Some(Ok(data)) => match data.payload {
180                    DecodedPayload::Schema(_) if had_schema => {
181                        return Poll::Ready(Some(Err(FlightError::protocol(
182                            "Unexpectedly saw multiple Schema messages in FlightData stream",
183                        ))));
184                    }
185                    DecodedPayload::Schema(_) => {
186                        // Need next message, poll inner again
187                    }
188                    DecodedPayload::RecordBatch(batch) => {
189                        return Poll::Ready(Some(Ok(batch)));
190                    }
191                    DecodedPayload::None => {
192                        // Need next message
193                    }
194                },
195            }
196        }
197    }
198}
199
200/// Wrapper around a stream of [`FlightData`] that handles the details
201/// of decoding low level Flight messages into [`Schema`] and
202/// [`RecordBatch`]es, including details such as dictionaries.
203///
204/// # Protocol Details
205///
206/// The client handles flight messages as followes:
207///
208/// - **None:** This message has no effect. This is useful to
209///   transmit metadata without any actual payload.
210///
211/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
212///   the decoded schema is returned.
213///
214/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
215///   dictionary for the same column will be overwritten. This
216///   message is NOT visible.
217///
218/// - **Record Batch:** Record batch is created based on the current
219///   schema and dictionaries. This fails if no schema was transmitted
220///   yet.
221///
222/// All other message types (at the time of writing: e.g. tensor and
223/// sparse tensor) lead to an error.
224///
225/// Example usecases
226///
227/// 1. Using this low level stream it is possible to receive a steam
228///    of RecordBatches in FlightData that have different schemas by
229///    handling multiple schema messages separately.
230pub struct FlightDataDecoder {
231    /// Underlying data stream
232    response: BoxStream<'static, Result<FlightData>>,
233    /// Decoding state
234    state: Option<FlightStreamState>,
235    /// Seen the end of the inner stream?
236    done: bool,
237}
238
239impl Debug for FlightDataDecoder {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct("FlightDataDecoder")
242            .field("response", &"<stream>")
243            .field("state", &self.state)
244            .field("done", &self.done)
245            .finish()
246    }
247}
248
249impl FlightDataDecoder {
250    /// Create a new wrapper around the stream of [`FlightData`]
251    pub fn new<S>(response: S) -> Self
252    where
253        S: Stream<Item = Result<FlightData>> + Send + 'static,
254    {
255        Self {
256            state: None,
257            response: response.boxed(),
258            done: false,
259        }
260    }
261
262    /// Returns the current schema for this stream
263    pub fn schema(&self) -> Option<&SchemaRef> {
264        self.state.as_ref().map(|state| &state.schema)
265    }
266
267    /// Extracts flight data from the next message, updating decoding
268    /// state as necessary.
269    fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
270        use arrow_ipc::MessageHeader;
271        let message = arrow_ipc::root_as_message(&data.data_header[..])
272            .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?;
273
274        match message.header_type() {
275            MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
276            MessageHeader::Schema => {
277                let schema = Schema::try_from(&data)
278                    .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
279
280                let schema = Arc::new(schema);
281                let dictionaries_by_field = HashMap::new();
282
283                self.state = Some(FlightStreamState {
284                    schema: Arc::clone(&schema),
285                    dictionaries_by_field,
286                });
287                Ok(Some(DecodedFlightData::new_schema(data, schema)))
288            }
289            MessageHeader::DictionaryBatch => {
290                let state = if let Some(state) = self.state.as_mut() {
291                    state
292                } else {
293                    return Err(FlightError::protocol(
294                        "Received DictionaryBatch prior to Schema",
295                    ));
296                };
297
298                let buffer = Buffer::from(data.data_body);
299                let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| {
300                    FlightError::protocol(
301                        "Could not get dictionary batch from DictionaryBatch message",
302                    )
303                })?;
304
305                arrow_ipc::reader::read_dictionary(
306                    &buffer,
307                    dictionary_batch,
308                    &state.schema,
309                    &mut state.dictionaries_by_field,
310                    &message.version(),
311                )
312                .map_err(|e| {
313                    FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}"))
314                })?;
315
316                // Updated internal state, but no decoded message
317                Ok(None)
318            }
319            MessageHeader::RecordBatch => {
320                let state = if let Some(state) = self.state.as_ref() {
321                    state
322                } else {
323                    return Err(FlightError::protocol(
324                        "Received RecordBatch prior to Schema",
325                    ));
326                };
327
328                let batch = flight_data_to_arrow_batch(
329                    &data,
330                    Arc::clone(&state.schema),
331                    &state.dictionaries_by_field,
332                )
333                .map_err(|e| {
334                    FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
335                })?;
336
337                Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
338            }
339            other => {
340                let name = other.variant_name().unwrap_or("UNKNOWN");
341                Err(FlightError::protocol(format!("Unexpected message: {name}")))
342            }
343        }
344    }
345}
346
347impl futures::Stream for FlightDataDecoder {
348    type Item = Result<DecodedFlightData>;
349    /// Returns the result of decoding the next [`FlightData`] message
350    /// from the server, or `None` if there are no further results
351    /// available.
352    fn poll_next(
353        mut self: Pin<&mut Self>,
354        cx: &mut std::task::Context<'_>,
355    ) -> Poll<Option<Self::Item>> {
356        if self.done {
357            return Poll::Ready(None);
358        }
359        loop {
360            let res = ready!(self.response.poll_next_unpin(cx));
361
362            return Poll::Ready(match res {
363                None => {
364                    self.done = true;
365                    None // inner is exhausted
366                }
367                Some(data) => Some(match data {
368                    Err(e) => Err(e),
369                    Ok(data) => match self.extract_message(data) {
370                        Ok(Some(extracted)) => Ok(extracted),
371                        Ok(None) => continue, // Need next input message
372                        Err(e) => Err(e),
373                    },
374                }),
375            });
376        }
377    }
378}
379
380/// tracks the state needed to reconstruct [`RecordBatch`]es from a
381/// streaming flight response.
382#[derive(Debug)]
383struct FlightStreamState {
384    schema: SchemaRef,
385    dictionaries_by_field: HashMap<i64, ArrayRef>,
386}
387
388/// FlightData and the decoded payload (Schema, RecordBatch), if any
389#[derive(Debug)]
390pub struct DecodedFlightData {
391    /// The original FlightData message
392    pub inner: FlightData,
393    /// The decoded payload
394    pub payload: DecodedPayload,
395}
396
397impl DecodedFlightData {
398    /// Create a new DecodedFlightData with no payload
399    pub fn new_none(inner: FlightData) -> Self {
400        Self {
401            inner,
402            payload: DecodedPayload::None,
403        }
404    }
405
406    /// Create a new DecodedFlightData with a [`Schema`] payload
407    pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
408        Self {
409            inner,
410            payload: DecodedPayload::Schema(schema),
411        }
412    }
413
414    /// Create a new [`DecodedFlightData`] with a [`RecordBatch`] payload
415    pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
416        Self {
417            inner,
418            payload: DecodedPayload::RecordBatch(batch),
419        }
420    }
421
422    /// Return the metadata field of the inner flight data
423    pub fn app_metadata(&self) -> Bytes {
424        self.inner.app_metadata.clone()
425    }
426}
427
428/// The result of decoding [`FlightData`]
429#[derive(Debug)]
430pub enum DecodedPayload {
431    /// None (no data was sent in the corresponding FlightData)
432    None,
433
434    /// A decoded Schema message
435    Schema(SchemaRef),
436
437    /// A decoded Record batch.
438    RecordBatch(RecordBatch),
439}