arrow_integration_testing/flight_client_scenarios/
integration_test.rs1use crate::open_json_file;
21use std::collections::HashMap;
22
23use arrow::{
24 array::ArrayRef,
25 buffer::Buffer,
26 datatypes::SchemaRef,
27 ipc::{self, reader, writer},
28 record_batch::RecordBatch,
29};
30use arrow_flight::{
31 flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
32 utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
33};
34use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
35use tonic::{Request, Streaming};
36
37use arrow::datatypes::Schema;
38use std::sync::Arc;
39
40type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
41type Result<T = (), E = Error> = std::result::Result<T, E>;
42
43type Client = FlightServiceClient<tonic::transport::Channel>;
44
45pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result {
47 let url = format!("http://{host}:{port}");
48
49 let client = FlightServiceClient::connect(url).await?;
50
51 let json_file = open_json_file(path)?;
52
53 let batches = json_file.read_batches()?;
54 let schema = Arc::new(json_file.schema);
55
56 let mut descriptor = FlightDescriptor::default();
57 descriptor.set_type(DescriptorType::Path);
58 descriptor.path = vec![path.to_string()];
59
60 upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?;
61 verify_data(client, descriptor, &batches).await?;
62
63 Ok(())
64}
65
66async fn upload_data(
67 mut client: Client,
68 schema: SchemaRef,
69 descriptor: FlightDescriptor,
70 original_data: Vec<RecordBatch>,
71) -> Result {
72 let (mut upload_tx, upload_rx) = mpsc::channel(10);
73
74 let options = arrow::ipc::writer::IpcWriteOptions::default();
75 #[allow(deprecated)]
76 let mut dict_tracker =
77 writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
78 let data_gen = writer::IpcDataGenerator::default();
79 let data = IpcMessage(
80 data_gen
81 .schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
82 .ipc_message
83 .into(),
84 );
85 let mut schema_flight_data = FlightData {
86 data_header: data.0,
87 ..Default::default()
88 };
89 schema_flight_data.flight_descriptor = Some(descriptor.clone());
91 upload_tx.send(schema_flight_data).await?;
92
93 let mut original_data_iter = original_data.iter().enumerate();
94
95 if let Some((counter, first_batch)) = original_data_iter.next() {
96 let metadata = counter.to_string().into_bytes();
97 send_batch(
99 &mut upload_tx,
100 &metadata,
101 first_batch,
102 &options,
103 &mut dict_tracker,
104 )
105 .await?;
106
107 let outer = client.do_put(Request::new(upload_rx)).await?;
108 let mut inner = outer.into_inner();
109
110 let r = inner
111 .next()
112 .await
113 .expect("No response received")
114 .expect("Invalid response received");
115 assert_eq!(metadata, r.app_metadata);
116
117 for (counter, batch) in original_data_iter {
119 let metadata = counter.to_string().into_bytes();
120 send_batch(
121 &mut upload_tx,
122 &metadata,
123 batch,
124 &options,
125 &mut dict_tracker,
126 )
127 .await?;
128
129 let r = inner
130 .next()
131 .await
132 .expect("No response received")
133 .expect("Invalid response received");
134 assert_eq!(metadata, r.app_metadata);
135 }
136 drop(upload_tx);
137 assert!(
138 inner.next().await.is_none(),
139 "Should not receive more results"
140 );
141 } else {
142 drop(upload_tx);
143 client.do_put(Request::new(upload_rx)).await?;
144 }
145
146 Ok(())
147}
148
149async fn send_batch(
150 upload_tx: &mut mpsc::Sender<FlightData>,
151 metadata: &[u8],
152 batch: &RecordBatch,
153 options: &writer::IpcWriteOptions,
154 dictionary_tracker: &mut writer::DictionaryTracker,
155) -> Result {
156 let data_gen = writer::IpcDataGenerator::default();
157
158 let (encoded_dictionaries, encoded_batch) = data_gen
159 .encoded_batch(batch, dictionary_tracker, options)
160 .expect("DictionaryTracker configured above to not error on replacement");
161
162 let dictionary_flight_data: Vec<FlightData> =
163 encoded_dictionaries.into_iter().map(Into::into).collect();
164 let mut batch_flight_data: FlightData = encoded_batch.into();
165
166 upload_tx
167 .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
168 .await?;
169
170 batch_flight_data.app_metadata = metadata.to_vec().into();
172 upload_tx.send(batch_flight_data).await?;
173 Ok(())
174}
175
176async fn verify_data(
177 mut client: Client,
178 descriptor: FlightDescriptor,
179 expected_data: &[RecordBatch],
180) -> Result {
181 let resp = client.get_flight_info(Request::new(descriptor)).await?;
182 let info = resp.into_inner();
183
184 assert!(
185 !info.endpoint.is_empty(),
186 "No endpoints returned from Flight server",
187 );
188 for endpoint in info.endpoint {
189 let ticket = endpoint
190 .ticket
191 .expect("No ticket returned from Flight server");
192
193 assert!(
194 !endpoint.location.is_empty(),
195 "No locations returned from Flight server",
196 );
197 for location in endpoint.location {
198 consume_flight_location(location, ticket.clone(), expected_data).await?;
199 }
200 }
201
202 Ok(())
203}
204
205async fn consume_flight_location(
206 location: Location,
207 ticket: Ticket,
208 expected_data: &[RecordBatch],
209) -> Result {
210 let mut location = location;
211 location.uri = location.uri.replace("grpc+tcp://", "http://");
215
216 let mut client = FlightServiceClient::connect(location.uri).await?;
217 let resp = client.do_get(ticket).await?;
218 let mut resp = resp.into_inner();
219
220 let flight_schema = receive_schema_flight_data(&mut resp)
221 .await
222 .unwrap_or_else(|| panic!("Failed to receive flight schema"));
223 let actual_schema = Arc::new(flight_schema);
224
225 let mut dictionaries_by_id = HashMap::new();
226
227 for (counter, expected_batch) in expected_data.iter().enumerate() {
228 let data =
229 receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id)
230 .await
231 .unwrap_or_else(|| {
232 panic!(
233 "Got fewer batches than expected, received so far: {} expected: {}",
234 counter,
235 expected_data.len(),
236 )
237 });
238
239 let metadata = counter.to_string().into_bytes();
240 assert_eq!(metadata, data.app_metadata);
241
242 let actual_batch =
243 flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
244 .expect("Unable to convert flight data to Arrow batch");
245
246 assert_eq!(actual_schema, actual_batch.schema());
247 assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
248 assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
249 let schema = expected_batch.schema();
250 for i in 0..expected_batch.num_columns() {
251 let field = schema.field(i);
252 let field_name = field.name();
253
254 let expected_data = expected_batch.column(i).as_ref();
255 let actual_data = actual_batch.column(i).as_ref();
256
257 assert_eq!(expected_data, actual_data, "Data for field {field_name}");
258 }
259 }
260
261 assert!(
262 resp.next().await.is_none(),
263 "Got more batches than the expected: {}",
264 expected_data.len(),
265 );
266
267 Ok(())
268}
269
270async fn receive_schema_flight_data(resp: &mut Streaming<FlightData>) -> Option<Schema> {
271 let data = resp.next().await?.ok()?;
272 let message =
273 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
274
275 let ipc_schema: ipc::Schema = message
277 .header_as_schema()
278 .expect("Unable to read IPC message as schema");
279 let schema = ipc::convert::fb_to_schema(ipc_schema);
280
281 Some(schema)
282}
283
284async fn receive_batch_flight_data(
285 resp: &mut Streaming<FlightData>,
286 schema: SchemaRef,
287 dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
288) -> Option<FlightData> {
289 let mut data = resp.next().await?.ok()?;
290 let mut message =
291 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message");
292
293 while message.header_type() == ipc::MessageHeader::DictionaryBatch {
294 reader::read_dictionary(
295 &Buffer::from(data.data_body.as_ref()),
296 message
297 .header_as_dictionary_batch()
298 .expect("Error parsing dictionary"),
299 &schema,
300 dictionaries_by_id,
301 &message.version(),
302 )
303 .expect("Error reading dictionary");
304
305 data = resp.next().await?.ok()?;
306 message =
307 arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message");
308 }
309
310 Some(data)
311}