arrow_integration_testing/flight_client_scenarios/
middleware.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Scenario for testing middleware.
19
20use arrow_flight::{
21    flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, FlightDescriptor,
22};
23use prost::bytes::Bytes;
24use tonic::{Request, Status};
25
26type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
27type Result<T = (), E = Error> = std::result::Result<T, E>;
28
29/// Run a scenario that tests middleware.
30pub async fn run_scenario(host: &str, port: u16) -> Result {
31    let url = format!("http://{host}:{port}");
32    let conn = tonic::transport::Endpoint::new(url)?.connect().await?;
33    let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor);
34
35    let mut descriptor = FlightDescriptor::default();
36    descriptor.set_type(DescriptorType::Cmd);
37    descriptor.cmd = Bytes::from_static(b"");
38
39    // This call is expected to fail.
40    match client
41        .get_flight_info(Request::new(descriptor.clone()))
42        .await
43    {
44        Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))),
45        Err(e) => {
46            let headers = e.metadata();
47            let middleware_header = headers.get("x-middleware");
48            let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or("");
49
50            if value != "expected value" {
51                let msg = format!(
52                    "On failing call: Expected to receive header 'x-middleware: expected value', \
53                     but instead got: '{value}'"
54                );
55                return Err(Box::new(Status::internal(msg)));
56            }
57        }
58    }
59
60    // This call should succeed
61    descriptor.cmd = Bytes::from_static(b"success");
62    let resp = client.get_flight_info(Request::new(descriptor)).await?;
63
64    let headers = resp.metadata();
65    let middleware_header = headers.get("x-middleware");
66    let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or("");
67
68    if value != "expected value" {
69        let msg = format!(
70            "On success call: Expected to receive header 'x-middleware: expected value', \
71            but instead got: '{value}'"
72        );
73        return Err(Box::new(Status::internal(msg)));
74    }
75
76    Ok(())
77}
78
79#[allow(clippy::unnecessary_wraps)]
80fn middleware_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
81    let metadata = req.metadata_mut();
82    metadata.insert("x-middleware", "expected value".parse().unwrap());
83    Ok(req)
84}