arrow_integration_testing/flight_server_scenarios/
auth_basic_proto.rs1use std::pin::Pin;
21use std::sync::Arc;
22
23use arrow_flight::{
24 flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action,
25 ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
26 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
27};
28use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
29use tokio::sync::Mutex;
30use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming};
31type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
32
33type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
34type Result<T = (), E = Error> = std::result::Result<T, E>;
35
36use prost::Message;
37
38use crate::{AUTH_PASSWORD, AUTH_USERNAME};
39
40pub async fn scenario_setup(port: u16) -> Result {
42 let service = AuthBasicProtoScenarioImpl {
43 username: AUTH_USERNAME.into(),
44 password: AUTH_PASSWORD.into(),
45 peer_identity: Arc::new(Mutex::new(None)),
46 };
47 let addr = super::listen_on(port).await?;
48 let svc = FlightServiceServer::new(service);
49
50 let server = Server::builder().add_service(svc).serve(addr);
51
52 println!("Server listening on localhost:{}", addr.port());
54 server.await?;
55 Ok(())
56}
57
58#[derive(Clone)]
60pub struct AuthBasicProtoScenarioImpl {
61 username: Arc<str>,
62 password: Arc<str>,
63 #[allow(dead_code)]
64 peer_identity: Arc<Mutex<Option<String>>>,
65}
66
67impl AuthBasicProtoScenarioImpl {
68 async fn check_auth(&self, metadata: &MetadataMap) -> Result<GrpcServerCallContext, Status> {
69 let token = metadata
70 .get_bin("auth-token-bin")
71 .and_then(|v| v.to_bytes().ok())
72 .and_then(|b| String::from_utf8(b.to_vec()).ok());
73 self.is_valid(token).await
74 }
75
76 async fn is_valid(&self, token: Option<String>) -> Result<GrpcServerCallContext, Status> {
77 match token {
78 Some(t) if t == *self.username => Ok(GrpcServerCallContext {
79 peer_identity: self.username.to_string(),
80 }),
81 _ => Err(Status::unauthenticated("Invalid token")),
82 }
83 }
84}
85
86struct GrpcServerCallContext {
87 peer_identity: String,
88}
89
90impl GrpcServerCallContext {
91 pub fn peer_identity(&self) -> &str {
92 &self.peer_identity
93 }
94}
95
96#[tonic::async_trait]
97impl FlightService for AuthBasicProtoScenarioImpl {
98 type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
99 type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
100 type DoGetStream = TonicStream<Result<FlightData, Status>>;
101 type DoPutStream = TonicStream<Result<PutResult, Status>>;
102 type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
103 type ListActionsStream = TonicStream<Result<ActionType, Status>>;
104 type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
105
106 async fn get_schema(
107 &self,
108 request: Request<FlightDescriptor>,
109 ) -> Result<Response<SchemaResult>, Status> {
110 self.check_auth(request.metadata()).await?;
111 Err(Status::unimplemented("Not yet implemented"))
112 }
113
114 async fn do_get(
115 &self,
116 request: Request<Ticket>,
117 ) -> Result<Response<Self::DoGetStream>, Status> {
118 self.check_auth(request.metadata()).await?;
119 Err(Status::unimplemented("Not yet implemented"))
120 }
121
122 async fn handshake(
123 &self,
124 request: Request<Streaming<HandshakeRequest>>,
125 ) -> Result<Response<Self::HandshakeStream>, Status> {
126 let (tx, rx) = mpsc::channel(10);
127
128 tokio::spawn({
129 let username = self.username.clone();
130 let password = self.password.clone();
131
132 async move {
133 let requests = request.into_inner();
134
135 requests
136 .for_each(move |req| {
137 let mut tx = tx.clone();
138 let req = req.expect("Error reading handshake request");
139 let HandshakeRequest { payload, .. } = req;
140
141 let auth =
142 BasicAuth::decode(&*payload).expect("Error parsing handshake request");
143
144 let resp = if *auth.username == *username && *auth.password == *password {
145 Ok(HandshakeResponse {
146 payload: username.as_bytes().to_vec().into(),
147 ..HandshakeResponse::default()
148 })
149 } else {
150 Err(Status::unauthenticated(format!(
151 "Don't know user {}",
152 auth.username
153 )))
154 };
155
156 async move {
157 tx.send(resp)
158 .await
159 .expect("Error sending handshake response");
160 }
161 })
162 .await;
163 }
164 });
165
166 Ok(Response::new(Box::pin(rx)))
167 }
168
169 async fn list_flights(
170 &self,
171 request: Request<Criteria>,
172 ) -> Result<Response<Self::ListFlightsStream>, Status> {
173 self.check_auth(request.metadata()).await?;
174 Err(Status::unimplemented("Not yet implemented"))
175 }
176
177 async fn get_flight_info(
178 &self,
179 request: Request<FlightDescriptor>,
180 ) -> Result<Response<FlightInfo>, Status> {
181 self.check_auth(request.metadata()).await?;
182 Err(Status::unimplemented("Not yet implemented"))
183 }
184
185 async fn poll_flight_info(
186 &self,
187 request: Request<FlightDescriptor>,
188 ) -> Result<Response<PollInfo>, Status> {
189 self.check_auth(request.metadata()).await?;
190 Err(Status::unimplemented("Not yet implemented"))
191 }
192
193 async fn do_put(
194 &self,
195 request: Request<Streaming<FlightData>>,
196 ) -> Result<Response<Self::DoPutStream>, Status> {
197 let metadata = request.metadata();
198 self.check_auth(metadata).await?;
199 Err(Status::unimplemented("Not yet implemented"))
200 }
201
202 async fn do_action(
203 &self,
204 request: Request<Action>,
205 ) -> Result<Response<Self::DoActionStream>, Status> {
206 let flight_context = self.check_auth(request.metadata()).await?;
207 let buf = flight_context.peer_identity().as_bytes().to_vec().into();
209 let result = arrow_flight::Result { body: buf };
210 let output = futures::stream::once(async { Ok(result) });
211 Ok(Response::new(Box::pin(output) as Self::DoActionStream))
212 }
213
214 async fn list_actions(
215 &self,
216 request: Request<Empty>,
217 ) -> Result<Response<Self::ListActionsStream>, Status> {
218 self.check_auth(request.metadata()).await?;
219 Err(Status::unimplemented("Not yet implemented"))
220 }
221
222 async fn do_exchange(
223 &self,
224 request: Request<Streaming<FlightData>>,
225 ) -> Result<Response<Self::DoExchangeStream>, Status> {
226 let metadata = request.metadata();
227 self.check_auth(metadata).await?;
228 Err(Status::unimplemented("Not yet implemented"))
229 }
230}