arrow_integration_testing/flight_client_scenarios/
auth_basic_proto.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Scenario for testing basic auth.
19
20use crate::{AUTH_PASSWORD, AUTH_USERNAME};
21
22use arrow_flight::{flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest};
23use futures::{stream, StreamExt};
24use prost::Message;
25use tonic::{metadata::MetadataValue, transport::Endpoint, Request, Status};
26
27type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
28type Result<T = (), E = Error> = std::result::Result<T, E>;
29
30type Client = FlightServiceClient<tonic::transport::Channel>;
31
32/// Run a scenario that tests basic auth.
33pub async fn run_scenario(host: &str, port: u16) -> Result {
34    let url = format!("http://{host}:{port}");
35    let endpoint = Endpoint::new(url)?;
36    let channel = endpoint.connect().await?;
37    let mut client = FlightServiceClient::new(channel);
38
39    let action = arrow_flight::Action::default();
40
41    let resp = client.do_action(Request::new(action.clone())).await;
42    // This client is unauthenticated and should fail.
43    match resp {
44        Err(e) => {
45            if e.code() != tonic::Code::Unauthenticated {
46                return Err(Box::new(Status::internal(format!(
47                    "Expected UNAUTHENTICATED but got {e:?}"
48                ))));
49            }
50        }
51        Ok(other) => {
52            return Err(Box::new(Status::internal(format!(
53                "Expected UNAUTHENTICATED but got {other:?}"
54            ))));
55        }
56    }
57
58    let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD)
59        .await
60        .expect("must respond successfully from handshake");
61
62    let mut request = Request::new(action);
63    let metadata = request.metadata_mut();
64    metadata.insert_bin(
65        "auth-token-bin",
66        MetadataValue::from_bytes(token.as_bytes()),
67    );
68
69    let resp = client.do_action(request).await?;
70    let mut resp = resp.into_inner();
71
72    let r = resp
73        .next()
74        .await
75        .expect("No response received")
76        .expect("Invalid response received");
77
78    let body = std::str::from_utf8(&r.body).unwrap();
79    assert_eq!(body, AUTH_USERNAME);
80
81    Ok(())
82}
83
84async fn authenticate(client: &mut Client, username: &str, password: &str) -> Result<String> {
85    let auth = BasicAuth {
86        username: username.into(),
87        password: password.into(),
88    };
89    let mut payload = vec![];
90    auth.encode(&mut payload)?;
91
92    let req = stream::once(async {
93        HandshakeRequest {
94            payload: payload.into(),
95            ..HandshakeRequest::default()
96        }
97    });
98
99    let rx = client.handshake(Request::new(req)).await?;
100    let mut rx = rx.into_inner();
101
102    let r = rx.next().await.expect("must respond from handshake")?;
103    assert!(rx.next().await.is_none(), "must not respond a second time");
104
105    Ok(std::str::from_utf8(&r.payload).unwrap().into())
106}