arrow_integration_testing/flight_server_scenarios/
integration_test.rs1use core::str;
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::sync::Arc;
24
25use arrow::{
26 array::ArrayRef,
27 buffer::Buffer,
28 datatypes::Schema,
29 datatypes::SchemaRef,
30 ipc::{self, reader, writer},
31 record_batch::RecordBatch,
32};
33use arrow_flight::{
34 flight_descriptor::DescriptorType, flight_service_server::FlightService,
35 flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData,
36 FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage,
37 PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket,
38};
39use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
40use tokio::sync::Mutex;
41use tonic::{transport::Server, Request, Response, Status, Streaming};
42
43type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
44
45type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
46type Result<T = (), E = Error> = std::result::Result<T, E>;
47
48pub async fn scenario_setup(port: u16) -> Result {
50 let addr = super::listen_on(port).await?;
51 let resolved_port = addr.port();
52
53 let service = FlightServiceImpl {
54 server_location: format!("grpc+tcp://localhost:{resolved_port}"),
58 ..Default::default()
59 };
60 let svc = FlightServiceServer::new(service);
61
62 let server = Server::builder().add_service(svc).serve(addr);
63
64 println!("Server listening on localhost:{}", addr.port());
66 server.await?;
67 Ok(())
68}
69
70#[derive(Debug, Clone)]
71struct IntegrationDataset {
72 schema: Schema,
73 chunks: Vec<RecordBatch>,
74}
75
76#[derive(Clone, Default)]
78pub struct FlightServiceImpl {
79 server_location: String,
80 uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
81}
82
83impl FlightServiceImpl {
84 fn endpoint_from_path(&self, path: &str) -> FlightEndpoint {
85 super::endpoint(path, &self.server_location)
86 }
87}
88
89#[tonic::async_trait]
90impl FlightService for FlightServiceImpl {
91 type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
92 type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
93 type DoGetStream = TonicStream<Result<FlightData, Status>>;
94 type DoPutStream = TonicStream<Result<PutResult, Status>>;
95 type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
96 type ListActionsStream = TonicStream<Result<ActionType, Status>>;
97 type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
98
99 async fn get_schema(
100 &self,
101 _request: Request<FlightDescriptor>,
102 ) -> Result<Response<SchemaResult>, Status> {
103 Err(Status::unimplemented("Not yet implemented"))
104 }
105
106 async fn do_get(
107 &self,
108 request: Request<Ticket>,
109 ) -> Result<Response<Self::DoGetStream>, Status> {
110 let ticket = request.into_inner();
111
112 let key = str::from_utf8(&ticket.ticket)
113 .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {e:?}")))?;
114
115 let uploaded_chunks = self.uploaded_chunks.lock().await;
116
117 let flight = uploaded_chunks
118 .get(key)
119 .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?;
120
121 let options = arrow::ipc::writer::IpcWriteOptions::default();
122 #[allow(deprecated)]
123 let mut dictionary_tracker =
124 writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
125 let data_gen = writer::IpcDataGenerator::default();
126 let data = IpcMessage(
127 data_gen
128 .schema_to_bytes_with_dictionary_tracker(
129 &flight.schema,
130 &mut dictionary_tracker,
131 &options,
132 )
133 .ipc_message
134 .into(),
135 );
136 let schema_flight_data = FlightData {
137 data_header: data.0,
138 ..Default::default()
139 };
140
141 let schema = std::iter::once(Ok(schema_flight_data));
142
143 let batches = flight
144 .chunks
145 .iter()
146 .enumerate()
147 .flat_map(|(counter, batch)| {
148 let (encoded_dictionaries, encoded_batch) = data_gen
149 .encoded_batch(batch, &mut dictionary_tracker, &options)
150 .expect("DictionaryTracker configured above to not error on replacement");
151
152 let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into);
153 let mut batch_flight_data: FlightData = encoded_batch.into();
154
155 let metadata = counter.to_string().into();
157 batch_flight_data.app_metadata = metadata;
158
159 dictionary_flight_data
160 .chain(std::iter::once(batch_flight_data))
161 .map(Ok)
162 });
163
164 let output = futures::stream::iter(schema.chain(batches).collect::<Vec<_>>());
165
166 Ok(Response::new(Box::pin(output) as Self::DoGetStream))
167 }
168
169 async fn handshake(
170 &self,
171 _request: Request<Streaming<HandshakeRequest>>,
172 ) -> Result<Response<Self::HandshakeStream>, Status> {
173 Err(Status::unimplemented("Not yet implemented"))
174 }
175
176 async fn list_flights(
177 &self,
178 _request: Request<Criteria>,
179 ) -> Result<Response<Self::ListFlightsStream>, Status> {
180 Err(Status::unimplemented("Not yet implemented"))
181 }
182
183 async fn get_flight_info(
184 &self,
185 request: Request<FlightDescriptor>,
186 ) -> Result<Response<FlightInfo>, Status> {
187 let descriptor = request.into_inner();
188
189 match descriptor.r#type {
190 t if t == DescriptorType::Path as i32 => {
191 let path = &descriptor.path;
192 if path.is_empty() {
193 return Err(Status::invalid_argument("Invalid path"));
194 }
195
196 let uploaded_chunks = self.uploaded_chunks.lock().await;
197 let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| {
198 Status::not_found(format!("Could not find flight. {}", path[0]))
199 })?;
200
201 let endpoint = self.endpoint_from_path(&path[0]);
202
203 let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum();
204
205 let options = arrow::ipc::writer::IpcWriteOptions::default();
206 let message = SchemaAsIpc::new(&flight.schema, &options)
207 .try_into()
208 .expect(
209 "Could not generate schema bytes from schema stored by a DoPut; \
210 this should be impossible",
211 );
212 let IpcMessage(schema) = message;
213
214 let info = FlightInfo {
215 schema,
216 flight_descriptor: Some(descriptor.clone()),
217 endpoint: vec![endpoint],
218 total_records: total_records as i64,
219 total_bytes: -1,
220 ordered: false,
221 app_metadata: vec![].into(),
222 };
223
224 Ok(Response::new(info))
225 }
226 other => Err(Status::unimplemented(format!("Request type: {other}"))),
227 }
228 }
229
230 async fn poll_flight_info(
231 &self,
232 _request: Request<FlightDescriptor>,
233 ) -> Result<Response<PollInfo>, Status> {
234 Err(Status::unimplemented("Not yet implemented"))
235 }
236
237 async fn do_put(
238 &self,
239 request: Request<Streaming<FlightData>>,
240 ) -> Result<Response<Self::DoPutStream>, Status> {
241 let mut input_stream = request.into_inner();
242 let flight_data = input_stream
243 .message()
244 .await?
245 .ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?;
246
247 let descriptor = flight_data
248 .flight_descriptor
249 .clone()
250 .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?;
251
252 if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() {
253 return Err(Status::invalid_argument("Must specify a path"));
254 }
255
256 let key = descriptor.path[0].clone();
257
258 let schema = Schema::try_from(&flight_data)
259 .map_err(|e| Status::invalid_argument(format!("Invalid schema: {e:?}")))?;
260 let schema_ref = Arc::new(schema.clone());
261
262 let (response_tx, response_rx) = mpsc::channel(10);
263
264 let uploaded_chunks = self.uploaded_chunks.clone();
265
266 tokio::spawn(async {
267 let mut error_tx = response_tx.clone();
268 if let Err(e) = save_uploaded_chunks(
269 uploaded_chunks,
270 schema_ref,
271 input_stream,
272 response_tx,
273 schema,
274 key,
275 )
276 .await
277 {
278 error_tx.send(Err(e)).await.expect("Error sending error")
279 }
280 });
281
282 Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream))
283 }
284
285 async fn do_action(
286 &self,
287 _request: Request<Action>,
288 ) -> Result<Response<Self::DoActionStream>, Status> {
289 Err(Status::unimplemented("Not yet implemented"))
290 }
291
292 async fn list_actions(
293 &self,
294 _request: Request<Empty>,
295 ) -> Result<Response<Self::ListActionsStream>, Status> {
296 Err(Status::unimplemented("Not yet implemented"))
297 }
298
299 async fn do_exchange(
300 &self,
301 _request: Request<Streaming<FlightData>>,
302 ) -> Result<Response<Self::DoExchangeStream>, Status> {
303 Err(Status::unimplemented("Not yet implemented"))
304 }
305}
306
307async fn send_app_metadata(
308 tx: &mut mpsc::Sender<Result<PutResult, Status>>,
309 app_metadata: &[u8],
310) -> Result<(), Status> {
311 tx.send(Ok(PutResult {
312 app_metadata: app_metadata.to_vec().into(),
313 }))
314 .await
315 .map_err(|e| Status::internal(format!("Could not send PutResult: {e:?}")))
316}
317
318async fn record_batch_from_message(
319 message: ipc::Message<'_>,
320 data_body: &Buffer,
321 schema_ref: SchemaRef,
322 dictionaries_by_id: &HashMap<i64, ArrayRef>,
323) -> Result<RecordBatch, Status> {
324 let ipc_batch = message
325 .header_as_record_batch()
326 .ok_or_else(|| Status::internal("Could not parse message header as record batch"))?;
327
328 let arrow_batch_result = reader::read_record_batch(
329 data_body,
330 ipc_batch,
331 schema_ref,
332 dictionaries_by_id,
333 None,
334 &message.version(),
335 );
336
337 arrow_batch_result
338 .map_err(|e| Status::internal(format!("Could not convert to RecordBatch: {e:?}")))
339}
340
341async fn dictionary_from_message(
342 message: ipc::Message<'_>,
343 data_body: &Buffer,
344 schema_ref: SchemaRef,
345 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
346) -> Result<(), Status> {
347 let ipc_batch = message
348 .header_as_dictionary_batch()
349 .ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?;
350
351 let dictionary_batch_result = reader::read_dictionary(
352 data_body,
353 ipc_batch,
354 &schema_ref,
355 dictionaries_by_id,
356 &message.version(),
357 );
358 dictionary_batch_result
359 .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}")))
360}
361
362async fn save_uploaded_chunks(
363 uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
364 schema_ref: Arc<Schema>,
365 mut input_stream: Streaming<FlightData>,
366 mut response_tx: mpsc::Sender<Result<PutResult, Status>>,
367 schema: Schema,
368 key: String,
369) -> Result<(), Status> {
370 let mut chunks = vec![];
371 let mut uploaded_chunks = uploaded_chunks.lock().await;
372
373 let mut dictionaries_by_id = HashMap::new();
374
375 while let Some(Ok(data)) = input_stream.next().await {
376 let message = arrow::ipc::root_as_message(&data.data_header[..])
377 .map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?;
378
379 match message.header_type() {
380 ipc::MessageHeader::Schema => {
381 return Err(Status::internal(
382 "Not expecting a schema when messages are read",
383 ))
384 }
385 ipc::MessageHeader::RecordBatch => {
386 send_app_metadata(&mut response_tx, &data.app_metadata).await?;
387
388 let batch = record_batch_from_message(
389 message,
390 &Buffer::from(data.data_body.as_ref()),
391 schema_ref.clone(),
392 &dictionaries_by_id,
393 )
394 .await?;
395
396 chunks.push(batch);
397 }
398 ipc::MessageHeader::DictionaryBatch => {
399 dictionary_from_message(
400 message,
401 &Buffer::from(data.data_body.as_ref()),
402 schema_ref.clone(),
403 &mut dictionaries_by_id,
404 )
405 .await?;
406 }
407 t => {
408 return Err(Status::internal(format!(
409 "Reading types other than record batches not yet supported, \
410 unable to read {t:?}"
411 )));
412 }
413 }
414 }
415
416 let dataset = IntegrationDataset { schema, chunks };
417 uploaded_chunks.insert(key, dataset);
418
419 Ok(())
420}