arrow_flight/
decode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_buffer::Buffer;
21use arrow_schema::{Schema, SchemaRef};
22use bytes::Bytes;
23use futures::{ready, stream::BoxStream, Stream, StreamExt};
24use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
25use tonic::metadata::MetadataMap;
26
27use crate::error::{FlightError, Result};
28
29/// Decodes a [Stream] of [`FlightData`] back into
30/// [`RecordBatch`]es. This can be used to decode the response from an
31/// Arrow Flight server
32///
33/// # Note
34/// To access the lower level Flight messages (e.g. to access
35/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`]
36/// and use the [`FlightDataDecoder`] directly.
37///
38/// # Example:
39/// ```no_run
40/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{
41/// # use bytes::Bytes;
42/// // make a do_get request
43/// use arrow_flight::{
44///   error::Result,
45///   decode::FlightRecordBatchStream,
46///   Ticket,
47///   flight_service_client::FlightServiceClient
48/// };
49/// use tonic::transport::Channel;
50/// use futures::stream::{StreamExt, TryStreamExt};
51///
52/// let client: FlightServiceClient<Channel> = // make client..
53/// # unimplemented!();
54///
55/// let request = tonic::Request::new(
56///   Ticket { ticket: Bytes::new() }
57/// );
58///
59/// // Get a stream of FlightData;
60/// let flight_data_stream = client
61///   .do_get(request)
62///   .await?
63///   .into_inner();
64///
65/// // Decode stream of FlightData to RecordBatches
66/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
67///   // convert tonic::Status to FlightError
68///   flight_data_stream.map_err(|e| e.into())
69/// );
70///
71/// // Read back RecordBatches
72/// while let Some(batch) = record_batch_stream.next().await {
73///   match batch {
74///     Ok(batch) => { /* process batch */ },
75///     Err(e) => { /* handle error */ },
76///   };
77/// }
78///
79/// # Ok(())
80/// # }
81/// ```
82#[derive(Debug)]
83pub struct FlightRecordBatchStream {
84    /// Optional grpc header metadata.
85    headers: MetadataMap,
86
87    /// Optional grpc trailer metadata.
88    trailers: Option<LazyTrailers>,
89
90    inner: FlightDataDecoder,
91}
92
93impl FlightRecordBatchStream {
94    /// Create a new [`FlightRecordBatchStream`] from a decoded stream
95    pub fn new(inner: FlightDataDecoder) -> Self {
96        Self {
97            inner,
98            headers: MetadataMap::default(),
99            trailers: None,
100        }
101    }
102
103    /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
104    pub fn new_from_flight_data<S>(inner: S) -> Self
105    where
106        S: Stream<Item = Result<FlightData>> + Send + 'static,
107    {
108        Self {
109            inner: FlightDataDecoder::new(inner),
110            headers: MetadataMap::default(),
111            trailers: None,
112        }
113    }
114
115    /// Record response headers.
116    pub fn with_headers(self, headers: MetadataMap) -> Self {
117        Self { headers, ..self }
118    }
119
120    /// Record response trailers.
121    pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
122        Self {
123            trailers: Some(trailers),
124            ..self
125        }
126    }
127
128    /// Headers attached to this stream.
129    pub fn headers(&self) -> &MetadataMap {
130        &self.headers
131    }
132
133    /// Trailers attached to this stream.
134    ///
135    /// Note that this will return `None` until the entire stream is consumed.
136    /// Only after calling `next()` returns `None`, might any available trailers be returned.
137    pub fn trailers(&self) -> Option<MetadataMap> {
138        self.trailers.as_ref().and_then(|trailers| trailers.get())
139    }
140
141    /// Return schema for the stream, if it has been received
142    pub fn schema(&self) -> Option<&SchemaRef> {
143        self.inner.schema()
144    }
145
146    /// Consume self and return the wrapped [`FlightDataDecoder`]
147    pub fn into_inner(self) -> FlightDataDecoder {
148        self.inner
149    }
150}
151
152impl futures::Stream for FlightRecordBatchStream {
153    type Item = Result<RecordBatch>;
154
155    /// Returns the next [`RecordBatch`] available in this stream, or `None` if
156    /// there are no further results available.
157    fn poll_next(
158        mut self: Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> Poll<Option<Result<RecordBatch>>> {
161        loop {
162            let had_schema = self.schema().is_some();
163            let res = ready!(self.inner.poll_next_unpin(cx));
164            match res {
165                // Inner exhausted
166                None => {
167                    return Poll::Ready(None);
168                }
169                Some(Err(e)) => {
170                    return Poll::Ready(Some(Err(e)));
171                }
172                // translate data
173                Some(Ok(data)) => match data.payload {
174                    DecodedPayload::Schema(_) if had_schema => {
175                        return Poll::Ready(Some(Err(FlightError::protocol(
176                            "Unexpectedly saw multiple Schema messages in FlightData stream",
177                        ))));
178                    }
179                    DecodedPayload::Schema(_) => {
180                        // Need next message, poll inner again
181                    }
182                    DecodedPayload::RecordBatch(batch) => {
183                        return Poll::Ready(Some(Ok(batch)));
184                    }
185                    DecodedPayload::None => {
186                        // Need next message
187                    }
188                },
189            }
190        }
191    }
192}
193
194/// Wrapper around a stream of [`FlightData`] that handles the details
195/// of decoding low level Flight messages into [`Schema`] and
196/// [`RecordBatch`]es, including details such as dictionaries.
197///
198/// # Protocol Details
199///
200/// The client handles flight messages as followes:
201///
202/// - **None:** This message has no effect. This is useful to
203///   transmit metadata without any actual payload.
204///
205/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
206///   the decoded schema is returned.
207///
208/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
209///   dictionary for the same column will be overwritten. This
210///   message is NOT visible.
211///
212/// - **Record Batch:** Record batch is created based on the current
213///   schema and dictionaries. This fails if no schema was transmitted
214///   yet.
215///
216/// All other message types (at the time of writing: e.g. tensor and
217/// sparse tensor) lead to an error.
218///
219/// Example usecases
220///
221/// 1. Using this low level stream it is possible to receive a steam
222///    of RecordBatches in FlightData that have different schemas by
223///    handling multiple schema messages separately.
224pub struct FlightDataDecoder {
225    /// Underlying data stream
226    response: BoxStream<'static, Result<FlightData>>,
227    /// Decoding state
228    state: Option<FlightStreamState>,
229    /// Seen the end of the inner stream?
230    done: bool,
231}
232
233impl Debug for FlightDataDecoder {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.debug_struct("FlightDataDecoder")
236            .field("response", &"<stream>")
237            .field("state", &self.state)
238            .field("done", &self.done)
239            .finish()
240    }
241}
242
243impl FlightDataDecoder {
244    /// Create a new wrapper around the stream of [`FlightData`]
245    pub fn new<S>(response: S) -> Self
246    where
247        S: Stream<Item = Result<FlightData>> + Send + 'static,
248    {
249        Self {
250            state: None,
251            response: response.boxed(),
252            done: false,
253        }
254    }
255
256    /// Returns the current schema for this stream
257    pub fn schema(&self) -> Option<&SchemaRef> {
258        self.state.as_ref().map(|state| &state.schema)
259    }
260
261    /// Extracts flight data from the next message, updating decoding
262    /// state as necessary.
263    fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
264        use arrow_ipc::MessageHeader;
265        let message = arrow_ipc::root_as_message(&data.data_header[..])
266            .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?;
267
268        match message.header_type() {
269            MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
270            MessageHeader::Schema => {
271                let schema = Schema::try_from(&data)
272                    .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
273
274                let schema = Arc::new(schema);
275                let dictionaries_by_field = HashMap::new();
276
277                self.state = Some(FlightStreamState {
278                    schema: Arc::clone(&schema),
279                    dictionaries_by_field,
280                });
281                Ok(Some(DecodedFlightData::new_schema(data, schema)))
282            }
283            MessageHeader::DictionaryBatch => {
284                let state = if let Some(state) = self.state.as_mut() {
285                    state
286                } else {
287                    return Err(FlightError::protocol(
288                        "Received DictionaryBatch prior to Schema",
289                    ));
290                };
291
292                let buffer = Buffer::from(data.data_body);
293                let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| {
294                    FlightError::protocol(
295                        "Could not get dictionary batch from DictionaryBatch message",
296                    )
297                })?;
298
299                arrow_ipc::reader::read_dictionary(
300                    &buffer,
301                    dictionary_batch,
302                    &state.schema,
303                    &mut state.dictionaries_by_field,
304                    &message.version(),
305                )
306                .map_err(|e| {
307                    FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}"))
308                })?;
309
310                // Updated internal state, but no decoded message
311                Ok(None)
312            }
313            MessageHeader::RecordBatch => {
314                let state = if let Some(state) = self.state.as_ref() {
315                    state
316                } else {
317                    return Err(FlightError::protocol(
318                        "Received RecordBatch prior to Schema",
319                    ));
320                };
321
322                let batch = flight_data_to_arrow_batch(
323                    &data,
324                    Arc::clone(&state.schema),
325                    &state.dictionaries_by_field,
326                )
327                .map_err(|e| {
328                    FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
329                })?;
330
331                Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
332            }
333            other => {
334                let name = other.variant_name().unwrap_or("UNKNOWN");
335                Err(FlightError::protocol(format!("Unexpected message: {name}")))
336            }
337        }
338    }
339}
340
341impl futures::Stream for FlightDataDecoder {
342    type Item = Result<DecodedFlightData>;
343    /// Returns the result of decoding the next [`FlightData`] message
344    /// from the server, or `None` if there are no further results
345    /// available.
346    fn poll_next(
347        mut self: Pin<&mut Self>,
348        cx: &mut std::task::Context<'_>,
349    ) -> Poll<Option<Self::Item>> {
350        if self.done {
351            return Poll::Ready(None);
352        }
353        loop {
354            let res = ready!(self.response.poll_next_unpin(cx));
355
356            return Poll::Ready(match res {
357                None => {
358                    self.done = true;
359                    None // inner is exhausted
360                }
361                Some(data) => Some(match data {
362                    Err(e) => Err(e),
363                    Ok(data) => match self.extract_message(data) {
364                        Ok(Some(extracted)) => Ok(extracted),
365                        Ok(None) => continue, // Need next input message
366                        Err(e) => Err(e),
367                    },
368                }),
369            });
370        }
371    }
372}
373
374/// tracks the state needed to reconstruct [`RecordBatch`]es from a
375/// streaming flight response.
376#[derive(Debug)]
377struct FlightStreamState {
378    schema: SchemaRef,
379    dictionaries_by_field: HashMap<i64, ArrayRef>,
380}
381
382/// FlightData and the decoded payload (Schema, RecordBatch), if any
383#[derive(Debug)]
384pub struct DecodedFlightData {
385    /// The original FlightData message
386    pub inner: FlightData,
387    /// The decoded payload
388    pub payload: DecodedPayload,
389}
390
391impl DecodedFlightData {
392    /// Create a new DecodedFlightData with no payload
393    pub fn new_none(inner: FlightData) -> Self {
394        Self {
395            inner,
396            payload: DecodedPayload::None,
397        }
398    }
399
400    /// Create a new DecodedFlightData with a [`Schema`] payload
401    pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
402        Self {
403            inner,
404            payload: DecodedPayload::Schema(schema),
405        }
406    }
407
408    /// Create a new [`DecodedFlightData`] with a [`RecordBatch`] payload
409    pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
410        Self {
411            inner,
412            payload: DecodedPayload::RecordBatch(batch),
413        }
414    }
415
416    /// Return the metadata field of the inner flight data
417    pub fn app_metadata(&self) -> Bytes {
418        self.inner.app_metadata.clone()
419    }
420}
421
422/// The result of decoding [`FlightData`]
423#[derive(Debug)]
424pub enum DecodedPayload {
425    /// None (no data was sent in the corresponding FlightData)
426    None,
427
428    /// A decoded Schema message
429    Schema(SchemaRef),
430
431    /// A decoded Record batch.
432    RecordBatch(RecordBatch),
433}