1use arrow_buffer::Buffer;
21use arrow_ipc::MessageHeader;
22use arrow_ipc::convert::fb_to_schema;
23use arrow_ipc::reader::read_record_batch;
24use arrow_ipc::root_as_message;
25use arrow_schema::SchemaRef;
26use base64::Engine;
27use base64::prelude::BASE64_STANDARD;
28use bytes::Bytes;
29use std::collections::HashMap;
30use std::str::FromStr;
31use tonic::metadata::AsciiMetadataKey;
32
33use crate::decode::FlightRecordBatchStream;
34use crate::encode::FlightDataEncoderBuilder;
35use crate::error::FlightError;
36use crate::error::Result;
37use crate::flight_service_client::FlightServiceClient;
38use crate::sql::r#gen::action_end_transaction_request::EndTransaction;
39use crate::sql::server::{
40 BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION,
41};
42use crate::sql::{
43 ActionBeginTransactionRequest, ActionBeginTransactionResult,
44 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
45 ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs,
46 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
47 CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
48 CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
49 CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
50 DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
51};
52use crate::streams::FallibleRequestStream;
53use crate::trailers::extract_lazy_trailers;
54use crate::{
55 Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
56 IpcMessage, PutResult, Ticket,
57};
58use arrow_array::RecordBatch;
59use arrow_schema::{ArrowError, Schema};
60use futures::{Stream, TryStreamExt, stream};
61use prost::Message;
62use tonic::codegen::{Body, StdError};
63use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
64
65#[derive(Debug)]
68pub struct FlightSqlServiceClient<T> {
69 token: Option<String>,
70 headers: HashMap<String, String>,
71 flight_client: FlightServiceClient<T>,
72}
73
74impl<T> FlightSqlServiceClient<T>
78where
79 T: tonic::client::GrpcService<tonic::body::Body>,
80 T::Error: Into<StdError>,
81 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
82 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
83{
84 pub fn new(channel: T) -> Self {
86 Self::new_from_inner(FlightServiceClient::new(channel))
87 }
88
89 pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
91 Self {
92 token: None,
93 flight_client: inner,
94 headers: HashMap::default(),
95 }
96 }
97
98 pub fn inner(&self) -> &FlightServiceClient<T> {
100 &self.flight_client
101 }
102
103 pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
105 &mut self.flight_client
106 }
107
108 pub fn into_inner(self) -> FlightServiceClient<T> {
110 self.flight_client
111 }
112
113 pub fn set_token(&mut self, token: String) {
115 self.token = Some(token);
116 }
117
118 pub fn clear_token(&mut self) {
120 self.token = None;
121 }
122
123 pub fn token(&self) -> Option<&String> {
125 self.token.as_ref()
126 }
127
128 pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
130 let key: String = key.into();
131 let value: String = value.into();
132 self.headers.insert(key, value);
133 }
134
135 async fn get_flight_info_for_command<M: ProstMessageExt>(
136 &mut self,
137 cmd: M,
138 ) -> Result<FlightInfo> {
139 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
140 let req = self.set_request_headers(descriptor.into_request())?;
141 let fi = self.flight_client.get_flight_info(req).await?.into_inner();
142 Ok(fi)
143 }
144
145 pub async fn execute(
147 &mut self,
148 query: String,
149 transaction_id: Option<Bytes>,
150 ) -> Result<FlightInfo> {
151 let cmd = CommandStatementQuery {
152 query,
153 transaction_id,
154 };
155 self.get_flight_info_for_command(cmd).await
156 }
157
158 pub async fn handshake(&mut self, username: &str, password: &str) -> Result<Bytes> {
164 let cmd = HandshakeRequest {
165 protocol_version: 0,
166 payload: Default::default(),
167 };
168 let mut req = tonic::Request::new(stream::iter(vec![cmd]));
169 let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
170 let val = format!("Basic {val}")
171 .parse()
172 .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
173 req.metadata_mut().insert("authorization", val);
174 let req = self.set_request_headers(req)?;
175 let resp = self
176 .flight_client
177 .handshake(req)
178 .await
179 .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?;
180 if let Some(auth) = resp.metadata().get("authorization") {
181 let auth = auth
182 .to_str()
183 .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?;
184 let bearer = "Bearer ";
185 if !auth.starts_with(bearer) {
186 return Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
187 }
188 let auth = auth[bearer.len()..].to_string();
189 self.token = Some(auth);
190 }
191 let responses: Vec<HandshakeResponse> = resp
192 .into_inner()
193 .try_collect()
194 .await
195 .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?;
196 let resp = match responses.as_slice() {
197 [resp] => resp.payload.clone(),
198 [] => Bytes::new(),
199 _ => Err(ArrowError::ParseError(
200 "Multiple handshake responses".to_string(),
201 ))?,
202 };
203 Ok(resp)
204 }
205
206 pub async fn execute_update(
208 &mut self,
209 query: String,
210 transaction_id: Option<Bytes>,
211 ) -> Result<i64> {
212 let cmd = CommandStatementUpdate {
213 query,
214 transaction_id,
215 };
216 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
217 let req = self.set_request_headers(
218 stream::iter(vec![FlightData {
219 flight_descriptor: Some(descriptor),
220 ..Default::default()
221 }])
222 .into_request(),
223 )?;
224 let mut result = self.flight_client.do_put(req).await?.into_inner();
225 let result = result.message().await?.unwrap();
226 let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?;
227 Ok(result.record_count)
228 }
229
230 pub async fn execute_ingest<S>(
232 &mut self,
233 command: CommandStatementIngest,
234 stream: S,
235 ) -> Result<i64>
236 where
237 S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
238 {
239 let (sender, receiver) = futures::channel::oneshot::channel();
240
241 let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
242 let flight_data = FlightDataEncoderBuilder::new()
243 .with_flight_descriptor(Some(descriptor))
244 .build(stream);
245
246 let flight_data = Box::pin(flight_data);
248 let flight_data: FallibleRequestStream<FlightData, FlightError> =
249 FallibleRequestStream::new(sender, flight_data);
250
251 let req = self.set_request_headers(flight_data.into_streaming_request())?;
252 let mut result = self.flight_client.do_put(req).await?.into_inner();
253
254 if let Ok(msg) = receiver.await {
258 return Err(FlightError::ExternalError(Box::new(msg)));
259 }
260
261 let result = result.message().await?.unwrap();
262 let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?;
263 Ok(result.record_count)
264 }
265
266 pub async fn get_catalogs(&mut self) -> Result<FlightInfo> {
268 self.get_flight_info_for_command(CommandGetCatalogs {})
269 .await
270 }
271
272 pub async fn get_db_schemas(&mut self, request: CommandGetDbSchemas) -> Result<FlightInfo> {
274 self.get_flight_info_for_command(request).await
275 }
276
277 pub async fn do_get(
279 &mut self,
280 ticket: impl IntoRequest<Ticket>,
281 ) -> Result<FlightRecordBatchStream> {
282 let req = self.set_request_headers(ticket.into_request())?;
283
284 let (md, response_stream, _ext) = self.flight_client.do_get(req).await?.into_parts();
285 let (response_stream, trailers) = extract_lazy_trailers(response_stream);
286
287 Ok(FlightRecordBatchStream::new_from_flight_data(
288 response_stream.map_err(|status| status.into()),
289 )
290 .with_headers(md)
291 .with_trailers(trailers))
292 }
293
294 pub async fn do_put(
296 &mut self,
297 request: impl tonic::IntoStreamingRequest<Message = FlightData>,
298 ) -> Result<Streaming<PutResult>> {
299 let req = self.set_request_headers(request.into_streaming_request())?;
300 Ok(self.flight_client.do_put(req).await?.into_inner())
301 }
302
303 pub async fn do_action(
305 &mut self,
306 request: impl IntoRequest<Action>,
307 ) -> Result<Streaming<crate::Result>> {
308 let req = self.set_request_headers(request.into_request())?;
309 Ok(self.flight_client.do_action(req).await?.into_inner())
310 }
311
312 pub async fn get_tables(&mut self, request: CommandGetTables) -> Result<FlightInfo> {
314 self.get_flight_info_for_command(request).await
315 }
316
317 pub async fn get_primary_keys(&mut self, request: CommandGetPrimaryKeys) -> Result<FlightInfo> {
319 self.get_flight_info_for_command(request).await
320 }
321
322 pub async fn get_exported_keys(
325 &mut self,
326 request: CommandGetExportedKeys,
327 ) -> Result<FlightInfo> {
328 self.get_flight_info_for_command(request).await
329 }
330
331 pub async fn get_imported_keys(
333 &mut self,
334 request: CommandGetImportedKeys,
335 ) -> Result<FlightInfo> {
336 self.get_flight_info_for_command(request).await
337 }
338
339 pub async fn get_cross_reference(
343 &mut self,
344 request: CommandGetCrossReference,
345 ) -> Result<FlightInfo> {
346 self.get_flight_info_for_command(request).await
347 }
348
349 pub async fn get_table_types(&mut self) -> Result<FlightInfo> {
351 self.get_flight_info_for_command(CommandGetTableTypes {})
352 .await
353 }
354
355 pub async fn get_sql_info(&mut self, sql_infos: Vec<SqlInfo>) -> Result<FlightInfo> {
357 let request = CommandGetSqlInfo {
358 info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
359 };
360 self.get_flight_info_for_command(request).await
361 }
362
363 pub async fn get_xdbc_type_info(
365 &mut self,
366 request: CommandGetXdbcTypeInfo,
367 ) -> Result<FlightInfo> {
368 self.get_flight_info_for_command(request).await
369 }
370
371 pub async fn prepare(
373 &mut self,
374 query: String,
375 transaction_id: Option<Bytes>,
376 ) -> Result<PreparedStatement<T>>
377 where
378 T: Clone,
379 {
380 let cmd = ActionCreatePreparedStatementRequest {
381 query,
382 transaction_id,
383 };
384 let action = Action {
385 r#type: CREATE_PREPARED_STATEMENT.to_string(),
386 body: cmd.as_any().encode_to_vec().into(),
387 };
388 let req = self.set_request_headers(action.into_request())?;
389 let mut result = self.flight_client.do_action(req).await?.into_inner();
390 let result = result.message().await?.unwrap();
391 let any = Any::decode(&*result.body)?;
392 let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
393 let dataset_schema = match prepared_result.dataset_schema.len() {
394 0 => Schema::empty(),
395 _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
396 };
397 let parameter_schema = match prepared_result.parameter_schema.len() {
398 0 => Schema::empty(),
399 _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
400 };
401 Ok(PreparedStatement::new(
402 self.clone(),
403 prepared_result.prepared_statement_handle,
404 dataset_schema,
405 parameter_schema,
406 ))
407 }
408
409 pub async fn begin_transaction(&mut self) -> Result<Bytes> {
411 let cmd = ActionBeginTransactionRequest {};
412 let action = Action {
413 r#type: BEGIN_TRANSACTION.to_string(),
414 body: cmd.as_any().encode_to_vec().into(),
415 };
416 let req = self.set_request_headers(action.into_request())?;
417 let mut result = self.flight_client.do_action(req).await?.into_inner();
418 let result = result.message().await?.unwrap();
419 let any = Any::decode(&*result.body)?;
420 let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap();
421 Ok(begin_result.transaction_id)
422 }
423
424 pub async fn end_transaction(
426 &mut self,
427 transaction_id: Bytes,
428 action: EndTransaction,
429 ) -> Result<()> {
430 let cmd = ActionEndTransactionRequest {
431 transaction_id,
432 action: action as i32,
433 };
434 let action = Action {
435 r#type: END_TRANSACTION.to_string(),
436 body: cmd.as_any().encode_to_vec().into(),
437 };
438 let req = self.set_request_headers(action.into_request())?;
439 let _ = self.flight_client.do_action(req).await?.into_inner();
440 Ok(())
441 }
442
443 pub async fn close(&mut self) -> Result<()> {
445 Ok(())
447 }
448
449 fn set_request_headers<M>(&self, mut req: tonic::Request<M>) -> Result<tonic::Request<M>> {
450 for (k, v) in &self.headers {
451 let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
452 ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
453 })?;
454 let v = v.parse().map_err(|e| {
455 ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}"))
456 })?;
457 req.metadata_mut().insert(k, v);
458 }
459 if let Some(token) = &self.token {
460 let val = format!("Bearer {token}").parse().map_err(|e| {
461 ArrowError::ParseError(format!("Cannot convert token to header value: {e}"))
462 })?;
463 req.metadata_mut().insert("authorization", val);
464 }
465 Ok(req)
466 }
467}
468
469impl<T: Clone> Clone for FlightSqlServiceClient<T> {
470 fn clone(&self) -> Self {
471 Self {
472 headers: self.headers.clone(),
473 token: self.token.clone(),
474 flight_client: self.flight_client.clone(),
475 }
476 }
477}
478
479#[derive(Debug, Clone)]
481pub struct PreparedStatement<T> {
482 flight_sql_client: FlightSqlServiceClient<T>,
483 parameter_binding: Option<RecordBatch>,
484 handle: Bytes,
485 dataset_schema: Schema,
486 parameter_schema: Schema,
487}
488
489impl<T> PreparedStatement<T>
490where
491 T: tonic::client::GrpcService<tonic::body::Body>,
492 T::Error: Into<StdError>,
493 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
494 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
495{
496 pub(crate) fn new(
497 flight_client: FlightSqlServiceClient<T>,
498 handle: impl Into<Bytes>,
499 dataset_schema: Schema,
500 parameter_schema: Schema,
501 ) -> Self {
502 PreparedStatement {
503 flight_sql_client: flight_client,
504 parameter_binding: None,
505 handle: handle.into(),
506 dataset_schema,
507 parameter_schema,
508 }
509 }
510
511 pub async fn execute(&mut self) -> Result<FlightInfo> {
513 self.write_bind_params().await?;
514
515 let cmd = CommandPreparedStatementQuery {
516 prepared_statement_handle: self.handle.clone(),
517 };
518
519 let result = self
520 .flight_sql_client
521 .get_flight_info_for_command(cmd)
522 .await?;
523 Ok(result)
524 }
525
526 pub async fn execute_update(&mut self) -> Result<i64> {
528 self.write_bind_params().await?;
529
530 let cmd = CommandPreparedStatementUpdate {
531 prepared_statement_handle: self.handle.clone(),
532 };
533 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
534 let mut result = self
535 .flight_sql_client
536 .do_put(stream::iter(vec![FlightData {
537 flight_descriptor: Some(descriptor),
538 ..Default::default()
539 }]))
540 .await?;
541 let result = result.message().await?.unwrap();
542 let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?;
543 Ok(result.record_count)
544 }
545
546 pub fn parameter_schema(&self) -> Result<&Schema> {
548 Ok(&self.parameter_schema)
549 }
550
551 pub fn dataset_schema(&self) -> Result<&Schema> {
553 Ok(&self.dataset_schema)
554 }
555
556 pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<()> {
558 self.parameter_binding = Some(parameter_binding);
559 Ok(())
560 }
561
562 async fn write_bind_params(&mut self) -> Result<()> {
565 if let Some(ref params_batch) = self.parameter_binding {
566 let cmd = CommandPreparedStatementQuery {
567 prepared_statement_handle: self.handle.clone(),
568 };
569
570 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
571 let flight_stream_builder = FlightDataEncoderBuilder::new()
572 .with_flight_descriptor(Some(descriptor))
573 .with_schema(params_batch.schema());
574 let flight_data = flight_stream_builder
575 .build(futures::stream::iter(
576 self.parameter_binding.clone().map(Ok),
577 ))
578 .try_collect::<Vec<_>>()
579 .await?;
580
581 if let Some(result) = self
585 .flight_sql_client
586 .do_put(stream::iter(flight_data))
587 .await?
588 .message()
589 .await?
590 {
591 if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
592 self.handle = handle;
593 }
594 }
595 }
596 Ok(())
597 }
598
599 fn unpack_prepared_statement_handle(&self, put_result: &PutResult) -> Result<Option<Bytes>> {
603 let result: DoPutPreparedStatementResult = Message::decode(&*put_result.app_metadata)?;
604 Ok(result.prepared_statement_handle)
605 }
606
607 pub async fn close(mut self) -> Result<()> {
610 let cmd = ActionClosePreparedStatementRequest {
611 prepared_statement_handle: self.handle.clone(),
612 };
613 let action = Action {
614 r#type: CLOSE_PREPARED_STATEMENT.to_string(),
615 body: cmd.as_any().encode_to_vec().into(),
616 };
617 let _ = self.flight_sql_client.do_action(action).await?;
618 Ok(())
619 }
620}
621
622pub enum ArrowFlightData {
624 RecordBatch(RecordBatch),
626 Schema(Schema),
628}
629
630pub fn arrow_data_from_flight_data(
632 flight_data: FlightData,
633 arrow_schema_ref: &SchemaRef,
634) -> std::result::Result<ArrowFlightData, ArrowError> {
635 let ipc_message = root_as_message(&flight_data.data_header[..])
636 .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
637
638 match ipc_message.header_type() {
639 MessageHeader::RecordBatch => {
640 let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| {
641 ArrowError::ComputeError(
642 "Unable to convert flight data header to a record batch".to_string(),
643 )
644 })?;
645
646 let dictionaries_by_field = HashMap::new();
647 let record_batch = read_record_batch(
648 &Buffer::from(flight_data.data_body),
649 ipc_record_batch,
650 arrow_schema_ref.clone(),
651 &dictionaries_by_field,
652 None,
653 &ipc_message.version(),
654 )?;
655 Ok(ArrowFlightData::RecordBatch(record_batch))
656 }
657 MessageHeader::Schema => {
658 let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
659 ArrowError::ComputeError(
660 "Unable to convert flight data header to a schema".to_string(),
661 )
662 })?;
663
664 let arrow_schema = fb_to_schema(ipc_schema);
665 Ok(ArrowFlightData::Schema(arrow_schema))
666 }
667 MessageHeader::DictionaryBatch => {
668 let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
669 ArrowError::ComputeError(
670 "Unable to convert flight data header to a dictionary batch".to_string(),
671 )
672 })?;
673 Err(ArrowError::NotYetImplemented(
674 "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(),
675 ))
676 }
677 MessageHeader::Tensor => {
678 let _ = ipc_message.header_as_tensor().ok_or_else(|| {
679 ArrowError::ComputeError(
680 "Unable to convert flight data header to a tensor".to_string(),
681 )
682 })?;
683 Err(ArrowError::NotYetImplemented(
684 "no idea on how to convert an ipc tensor to an arrow type".to_string(),
685 ))
686 }
687 MessageHeader::SparseTensor => {
688 let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
689 ArrowError::ComputeError(
690 "Unable to convert flight data header to a sparse tensor".to_string(),
691 )
692 })?;
693 Err(ArrowError::NotYetImplemented(
694 "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(),
695 ))
696 }
697 _ => Err(ArrowError::ComputeError(format!(
698 "Unable to convert message with header_type: '{:?}' to arrow data",
699 ipc_message.header_type()
700 ))),
701 }
702}