arrow_integration_testing/flight_client_scenarios/
integration_test.rs1use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24 array::ArrayRef,
25 buffer::Buffer,
26 datatypes::SchemaRef,
27 ipc::{
28 self, reader,
29 writer::{self, CompressionContext},
30 },
31 record_batch::RecordBatch,
32};
33use arrow_flight::{
34 flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
35 utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
36};
37use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
38use tonic::{transport::Endpoint, Request, Streaming};
39
40use arrow::datatypes::Schema;
41use std::sync::Arc;
42
43type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
44type Result<T = (), E = Error> = std::result::Result<T, E>;
45
46type Client = FlightServiceClient<tonic::transport::Channel>;
47
48pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
50 let url = format!("http://{host}:{port}");
51
52 let endpoint = Endpoint::new(url)?;
53 let channel = endpoint.connect().await?;
54 let client = FlightServiceClient::new(channel);
55
56 let json_file = open_json_file(path)?;
57
58 let batches = json_file.read_batches()?;
59 let schema = Arc::new(json_file.schema);
60
61 let mut descriptor = FlightDescriptor::default();
62 descriptor.set_type(DescriptorType::Path);
63 descriptor.path = vec![path.to_string()];
64
65 upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
66 verify_data(client, descriptor, &batches).await?;
67
68 Ok(())
69}
70
71async fn upload_data(
72 mut client: Client,
73 schema: SchemaRef,
74 descriptor: FlightDescriptor,
75 original_data: Vec<RecordBatch>,
76) -> Result {
77 let (mut upload_tx, upload_rx) = mpsc::channel(10);
78
79 let options = arrow::ipc::writer::IpcWriteOptions::default();
80 let mut dict_tracker = writer::DictionaryTracker::new(false);
81 let data_gen = writer::IpcDataGenerator::default();
82 let data = IpcMessage(
83 data_gen
84 .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
85 .ipc_message
86 .into(),
87 );
88 let mut schema_flight_data = FlightData {
89 data_header: data.0,
90 ..Default::default()
91 };
92 schema_flight_data.flight_descriptor = Some(descriptor.clone());
94 upload_tx.send(schema_flight_data).await?;
95
96 let mut original_data_iter = original_data.iter().enumerate();
97
98 let mut compression_context = CompressionContext::default();
99
100 if let Some((counter, first_batch)) = original_data_iter.next() {
101 let metadata = counter.to_string().into_bytes();
102 send_batch(
104 &mut upload_tx,
105 &metadata,
106 first_batch,
107 &options,
108 &mut dict_tracker,
109 &mut compression_context,
110 )
111 .await?;
112
113 let outer = client.do_put(Request::new(upload_rx)).await?;
114 let mut inner = outer.into_inner();
115
116 let r = inner
117 .next()
118 .await
119 .expect("No response received")
120 .expect("Invalid response received");
121 assert_eq!(metadata, r.app_metadata);
122
123 for (counter, batch) in original_data_iter {
125 let metadata = counter.to_string().into_bytes();
126 send_batch(
127 &mut upload_tx,
128 &metadata,
129 batch,
130 &options,
131 &mut dict_tracker,
132 &mut compression_context,
133 )
134 .await?;
135
136 let r = inner
137 .next()
138 .await
139 .expect("No response received")
140 .expect("Invalid response received");
141 assert_eq!(metadata, r.app_metadata);
142 }
143 drop(upload_tx);
144 assert!(
145 inner.next().await.is_none(),
146 "Should not receive more results"
147 );
148 } else {
149 drop(upload_tx);
150 client.do_put(Request::new(upload_rx)).await?;
151 }
152
153 Ok(())
154}
155
156async fn send_batch(
157 upload_tx: &mut mpsc::Sender<FlightData>,
158 metadata: &[u8],
159 batch: &RecordBatch,
160 options: &writer::IpcWriteOptions,
161 dictionary_tracker: &mut writer::DictionaryTracker,
162 compression_context: &mut CompressionContext,
163) -> Result {
164 let data_gen = writer::IpcDataGenerator::default();
165
166 let (encoded_dictionaries, encoded_batch) = data_gen
167 .encode(batch, dictionary_tracker, options, compression_context)
168 .expect("DictionaryTracker configured above to not error on replacement");
169
170 let dictionary_flight_data: Vec<FlightData> =
171 encoded_dictionaries.into_iter().map(Into::into).collect();
172 let mut batch_flight_data: FlightData = encoded_batch.into();
173
174 upload_tx
175 .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
176 .await?;
177
178 batch_flight_data.app_metadata = metadata.to_vec().into();
180 upload_tx.send(batch_flight_data).await?;
181 Ok(())
182}
183
184async fn verify_data(
185 mut client: Client,
186 descriptor: FlightDescriptor,
187 expected_data: &[RecordBatch],
188) -> Result {
189 let resp = client.get_flight_info(Request::new(descriptor)).await?;
190 let info = resp.into_inner();
191
192 assert!(
193 !info.endpoint.is_empty(),
194 "No endpoints returned from Flight server",
195 );
196 for endpoint in info.endpoint {
197 let ticket = endpoint
198 .ticket
199 .expect("No ticket returned from Flight server");
200
201 assert!(
202 !endpoint.location.is_empty(),
203 "No locations returned from Flight server",
204 );
205 for location in endpoint.location {
206 consume_flight_location(location, ticket.clone(), expected_data).await?;
207 }
208 }
209
210 Ok(())
211}
212
213async fn consume_flight_location(
214 location: Location,
215 ticket: Ticket,
216 expected_data: &[RecordBatch],
217) -> Result {
218 let mut location = location;
219 location.uri = location.uri.replace("grpc+tcp://", "http://");
223
224 let endpoint = Endpoint::new(location.uri)?;
225 let channel = endpoint.connect().await?;
226 let mut client = FlightServiceClient::new(channel);
227 let resp = client.do_get(ticket).await?;
228 let mut resp = resp.into_inner();
229
230 let flight_schema = receive_schema_flight_data(&mut resp)
231 .await
232 .unwrap_or_else(|| panic!("Failed to receive flight schema"));
233 let actual_schema = Arc::new(flight_schema);
234
235 let mut dictionaries_by_id = HashMap::new();
236
237 for (counter, expected_batch) in expected_data.iter().enumerate() {
238 let data =
239 receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
240 .await
241 .unwrap_or_else(|| {
242 panic!(
243 "Got fewer batches than expected, received so far: {} expected: {}",
244 counter,
245 expected_data.len(),
246 )
247 });
248
249 let metadata = counter.to_string().into_bytes();
250 assert_eq!(metadata, data.app_metadata);
251
252 let actual_batch =
253 flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
254 .expect("Unable to convert flight data to Arrow batch");
255
256 assert_eq!(actual_schema, actual_batch.schema());
257 assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
258 assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
259 let schema = expected_batch.schema();
260 for i in 0..expected_batch.num_columns() {
261 let field = schema.field(i);
262 let field_name = field.name();
263
264 let expected_data = expected_batch.column(i).as_ref();
265 let actual_data = actual_batch.column(i).as_ref();
266
267 assert_eq!(expected_data, actual_data, "Data for field {field_name}");
268 }
269 }
270
271 assert!(
272 resp.next().await.is_none(),
273 "Got more batches than the expected: {}",
274 expected_data.len(),
275 );
276
277 Ok(())
278}
279
280async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
281 let data = resp.next().await?.ok()?;
282 let message =
283 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
284
285 let ipc_schema: ipc::Schema = message
287 .header_as_schema()
288 .expect("Unable to read IPC message as schema");
289 let schema = ipc::convert::fb_to_schema(ipc_schema);
290
291 Some(schema)
292}
293
294async fn receive_batch_flight_data(
295 resp: &mut Streaming<FlightData>,
296 schema: SchemaRef,
297 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
298) -> Option<FlightData> {
299 let mut data = resp.next().await?.ok()?;
300 let mut message =
301 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
302
303 while message.header_type() == ipc::MessageHeader::DictionaryBatch {
304 reader::read_dictionary(
305 &Buffer::from(data.data_body.as_ref()),
306 message
307 .header_as_dictionary_batch()
308 .expect("Error parsing dictionary"),
309 &schema,
310 dictionaries_by_id,
311 &message.version(),
312 )
313 .expect("Error reading dictionary");
314
315 data = resp.next().await?.ok()?;
316 message =
317 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
318 }
319
320 Some(data)
321}