arrow_integration_testing/flight_server_scenarios/
integration_test.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Integration tests for the Flight server.
19
20use core::str;
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::sync::Arc;
24
25use arrow::{
26    array::ArrayRef,
27    buffer::Buffer,
28    datatypes::Schema,
29    datatypes::SchemaRef,
30    ipc::{self, reader, writer},
31    record_batch::RecordBatch,
32};
33use arrow_flight::{
34    flight_descriptor::DescriptorType, flight_service_server::FlightService,
35    flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData,
36    FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage,
37    PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket,
38};
39use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
40use tokio::sync::Mutex;
41use tonic::{transport::Server, Request, Response, Status, Streaming};
42
43type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
44
45type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
46type Result<T = (), E = Error> = std::result::Result<T, E>;
47
48/// Run a scenario that tests integration testing.
49pub async fn scenario_setup(port: u16) -> Result {
50    let addr = super::listen_on(port).await?;
51    let resolved_port = addr.port();
52
53    let service = FlightServiceImpl {
54        // See https://github.com/apache/arrow-rs/issues/6577
55        // C# had trouble resolving addressed like 0.0.0.0:port
56        // server_location: format!("grpc+tcp://{addr}"),
57        server_location: format!("grpc+tcp://localhost:{resolved_port}"),
58        ..Default::default()
59    };
60    let svc = FlightServiceServer::new(service);
61
62    let server = Server::builder().add_service(svc).serve(addr);
63
64    // NOTE: Log output used in tests to signal server is ready
65    println!("Server listening on localhost:{}", addr.port());
66    server.await?;
67    Ok(())
68}
69
70#[derive(Debug, Clone)]
71struct IntegrationDataset {
72    schema: Schema,
73    chunks: Vec<RecordBatch>,
74}
75
76/// Flight service implementation for integration testing
77#[derive(Clone, Default)]
78pub struct FlightServiceImpl {
79    server_location: String,
80    uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
81}
82
83impl FlightServiceImpl {
84    fn endpoint_from_path(&self, path: &str) -> FlightEndpoint {
85        super::endpoint(path, &self.server_location)
86    }
87}
88
89#[tonic::async_trait]
90impl FlightService for FlightServiceImpl {
91    type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
92    type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
93    type DoGetStream = TonicStream<Result<FlightData, Status>>;
94    type DoPutStream = TonicStream<Result<PutResult, Status>>;
95    type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
96    type ListActionsStream = TonicStream<Result<ActionType, Status>>;
97    type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
98
99    async fn get_schema(
100        &self,
101        _request: Request<FlightDescriptor>,
102    ) -> Result<Response<SchemaResult>, Status> {
103        Err(Status::unimplemented("Not yet implemented"))
104    }
105
106    async fn do_get(
107        &self,
108        request: Request<Ticket>,
109    ) -> Result<Response<Self::DoGetStream>, Status> {
110        let ticket = request.into_inner();
111
112        let key = str::from_utf8(&ticket.ticket)
113            .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {e:?}")))?;
114
115        let uploaded_chunks = self.uploaded_chunks.lock().await;
116
117        let flight = uploaded_chunks
118            .get(key)
119            .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?;
120
121        let options = arrow::ipc::writer::IpcWriteOptions::default();
122        #[allow(deprecated)]
123        let mut dictionary_tracker =
124            writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
125        let data_gen = writer::IpcDataGenerator::default();
126        let data = IpcMessage(
127            data_gen
128                .schema_to_bytes_with_dictionary_tracker(
129                    &flight.schema,
130                    &mut dictionary_tracker,
131                    &options,
132                )
133                .ipc_message
134                .into(),
135        );
136        let schema_flight_data = FlightData {
137            data_header: data.0,
138            ..Default::default()
139        };
140
141        let schema = std::iter::once(Ok(schema_flight_data));
142
143        let batches = flight
144            .chunks
145            .iter()
146            .enumerate()
147            .flat_map(|(counter, batch)| {
148                let (encoded_dictionaries, encoded_batch) = data_gen
149                    .encoded_batch(batch, &mut dictionary_tracker, &options)
150                    .expect("DictionaryTracker configured above to not error on replacement");
151
152                let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into);
153                let mut batch_flight_data: FlightData = encoded_batch.into();
154
155                // Only the record batch's FlightData gets app_metadata
156                let metadata = counter.to_string().into();
157                batch_flight_data.app_metadata = metadata;
158
159                dictionary_flight_data
160                    .chain(std::iter::once(batch_flight_data))
161                    .map(Ok)
162            });
163
164        let output = futures::stream::iter(schema.chain(batches).collect::<Vec<_>>());
165
166        Ok(Response::new(Box::pin(output) as Self::DoGetStream))
167    }
168
169    async fn handshake(
170        &self,
171        _request: Request<Streaming<HandshakeRequest>>,
172    ) -> Result<Response<Self::HandshakeStream>, Status> {
173        Err(Status::unimplemented("Not yet implemented"))
174    }
175
176    async fn list_flights(
177        &self,
178        _request: Request<Criteria>,
179    ) -> Result<Response<Self::ListFlightsStream>, Status> {
180        Err(Status::unimplemented("Not yet implemented"))
181    }
182
183    async fn get_flight_info(
184        &self,
185        request: Request<FlightDescriptor>,
186    ) -> Result<Response<FlightInfo>, Status> {
187        let descriptor = request.into_inner();
188
189        match descriptor.r#type {
190            t if t == DescriptorType::Path as i32 => {
191                let path = &descriptor.path;
192                if path.is_empty() {
193                    return Err(Status::invalid_argument("Invalid path"));
194                }
195
196                let uploaded_chunks = self.uploaded_chunks.lock().await;
197                let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| {
198                    Status::not_found(format!("Could not find flight. {}", path[0]))
199                })?;
200
201                let endpoint = self.endpoint_from_path(&path[0]);
202
203                let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum();
204
205                let options = arrow::ipc::writer::IpcWriteOptions::default();
206                let message = SchemaAsIpc::new(&flight.schema, &options)
207                    .try_into()
208                    .expect(
209                        "Could not generate schema bytes from schema stored by a DoPut; \
210                         this should be impossible",
211                    );
212                let IpcMessage(schema) = message;
213
214                let info = FlightInfo {
215                    schema,
216                    flight_descriptor: Some(descriptor.clone()),
217                    endpoint: vec![endpoint],
218                    total_records: total_records as i64,
219                    total_bytes: -1,
220                    ordered: false,
221                    app_metadata: vec![].into(),
222                };
223
224                Ok(Response::new(info))
225            }
226            other => Err(Status::unimplemented(format!("Request type: {other}"))),
227        }
228    }
229
230    async fn poll_flight_info(
231        &self,
232        _request: Request<FlightDescriptor>,
233    ) -> Result<Response<PollInfo>, Status> {
234        Err(Status::unimplemented("Not yet implemented"))
235    }
236
237    async fn do_put(
238        &self,
239        request: Request<Streaming<FlightData>>,
240    ) -> Result<Response<Self::DoPutStream>, Status> {
241        let mut input_stream = request.into_inner();
242        let flight_data = input_stream
243            .message()
244            .await?
245            .ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?;
246
247        let descriptor = flight_data
248            .flight_descriptor
249            .clone()
250            .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?;
251
252        if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() {
253            return Err(Status::invalid_argument("Must specify a path"));
254        }
255
256        let key = descriptor.path[0].clone();
257
258        let schema = Schema::try_from(&flight_data)
259            .map_err(|e| Status::invalid_argument(format!("Invalid schema: {e:?}")))?;
260        let schema_ref = Arc::new(schema.clone());
261
262        let (response_tx, response_rx) = mpsc::channel(10);
263
264        let uploaded_chunks = self.uploaded_chunks.clone();
265
266        tokio::spawn(async {
267            let mut error_tx = response_tx.clone();
268            if let Err(e) = save_uploaded_chunks(
269                uploaded_chunks,
270                schema_ref,
271                input_stream,
272                response_tx,
273                schema,
274                key,
275            )
276            .await
277            {
278                error_tx.send(Err(e)).await.expect("Error sending error")
279            }
280        });
281
282        Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream))
283    }
284
285    async fn do_action(
286        &self,
287        _request: Request<Action>,
288    ) -> Result<Response<Self::DoActionStream>, Status> {
289        Err(Status::unimplemented("Not yet implemented"))
290    }
291
292    async fn list_actions(
293        &self,
294        _request: Request<Empty>,
295    ) -> Result<Response<Self::ListActionsStream>, Status> {
296        Err(Status::unimplemented("Not yet implemented"))
297    }
298
299    async fn do_exchange(
300        &self,
301        _request: Request<Streaming<FlightData>>,
302    ) -> Result<Response<Self::DoExchangeStream>, Status> {
303        Err(Status::unimplemented("Not yet implemented"))
304    }
305}
306
307async fn send_app_metadata(
308    tx: &mut mpsc::Sender<Result<PutResult, Status>>,
309    app_metadata: &[u8],
310) -> Result<(), Status> {
311    tx.send(Ok(PutResult {
312        app_metadata: app_metadata.to_vec().into(),
313    }))
314    .await
315    .map_err(|e| Status::internal(format!("Could not send PutResult: {e:?}")))
316}
317
318async fn record_batch_from_message(
319    message: ipc::Message<'_>,
320    data_body: &Buffer,
321    schema_ref: SchemaRef,
322    dictionaries_by_id: &HashMap<i64, ArrayRef>,
323) -> Result<RecordBatch, Status> {
324    let ipc_batch = message
325        .header_as_record_batch()
326        .ok_or_else(|| Status::internal("Could not parse message header as record batch"))?;
327
328    let arrow_batch_result = reader::read_record_batch(
329        data_body,
330        ipc_batch,
331        schema_ref,
332        dictionaries_by_id,
333        None,
334        &message.version(),
335    );
336
337    arrow_batch_result
338        .map_err(|e| Status::internal(format!("Could not convert to RecordBatch: {e:?}")))
339}
340
341async fn dictionary_from_message(
342    message: ipc::Message<'_>,
343    data_body: &Buffer,
344    schema_ref: SchemaRef,
345    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
346) -> Result<(), Status> {
347    let ipc_batch = message
348        .header_as_dictionary_batch()
349        .ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?;
350
351    let dictionary_batch_result = reader::read_dictionary(
352        data_body,
353        ipc_batch,
354        &schema_ref,
355        dictionaries_by_id,
356        &message.version(),
357    );
358    dictionary_batch_result
359        .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}")))
360}
361
362async fn save_uploaded_chunks(
363    uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
364    schema_ref: Arc<Schema>,
365    mut input_stream: Streaming<FlightData>,
366    mut response_tx: mpsc::Sender<Result<PutResult, Status>>,
367    schema: Schema,
368    key: String,
369) -> Result<(), Status> {
370    let mut chunks = vec![];
371    let mut uploaded_chunks = uploaded_chunks.lock().await;
372
373    let mut dictionaries_by_id = HashMap::new();
374
375    while let Some(Ok(data)) = input_stream.next().await {
376        let message = arrow::ipc::root_as_message(&data.data_header[..])
377            .map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?;
378
379        match message.header_type() {
380            ipc::MessageHeader::Schema => {
381                return Err(Status::internal(
382                    "Not expecting a schema when messages are read",
383                ))
384            }
385            ipc::MessageHeader::RecordBatch => {
386                send_app_metadata(&mut response_tx, &data.app_metadata).await?;
387
388                let batch = record_batch_from_message(
389                    message,
390                    &Buffer::from(data.data_body.as_ref()),
391                    schema_ref.clone(),
392                    &dictionaries_by_id,
393                )
394                .await?;
395
396                chunks.push(batch);
397            }
398            ipc::MessageHeader::DictionaryBatch => {
399                dictionary_from_message(
400                    message,
401                    &Buffer::from(data.data_body.as_ref()),
402                    schema_ref.clone(),
403                    &mut dictionaries_by_id,
404                )
405                .await?;
406            }
407            t => {
408                return Err(Status::internal(format!(
409                    "Reading types other than record batches not yet supported, \
410                                              unable to read {t:?}"
411                )));
412            }
413        }
414    }
415
416    let dataset = IntegrationDataset { schema, chunks };
417    uploaded_chunks.insert(key, dataset);
418
419    Ok(())
420}