arrow_flight/
streams.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters
19
20use crate::error::FlightError;
21use futures::{
22    channel::oneshot::{Receiver, Sender},
23    FutureExt, Stream, StreamExt,
24};
25use std::pin::Pin;
26use std::task::{ready, Poll};
27
28/// Wrapper around a fallible stream (one that returns errors) that makes it infallible.
29///
30/// Any errors encountered in the stream are ignored are sent to the provided
31/// oneshot sender.
32///
33/// This can be used to accept a stream of `Result<_>` from a client API and send
34/// them to the remote server that wants only the successful results.
35pub struct FallibleRequestStream<T, E> {
36    /// sender to notify error
37    sender: Option<Sender<E>>,
38    /// fallible stream
39    fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
40}
41
42impl<T, E> FallibleRequestStream<T, E> {
43    /// Create a FallibleRequestStream
44    pub fn new(
45        sender: Sender<E>,
46        fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
47    ) -> Self {
48        Self {
49            sender: Some(sender),
50            fallible_stream,
51        }
52    }
53}
54
55impl<T, E> Stream for FallibleRequestStream<T, E> {
56    type Item = T;
57
58    fn poll_next(
59        self: std::pin::Pin<&mut Self>,
60        cx: &mut std::task::Context<'_>,
61    ) -> std::task::Poll<Option<Self::Item>> {
62        let pinned = self.get_mut();
63        let mut request_streams = pinned.fallible_stream.as_mut();
64        match ready!(request_streams.poll_next_unpin(cx)) {
65            Some(Ok(data)) => Poll::Ready(Some(data)),
66            Some(Err(e)) => {
67                // in theory this should only ever be called once
68                // as this stream should not be polled again after returning
69                // None, however we still check for None to be safe
70                if let Some(sender) = pinned.sender.take() {
71                    // an error means the other end of the channel is not around
72                    // to receive the error, so ignore it
73                    let _ = sender.send(e);
74                }
75                Poll::Ready(None)
76            }
77            None => Poll::Ready(None),
78        }
79    }
80}
81
82/// Wrapper for a tonic response stream that maps errors to `FlightError` and
83/// returns errors from a oneshot channel into the stream.
84///
85/// The user of this stream can inject an error into the response stream using
86/// the one shot receiver. This is used to propagate errors in
87/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client
88/// provided input stream to the response stream.
89///
90/// # Error Priority
91/// Error from the receiver are prioritised over the response stream.
92///
93/// [`FlightClient::do_put`]: crate::FlightClient::do_put
94/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange
95pub(crate) struct FallibleTonicResponseStream<T> {
96    /// Receiver for FlightError
97    receiver: Receiver<FlightError>,
98    /// Tonic response stream
99    response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
100}
101
102impl<T> FallibleTonicResponseStream<T> {
103    pub(crate) fn new(
104        receiver: Receiver<FlightError>,
105        response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
106    ) -> Self {
107        Self {
108            receiver,
109            response_stream,
110        }
111    }
112}
113
114impl<T> Stream for FallibleTonicResponseStream<T> {
115    type Item = Result<T, FlightError>;
116
117    fn poll_next(
118        self: Pin<&mut Self>,
119        cx: &mut std::task::Context<'_>,
120    ) -> Poll<Option<Self::Item>> {
121        let pinned = self.get_mut();
122        let receiver = &mut pinned.receiver;
123        // Prioritise sending the error that's been notified over
124        // polling the response_stream
125        if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
126            return Poll::Ready(Some(Err(err)));
127        };
128
129        match ready!(pinned.response_stream.poll_next_unpin(cx)) {
130            Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
131            Some(Err(status)) => Poll::Ready(Some(Err(status.into()))),
132            None => Poll::Ready(None),
133        }
134    }
135}