arrow_flight/sql/
server.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL.
19
20use std::pin::Pin;
21
22use futures::{stream::Peekable, Stream, StreamExt};
23use prost::Message;
24use tonic::{Request, Response, Status, Streaming};
25
26use super::{
27    ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
28    ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
29    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
30    ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
31    ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs,
32    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
33    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
34    CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
35    CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
36    CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
37    SqlInfo, TicketStatementQuery,
38};
39use crate::{
40    flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
41    FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
42    SchemaResult, Ticket,
43};
44
45pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
46pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
47pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan";
48pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction";
49pub(crate) static END_TRANSACTION: &str = "EndTransaction";
50pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint";
51pub(crate) static END_SAVEPOINT: &str = "EndSavepoint";
52pub(crate) static CANCEL_QUERY: &str = "CancelQuery";
53
54/// Implements FlightSqlService to handle the flight sql protocol
55#[tonic::async_trait]
56pub trait FlightSqlService: Sync + Send + Sized + 'static {
57    /// When impl FlightSqlService, you can always set FlightService to Self
58    type FlightService: FlightService;
59
60    /// Accept authentication and return a token
61    /// <https://arrow.apache.org/docs/format/Flight.html#authentication>
62    async fn do_handshake(
63        &self,
64        _request: Request<Streaming<HandshakeRequest>>,
65    ) -> Result<
66        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
67        Status,
68    > {
69        Err(Status::unimplemented(
70            "Handshake has no default implementation",
71        ))
72    }
73
74    /// Implementors may override to handle additional calls to do_get()
75    async fn do_get_fallback(
76        &self,
77        _request: Request<Ticket>,
78        message: Any,
79    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
80        Err(Status::unimplemented(format!(
81            "do_get: The defined request is invalid: {}",
82            message.type_url
83        )))
84    }
85
86    /// Get a FlightInfo for executing a SQL query.
87    async fn get_flight_info_statement(
88        &self,
89        _query: CommandStatementQuery,
90        _request: Request<FlightDescriptor>,
91    ) -> Result<Response<FlightInfo>, Status> {
92        Err(Status::unimplemented(
93            "get_flight_info_statement has no default implementation",
94        ))
95    }
96
97    /// Get a FlightInfo for executing a substrait plan.
98    async fn get_flight_info_substrait_plan(
99        &self,
100        _query: CommandStatementSubstraitPlan,
101        _request: Request<FlightDescriptor>,
102    ) -> Result<Response<FlightInfo>, Status> {
103        Err(Status::unimplemented(
104            "get_flight_info_substrait_plan has no default implementation",
105        ))
106    }
107
108    /// Get a FlightInfo for executing an already created prepared statement.
109    async fn get_flight_info_prepared_statement(
110        &self,
111        _query: CommandPreparedStatementQuery,
112        _request: Request<FlightDescriptor>,
113    ) -> Result<Response<FlightInfo>, Status> {
114        Err(Status::unimplemented(
115            "get_flight_info_prepared_statement has no default implementation",
116        ))
117    }
118
119    /// Get a FlightInfo for listing catalogs.
120    async fn get_flight_info_catalogs(
121        &self,
122        _query: CommandGetCatalogs,
123        _request: Request<FlightDescriptor>,
124    ) -> Result<Response<FlightInfo>, Status> {
125        Err(Status::unimplemented(
126            "get_flight_info_catalogs has no default implementation",
127        ))
128    }
129
130    /// Get a FlightInfo for listing schemas.
131    async fn get_flight_info_schemas(
132        &self,
133        _query: CommandGetDbSchemas,
134        _request: Request<FlightDescriptor>,
135    ) -> Result<Response<FlightInfo>, Status> {
136        Err(Status::unimplemented(
137            "get_flight_info_schemas has no default implementation",
138        ))
139    }
140
141    /// Get a FlightInfo for listing tables.
142    async fn get_flight_info_tables(
143        &self,
144        _query: CommandGetTables,
145        _request: Request<FlightDescriptor>,
146    ) -> Result<Response<FlightInfo>, Status> {
147        Err(Status::unimplemented(
148            "get_flight_info_tables has no default implementation",
149        ))
150    }
151
152    /// Get a FlightInfo to extract information about the table types.
153    async fn get_flight_info_table_types(
154        &self,
155        _query: CommandGetTableTypes,
156        _request: Request<FlightDescriptor>,
157    ) -> Result<Response<FlightInfo>, Status> {
158        Err(Status::unimplemented(
159            "get_flight_info_table_types has no default implementation",
160        ))
161    }
162
163    /// Get a FlightInfo for retrieving other information (See SqlInfo).
164    async fn get_flight_info_sql_info(
165        &self,
166        _query: CommandGetSqlInfo,
167        _request: Request<FlightDescriptor>,
168    ) -> Result<Response<FlightInfo>, Status> {
169        Err(Status::unimplemented(
170            "get_flight_info_sql_info has no default implementation",
171        ))
172    }
173
174    /// Get a FlightInfo to extract information about primary and foreign keys.
175    async fn get_flight_info_primary_keys(
176        &self,
177        _query: CommandGetPrimaryKeys,
178        _request: Request<FlightDescriptor>,
179    ) -> Result<Response<FlightInfo>, Status> {
180        Err(Status::unimplemented(
181            "get_flight_info_primary_keys has no default implementation",
182        ))
183    }
184
185    /// Get a FlightInfo to extract information about exported keys.
186    async fn get_flight_info_exported_keys(
187        &self,
188        _query: CommandGetExportedKeys,
189        _request: Request<FlightDescriptor>,
190    ) -> Result<Response<FlightInfo>, Status> {
191        Err(Status::unimplemented(
192            "get_flight_info_exported_keys has no default implementation",
193        ))
194    }
195
196    /// Get a FlightInfo to extract information about imported keys.
197    async fn get_flight_info_imported_keys(
198        &self,
199        _query: CommandGetImportedKeys,
200        _request: Request<FlightDescriptor>,
201    ) -> Result<Response<FlightInfo>, Status> {
202        Err(Status::unimplemented(
203            "get_flight_info_imported_keys has no default implementation",
204        ))
205    }
206
207    /// Get a FlightInfo to extract information about cross reference.
208    async fn get_flight_info_cross_reference(
209        &self,
210        _query: CommandGetCrossReference,
211        _request: Request<FlightDescriptor>,
212    ) -> Result<Response<FlightInfo>, Status> {
213        Err(Status::unimplemented(
214            "get_flight_info_cross_reference has no default implementation",
215        ))
216    }
217
218    /// Get a FlightInfo to extract information about the supported XDBC types.
219    async fn get_flight_info_xdbc_type_info(
220        &self,
221        _query: CommandGetXdbcTypeInfo,
222        _request: Request<FlightDescriptor>,
223    ) -> Result<Response<FlightInfo>, Status> {
224        Err(Status::unimplemented(
225            "get_flight_info_xdbc_type_info has no default implementation",
226        ))
227    }
228
229    /// Implementors may override to handle additional calls to get_flight_info()
230    async fn get_flight_info_fallback(
231        &self,
232        cmd: Command,
233        _request: Request<FlightDescriptor>,
234    ) -> Result<Response<FlightInfo>, Status> {
235        Err(Status::unimplemented(format!(
236            "get_flight_info: The defined request is invalid: {}",
237            cmd.type_url()
238        )))
239    }
240
241    // do_get
242
243    /// Get a FlightDataStream containing the query results.
244    async fn do_get_statement(
245        &self,
246        _ticket: TicketStatementQuery,
247        _request: Request<Ticket>,
248    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
249        Err(Status::unimplemented(
250            "do_get_statement has no default implementation",
251        ))
252    }
253
254    /// Get a FlightDataStream containing the prepared statement query results.
255    async fn do_get_prepared_statement(
256        &self,
257        _query: CommandPreparedStatementQuery,
258        _request: Request<Ticket>,
259    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
260        Err(Status::unimplemented(
261            "do_get_prepared_statement has no default implementation",
262        ))
263    }
264
265    /// Get a FlightDataStream containing the list of catalogs.
266    async fn do_get_catalogs(
267        &self,
268        _query: CommandGetCatalogs,
269        _request: Request<Ticket>,
270    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
271        Err(Status::unimplemented(
272            "do_get_catalogs has no default implementation",
273        ))
274    }
275
276    /// Get a FlightDataStream containing the list of schemas.
277    async fn do_get_schemas(
278        &self,
279        _query: CommandGetDbSchemas,
280        _request: Request<Ticket>,
281    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
282        Err(Status::unimplemented(
283            "do_get_schemas has no default implementation",
284        ))
285    }
286
287    /// Get a FlightDataStream containing the list of tables.
288    async fn do_get_tables(
289        &self,
290        _query: CommandGetTables,
291        _request: Request<Ticket>,
292    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
293        Err(Status::unimplemented(
294            "do_get_tables has no default implementation",
295        ))
296    }
297
298    /// Get a FlightDataStream containing the data related to the table types.
299    async fn do_get_table_types(
300        &self,
301        _query: CommandGetTableTypes,
302        _request: Request<Ticket>,
303    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
304        Err(Status::unimplemented(
305            "do_get_table_types has no default implementation",
306        ))
307    }
308
309    /// Get a FlightDataStream containing the list of SqlInfo results.
310    async fn do_get_sql_info(
311        &self,
312        _query: CommandGetSqlInfo,
313        _request: Request<Ticket>,
314    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
315        Err(Status::unimplemented(
316            "do_get_sql_info has no default implementation",
317        ))
318    }
319
320    /// Get a FlightDataStream containing the data related to the primary and foreign keys.
321    async fn do_get_primary_keys(
322        &self,
323        _query: CommandGetPrimaryKeys,
324        _request: Request<Ticket>,
325    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
326        Err(Status::unimplemented(
327            "do_get_primary_keys has no default implementation",
328        ))
329    }
330
331    /// Get a FlightDataStream containing the data related to the exported keys.
332    async fn do_get_exported_keys(
333        &self,
334        _query: CommandGetExportedKeys,
335        _request: Request<Ticket>,
336    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
337        Err(Status::unimplemented(
338            "do_get_exported_keys has no default implementation",
339        ))
340    }
341
342    /// Get a FlightDataStream containing the data related to the imported keys.
343    async fn do_get_imported_keys(
344        &self,
345        _query: CommandGetImportedKeys,
346        _request: Request<Ticket>,
347    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
348        Err(Status::unimplemented(
349            "do_get_imported_keys has no default implementation",
350        ))
351    }
352
353    /// Get a FlightDataStream containing the data related to the cross reference.
354    async fn do_get_cross_reference(
355        &self,
356        _query: CommandGetCrossReference,
357        _request: Request<Ticket>,
358    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
359        Err(Status::unimplemented(
360            "do_get_cross_reference has no default implementation",
361        ))
362    }
363
364    /// Get a FlightDataStream containing the data related to the supported XDBC types.
365    async fn do_get_xdbc_type_info(
366        &self,
367        _query: CommandGetXdbcTypeInfo,
368        _request: Request<Ticket>,
369    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
370        Err(Status::unimplemented(
371            "do_get_xdbc_type_info has no default implementation",
372        ))
373    }
374
375    // do_put
376
377    /// Implementors may override to handle additional calls to do_put()
378    async fn do_put_fallback(
379        &self,
380        _request: Request<PeekableFlightDataStream>,
381        message: Any,
382    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
383        Err(Status::unimplemented(format!(
384            "do_put: The defined request is invalid: {}",
385            message.type_url
386        )))
387    }
388
389    /// Execute an update SQL statement.
390    async fn do_put_statement_update(
391        &self,
392        _ticket: CommandStatementUpdate,
393        _request: Request<PeekableFlightDataStream>,
394    ) -> Result<i64, Status> {
395        Err(Status::unimplemented(
396            "do_put_statement_update has no default implementation",
397        ))
398    }
399
400    /// Execute a bulk ingestion.
401    async fn do_put_statement_ingest(
402        &self,
403        _ticket: CommandStatementIngest,
404        _request: Request<PeekableFlightDataStream>,
405    ) -> Result<i64, Status> {
406        Err(Status::unimplemented(
407            "do_put_statement_ingest has no default implementation",
408        ))
409    }
410
411    /// Bind parameters to given prepared statement.
412    ///
413    /// Returns an opaque handle that the client should pass
414    /// back to the server during subsequent requests with this
415    /// prepared statement.
416    async fn do_put_prepared_statement_query(
417        &self,
418        _query: CommandPreparedStatementQuery,
419        _request: Request<PeekableFlightDataStream>,
420    ) -> Result<DoPutPreparedStatementResult, Status> {
421        Err(Status::unimplemented(
422            "do_put_prepared_statement_query has no default implementation",
423        ))
424    }
425
426    /// Execute an update SQL prepared statement.
427    async fn do_put_prepared_statement_update(
428        &self,
429        _query: CommandPreparedStatementUpdate,
430        _request: Request<PeekableFlightDataStream>,
431    ) -> Result<i64, Status> {
432        Err(Status::unimplemented(
433            "do_put_prepared_statement_update has no default implementation",
434        ))
435    }
436
437    /// Execute a substrait plan
438    async fn do_put_substrait_plan(
439        &self,
440        _query: CommandStatementSubstraitPlan,
441        _request: Request<PeekableFlightDataStream>,
442    ) -> Result<i64, Status> {
443        Err(Status::unimplemented(
444            "do_put_substrait_plan has no default implementation",
445        ))
446    }
447
448    // do_action
449
450    /// Implementors may override to handle additional calls to do_action()
451    async fn do_action_fallback(
452        &self,
453        request: Request<Action>,
454    ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
455        Err(Status::invalid_argument(format!(
456            "do_action: The defined request is invalid: {:?}",
457            request.get_ref().r#type
458        )))
459    }
460
461    /// Add custom actions to list_actions() result
462    async fn list_custom_actions(&self) -> Option<Vec<Result<ActionType, Status>>> {
463        None
464    }
465
466    /// Create a prepared statement from given SQL statement.
467    async fn do_action_create_prepared_statement(
468        &self,
469        _query: ActionCreatePreparedStatementRequest,
470        _request: Request<Action>,
471    ) -> Result<ActionCreatePreparedStatementResult, Status> {
472        Err(Status::unimplemented(
473            "do_action_create_prepared_statement has no default implementation",
474        ))
475    }
476
477    /// Close a prepared statement.
478    async fn do_action_close_prepared_statement(
479        &self,
480        _query: ActionClosePreparedStatementRequest,
481        _request: Request<Action>,
482    ) -> Result<(), Status> {
483        Err(Status::unimplemented(
484            "do_action_close_prepared_statement has no default implementation",
485        ))
486    }
487
488    /// Create a prepared substrait plan.
489    async fn do_action_create_prepared_substrait_plan(
490        &self,
491        _query: ActionCreatePreparedSubstraitPlanRequest,
492        _request: Request<Action>,
493    ) -> Result<ActionCreatePreparedStatementResult, Status> {
494        Err(Status::unimplemented(
495            "do_action_create_prepared_substrait_plan has no default implementation",
496        ))
497    }
498
499    /// Begin a transaction
500    async fn do_action_begin_transaction(
501        &self,
502        _query: ActionBeginTransactionRequest,
503        _request: Request<Action>,
504    ) -> Result<ActionBeginTransactionResult, Status> {
505        Err(Status::unimplemented(
506            "do_action_begin_transaction has no default implementation",
507        ))
508    }
509
510    /// End a transaction
511    async fn do_action_end_transaction(
512        &self,
513        _query: ActionEndTransactionRequest,
514        _request: Request<Action>,
515    ) -> Result<(), Status> {
516        Err(Status::unimplemented(
517            "do_action_end_transaction has no default implementation",
518        ))
519    }
520
521    /// Begin a savepoint
522    async fn do_action_begin_savepoint(
523        &self,
524        _query: ActionBeginSavepointRequest,
525        _request: Request<Action>,
526    ) -> Result<ActionBeginSavepointResult, Status> {
527        Err(Status::unimplemented(
528            "do_action_begin_savepoint has no default implementation",
529        ))
530    }
531
532    /// End a savepoint
533    async fn do_action_end_savepoint(
534        &self,
535        _query: ActionEndSavepointRequest,
536        _request: Request<Action>,
537    ) -> Result<(), Status> {
538        Err(Status::unimplemented(
539            "do_action_end_savepoint has no default implementation",
540        ))
541    }
542
543    /// Cancel a query
544    async fn do_action_cancel_query(
545        &self,
546        _query: ActionCancelQueryRequest,
547        _request: Request<Action>,
548    ) -> Result<ActionCancelQueryResult, Status> {
549        Err(Status::unimplemented(
550            "do_action_cancel_query has no default implementation",
551        ))
552    }
553
554    /// do_exchange
555    /// Implementors may override to handle additional calls to do_exchange()
556    async fn do_exchange_fallback(
557        &self,
558        _request: Request<Streaming<FlightData>>,
559    ) -> Result<Response<<Self as FlightService>::DoExchangeStream>, Status> {
560        Err(Status::unimplemented("Not yet implemented"))
561    }
562
563    /// Register a new SqlInfo result, making it available when calling GetSqlInfo.
564    async fn register_sql_info(&self, id: i32, result: &SqlInfo);
565}
566
567/// Implements the lower level interface to handle FlightSQL
568#[tonic::async_trait]
569impl<T: 'static> FlightService for T
570where
571    T: FlightSqlService + Send,
572{
573    type HandshakeStream =
574        Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
575    type ListFlightsStream =
576        Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
577    type DoGetStream = Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
578    type DoPutStream = Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
579    type DoActionStream =
580        Pin<Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>>;
581    type ListActionsStream =
582        Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
583    type DoExchangeStream =
584        Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
585
586    async fn handshake(
587        &self,
588        request: Request<Streaming<HandshakeRequest>>,
589    ) -> Result<Response<Self::HandshakeStream>, Status> {
590        let res = self.do_handshake(request).await?;
591        Ok(res)
592    }
593
594    async fn list_flights(
595        &self,
596        _request: Request<Criteria>,
597    ) -> Result<Response<Self::ListFlightsStream>, Status> {
598        Err(Status::unimplemented("Not yet implemented"))
599    }
600
601    async fn get_flight_info(
602        &self,
603        request: Request<FlightDescriptor>,
604    ) -> Result<Response<FlightInfo>, Status> {
605        let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
606
607        match Command::try_from(message).map_err(arrow_error_to_status)? {
608            Command::CommandStatementQuery(token) => {
609                self.get_flight_info_statement(token, request).await
610            }
611            Command::CommandPreparedStatementQuery(handle) => {
612                self.get_flight_info_prepared_statement(handle, request)
613                    .await
614            }
615            Command::CommandStatementSubstraitPlan(handle) => {
616                self.get_flight_info_substrait_plan(handle, request).await
617            }
618            Command::CommandGetCatalogs(token) => {
619                self.get_flight_info_catalogs(token, request).await
620            }
621            Command::CommandGetDbSchemas(token) => {
622                return self.get_flight_info_schemas(token, request).await
623            }
624            Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await,
625            Command::CommandGetTableTypes(token) => {
626                self.get_flight_info_table_types(token, request).await
627            }
628            Command::CommandGetSqlInfo(token) => {
629                self.get_flight_info_sql_info(token, request).await
630            }
631            Command::CommandGetPrimaryKeys(token) => {
632                self.get_flight_info_primary_keys(token, request).await
633            }
634            Command::CommandGetExportedKeys(token) => {
635                self.get_flight_info_exported_keys(token, request).await
636            }
637            Command::CommandGetImportedKeys(token) => {
638                self.get_flight_info_imported_keys(token, request).await
639            }
640            Command::CommandGetCrossReference(token) => {
641                self.get_flight_info_cross_reference(token, request).await
642            }
643            Command::CommandGetXdbcTypeInfo(token) => {
644                self.get_flight_info_xdbc_type_info(token, request).await
645            }
646            cmd => self.get_flight_info_fallback(cmd, request).await,
647        }
648    }
649
650    async fn poll_flight_info(
651        &self,
652        _request: Request<FlightDescriptor>,
653    ) -> Result<Response<PollInfo>, Status> {
654        Err(Status::unimplemented("Not yet implemented"))
655    }
656
657    async fn get_schema(
658        &self,
659        _request: Request<FlightDescriptor>,
660    ) -> Result<Response<SchemaResult>, Status> {
661        Err(Status::unimplemented("Not yet implemented"))
662    }
663
664    async fn do_get(
665        &self,
666        request: Request<Ticket>,
667    ) -> Result<Response<Self::DoGetStream>, Status> {
668        let msg: Any =
669            Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?;
670
671        match Command::try_from(msg).map_err(arrow_error_to_status)? {
672            Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await,
673            Command::CommandPreparedStatementQuery(command) => {
674                self.do_get_prepared_statement(command, request).await
675            }
676            Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await,
677            Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await,
678            Command::CommandGetTables(command) => self.do_get_tables(command, request).await,
679            Command::CommandGetTableTypes(command) => {
680                self.do_get_table_types(command, request).await
681            }
682            Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await,
683            Command::CommandGetPrimaryKeys(command) => {
684                self.do_get_primary_keys(command, request).await
685            }
686            Command::CommandGetExportedKeys(command) => {
687                self.do_get_exported_keys(command, request).await
688            }
689            Command::CommandGetImportedKeys(command) => {
690                self.do_get_imported_keys(command, request).await
691            }
692            Command::CommandGetCrossReference(command) => {
693                self.do_get_cross_reference(command, request).await
694            }
695            Command::CommandGetXdbcTypeInfo(command) => {
696                self.do_get_xdbc_type_info(command, request).await
697            }
698            cmd => self.do_get_fallback(request, cmd.into_any()).await,
699        }
700    }
701
702    async fn do_put(
703        &self,
704        request: Request<Streaming<FlightData>>,
705    ) -> Result<Response<Self::DoPutStream>, Status> {
706        // See issue #4658: https://github.com/apache/arrow-rs/issues/4658
707        // To dispatch to the correct `do_put` method, we cannot discard the first message,
708        // as it may contain the Arrow schema, which the `do_put` handler may need.
709        // To allow the first message to be reused by the `do_put` handler,
710        // we wrap this stream in a `Peekable` one, which allows us to peek at
711        // the first message without discarding it.
712        let mut request = request.map(PeekableFlightDataStream::new);
713        let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?;
714
715        let message =
716            Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?;
717        match Command::try_from(message).map_err(arrow_error_to_status)? {
718            Command::CommandStatementUpdate(command) => {
719                let record_count = self.do_put_statement_update(command, request).await?;
720                let result = DoPutUpdateResult { record_count };
721                let output = futures::stream::iter(vec![Ok(PutResult {
722                    app_metadata: result.encode_to_vec().into(),
723                })]);
724                Ok(Response::new(Box::pin(output)))
725            }
726            Command::CommandStatementIngest(command) => {
727                let record_count = self.do_put_statement_ingest(command, request).await?;
728                let result = DoPutUpdateResult { record_count };
729                let output = futures::stream::iter(vec![Ok(PutResult {
730                    app_metadata: result.encode_to_vec().into(),
731                })]);
732                Ok(Response::new(Box::pin(output)))
733            }
734            Command::CommandPreparedStatementQuery(command) => {
735                let result = self
736                    .do_put_prepared_statement_query(command, request)
737                    .await?;
738                let output = futures::stream::iter(vec![Ok(PutResult {
739                    app_metadata: result.encode_to_vec().into(),
740                })]);
741                Ok(Response::new(Box::pin(output)))
742            }
743            Command::CommandStatementSubstraitPlan(command) => {
744                let record_count = self.do_put_substrait_plan(command, request).await?;
745                let result = DoPutUpdateResult { record_count };
746                let output = futures::stream::iter(vec![Ok(PutResult {
747                    app_metadata: result.encode_to_vec().into(),
748                })]);
749                Ok(Response::new(Box::pin(output)))
750            }
751            Command::CommandPreparedStatementUpdate(command) => {
752                let record_count = self
753                    .do_put_prepared_statement_update(command, request)
754                    .await?;
755                let result = DoPutUpdateResult { record_count };
756                let output = futures::stream::iter(vec![Ok(PutResult {
757                    app_metadata: result.encode_to_vec().into(),
758                })]);
759                Ok(Response::new(Box::pin(output)))
760            }
761            cmd => self.do_put_fallback(request, cmd.into_any()).await,
762        }
763    }
764
765    async fn list_actions(
766        &self,
767        _request: Request<Empty>,
768    ) -> Result<Response<Self::ListActionsStream>, Status> {
769        let create_prepared_statement_action_type = ActionType {
770            r#type: CREATE_PREPARED_STATEMENT.to_string(),
771            description: "Creates a reusable prepared statement resource on the server.\n
772                Request Message: ActionCreatePreparedStatementRequest\n
773                Response Message: ActionCreatePreparedStatementResult"
774                .into(),
775        };
776        let close_prepared_statement_action_type = ActionType {
777            r#type: CLOSE_PREPARED_STATEMENT.to_string(),
778            description: "Closes a reusable prepared statement resource on the server.\n
779                Request Message: ActionClosePreparedStatementRequest\n
780                Response Message: N/A"
781                .into(),
782        };
783        let create_prepared_substrait_plan_action_type = ActionType {
784            r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(),
785            description: "Creates a reusable prepared substrait plan resource on the server.\n
786                Request Message: ActionCreatePreparedSubstraitPlanRequest\n
787                Response Message: ActionCreatePreparedStatementResult"
788                .into(),
789        };
790        let begin_transaction_action_type = ActionType {
791            r#type: BEGIN_TRANSACTION.to_string(),
792            description: "Begins a transaction.\n
793                Request Message: ActionBeginTransactionRequest\n
794                Response Message: ActionBeginTransactionResult"
795                .into(),
796        };
797        let end_transaction_action_type = ActionType {
798            r#type: END_TRANSACTION.to_string(),
799            description: "Ends a transaction\n
800                Request Message: ActionEndTransactionRequest\n
801                Response Message: N/A"
802                .into(),
803        };
804        let begin_savepoint_action_type = ActionType {
805            r#type: BEGIN_SAVEPOINT.to_string(),
806            description: "Begins a savepoint.\n
807                Request Message: ActionBeginSavepointRequest\n
808                Response Message: ActionBeginSavepointResult"
809                .into(),
810        };
811        let end_savepoint_action_type = ActionType {
812            r#type: END_SAVEPOINT.to_string(),
813            description: "Ends a savepoint\n
814                Request Message: ActionEndSavepointRequest\n
815                Response Message: N/A"
816                .into(),
817        };
818        let cancel_query_action_type = ActionType {
819            r#type: CANCEL_QUERY.to_string(),
820            description: "Cancels a query\n
821                Request Message: ActionCancelQueryRequest\n
822                Response Message: ActionCancelQueryResult"
823                .into(),
824        };
825        let mut actions: Vec<Result<ActionType, Status>> = vec![
826            Ok(create_prepared_statement_action_type),
827            Ok(close_prepared_statement_action_type),
828            Ok(create_prepared_substrait_plan_action_type),
829            Ok(begin_transaction_action_type),
830            Ok(end_transaction_action_type),
831            Ok(begin_savepoint_action_type),
832            Ok(end_savepoint_action_type),
833            Ok(cancel_query_action_type),
834        ];
835
836        if let Some(mut custom_actions) = self.list_custom_actions().await {
837            actions.append(&mut custom_actions);
838        }
839
840        let output = futures::stream::iter(actions);
841        Ok(Response::new(Box::pin(output) as Self::ListActionsStream))
842    }
843
844    async fn do_action(
845        &self,
846        request: Request<Action>,
847    ) -> Result<Response<Self::DoActionStream>, Status> {
848        if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
849            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
850
851            let cmd: ActionCreatePreparedStatementRequest = any
852                .unpack()
853                .map_err(arrow_error_to_status)?
854                .ok_or_else(|| {
855                    Status::invalid_argument(
856                        "Unable to unpack ActionCreatePreparedStatementRequest.",
857                    )
858                })?;
859            let stmt = self
860                .do_action_create_prepared_statement(cmd, request)
861                .await?;
862            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
863                body: stmt.as_any().encode_to_vec().into(),
864            })]);
865            return Ok(Response::new(Box::pin(output)));
866        } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
867            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
868
869            let cmd: ActionClosePreparedStatementRequest = any
870                .unpack()
871                .map_err(arrow_error_to_status)?
872                .ok_or_else(|| {
873                    Status::invalid_argument(
874                        "Unable to unpack ActionClosePreparedStatementRequest.",
875                    )
876                })?;
877            self.do_action_close_prepared_statement(cmd, request)
878                .await?;
879            return Ok(Response::new(Box::pin(futures::stream::empty())));
880        } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN {
881            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
882
883            let cmd: ActionCreatePreparedSubstraitPlanRequest = any
884                .unpack()
885                .map_err(arrow_error_to_status)?
886                .ok_or_else(|| {
887                    Status::invalid_argument(
888                        "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.",
889                    )
890                })?;
891            self.do_action_create_prepared_substrait_plan(cmd, request)
892                .await?;
893            return Ok(Response::new(Box::pin(futures::stream::empty())));
894        } else if request.get_ref().r#type == BEGIN_TRANSACTION {
895            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
896
897            let cmd: ActionBeginTransactionRequest = any
898                .unpack()
899                .map_err(arrow_error_to_status)?
900                .ok_or_else(|| {
901                Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.")
902            })?;
903            let stmt = self.do_action_begin_transaction(cmd, request).await?;
904            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
905                body: stmt.as_any().encode_to_vec().into(),
906            })]);
907            return Ok(Response::new(Box::pin(output)));
908        } else if request.get_ref().r#type == END_TRANSACTION {
909            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
910
911            let cmd: ActionEndTransactionRequest = any
912                .unpack()
913                .map_err(arrow_error_to_status)?
914                .ok_or_else(|| {
915                    Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.")
916                })?;
917            self.do_action_end_transaction(cmd, request).await?;
918            return Ok(Response::new(Box::pin(futures::stream::empty())));
919        } else if request.get_ref().r#type == BEGIN_SAVEPOINT {
920            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
921
922            let cmd: ActionBeginSavepointRequest = any
923                .unpack()
924                .map_err(arrow_error_to_status)?
925                .ok_or_else(|| {
926                    Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.")
927                })?;
928            let stmt = self.do_action_begin_savepoint(cmd, request).await?;
929            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
930                body: stmt.as_any().encode_to_vec().into(),
931            })]);
932            return Ok(Response::new(Box::pin(output)));
933        } else if request.get_ref().r#type == END_SAVEPOINT {
934            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
935
936            let cmd: ActionEndSavepointRequest = any
937                .unpack()
938                .map_err(arrow_error_to_status)?
939                .ok_or_else(|| {
940                    Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.")
941                })?;
942            self.do_action_end_savepoint(cmd, request).await?;
943            return Ok(Response::new(Box::pin(futures::stream::empty())));
944        } else if request.get_ref().r#type == CANCEL_QUERY {
945            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
946
947            let cmd: ActionCancelQueryRequest = any
948                .unpack()
949                .map_err(arrow_error_to_status)?
950                .ok_or_else(|| {
951                    Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.")
952                })?;
953            let stmt = self.do_action_cancel_query(cmd, request).await?;
954            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
955                body: stmt.as_any().encode_to_vec().into(),
956            })]);
957            return Ok(Response::new(Box::pin(output)));
958        }
959
960        self.do_action_fallback(request).await
961    }
962
963    async fn do_exchange(
964        &self,
965        request: Request<Streaming<FlightData>>,
966    ) -> Result<Response<Self::DoExchangeStream>, Status> {
967        self.do_exchange_fallback(request).await
968    }
969}
970
971fn decode_error_to_status(err: prost::DecodeError) -> Status {
972    Status::invalid_argument(format!("{err:?}"))
973}
974
975fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
976    Status::internal(format!("{err:?}"))
977}
978
979/// A wrapper around [`Streaming<FlightData>`] that allows "peeking" at the
980/// message at the front of the stream without consuming it.
981///
982/// This is needed because sometimes the first message in the stream will contain
983/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic
984/// must inspect this information.
985///
986/// # Example
987///
988/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without
989/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream.
990/// See the following example:
991///
992/// ```no_run
993/// use arrow_array::RecordBatch;
994/// use arrow_flight::decode::FlightRecordBatchStream;
995/// use arrow_flight::FlightDescriptor;
996/// use arrow_flight::error::FlightError;
997/// use arrow_flight::sql::server::PeekableFlightDataStream;
998/// use tonic::{Request, Status};
999/// use futures::TryStreamExt;
1000///
1001/// #[tokio::main]
1002/// async fn main() -> Result<(), Status> {
1003///     let request: Request<PeekableFlightDataStream> = todo!();
1004///     let stream: PeekableFlightDataStream = request.into_inner();
1005///
1006///     // The first message contains the flight descriptor and the schema.
1007///     // Read the flight descriptor without discarding the schema:
1008///     let flight_descriptor: FlightDescriptor = stream
1009///         .peek()
1010///         .await
1011///         .cloned()
1012///         .transpose()?
1013///         .and_then(|data| data.flight_descriptor)
1014///         .expect("first message should contain flight descriptor");
1015///
1016///     // Pass the stream through a decoder
1017///     let batches: Vec<RecordBatch> = FlightRecordBatchStream::new_from_flight_data(
1018///         request.into_inner().map_err(|e| e.into()),
1019///     )
1020///     .try_collect()
1021///     .await?;
1022/// }
1023/// ```
1024pub struct PeekableFlightDataStream {
1025    inner: Peekable<Streaming<FlightData>>,
1026}
1027
1028impl PeekableFlightDataStream {
1029    fn new(stream: Streaming<FlightData>) -> Self {
1030        Self {
1031            inner: stream.peekable(),
1032        }
1033    }
1034
1035    /// Convert this stream into a `Streaming<FlightData>`.
1036    /// Any messages observed through [`Self::peek`] will be lost
1037    /// after the conversion.
1038    pub fn into_inner(self) -> Streaming<FlightData> {
1039        self.inner.into_inner()
1040    }
1041
1042    /// Convert this stream into a `Peekable<Streaming<FlightData>>`.
1043    /// Preserves the state of the stream, so that calls to [`Self::peek`]
1044    /// and [`Self::poll_next`] are the same.
1045    pub fn into_peekable(self) -> Peekable<Streaming<FlightData>> {
1046        self.inner
1047    }
1048
1049    /// Peek at the head of this stream without advancing it.
1050    pub async fn peek(&mut self) -> Option<&Result<FlightData, Status>> {
1051        Pin::new(&mut self.inner).peek().await
1052    }
1053}
1054
1055impl Stream for PeekableFlightDataStream {
1056    type Item = Result<FlightData, Status>;
1057
1058    fn poll_next(
1059        mut self: Pin<&mut Self>,
1060        cx: &mut std::task::Context<'_>,
1061    ) -> std::task::Poll<Option<Self::Item>> {
1062        self.inner.poll_next_unpin(cx)
1063    }
1064}