arrow_flight/streams.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters
19
20use crate::error::FlightError;
21use futures::{
22 channel::oneshot::{Receiver, Sender},
23 FutureExt, Stream, StreamExt,
24};
25use std::pin::Pin;
26use std::task::{ready, Poll};
27
28/// Wrapper around a fallible stream (one that returns errors) that makes it infallible.
29///
30/// Any errors encountered in the stream are ignored are sent to the provided
31/// oneshot sender.
32///
33/// This can be used to accept a stream of `Result<_>` from a client API and send
34/// them to the remote server that wants only the successful results.
35pub struct FallibleRequestStream<T, E> {
36 /// sender to notify error
37 sender: Option<Sender<E>>,
38 /// fallible stream
39 fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
40}
41
42impl<T, E> FallibleRequestStream<T, E> {
43 /// Create a FallibleRequestStream
44 pub fn new(
45 sender: Sender<E>,
46 fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
47 ) -> Self {
48 Self {
49 sender: Some(sender),
50 fallible_stream,
51 }
52 }
53}
54
55impl<T, E> Stream for FallibleRequestStream<T, E> {
56 type Item = T;
57
58 fn poll_next(
59 self: std::pin::Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 ) -> std::task::Poll<Option<Self::Item>> {
62 let pinned = self.get_mut();
63 let mut request_streams = pinned.fallible_stream.as_mut();
64 match ready!(request_streams.poll_next_unpin(cx)) {
65 Some(Ok(data)) => Poll::Ready(Some(data)),
66 Some(Err(e)) => {
67 // in theory this should only ever be called once
68 // as this stream should not be polled again after returning
69 // None, however we still check for None to be safe
70 if let Some(sender) = pinned.sender.take() {
71 // an error means the other end of the channel is not around
72 // to receive the error, so ignore it
73 let _ = sender.send(e);
74 }
75 Poll::Ready(None)
76 }
77 None => Poll::Ready(None),
78 }
79 }
80}
81
82/// Wrapper for a tonic response stream that maps errors to `FlightError` and
83/// returns errors from a oneshot channel into the stream.
84///
85/// The user of this stream can inject an error into the response stream using
86/// the one shot receiver. This is used to propagate errors in
87/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client
88/// provided input stream to the response stream.
89///
90/// # Error Priority
91/// Error from the receiver are prioritised over the response stream.
92///
93/// [`FlightClient::do_put`]: crate::FlightClient::do_put
94/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange
95pub(crate) struct FallibleTonicResponseStream<T> {
96 /// Receiver for FlightError
97 receiver: Receiver<FlightError>,
98 /// Tonic response stream
99 response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
100}
101
102impl<T> FallibleTonicResponseStream<T> {
103 pub(crate) fn new(
104 receiver: Receiver<FlightError>,
105 response_stream: Pin<Box<dyn Stream<Item = Result<T, tonic::Status>> + Send + 'static>>,
106 ) -> Self {
107 Self {
108 receiver,
109 response_stream,
110 }
111 }
112}
113
114impl<T> Stream for FallibleTonicResponseStream<T> {
115 type Item = Result<T, FlightError>;
116
117 fn poll_next(
118 self: Pin<&mut Self>,
119 cx: &mut std::task::Context<'_>,
120 ) -> Poll<Option<Self::Item>> {
121 let pinned = self.get_mut();
122 let receiver = &mut pinned.receiver;
123 // Prioritise sending the error that's been notified over
124 // polling the response_stream
125 if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
126 return Poll::Ready(Some(Err(err)));
127 };
128
129 match ready!(pinned.response_stream.poll_next_unpin(cx)) {
130 Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
131 Some(Err(status)) => Poll::Ready(Some(Err(status.into()))),
132 None => Poll::Ready(None),
133 }
134 }
135}