arrow_flight/
streams.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters
19
20use crate::error::FlightError;
21use futures::{
22    channel::oneshot::{Receiver, Sender},
23    FutureExt, Stream, StreamExt,
24};
25use std::pin::Pin;
26use std::task::{ready, Poll};
27
28/// Wrapper around a fallible stream (one that returns errors) that makes it infallible.
29///
30/// Any errors encountered in the stream are ignored are sent to the provided
31/// oneshot sender.
32///
33/// This can be used to accept a stream of `Result<_>` from a client API and send
34/// them to the remote server that wants only the successful results.
35pub(crate) struct FallibleRequestStream<T, E> {
36    /// sender to notify error
37    sender: Option<Sender<E>>,
38    /// fallible stream
39    fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
40}
41
42impl<T, E> FallibleRequestStream<T, E> {
43    pub(crate) fn new(
44        sender: Sender<E>,
45        fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
46    ) -> Self {
47        Self {
48            sender: Some(sender),
49            fallible_stream,
50        }
51    }
52}
53
54impl<T, E> Stream for FallibleRequestStream<T, E> {
55    type Item = T;
56
57    fn poll_next(
58        self: std::pin::Pin<&mut Self>,
59        cx: &mut std::task::Context<'_>,
60    ) -> std::task::Poll<Option<Self::Item>> {
61        let pinned = self.get_mut();
62        let mut request_streams = pinned.fallible_stream.as_mut();
63        match ready!(request_streams.poll_next_unpin(cx)) {
64            Some(Ok(data)) => Poll::Ready(Some(data)),
65            Some(Err(e)) => {
66                // in theory this should only ever be called once
67                // as this stream should not be polled again after returning
68                // None, however we still check for None to be safe
69                if let Some(sender) = pinned.sender.take() {
70                    // an error means the other end of the channel is not around
71                    // to receive the error, so ignore it
72                    let _ = sender.send(e);
73                }
74                Poll::Ready(None)
75            }
76            None => Poll::Ready(None),
77        }
78    }
79}
80
81/// Wrapper for a tonic response stream that maps errors to `FlightError` and
82/// returns errors from a oneshot channel into the stream.
83///
84/// The user of this stream can inject an error into the response stream using
85/// the one shot receiver. This is used to propagate errors in
86/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client
87/// provided input stream to the response stream.
88///
89/// # Error Priority
90/// Error from the receiver are prioritised over the response stream.
91///
92/// [`FlightClient::do_put`]: crate::FlightClient::do_put
93/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange
94pub(crate) struct FallibleTonicResponseStream<T> {
95    /// Receiver for FlightError
96    receiver: Receiver<FlightError>,
97    /// Tonic response stream
98    response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
99}
100
101impl<T> FallibleTonicResponseStream<T> {
102    pub(crate) fn new(
103        receiver: Receiver<FlightError>,
104        response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
105    ) -> Self {
106        Self {
107            receiver,
108            response_stream,
109        }
110    }
111}
112
113impl<T> Stream for FallibleTonicResponseStream<T> {
114    type Item = Result<T, FlightError>;
115
116    fn poll_next(
117        self: Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119    ) -> Poll<Option<Self::Item>> {
120        let pinned = self.get_mut();
121        let receiver = &mut pinned.receiver;
122        // Prioritise sending the error that's been notified over
123        // polling the response_stream
124        if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
125            return Poll::Ready(Some(Err(err)));
126        };
127
128        match ready!(pinned.response_stream.poll_next_unpin(cx)) {
129            Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
130            Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
131            None => Poll::Ready(None),
132        }
133    }
134}