arrow_integration_testing/flight_server_scenarios/
auth_basic_proto.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Basic auth test for the Flight server.
19
20use std::pin::Pin;
21use std::sync::Arc;
22
23use arrow_flight::{
24    flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action,
25    ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
26    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
27};
28use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
29use tokio::sync::Mutex;
30use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming};
31type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
32
33type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
34type Result<T = (), E = Error> = std::result::Result<T, E>;
35
36use prost::Message;
37
38use crate::{AUTH_PASSWORD, AUTH_USERNAME};
39
40/// Run a scenario that tests basic auth.
41pub async fn scenario_setup(port: u16) -> Result {
42    let service = AuthBasicProtoScenarioImpl {
43        username: AUTH_USERNAME.into(),
44        password: AUTH_PASSWORD.into(),
45        peer_identity: Arc::new(Mutex::new(None)),
46    };
47    let addr = super::listen_on(port).await?;
48    let svc = FlightServiceServer::new(service);
49
50    let server = Server::builder().add_service(svc).serve(addr);
51
52    // NOTE: Log output used in tests to signal server is ready
53    println!("Server listening on localhost:{}", addr.port());
54    server.await?;
55    Ok(())
56}
57
58/// Scenario for testing basic auth.
59#[derive(Clone)]
60pub struct AuthBasicProtoScenarioImpl {
61    username: Arc<str>,
62    password: Arc<str>,
63    #[allow(dead_code)]
64    peer_identity: Arc<Mutex<Option<String>>>,
65}
66
67impl AuthBasicProtoScenarioImpl {
68    async fn check_auth(&self, metadata: &MetadataMap) -> Result<GrpcServerCallContext, Status> {
69        let token = metadata
70            .get_bin("auth-token-bin")
71            .and_then(|v| v.to_bytes().ok())
72            .and_then(|b| String::from_utf8(b.to_vec()).ok());
73        self.is_valid(token).await
74    }
75
76    async fn is_valid(&self, token: Option<String>) -> Result<GrpcServerCallContext, Status> {
77        match token {
78            Some(t) if t == *self.username => Ok(GrpcServerCallContext {
79                peer_identity: self.username.to_string(),
80            }),
81            _ => Err(Status::unauthenticated("Invalid token")),
82        }
83    }
84}
85
86struct GrpcServerCallContext {
87    peer_identity: String,
88}
89
90impl GrpcServerCallContext {
91    pub fn peer_identity(&self) -> &str {
92        &self.peer_identity
93    }
94}
95
96#[tonic::async_trait]
97impl FlightService for AuthBasicProtoScenarioImpl {
98    type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
99    type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
100    type DoGetStream = TonicStream<Result<FlightData, Status>>;
101    type DoPutStream = TonicStream<Result<PutResult, Status>>;
102    type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
103    type ListActionsStream = TonicStream<Result<ActionType, Status>>;
104    type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
105
106    async fn get_schema(
107        &self,
108        request: Request<FlightDescriptor>,
109    ) -> Result<Response<SchemaResult>, Status> {
110        self.check_auth(request.metadata()).await?;
111        Err(Status::unimplemented("Not yet implemented"))
112    }
113
114    async fn do_get(
115        &self,
116        request: Request<Ticket>,
117    ) -> Result<Response<Self::DoGetStream>, Status> {
118        self.check_auth(request.metadata()).await?;
119        Err(Status::unimplemented("Not yet implemented"))
120    }
121
122    async fn handshake(
123        &self,
124        request: Request<Streaming<HandshakeRequest>>,
125    ) -> Result<Response<Self::HandshakeStream>, Status> {
126        let (tx, rx) = mpsc::channel(10);
127
128        tokio::spawn({
129            let username = self.username.clone();
130            let password = self.password.clone();
131
132            async move {
133                let requests = request.into_inner();
134
135                requests
136                    .for_each(move |req| {
137                        let mut tx = tx.clone();
138                        let req = req.expect("Error reading handshake request");
139                        let HandshakeRequest { payload, .. } = req;
140
141                        let auth =
142                            BasicAuth::decode(&*payload).expect("Error parsing handshake request");
143
144                        let resp = if *auth.username == *username && *auth.password == *password {
145                            Ok(HandshakeResponse {
146                                payload: username.as_bytes().to_vec().into(),
147                                ..HandshakeResponse::default()
148                            })
149                        } else {
150                            Err(Status::unauthenticated(format!(
151                                "Don't know user {}",
152                                auth.username
153                            )))
154                        };
155
156                        async move {
157                            tx.send(resp)
158                                .await
159                                .expect("Error sending handshake response");
160                        }
161                    })
162                    .await;
163            }
164        });
165
166        Ok(Response::new(Box::pin(rx)))
167    }
168
169    async fn list_flights(
170        &self,
171        request: Request<Criteria>,
172    ) -> Result<Response<Self::ListFlightsStream>, Status> {
173        self.check_auth(request.metadata()).await?;
174        Err(Status::unimplemented("Not yet implemented"))
175    }
176
177    async fn get_flight_info(
178        &self,
179        request: Request<FlightDescriptor>,
180    ) -> Result<Response<FlightInfo>, Status> {
181        self.check_auth(request.metadata()).await?;
182        Err(Status::unimplemented("Not yet implemented"))
183    }
184
185    async fn poll_flight_info(
186        &self,
187        request: Request<FlightDescriptor>,
188    ) -> Result<Response<PollInfo>, Status> {
189        self.check_auth(request.metadata()).await?;
190        Err(Status::unimplemented("Not yet implemented"))
191    }
192
193    async fn do_put(
194        &self,
195        request: Request<Streaming<FlightData>>,
196    ) -> Result<Response<Self::DoPutStream>, Status> {
197        let metadata = request.metadata();
198        self.check_auth(metadata).await?;
199        Err(Status::unimplemented("Not yet implemented"))
200    }
201
202    async fn do_action(
203        &self,
204        request: Request<Action>,
205    ) -> Result<Response<Self::DoActionStream>, Status> {
206        let flight_context = self.check_auth(request.metadata()).await?;
207        // Respond with the authenticated username.
208        let buf = flight_context.peer_identity().as_bytes().to_vec().into();
209        let result = arrow_flight::Result { body: buf };
210        let output = futures::stream::once(async { Ok(result) });
211        Ok(Response::new(Box::pin(output) as Self::DoActionStream))
212    }
213
214    async fn list_actions(
215        &self,
216        request: Request<Empty>,
217    ) -> Result<Response<Self::ListActionsStream>, Status> {
218        self.check_auth(request.metadata()).await?;
219        Err(Status::unimplemented("Not yet implemented"))
220    }
221
222    async fn do_exchange(
223        &self,
224        request: Request<Streaming<FlightData>>,
225    ) -> Result<Response<Self::DoExchangeStream>, Status> {
226        let metadata = request.metadata();
227        self.check_auth(metadata).await?;
228        Err(Status::unimplemented("Not yet implemented"))
229    }
230}