arrow_flight/streams.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters
19
20use crate::error::FlightError;
21use futures::{
22 channel::oneshot::{Receiver, Sender},
23 FutureExt, Stream, StreamExt,
24};
25use std::pin::Pin;
26use std::task::{ready, Poll};
27
28/// Wrapper around a fallible stream (one that returns errors) that makes it infallible.
29///
30/// Any errors encountered in the stream are ignored are sent to the provided
31/// oneshot sender.
32///
33/// This can be used to accept a stream of `Result<_>` from a client API and send
34/// them to the remote server that wants only the successful results.
35pub(crate) struct FallibleRequestStream<T, E> {
36 /// sender to notify error
37 sender: Option<Sender<E>>,
38 /// fallible stream
39 fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
40}
41
42impl<T, E> FallibleRequestStream<T, E> {
43 pub(crate) fn new(
44 sender: Sender<E>,
45 fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
46 ) -> Self {
47 Self {
48 sender: Some(sender),
49 fallible_stream,
50 }
51 }
52}
53
54impl<T, E> Stream for FallibleRequestStream<T, E> {
55 type Item = T;
56
57 fn poll_next(
58 self: std::pin::Pin<&mut Self>,
59 cx: &mut std::task::Context<'_>,
60 ) -> std::task::Poll<Option<Self::Item>> {
61 let pinned = self.get_mut();
62 let mut request_streams = pinned.fallible_stream.as_mut();
63 match ready!(request_streams.poll_next_unpin(cx)) {
64 Some(Ok(data)) => Poll::Ready(Some(data)),
65 Some(Err(e)) => {
66 // in theory this should only ever be called once
67 // as this stream should not be polled again after returning
68 // None, however we still check for None to be safe
69 if let Some(sender) = pinned.sender.take() {
70 // an error means the other end of the channel is not around
71 // to receive the error, so ignore it
72 let _ = sender.send(e);
73 }
74 Poll::Ready(None)
75 }
76 None => Poll::Ready(None),
77 }
78 }
79}
80
81/// Wrapper for a tonic response stream that maps errors to `FlightError` and
82/// returns errors from a oneshot channel into the stream.
83///
84/// The user of this stream can inject an error into the response stream using
85/// the one shot receiver. This is used to propagate errors in
86/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client
87/// provided input stream to the response stream.
88///
89/// # Error Priority
90/// Error from the receiver are prioritised over the response stream.
91///
92/// [`FlightClient::do_put`]: crate::FlightClient::do_put
93/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange
94pub(crate) struct FallibleTonicResponseStream<T> {
95 /// Receiver for FlightError
96 receiver: Receiver<FlightError>,
97 /// Tonic response stream
98 response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
99}
100
101impl<T> FallibleTonicResponseStream<T> {
102 pub(crate) fn new(
103 receiver: Receiver<FlightError>,
104 response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
105 ) -> Self {
106 Self {
107 receiver,
108 response_stream,
109 }
110 }
111}
112
113impl<T> Stream for FallibleTonicResponseStream<T> {
114 type Item = Result<T, FlightError>;
115
116 fn poll_next(
117 self: Pin<&mut Self>,
118 cx: &mut std::task::Context<'_>,
119 ) -> Poll<Option<Self::Item>> {
120 let pinned = self.get_mut();
121 let receiver = &mut pinned.receiver;
122 // Prioritise sending the error that's been notified over
123 // polling the response_stream
124 if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
125 return Poll::Ready(Some(Err(err)));
126 };
127
128 match ready!(pinned.response_stream.poll_next_unpin(cx)) {
129 Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
130 Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
131 None => Poll::Ready(None),
132 }
133 }
134}