1use std::pin::Pin;
21
22use futures::{stream::Peekable, Stream, StreamExt};
23use prost::Message;
24use tonic::{Request, Response, Status, Streaming};
25
26use super::{
27 ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
28 ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
29 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
30 ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
31 ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs,
32 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
33 CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
34 CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
35 CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
36 CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
37 SqlInfo, TicketStatementQuery,
38};
39use crate::{
40 flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
41 FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
42 SchemaResult, Ticket,
43};
44
45pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
46pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
47pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan";
48pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction";
49pub(crate) static END_TRANSACTION: &str = "EndTransaction";
50pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint";
51pub(crate) static END_SAVEPOINT: &str = "EndSavepoint";
52pub(crate) static CANCEL_QUERY: &str = "CancelQuery";
53
54#[tonic::async_trait]
56pub trait FlightSqlService: Sync + Send + Sized + 'static {
57 type FlightService: FlightService;
59
60 async fn do_handshake(
63 &self,
64 _request: Request<Streaming<HandshakeRequest>>,
65 ) -> Result<
66 Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
67 Status,
68 > {
69 Err(Status::unimplemented(
70 "Handshake has no default implementation",
71 ))
72 }
73
74 async fn do_get_fallback(
76 &self,
77 _request: Request<Ticket>,
78 message: Any,
79 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
80 Err(Status::unimplemented(format!(
81 "do_get: The defined request is invalid: {}",
82 message.type_url
83 )))
84 }
85
86 async fn get_flight_info_statement(
88 &self,
89 _query: CommandStatementQuery,
90 _request: Request<FlightDescriptor>,
91 ) -> Result<Response<FlightInfo>, Status> {
92 Err(Status::unimplemented(
93 "get_flight_info_statement has no default implementation",
94 ))
95 }
96
97 async fn get_flight_info_substrait_plan(
99 &self,
100 _query: CommandStatementSubstraitPlan,
101 _request: Request<FlightDescriptor>,
102 ) -> Result<Response<FlightInfo>, Status> {
103 Err(Status::unimplemented(
104 "get_flight_info_substrait_plan has no default implementation",
105 ))
106 }
107
108 async fn get_flight_info_prepared_statement(
110 &self,
111 _query: CommandPreparedStatementQuery,
112 _request: Request<FlightDescriptor>,
113 ) -> Result<Response<FlightInfo>, Status> {
114 Err(Status::unimplemented(
115 "get_flight_info_prepared_statement has no default implementation",
116 ))
117 }
118
119 async fn get_flight_info_catalogs(
121 &self,
122 _query: CommandGetCatalogs,
123 _request: Request<FlightDescriptor>,
124 ) -> Result<Response<FlightInfo>, Status> {
125 Err(Status::unimplemented(
126 "get_flight_info_catalogs has no default implementation",
127 ))
128 }
129
130 async fn get_flight_info_schemas(
132 &self,
133 _query: CommandGetDbSchemas,
134 _request: Request<FlightDescriptor>,
135 ) -> Result<Response<FlightInfo>, Status> {
136 Err(Status::unimplemented(
137 "get_flight_info_schemas has no default implementation",
138 ))
139 }
140
141 async fn get_flight_info_tables(
143 &self,
144 _query: CommandGetTables,
145 _request: Request<FlightDescriptor>,
146 ) -> Result<Response<FlightInfo>, Status> {
147 Err(Status::unimplemented(
148 "get_flight_info_tables has no default implementation",
149 ))
150 }
151
152 async fn get_flight_info_table_types(
154 &self,
155 _query: CommandGetTableTypes,
156 _request: Request<FlightDescriptor>,
157 ) -> Result<Response<FlightInfo>, Status> {
158 Err(Status::unimplemented(
159 "get_flight_info_table_types has no default implementation",
160 ))
161 }
162
163 async fn get_flight_info_sql_info(
165 &self,
166 _query: CommandGetSqlInfo,
167 _request: Request<FlightDescriptor>,
168 ) -> Result<Response<FlightInfo>, Status> {
169 Err(Status::unimplemented(
170 "get_flight_info_sql_info has no default implementation",
171 ))
172 }
173
174 async fn get_flight_info_primary_keys(
176 &self,
177 _query: CommandGetPrimaryKeys,
178 _request: Request<FlightDescriptor>,
179 ) -> Result<Response<FlightInfo>, Status> {
180 Err(Status::unimplemented(
181 "get_flight_info_primary_keys has no default implementation",
182 ))
183 }
184
185 async fn get_flight_info_exported_keys(
187 &self,
188 _query: CommandGetExportedKeys,
189 _request: Request<FlightDescriptor>,
190 ) -> Result<Response<FlightInfo>, Status> {
191 Err(Status::unimplemented(
192 "get_flight_info_exported_keys has no default implementation",
193 ))
194 }
195
196 async fn get_flight_info_imported_keys(
198 &self,
199 _query: CommandGetImportedKeys,
200 _request: Request<FlightDescriptor>,
201 ) -> Result<Response<FlightInfo>, Status> {
202 Err(Status::unimplemented(
203 "get_flight_info_imported_keys has no default implementation",
204 ))
205 }
206
207 async fn get_flight_info_cross_reference(
209 &self,
210 _query: CommandGetCrossReference,
211 _request: Request<FlightDescriptor>,
212 ) -> Result<Response<FlightInfo>, Status> {
213 Err(Status::unimplemented(
214 "get_flight_info_cross_reference has no default implementation",
215 ))
216 }
217
218 async fn get_flight_info_xdbc_type_info(
220 &self,
221 _query: CommandGetXdbcTypeInfo,
222 _request: Request<FlightDescriptor>,
223 ) -> Result<Response<FlightInfo>, Status> {
224 Err(Status::unimplemented(
225 "get_flight_info_xdbc_type_info has no default implementation",
226 ))
227 }
228
229 async fn get_flight_info_fallback(
231 &self,
232 cmd: Command,
233 _request: Request<FlightDescriptor>,
234 ) -> Result<Response<FlightInfo>, Status> {
235 Err(Status::unimplemented(format!(
236 "get_flight_info: The defined request is invalid: {}",
237 cmd.type_url()
238 )))
239 }
240
241 async fn do_get_statement(
245 &self,
246 _ticket: TicketStatementQuery,
247 _request: Request<Ticket>,
248 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
249 Err(Status::unimplemented(
250 "do_get_statement has no default implementation",
251 ))
252 }
253
254 async fn do_get_prepared_statement(
256 &self,
257 _query: CommandPreparedStatementQuery,
258 _request: Request<Ticket>,
259 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
260 Err(Status::unimplemented(
261 "do_get_prepared_statement has no default implementation",
262 ))
263 }
264
265 async fn do_get_catalogs(
267 &self,
268 _query: CommandGetCatalogs,
269 _request: Request<Ticket>,
270 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
271 Err(Status::unimplemented(
272 "do_get_catalogs has no default implementation",
273 ))
274 }
275
276 async fn do_get_schemas(
278 &self,
279 _query: CommandGetDbSchemas,
280 _request: Request<Ticket>,
281 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
282 Err(Status::unimplemented(
283 "do_get_schemas has no default implementation",
284 ))
285 }
286
287 async fn do_get_tables(
289 &self,
290 _query: CommandGetTables,
291 _request: Request<Ticket>,
292 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
293 Err(Status::unimplemented(
294 "do_get_tables has no default implementation",
295 ))
296 }
297
298 async fn do_get_table_types(
300 &self,
301 _query: CommandGetTableTypes,
302 _request: Request<Ticket>,
303 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
304 Err(Status::unimplemented(
305 "do_get_table_types has no default implementation",
306 ))
307 }
308
309 async fn do_get_sql_info(
311 &self,
312 _query: CommandGetSqlInfo,
313 _request: Request<Ticket>,
314 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
315 Err(Status::unimplemented(
316 "do_get_sql_info has no default implementation",
317 ))
318 }
319
320 async fn do_get_primary_keys(
322 &self,
323 _query: CommandGetPrimaryKeys,
324 _request: Request<Ticket>,
325 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
326 Err(Status::unimplemented(
327 "do_get_primary_keys has no default implementation",
328 ))
329 }
330
331 async fn do_get_exported_keys(
333 &self,
334 _query: CommandGetExportedKeys,
335 _request: Request<Ticket>,
336 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
337 Err(Status::unimplemented(
338 "do_get_exported_keys has no default implementation",
339 ))
340 }
341
342 async fn do_get_imported_keys(
344 &self,
345 _query: CommandGetImportedKeys,
346 _request: Request<Ticket>,
347 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
348 Err(Status::unimplemented(
349 "do_get_imported_keys has no default implementation",
350 ))
351 }
352
353 async fn do_get_cross_reference(
355 &self,
356 _query: CommandGetCrossReference,
357 _request: Request<Ticket>,
358 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
359 Err(Status::unimplemented(
360 "do_get_cross_reference has no default implementation",
361 ))
362 }
363
364 async fn do_get_xdbc_type_info(
366 &self,
367 _query: CommandGetXdbcTypeInfo,
368 _request: Request<Ticket>,
369 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
370 Err(Status::unimplemented(
371 "do_get_xdbc_type_info has no default implementation",
372 ))
373 }
374
375 async fn do_put_fallback(
379 &self,
380 _request: Request<PeekableFlightDataStream>,
381 message: Any,
382 ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
383 Err(Status::unimplemented(format!(
384 "do_put: The defined request is invalid: {}",
385 message.type_url
386 )))
387 }
388
389 async fn do_put_statement_update(
391 &self,
392 _ticket: CommandStatementUpdate,
393 _request: Request<PeekableFlightDataStream>,
394 ) -> Result<i64, Status> {
395 Err(Status::unimplemented(
396 "do_put_statement_update has no default implementation",
397 ))
398 }
399
400 async fn do_put_statement_ingest(
402 &self,
403 _ticket: CommandStatementIngest,
404 _request: Request<PeekableFlightDataStream>,
405 ) -> Result<i64, Status> {
406 Err(Status::unimplemented(
407 "do_put_statement_ingest has no default implementation",
408 ))
409 }
410
411 async fn do_put_prepared_statement_query(
417 &self,
418 _query: CommandPreparedStatementQuery,
419 _request: Request<PeekableFlightDataStream>,
420 ) -> Result<DoPutPreparedStatementResult, Status> {
421 Err(Status::unimplemented(
422 "do_put_prepared_statement_query has no default implementation",
423 ))
424 }
425
426 async fn do_put_prepared_statement_update(
428 &self,
429 _query: CommandPreparedStatementUpdate,
430 _request: Request<PeekableFlightDataStream>,
431 ) -> Result<i64, Status> {
432 Err(Status::unimplemented(
433 "do_put_prepared_statement_update has no default implementation",
434 ))
435 }
436
437 async fn do_put_substrait_plan(
439 &self,
440 _query: CommandStatementSubstraitPlan,
441 _request: Request<PeekableFlightDataStream>,
442 ) -> Result<i64, Status> {
443 Err(Status::unimplemented(
444 "do_put_substrait_plan has no default implementation",
445 ))
446 }
447
448 async fn do_action_fallback(
452 &self,
453 request: Request<Action>,
454 ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
455 Err(Status::invalid_argument(format!(
456 "do_action: The defined request is invalid: {:?}",
457 request.get_ref().r#type
458 )))
459 }
460
461 async fn list_custom_actions(&self) -> Option<Vec<Result<ActionType, Status>>> {
463 None
464 }
465
466 async fn do_action_create_prepared_statement(
468 &self,
469 _query: ActionCreatePreparedStatementRequest,
470 _request: Request<Action>,
471 ) -> Result<ActionCreatePreparedStatementResult, Status> {
472 Err(Status::unimplemented(
473 "do_action_create_prepared_statement has no default implementation",
474 ))
475 }
476
477 async fn do_action_close_prepared_statement(
479 &self,
480 _query: ActionClosePreparedStatementRequest,
481 _request: Request<Action>,
482 ) -> Result<(), Status> {
483 Err(Status::unimplemented(
484 "do_action_close_prepared_statement has no default implementation",
485 ))
486 }
487
488 async fn do_action_create_prepared_substrait_plan(
490 &self,
491 _query: ActionCreatePreparedSubstraitPlanRequest,
492 _request: Request<Action>,
493 ) -> Result<ActionCreatePreparedStatementResult, Status> {
494 Err(Status::unimplemented(
495 "do_action_create_prepared_substrait_plan has no default implementation",
496 ))
497 }
498
499 async fn do_action_begin_transaction(
501 &self,
502 _query: ActionBeginTransactionRequest,
503 _request: Request<Action>,
504 ) -> Result<ActionBeginTransactionResult, Status> {
505 Err(Status::unimplemented(
506 "do_action_begin_transaction has no default implementation",
507 ))
508 }
509
510 async fn do_action_end_transaction(
512 &self,
513 _query: ActionEndTransactionRequest,
514 _request: Request<Action>,
515 ) -> Result<(), Status> {
516 Err(Status::unimplemented(
517 "do_action_end_transaction has no default implementation",
518 ))
519 }
520
521 async fn do_action_begin_savepoint(
523 &self,
524 _query: ActionBeginSavepointRequest,
525 _request: Request<Action>,
526 ) -> Result<ActionBeginSavepointResult, Status> {
527 Err(Status::unimplemented(
528 "do_action_begin_savepoint has no default implementation",
529 ))
530 }
531
532 async fn do_action_end_savepoint(
534 &self,
535 _query: ActionEndSavepointRequest,
536 _request: Request<Action>,
537 ) -> Result<(), Status> {
538 Err(Status::unimplemented(
539 "do_action_end_savepoint has no default implementation",
540 ))
541 }
542
543 async fn do_action_cancel_query(
545 &self,
546 _query: ActionCancelQueryRequest,
547 _request: Request<Action>,
548 ) -> Result<ActionCancelQueryResult, Status> {
549 Err(Status::unimplemented(
550 "do_action_cancel_query has no default implementation",
551 ))
552 }
553
554 async fn do_exchange_fallback(
557 &self,
558 _request: Request<Streaming<FlightData>>,
559 ) -> Result<Response<<Self as FlightService>::DoExchangeStream>, Status> {
560 Err(Status::unimplemented("Not yet implemented"))
561 }
562
563 async fn register_sql_info(&self, id: i32, result: &SqlInfo);
565}
566
567#[tonic::async_trait]
569impl<T: 'static> FlightService for T
570where
571 T: FlightSqlService + Send,
572{
573 type HandshakeStream =
574 Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
575 type ListFlightsStream =
576 Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
577 type DoGetStream = Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
578 type DoPutStream = Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
579 type DoActionStream =
580 Pin<Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>>;
581 type ListActionsStream =
582 Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
583 type DoExchangeStream =
584 Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
585
586 async fn handshake(
587 &self,
588 request: Request<Streaming<HandshakeRequest>>,
589 ) -> Result<Response<Self::HandshakeStream>, Status> {
590 let res = self.do_handshake(request).await?;
591 Ok(res)
592 }
593
594 async fn list_flights(
595 &self,
596 _request: Request<Criteria>,
597 ) -> Result<Response<Self::ListFlightsStream>, Status> {
598 Err(Status::unimplemented("Not yet implemented"))
599 }
600
601 async fn get_flight_info(
602 &self,
603 request: Request<FlightDescriptor>,
604 ) -> Result<Response<FlightInfo>, Status> {
605 let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
606
607 match Command::try_from(message).map_err(arrow_error_to_status)? {
608 Command::CommandStatementQuery(token) => {
609 self.get_flight_info_statement(token, request).await
610 }
611 Command::CommandPreparedStatementQuery(handle) => {
612 self.get_flight_info_prepared_statement(handle, request)
613 .await
614 }
615 Command::CommandStatementSubstraitPlan(handle) => {
616 self.get_flight_info_substrait_plan(handle, request).await
617 }
618 Command::CommandGetCatalogs(token) => {
619 self.get_flight_info_catalogs(token, request).await
620 }
621 Command::CommandGetDbSchemas(token) => {
622 return self.get_flight_info_schemas(token, request).await
623 }
624 Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await,
625 Command::CommandGetTableTypes(token) => {
626 self.get_flight_info_table_types(token, request).await
627 }
628 Command::CommandGetSqlInfo(token) => {
629 self.get_flight_info_sql_info(token, request).await
630 }
631 Command::CommandGetPrimaryKeys(token) => {
632 self.get_flight_info_primary_keys(token, request).await
633 }
634 Command::CommandGetExportedKeys(token) => {
635 self.get_flight_info_exported_keys(token, request).await
636 }
637 Command::CommandGetImportedKeys(token) => {
638 self.get_flight_info_imported_keys(token, request).await
639 }
640 Command::CommandGetCrossReference(token) => {
641 self.get_flight_info_cross_reference(token, request).await
642 }
643 Command::CommandGetXdbcTypeInfo(token) => {
644 self.get_flight_info_xdbc_type_info(token, request).await
645 }
646 cmd => self.get_flight_info_fallback(cmd, request).await,
647 }
648 }
649
650 async fn poll_flight_info(
651 &self,
652 _request: Request<FlightDescriptor>,
653 ) -> Result<Response<PollInfo>, Status> {
654 Err(Status::unimplemented("Not yet implemented"))
655 }
656
657 async fn get_schema(
658 &self,
659 _request: Request<FlightDescriptor>,
660 ) -> Result<Response<SchemaResult>, Status> {
661 Err(Status::unimplemented("Not yet implemented"))
662 }
663
664 async fn do_get(
665 &self,
666 request: Request<Ticket>,
667 ) -> Result<Response<Self::DoGetStream>, Status> {
668 let msg: Any =
669 Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?;
670
671 match Command::try_from(msg).map_err(arrow_error_to_status)? {
672 Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await,
673 Command::CommandPreparedStatementQuery(command) => {
674 self.do_get_prepared_statement(command, request).await
675 }
676 Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await,
677 Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await,
678 Command::CommandGetTables(command) => self.do_get_tables(command, request).await,
679 Command::CommandGetTableTypes(command) => {
680 self.do_get_table_types(command, request).await
681 }
682 Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await,
683 Command::CommandGetPrimaryKeys(command) => {
684 self.do_get_primary_keys(command, request).await
685 }
686 Command::CommandGetExportedKeys(command) => {
687 self.do_get_exported_keys(command, request).await
688 }
689 Command::CommandGetImportedKeys(command) => {
690 self.do_get_imported_keys(command, request).await
691 }
692 Command::CommandGetCrossReference(command) => {
693 self.do_get_cross_reference(command, request).await
694 }
695 Command::CommandGetXdbcTypeInfo(command) => {
696 self.do_get_xdbc_type_info(command, request).await
697 }
698 cmd => self.do_get_fallback(request, cmd.into_any()).await,
699 }
700 }
701
702 async fn do_put(
703 &self,
704 request: Request<Streaming<FlightData>>,
705 ) -> Result<Response<Self::DoPutStream>, Status> {
706 let mut request = request.map(PeekableFlightDataStream::new);
713 let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;
714
715 let message =
716 Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?;
717 match Command::try_from(message).map_err(arrow_error_to_status)? {
718 Command::CommandStatementUpdate(command) => {
719 let record_count = self.do_put_statement_update(command, request).await?;
720 let result = DoPutUpdateResult { record_count };
721 let output = futures::stream::iter(vec![Ok(PutResult {
722 app_metadata: result.encode_to_vec().into(),
723 })]);
724 Ok(Response::new(Box::pin(output)))
725 }
726 Command::CommandStatementIngest(command) => {
727 let record_count = self.do_put_statement_ingest(command, request).await?;
728 let result = DoPutUpdateResult { record_count };
729 let output = futures::stream::iter(vec![Ok(PutResult {
730 app_metadata: result.encode_to_vec().into(),
731 })]);
732 Ok(Response::new(Box::pin(output)))
733 }
734 Command::CommandPreparedStatementQuery(command) => {
735 let result = self
736 .do_put_prepared_statement_query(command, request)
737 .await?;
738 let output = futures::stream::iter(vec![Ok(PutResult {
739 app_metadata: result.encode_to_vec().into(),
740 })]);
741 Ok(Response::new(Box::pin(output)))
742 }
743 Command::CommandStatementSubstraitPlan(command) => {
744 let record_count = self.do_put_substrait_plan(command, request).await?;
745 let result = DoPutUpdateResult { record_count };
746 let output = futures::stream::iter(vec![Ok(PutResult {
747 app_metadata: result.encode_to_vec().into(),
748 })]);
749 Ok(Response::new(Box::pin(output)))
750 }
751 Command::CommandPreparedStatementUpdate(command) => {
752 let record_count = self
753 .do_put_prepared_statement_update(command, request)
754 .await?;
755 let result = DoPutUpdateResult { record_count };
756 let output = futures::stream::iter(vec![Ok(PutResult {
757 app_metadata: result.encode_to_vec().into(),
758 })]);
759 Ok(Response::new(Box::pin(output)))
760 }
761 cmd => self.do_put_fallback(request, cmd.into_any()).await,
762 }
763 }
764
765 async fn list_actions(
766 &self,
767 _request: Request<Empty>,
768 ) -> Result<Response<Self::ListActionsStream>, Status> {
769 let create_prepared_statement_action_type = ActionType {
770 r#type: CREATE_PREPARED_STATEMENT.to_string(),
771 description: "Creates a reusable prepared statement resource on the server.\n
772 Request Message: ActionCreatePreparedStatementRequest\n
773 Response Message: ActionCreatePreparedStatementResult"
774 .into(),
775 };
776 let close_prepared_statement_action_type = ActionType {
777 r#type: CLOSE_PREPARED_STATEMENT.to_string(),
778 description: "Closes a reusable prepared statement resource on the server.\n
779 Request Message: ActionClosePreparedStatementRequest\n
780 Response Message: N/A"
781 .into(),
782 };
783 let create_prepared_substrait_plan_action_type = ActionType {
784 r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(),
785 description: "Creates a reusable prepared substrait plan resource on the server.\n
786 Request Message: ActionCreatePreparedSubstraitPlanRequest\n
787 Response Message: ActionCreatePreparedStatementResult"
788 .into(),
789 };
790 let begin_transaction_action_type = ActionType {
791 r#type: BEGIN_TRANSACTION.to_string(),
792 description: "Begins a transaction.\n
793 Request Message: ActionBeginTransactionRequest\n
794 Response Message: ActionBeginTransactionResult"
795 .into(),
796 };
797 let end_transaction_action_type = ActionType {
798 r#type: END_TRANSACTION.to_string(),
799 description: "Ends a transaction\n
800 Request Message: ActionEndTransactionRequest\n
801 Response Message: N/A"
802 .into(),
803 };
804 let begin_savepoint_action_type = ActionType {
805 r#type: BEGIN_SAVEPOINT.to_string(),
806 description: "Begins a savepoint.\n
807 Request Message: ActionBeginSavepointRequest\n
808 Response Message: ActionBeginSavepointResult"
809 .into(),
810 };
811 let end_savepoint_action_type = ActionType {
812 r#type: END_SAVEPOINT.to_string(),
813 description: "Ends a savepoint\n
814 Request Message: ActionEndSavepointRequest\n
815 Response Message: N/A"
816 .into(),
817 };
818 let cancel_query_action_type = ActionType {
819 r#type: CANCEL_QUERY.to_string(),
820 description: "Cancels a query\n
821 Request Message: ActionCancelQueryRequest\n
822 Response Message: ActionCancelQueryResult"
823 .into(),
824 };
825 let mut actions: Vec<Result<ActionType, Status>> = vec![
826 Ok(create_prepared_statement_action_type),
827 Ok(close_prepared_statement_action_type),
828 Ok(create_prepared_substrait_plan_action_type),
829 Ok(begin_transaction_action_type),
830 Ok(end_transaction_action_type),
831 Ok(begin_savepoint_action_type),
832 Ok(end_savepoint_action_type),
833 Ok(cancel_query_action_type),
834 ];
835
836 if let Some(mut custom_actions) = self.list_custom_actions().await {
837 actions.append(&mut custom_actions);
838 }
839
840 let output = futures::stream::iter(actions);
841 Ok(Response::new(Box::pin(output) as Self::ListActionsStream))
842 }
843
844 async fn do_action(
845 &self,
846 request: Request<Action>,
847 ) -> Result<Response<Self::DoActionStream>, Status> {
848 if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
849 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
850
851 let cmd: ActionCreatePreparedStatementRequest = any
852 .unpack()
853 .map_err(arrow_error_to_status)?
854 .ok_or_else(|| {
855 Status::invalid_argument(
856 "Unable to unpack ActionCreatePreparedStatementRequest.",
857 )
858 })?;
859 let stmt = self
860 .do_action_create_prepared_statement(cmd, request)
861 .await?;
862 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
863 body: stmt.as_any().encode_to_vec().into(),
864 })]);
865 return Ok(Response::new(Box::pin(output)));
866 } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
867 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
868
869 let cmd: ActionClosePreparedStatementRequest = any
870 .unpack()
871 .map_err(arrow_error_to_status)?
872 .ok_or_else(|| {
873 Status::invalid_argument(
874 "Unable to unpack ActionClosePreparedStatementRequest.",
875 )
876 })?;
877 self.do_action_close_prepared_statement(cmd, request)
878 .await?;
879 return Ok(Response::new(Box::pin(futures::stream::empty())));
880 } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN {
881 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
882
883 let cmd: ActionCreatePreparedSubstraitPlanRequest = any
884 .unpack()
885 .map_err(arrow_error_to_status)?
886 .ok_or_else(|| {
887 Status::invalid_argument(
888 "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.",
889 )
890 })?;
891 self.do_action_create_prepared_substrait_plan(cmd, request)
892 .await?;
893 return Ok(Response::new(Box::pin(futures::stream::empty())));
894 } else if request.get_ref().r#type == BEGIN_TRANSACTION {
895 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
896
897 let cmd: ActionBeginTransactionRequest = any
898 .unpack()
899 .map_err(arrow_error_to_status)?
900 .ok_or_else(|| {
901 Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.")
902 })?;
903 let stmt = self.do_action_begin_transaction(cmd, request).await?;
904 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
905 body: stmt.as_any().encode_to_vec().into(),
906 })]);
907 return Ok(Response::new(Box::pin(output)));
908 } else if request.get_ref().r#type == END_TRANSACTION {
909 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
910
911 let cmd: ActionEndTransactionRequest = any
912 .unpack()
913 .map_err(arrow_error_to_status)?
914 .ok_or_else(|| {
915 Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.")
916 })?;
917 self.do_action_end_transaction(cmd, request).await?;
918 return Ok(Response::new(Box::pin(futures::stream::empty())));
919 } else if request.get_ref().r#type == BEGIN_SAVEPOINT {
920 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
921
922 let cmd: ActionBeginSavepointRequest = any
923 .unpack()
924 .map_err(arrow_error_to_status)?
925 .ok_or_else(|| {
926 Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.")
927 })?;
928 let stmt = self.do_action_begin_savepoint(cmd, request).await?;
929 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
930 body: stmt.as_any().encode_to_vec().into(),
931 })]);
932 return Ok(Response::new(Box::pin(output)));
933 } else if request.get_ref().r#type == END_SAVEPOINT {
934 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
935
936 let cmd: ActionEndSavepointRequest = any
937 .unpack()
938 .map_err(arrow_error_to_status)?
939 .ok_or_else(|| {
940 Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.")
941 })?;
942 self.do_action_end_savepoint(cmd, request).await?;
943 return Ok(Response::new(Box::pin(futures::stream::empty())));
944 } else if request.get_ref().r#type == CANCEL_QUERY {
945 let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
946
947 let cmd: ActionCancelQueryRequest = any
948 .unpack()
949 .map_err(arrow_error_to_status)?
950 .ok_or_else(|| {
951 Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.")
952 })?;
953 let stmt = self.do_action_cancel_query(cmd, request).await?;
954 let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
955 body: stmt.as_any().encode_to_vec().into(),
956 })]);
957 return Ok(Response::new(Box::pin(output)));
958 }
959
960 self.do_action_fallback(request).await
961 }
962
963 async fn do_exchange(
964 &self,
965 request: Request<Streaming<FlightData>>,
966 ) -> Result<Response<Self::DoExchangeStream>, Status> {
967 self.do_exchange_fallback(request).await
968 }
969}
970
971fn decode_error_to_status(err: prost::DecodeError) -> Status {
972 Status::invalid_argument(format!("{err:?}"))
973}
974
975fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
976 Status::internal(format!("{err:?}"))
977}
978
979pub struct PeekableFlightDataStream {
1025 inner: Peekable<Streaming<FlightData>>,
1026}
1027
1028impl PeekableFlightDataStream {
1029 fn new(stream: Streaming<FlightData>) -> Self {
1030 Self {
1031 inner: stream.peekable(),
1032 }
1033 }
1034
1035 pub fn into_inner(self) -> Streaming<FlightData> {
1039 self.inner.into_inner()
1040 }
1041
1042 pub fn into_peekable(self) -> Peekable<Streaming<FlightData>> {
1046 self.inner
1047 }
1048
1049 pub async fn peek(&mut self) -> Option<&Result<FlightData, Status>> {
1051 Pin::new(&mut self.inner).peek().await
1052 }
1053}
1054
1055impl Stream for PeekableFlightDataStream {
1056 type Item = Result<FlightData, Status>;
1057
1058 fn poll_next(
1059 mut self: Pin<&mut Self>,
1060 cx: &mut std::task::Context<'_>,
1061 ) -> std::task::Poll<Option<Self::Item>> {
1062 self.inner.poll_next_unpin(cx)
1063 }
1064}