arrow_integration_testing/flight_client_scenarios/
integration_test.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Integration tests for the Flight client.
19
20use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24    array::ArrayRef,
25    buffer::Buffer,
26    datatypes::SchemaRef,
27    ipc::{
28        self, reader,
29        writer::{self, CompressionContext},
30    },
31    record_batch::RecordBatch,
32};
33use arrow_flight::{
34    flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
35    utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
36};
37use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
38use tonic::{transport::Endpoint, Request, Streaming};
39
40use arrow::datatypes::Schema;
41use std::sync::Arc;
42
43type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
44type Result<T = (), E = Error> = std::result::Result<T, E>;
45
46type Client = FlightServiceClient<tonic::transport::Channel>;
47
48/// Run a scenario that uploads data to a Flight server and then downloads it back
49pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
50    let url = format!("http://{host}:{port}");
51
52    let endpoint = Endpoint::new(url)?;
53    let channel = endpoint.connect().await?;
54    let client = FlightServiceClient::new(channel);
55
56    let json_file = open_json_file(path)?;
57
58    let batches = json_file.read_batches()?;
59    let schema = Arc::new(json_file.schema);
60
61    let mut descriptor = FlightDescriptor::default();
62    descriptor.set_type(DescriptorType::Path);
63    descriptor.path = vec![path.to_string()];
64
65    upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
66    verify_data(client, descriptor, &batches).await?;
67
68    Ok(())
69}
70
71async fn upload_data(
72    mut client: Client,
73    schema: SchemaRef,
74    descriptor: FlightDescriptor,
75    original_data: Vec<RecordBatch>,
76) -> Result {
77    let (mut upload_tx, upload_rx) = mpsc::channel(10);
78
79    let options = arrow::ipc::writer::IpcWriteOptions::default();
80    let mut dict_tracker = writer::DictionaryTracker::new(false);
81    let data_gen = writer::IpcDataGenerator::default();
82    let data = IpcMessage(
83        data_gen
84            .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
85            .ipc_message
86            .into(),
87    );
88    let mut schema_flight_data = FlightData {
89        data_header: data.0,
90        ..Default::default()
91    };
92    // arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options);
93    schema_flight_data.flight_descriptor = Some(descriptor.clone());
94    upload_tx.send(schema_flight_data).await?;
95
96    let mut original_data_iter = original_data.iter().enumerate();
97
98    let mut compression_context = CompressionContext::default();
99
100    if let Some((counter, first_batch)) = original_data_iter.next() {
101        let metadata = counter.to_string().into_bytes();
102        // Preload the first batch into the channel before starting the request
103        send_batch(
104            &mut upload_tx,
105            &metadata,
106            first_batch,
107            &options,
108            &mut dict_tracker,
109            &mut compression_context,
110        )
111        .await?;
112
113        let outer = client.do_put(Request::new(upload_rx)).await?;
114        let mut inner = outer.into_inner();
115
116        let r = inner
117            .next()
118            .await
119            .expect("No response received")
120            .expect("Invalid response received");
121        assert_eq!(metadata, r.app_metadata);
122
123        // Stream the rest of the batches
124        for (counter, batch) in original_data_iter {
125            let metadata = counter.to_string().into_bytes();
126            send_batch(
127                &mut upload_tx,
128                &metadata,
129                batch,
130                &options,
131                &mut dict_tracker,
132                &mut compression_context,
133            )
134            .await?;
135
136            let r = inner
137                .next()
138                .await
139                .expect("No response received")
140                .expect("Invalid response received");
141            assert_eq!(metadata, r.app_metadata);
142        }
143        drop(upload_tx);
144        assert!(
145            inner.next().await.is_none(),
146            "Should not receive more results"
147        );
148    } else {
149        drop(upload_tx);
150        client.do_put(Request::new(upload_rx)).await?;
151    }
152
153    Ok(())
154}
155
156async fn send_batch(
157    upload_tx: &mut mpsc::Sender<FlightData>,
158    metadata: &[u8],
159    batch: &RecordBatch,
160    options: &writer::IpcWriteOptions,
161    dictionary_tracker: &mut writer::DictionaryTracker,
162    compression_context: &mut CompressionContext,
163) -> Result {
164    let data_gen = writer::IpcDataGenerator::default();
165
166    let (encoded_dictionaries, encoded_batch) = data_gen
167        .encode(batch, dictionary_tracker, options, compression_context)
168        .expect("DictionaryTracker configured above to not error on replacement");
169
170    let dictionary_flight_data: Vec<FlightData> =
171        encoded_dictionaries.into_iter().map(Into::into).collect();
172    let mut batch_flight_data: FlightData = encoded_batch.into();
173
174    upload_tx
175        .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
176        .await?;
177
178    // Only the record batch's FlightData gets app_metadata
179    batch_flight_data.app_metadata = metadata.to_vec().into();
180    upload_tx.send(batch_flight_data).await?;
181    Ok(())
182}
183
184async fn verify_data(
185    mut client: Client,
186    descriptor: FlightDescriptor,
187    expected_data: &[RecordBatch],
188) -> Result {
189    let resp = client.get_flight_info(Request::new(descriptor)).await?;
190    let info = resp.into_inner();
191
192    assert!(
193        !info.endpoint.is_empty(),
194        "No endpoints returned from Flight server",
195    );
196    for endpoint in info.endpoint {
197        let ticket = endpoint
198            .ticket
199            .expect("No ticket returned from Flight server");
200
201        assert!(
202            !endpoint.location.is_empty(),
203            "No locations returned from Flight server",
204        );
205        for location in endpoint.location {
206            consume_flight_location(location, ticket.clone(), expected_data).await?;
207        }
208    }
209
210    Ok(())
211}
212
213async fn consume_flight_location(
214    location: Location,
215    ticket: Ticket,
216    expected_data: &[RecordBatch],
217) -> Result {
218    let mut location = location;
219    // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs
220    // don't recognize this as valid.
221    // more details: https://github.com/apache/arrow-rs/issues/1398
222    location.uri = location.uri.replace("grpc+tcp://", "http://");
223
224    let endpoint = Endpoint::new(location.uri)?;
225    let channel = endpoint.connect().await?;
226    let mut client = FlightServiceClient::new(channel);
227    let resp = client.do_get(ticket).await?;
228    let mut resp = resp.into_inner();
229
230    let flight_schema = receive_schema_flight_data(&mut resp)
231        .await
232        .unwrap_or_else(|| panic!("Failed to receive flight schema"));
233    let actual_schema = Arc::new(flight_schema);
234
235    let mut dictionaries_by_id = HashMap::new();
236
237    for (counter, expected_batch) in expected_data.iter().enumerate() {
238        let data =
239            receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
240                .await
241                .unwrap_or_else(|| {
242                    panic!(
243                        "Got fewer batches than expected, received so far: {} expected: {}",
244                        counter,
245                        expected_data.len(),
246                    )
247                });
248
249        let metadata = counter.to_string().into_bytes();
250        assert_eq!(metadata, data.app_metadata);
251
252        let actual_batch =
253            flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
254                .expect("Unable to convert flight data to Arrow batch");
255
256        assert_eq!(actual_schema, actual_batch.schema());
257        assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
258        assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
259        let schema = expected_batch.schema();
260        for i in 0..expected_batch.num_columns() {
261            let field = schema.field(i);
262            let field_name = field.name();
263
264            let expected_data = expected_batch.column(i).as_ref();
265            let actual_data = actual_batch.column(i).as_ref();
266
267            assert_eq!(expected_data, actual_data, "Data for field {field_name}");
268        }
269    }
270
271    assert!(
272        resp.next().await.is_none(),
273        "Got more batches than the expected: {}",
274        expected_data.len(),
275    );
276
277    Ok(())
278}
279
280async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
281    let data = resp.next().await?.ok()?;
282    let message =
283        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
284
285    // message header is a Schema, so read it
286    let ipc_schema: ipc::Schema = message
287        .header_as_schema()
288        .expect("Unable to read IPC message as schema");
289    let schema = ipc::convert::fb_to_schema(ipc_schema);
290
291    Some(schema)
292}
293
294async fn receive_batch_flight_data(
295    resp: &mut Streaming<FlightData>,
296    schema: SchemaRef,
297    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
298) -> Option<FlightData> {
299    let mut data = resp.next().await?.ok()?;
300    let mut message =
301        arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
302
303    while message.header_type() == ipc::MessageHeader::DictionaryBatch {
304        reader::read_dictionary(
305            &Buffer::from(data.data_body.as_ref()),
306            message
307                .header_as_dictionary_batch()
308                .expect("Error parsing dictionary"),
309            &schema,
310            dictionaries_by_id,
311            &message.version(),
312        )
313        .expect("Error reading dictionary");
314
315        data = resp.next().await?.ok()?;
316        message =
317            arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
318    }
319
320    Some(data)
321}