Skip to main content

arrow_flight/
client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo,
20    HandshakeRequest, PollInfo, PutResult, Ticket,
21    decode::FlightRecordBatchStream,
22    flight_service_client::FlightServiceClient,
23    r#gen::{CancelFlightInfoRequest, CancelFlightInfoResult, RenewFlightEndpointRequest},
24    trailers::extract_lazy_trailers,
25};
26use arrow_schema::Schema;
27use bytes::Bytes;
28use futures::{
29    Stream, StreamExt, TryStreamExt,
30    future::ready,
31    stream::{self, BoxStream},
32};
33use prost::Message;
34use tonic::codegen::{Body, StdError};
35use tonic::{metadata::MetadataMap, transport::Channel};
36
37use crate::error::{FlightError, Result};
38use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};
39
40/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client.
41///
42/// [`FlightClient`] is intended as a convenience for interactions
43/// with Arrow Flight servers. For more direct control, such as access
44/// to the response headers, use  [`FlightServiceClient`] directly
45/// via methods such as [`Self::inner`] or [`Self::into_inner`].
46///
47/// # Example:
48/// ```no_run
49/// # async fn run() {
50/// # use arrow_flight::FlightClient;
51/// # use bytes::Bytes;
52/// use tonic::transport::Channel;
53/// let channel = Channel::from_static("http://localhost:1234")
54///   .connect()
55///   .await
56///   .expect("error connecting");
57///
58/// let mut client = FlightClient::new(channel);
59///
60/// // Send 'Hi' bytes as the handshake request to the server
61/// let response = client
62///   .handshake(Bytes::from("Hi"))
63///   .await
64///   .expect("error handshaking");
65///
66/// // Expect the server responded with 'Ho'
67/// assert_eq!(response, Bytes::from("Ho"));
68/// # }
69/// ```
70#[derive(Debug)]
71pub struct FlightClient<T = Channel> {
72    /// Optional grpc header metadata to include with each request
73    metadata: MetadataMap,
74
75    /// The inner client
76    inner: FlightServiceClient<T>,
77}
78
79impl<T> FlightClient<T>
80where
81    T: tonic::client::GrpcService<tonic::body::Body>,
82    T::Error: Into<StdError>,
83    T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
84    <T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
85{
86    /// Creates a client with the provided transport
87    pub fn new(inner: T) -> Self {
88        Self::new_from_inner(FlightServiceClient::new(inner))
89    }
90
91    /// Creates a new higher level client with the provided lower level client
92    pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
93        Self {
94            metadata: MetadataMap::new(),
95            inner,
96        }
97    }
98
99    /// Return a reference to gRPC metadata included with each request
100    pub fn metadata(&self) -> &MetadataMap {
101        &self.metadata
102    }
103
104    /// Return a reference to gRPC metadata included with each request
105    ///
106    /// These headers can be used, for example, to include
107    /// authorization or other application specific headers.
108    pub fn metadata_mut(&mut self) -> &mut MetadataMap {
109        &mut self.metadata
110    }
111
112    /// Add the specified header with value to all subsequent
113    /// requests. See [`Self::metadata_mut`] for fine grained control.
114    pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
115        let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
116            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
117
118        let value = value
119            .parse()
120            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
121
122        // ignore previous value
123        self.metadata.insert(key, value);
124
125        Ok(())
126    }
127
128    /// Return a reference to the underlying tonic
129    /// [`FlightServiceClient`]
130    pub fn inner(&self) -> &FlightServiceClient<T> {
131        &self.inner
132    }
133
134    /// Return a mutable reference to the underlying tonic
135    /// [`FlightServiceClient`]
136    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
137        &mut self.inner
138    }
139
140    /// Consume this client and return the underlying tonic
141    /// [`FlightServiceClient`]
142    pub fn into_inner(self) -> FlightServiceClient<T> {
143        self.inner
144    }
145
146    /// Perform an Arrow Flight handshake with the server, sending
147    /// `payload` as the [`HandshakeRequest`] payload and returning
148    /// the [`HandshakeResponse`](crate::HandshakeResponse)
149    /// bytes returned from the server
150    ///
151    /// See [`FlightClient`] docs for an example.
152    pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
153        let request = HandshakeRequest {
154            protocol_version: 0,
155            payload: payload.into(),
156        };
157
158        // apply headers, etc
159        let request = self.make_request(stream::once(ready(request)));
160
161        let mut response_stream = self.inner.handshake(request).await?.into_inner();
162
163        if let Some(response) = response_stream.next().await.transpose()? {
164            // check if there is another response
165            if response_stream.next().await.is_some() {
166                return Err(FlightError::protocol(
167                    "Got unexpected second response from handshake",
168                ));
169            }
170
171            Ok(response.payload)
172        } else {
173            Err(FlightError::protocol("No response from handshake"))
174        }
175    }
176
177    /// Make a `DoGet` call to the server with the provided ticket,
178    /// returning a [`FlightRecordBatchStream`] for reading
179    /// [`RecordBatch`](arrow_array::RecordBatch)es.
180    ///
181    /// # Note
182    ///
183    /// To access the returned [`FlightData`] use
184    /// [`FlightRecordBatchStream::into_inner()`]
185    ///
186    /// # Example:
187    /// ```no_run
188    /// # async fn run() {
189    /// # use bytes::Bytes;
190    /// # use arrow_flight::FlightClient;
191    /// # use arrow_flight::Ticket;
192    /// # use arrow_array::RecordBatch;
193    /// # use futures::stream::TryStreamExt;
194    /// # let channel: tonic::transport::Channel = unimplemented!();
195    /// # let ticket = Ticket { ticket: Bytes::from("foo") };
196    /// let mut client = FlightClient::new(channel);
197    ///
198    /// // Invoke a do_get request on the server with a previously
199    /// // received Ticket
200    ///
201    /// let response = client
202    ///    .do_get(ticket)
203    ///    .await
204    ///    .expect("error invoking do_get");
205    ///
206    /// // Use try_collect to get the RecordBatches from the server
207    /// let batches: Vec<RecordBatch> = response
208    ///    .try_collect()
209    ///    .await
210    ///    .expect("no stream errors");
211    /// # }
212    /// ```
213    pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
214        let request = self.make_request(ticket);
215
216        let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts();
217        let (response_stream, trailers) = extract_lazy_trailers(response_stream);
218
219        Ok(FlightRecordBatchStream::new_from_flight_data(
220            response_stream.map_err(|status| status.into()),
221        )
222        .with_headers(md)
223        .with_trailers(trailers))
224    }
225
226    /// Make a `GetFlightInfo` call to the server with the provided
227    /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
228    /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
229    /// to retrieve the requested batches.
230    ///
231    /// # Example:
232    /// ```no_run
233    /// # async fn run() {
234    /// # use arrow_flight::FlightClient;
235    /// # use arrow_flight::FlightDescriptor;
236    /// # let channel: tonic::transport::Channel = unimplemented!();
237    /// let mut client = FlightClient::new(channel);
238    ///
239    /// // Send a 'CMD' request to the server
240    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
241    /// let flight_info = client
242    ///   .get_flight_info(request)
243    ///   .await
244    ///   .expect("error handshaking");
245    ///
246    /// // retrieve the first endpoint from the returned flight info
247    /// let ticket = flight_info
248    ///   .endpoint[0]
249    ///   // Extract the ticket
250    ///   .ticket
251    ///   .clone()
252    ///   .expect("expected ticket");
253    ///
254    /// // Retrieve the corresponding RecordBatch stream with do_get
255    /// let data = client
256    ///   .do_get(ticket)
257    ///   .await
258    ///   .expect("error fetching data");
259    /// # }
260    /// ```
261    pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<FlightInfo> {
262        let request = self.make_request(descriptor);
263
264        let response = self.inner.get_flight_info(request).await?.into_inner();
265        Ok(response)
266    }
267
268    /// Make a `PollFlightInfo` call to the server with the provided
269    /// [`FlightDescriptor`] and return the [`PollInfo`] from the
270    /// server.
271    ///
272    /// The `info` field of the [`PollInfo`] can be used with
273    /// [`Self::do_get`] to retrieve the requested batches.
274    ///
275    /// If the `flight_descriptor` field of the [`PollInfo`] is
276    /// `None` then the `info` field represents the complete results.
277    ///
278    /// If the `flight_descriptor` field is some [`FlightDescriptor`]
279    /// then the `info` field has incomplete results, and the client
280    /// should call this method again with the new `flight_descriptor`
281    /// to get the updated status.
282    ///
283    /// The `expiration_time`, if set, represents the expiration time
284    /// of the `flight_descriptor`, after which the server may not accept
285    /// this retry descriptor and may cancel the query.
286    ///
287    /// # Example:
288    /// ```no_run
289    /// # async fn run() {
290    /// # use arrow_flight::FlightClient;
291    /// # use arrow_flight::FlightDescriptor;
292    /// # let channel: tonic::transport::Channel = unimplemented!();
293    /// let mut client = FlightClient::new(channel);
294    ///
295    /// // Send a 'CMD' request to the server
296    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
297    /// let poll_info = client
298    ///   .poll_flight_info(request)
299    ///   .await
300    ///   .expect("error handshaking");
301    ///
302    /// // retrieve the first endpoint from the returned poll info
303    /// let ticket = poll_info
304    ///   .info
305    ///   .expect("expected flight info")
306    ///   .endpoint[0]
307    ///   // Extract the ticket
308    ///   .ticket
309    ///   .clone()
310    ///   .expect("expected ticket");
311    ///
312    /// // Retrieve the corresponding RecordBatch stream with do_get
313    /// let data = client
314    ///   .do_get(ticket)
315    ///   .await
316    ///   .expect("error fetching data");
317    /// # }
318    /// ```
319    pub async fn poll_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<PollInfo> {
320        let request = self.make_request(descriptor);
321
322        let response = self.inner.poll_flight_info(request).await?.into_inner();
323        Ok(response)
324    }
325
326    /// Make a `DoPut` call to the server with the provided
327    /// [`Stream`] of [`FlightData`] and returning a
328    /// stream of [`PutResult`].
329    ///
330    /// # Note
331    ///
332    /// The input stream is [`Result`] so that this can be connected
333    /// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder),
334    /// without having to buffer. If the input stream returns an error
335    /// that error will not be sent to the server, instead it will be
336    /// placed into the result stream and the server connection
337    /// terminated.
338    ///
339    /// # Example:
340    /// ```no_run
341    /// # async fn run() {
342    /// # use futures::{TryStreamExt, StreamExt};
343    /// # use std::sync::Arc;
344    /// # use arrow_array::UInt64Array;
345    /// # use arrow_array::RecordBatch;
346    /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
347    /// # use arrow_flight::encode::FlightDataEncoderBuilder;
348    /// # let batch = RecordBatch::try_from_iter(vec![
349    /// #  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
350    /// # ]).unwrap();
351    /// # let channel: tonic::transport::Channel = unimplemented!();
352    /// let mut client = FlightClient::new(channel);
353    ///
354    /// // encode the batch as a stream of `FlightData`
355    /// let flight_data_stream = FlightDataEncoderBuilder::new()
356    ///   .build(futures::stream::iter(vec![Ok(batch)]));
357    ///
358    /// // send the stream and get the results as `PutResult`
359    /// let response: Vec<PutResult>= client
360    ///   .do_put(flight_data_stream)
361    ///   .await
362    ///   .unwrap()
363    ///   .try_collect() // use TryStreamExt to collect stream
364    ///   .await
365    ///   .expect("error calling do_put");
366    /// # }
367    /// ```
368    pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
369        &mut self,
370        request: S,
371    ) -> Result<BoxStream<'static, Result<PutResult>>> {
372        let (sender, receiver) = futures::channel::oneshot::channel();
373
374        // Intercepts client errors and sends them to the oneshot channel above
375        let request = Box::pin(request); // Pin to heap
376        let request_stream = FallibleRequestStream::new(sender, request);
377
378        let request = self.make_request(request_stream);
379        let response_stream = self.inner.do_put(request).await?.into_inner();
380
381        // Forwards errors from the error oneshot with priority over responses from server
382        let response_stream = Box::pin(response_stream);
383        let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
384
385        // combine the response from the server and any error from the client
386        Ok(error_stream.boxed())
387    }
388
389    /// Make a `DoExchange` call to the server with the provided
390    /// [`Stream`] of [`FlightData`] and returning a
391    /// stream of [`FlightData`].
392    ///
393    /// # Example:
394    /// ```no_run
395    /// # async fn run() {
396    /// # use futures::{TryStreamExt, StreamExt};
397    /// # use std::sync::Arc;
398    /// # use arrow_array::UInt64Array;
399    /// # use arrow_array::RecordBatch;
400    /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
401    /// # use arrow_flight::encode::FlightDataEncoderBuilder;
402    /// # let batch = RecordBatch::try_from_iter(vec![
403    /// #  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
404    /// # ]).unwrap();
405    /// # let channel: tonic::transport::Channel = unimplemented!();
406    /// let mut client = FlightClient::new(channel);
407    ///
408    /// // encode the batch as a stream of `FlightData`
409    /// let flight_data_stream = FlightDataEncoderBuilder::new()
410    ///   .build(futures::stream::iter(vec![Ok(batch)]));
411    ///
412    /// // send the stream and get the results as `RecordBatches`
413    /// let response: Vec<RecordBatch> = client
414    ///   .do_exchange(flight_data_stream)
415    ///   .await
416    ///   .unwrap()
417    ///   .try_collect() // use TryStreamExt to collect stream
418    ///   .await
419    ///   .expect("error calling do_exchange");
420    /// # }
421    /// ```
422    pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
423        &mut self,
424        request: S,
425    ) -> Result<FlightRecordBatchStream> {
426        let (sender, receiver) = futures::channel::oneshot::channel();
427
428        let request = Box::pin(request);
429        // Intercepts client errors and sends them to the oneshot channel above
430        let request_stream = FallibleRequestStream::new(sender, request);
431
432        let request = self.make_request(request_stream);
433        let response_stream = self.inner.do_exchange(request).await?.into_inner();
434
435        let response_stream = Box::pin(response_stream);
436        let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
437
438        // combine the response from the server and any error from the client
439        Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
440    }
441
442    /// Make a `ListFlights` call to the server with the provided
443    /// criteria and returning a [`Stream`] of [`FlightInfo`].
444    ///
445    /// # Example:
446    /// ```no_run
447    /// # async fn run() {
448    /// # use futures::TryStreamExt;
449    /// # use bytes::Bytes;
450    /// # use arrow_flight::{FlightInfo, FlightClient};
451    /// # let channel: tonic::transport::Channel = unimplemented!();
452    /// let mut client = FlightClient::new(channel);
453    ///
454    /// // Send 'Name=Foo' bytes as the "expression" to the server
455    /// // and gather the returned FlightInfo
456    /// let responses: Vec<FlightInfo> = client
457    ///   .list_flights(Bytes::from("Name=Foo"))
458    ///   .await
459    ///   .expect("error listing flights")
460    ///   .try_collect() // use TryStreamExt to collect stream
461    ///   .await
462    ///   .expect("error gathering flights");
463    /// # }
464    /// ```
465    pub async fn list_flights(
466        &mut self,
467        expression: impl Into<Bytes>,
468    ) -> Result<BoxStream<'static, Result<FlightInfo>>> {
469        let request = Criteria {
470            expression: expression.into(),
471        };
472
473        let request = self.make_request(request);
474
475        let response = self
476            .inner
477            .list_flights(request)
478            .await?
479            .into_inner()
480            .map_err(|status| status.into());
481
482        Ok(response.boxed())
483    }
484
485    /// Make a `GetSchema` call to the server with the provided
486    /// [`FlightDescriptor`] and returning the associated [`Schema`].
487    ///
488    /// # Example:
489    /// ```no_run
490    /// # async fn run() {
491    /// # use bytes::Bytes;
492    /// # use arrow_flight::{FlightDescriptor, FlightClient};
493    /// # use arrow_schema::Schema;
494    /// # let channel: tonic::transport::Channel = unimplemented!();
495    /// let mut client = FlightClient::new(channel);
496    ///
497    /// // Request the schema result of a 'CMD' request to the server
498    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
499    ///
500    /// let schema: Schema = client
501    ///   .get_schema(request)
502    ///   .await
503    ///   .expect("error making request");
504    /// # }
505    /// ```
506    pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result<Schema> {
507        let request = self.make_request(flight_descriptor);
508
509        let schema_result = self.inner.get_schema(request).await?.into_inner();
510
511        // attempt decode from IPC
512        let schema: Schema = schema_result.try_into()?;
513
514        Ok(schema)
515    }
516
517    /// Make a `ListActions` call to the server and returning a
518    /// [`Stream`] of [`ActionType`].
519    ///
520    /// # Example:
521    /// ```no_run
522    /// # async fn run() {
523    /// # use futures::TryStreamExt;
524    /// # use arrow_flight::{ActionType, FlightClient};
525    /// # use arrow_schema::Schema;
526    /// # let channel: tonic::transport::Channel = unimplemented!();
527    /// let mut client = FlightClient::new(channel);
528    ///
529    /// // List available actions on the server:
530    /// let actions: Vec<ActionType> = client
531    ///   .list_actions()
532    ///   .await
533    ///   .expect("error listing actions")
534    ///   .try_collect() // use TryStreamExt to collect stream
535    ///   .await
536    ///   .expect("error gathering actions");
537    /// # }
538    /// ```
539    pub async fn list_actions(&mut self) -> Result<BoxStream<'static, Result<ActionType>>> {
540        let request = self.make_request(Empty {});
541
542        let action_stream = self
543            .inner
544            .list_actions(request)
545            .await?
546            .into_inner()
547            .map_err(|status| status.into());
548
549        Ok(action_stream.boxed())
550    }
551
552    /// Make a `DoAction` call to the server and returning a
553    /// [`Stream`] of opaque [`Bytes`].
554    ///
555    /// # Example:
556    /// ```no_run
557    /// # async fn run() {
558    /// # use bytes::Bytes;
559    /// # use futures::TryStreamExt;
560    /// # use arrow_flight::{Action, FlightClient};
561    /// # use arrow_schema::Schema;
562    /// # let channel: tonic::transport::Channel = unimplemented!();
563    /// let mut client = FlightClient::new(channel);
564    ///
565    /// let request = Action::new("my_action", "the body");
566    ///
567    /// // Make a request to run the action on the server
568    /// let results: Vec<Bytes> = client
569    ///   .do_action(request)
570    ///   .await
571    ///   .expect("error executing acton")
572    ///   .try_collect() // use TryStreamExt to collect stream
573    ///   .await
574    ///   .expect("error gathering action results");
575    /// # }
576    /// ```
577    pub async fn do_action(&mut self, action: Action) -> Result<BoxStream<'static, Result<Bytes>>> {
578        let request = self.make_request(action);
579
580        let result_stream = self
581            .inner
582            .do_action(request)
583            .await?
584            .into_inner()
585            .map_err(|status| status.into())
586            .map(|r| {
587                r.map(|r| {
588                    // unwrap inner bytes
589                    let crate::Result { body } = r;
590                    body
591                })
592            });
593
594        Ok(result_stream.boxed())
595    }
596
597    /// Make a `CancelFlightInfo` call to the server and return
598    /// a [`CancelFlightInfoResult`].
599    ///
600    /// # Example:
601    /// ```no_run
602    /// # async fn run() {
603    /// # use arrow_flight::{CancelFlightInfoRequest, FlightClient, FlightDescriptor};
604    /// # let channel: tonic::transport::Channel = unimplemented!();
605    /// let mut client = FlightClient::new(channel);
606    ///
607    /// // Send a 'CMD' request to the server
608    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
609    /// let flight_info = client
610    ///   .get_flight_info(request)
611    ///   .await
612    ///   .expect("error handshaking");
613    ///
614    /// // Cancel the query
615    /// let request = CancelFlightInfoRequest::new(flight_info);
616    /// let result = client
617    ///   .cancel_flight_info(request)
618    ///   .await
619    ///   .expect("error cancelling");
620    /// # }
621    /// ```
622    pub async fn cancel_flight_info(
623        &mut self,
624        request: CancelFlightInfoRequest,
625    ) -> Result<CancelFlightInfoResult> {
626        let action = Action::new("CancelFlightInfo", request.encode_to_vec());
627        let response = self.do_action(action).await?.try_next().await?;
628        let response = response.ok_or(FlightError::protocol(
629            "Received no response for cancel_flight_info call",
630        ))?;
631        CancelFlightInfoResult::decode(response)
632            .map_err(|e| FlightError::DecodeError(e.to_string()))
633    }
634
635    /// Make a `RenewFlightEndpoint` call to the server and return
636    /// the renewed [`FlightEndpoint`].
637    ///
638    /// # Example:
639    /// ```no_run
640    /// # async fn run() {
641    /// # use arrow_flight::{FlightClient, FlightDescriptor, RenewFlightEndpointRequest};
642    /// # let channel: tonic::transport::Channel = unimplemented!();
643    /// let mut client = FlightClient::new(channel);
644    ///
645    /// // Send a 'CMD' request to the server
646    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
647    /// let flight_endpoint = client
648    ///   .get_flight_info(request)
649    ///   .await
650    ///   .expect("error handshaking")
651    ///   .endpoint[0];
652    ///
653    /// // Renew the endpoint
654    /// let request = RenewFlightEndpointRequest::new(flight_endpoint);
655    /// let flight_endpoint = client
656    ///   .renew_flight_endpoint(request)
657    ///   .await
658    ///   .expect("error renewing");
659    /// # }
660    /// ```
661    pub async fn renew_flight_endpoint(
662        &mut self,
663        request: RenewFlightEndpointRequest,
664    ) -> Result<FlightEndpoint> {
665        let action = Action::new("RenewFlightEndpoint", request.encode_to_vec());
666        let response = self.do_action(action).await?.try_next().await?;
667        let response = response.ok_or(FlightError::protocol(
668            "Received no response for renew_flight_endpoint call",
669        ))?;
670        FlightEndpoint::decode(response).map_err(|e| FlightError::DecodeError(e.to_string()))
671    }
672
673    /// return a Request, adding any configured metadata
674    fn make_request<R>(&self, t: R) -> tonic::Request<R> {
675        // Pass along metadata
676        let mut request = tonic::Request::new(t);
677        *request.metadata_mut() = self.metadata.clone();
678        request
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::FlightClient;
685    use crate::encode::FlightDataEncoderBuilder;
686    use crate::flight_service_server::{FlightService, FlightServiceServer};
687    use crate::{
688        Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
689        HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
690    };
691    use arrow_array::{RecordBatch, UInt64Array};
692    use bytes::Bytes;
693    use futures::{StreamExt, TryStreamExt, stream::BoxStream};
694    use std::net::SocketAddr;
695    use std::sync::{Arc, Mutex};
696    use std::time::Duration;
697    use tokio::net::TcpListener;
698    use tokio::task::JoinHandle;
699    use tonic::metadata::MetadataMap;
700    use tonic::service::interceptor::InterceptedService;
701    use tonic::transport::Channel;
702    use tonic::{Request, Response, Status, Streaming};
703    use uuid::Uuid;
704
705    /// Minimal `FlightService` that records request metadata and serves a
706    /// configured `do_get` response. Other RPCs return `Unimplemented`.
707    #[derive(Debug, Clone, Default)]
708    struct InterceptorTestServer {
709        state: Arc<Mutex<InterceptorTestState>>,
710    }
711
712    #[derive(Debug, Default)]
713    struct InterceptorTestState {
714        do_get_request: Option<Ticket>,
715        do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
716        last_request_metadata: Option<MetadataMap>,
717    }
718
719    impl InterceptorTestServer {
720        fn save_metadata<T>(&self, request: &Request<T>) {
721            self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone());
722        }
723
724        fn set_do_get_response(&self, response: Vec<Result<RecordBatch, Status>>) {
725            self.state.lock().unwrap().do_get_response = Some(response);
726        }
727
728        fn take_do_get_request(&self) -> Option<Ticket> {
729            self.state.lock().unwrap().do_get_request.take()
730        }
731
732        fn take_last_request_metadata(&self) -> Option<MetadataMap> {
733            self.state.lock().unwrap().last_request_metadata.take()
734        }
735    }
736
737    #[tonic::async_trait]
738    impl FlightService for InterceptorTestServer {
739        type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
740        type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
741        type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
742        type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
743        type DoActionStream = BoxStream<'static, Result<crate::Result, Status>>;
744        type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
745        type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
746
747        async fn do_get(
748            &self,
749            request: Request<Ticket>,
750        ) -> Result<Response<Self::DoGetStream>, Status> {
751            self.save_metadata(&request);
752            let mut state = self.state.lock().unwrap();
753            state.do_get_request = Some(request.into_inner());
754
755            let batches = state
756                .do_get_response
757                .take()
758                .ok_or_else(|| Status::internal("no do_get response configured"))?;
759            let batch_stream = futures::stream::iter(batches).map_err(Into::into);
760            let stream = FlightDataEncoderBuilder::new()
761                .build(batch_stream)
762                .map_err(Into::into);
763            Ok(Response::new(stream.boxed()))
764        }
765
766        async fn handshake(
767            &self,
768            _: Request<Streaming<HandshakeRequest>>,
769        ) -> Result<Response<Self::HandshakeStream>, Status> {
770            Err(Status::unimplemented(""))
771        }
772        async fn list_flights(
773            &self,
774            _: Request<Criteria>,
775        ) -> Result<Response<Self::ListFlightsStream>, Status> {
776            Err(Status::unimplemented(""))
777        }
778        async fn get_flight_info(
779            &self,
780            _: Request<FlightDescriptor>,
781        ) -> Result<Response<FlightInfo>, Status> {
782            Err(Status::unimplemented(""))
783        }
784        async fn poll_flight_info(
785            &self,
786            _: Request<FlightDescriptor>,
787        ) -> Result<Response<PollInfo>, Status> {
788            Err(Status::unimplemented(""))
789        }
790        async fn get_schema(
791            &self,
792            _: Request<FlightDescriptor>,
793        ) -> Result<Response<SchemaResult>, Status> {
794            Err(Status::unimplemented(""))
795        }
796        async fn do_put(
797            &self,
798            _: Request<Streaming<FlightData>>,
799        ) -> Result<Response<Self::DoPutStream>, Status> {
800            Err(Status::unimplemented(""))
801        }
802        async fn do_action(
803            &self,
804            _: Request<Action>,
805        ) -> Result<Response<Self::DoActionStream>, Status> {
806            Err(Status::unimplemented(""))
807        }
808        async fn list_actions(
809            &self,
810            _: Request<Empty>,
811        ) -> Result<Response<Self::ListActionsStream>, Status> {
812            Err(Status::unimplemented(""))
813        }
814        async fn do_exchange(
815            &self,
816            _: Request<Streaming<FlightData>>,
817        ) -> Result<Response<Self::DoExchangeStream>, Status> {
818            Err(Status::unimplemented(""))
819        }
820    }
821
822    /// Spawns the test server on a background task and exposes a connected channel.
823    struct InterceptorTestFixture {
824        shutdown: Option<tokio::sync::oneshot::Sender<()>>,
825        addr: SocketAddr,
826        handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
827    }
828
829    impl InterceptorTestFixture {
830        async fn new(server: FlightServiceServer<InterceptorTestServer>) -> Self {
831            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
832            let addr = listener.local_addr().unwrap();
833            let (tx, rx) = tokio::sync::oneshot::channel();
834            let shutdown_future = async move {
835                rx.await.ok();
836            };
837            let serve = tonic::transport::Server::builder()
838                .timeout(Duration::from_secs(30))
839                .add_service(server)
840                .serve_with_incoming_shutdown(
841                    tokio_stream::wrappers::TcpListenerStream::new(listener),
842                    shutdown_future,
843                );
844            let handle = tokio::task::spawn(serve);
845            Self {
846                shutdown: Some(tx),
847                addr,
848                handle: Some(handle),
849            }
850        }
851
852        async fn channel(&self) -> Channel {
853            let url = format!("http://{}", self.addr);
854            tonic::transport::Endpoint::from_shared(url)
855                .expect("valid endpoint")
856                .timeout(Duration::from_secs(30))
857                .connect()
858                .await
859                .expect("error connecting to server")
860        }
861
862        async fn shutdown_and_wait(mut self) {
863            if let Some(tx) = self.shutdown.take() {
864                tx.send(()).expect("server quit early");
865            }
866            if let Some(handle) = self.handle.take() {
867                handle
868                    .await
869                    .expect("task join error (panic?)")
870                    .expect("server error at shutdown");
871            }
872        }
873    }
874
875    /// Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`]
876    /// that injects a custom header is passed to [`FlightClient`], and the server
877    /// observes the header on the request.
878    #[tokio::test]
879    async fn test_flight_client_with_intercepted_channel_passes_custom_header() {
880        let test_server = InterceptorTestServer::default();
881        let fixture =
882            InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;
883
884        let channel = fixture.channel().await;
885
886        let header_name = "x-random-header";
887        let header_value = format!("random-{}", Uuid::new_v4());
888        let header_value_for_interceptor = header_value.clone();
889
890        let interceptor = move |mut req: Request<()>| -> Result<Request<()>, Status> {
891            req.metadata_mut().insert(
892                header_name,
893                header_value_for_interceptor
894                    .parse()
895                    .expect("valid metadata value"),
896            );
897            Ok(req)
898        };
899
900        let intercepted = InterceptedService::new(channel, interceptor);
901        let mut client = FlightClient::new(intercepted);
902
903        let ticket = Ticket {
904            ticket: Bytes::from("dummy-ticket"),
905        };
906
907        let batch = RecordBatch::try_from_iter(vec![(
908            "col",
909            Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
910        )])
911        .unwrap();
912
913        test_server.set_do_get_response(vec![Ok(batch.clone())]);
914
915        let response_stream = client
916            .do_get(ticket.clone())
917            .await
918            .expect("error making do_get request");
919
920        let response: Vec<RecordBatch> = response_stream
921            .try_collect()
922            .await
923            .expect("error streaming data");
924
925        assert_eq!(response, vec![batch]);
926        assert_eq!(test_server.take_do_get_request(), Some(ticket));
927
928        let metadata = test_server
929            .take_last_request_metadata()
930            .expect("server received headers")
931            .into_headers();
932
933        let received = metadata
934            .get(header_name)
935            .expect("interceptor header missing on server")
936            .to_str()
937            .expect("ascii header value");
938        assert_eq!(received, header_value);
939
940        fixture.shutdown_and_wait().await;
941    }
942}