arrow_integration_testing/flight_client_scenarios/
integration_test.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Integration tests for the Flight client.
19
20use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24    array::ArrayRef,
25    buffer::Buffer,
26    datatypes::SchemaRef,
27    ipc::{self, reader, writer},
28    record_batch::RecordBatch,
29};
30use arrow_flight::{
31    flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
32    utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
33};
34use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
35use tonic::{Request, Streaming};
36
37use arrow::datatypes::Schema;
38use std::sync::Arc;
39
40type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
41type Result<T = (), E = Error> = std::result::Result<T, E>;
42
43type Client = FlightServiceClient<tonic::transport::Channel>;
44
45/// Run a scenario that uploads data to a Flight server and then downloads it back
46pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
47    let url = format!("http://{host}:{port}");
48
49    let client = FlightServiceClient::connect(url).await?;
50
51    let json_file = open_json_file(path)?;
52
53    let batches = json_file.read_batches()?;
54    let schema = Arc::new(json_file.schema);
55
56    let mut descriptor = FlightDescriptor::default();
57    descriptor.set_type(DescriptorType::Path);
58    descriptor.path = vec![path.to_string()];
59
60    upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
61    verify_data(client, descriptor, &batches).await?;
62
63    Ok(())
64}
65
66async fn upload_data(
67    mut client: Client,
68    schema: SchemaRef,
69    descriptor: FlightDescriptor,
70    original_data: Vec<RecordBatch>,
71) -> Result {
72    let (mut upload_tx, upload_rx) = mpsc::channel(10);
73
74    let options = arrow::ipc::writer::IpcWriteOptions::default();
75    #[allow(deprecated)]
76    let mut dict_tracker =
77        writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
78    let data_gen = writer::IpcDataGenerator::default();
79    let data = IpcMessage(
80        data_gen
81            .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
82            .ipc_message
83            .into(),
84    );
85    let mut schema_flight_data = FlightData {
86        data_header: data.0,
87        ..Default::default()
88    };
89    // arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options);
90    schema_flight_data.flight_descriptor = Some(descriptor.clone());
91    upload_tx.send(schema_flight_data).await?;
92
93    let mut original_data_iter = original_data.iter().enumerate();
94
95    if let Some((counter, first_batch)) = original_data_iter.next() {
96        let metadata = counter.to_string().into_bytes();
97        // Preload the first batch into the channel before starting the request
98        send_batch(
99            &mut upload_tx,
100            &metadata,
101            first_batch,
102            &options,
103            &mut dict_tracker,
104        )
105        .await?;
106
107        let outer = client.do_put(Request::new(upload_rx)).await?;
108        let mut inner = outer.into_inner();
109
110        let r = inner
111            .next()
112            .await
113            .expect("No response received")
114            .expect("Invalid response received");
115        assert_eq!(metadata, r.app_metadata);
116
117        // Stream the rest of the batches
118        for (counter, batch) in original_data_iter {
119            let metadata = counter.to_string().into_bytes();
120            send_batch(
121                &mut upload_tx,
122                &metadata,
123                batch,
124                &options,
125                &mut dict_tracker,
126            )
127            .await?;
128
129            let r = inner
130                .next()
131                .await
132                .expect("No response received")
133                .expect("Invalid response received");
134            assert_eq!(metadata, r.app_metadata);
135        }
136        drop(upload_tx);
137        assert!(
138            inner.next().await.is_none(),
139            "Should not receive more results"
140        );
141    } else {
142        drop(upload_tx);
143        client.do_put(Request::new(upload_rx)).await?;
144    }
145
146    Ok(())
147}
148
149async fn send_batch(
150    upload_tx: &mut mpsc::Sender<FlightData>,
151    metadata: &[u8],
152    batch: &RecordBatch,
153    options: &writer::IpcWriteOptions,
154    dictionary_tracker: &mut writer::DictionaryTracker,
155) -> Result {
156    let data_gen = writer::IpcDataGenerator::default();
157
158    let (encoded_dictionaries, encoded_batch) = data_gen
159        .encoded_batch(batch, dictionary_tracker, options)
160        .expect("DictionaryTracker configured above to not error on replacement");
161
162    let dictionary_flight_data: Vec<FlightData> =
163        encoded_dictionaries.into_iter().map(Into::into).collect();
164    let mut batch_flight_data: FlightData = encoded_batch.into();
165
166    upload_tx
167        .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
168        .await?;
169
170    // Only the record batch's FlightData gets app_metadata
171    batch_flight_data.app_metadata = metadata.to_vec().into();
172    upload_tx.send(batch_flight_data).await?;
173    Ok(())
174}
175
176async fn verify_data(
177    mut client: Client,
178    descriptor: FlightDescriptor,
179    expected_data: &[RecordBatch],
180) -> Result {
181    let resp = client.get_flight_info(Request::new(descriptor)).await?;
182    let info = resp.into_inner();
183
184    assert!(
185        !info.endpoint.is_empty(),
186        "No endpoints returned from Flight server",
187    );
188    for endpoint in info.endpoint {
189        let ticket = endpoint
190            .ticket
191            .expect("No ticket returned from Flight server");
192
193        assert!(
194            !endpoint.location.is_empty(),
195            "No locations returned from Flight server",
196        );
197        for location in endpoint.location {
198            consume_flight_location(location, ticket.clone(), expected_data).await?;
199        }
200    }
201
202    Ok(())
203}
204
205async fn consume_flight_location(
206    location: Location,
207    ticket: Ticket,
208    expected_data: &[RecordBatch],
209) -> Result {
210    let mut location = location;
211    // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs
212    // don't recognize this as valid.
213    // more details: https://github.com/apache/arrow-rs/issues/1398
214    location.uri = location.uri.replace("grpc+tcp://", "http://");
215
216    let mut client = FlightServiceClient::connect(location.uri).await?;
217    let resp = client.do_get(ticket).await?;
218    let mut resp = resp.into_inner();
219
220    let flight_schema = receive_schema_flight_data(&mut resp)
221        .await
222        .unwrap_or_else(|| panic!("Failed to receive flight schema"));
223    let actual_schema = Arc::new(flight_schema);
224
225    let mut dictionaries_by_id = HashMap::new();
226
227    for (counter, expected_batch) in expected_data.iter().enumerate() {
228        let data =
229            receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
230                .await
231                .unwrap_or_else(|| {
232                    panic!(
233                        "Got fewer batches than expected, received so far: {} expected: {}",
234                        counter,
235                        expected_data.len(),
236                    )
237                });
238
239        let metadata = counter.to_string().into_bytes();
240        assert_eq!(metadata, data.app_metadata);
241
242        let actual_batch =
243            flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
244                .expect("Unable to convert flight data to Arrow batch");
245
246        assert_eq!(actual_schema, actual_batch.schema());
247        assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
248        assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
249        let schema = expected_batch.schema();
250        for i in 0..expected_batch.num_columns() {
251            let field = schema.field(i);
252            let field_name = field.name();
253
254            let expected_data = expected_batch.column(i).as_ref();
255            let actual_data = actual_batch.column(i).as_ref();
256
257            assert_eq!(expected_data, actual_data, "Data for field {field_name}");
258        }
259    }
260
261    assert!(
262        resp.next().await.is_none(),
263        "Got more batches than the expected: {}",
264        expected_data.len(),
265    );
266
267    Ok(())
268}
269
270async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
271    let data = resp.next().await?.ok()?;
272    let message =
273        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
274
275    // message header is a Schema, so read it
276    let ipc_schema: ipc::Schema = message
277        .header_as_schema()
278        .expect("Unable to read IPC message as schema");
279    let schema = ipc::convert::fb_to_schema(ipc_schema);
280
281    Some(schema)
282}
283
284async fn receive_batch_flight_data(
285    resp: &mut Streaming<FlightData>,
286    schema: SchemaRef,
287    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
288) -> Option<FlightData> {
289    let mut data = resp.next().await?.ok()?;
290    let mut message =
291        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
292
293    while message.header_type() == ipc::MessageHeader::DictionaryBatch {
294        reader::read_dictionary(
295            &Buffer::from(data.data_body.as_ref()),
296            message
297                .header_as_dictionary_batch()
298                .expect("Error parsing dictionary"),
299            &schema,
300            dictionaries_by_id,
301            &message.version(),
302        )
303        .expect("Error reading dictionary");
304
305        data = resp.next().await?.ok()?;
306        message =
307            arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
308    }
309
310    Some(data)
311}