1use std::fmt::{Display, Formatter};
21use std::pin::Pin;
22
23use super::{
24 ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
25 ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
26 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
27 ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
28 ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs,
29 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
30 CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
31 CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
32 CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
33 CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
34 SqlInfo, TicketStatementQuery,
35};
36use crate::{
37 flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
38 FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
39 SchemaResult, Ticket,
40};
41use futures::{stream::Peekable, Stream, StreamExt};
42use prost::Message;
43use tonic::{Request, Response, Status, Streaming};
44
45pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
46pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
47pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan";
48pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction";
49pub(crate) static END_TRANSACTION: &str = "EndTransaction";
50pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint";
51pub(crate) static END_SAVEPOINT: &str = "EndSavepoint";
52pub(crate) static CANCEL_QUERY: &str = "CancelQuery";
53
54#[tonic::async_trait]
56pub trait FlightSqlService: Sync + Send + Sized + 'static {
57 type FlightService: FlightService;
59
60 async fn do_handshake(
63 &self,
64 _request: Request<Streaming<HandshakeRequest>>,
65 ) -> Result<
66 Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
67 Status,
68 > {
69 Err(Status::unimplemented(
70 "Handshake has no default implementation",
71 ))
72 }
73
74 async fn do_get_fallback(
76 &self,
77 _request: Request<Ticket>,
78 message: Any,
79 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
80 Err(Status::unimplemented(format!(
81 "do_get: The defined request is invalid: {}",
82 message.type_url
83 )))
84 }
85
86 async fn get_flight_info_statement(
88 &self,
89 _query: CommandStatementQuery,
90 _request: Request<FlightDescriptor>,
91 ) -> Result<Response<FlightInfo>, Status> {
92 Err(Status::unimplemented(
93 "get_flight_info_statement has no default implementation",
94 ))
95 }
96
97 async fn get_flight_info_substrait_plan(
99 &self,
100 _query: CommandStatementSubstraitPlan,
101 _request: Request<FlightDescriptor>,
102 ) -> Result<Response<FlightInfo>, Status> {
103 Err(Status::unimplemented(
104 "get_flight_info_substrait_plan has no default implementation",
105 ))
106 }
107
108 async fn get_flight_info_prepared_statement(
110 &self,
111 _query: CommandPreparedStatementQuery,
112 _request: Request<FlightDescriptor>,
113 ) -> Result<Response<FlightInfo>, Status> {
114 Err(Status::unimplemented(
115 "get_flight_info_prepared_statement has no default implementation",
116 ))
117 }
118
119 async fn get_flight_info_catalogs(
121 &self,
122 _query: CommandGetCatalogs,
123 _request: Request<FlightDescriptor>,
124 ) -> Result<Response<FlightInfo>, Status> {
125 Err(Status::unimplemented(
126 "get_flight_info_catalogs has no default implementation",
127 ))
128 }
129
130 async fn get_flight_info_schemas(
132 &self,
133 _query: CommandGetDbSchemas,
134 _request: Request<FlightDescriptor>,
135 ) -> Result<Response<FlightInfo>, Status> {
136 Err(Status::unimplemented(
137 "get_flight_info_schemas has no default implementation",
138 ))
139 }
140
141 async fn get_flight_info_tables(
143 &self,
144 _query: CommandGetTables,
145 _request: Request<FlightDescriptor>,
146 ) -> Result<Response<FlightInfo>, Status> {
147 Err(Status::unimplemented(
148 "get_flight_info_tables has no default implementation",
149 ))
150 }
151
152 async fn get_flight_info_table_types(
154 &self,
155 _query: CommandGetTableTypes,
156 _request: Request<FlightDescriptor>,
157 ) -> Result<Response<FlightInfo>, Status> {
158 Err(Status::unimplemented(
159 "get_flight_info_table_types has no default implementation",
160 ))
161 }
162
163 async fn get_flight_info_sql_info(
165 &self,
166 _query: CommandGetSqlInfo,
167 _request: Request<FlightDescriptor>,
168 ) -> Result<Response<FlightInfo>, Status> {
169 Err(Status::unimplemented(
170 "get_flight_info_sql_info has no default implementation",
171 ))
172 }
173
174 async fn get_flight_info_primary_keys(
176 &self,
177 _query: CommandGetPrimaryKeys,
178 _request: Request<FlightDescriptor>,
179 ) -> Result<Response<FlightInfo>, Status> {
180 Err(Status::unimplemented(
181 "get_flight_info_primary_keys has no default implementation",
182 ))
183 }
184
185 async fn get_flight_info_exported_keys(
187 &self,
188 _query: CommandGetExportedKeys,
189 _request: Request<FlightDescriptor>,
190 ) -> Result<Response<FlightInfo>, Status> {
191 Err(Status::unimplemented(
192 "get_flight_info_exported_keys has no default implementation",
193 ))
194 }
195
196 async fn get_flight_info_imported_keys(
198 &self,
199 _query: CommandGetImportedKeys,
200 _request: Request<FlightDescriptor>,
201 ) -> Result<Response<FlightInfo>, Status> {
202 Err(Status::unimplemented(
203 "get_flight_info_imported_keys has no default implementation",
204 ))
205 }
206
207 async fn get_flight_info_cross_reference(
209 &self,
210 _query: CommandGetCrossReference,
211 _request: Request<FlightDescriptor>,
212 ) -> Result<Response<FlightInfo>, Status> {
213 Err(Status::unimplemented(
214 "get_flight_info_cross_reference has no default implementation",
215 ))
216 }
217
218 async fn get_flight_info_xdbc_type_info(
220 &self,
221 _query: CommandGetXdbcTypeInfo,
222 _request: Request<FlightDescriptor>,
223 ) -> Result<Response<FlightInfo>, Status> {
224 Err(Status::unimplemented(
225 "get_flight_info_xdbc_type_info has no default implementation",
226 ))
227 }
228
229 async fn get_flight_info_fallback(
231 &self,
232 cmd: Command,
233 _request: Request<FlightDescriptor>,
234 ) -> Result<Response<FlightInfo>, Status> {
235 Err(Status::unimplemented(format!(
236 "get_flight_info: The defined request is invalid: {}",
237 cmd.type_url()
238 )))
239 }
240
241 async fn do_get_statement(
245 &self,
246 _ticket: TicketStatementQuery,
247 _request: Request<Ticket>,
248 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
249 Err(Status::unimplemented(
250 "do_get_statement has no default implementation",
251 ))
252 }
253
254 async fn do_get_prepared_statement(
256 &self,
257 _query: CommandPreparedStatementQuery,
258 _request: Request<Ticket>,
259 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
260 Err(Status::unimplemented(
261 "do_get_prepared_statement has no default implementation",
262 ))
263 }
264
265 async fn do_get_catalogs(
267 &self,
268 _query: CommandGetCatalogs,
269 _request: Request<Ticket>,
270 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
271 Err(Status::unimplemented(
272 "do_get_catalogs has no default implementation",
273 ))
274 }
275
276 async fn do_get_schemas(
278 &self,
279 _query: CommandGetDbSchemas,
280 _request: Request<Ticket>,
281 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
282 Err(Status::unimplemented(
283 "do_get_schemas has no default implementation",
284 ))
285 }
286
287 async fn do_get_tables(
289 &self,
290 _query: CommandGetTables,
291 _request: Request<Ticket>,
292 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
293 Err(Status::unimplemented(
294 "do_get_tables has no default implementation",
295 ))
296 }
297
298 async fn do_get_table_types(
300 &self,
301 _query: CommandGetTableTypes,
302 _request: Request<Ticket>,
303 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
304 Err(Status::unimplemented(
305 "do_get_table_types has no default implementation",
306 ))
307 }
308
309 async fn do_get_sql_info(
311 &self,
312 _query: CommandGetSqlInfo,
313 _request: Request<Ticket>,
314 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
315 Err(Status::unimplemented(
316 "do_get_sql_info has no default implementation",
317 ))
318 }
319
320 async fn do_get_primary_keys(
322 &self,
323 _query: CommandGetPrimaryKeys,
324 _request: Request<Ticket>,
325 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
326 Err(Status::unimplemented(
327 "do_get_primary_keys has no default implementation",
328 ))
329 }
330
331 async fn do_get_exported_keys(
333 &self,
334 _query: CommandGetExportedKeys,
335 _request: Request<Ticket>,
336 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
337 Err(Status::unimplemented(
338 "do_get_exported_keys has no default implementation",
339 ))
340 }
341
342 async fn do_get_imported_keys(
344 &self,
345 _query: CommandGetImportedKeys,
346 _request: Request<Ticket>,
347 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
348 Err(Status::unimplemented(
349 "do_get_imported_keys has no default implementation",
350 ))
351 }
352
353 async fn do_get_cross_reference(
355 &self,
356 _query: CommandGetCrossReference,
357 _request: Request<Ticket>,
358 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
359 Err(Status::unimplemented(
360 "do_get_cross_reference has no default implementation",
361 ))
362 }
363
364 async fn do_get_xdbc_type_info(
366 &self,
367 _query: CommandGetXdbcTypeInfo,
368 _request: Request<Ticket>,
369 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
370 Err(Status::unimplemented(
371 "do_get_xdbc_type_info has no default implementation",
372 ))
373 }
374
375 async fn do_put_fallback(
379 &self,
380 _request: Request<PeekableFlightDataStream>,
381 message: Any,
382 ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
383 Err(Status::unimplemented(format!(
384 "do_put: The defined request is invalid: {}",
385 message.type_url
386 )))
387 }
388
389 async fn do_put_error_callback(
391 &self,
392 _request: Request<PeekableFlightDataStream>,
393 error: DoPutError,
394 ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
395 Err(Status::unimplemented(format!("Unhandled Error: {}", error)))
396 }
397
398 async fn do_put_statement_update(
400 &self,
401 _ticket: CommandStatementUpdate,
402 _request: Request<PeekableFlightDataStream>,
403 ) -> Result<i64, Status> {
404 Err(Status::unimplemented(
405 "do_put_statement_update has no default implementation",
406 ))
407 }
408
409 async fn do_put_statement_ingest(
411 &self,
412 _ticket: CommandStatementIngest,
413 _request: Request<PeekableFlightDataStream>,
414 ) -> Result<i64, Status> {
415 Err(Status::unimplemented(
416 "do_put_statement_ingest has no default implementation",
417 ))
418 }
419
420 async fn do_put_prepared_statement_query(
426 &self,
427 _query: CommandPreparedStatementQuery,
428 _request: Request<PeekableFlightDataStream>,
429 ) -> Result<DoPutPreparedStatementResult, Status> {
430 Err(Status::unimplemented(
431 "do_put_prepared_statement_query has no default implementation",
432 ))
433 }
434
435 async fn do_put_prepared_statement_update(
437 &self,
438 _query: CommandPreparedStatementUpdate,
439 _request: Request<PeekableFlightDataStream>,
440 ) -> Result<i64, Status> {
441 Err(Status::unimplemented(
442 "do_put_prepared_statement_update has no default implementation",
443 ))
444 }
445
446 async fn do_put_substrait_plan(
448 &self,
449 _query: CommandStatementSubstraitPlan,
450 _request: Request<PeekableFlightDataStream>,
451 ) -> Result<i64, Status> {
452 Err(Status::unimplemented(
453 "do_put_substrait_plan has no default implementation",
454 ))
455 }
456
457 async fn do_action_fallback(
461 &self,
462 request: Request<Action>,
463 ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
464 Err(Status::invalid_argument(format!(
465 "do_action: The defined request is invalid: {:?}",
466 request.get_ref().r#type
467 )))
468 }
469
470 async fn list_custom_actions(&self) -> Option<Vec<Result<ActionType, Status>>> {
472 None
473 }
474
475 async fn do_action_create_prepared_statement(
477 &self,
478 _query: ActionCreatePreparedStatementRequest,
479 _request: Request<Action>,
480 ) -> Result<ActionCreatePreparedStatementResult, Status> {
481 Err(Status::unimplemented(
482 "do_action_create_prepared_statement has no default implementation",
483 ))
484 }
485
486 async fn do_action_close_prepared_statement(
488 &self,
489 _query: ActionClosePreparedStatementRequest,
490 _request: Request<Action>,
491 ) -> Result<(), Status> {
492 Err(Status::unimplemented(
493 "do_action_close_prepared_statement has no default implementation",
494 ))
495 }
496
497 async fn do_action_create_prepared_substrait_plan(
499 &self,
500 _query: ActionCreatePreparedSubstraitPlanRequest,
501 _request: Request<Action>,
502 ) -> Result<ActionCreatePreparedStatementResult, Status> {
503 Err(Status::unimplemented(
504 "do_action_create_prepared_substrait_plan has no default implementation",
505 ))
506 }
507
508 async fn do_action_begin_transaction(
510 &self,
511 _query: ActionBeginTransactionRequest,
512 _request: Request<Action>,
513 ) -> Result<ActionBeginTransactionResult, Status> {
514 Err(Status::unimplemented(
515 "do_action_begin_transaction has no default implementation",
516 ))
517 }
518
519 async fn do_action_end_transaction(
521 &self,
522 _query: ActionEndTransactionRequest,
523 _request: Request<Action>,
524 ) -> Result<(), Status> {
525 Err(Status::unimplemented(
526 "do_action_end_transaction has no default implementation",
527 ))
528 }
529
530 async fn do_action_begin_savepoint(
532 &self,
533 _query: ActionBeginSavepointRequest,
534 _request: Request<Action>,
535 ) -> Result<ActionBeginSavepointResult, Status> {
536 Err(Status::unimplemented(
537 "do_action_begin_savepoint has no default implementation",
538 ))
539 }
540
541 async fn do_action_end_savepoint(
543 &self,
544 _query: ActionEndSavepointRequest,
545 _request: Request<Action>,
546 ) -> Result<(), Status> {
547 Err(Status::unimplemented(
548 "do_action_end_savepoint has no default implementation",
549 ))
550 }
551
552 async fn do_action_cancel_query(
554 &self,
555 _query: ActionCancelQueryRequest,
556 _request: Request<Action>,
557 ) -> Result<ActionCancelQueryResult, Status> {
558 Err(Status::unimplemented(
559 "do_action_cancel_query has no default implementation",
560 ))
561 }
562
563 async fn do_exchange_fallback(
566 &self,
567 _request: Request<Streaming<FlightData>>,
568 ) -> Result<Response<<Self as FlightService>::DoExchangeStream>, Status> {
569 Err(Status::unimplemented("Not yet implemented"))
570 }
571
572 async fn register_sql_info(&self, id: i32, result: &SqlInfo);
574}
575
576#[tonic::async_trait]
578impl<T: 'static> FlightService for T
579where
580 T: FlightSqlService + Send,
581{
582 type HandshakeStream =
583 Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
584 type ListFlightsStream =
585 Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
586 type DoGetStream = Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
587 type DoPutStream = Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
588 type DoActionStream =
589 Pin<Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>>;
590 type ListActionsStream =
591 Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
592 type DoExchangeStream =
593 Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
594
595 async fn handshake(
596 &self,
597 request: Request<Streaming<HandshakeRequest>>,
598 ) -> Result<Response<Self::HandshakeStream>, Status> {
599 let res = self.do_handshake(request).await?;
600 Ok(res)
601 }
602
603 async fn list_flights(
604 &self,
605 _request: Request<Criteria>,
606 ) -> Result<Response<Self::ListFlightsStream>, Status> {
607 Err(Status::unimplemented("Not yet implemented"))
608 }
609
610 async fn get_flight_info(
611 &self,
612 request: Request<FlightDescriptor>,
613 ) -> Result<Response<FlightInfo>, Status> {
614 let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
615
616 match Command::try_from(message).map_err(arrow_error_to_status)? {
617 Command::CommandStatementQuery(token) => {
618 self.get_flight_info_statement(token, request).await
619 }
620 Command::CommandPreparedStatementQuery(handle) => {
621 self.get_flight_info_prepared_statement(handle, request)
622 .await
623 }
624 Command::CommandStatementSubstraitPlan(handle) => {
625 self.get_flight_info_substrait_plan(handle, request).await
626 }
627 Command::CommandGetCatalogs(token) => {
628 self.get_flight_info_catalogs(token, request).await
629 }
630 Command::CommandGetDbSchemas(token) => {
631 return self.get_flight_info_schemas(token, request).await
632 }
633 Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await,
634 Command::CommandGetTableTypes(token) => {
635 self.get_flight_info_table_types(token, request).await
636 }
637 Command::CommandGetSqlInfo(token) => {
638 self.get_flight_info_sql_info(token, request).await
639 }
640 Command::CommandGetPrimaryKeys(token) => {
641 self.get_flight_info_primary_keys(token, request).await
642 }
643 Command::CommandGetExportedKeys(token) => {
644 self.get_flight_info_exported_keys(token, request).await
645 }
646 Command::CommandGetImportedKeys(token) => {
647 self.get_flight_info_imported_keys(token, request).await
648 }
649 Command::CommandGetCrossReference(token) => {
650 self.get_flight_info_cross_reference(token, request).await
651 }
652 Command::CommandGetXdbcTypeInfo(token) => {
653 self.get_flight_info_xdbc_type_info(token, request).await
654 }
655 cmd => self.get_flight_info_fallback(cmd, request).await,
656 }
657 }
658
659 async fn poll_flight_info(
660 &self,
661 _request: Request<FlightDescriptor>,
662 ) -> Result<Response<PollInfo>, Status> {
663 Err(Status::unimplemented("Not yet implemented"))
664 }
665
666 async fn get_schema(
667 &self,
668 _request: Request<FlightDescriptor>,
669 ) -> Result<Response<SchemaResult>, Status> {
670 Err(Status::unimplemented("Not yet implemented"))
671 }
672
673 async fn do_get(
674 &self,
675 request: Request<Ticket>,
676 ) -> Result<Response<Self::DoGetStream>, Status> {
677 let msg: Any =
678 Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?;
679
680 match Command::try_from(msg).map_err(arrow_error_to_status)? {
681 Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await,
682 Command::CommandPreparedStatementQuery(command) => {
683 self.do_get_prepared_statement(command, request).await
684 }
685 Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await,
686 Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await,
687 Command::CommandGetTables(command) => self.do_get_tables(command, request).await,
688 Command::CommandGetTableTypes(command) => {
689 self.do_get_table_types(command, request).await
690 }
691 Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await,
692 Command::CommandGetPrimaryKeys(command) => {
693 self.do_get_primary_keys(command, request).await
694 }
695 Command::CommandGetExportedKeys(command) => {
696 self.do_get_exported_keys(command, request).await
697 }
698 Command::CommandGetImportedKeys(command) => {
699 self.do_get_imported_keys(command, request).await
700 }
701 Command::CommandGetCrossReference(command) => {
702 self.do_get_cross_reference(command, request).await
703 }
704 Command::CommandGetXdbcTypeInfo(command) => {
705 self.do_get_xdbc_type_info(command, request).await
706 }
707 cmd => self.do_get_fallback(request, cmd.into_any()).await,
708 }
709 }
710
711 async fn do_put(
712 &self,
713 request: Request<Streaming<FlightData>>,
714 ) -> Result<Response<Self::DoPutStream>, Status> {
715 let mut request = request.map(PeekableFlightDataStream::new);
722 let mut stream = Pin::new(request.get_mut());
723
724 let peeked_item = stream.peek().await.cloned();
725 let Some(cmd) = peeked_item else {
726 return self
727 .do_put_error_callback(request, DoPutError::MissingCommand)
728 .await;
729 };
730
731 let Some(flight_descriptor) = cmd?.flight_descriptor else {
732 return self
733 .do_put_error_callback(request, DoPutError::MissingFlightDescriptor)
734 .await;
735 };
736 let message = Any::decode(flight_descriptor.cmd).map_err(decode_error_to_status)?;
737 match Command::try_from(message).map_err(arrow_error_to_status)? {
738 Command::CommandStatementUpdate(command) => {
739 let record_count = self.do_put_statement_update(command, request).await?;
740 let result = DoPutUpdateResult { record_count };
741 let output = futures::stream::iter(vec![Ok(PutResult {
742 app_metadata: result.encode_to_vec().into(),
743 })]);
744 Ok(Response::new(Box::pin(output)))
745 }
746 Command::CommandStatementIngest(command) => {
747 let record_count = self.do_put_statement_ingest(command, request).await?;
748 let result = DoPutUpdateResult { record_count };
749 let output = futures::stream::iter(vec![Ok(PutResult {
750 app_metadata: result.encode_to_vec().into(),
751 })]);
752 Ok(Response::new(Box::pin(output)))
753 }
754 Command::CommandPreparedStatementQuery(command) => {
755 let result = self
756 .do_put_prepared_statement_query(command, request)
757 .await?;
758 let output = futures::stream::iter(vec![Ok(PutResult {
759 app_metadata: result.encode_to_vec().into(),
760 })]);
761 Ok(Response::new(Box::pin(output)))
762 }
763 Command::CommandStatementSubstraitPlan(command) => {
764 let record_count = self.do_put_substrait_plan(command, request).await?;
765 let result = DoPutUpdateResult { record_count };
766 let output = futures::stream::iter(vec![Ok(PutResult {
767 app_metadata: result.encode_to_vec().into(),
768 })]);
769 Ok(Response::new(Box::pin(output)))
770 }
771 Command::CommandPreparedStatementUpdate(command) => {
772 let record_count = self
773 .do_put_prepared_statement_update(command, request)
774 .await?;
775 let result = DoPutUpdateResult { record_count };
776 let output = futures::stream::iter(vec![Ok(PutResult {
777 app_metadata: result.encode_to_vec().into(),
778 })]);
779 Ok(Response::new(Box::pin(output)))
780 }
781 cmd => self.do_put_fallback(request, cmd.into_any()).await,
782 }
783 }
784
785 async fn list_actions(
786 &self,
787 _request: Request<Empty>,
788 ) -> Result<Response<Self::ListActionsStream>, Status> {
789 let create_prepared_statement_action_type = ActionType {
790 r#type: CREATE_PREPARED_STATEMENT.to_string(),
791 description: "Creates a reusable prepared statement resource on the server.\n
792 Request Message: ActionCreatePreparedStatementRequest\n
793 Response Message: ActionCreatePreparedStatementResult"
794 .into(),
795 };
796 let close_prepared_statement_action_type = ActionType {
797 r#type: CLOSE_PREPARED_STATEMENT.to_string(),
798 description: "Closes a reusable prepared statement resource on the server.\n
799 Request Message: ActionClosePreparedStatementRequest\n
800 Response Message: N/A"
801 .into(),
802 };
803 let create_prepared_substrait_plan_action_type = ActionType {
804 r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(),
805 description: "Creates a reusable prepared substrait plan resource on the server.\n
806 Request Message: ActionCreatePreparedSubstraitPlanRequest\n
807 Response Message: ActionCreatePreparedStatementResult"
808 .into(),
809 };
810 let begin_transaction_action_type = ActionType {
811 r#type: BEGIN_TRANSACTION.to_string(),
812 description: "Begins a transaction.\n
813 Request Message: ActionBeginTransactionRequest\n
814 Response Message: ActionBeginTransactionResult"
815 .into(),
816 };
817 let end_transaction_action_type = ActionType {
818 r#type: END_TRANSACTION.to_string(),
819 description: "Ends a transaction\n
820 Request Message: ActionEndTransactionRequest\n
821 Response Message: N/A"
822 .into(),
823 };
824 let begin_savepoint_action_type = ActionType {
825 r#type: BEGIN_SAVEPOINT.to_string(),
826 description: "Begins a savepoint.\n
827 Request Message: ActionBeginSavepointRequest\n
828 Response Message: ActionBeginSavepointResult"
829 .into(),
830 };
831 let end_savepoint_action_type = ActionType {
832 r#type: END_SAVEPOINT.to_string(),
833 description: "Ends a savepoint\n
834 Request Message: ActionEndSavepointRequest\n
835 Response Message: N/A"
836 .into(),
837 };
838 let cancel_query_action_type = ActionType {
839 r#type: CANCEL_QUERY.to_string(),
840 description: "Cancels a query\n
841 Request Message: ActionCancelQueryRequest\n
842 Response Message: ActionCancelQueryResult"
843 .into(),
844 };
845 let mut actions: Vec<Result<ActionType, Status>> = vec![
846 Ok(create_prepared_statement_action_type),
847 Ok(close_prepared_statement_action_type),
848 Ok(create_prepared_substrait_plan_action_type),
849 Ok(begin_transaction_action_type),
850 Ok(end_transaction_action_type),
851 Ok(begin_savepoint_action_type),
852 Ok(end_savepoint_action_type),
853 Ok(cancel_query_action_type),
854 ];
855
856 if let Some(mut custom_actions) = self.list_custom_actions().await {
857 actions.append(&mut custom_actions);
858 }
859
860 let output = futures::stream::iter(actions);
861 Ok(Response::new(Box::pin(output) as Self::ListActionsStream))
862 }
863
864 async fn do_action(
865 &self,
866 request: Request<Action>,
867 ) -> Result<Response<Self::DoActionStream>, Status> {
868 if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
869 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
870
871 let cmd: ActionCreatePreparedStatementRequest = any
872 .unpack()
873 .map_err(arrow_error_to_status)?
874 .ok_or_else(|| {
875 Status::invalid_argument(
876 "Unable to unpack ActionCreatePreparedStatementRequest.",
877 )
878 })?;
879 let stmt = self
880 .do_action_create_prepared_statement(cmd, request)
881 .await?;
882 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
883 body: stmt.as_any().encode_to_vec().into(),
884 })]);
885 return Ok(Response::new(Box::pin(output)));
886 } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
887 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
888
889 let cmd: ActionClosePreparedStatementRequest = any
890 .unpack()
891 .map_err(arrow_error_to_status)?
892 .ok_or_else(|| {
893 Status::invalid_argument(
894 "Unable to unpack ActionClosePreparedStatementRequest.",
895 )
896 })?;
897 self.do_action_close_prepared_statement(cmd, request)
898 .await?;
899 return Ok(Response::new(Box::pin(futures::stream::empty())));
900 } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN {
901 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
902
903 let cmd: ActionCreatePreparedSubstraitPlanRequest = any
904 .unpack()
905 .map_err(arrow_error_to_status)?
906 .ok_or_else(|| {
907 Status::invalid_argument(
908 "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.",
909 )
910 })?;
911 self.do_action_create_prepared_substrait_plan(cmd, request)
912 .await?;
913 return Ok(Response::new(Box::pin(futures::stream::empty())));
914 } else if request.get_ref().r#type == BEGIN_TRANSACTION {
915 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
916
917 let cmd: ActionBeginTransactionRequest = any
918 .unpack()
919 .map_err(arrow_error_to_status)?
920 .ok_or_else(|| {
921 Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.")
922 })?;
923 let stmt = self.do_action_begin_transaction(cmd, request).await?;
924 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
925 body: stmt.as_any().encode_to_vec().into(),
926 })]);
927 return Ok(Response::new(Box::pin(output)));
928 } else if request.get_ref().r#type == END_TRANSACTION {
929 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
930
931 let cmd: ActionEndTransactionRequest = any
932 .unpack()
933 .map_err(arrow_error_to_status)?
934 .ok_or_else(|| {
935 Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.")
936 })?;
937 self.do_action_end_transaction(cmd, request).await?;
938 return Ok(Response::new(Box::pin(futures::stream::empty())));
939 } else if request.get_ref().r#type == BEGIN_SAVEPOINT {
940 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
941
942 let cmd: ActionBeginSavepointRequest = any
943 .unpack()
944 .map_err(arrow_error_to_status)?
945 .ok_or_else(|| {
946 Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.")
947 })?;
948 let stmt = self.do_action_begin_savepoint(cmd, request).await?;
949 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
950 body: stmt.as_any().encode_to_vec().into(),
951 })]);
952 return Ok(Response::new(Box::pin(output)));
953 } else if request.get_ref().r#type == END_SAVEPOINT {
954 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
955
956 let cmd: ActionEndSavepointRequest = any
957 .unpack()
958 .map_err(arrow_error_to_status)?
959 .ok_or_else(|| {
960 Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.")
961 })?;
962 self.do_action_end_savepoint(cmd, request).await?;
963 return Ok(Response::new(Box::pin(futures::stream::empty())));
964 } else if request.get_ref().r#type == CANCEL_QUERY {
965 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
966
967 let cmd: ActionCancelQueryRequest = any
968 .unpack()
969 .map_err(arrow_error_to_status)?
970 .ok_or_else(|| {
971 Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.")
972 })?;
973 let stmt = self.do_action_cancel_query(cmd, request).await?;
974 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
975 body: stmt.as_any().encode_to_vec().into(),
976 })]);
977 return Ok(Response::new(Box::pin(output)));
978 }
979
980 self.do_action_fallback(request).await
981 }
982
983 async fn do_exchange(
984 &self,
985 request: Request<Streaming<FlightData>>,
986 ) -> Result<Response<Self::DoExchangeStream>, Status> {
987 self.do_exchange_fallback(request).await
988 }
989}
990
991pub enum DoPutError {
993 MissingCommand,
995 MissingFlightDescriptor,
997}
998impl Display for DoPutError {
999 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1000 match self {
1001 DoPutError::MissingCommand => {
1002 write!(f, "Command is missing.")
1003 }
1004 DoPutError::MissingFlightDescriptor => {
1005 write!(f, "Flight descriptor is missing.")
1006 }
1007 }
1008 }
1009}
1010
1011fn decode_error_to_status(err: prost::DecodeError) -> Status {
1012 Status::invalid_argument(format!("{err:?}"))
1013}
1014
1015fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
1016 Status::internal(format!("{err:?}"))
1017}
1018
1019pub struct PeekableFlightDataStream {
1065 inner: Peekable<Streaming<FlightData>>,
1066}
1067
1068impl PeekableFlightDataStream {
1069 fn new(stream: Streaming<FlightData>) -> Self {
1070 Self {
1071 inner: stream.peekable(),
1072 }
1073 }
1074
1075 pub fn into_inner(self) -> Streaming<FlightData> {
1079 self.inner.into_inner()
1080 }
1081
1082 pub fn into_peekable(self) -> Peekable<Streaming<FlightData>> {
1086 self.inner
1087 }
1088
1089 pub async fn peek(&mut self) -> Option<&Result<FlightData, Status>> {
1091 Pin::new(&mut self.inner).peek().await
1092 }
1093}
1094
1095impl Stream for PeekableFlightDataStream {
1096 type Item = Result<FlightData, Status>;
1097
1098 fn poll_next(
1099 mut self: Pin<&mut Self>,
1100 cx: &mut std::task::Context<'_>,
1101 ) -> std::task::Poll<Option<Self::Item>> {
1102 self.inner.poll_next_unpin(cx)
1103 }
1104}