arrow_integration_testing/flight_client_scenarios/
integration_test.rs1use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24 array::ArrayRef,
25 buffer::Buffer,
26 datatypes::SchemaRef,
27 ipc::{self, reader, writer},
28 record_batch::RecordBatch,
29};
30use arrow_flight::{
31 flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
32 utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
33};
34use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
35use tonic::{Request, Streaming};
36
37use arrow::datatypes::Schema;
38use std::sync::Arc;
39
40type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
41type Result<T = (), E = Error> = std::result::Result<T, E>;
42
43type Client = FlightServiceClient<tonic::transport::Channel>;
44
45pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
47 let url = format!("http://{host}:{port}");
48
49 let client = FlightServiceClient::connect(url).await?;
50
51 let json_file = open_json_file(path)?;
52
53 let batches = json_file.read_batches()?;
54 let schema = Arc::new(json_file.schema);
55
56 let mut descriptor = FlightDescriptor::default();
57 descriptor.set_type(DescriptorType::Path);
58 descriptor.path = vec![path.to_string()];
59
60 upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
61 verify_data(client, descriptor, &batches).await?;
62
63 Ok(())
64}
65
66async fn upload_data(
67 mut client: Client,
68 schema: SchemaRef,
69 descriptor: FlightDescriptor,
70 original_data: Vec<RecordBatch>,
71) -> Result {
72 let (mut upload_tx, upload_rx) = mpsc::channel(10);
73
74 let options = arrow::ipc::writer::IpcWriteOptions::default();
75 let mut dict_tracker = writer::DictionaryTracker::new(false);
76 let data_gen = writer::IpcDataGenerator::default();
77 let data = IpcMessage(
78 data_gen
79 .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
80 .ipc_message
81 .into(),
82 );
83 let mut schema_flight_data = FlightData {
84 data_header: data.0,
85 ..Default::default()
86 };
87 schema_flight_data.flight_descriptor = Some(descriptor.clone());
89 upload_tx.send(schema_flight_data).await?;
90
91 let mut original_data_iter = original_data.iter().enumerate();
92
93 if let Some((counter, first_batch)) = original_data_iter.next() {
94 let metadata = counter.to_string().into_bytes();
95 send_batch(
97 &mut upload_tx,
98 &metadata,
99 first_batch,
100 &options,
101 &mut dict_tracker,
102 )
103 .await?;
104
105 let outer = client.do_put(Request::new(upload_rx)).await?;
106 let mut inner = outer.into_inner();
107
108 let r = inner
109 .next()
110 .await
111 .expect("No response received")
112 .expect("Invalid response received");
113 assert_eq!(metadata, r.app_metadata);
114
115 for (counter, batch) in original_data_iter {
117 let metadata = counter.to_string().into_bytes();
118 send_batch(
119 &mut upload_tx,
120 &metadata,
121 batch,
122 &options,
123 &mut dict_tracker,
124 )
125 .await?;
126
127 let r = inner
128 .next()
129 .await
130 .expect("No response received")
131 .expect("Invalid response received");
132 assert_eq!(metadata, r.app_metadata);
133 }
134 drop(upload_tx);
135 assert!(
136 inner.next().await.is_none(),
137 "Should not receive more results"
138 );
139 } else {
140 drop(upload_tx);
141 client.do_put(Request::new(upload_rx)).await?;
142 }
143
144 Ok(())
145}
146
147async fn send_batch(
148 upload_tx: &mut mpsc::Sender<FlightData>,
149 metadata: &[u8],
150 batch: &RecordBatch,
151 options: &writer::IpcWriteOptions,
152 dictionary_tracker: &mut writer::DictionaryTracker,
153) -> Result {
154 let data_gen = writer::IpcDataGenerator::default();
155
156 let (encoded_dictionaries, encoded_batch) = data_gen
157 .encoded_batch(batch, dictionary_tracker, options)
158 .expect("DictionaryTracker configured above to not error on replacement");
159
160 let dictionary_flight_data: Vec<FlightData> =
161 encoded_dictionaries.into_iter().map(Into::into).collect();
162 let mut batch_flight_data: FlightData = encoded_batch.into();
163
164 upload_tx
165 .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
166 .await?;
167
168 batch_flight_data.app_metadata = metadata.to_vec().into();
170 upload_tx.send(batch_flight_data).await?;
171 Ok(())
172}
173
174async fn verify_data(
175 mut client: Client,
176 descriptor: FlightDescriptor,
177 expected_data: &[RecordBatch],
178) -> Result {
179 let resp = client.get_flight_info(Request::new(descriptor)).await?;
180 let info = resp.into_inner();
181
182 assert!(
183 !info.endpoint.is_empty(),
184 "No endpoints returned from Flight server",
185 );
186 for endpoint in info.endpoint {
187 let ticket = endpoint
188 .ticket
189 .expect("No ticket returned from Flight server");
190
191 assert!(
192 !endpoint.location.is_empty(),
193 "No locations returned from Flight server",
194 );
195 for location in endpoint.location {
196 consume_flight_location(location, ticket.clone(), expected_data).await?;
197 }
198 }
199
200 Ok(())
201}
202
203async fn consume_flight_location(
204 location: Location,
205 ticket: Ticket,
206 expected_data: &[RecordBatch],
207) -> Result {
208 let mut location = location;
209 location.uri = location.uri.replace("grpc+tcp://", "http://");
213
214 let mut client = FlightServiceClient::connect(location.uri).await?;
215 let resp = client.do_get(ticket).await?;
216 let mut resp = resp.into_inner();
217
218 let flight_schema = receive_schema_flight_data(&mut resp)
219 .await
220 .unwrap_or_else(|| panic!("Failed to receive flight schema"));
221 let actual_schema = Arc::new(flight_schema);
222
223 let mut dictionaries_by_id = HashMap::new();
224
225 for (counter, expected_batch) in expected_data.iter().enumerate() {
226 let data =
227 receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
228 .await
229 .unwrap_or_else(|| {
230 panic!(
231 "Got fewer batches than expected, received so far: {} expected: {}",
232 counter,
233 expected_data.len(),
234 )
235 });
236
237 let metadata = counter.to_string().into_bytes();
238 assert_eq!(metadata, data.app_metadata);
239
240 let actual_batch =
241 flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
242 .expect("Unable to convert flight data to Arrow batch");
243
244 assert_eq!(actual_schema, actual_batch.schema());
245 assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
246 assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
247 let schema = expected_batch.schema();
248 for i in 0..expected_batch.num_columns() {
249 let field = schema.field(i);
250 let field_name = field.name();
251
252 let expected_data = expected_batch.column(i).as_ref();
253 let actual_data = actual_batch.column(i).as_ref();
254
255 assert_eq!(expected_data, actual_data, "Data for field {field_name}");
256 }
257 }
258
259 assert!(
260 resp.next().await.is_none(),
261 "Got more batches than the expected: {}",
262 expected_data.len(),
263 );
264
265 Ok(())
266}
267
268async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
269 let data = resp.next().await?.ok()?;
270 let message =
271 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
272
273 let ipc_schema: ipc::Schema = message
275 .header_as_schema()
276 .expect("Unable to read IPC message as schema");
277 let schema = ipc::convert::fb_to_schema(ipc_schema);
278
279 Some(schema)
280}
281
282async fn receive_batch_flight_data(
283 resp: &mut Streaming<FlightData>,
284 schema: SchemaRef,
285 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
286) -> Option<FlightData> {
287 let mut data = resp.next().await?.ok()?;
288 let mut message =
289 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
290
291 while message.header_type() == ipc::MessageHeader::DictionaryBatch {
292 reader::read_dictionary(
293 &Buffer::from(data.data_body.as_ref()),
294 message
295 .header_as_dictionary_batch()
296 .expect("Error parsing dictionary"),
297 &schema,
298 dictionaries_by_id,
299 &message.version(),
300 )
301 .expect("Error reading dictionary");
302
303 data = resp.next().await?.ok()?;
304 message =
305 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
306 }
307
308 Some(data)
309}