1use base64::prelude::BASE64_STANDARD;
21use base64::Engine;
22use bytes::Bytes;
23use std::collections::HashMap;
24use std::str::FromStr;
25use tonic::metadata::AsciiMetadataKey;
26
27use crate::decode::FlightRecordBatchStream;
28use crate::encode::FlightDataEncoderBuilder;
29use crate::error::FlightError;
30use crate::flight_service_client::FlightServiceClient;
31use crate::sql::gen::action_end_transaction_request::EndTransaction;
32use crate::sql::server::{
33 BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION,
34};
35use crate::sql::{
36 ActionBeginTransactionRequest, ActionBeginTransactionResult,
37 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
38 ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs,
39 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
40 CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
41 CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
42 CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
43 DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
44};
45use crate::streams::FallibleRequestStream;
46use crate::trailers::extract_lazy_trailers;
47use crate::{
48 Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
49 IpcMessage, PutResult, Ticket,
50};
51use arrow_array::RecordBatch;
52use arrow_buffer::Buffer;
53use arrow_ipc::convert::fb_to_schema;
54use arrow_ipc::reader::read_record_batch;
55use arrow_ipc::{root_as_message, MessageHeader};
56use arrow_schema::{ArrowError, Schema, SchemaRef};
57use futures::{stream, Stream, TryStreamExt};
58use prost::Message;
59use tonic::transport::Channel;
60use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
61
62#[derive(Debug, Clone)]
65pub struct FlightSqlServiceClient<T> {
66 token: Option<String>,
67 headers: HashMap<String, String>,
68 flight_client: FlightServiceClient<T>,
69}
70
71impl FlightSqlServiceClient<Channel> {
75 pub fn new(channel: Channel) -> Self {
77 Self::new_from_inner(FlightServiceClient::new(channel))
78 }
79
80 pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
82 Self {
83 token: None,
84 flight_client: inner,
85 headers: HashMap::default(),
86 }
87 }
88
89 pub fn inner(&self) -> &FlightServiceClient<Channel> {
91 &self.flight_client
92 }
93
94 pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
96 &mut self.flight_client
97 }
98
99 pub fn into_inner(self) -> FlightServiceClient<Channel> {
101 self.flight_client
102 }
103
104 pub fn set_token(&mut self, token: String) {
106 self.token = Some(token);
107 }
108
109 pub fn clear_token(&mut self) {
111 self.token = None;
112 }
113
114 pub fn token(&self) -> Option<&String> {
116 self.token.as_ref()
117 }
118
119 pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
121 let key: String = key.into();
122 let value: String = value.into();
123 self.headers.insert(key, value);
124 }
125
126 async fn get_flight_info_for_command<M: ProstMessageExt>(
127 &mut self,
128 cmd: M,
129 ) -> Result<FlightInfo, ArrowError> {
130 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
131 let req = self.set_request_headers(descriptor.into_request())?;
132 let fi = self
133 .flight_client
134 .get_flight_info(req)
135 .await
136 .map_err(status_to_arrow_error)?
137 .into_inner();
138 Ok(fi)
139 }
140
141 pub async fn execute(
143 &mut self,
144 query: String,
145 transaction_id: Option<Bytes>,
146 ) -> Result<FlightInfo, ArrowError> {
147 let cmd = CommandStatementQuery {
148 query,
149 transaction_id,
150 };
151 self.get_flight_info_for_command(cmd).await
152 }
153
154 pub async fn handshake(&mut self, username: &str, password: &str) -> Result<Bytes, ArrowError> {
160 let cmd = HandshakeRequest {
161 protocol_version: 0,
162 payload: Default::default(),
163 };
164 let mut req = tonic::Request::new(stream::iter(vec![cmd]));
165 let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
166 let val = format!("Basic {val}")
167 .parse()
168 .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
169 req.metadata_mut().insert("authorization", val);
170 let req = self.set_request_headers(req)?;
171 let resp = self
172 .flight_client
173 .handshake(req)
174 .await
175 .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?;
176 if let Some(auth) = resp.metadata().get("authorization") {
177 let auth = auth
178 .to_str()
179 .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?;
180 let bearer = "Bearer ";
181 if !auth.starts_with(bearer) {
182 Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
183 }
184 let auth = auth[bearer.len()..].to_string();
185 self.token = Some(auth);
186 }
187 let responses: Vec<HandshakeResponse> = resp
188 .into_inner()
189 .try_collect()
190 .await
191 .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?;
192 let resp = match responses.as_slice() {
193 [resp] => resp.payload.clone(),
194 [] => Bytes::new(),
195 _ => Err(ArrowError::ParseError(
196 "Multiple handshake responses".to_string(),
197 ))?,
198 };
199 Ok(resp)
200 }
201
202 pub async fn execute_update(
204 &mut self,
205 query: String,
206 transaction_id: Option<Bytes>,
207 ) -> Result<i64, ArrowError> {
208 let cmd = CommandStatementUpdate {
209 query,
210 transaction_id,
211 };
212 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
213 let req = self.set_request_headers(
214 stream::iter(vec![FlightData {
215 flight_descriptor: Some(descriptor),
216 ..Default::default()
217 }])
218 .into_request(),
219 )?;
220 let mut result = self
221 .flight_client
222 .do_put(req)
223 .await
224 .map_err(status_to_arrow_error)?
225 .into_inner();
226 let result = result
227 .message()
228 .await
229 .map_err(status_to_arrow_error)?
230 .unwrap();
231 let result: DoPutUpdateResult =
232 Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
233 Ok(result.record_count)
234 }
235
236 pub async fn execute_ingest<S>(
238 &mut self,
239 command: CommandStatementIngest,
240 stream: S,
241 ) -> Result<i64, ArrowError>
242 where
243 S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
244 {
245 let (sender, receiver) = futures::channel::oneshot::channel();
246
247 let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
248 let flight_data = FlightDataEncoderBuilder::new()
249 .with_flight_descriptor(Some(descriptor))
250 .build(stream);
251
252 let flight_data = Box::pin(flight_data);
254 let flight_data: FallibleRequestStream<FlightData, FlightError> =
255 FallibleRequestStream::new(sender, flight_data);
256
257 let req = self.set_request_headers(flight_data.into_streaming_request())?;
258 let mut result = self
259 .flight_client
260 .do_put(req)
261 .await
262 .map_err(status_to_arrow_error)?
263 .into_inner();
264
265 if let Ok(msg) = receiver.await {
269 return Err(ArrowError::ExternalError(Box::new(msg)));
270 }
271
272 let result = result
273 .message()
274 .await
275 .map_err(status_to_arrow_error)?
276 .unwrap();
277 let result: DoPutUpdateResult =
278 Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
279 Ok(result.record_count)
280 }
281
282 pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
284 self.get_flight_info_for_command(CommandGetCatalogs {})
285 .await
286 }
287
288 pub async fn get_db_schemas(
290 &mut self,
291 request: CommandGetDbSchemas,
292 ) -> Result<FlightInfo, ArrowError> {
293 self.get_flight_info_for_command(request).await
294 }
295
296 pub async fn do_get(
298 &mut self,
299 ticket: impl IntoRequest<Ticket>,
300 ) -> Result<FlightRecordBatchStream, ArrowError> {
301 let req = self.set_request_headers(ticket.into_request())?;
302
303 let (md, response_stream, _ext) = self
304 .flight_client
305 .do_get(req)
306 .await
307 .map_err(status_to_arrow_error)?
308 .into_parts();
309 let (response_stream, trailers) = extract_lazy_trailers(response_stream);
310
311 Ok(FlightRecordBatchStream::new_from_flight_data(
312 response_stream.map_err(FlightError::Tonic),
313 )
314 .with_headers(md)
315 .with_trailers(trailers))
316 }
317
318 pub async fn do_put(
320 &mut self,
321 request: impl tonic::IntoStreamingRequest<Message = FlightData>,
322 ) -> Result<Streaming<PutResult>, ArrowError> {
323 let req = self.set_request_headers(request.into_streaming_request())?;
324 Ok(self
325 .flight_client
326 .do_put(req)
327 .await
328 .map_err(status_to_arrow_error)?
329 .into_inner())
330 }
331
332 pub async fn do_action(
334 &mut self,
335 request: impl IntoRequest<Action>,
336 ) -> Result<Streaming<crate::Result>, ArrowError> {
337 let req = self.set_request_headers(request.into_request())?;
338 Ok(self
339 .flight_client
340 .do_action(req)
341 .await
342 .map_err(status_to_arrow_error)?
343 .into_inner())
344 }
345
346 pub async fn get_tables(
348 &mut self,
349 request: CommandGetTables,
350 ) -> Result<FlightInfo, ArrowError> {
351 self.get_flight_info_for_command(request).await
352 }
353
354 pub async fn get_primary_keys(
356 &mut self,
357 request: CommandGetPrimaryKeys,
358 ) -> Result<FlightInfo, ArrowError> {
359 self.get_flight_info_for_command(request).await
360 }
361
362 pub async fn get_exported_keys(
365 &mut self,
366 request: CommandGetExportedKeys,
367 ) -> Result<FlightInfo, ArrowError> {
368 self.get_flight_info_for_command(request).await
369 }
370
371 pub async fn get_imported_keys(
373 &mut self,
374 request: CommandGetImportedKeys,
375 ) -> Result<FlightInfo, ArrowError> {
376 self.get_flight_info_for_command(request).await
377 }
378
379 pub async fn get_cross_reference(
383 &mut self,
384 request: CommandGetCrossReference,
385 ) -> Result<FlightInfo, ArrowError> {
386 self.get_flight_info_for_command(request).await
387 }
388
389 pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
391 self.get_flight_info_for_command(CommandGetTableTypes {})
392 .await
393 }
394
395 pub async fn get_sql_info(
397 &mut self,
398 sql_infos: Vec<SqlInfo>,
399 ) -> Result<FlightInfo, ArrowError> {
400 let request = CommandGetSqlInfo {
401 info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
402 };
403 self.get_flight_info_for_command(request).await
404 }
405
406 pub async fn get_xdbc_type_info(
408 &mut self,
409 request: CommandGetXdbcTypeInfo,
410 ) -> Result<FlightInfo, ArrowError> {
411 self.get_flight_info_for_command(request).await
412 }
413
414 pub async fn prepare(
416 &mut self,
417 query: String,
418 transaction_id: Option<Bytes>,
419 ) -> Result<PreparedStatement<Channel>, ArrowError> {
420 let cmd = ActionCreatePreparedStatementRequest {
421 query,
422 transaction_id,
423 };
424 let action = Action {
425 r#type: CREATE_PREPARED_STATEMENT.to_string(),
426 body: cmd.as_any().encode_to_vec().into(),
427 };
428 let req = self.set_request_headers(action.into_request())?;
429 let mut result = self
430 .flight_client
431 .do_action(req)
432 .await
433 .map_err(status_to_arrow_error)?
434 .into_inner();
435 let result = result
436 .message()
437 .await
438 .map_err(status_to_arrow_error)?
439 .unwrap();
440 let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
441 let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
442 let dataset_schema = match prepared_result.dataset_schema.len() {
443 0 => Schema::empty(),
444 _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
445 };
446 let parameter_schema = match prepared_result.parameter_schema.len() {
447 0 => Schema::empty(),
448 _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
449 };
450 Ok(PreparedStatement::new(
451 self.clone(),
452 prepared_result.prepared_statement_handle,
453 dataset_schema,
454 parameter_schema,
455 ))
456 }
457
458 pub async fn begin_transaction(&mut self) -> Result<Bytes, ArrowError> {
460 let cmd = ActionBeginTransactionRequest {};
461 let action = Action {
462 r#type: BEGIN_TRANSACTION.to_string(),
463 body: cmd.as_any().encode_to_vec().into(),
464 };
465 let req = self.set_request_headers(action.into_request())?;
466 let mut result = self
467 .flight_client
468 .do_action(req)
469 .await
470 .map_err(status_to_arrow_error)?
471 .into_inner();
472 let result = result
473 .message()
474 .await
475 .map_err(status_to_arrow_error)?
476 .unwrap();
477 let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
478 let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap();
479 Ok(begin_result.transaction_id)
480 }
481
482 pub async fn end_transaction(
484 &mut self,
485 transaction_id: Bytes,
486 action: EndTransaction,
487 ) -> Result<(), ArrowError> {
488 let cmd = ActionEndTransactionRequest {
489 transaction_id,
490 action: action as i32,
491 };
492 let action = Action {
493 r#type: END_TRANSACTION.to_string(),
494 body: cmd.as_any().encode_to_vec().into(),
495 };
496 let req = self.set_request_headers(action.into_request())?;
497 let _ = self
498 .flight_client
499 .do_action(req)
500 .await
501 .map_err(status_to_arrow_error)?
502 .into_inner();
503 Ok(())
504 }
505
506 pub async fn close(&mut self) -> Result<(), ArrowError> {
508 Ok(())
510 }
511
512 fn set_request_headers<T>(
513 &self,
514 mut req: tonic::Request<T>,
515 ) -> Result<tonic::Request<T>, ArrowError> {
516 for (k, v) in &self.headers {
517 let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
518 ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
519 })?;
520 let v = v.parse().map_err(|e| {
521 ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}"))
522 })?;
523 req.metadata_mut().insert(k, v);
524 }
525 if let Some(token) = &self.token {
526 let val = format!("Bearer {token}").parse().map_err(|e| {
527 ArrowError::ParseError(format!("Cannot convert token to header value: {e}"))
528 })?;
529 req.metadata_mut().insert("authorization", val);
530 }
531 Ok(req)
532 }
533}
534
535#[derive(Debug, Clone)]
537pub struct PreparedStatement<T> {
538 flight_sql_client: FlightSqlServiceClient<T>,
539 parameter_binding: Option<RecordBatch>,
540 handle: Bytes,
541 dataset_schema: Schema,
542 parameter_schema: Schema,
543}
544
545impl PreparedStatement<Channel> {
546 pub(crate) fn new(
547 flight_client: FlightSqlServiceClient<Channel>,
548 handle: impl Into<Bytes>,
549 dataset_schema: Schema,
550 parameter_schema: Schema,
551 ) -> Self {
552 PreparedStatement {
553 flight_sql_client: flight_client,
554 parameter_binding: None,
555 handle: handle.into(),
556 dataset_schema,
557 parameter_schema,
558 }
559 }
560
561 pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
563 self.write_bind_params().await?;
564
565 let cmd = CommandPreparedStatementQuery {
566 prepared_statement_handle: self.handle.clone(),
567 };
568
569 let result = self
570 .flight_sql_client
571 .get_flight_info_for_command(cmd)
572 .await?;
573 Ok(result)
574 }
575
576 pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
578 self.write_bind_params().await?;
579
580 let cmd = CommandPreparedStatementUpdate {
581 prepared_statement_handle: self.handle.clone(),
582 };
583 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
584 let mut result = self
585 .flight_sql_client
586 .do_put(stream::iter(vec![FlightData {
587 flight_descriptor: Some(descriptor),
588 ..Default::default()
589 }]))
590 .await?;
591 let result = result
592 .message()
593 .await
594 .map_err(status_to_arrow_error)?
595 .unwrap();
596 let result: DoPutUpdateResult =
597 Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
598 Ok(result.record_count)
599 }
600
601 pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
603 Ok(&self.parameter_schema)
604 }
605
606 pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
608 Ok(&self.dataset_schema)
609 }
610
611 pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> {
613 self.parameter_binding = Some(parameter_binding);
614 Ok(())
615 }
616
617 async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
620 if let Some(ref params_batch) = self.parameter_binding {
621 let cmd = CommandPreparedStatementQuery {
622 prepared_statement_handle: self.handle.clone(),
623 };
624
625 let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
626 let flight_stream_builder = FlightDataEncoderBuilder::new()
627 .with_flight_descriptor(Some(descriptor))
628 .with_schema(params_batch.schema());
629 let flight_data = flight_stream_builder
630 .build(futures::stream::iter(
631 self.parameter_binding.clone().map(Ok),
632 ))
633 .try_collect::<Vec<_>>()
634 .await
635 .map_err(flight_error_to_arrow_error)?;
636
637 if let Some(result) = self
641 .flight_sql_client
642 .do_put(stream::iter(flight_data))
643 .await?
644 .message()
645 .await
646 .map_err(status_to_arrow_error)?
647 {
648 if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
649 self.handle = handle;
650 }
651 }
652 }
653 Ok(())
654 }
655
656 fn unpack_prepared_statement_handle(
660 &self,
661 put_result: &PutResult,
662 ) -> Result<Option<Bytes>, ArrowError> {
663 let result: DoPutPreparedStatementResult =
664 Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
665 Ok(result.prepared_statement_handle)
666 }
667
668 pub async fn close(mut self) -> Result<(), ArrowError> {
671 let cmd = ActionClosePreparedStatementRequest {
672 prepared_statement_handle: self.handle.clone(),
673 };
674 let action = Action {
675 r#type: CLOSE_PREPARED_STATEMENT.to_string(),
676 body: cmd.as_any().encode_to_vec().into(),
677 };
678 let _ = self.flight_sql_client.do_action(action).await?;
679 Ok(())
680 }
681}
682
683fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
684 ArrowError::IpcError(err.to_string())
685}
686
687fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
688 ArrowError::IpcError(format!("{status:?}"))
689}
690
691fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
692 match err {
693 FlightError::Arrow(e) => e,
694 e => ArrowError::ExternalError(Box::new(e)),
695 }
696}
697
698pub enum ArrowFlightData {
700 RecordBatch(RecordBatch),
702 Schema(Schema),
704}
705
706pub fn arrow_data_from_flight_data(
708 flight_data: FlightData,
709 arrow_schema_ref: &SchemaRef,
710) -> Result<ArrowFlightData, ArrowError> {
711 let ipc_message = root_as_message(&flight_data.data_header[..])
712 .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
713
714 match ipc_message.header_type() {
715 MessageHeader::RecordBatch => {
716 let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| {
717 ArrowError::ComputeError(
718 "Unable to convert flight data header to a record batch".to_string(),
719 )
720 })?;
721
722 let dictionaries_by_field = HashMap::new();
723 let record_batch = read_record_batch(
724 &Buffer::from(flight_data.data_body),
725 ipc_record_batch,
726 arrow_schema_ref.clone(),
727 &dictionaries_by_field,
728 None,
729 &ipc_message.version(),
730 )?;
731 Ok(ArrowFlightData::RecordBatch(record_batch))
732 }
733 MessageHeader::Schema => {
734 let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
735 ArrowError::ComputeError(
736 "Unable to convert flight data header to a schema".to_string(),
737 )
738 })?;
739
740 let arrow_schema = fb_to_schema(ipc_schema);
741 Ok(ArrowFlightData::Schema(arrow_schema))
742 }
743 MessageHeader::DictionaryBatch => {
744 let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
745 ArrowError::ComputeError(
746 "Unable to convert flight data header to a dictionary batch".to_string(),
747 )
748 })?;
749 Err(ArrowError::NotYetImplemented(
750 "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(),
751 ))
752 }
753 MessageHeader::Tensor => {
754 let _ = ipc_message.header_as_tensor().ok_or_else(|| {
755 ArrowError::ComputeError(
756 "Unable to convert flight data header to a tensor".to_string(),
757 )
758 })?;
759 Err(ArrowError::NotYetImplemented(
760 "no idea on how to convert an ipc tensor to an arrow type".to_string(),
761 ))
762 }
763 MessageHeader::SparseTensor => {
764 let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
765 ArrowError::ComputeError(
766 "Unable to convert flight data header to a sparse tensor".to_string(),
767 )
768 })?;
769 Err(ArrowError::NotYetImplemented(
770 "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(),
771 ))
772 }
773 _ => Err(ArrowError::ComputeError(format!(
774 "Unable to convert message with header_type: '{:?}' to arrow data",
775 ipc_message.header_type()
776 ))),
777 }
778}