arrow_flight/client.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo,
20 HandshakeRequest, PollInfo, PutResult, Ticket,
21 decode::FlightRecordBatchStream,
22 flight_service_client::FlightServiceClient,
23 r#gen::{CancelFlightInfoRequest, CancelFlightInfoResult, RenewFlightEndpointRequest},
24 trailers::extract_lazy_trailers,
25};
26use arrow_schema::Schema;
27use bytes::Bytes;
28use futures::{
29 Stream, StreamExt, TryStreamExt,
30 future::ready,
31 stream::{self, BoxStream},
32};
33use prost::Message;
34use tonic::codegen::{Body, StdError};
35use tonic::{metadata::MetadataMap, transport::Channel};
36
37use crate::error::{FlightError, Result};
38use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};
39
40/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client.
41///
42/// [`FlightClient`] is intended as a convenience for interactions
43/// with Arrow Flight servers. For more direct control, such as access
44/// to the response headers, use [`FlightServiceClient`] directly
45/// via methods such as [`Self::inner`] or [`Self::into_inner`].
46///
47/// # Example:
48/// ```no_run
49/// # async fn run() {
50/// # use arrow_flight::FlightClient;
51/// # use bytes::Bytes;
52/// use tonic::transport::Channel;
53/// let channel = Channel::from_static("http://localhost:1234")
54/// .connect()
55/// .await
56/// .expect("error connecting");
57///
58/// let mut client = FlightClient::new(channel);
59///
60/// // Send 'Hi' bytes as the handshake request to the server
61/// let response = client
62/// .handshake(Bytes::from("Hi"))
63/// .await
64/// .expect("error handshaking");
65///
66/// // Expect the server responded with 'Ho'
67/// assert_eq!(response, Bytes::from("Ho"));
68/// # }
69/// ```
70#[derive(Debug)]
71pub struct FlightClient<T = Channel> {
72 /// Optional grpc header metadata to include with each request
73 metadata: MetadataMap,
74
75 /// The inner client
76 inner: FlightServiceClient<T>,
77}
78
79impl<T> FlightClient<T>
80where
81 T: tonic::client::GrpcService<tonic::body::Body>,
82 T::Error: Into<StdError>,
83 T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
84 <T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
85{
86 /// Creates a client with the provided transport
87 pub fn new(inner: T) -> Self {
88 Self::new_from_inner(FlightServiceClient::new(inner))
89 }
90
91 /// Creates a new higher level client with the provided lower level client
92 pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
93 Self {
94 metadata: MetadataMap::new(),
95 inner,
96 }
97 }
98
99 /// Return a reference to gRPC metadata included with each request
100 pub fn metadata(&self) -> &MetadataMap {
101 &self.metadata
102 }
103
104 /// Return a reference to gRPC metadata included with each request
105 ///
106 /// These headers can be used, for example, to include
107 /// authorization or other application specific headers.
108 pub fn metadata_mut(&mut self) -> &mut MetadataMap {
109 &mut self.metadata
110 }
111
112 /// Add the specified header with value to all subsequent
113 /// requests. See [`Self::metadata_mut`] for fine grained control.
114 pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
115 let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
116 .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
117
118 let value = value
119 .parse()
120 .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
121
122 // ignore previous value
123 self.metadata.insert(key, value);
124
125 Ok(())
126 }
127
128 /// Return a reference to the underlying tonic
129 /// [`FlightServiceClient`]
130 pub fn inner(&self) -> &FlightServiceClient<T> {
131 &self.inner
132 }
133
134 /// Return a mutable reference to the underlying tonic
135 /// [`FlightServiceClient`]
136 pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
137 &mut self.inner
138 }
139
140 /// Consume this client and return the underlying tonic
141 /// [`FlightServiceClient`]
142 pub fn into_inner(self) -> FlightServiceClient<T> {
143 self.inner
144 }
145
146 /// Perform an Arrow Flight handshake with the server, sending
147 /// `payload` as the [`HandshakeRequest`] payload and returning
148 /// the [`HandshakeResponse`](crate::HandshakeResponse)
149 /// bytes returned from the server
150 ///
151 /// See [`FlightClient`] docs for an example.
152 pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
153 let request = HandshakeRequest {
154 protocol_version: 0,
155 payload: payload.into(),
156 };
157
158 // apply headers, etc
159 let request = self.make_request(stream::once(ready(request)));
160
161 let mut response_stream = self.inner.handshake(request).await?.into_inner();
162
163 if let Some(response) = response_stream.next().await.transpose()? {
164 // check if there is another response
165 if response_stream.next().await.is_some() {
166 return Err(FlightError::protocol(
167 "Got unexpected second response from handshake",
168 ));
169 }
170
171 Ok(response.payload)
172 } else {
173 Err(FlightError::protocol("No response from handshake"))
174 }
175 }
176
177 /// Make a `DoGet` call to the server with the provided ticket,
178 /// returning a [`FlightRecordBatchStream`] for reading
179 /// [`RecordBatch`](arrow_array::RecordBatch)es.
180 ///
181 /// # Note
182 ///
183 /// To access the returned [`FlightData`] use
184 /// [`FlightRecordBatchStream::into_inner()`]
185 ///
186 /// # Example:
187 /// ```no_run
188 /// # async fn run() {
189 /// # use bytes::Bytes;
190 /// # use arrow_flight::FlightClient;
191 /// # use arrow_flight::Ticket;
192 /// # use arrow_array::RecordBatch;
193 /// # use futures::stream::TryStreamExt;
194 /// # let channel: tonic::transport::Channel = unimplemented!();
195 /// # let ticket = Ticket { ticket: Bytes::from("foo") };
196 /// let mut client = FlightClient::new(channel);
197 ///
198 /// // Invoke a do_get request on the server with a previously
199 /// // received Ticket
200 ///
201 /// let response = client
202 /// .do_get(ticket)
203 /// .await
204 /// .expect("error invoking do_get");
205 ///
206 /// // Use try_collect to get the RecordBatches from the server
207 /// let batches: Vec<RecordBatch> = response
208 /// .try_collect()
209 /// .await
210 /// .expect("no stream errors");
211 /// # }
212 /// ```
213 pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
214 let request = self.make_request(ticket);
215
216 let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts();
217 let (response_stream, trailers) = extract_lazy_trailers(response_stream);
218
219 Ok(FlightRecordBatchStream::new_from_flight_data(
220 response_stream.map_err(|status| status.into()),
221 )
222 .with_headers(md)
223 .with_trailers(trailers))
224 }
225
226 /// Make a `GetFlightInfo` call to the server with the provided
227 /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
228 /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
229 /// to retrieve the requested batches.
230 ///
231 /// # Example:
232 /// ```no_run
233 /// # async fn run() {
234 /// # use arrow_flight::FlightClient;
235 /// # use arrow_flight::FlightDescriptor;
236 /// # let channel: tonic::transport::Channel = unimplemented!();
237 /// let mut client = FlightClient::new(channel);
238 ///
239 /// // Send a 'CMD' request to the server
240 /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
241 /// let flight_info = client
242 /// .get_flight_info(request)
243 /// .await
244 /// .expect("error handshaking");
245 ///
246 /// // retrieve the first endpoint from the returned flight info
247 /// let ticket = flight_info
248 /// .endpoint[0]
249 /// // Extract the ticket
250 /// .ticket
251 /// .clone()
252 /// .expect("expected ticket");
253 ///
254 /// // Retrieve the corresponding RecordBatch stream with do_get
255 /// let data = client
256 /// .do_get(ticket)
257 /// .await
258 /// .expect("error fetching data");
259 /// # }
260 /// ```
261 pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<FlightInfo> {
262 let request = self.make_request(descriptor);
263
264 let response = self.inner.get_flight_info(request).await?.into_inner();
265 Ok(response)
266 }
267
268 /// Make a `PollFlightInfo` call to the server with the provided
269 /// [`FlightDescriptor`] and return the [`PollInfo`] from the
270 /// server.
271 ///
272 /// The `info` field of the [`PollInfo`] can be used with
273 /// [`Self::do_get`] to retrieve the requested batches.
274 ///
275 /// If the `flight_descriptor` field of the [`PollInfo`] is
276 /// `None` then the `info` field represents the complete results.
277 ///
278 /// If the `flight_descriptor` field is some [`FlightDescriptor`]
279 /// then the `info` field has incomplete results, and the client
280 /// should call this method again with the new `flight_descriptor`
281 /// to get the updated status.
282 ///
283 /// The `expiration_time`, if set, represents the expiration time
284 /// of the `flight_descriptor`, after which the server may not accept
285 /// this retry descriptor and may cancel the query.
286 ///
287 /// # Example:
288 /// ```no_run
289 /// # async fn run() {
290 /// # use arrow_flight::FlightClient;
291 /// # use arrow_flight::FlightDescriptor;
292 /// # let channel: tonic::transport::Channel = unimplemented!();
293 /// let mut client = FlightClient::new(channel);
294 ///
295 /// // Send a 'CMD' request to the server
296 /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
297 /// let poll_info = client
298 /// .poll_flight_info(request)
299 /// .await
300 /// .expect("error handshaking");
301 ///
302 /// // retrieve the first endpoint from the returned poll info
303 /// let ticket = poll_info
304 /// .info
305 /// .expect("expected flight info")
306 /// .endpoint[0]
307 /// // Extract the ticket
308 /// .ticket
309 /// .clone()
310 /// .expect("expected ticket");
311 ///
312 /// // Retrieve the corresponding RecordBatch stream with do_get
313 /// let data = client
314 /// .do_get(ticket)
315 /// .await
316 /// .expect("error fetching data");
317 /// # }
318 /// ```
319 pub async fn poll_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<PollInfo> {
320 let request = self.make_request(descriptor);
321
322 let response = self.inner.poll_flight_info(request).await?.into_inner();
323 Ok(response)
324 }
325
326 /// Make a `DoPut` call to the server with the provided
327 /// [`Stream`] of [`FlightData`] and returning a
328 /// stream of [`PutResult`].
329 ///
330 /// # Note
331 ///
332 /// The input stream is [`Result`] so that this can be connected
333 /// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder),
334 /// without having to buffer. If the input stream returns an error
335 /// that error will not be sent to the server, instead it will be
336 /// placed into the result stream and the server connection
337 /// terminated.
338 ///
339 /// # Example:
340 /// ```no_run
341 /// # async fn run() {
342 /// # use futures::{TryStreamExt, StreamExt};
343 /// # use std::sync::Arc;
344 /// # use arrow_array::UInt64Array;
345 /// # use arrow_array::RecordBatch;
346 /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
347 /// # use arrow_flight::encode::FlightDataEncoderBuilder;
348 /// # let batch = RecordBatch::try_from_iter(vec![
349 /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
350 /// # ]).unwrap();
351 /// # let channel: tonic::transport::Channel = unimplemented!();
352 /// let mut client = FlightClient::new(channel);
353 ///
354 /// // encode the batch as a stream of `FlightData`
355 /// let flight_data_stream = FlightDataEncoderBuilder::new()
356 /// .build(futures::stream::iter(vec![Ok(batch)]));
357 ///
358 /// // send the stream and get the results as `PutResult`
359 /// let response: Vec<PutResult>= client
360 /// .do_put(flight_data_stream)
361 /// .await
362 /// .unwrap()
363 /// .try_collect() // use TryStreamExt to collect stream
364 /// .await
365 /// .expect("error calling do_put");
366 /// # }
367 /// ```
368 pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
369 &mut self,
370 request: S,
371 ) -> Result<BoxStream<'static, Result<PutResult>>> {
372 let (sender, receiver) = futures::channel::oneshot::channel();
373
374 // Intercepts client errors and sends them to the oneshot channel above
375 let request = Box::pin(request); // Pin to heap
376 let request_stream = FallibleRequestStream::new(sender, request);
377
378 let request = self.make_request(request_stream);
379 let response_stream = self.inner.do_put(request).await?.into_inner();
380
381 // Forwards errors from the error oneshot with priority over responses from server
382 let response_stream = Box::pin(response_stream);
383 let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
384
385 // combine the response from the server and any error from the client
386 Ok(error_stream.boxed())
387 }
388
389 /// Make a `DoExchange` call to the server with the provided
390 /// [`Stream`] of [`FlightData`] and returning a
391 /// stream of [`FlightData`].
392 ///
393 /// # Example:
394 /// ```no_run
395 /// # async fn run() {
396 /// # use futures::{TryStreamExt, StreamExt};
397 /// # use std::sync::Arc;
398 /// # use arrow_array::UInt64Array;
399 /// # use arrow_array::RecordBatch;
400 /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
401 /// # use arrow_flight::encode::FlightDataEncoderBuilder;
402 /// # let batch = RecordBatch::try_from_iter(vec![
403 /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
404 /// # ]).unwrap();
405 /// # let channel: tonic::transport::Channel = unimplemented!();
406 /// let mut client = FlightClient::new(channel);
407 ///
408 /// // encode the batch as a stream of `FlightData`
409 /// let flight_data_stream = FlightDataEncoderBuilder::new()
410 /// .build(futures::stream::iter(vec![Ok(batch)]));
411 ///
412 /// // send the stream and get the results as `RecordBatches`
413 /// let response: Vec<RecordBatch> = client
414 /// .do_exchange(flight_data_stream)
415 /// .await
416 /// .unwrap()
417 /// .try_collect() // use TryStreamExt to collect stream
418 /// .await
419 /// .expect("error calling do_exchange");
420 /// # }
421 /// ```
422 pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
423 &mut self,
424 request: S,
425 ) -> Result<FlightRecordBatchStream> {
426 let (sender, receiver) = futures::channel::oneshot::channel();
427
428 let request = Box::pin(request);
429 // Intercepts client errors and sends them to the oneshot channel above
430 let request_stream = FallibleRequestStream::new(sender, request);
431
432 let request = self.make_request(request_stream);
433 let response_stream = self.inner.do_exchange(request).await?.into_inner();
434
435 let response_stream = Box::pin(response_stream);
436 let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
437
438 // combine the response from the server and any error from the client
439 Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
440 }
441
442 /// Make a `ListFlights` call to the server with the provided
443 /// criteria and returning a [`Stream`] of [`FlightInfo`].
444 ///
445 /// # Example:
446 /// ```no_run
447 /// # async fn run() {
448 /// # use futures::TryStreamExt;
449 /// # use bytes::Bytes;
450 /// # use arrow_flight::{FlightInfo, FlightClient};
451 /// # let channel: tonic::transport::Channel = unimplemented!();
452 /// let mut client = FlightClient::new(channel);
453 ///
454 /// // Send 'Name=Foo' bytes as the "expression" to the server
455 /// // and gather the returned FlightInfo
456 /// let responses: Vec<FlightInfo> = client
457 /// .list_flights(Bytes::from("Name=Foo"))
458 /// .await
459 /// .expect("error listing flights")
460 /// .try_collect() // use TryStreamExt to collect stream
461 /// .await
462 /// .expect("error gathering flights");
463 /// # }
464 /// ```
465 pub async fn list_flights(
466 &mut self,
467 expression: impl Into<Bytes>,
468 ) -> Result<BoxStream<'static, Result<FlightInfo>>> {
469 let request = Criteria {
470 expression: expression.into(),
471 };
472
473 let request = self.make_request(request);
474
475 let response = self
476 .inner
477 .list_flights(request)
478 .await?
479 .into_inner()
480 .map_err(|status| status.into());
481
482 Ok(response.boxed())
483 }
484
485 /// Make a `GetSchema` call to the server with the provided
486 /// [`FlightDescriptor`] and returning the associated [`Schema`].
487 ///
488 /// # Example:
489 /// ```no_run
490 /// # async fn run() {
491 /// # use bytes::Bytes;
492 /// # use arrow_flight::{FlightDescriptor, FlightClient};
493 /// # use arrow_schema::Schema;
494 /// # let channel: tonic::transport::Channel = unimplemented!();
495 /// let mut client = FlightClient::new(channel);
496 ///
497 /// // Request the schema result of a 'CMD' request to the server
498 /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
499 ///
500 /// let schema: Schema = client
501 /// .get_schema(request)
502 /// .await
503 /// .expect("error making request");
504 /// # }
505 /// ```
506 pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result<Schema> {
507 let request = self.make_request(flight_descriptor);
508
509 let schema_result = self.inner.get_schema(request).await?.into_inner();
510
511 // attempt decode from IPC
512 let schema: Schema = schema_result.try_into()?;
513
514 Ok(schema)
515 }
516
517 /// Make a `ListActions` call to the server and returning a
518 /// [`Stream`] of [`ActionType`].
519 ///
520 /// # Example:
521 /// ```no_run
522 /// # async fn run() {
523 /// # use futures::TryStreamExt;
524 /// # use arrow_flight::{ActionType, FlightClient};
525 /// # use arrow_schema::Schema;
526 /// # let channel: tonic::transport::Channel = unimplemented!();
527 /// let mut client = FlightClient::new(channel);
528 ///
529 /// // List available actions on the server:
530 /// let actions: Vec<ActionType> = client
531 /// .list_actions()
532 /// .await
533 /// .expect("error listing actions")
534 /// .try_collect() // use TryStreamExt to collect stream
535 /// .await
536 /// .expect("error gathering actions");
537 /// # }
538 /// ```
539 pub async fn list_actions(&mut self) -> Result<BoxStream<'static, Result<ActionType>>> {
540 let request = self.make_request(Empty {});
541
542 let action_stream = self
543 .inner
544 .list_actions(request)
545 .await?
546 .into_inner()
547 .map_err(|status| status.into());
548
549 Ok(action_stream.boxed())
550 }
551
552 /// Make a `DoAction` call to the server and returning a
553 /// [`Stream`] of opaque [`Bytes`].
554 ///
555 /// # Example:
556 /// ```no_run
557 /// # async fn run() {
558 /// # use bytes::Bytes;
559 /// # use futures::TryStreamExt;
560 /// # use arrow_flight::{Action, FlightClient};
561 /// # use arrow_schema::Schema;
562 /// # let channel: tonic::transport::Channel = unimplemented!();
563 /// let mut client = FlightClient::new(channel);
564 ///
565 /// let request = Action::new("my_action", "the body");
566 ///
567 /// // Make a request to run the action on the server
568 /// let results: Vec<Bytes> = client
569 /// .do_action(request)
570 /// .await
571 /// .expect("error executing acton")
572 /// .try_collect() // use TryStreamExt to collect stream
573 /// .await
574 /// .expect("error gathering action results");
575 /// # }
576 /// ```
577 pub async fn do_action(&mut self, action: Action) -> Result<BoxStream<'static, Result<Bytes>>> {
578 let request = self.make_request(action);
579
580 let result_stream = self
581 .inner
582 .do_action(request)
583 .await?
584 .into_inner()
585 .map_err(|status| status.into())
586 .map(|r| {
587 r.map(|r| {
588 // unwrap inner bytes
589 let crate::Result { body } = r;
590 body
591 })
592 });
593
594 Ok(result_stream.boxed())
595 }
596
597 /// Make a `CancelFlightInfo` call to the server and return
598 /// a [`CancelFlightInfoResult`].
599 ///
600 /// # Example:
601 /// ```no_run
602 /// # async fn run() {
603 /// # use arrow_flight::{CancelFlightInfoRequest, FlightClient, FlightDescriptor};
604 /// # let channel: tonic::transport::Channel = unimplemented!();
605 /// let mut client = FlightClient::new(channel);
606 ///
607 /// // Send a 'CMD' request to the server
608 /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
609 /// let flight_info = client
610 /// .get_flight_info(request)
611 /// .await
612 /// .expect("error handshaking");
613 ///
614 /// // Cancel the query
615 /// let request = CancelFlightInfoRequest::new(flight_info);
616 /// let result = client
617 /// .cancel_flight_info(request)
618 /// .await
619 /// .expect("error cancelling");
620 /// # }
621 /// ```
622 pub async fn cancel_flight_info(
623 &mut self,
624 request: CancelFlightInfoRequest,
625 ) -> Result<CancelFlightInfoResult> {
626 let action = Action::new("CancelFlightInfo", request.encode_to_vec());
627 let response = self.do_action(action).await?.try_next().await?;
628 let response = response.ok_or(FlightError::protocol(
629 "Received no response for cancel_flight_info call",
630 ))?;
631 CancelFlightInfoResult::decode(response)
632 .map_err(|e| FlightError::DecodeError(e.to_string()))
633 }
634
635 /// Make a `RenewFlightEndpoint` call to the server and return
636 /// the renewed [`FlightEndpoint`].
637 ///
638 /// # Example:
639 /// ```no_run
640 /// # async fn run() {
641 /// # use arrow_flight::{FlightClient, FlightDescriptor, RenewFlightEndpointRequest};
642 /// # let channel: tonic::transport::Channel = unimplemented!();
643 /// let mut client = FlightClient::new(channel);
644 ///
645 /// // Send a 'CMD' request to the server
646 /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
647 /// let flight_endpoint = client
648 /// .get_flight_info(request)
649 /// .await
650 /// .expect("error handshaking")
651 /// .endpoint[0];
652 ///
653 /// // Renew the endpoint
654 /// let request = RenewFlightEndpointRequest::new(flight_endpoint);
655 /// let flight_endpoint = client
656 /// .renew_flight_endpoint(request)
657 /// .await
658 /// .expect("error renewing");
659 /// # }
660 /// ```
661 pub async fn renew_flight_endpoint(
662 &mut self,
663 request: RenewFlightEndpointRequest,
664 ) -> Result<FlightEndpoint> {
665 let action = Action::new("RenewFlightEndpoint", request.encode_to_vec());
666 let response = self.do_action(action).await?.try_next().await?;
667 let response = response.ok_or(FlightError::protocol(
668 "Received no response for renew_flight_endpoint call",
669 ))?;
670 FlightEndpoint::decode(response).map_err(|e| FlightError::DecodeError(e.to_string()))
671 }
672
673 /// return a Request, adding any configured metadata
674 fn make_request<R>(&self, t: R) -> tonic::Request<R> {
675 // Pass along metadata
676 let mut request = tonic::Request::new(t);
677 *request.metadata_mut() = self.metadata.clone();
678 request
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::FlightClient;
685 use crate::encode::FlightDataEncoderBuilder;
686 use crate::flight_service_server::{FlightService, FlightServiceServer};
687 use crate::{
688 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
689 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
690 };
691 use arrow_array::{RecordBatch, UInt64Array};
692 use bytes::Bytes;
693 use futures::{StreamExt, TryStreamExt, stream::BoxStream};
694 use std::net::SocketAddr;
695 use std::sync::{Arc, Mutex};
696 use std::time::Duration;
697 use tokio::net::TcpListener;
698 use tokio::task::JoinHandle;
699 use tonic::metadata::MetadataMap;
700 use tonic::service::interceptor::InterceptedService;
701 use tonic::transport::Channel;
702 use tonic::{Request, Response, Status, Streaming};
703 use uuid::Uuid;
704
705 /// Minimal `FlightService` that records request metadata and serves a
706 /// configured `do_get` response. Other RPCs return `Unimplemented`.
707 #[derive(Debug, Clone, Default)]
708 struct InterceptorTestServer {
709 state: Arc<Mutex<InterceptorTestState>>,
710 }
711
712 #[derive(Debug, Default)]
713 struct InterceptorTestState {
714 do_get_request: Option<Ticket>,
715 do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
716 last_request_metadata: Option<MetadataMap>,
717 }
718
719 impl InterceptorTestServer {
720 fn save_metadata<T>(&self, request: &Request<T>) {
721 self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone());
722 }
723
724 fn set_do_get_response(&self, response: Vec<Result<RecordBatch, Status>>) {
725 self.state.lock().unwrap().do_get_response = Some(response);
726 }
727
728 fn take_do_get_request(&self) -> Option<Ticket> {
729 self.state.lock().unwrap().do_get_request.take()
730 }
731
732 fn take_last_request_metadata(&self) -> Option<MetadataMap> {
733 self.state.lock().unwrap().last_request_metadata.take()
734 }
735 }
736
737 #[tonic::async_trait]
738 impl FlightService for InterceptorTestServer {
739 type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
740 type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
741 type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
742 type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
743 type DoActionStream = BoxStream<'static, Result<crate::Result, Status>>;
744 type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
745 type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
746
747 async fn do_get(
748 &self,
749 request: Request<Ticket>,
750 ) -> Result<Response<Self::DoGetStream>, Status> {
751 self.save_metadata(&request);
752 let mut state = self.state.lock().unwrap();
753 state.do_get_request = Some(request.into_inner());
754
755 let batches = state
756 .do_get_response
757 .take()
758 .ok_or_else(|| Status::internal("no do_get response configured"))?;
759 let batch_stream = futures::stream::iter(batches).map_err(Into::into);
760 let stream = FlightDataEncoderBuilder::new()
761 .build(batch_stream)
762 .map_err(Into::into);
763 Ok(Response::new(stream.boxed()))
764 }
765
766 async fn handshake(
767 &self,
768 _: Request<Streaming<HandshakeRequest>>,
769 ) -> Result<Response<Self::HandshakeStream>, Status> {
770 Err(Status::unimplemented(""))
771 }
772 async fn list_flights(
773 &self,
774 _: Request<Criteria>,
775 ) -> Result<Response<Self::ListFlightsStream>, Status> {
776 Err(Status::unimplemented(""))
777 }
778 async fn get_flight_info(
779 &self,
780 _: Request<FlightDescriptor>,
781 ) -> Result<Response<FlightInfo>, Status> {
782 Err(Status::unimplemented(""))
783 }
784 async fn poll_flight_info(
785 &self,
786 _: Request<FlightDescriptor>,
787 ) -> Result<Response<PollInfo>, Status> {
788 Err(Status::unimplemented(""))
789 }
790 async fn get_schema(
791 &self,
792 _: Request<FlightDescriptor>,
793 ) -> Result<Response<SchemaResult>, Status> {
794 Err(Status::unimplemented(""))
795 }
796 async fn do_put(
797 &self,
798 _: Request<Streaming<FlightData>>,
799 ) -> Result<Response<Self::DoPutStream>, Status> {
800 Err(Status::unimplemented(""))
801 }
802 async fn do_action(
803 &self,
804 _: Request<Action>,
805 ) -> Result<Response<Self::DoActionStream>, Status> {
806 Err(Status::unimplemented(""))
807 }
808 async fn list_actions(
809 &self,
810 _: Request<Empty>,
811 ) -> Result<Response<Self::ListActionsStream>, Status> {
812 Err(Status::unimplemented(""))
813 }
814 async fn do_exchange(
815 &self,
816 _: Request<Streaming<FlightData>>,
817 ) -> Result<Response<Self::DoExchangeStream>, Status> {
818 Err(Status::unimplemented(""))
819 }
820 }
821
822 /// Spawns the test server on a background task and exposes a connected channel.
823 struct InterceptorTestFixture {
824 shutdown: Option<tokio::sync::oneshot::Sender<()>>,
825 addr: SocketAddr,
826 handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
827 }
828
829 impl InterceptorTestFixture {
830 async fn new(server: FlightServiceServer<InterceptorTestServer>) -> Self {
831 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
832 let addr = listener.local_addr().unwrap();
833 let (tx, rx) = tokio::sync::oneshot::channel();
834 let shutdown_future = async move {
835 rx.await.ok();
836 };
837 let serve = tonic::transport::Server::builder()
838 .timeout(Duration::from_secs(30))
839 .add_service(server)
840 .serve_with_incoming_shutdown(
841 tokio_stream::wrappers::TcpListenerStream::new(listener),
842 shutdown_future,
843 );
844 let handle = tokio::task::spawn(serve);
845 Self {
846 shutdown: Some(tx),
847 addr,
848 handle: Some(handle),
849 }
850 }
851
852 async fn channel(&self) -> Channel {
853 let url = format!("http://{}", self.addr);
854 tonic::transport::Endpoint::from_shared(url)
855 .expect("valid endpoint")
856 .timeout(Duration::from_secs(30))
857 .connect()
858 .await
859 .expect("error connecting to server")
860 }
861
862 async fn shutdown_and_wait(mut self) {
863 if let Some(tx) = self.shutdown.take() {
864 tx.send(()).expect("server quit early");
865 }
866 if let Some(handle) = self.handle.take() {
867 handle
868 .await
869 .expect("task join error (panic?)")
870 .expect("server error at shutdown");
871 }
872 }
873 }
874
875 /// Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`]
876 /// that injects a custom header is passed to [`FlightClient`], and the server
877 /// observes the header on the request.
878 #[tokio::test]
879 async fn test_flight_client_with_intercepted_channel_passes_custom_header() {
880 let test_server = InterceptorTestServer::default();
881 let fixture =
882 InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;
883
884 let channel = fixture.channel().await;
885
886 let header_name = "x-random-header";
887 let header_value = format!("random-{}", Uuid::new_v4());
888 let header_value_for_interceptor = header_value.clone();
889
890 let interceptor = move |mut req: Request<()>| -> Result<Request<()>, Status> {
891 req.metadata_mut().insert(
892 header_name,
893 header_value_for_interceptor
894 .parse()
895 .expect("valid metadata value"),
896 );
897 Ok(req)
898 };
899
900 let intercepted = InterceptedService::new(channel, interceptor);
901 let mut client = FlightClient::new(intercepted);
902
903 let ticket = Ticket {
904 ticket: Bytes::from("dummy-ticket"),
905 };
906
907 let batch = RecordBatch::try_from_iter(vec![(
908 "col",
909 Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
910 )])
911 .unwrap();
912
913 test_server.set_do_get_response(vec![Ok(batch.clone())]);
914
915 let response_stream = client
916 .do_get(ticket.clone())
917 .await
918 .expect("error making do_get request");
919
920 let response: Vec<RecordBatch> = response_stream
921 .try_collect()
922 .await
923 .expect("error streaming data");
924
925 assert_eq!(response, vec![batch]);
926 assert_eq!(test_server.take_do_get_request(), Some(ticket));
927
928 let metadata = test_server
929 .take_last_request_metadata()
930 .expect("server received headers")
931 .into_headers();
932
933 let received = metadata
934 .get(header_name)
935 .expect("interceptor header missing on server")
936 .to_str()
937 .expect("ascii header value");
938 assert_eq!(received, header_value);
939
940 fixture.shutdown_and_wait().await;
941 }
942}