arrow_flight/decode.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_buffer::Buffer;
21use arrow_schema::{Schema, SchemaRef};
22use bytes::Bytes;
23use futures::{ready, stream::BoxStream, Stream, StreamExt};
24use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
25use tonic::metadata::MetadataMap;
26
27use crate::error::{FlightError, Result};
28
29/// Decodes a [Stream] of [`FlightData`] back into
30/// [`RecordBatch`]es. This can be used to decode the response from an
31/// Arrow Flight server
32///
33/// # Note
34/// To access the lower level Flight messages (e.g. to access
35/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`]
36/// and use the [`FlightDataDecoder`] directly.
37///
38/// # Example:
39/// ```no_run
40/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{
41/// # use bytes::Bytes;
42/// // make a do_get request
43/// use arrow_flight::{
44/// error::Result,
45/// decode::FlightRecordBatchStream,
46/// Ticket,
47/// flight_service_client::FlightServiceClient
48/// };
49/// use tonic::transport::Channel;
50/// use futures::stream::{StreamExt, TryStreamExt};
51///
52/// let client: FlightServiceClient<Channel> = // make client..
53/// # unimplemented!();
54///
55/// let request = tonic::Request::new(
56/// Ticket { ticket: Bytes::new() }
57/// );
58///
59/// // Get a stream of FlightData;
60/// let flight_data_stream = client
61/// .do_get(request)
62/// .await?
63/// .into_inner();
64///
65/// // Decode stream of FlightData to RecordBatches
66/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
67/// // convert tonic::Status to FlightError
68/// flight_data_stream.map_err(|e| e.into())
69/// );
70///
71/// // Read back RecordBatches
72/// while let Some(batch) = record_batch_stream.next().await {
73/// match batch {
74/// Ok(batch) => { /* process batch */ },
75/// Err(e) => { /* handle error */ },
76/// };
77/// }
78///
79/// # Ok(())
80/// # }
81/// ```
82#[derive(Debug)]
83pub struct FlightRecordBatchStream {
84 /// Optional grpc header metadata.
85 headers: MetadataMap,
86
87 /// Optional grpc trailer metadata.
88 trailers: Option<LazyTrailers>,
89
90 inner: FlightDataDecoder,
91}
92
93impl FlightRecordBatchStream {
94 /// Create a new [`FlightRecordBatchStream`] from a decoded stream
95 pub fn new(inner: FlightDataDecoder) -> Self {
96 Self {
97 inner,
98 headers: MetadataMap::default(),
99 trailers: None,
100 }
101 }
102
103 /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
104 pub fn new_from_flight_data<S>(inner: S) -> Self
105 where
106 S: Stream<Item = Result<FlightData>> + Send + 'static,
107 {
108 Self {
109 inner: FlightDataDecoder::new(inner),
110 headers: MetadataMap::default(),
111 trailers: None,
112 }
113 }
114
115 /// Record response headers.
116 pub fn with_headers(self, headers: MetadataMap) -> Self {
117 Self { headers, ..self }
118 }
119
120 /// Record response trailers.
121 pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
122 Self {
123 trailers: Some(trailers),
124 ..self
125 }
126 }
127
128 /// Headers attached to this stream.
129 pub fn headers(&self) -> &MetadataMap {
130 &self.headers
131 }
132
133 /// Trailers attached to this stream.
134 ///
135 /// Note that this will return `None` until the entire stream is consumed.
136 /// Only after calling `next()` returns `None`, might any available trailers be returned.
137 pub fn trailers(&self) -> Option<MetadataMap> {
138 self.trailers.as_ref().and_then(|trailers| trailers.get())
139 }
140
141 /// Return schema for the stream, if it has been received
142 pub fn schema(&self) -> Option<&SchemaRef> {
143 self.inner.schema()
144 }
145
146 /// Consume self and return the wrapped [`FlightDataDecoder`]
147 pub fn into_inner(self) -> FlightDataDecoder {
148 self.inner
149 }
150}
151
152impl futures::Stream for FlightRecordBatchStream {
153 type Item = Result<RecordBatch>;
154
155 /// Returns the next [`RecordBatch`] available in this stream, or `None` if
156 /// there are no further results available.
157 fn poll_next(
158 mut self: Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 ) -> Poll<Option<Result<RecordBatch>>> {
161 loop {
162 let had_schema = self.schema().is_some();
163 let res = ready!(self.inner.poll_next_unpin(cx));
164 match res {
165 // Inner exhausted
166 None => {
167 return Poll::Ready(None);
168 }
169 Some(Err(e)) => {
170 return Poll::Ready(Some(Err(e)));
171 }
172 // translate data
173 Some(Ok(data)) => match data.payload {
174 DecodedPayload::Schema(_) if had_schema => {
175 return Poll::Ready(Some(Err(FlightError::protocol(
176 "Unexpectedly saw multiple Schema messages in FlightData stream",
177 ))));
178 }
179 DecodedPayload::Schema(_) => {
180 // Need next message, poll inner again
181 }
182 DecodedPayload::RecordBatch(batch) => {
183 return Poll::Ready(Some(Ok(batch)));
184 }
185 DecodedPayload::None => {
186 // Need next message
187 }
188 },
189 }
190 }
191 }
192}
193
194/// Wrapper around a stream of [`FlightData`] that handles the details
195/// of decoding low level Flight messages into [`Schema`] and
196/// [`RecordBatch`]es, including details such as dictionaries.
197///
198/// # Protocol Details
199///
200/// The client handles flight messages as followes:
201///
202/// - **None:** This message has no effect. This is useful to
203/// transmit metadata without any actual payload.
204///
205/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
206/// the decoded schema is returned.
207///
208/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
209/// dictionary for the same column will be overwritten. This
210/// message is NOT visible.
211///
212/// - **Record Batch:** Record batch is created based on the current
213/// schema and dictionaries. This fails if no schema was transmitted
214/// yet.
215///
216/// All other message types (at the time of writing: e.g. tensor and
217/// sparse tensor) lead to an error.
218///
219/// Example usecases
220///
221/// 1. Using this low level stream it is possible to receive a steam
222/// of RecordBatches in FlightData that have different schemas by
223/// handling multiple schema messages separately.
224pub struct FlightDataDecoder {
225 /// Underlying data stream
226 response: BoxStream<'static, Result<FlightData>>,
227 /// Decoding state
228 state: Option<FlightStreamState>,
229 /// Seen the end of the inner stream?
230 done: bool,
231}
232
233impl Debug for FlightDataDecoder {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("FlightDataDecoder")
236 .field("response", &"<stream>")
237 .field("state", &self.state)
238 .field("done", &self.done)
239 .finish()
240 }
241}
242
243impl FlightDataDecoder {
244 /// Create a new wrapper around the stream of [`FlightData`]
245 pub fn new<S>(response: S) -> Self
246 where
247 S: Stream<Item = Result<FlightData>> + Send + 'static,
248 {
249 Self {
250 state: None,
251 response: response.boxed(),
252 done: false,
253 }
254 }
255
256 /// Returns the current schema for this stream
257 pub fn schema(&self) -> Option<&SchemaRef> {
258 self.state.as_ref().map(|state| &state.schema)
259 }
260
261 /// Extracts flight data from the next message, updating decoding
262 /// state as necessary.
263 fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
264 use arrow_ipc::MessageHeader;
265 let message = arrow_ipc::root_as_message(&data.data_header[..])
266 .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?;
267
268 match message.header_type() {
269 MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
270 MessageHeader::Schema => {
271 let schema = Schema::try_from(&data)
272 .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
273
274 let schema = Arc::new(schema);
275 let dictionaries_by_field = HashMap::new();
276
277 self.state = Some(FlightStreamState {
278 schema: Arc::clone(&schema),
279 dictionaries_by_field,
280 });
281 Ok(Some(DecodedFlightData::new_schema(data, schema)))
282 }
283 MessageHeader::DictionaryBatch => {
284 let state = if let Some(state) = self.state.as_mut() {
285 state
286 } else {
287 return Err(FlightError::protocol(
288 "Received DictionaryBatch prior to Schema",
289 ));
290 };
291
292 let buffer = Buffer::from(data.data_body);
293 let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| {
294 FlightError::protocol(
295 "Could not get dictionary batch from DictionaryBatch message",
296 )
297 })?;
298
299 arrow_ipc::reader::read_dictionary(
300 &buffer,
301 dictionary_batch,
302 &state.schema,
303 &mut state.dictionaries_by_field,
304 &message.version(),
305 )
306 .map_err(|e| {
307 FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}"))
308 })?;
309
310 // Updated internal state, but no decoded message
311 Ok(None)
312 }
313 MessageHeader::RecordBatch => {
314 let state = if let Some(state) = self.state.as_ref() {
315 state
316 } else {
317 return Err(FlightError::protocol(
318 "Received RecordBatch prior to Schema",
319 ));
320 };
321
322 let batch = flight_data_to_arrow_batch(
323 &data,
324 Arc::clone(&state.schema),
325 &state.dictionaries_by_field,
326 )
327 .map_err(|e| {
328 FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
329 })?;
330
331 Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
332 }
333 other => {
334 let name = other.variant_name().unwrap_or("UNKNOWN");
335 Err(FlightError::protocol(format!("Unexpected message: {name}")))
336 }
337 }
338 }
339}
340
341impl futures::Stream for FlightDataDecoder {
342 type Item = Result<DecodedFlightData>;
343 /// Returns the result of decoding the next [`FlightData`] message
344 /// from the server, or `None` if there are no further results
345 /// available.
346 fn poll_next(
347 mut self: Pin<&mut Self>,
348 cx: &mut std::task::Context<'_>,
349 ) -> Poll<Option<Self::Item>> {
350 if self.done {
351 return Poll::Ready(None);
352 }
353 loop {
354 let res = ready!(self.response.poll_next_unpin(cx));
355
356 return Poll::Ready(match res {
357 None => {
358 self.done = true;
359 None // inner is exhausted
360 }
361 Some(data) => Some(match data {
362 Err(e) => Err(e),
363 Ok(data) => match self.extract_message(data) {
364 Ok(Some(extracted)) => Ok(extracted),
365 Ok(None) => continue, // Need next input message
366 Err(e) => Err(e),
367 },
368 }),
369 });
370 }
371 }
372}
373
374/// tracks the state needed to reconstruct [`RecordBatch`]es from a
375/// streaming flight response.
376#[derive(Debug)]
377struct FlightStreamState {
378 schema: SchemaRef,
379 dictionaries_by_field: HashMap<i64, ArrayRef>,
380}
381
382/// FlightData and the decoded payload (Schema, RecordBatch), if any
383#[derive(Debug)]
384pub struct DecodedFlightData {
385 /// The original FlightData message
386 pub inner: FlightData,
387 /// The decoded payload
388 pub payload: DecodedPayload,
389}
390
391impl DecodedFlightData {
392 /// Create a new DecodedFlightData with no payload
393 pub fn new_none(inner: FlightData) -> Self {
394 Self {
395 inner,
396 payload: DecodedPayload::None,
397 }
398 }
399
400 /// Create a new DecodedFlightData with a [`Schema`] payload
401 pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
402 Self {
403 inner,
404 payload: DecodedPayload::Schema(schema),
405 }
406 }
407
408 /// Create a new [`DecodedFlightData`] with a [`RecordBatch`] payload
409 pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
410 Self {
411 inner,
412 payload: DecodedPayload::RecordBatch(batch),
413 }
414 }
415
416 /// Return the metadata field of the inner flight data
417 pub fn app_metadata(&self) -> Bytes {
418 self.inner.app_metadata.clone()
419 }
420}
421
422/// The result of decoding [`FlightData`]
423#[derive(Debug)]
424pub enum DecodedPayload {
425 /// None (no data was sent in the corresponding FlightData)
426 None,
427
428 /// A decoded Schema message
429 Schema(SchemaRef),
430
431 /// A decoded Record batch.
432 RecordBatch(RecordBatch),
433}