arrow_integration_testing/flight_server_scenarios/
middleware.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Middleware test for the Flight server.
19
20use std::pin::Pin;
21
22use arrow_flight::{
23    flight_descriptor::DescriptorType, flight_service_server::FlightService,
24    flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData,
25    FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult,
26    SchemaResult, Ticket,
27};
28use futures::Stream;
29use tonic::{transport::Server, Request, Response, Status, Streaming};
30
31type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
32
33type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
34type Result<T = (), E = Error> = std::result::Result<T, E>;
35
36/// Run a scenario that tests middleware.
37pub async fn scenario_setup(port: u16) -> Result {
38    let service = MiddlewareScenarioImpl {};
39    let svc = FlightServiceServer::new(service);
40    let addr = super::listen_on(port).await?;
41
42    let server = Server::builder().add_service(svc).serve(addr);
43
44    // NOTE: Log output used in tests to signal server is ready
45    println!("Server listening on localhost:{}", addr.port());
46    server.await?;
47    Ok(())
48}
49
50/// Middleware interceptor for testing
51#[derive(Clone, Default)]
52pub struct MiddlewareScenarioImpl {}
53
54#[tonic::async_trait]
55impl FlightService for MiddlewareScenarioImpl {
56    type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
57    type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
58    type DoGetStream = TonicStream<Result<FlightData, Status>>;
59    type DoPutStream = TonicStream<Result<PutResult, Status>>;
60    type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
61    type ListActionsStream = TonicStream<Result<ActionType, Status>>;
62    type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
63
64    async fn get_schema(
65        &self,
66        _request: Request<FlightDescriptor>,
67    ) -> Result<Response<SchemaResult>, Status> {
68        Err(Status::unimplemented("Not yet implemented"))
69    }
70
71    async fn do_get(
72        &self,
73        _request: Request<Ticket>,
74    ) -> Result<Response<Self::DoGetStream>, Status> {
75        Err(Status::unimplemented("Not yet implemented"))
76    }
77
78    async fn handshake(
79        &self,
80        _request: Request<Streaming<HandshakeRequest>>,
81    ) -> Result<Response<Self::HandshakeStream>, Status> {
82        Err(Status::unimplemented("Not yet implemented"))
83    }
84
85    async fn list_flights(
86        &self,
87        _request: Request<Criteria>,
88    ) -> Result<Response<Self::ListFlightsStream>, Status> {
89        Err(Status::unimplemented("Not yet implemented"))
90    }
91
92    async fn get_flight_info(
93        &self,
94        request: Request<FlightDescriptor>,
95    ) -> Result<Response<FlightInfo>, Status> {
96        let middleware_header = request.metadata().get("x-middleware").cloned();
97
98        let descriptor = request.into_inner();
99
100        if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd.as_ref() == b"success"
101        {
102            // Return a fake location - the test doesn't read it
103            let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010");
104
105            let info = FlightInfo {
106                flight_descriptor: Some(descriptor),
107                endpoint: vec![endpoint],
108                ..Default::default()
109            };
110
111            let mut response = Response::new(info);
112            if let Some(value) = middleware_header {
113                response.metadata_mut().insert("x-middleware", value);
114            }
115
116            return Ok(response);
117        }
118
119        let mut status = Status::unknown("Unknown");
120        if let Some(value) = middleware_header {
121            status.metadata_mut().insert("x-middleware", value);
122        }
123
124        Err(status)
125    }
126
127    async fn poll_flight_info(
128        &self,
129        _request: Request<FlightDescriptor>,
130    ) -> Result<Response<PollInfo>, Status> {
131        Err(Status::unimplemented("Not yet implemented"))
132    }
133
134    async fn do_put(
135        &self,
136        _request: Request<Streaming<FlightData>>,
137    ) -> Result<Response<Self::DoPutStream>, Status> {
138        Err(Status::unimplemented("Not yet implemented"))
139    }
140
141    async fn do_action(
142        &self,
143        _request: Request<Action>,
144    ) -> Result<Response<Self::DoActionStream>, Status> {
145        Err(Status::unimplemented("Not yet implemented"))
146    }
147
148    async fn list_actions(
149        &self,
150        _request: Request<Empty>,
151    ) -> Result<Response<Self::ListActionsStream>, Status> {
152        Err(Status::unimplemented("Not yet implemented"))
153    }
154
155    async fn do_exchange(
156        &self,
157        _request: Request<Streaming<FlightData>>,
158    ) -> Result<Response<Self::DoExchangeStream>, Status> {
159        Err(Status::unimplemented("Not yet implemented"))
160    }
161}