arrow_integration_testing/flight_client_scenarios/
integration_test.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Integration tests for the Flight client.
19
20use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24    array::ArrayRef,
25    buffer::Buffer,
26    datatypes::SchemaRef,
27    ipc::{self, reader, writer},
28    record_batch::RecordBatch,
29};
30use arrow_flight::{
31    flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
32    utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
33};
34use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
35use tonic::{Request, Streaming};
36
37use arrow::datatypes::Schema;
38use std::sync::Arc;
39
40type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
41type Result<T = (), E = Error> = std::result::Result<T, E>;
42
43type Client = FlightServiceClient<tonic::transport::Channel>;
44
45/// Run a scenario that uploads data to a Flight server and then downloads it back
46pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
47    let url = format!("http://{host}:{port}");
48
49    let client = FlightServiceClient::connect(url).await?;
50
51    let json_file = open_json_file(path)?;
52
53    let batches = json_file.read_batches()?;
54    let schema = Arc::new(json_file.schema);
55
56    let mut descriptor = FlightDescriptor::default();
57    descriptor.set_type(DescriptorType::Path);
58    descriptor.path = vec![path.to_string()];
59
60    upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
61    verify_data(client, descriptor, &batches).await?;
62
63    Ok(())
64}
65
66async fn upload_data(
67    mut client: Client,
68    schema: SchemaRef,
69    descriptor: FlightDescriptor,
70    original_data: Vec<RecordBatch>,
71) -> Result {
72    let (mut upload_tx, upload_rx) = mpsc::channel(10);
73
74    let options = arrow::ipc::writer::IpcWriteOptions::default();
75    let mut dict_tracker = writer::DictionaryTracker::new(false);
76    let data_gen = writer::IpcDataGenerator::default();
77    let data = IpcMessage(
78        data_gen
79            .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
80            .ipc_message
81            .into(),
82    );
83    let mut schema_flight_data = FlightData {
84        data_header: data.0,
85        ..Default::default()
86    };
87    // arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options);
88    schema_flight_data.flight_descriptor = Some(descriptor.clone());
89    upload_tx.send(schema_flight_data).await?;
90
91    let mut original_data_iter = original_data.iter().enumerate();
92
93    if let Some((counter, first_batch)) = original_data_iter.next() {
94        let metadata = counter.to_string().into_bytes();
95        // Preload the first batch into the channel before starting the request
96        send_batch(
97            &mut upload_tx,
98            &metadata,
99            first_batch,
100            &options,
101            &mut dict_tracker,
102        )
103        .await?;
104
105        let outer = client.do_put(Request::new(upload_rx)).await?;
106        let mut inner = outer.into_inner();
107
108        let r = inner
109            .next()
110            .await
111            .expect("No response received")
112            .expect("Invalid response received");
113        assert_eq!(metadata, r.app_metadata);
114
115        // Stream the rest of the batches
116        for (counter, batch) in original_data_iter {
117            let metadata = counter.to_string().into_bytes();
118            send_batch(
119                &mut upload_tx,
120                &metadata,
121                batch,
122                &options,
123                &mut dict_tracker,
124            )
125            .await?;
126
127            let r = inner
128                .next()
129                .await
130                .expect("No response received")
131                .expect("Invalid response received");
132            assert_eq!(metadata, r.app_metadata);
133        }
134        drop(upload_tx);
135        assert!(
136            inner.next().await.is_none(),
137            "Should not receive more results"
138        );
139    } else {
140        drop(upload_tx);
141        client.do_put(Request::new(upload_rx)).await?;
142    }
143
144    Ok(())
145}
146
147async fn send_batch(
148    upload_tx: &mut mpsc::Sender<FlightData>,
149    metadata: &[u8],
150    batch: &RecordBatch,
151    options: &writer::IpcWriteOptions,
152    dictionary_tracker: &mut writer::DictionaryTracker,
153) -> Result {
154    let data_gen = writer::IpcDataGenerator::default();
155
156    let (encoded_dictionaries, encoded_batch) = data_gen
157        .encoded_batch(batch, dictionary_tracker, options)
158        .expect("DictionaryTracker configured above to not error on replacement");
159
160    let dictionary_flight_data: Vec<FlightData> =
161        encoded_dictionaries.into_iter().map(Into::into).collect();
162    let mut batch_flight_data: FlightData = encoded_batch.into();
163
164    upload_tx
165        .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
166        .await?;
167
168    // Only the record batch's FlightData gets app_metadata
169    batch_flight_data.app_metadata = metadata.to_vec().into();
170    upload_tx.send(batch_flight_data).await?;
171    Ok(())
172}
173
174async fn verify_data(
175    mut client: Client,
176    descriptor: FlightDescriptor,
177    expected_data: &[RecordBatch],
178) -> Result {
179    let resp = client.get_flight_info(Request::new(descriptor)).await?;
180    let info = resp.into_inner();
181
182    assert!(
183        !info.endpoint.is_empty(),
184        "No endpoints returned from Flight server",
185    );
186    for endpoint in info.endpoint {
187        let ticket = endpoint
188            .ticket
189            .expect("No ticket returned from Flight server");
190
191        assert!(
192            !endpoint.location.is_empty(),
193            "No locations returned from Flight server",
194        );
195        for location in endpoint.location {
196            consume_flight_location(location, ticket.clone(), expected_data).await?;
197        }
198    }
199
200    Ok(())
201}
202
203async fn consume_flight_location(
204    location: Location,
205    ticket: Ticket,
206    expected_data: &[RecordBatch],
207) -> Result {
208    let mut location = location;
209    // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs
210    // don't recognize this as valid.
211    // more details: https://github.com/apache/arrow-rs/issues/1398
212    location.uri = location.uri.replace("grpc+tcp://", "http://");
213
214    let mut client = FlightServiceClient::connect(location.uri).await?;
215    let resp = client.do_get(ticket).await?;
216    let mut resp = resp.into_inner();
217
218    let flight_schema = receive_schema_flight_data(&mut resp)
219        .await
220        .unwrap_or_else(|| panic!("Failed to receive flight schema"));
221    let actual_schema = Arc::new(flight_schema);
222
223    let mut dictionaries_by_id = HashMap::new();
224
225    for (counter, expected_batch) in expected_data.iter().enumerate() {
226        let data =
227            receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
228                .await
229                .unwrap_or_else(|| {
230                    panic!(
231                        "Got fewer batches than expected, received so far: {} expected: {}",
232                        counter,
233                        expected_data.len(),
234                    )
235                });
236
237        let metadata = counter.to_string().into_bytes();
238        assert_eq!(metadata, data.app_metadata);
239
240        let actual_batch =
241            flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
242                .expect("Unable to convert flight data to Arrow batch");
243
244        assert_eq!(actual_schema, actual_batch.schema());
245        assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
246        assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
247        let schema = expected_batch.schema();
248        for i in 0..expected_batch.num_columns() {
249            let field = schema.field(i);
250            let field_name = field.name();
251
252            let expected_data = expected_batch.column(i).as_ref();
253            let actual_data = actual_batch.column(i).as_ref();
254
255            assert_eq!(expected_data, actual_data, "Data for field {field_name}");
256        }
257    }
258
259    assert!(
260        resp.next().await.is_none(),
261        "Got more batches than the expected: {}",
262        expected_data.len(),
263    );
264
265    Ok(())
266}
267
268async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
269    let data = resp.next().await?.ok()?;
270    let message =
271        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
272
273    // message header is a Schema, so read it
274    let ipc_schema: ipc::Schema = message
275        .header_as_schema()
276        .expect("Unable to read IPC message as schema");
277    let schema = ipc::convert::fb_to_schema(ipc_schema);
278
279    Some(schema)
280}
281
282async fn receive_batch_flight_data(
283    resp: &mut Streaming<FlightData>,
284    schema: SchemaRef,
285    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
286) -> Option<FlightData> {
287    let mut data = resp.next().await?.ok()?;
288    let mut message =
289        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
290
291    while message.header_type() == ipc::MessageHeader::DictionaryBatch {
292        reader::read_dictionary(
293            &Buffer::from(data.data_body.as_ref()),
294            message
295                .header_as_dictionary_batch()
296                .expect("Error parsing dictionary"),
297            &schema,
298            dictionaries_by_id,
299            &message.version(),
300        )
301        .expect("Error reading dictionary");
302
303        data = resp.next().await?.ok()?;
304        message =
305            arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
306    }
307
308    Some(data)
309}