arrow_flight/decode.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_buffer::Buffer;
21use arrow_schema::{Schema, SchemaRef};
22use bytes::Bytes;
23use futures::{ready, stream::BoxStream, Stream, StreamExt};
24use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
25use tonic::metadata::MetadataMap;
26
27use crate::error::{FlightError, Result};
28
29/// Decodes a [Stream] of [`FlightData`] back into
30/// [`RecordBatch`]es. This can be used to decode the response from an
31/// Arrow Flight server
32///
33/// # Note
34/// To access the lower level Flight messages (e.g. to access
35/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`]
36/// and use the [`FlightDataDecoder`] directly.
37///
38/// # Example:
39/// ```no_run
40/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{
41/// # use bytes::Bytes;
42/// // make a do_get request
43/// use arrow_flight::{
44/// error::Result,
45/// decode::FlightRecordBatchStream,
46/// Ticket,
47/// flight_service_client::FlightServiceClient
48/// };
49/// use tonic::transport::Channel;
50/// use futures::stream::{StreamExt, TryStreamExt};
51///
52/// let client: FlightServiceClient<Channel> = // make client..
53/// # unimplemented!();
54///
55/// let request = tonic::Request::new(
56/// Ticket { ticket: Bytes::new() }
57/// );
58///
59/// // Get a stream of FlightData;
60/// let flight_data_stream = client
61/// .do_get(request)
62/// .await?
63/// .into_inner();
64///
65/// // Decode stream of FlightData to RecordBatches
66/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
67/// // convert tonic::Status to FlightError
68/// flight_data_stream.map_err(|e| e.into())
69/// );
70///
71/// // Read back RecordBatches
72/// while let Some(batch) = record_batch_stream.next().await {
73/// match batch {
74/// Ok(batch) => { /* process batch */ },
75/// Err(e) => { /* handle error */ },
76/// };
77/// }
78///
79/// # Ok(())
80/// # }
81/// ```
82#[derive(Debug)]
83pub struct FlightRecordBatchStream {
84 /// Optional grpc header metadata.
85 headers: MetadataMap,
86
87 /// Optional grpc trailer metadata.
88 trailers: Option<LazyTrailers>,
89
90 inner: FlightDataDecoder,
91}
92
93impl FlightRecordBatchStream {
94 /// Create a new [`FlightRecordBatchStream`] from a decoded stream
95 pub fn new(inner: FlightDataDecoder) -> Self {
96 Self {
97 inner,
98 headers: MetadataMap::default(),
99 trailers: None,
100 }
101 }
102
103 /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
104 pub fn new_from_flight_data<S>(inner: S) -> Self
105 where
106 S: Stream<Item = Result<FlightData>> + Send + 'static,
107 {
108 Self {
109 inner: FlightDataDecoder::new(inner),
110 headers: MetadataMap::default(),
111 trailers: None,
112 }
113 }
114
115 /// Record response headers.
116 pub fn with_headers(self, headers: MetadataMap) -> Self {
117 Self { headers, ..self }
118 }
119
120 /// Record response trailers.
121 pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
122 Self {
123 trailers: Some(trailers),
124 ..self
125 }
126 }
127
128 /// Headers attached to this stream.
129 pub fn headers(&self) -> &MetadataMap {
130 &self.headers
131 }
132
133 /// Trailers attached to this stream.
134 ///
135 /// Note that this will return `None` until the entire stream is consumed.
136 /// Only after calling `next()` returns `None`, might any available trailers be returned.
137 pub fn trailers(&self) -> Option<MetadataMap> {
138 self.trailers.as_ref().and_then(|trailers| trailers.get())
139 }
140
141 /// Has a message defining the schema been received yet?
142 #[deprecated = "use schema().is_some() instead"]
143 pub fn got_schema(&self) -> bool {
144 self.schema().is_some()
145 }
146
147 /// Return schema for the stream, if it has been received
148 pub fn schema(&self) -> Option<&SchemaRef> {
149 self.inner.schema()
150 }
151
152 /// Consume self and return the wrapped [`FlightDataDecoder`]
153 pub fn into_inner(self) -> FlightDataDecoder {
154 self.inner
155 }
156}
157
158impl futures::Stream for FlightRecordBatchStream {
159 type Item = Result<RecordBatch>;
160
161 /// Returns the next [`RecordBatch`] available in this stream, or `None` if
162 /// there are no further results available.
163 fn poll_next(
164 mut self: Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> Poll<Option<Result<RecordBatch>>> {
167 loop {
168 let had_schema = self.schema().is_some();
169 let res = ready!(self.inner.poll_next_unpin(cx));
170 match res {
171 // Inner exhausted
172 None => {
173 return Poll::Ready(None);
174 }
175 Some(Err(e)) => {
176 return Poll::Ready(Some(Err(e)));
177 }
178 // translate data
179 Some(Ok(data)) => match data.payload {
180 DecodedPayload::Schema(_) if had_schema => {
181 return Poll::Ready(Some(Err(FlightError::protocol(
182 "Unexpectedly saw multiple Schema messages in FlightData stream",
183 ))));
184 }
185 DecodedPayload::Schema(_) => {
186 // Need next message, poll inner again
187 }
188 DecodedPayload::RecordBatch(batch) => {
189 return Poll::Ready(Some(Ok(batch)));
190 }
191 DecodedPayload::None => {
192 // Need next message
193 }
194 },
195 }
196 }
197 }
198}
199
200/// Wrapper around a stream of [`FlightData`] that handles the details
201/// of decoding low level Flight messages into [`Schema`] and
202/// [`RecordBatch`]es, including details such as dictionaries.
203///
204/// # Protocol Details
205///
206/// The client handles flight messages as followes:
207///
208/// - **None:** This message has no effect. This is useful to
209/// transmit metadata without any actual payload.
210///
211/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
212/// the decoded schema is returned.
213///
214/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
215/// dictionary for the same column will be overwritten. This
216/// message is NOT visible.
217///
218/// - **Record Batch:** Record batch is created based on the current
219/// schema and dictionaries. This fails if no schema was transmitted
220/// yet.
221///
222/// All other message types (at the time of writing: e.g. tensor and
223/// sparse tensor) lead to an error.
224///
225/// Example usecases
226///
227/// 1. Using this low level stream it is possible to receive a steam
228/// of RecordBatches in FlightData that have different schemas by
229/// handling multiple schema messages separately.
230pub struct FlightDataDecoder {
231 /// Underlying data stream
232 response: BoxStream<'static, Result<FlightData>>,
233 /// Decoding state
234 state: Option<FlightStreamState>,
235 /// Seen the end of the inner stream?
236 done: bool,
237}
238
239impl Debug for FlightDataDecoder {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct("FlightDataDecoder")
242 .field("response", &"<stream>")
243 .field("state", &self.state)
244 .field("done", &self.done)
245 .finish()
246 }
247}
248
249impl FlightDataDecoder {
250 /// Create a new wrapper around the stream of [`FlightData`]
251 pub fn new<S>(response: S) -> Self
252 where
253 S: Stream<Item = Result<FlightData>> + Send + 'static,
254 {
255 Self {
256 state: None,
257 response: response.boxed(),
258 done: false,
259 }
260 }
261
262 /// Returns the current schema for this stream
263 pub fn schema(&self) -> Option<&SchemaRef> {
264 self.state.as_ref().map(|state| &state.schema)
265 }
266
267 /// Extracts flight data from the next message, updating decoding
268 /// state as necessary.
269 fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
270 use arrow_ipc::MessageHeader;
271 let message = arrow_ipc::root_as_message(&data.data_header[..])
272 .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?;
273
274 match message.header_type() {
275 MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
276 MessageHeader::Schema => {
277 let schema = Schema::try_from(&data)
278 .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
279
280 let schema = Arc::new(schema);
281 let dictionaries_by_field = HashMap::new();
282
283 self.state = Some(FlightStreamState {
284 schema: Arc::clone(&schema),
285 dictionaries_by_field,
286 });
287 Ok(Some(DecodedFlightData::new_schema(data, schema)))
288 }
289 MessageHeader::DictionaryBatch => {
290 let state = if let Some(state) = self.state.as_mut() {
291 state
292 } else {
293 return Err(FlightError::protocol(
294 "Received DictionaryBatch prior to Schema",
295 ));
296 };
297
298 let buffer = Buffer::from(data.data_body);
299 let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| {
300 FlightError::protocol(
301 "Could not get dictionary batch from DictionaryBatch message",
302 )
303 })?;
304
305 arrow_ipc::reader::read_dictionary(
306 &buffer,
307 dictionary_batch,
308 &state.schema,
309 &mut state.dictionaries_by_field,
310 &message.version(),
311 )
312 .map_err(|e| {
313 FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}"))
314 })?;
315
316 // Updated internal state, but no decoded message
317 Ok(None)
318 }
319 MessageHeader::RecordBatch => {
320 let state = if let Some(state) = self.state.as_ref() {
321 state
322 } else {
323 return Err(FlightError::protocol(
324 "Received RecordBatch prior to Schema",
325 ));
326 };
327
328 let batch = flight_data_to_arrow_batch(
329 &data,
330 Arc::clone(&state.schema),
331 &state.dictionaries_by_field,
332 )
333 .map_err(|e| {
334 FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
335 })?;
336
337 Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
338 }
339 other => {
340 let name = other.variant_name().unwrap_or("UNKNOWN");
341 Err(FlightError::protocol(format!("Unexpected message: {name}")))
342 }
343 }
344 }
345}
346
347impl futures::Stream for FlightDataDecoder {
348 type Item = Result<DecodedFlightData>;
349 /// Returns the result of decoding the next [`FlightData`] message
350 /// from the server, or `None` if there are no further results
351 /// available.
352 fn poll_next(
353 mut self: Pin<&mut Self>,
354 cx: &mut std::task::Context<'_>,
355 ) -> Poll<Option<Self::Item>> {
356 if self.done {
357 return Poll::Ready(None);
358 }
359 loop {
360 let res = ready!(self.response.poll_next_unpin(cx));
361
362 return Poll::Ready(match res {
363 None => {
364 self.done = true;
365 None // inner is exhausted
366 }
367 Some(data) => Some(match data {
368 Err(e) => Err(e),
369 Ok(data) => match self.extract_message(data) {
370 Ok(Some(extracted)) => Ok(extracted),
371 Ok(None) => continue, // Need next input message
372 Err(e) => Err(e),
373 },
374 }),
375 });
376 }
377 }
378}
379
380/// tracks the state needed to reconstruct [`RecordBatch`]es from a
381/// streaming flight response.
382#[derive(Debug)]
383struct FlightStreamState {
384 schema: SchemaRef,
385 dictionaries_by_field: HashMap<i64, ArrayRef>,
386}
387
388/// FlightData and the decoded payload (Schema, RecordBatch), if any
389#[derive(Debug)]
390pub struct DecodedFlightData {
391 /// The original FlightData message
392 pub inner: FlightData,
393 /// The decoded payload
394 pub payload: DecodedPayload,
395}
396
397impl DecodedFlightData {
398 /// Create a new DecodedFlightData with no payload
399 pub fn new_none(inner: FlightData) -> Self {
400 Self {
401 inner,
402 payload: DecodedPayload::None,
403 }
404 }
405
406 /// Create a new DecodedFlightData with a [`Schema`] payload
407 pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
408 Self {
409 inner,
410 payload: DecodedPayload::Schema(schema),
411 }
412 }
413
414 /// Create a new [`DecodedFlightData`] with a [`RecordBatch`] payload
415 pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
416 Self {
417 inner,
418 payload: DecodedPayload::RecordBatch(batch),
419 }
420 }
421
422 /// Return the metadata field of the inner flight data
423 pub fn app_metadata(&self) -> Bytes {
424 self.inner.app_metadata.clone()
425 }
426}
427
428/// The result of decoding [`FlightData`]
429#[derive(Debug)]
430pub enum DecodedPayload {
431 /// None (no data was sent in the corresponding FlightData)
432 None,
433
434 /// A decoded Schema message
435 Schema(SchemaRef),
436
437 /// A decoded Record batch.
438 RecordBatch(RecordBatch),
439}