arrow_flight/sql/
server.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL.
19
20use std::fmt::{Display, Formatter};
21use std::pin::Pin;
22
23use super::{
24    ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
25    ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
26    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
27    ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
28    ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs,
29    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
30    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
31    CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
32    CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
33    CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
34    SqlInfo, TicketStatementQuery,
35};
36use crate::{
37    flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
38    FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult,
39    SchemaResult, Ticket,
40};
41use futures::{stream::Peekable, Stream, StreamExt};
42use prost::Message;
43use tonic::{Request, Response, Status, Streaming};
44
45pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
46pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
47pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan";
48pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction";
49pub(crate) static END_TRANSACTION: &str = "EndTransaction";
50pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint";
51pub(crate) static END_SAVEPOINT: &str = "EndSavepoint";
52pub(crate) static CANCEL_QUERY: &str = "CancelQuery";
53
54/// Implements FlightSqlService to handle the flight sql protocol
55#[tonic::async_trait]
56pub trait FlightSqlService: Sync + Send + Sized + 'static {
57    /// When impl FlightSqlService, you can always set FlightService to Self
58    type FlightService: FlightService;
59
60    /// Accept authentication and return a token
61    /// <https://arrow.apache.org/docs/format/Flight.html#authentication>
62    async fn do_handshake(
63        &self,
64        _request: Request<Streaming<HandshakeRequest>>,
65    ) -> Result<
66        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
67        Status,
68    > {
69        Err(Status::unimplemented(
70            "Handshake has no default implementation",
71        ))
72    }
73
74    /// Implementors may override to handle additional calls to do_get()
75    async fn do_get_fallback(
76        &self,
77        _request: Request<Ticket>,
78        message: Any,
79    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
80        Err(Status::unimplemented(format!(
81            "do_get: The defined request is invalid: {}",
82            message.type_url
83        )))
84    }
85
86    /// Get a FlightInfo for executing a SQL query.
87    async fn get_flight_info_statement(
88        &self,
89        _query: CommandStatementQuery,
90        _request: Request<FlightDescriptor>,
91    ) -> Result<Response<FlightInfo>, Status> {
92        Err(Status::unimplemented(
93            "get_flight_info_statement has no default implementation",
94        ))
95    }
96
97    /// Get a FlightInfo for executing a substrait plan.
98    async fn get_flight_info_substrait_plan(
99        &self,
100        _query: CommandStatementSubstraitPlan,
101        _request: Request<FlightDescriptor>,
102    ) -> Result<Response<FlightInfo>, Status> {
103        Err(Status::unimplemented(
104            "get_flight_info_substrait_plan has no default implementation",
105        ))
106    }
107
108    /// Get a FlightInfo for executing an already created prepared statement.
109    async fn get_flight_info_prepared_statement(
110        &self,
111        _query: CommandPreparedStatementQuery,
112        _request: Request<FlightDescriptor>,
113    ) -> Result<Response<FlightInfo>, Status> {
114        Err(Status::unimplemented(
115            "get_flight_info_prepared_statement has no default implementation",
116        ))
117    }
118
119    /// Get a FlightInfo for listing catalogs.
120    async fn get_flight_info_catalogs(
121        &self,
122        _query: CommandGetCatalogs,
123        _request: Request<FlightDescriptor>,
124    ) -> Result<Response<FlightInfo>, Status> {
125        Err(Status::unimplemented(
126            "get_flight_info_catalogs has no default implementation",
127        ))
128    }
129
130    /// Get a FlightInfo for listing schemas.
131    async fn get_flight_info_schemas(
132        &self,
133        _query: CommandGetDbSchemas,
134        _request: Request<FlightDescriptor>,
135    ) -> Result<Response<FlightInfo>, Status> {
136        Err(Status::unimplemented(
137            "get_flight_info_schemas has no default implementation",
138        ))
139    }
140
141    /// Get a FlightInfo for listing tables.
142    async fn get_flight_info_tables(
143        &self,
144        _query: CommandGetTables,
145        _request: Request<FlightDescriptor>,
146    ) -> Result<Response<FlightInfo>, Status> {
147        Err(Status::unimplemented(
148            "get_flight_info_tables has no default implementation",
149        ))
150    }
151
152    /// Get a FlightInfo to extract information about the table types.
153    async fn get_flight_info_table_types(
154        &self,
155        _query: CommandGetTableTypes,
156        _request: Request<FlightDescriptor>,
157    ) -> Result<Response<FlightInfo>, Status> {
158        Err(Status::unimplemented(
159            "get_flight_info_table_types has no default implementation",
160        ))
161    }
162
163    /// Get a FlightInfo for retrieving other information (See SqlInfo).
164    async fn get_flight_info_sql_info(
165        &self,
166        _query: CommandGetSqlInfo,
167        _request: Request<FlightDescriptor>,
168    ) -> Result<Response<FlightInfo>, Status> {
169        Err(Status::unimplemented(
170            "get_flight_info_sql_info has no default implementation",
171        ))
172    }
173
174    /// Get a FlightInfo to extract information about primary and foreign keys.
175    async fn get_flight_info_primary_keys(
176        &self,
177        _query: CommandGetPrimaryKeys,
178        _request: Request<FlightDescriptor>,
179    ) -> Result<Response<FlightInfo>, Status> {
180        Err(Status::unimplemented(
181            "get_flight_info_primary_keys has no default implementation",
182        ))
183    }
184
185    /// Get a FlightInfo to extract information about exported keys.
186    async fn get_flight_info_exported_keys(
187        &self,
188        _query: CommandGetExportedKeys,
189        _request: Request<FlightDescriptor>,
190    ) -> Result<Response<FlightInfo>, Status> {
191        Err(Status::unimplemented(
192            "get_flight_info_exported_keys has no default implementation",
193        ))
194    }
195
196    /// Get a FlightInfo to extract information about imported keys.
197    async fn get_flight_info_imported_keys(
198        &self,
199        _query: CommandGetImportedKeys,
200        _request: Request<FlightDescriptor>,
201    ) -> Result<Response<FlightInfo>, Status> {
202        Err(Status::unimplemented(
203            "get_flight_info_imported_keys has no default implementation",
204        ))
205    }
206
207    /// Get a FlightInfo to extract information about cross reference.
208    async fn get_flight_info_cross_reference(
209        &self,
210        _query: CommandGetCrossReference,
211        _request: Request<FlightDescriptor>,
212    ) -> Result<Response<FlightInfo>, Status> {
213        Err(Status::unimplemented(
214            "get_flight_info_cross_reference has no default implementation",
215        ))
216    }
217
218    /// Get a FlightInfo to extract information about the supported XDBC types.
219    async fn get_flight_info_xdbc_type_info(
220        &self,
221        _query: CommandGetXdbcTypeInfo,
222        _request: Request<FlightDescriptor>,
223    ) -> Result<Response<FlightInfo>, Status> {
224        Err(Status::unimplemented(
225            "get_flight_info_xdbc_type_info has no default implementation",
226        ))
227    }
228
229    /// Implementors may override to handle additional calls to get_flight_info()
230    async fn get_flight_info_fallback(
231        &self,
232        cmd: Command,
233        _request: Request<FlightDescriptor>,
234    ) -> Result<Response<FlightInfo>, Status> {
235        Err(Status::unimplemented(format!(
236            "get_flight_info: The defined request is invalid: {}",
237            cmd.type_url()
238        )))
239    }
240
241    // do_get
242
243    /// Get a FlightDataStream containing the query results.
244    async fn do_get_statement(
245        &self,
246        _ticket: TicketStatementQuery,
247        _request: Request<Ticket>,
248    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
249        Err(Status::unimplemented(
250            "do_get_statement has no default implementation",
251        ))
252    }
253
254    /// Get a FlightDataStream containing the prepared statement query results.
255    async fn do_get_prepared_statement(
256        &self,
257        _query: CommandPreparedStatementQuery,
258        _request: Request<Ticket>,
259    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
260        Err(Status::unimplemented(
261            "do_get_prepared_statement has no default implementation",
262        ))
263    }
264
265    /// Get a FlightDataStream containing the list of catalogs.
266    async fn do_get_catalogs(
267        &self,
268        _query: CommandGetCatalogs,
269        _request: Request<Ticket>,
270    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
271        Err(Status::unimplemented(
272            "do_get_catalogs has no default implementation",
273        ))
274    }
275
276    /// Get a FlightDataStream containing the list of schemas.
277    async fn do_get_schemas(
278        &self,
279        _query: CommandGetDbSchemas,
280        _request: Request<Ticket>,
281    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
282        Err(Status::unimplemented(
283            "do_get_schemas has no default implementation",
284        ))
285    }
286
287    /// Get a FlightDataStream containing the list of tables.
288    async fn do_get_tables(
289        &self,
290        _query: CommandGetTables,
291        _request: Request<Ticket>,
292    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
293        Err(Status::unimplemented(
294            "do_get_tables has no default implementation",
295        ))
296    }
297
298    /// Get a FlightDataStream containing the data related to the table types.
299    async fn do_get_table_types(
300        &self,
301        _query: CommandGetTableTypes,
302        _request: Request<Ticket>,
303    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
304        Err(Status::unimplemented(
305            "do_get_table_types has no default implementation",
306        ))
307    }
308
309    /// Get a FlightDataStream containing the list of SqlInfo results.
310    async fn do_get_sql_info(
311        &self,
312        _query: CommandGetSqlInfo,
313        _request: Request<Ticket>,
314    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
315        Err(Status::unimplemented(
316            "do_get_sql_info has no default implementation",
317        ))
318    }
319
320    /// Get a FlightDataStream containing the data related to the primary and foreign keys.
321    async fn do_get_primary_keys(
322        &self,
323        _query: CommandGetPrimaryKeys,
324        _request: Request<Ticket>,
325    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
326        Err(Status::unimplemented(
327            "do_get_primary_keys has no default implementation",
328        ))
329    }
330
331    /// Get a FlightDataStream containing the data related to the exported keys.
332    async fn do_get_exported_keys(
333        &self,
334        _query: CommandGetExportedKeys,
335        _request: Request<Ticket>,
336    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
337        Err(Status::unimplemented(
338            "do_get_exported_keys has no default implementation",
339        ))
340    }
341
342    /// Get a FlightDataStream containing the data related to the imported keys.
343    async fn do_get_imported_keys(
344        &self,
345        _query: CommandGetImportedKeys,
346        _request: Request<Ticket>,
347    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
348        Err(Status::unimplemented(
349            "do_get_imported_keys has no default implementation",
350        ))
351    }
352
353    /// Get a FlightDataStream containing the data related to the cross reference.
354    async fn do_get_cross_reference(
355        &self,
356        _query: CommandGetCrossReference,
357        _request: Request<Ticket>,
358    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
359        Err(Status::unimplemented(
360            "do_get_cross_reference has no default implementation",
361        ))
362    }
363
364    /// Get a FlightDataStream containing the data related to the supported XDBC types.
365    async fn do_get_xdbc_type_info(
366        &self,
367        _query: CommandGetXdbcTypeInfo,
368        _request: Request<Ticket>,
369    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
370        Err(Status::unimplemented(
371            "do_get_xdbc_type_info has no default implementation",
372        ))
373    }
374
375    // do_put
376
377    /// Implementors may override to handle additional calls to do_put()
378    async fn do_put_fallback(
379        &self,
380        _request: Request<PeekableFlightDataStream>,
381        message: Any,
382    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
383        Err(Status::unimplemented(format!(
384            "do_put: The defined request is invalid: {}",
385            message.type_url
386        )))
387    }
388
389    /// Implementors may override to handle do_put errors
390    async fn do_put_error_callback(
391        &self,
392        _request: Request<PeekableFlightDataStream>,
393        error: DoPutError,
394    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
395        Err(Status::unimplemented(format!("Unhandled Error: {}", error)))
396    }
397
398    /// Execute an update SQL statement.
399    async fn do_put_statement_update(
400        &self,
401        _ticket: CommandStatementUpdate,
402        _request: Request<PeekableFlightDataStream>,
403    ) -> Result<i64, Status> {
404        Err(Status::unimplemented(
405            "do_put_statement_update has no default implementation",
406        ))
407    }
408
409    /// Execute a bulk ingestion.
410    async fn do_put_statement_ingest(
411        &self,
412        _ticket: CommandStatementIngest,
413        _request: Request<PeekableFlightDataStream>,
414    ) -> Result<i64, Status> {
415        Err(Status::unimplemented(
416            "do_put_statement_ingest has no default implementation",
417        ))
418    }
419
420    /// Bind parameters to given prepared statement.
421    ///
422    /// Returns an opaque handle that the client should pass
423    /// back to the server during subsequent requests with this
424    /// prepared statement.
425    async fn do_put_prepared_statement_query(
426        &self,
427        _query: CommandPreparedStatementQuery,
428        _request: Request<PeekableFlightDataStream>,
429    ) -> Result<DoPutPreparedStatementResult, Status> {
430        Err(Status::unimplemented(
431            "do_put_prepared_statement_query has no default implementation",
432        ))
433    }
434
435    /// Execute an update SQL prepared statement.
436    async fn do_put_prepared_statement_update(
437        &self,
438        _query: CommandPreparedStatementUpdate,
439        _request: Request<PeekableFlightDataStream>,
440    ) -> Result<i64, Status> {
441        Err(Status::unimplemented(
442            "do_put_prepared_statement_update has no default implementation",
443        ))
444    }
445
446    /// Execute a substrait plan
447    async fn do_put_substrait_plan(
448        &self,
449        _query: CommandStatementSubstraitPlan,
450        _request: Request<PeekableFlightDataStream>,
451    ) -> Result<i64, Status> {
452        Err(Status::unimplemented(
453            "do_put_substrait_plan has no default implementation",
454        ))
455    }
456
457    // do_action
458
459    /// Implementors may override to handle additional calls to do_action()
460    async fn do_action_fallback(
461        &self,
462        request: Request<Action>,
463    ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
464        Err(Status::invalid_argument(format!(
465            "do_action: The defined request is invalid: {:?}",
466            request.get_ref().r#type
467        )))
468    }
469
470    /// Add custom actions to list_actions() result
471    async fn list_custom_actions(&self) -> Option<Vec<Result<ActionType, Status>>> {
472        None
473    }
474
475    /// Create a prepared statement from given SQL statement.
476    async fn do_action_create_prepared_statement(
477        &self,
478        _query: ActionCreatePreparedStatementRequest,
479        _request: Request<Action>,
480    ) -> Result<ActionCreatePreparedStatementResult, Status> {
481        Err(Status::unimplemented(
482            "do_action_create_prepared_statement has no default implementation",
483        ))
484    }
485
486    /// Close a prepared statement.
487    async fn do_action_close_prepared_statement(
488        &self,
489        _query: ActionClosePreparedStatementRequest,
490        _request: Request<Action>,
491    ) -> Result<(), Status> {
492        Err(Status::unimplemented(
493            "do_action_close_prepared_statement has no default implementation",
494        ))
495    }
496
497    /// Create a prepared substrait plan.
498    async fn do_action_create_prepared_substrait_plan(
499        &self,
500        _query: ActionCreatePreparedSubstraitPlanRequest,
501        _request: Request<Action>,
502    ) -> Result<ActionCreatePreparedStatementResult, Status> {
503        Err(Status::unimplemented(
504            "do_action_create_prepared_substrait_plan has no default implementation",
505        ))
506    }
507
508    /// Begin a transaction
509    async fn do_action_begin_transaction(
510        &self,
511        _query: ActionBeginTransactionRequest,
512        _request: Request<Action>,
513    ) -> Result<ActionBeginTransactionResult, Status> {
514        Err(Status::unimplemented(
515            "do_action_begin_transaction has no default implementation",
516        ))
517    }
518
519    /// End a transaction
520    async fn do_action_end_transaction(
521        &self,
522        _query: ActionEndTransactionRequest,
523        _request: Request<Action>,
524    ) -> Result<(), Status> {
525        Err(Status::unimplemented(
526            "do_action_end_transaction has no default implementation",
527        ))
528    }
529
530    /// Begin a savepoint
531    async fn do_action_begin_savepoint(
532        &self,
533        _query: ActionBeginSavepointRequest,
534        _request: Request<Action>,
535    ) -> Result<ActionBeginSavepointResult, Status> {
536        Err(Status::unimplemented(
537            "do_action_begin_savepoint has no default implementation",
538        ))
539    }
540
541    /// End a savepoint
542    async fn do_action_end_savepoint(
543        &self,
544        _query: ActionEndSavepointRequest,
545        _request: Request<Action>,
546    ) -> Result<(), Status> {
547        Err(Status::unimplemented(
548            "do_action_end_savepoint has no default implementation",
549        ))
550    }
551
552    /// Cancel a query
553    async fn do_action_cancel_query(
554        &self,
555        _query: ActionCancelQueryRequest,
556        _request: Request<Action>,
557    ) -> Result<ActionCancelQueryResult, Status> {
558        Err(Status::unimplemented(
559            "do_action_cancel_query has no default implementation",
560        ))
561    }
562
563    /// do_exchange
564    /// Implementors may override to handle additional calls to do_exchange()
565    async fn do_exchange_fallback(
566        &self,
567        _request: Request<Streaming<FlightData>>,
568    ) -> Result<Response<<Self as FlightService>::DoExchangeStream>, Status> {
569        Err(Status::unimplemented("Not yet implemented"))
570    }
571
572    /// Register a new SqlInfo result, making it available when calling GetSqlInfo.
573    async fn register_sql_info(&self, id: i32, result: &SqlInfo);
574}
575
576/// Implements the lower level interface to handle FlightSQL
577#[tonic::async_trait]
578impl<T: 'static> FlightService for T
579where
580    T: FlightSqlService + Send,
581{
582    type HandshakeStream =
583        Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
584    type ListFlightsStream =
585        Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
586    type DoGetStream = Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
587    type DoPutStream = Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
588    type DoActionStream =
589        Pin<Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>>;
590    type ListActionsStream =
591        Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
592    type DoExchangeStream =
593        Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
594
595    async fn handshake(
596        &self,
597        request: Request<Streaming<HandshakeRequest>>,
598    ) -> Result<Response<Self::HandshakeStream>, Status> {
599        let res = self.do_handshake(request).await?;
600        Ok(res)
601    }
602
603    async fn list_flights(
604        &self,
605        _request: Request<Criteria>,
606    ) -> Result<Response<Self::ListFlightsStream>, Status> {
607        Err(Status::unimplemented("Not yet implemented"))
608    }
609
610    async fn get_flight_info(
611        &self,
612        request: Request<FlightDescriptor>,
613    ) -> Result<Response<FlightInfo>, Status> {
614        let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
615
616        match Command::try_from(message).map_err(arrow_error_to_status)? {
617            Command::CommandStatementQuery(token) => {
618                self.get_flight_info_statement(token, request).await
619            }
620            Command::CommandPreparedStatementQuery(handle) => {
621                self.get_flight_info_prepared_statement(handle, request)
622                    .await
623            }
624            Command::CommandStatementSubstraitPlan(handle) => {
625                self.get_flight_info_substrait_plan(handle, request).await
626            }
627            Command::CommandGetCatalogs(token) => {
628                self.get_flight_info_catalogs(token, request).await
629            }
630            Command::CommandGetDbSchemas(token) => {
631                return self.get_flight_info_schemas(token, request).await
632            }
633            Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await,
634            Command::CommandGetTableTypes(token) => {
635                self.get_flight_info_table_types(token, request).await
636            }
637            Command::CommandGetSqlInfo(token) => {
638                self.get_flight_info_sql_info(token, request).await
639            }
640            Command::CommandGetPrimaryKeys(token) => {
641                self.get_flight_info_primary_keys(token, request).await
642            }
643            Command::CommandGetExportedKeys(token) => {
644                self.get_flight_info_exported_keys(token, request).await
645            }
646            Command::CommandGetImportedKeys(token) => {
647                self.get_flight_info_imported_keys(token, request).await
648            }
649            Command::CommandGetCrossReference(token) => {
650                self.get_flight_info_cross_reference(token, request).await
651            }
652            Command::CommandGetXdbcTypeInfo(token) => {
653                self.get_flight_info_xdbc_type_info(token, request).await
654            }
655            cmd => self.get_flight_info_fallback(cmd, request).await,
656        }
657    }
658
659    async fn poll_flight_info(
660        &self,
661        _request: Request<FlightDescriptor>,
662    ) -> Result<Response<PollInfo>, Status> {
663        Err(Status::unimplemented("Not yet implemented"))
664    }
665
666    async fn get_schema(
667        &self,
668        _request: Request<FlightDescriptor>,
669    ) -> Result<Response<SchemaResult>, Status> {
670        Err(Status::unimplemented("Not yet implemented"))
671    }
672
673    async fn do_get(
674        &self,
675        request: Request<Ticket>,
676    ) -> Result<Response<Self::DoGetStream>, Status> {
677        let msg: Any =
678            Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?;
679
680        match Command::try_from(msg).map_err(arrow_error_to_status)? {
681            Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await,
682            Command::CommandPreparedStatementQuery(command) => {
683                self.do_get_prepared_statement(command, request).await
684            }
685            Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await,
686            Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await,
687            Command::CommandGetTables(command) => self.do_get_tables(command, request).await,
688            Command::CommandGetTableTypes(command) => {
689                self.do_get_table_types(command, request).await
690            }
691            Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await,
692            Command::CommandGetPrimaryKeys(command) => {
693                self.do_get_primary_keys(command, request).await
694            }
695            Command::CommandGetExportedKeys(command) => {
696                self.do_get_exported_keys(command, request).await
697            }
698            Command::CommandGetImportedKeys(command) => {
699                self.do_get_imported_keys(command, request).await
700            }
701            Command::CommandGetCrossReference(command) => {
702                self.do_get_cross_reference(command, request).await
703            }
704            Command::CommandGetXdbcTypeInfo(command) => {
705                self.do_get_xdbc_type_info(command, request).await
706            }
707            cmd => self.do_get_fallback(request, cmd.into_any()).await,
708        }
709    }
710
711    async fn do_put(
712        &self,
713        request: Request<Streaming<FlightData>>,
714    ) -> Result<Response<Self::DoPutStream>, Status> {
715        // See issue #4658: https://github.com/apache/arrow-rs/issues/4658
716        // To dispatch to the correct `do_put` method, we cannot discard the first message,
717        // as it may contain the Arrow schema, which the `do_put` handler may need.
718        // To allow the first message to be reused by the `do_put` handler,
719        // we wrap this stream in a `Peekable` one, which allows us to peek at
720        // the first message without discarding it.
721        let mut request = request.map(PeekableFlightDataStream::new);
722        let mut stream = Pin::new(request.get_mut());
723
724        let peeked_item = stream.peek().await.cloned();
725        let Some(cmd) = peeked_item else {
726            return self
727                .do_put_error_callback(request, DoPutError::MissingCommand)
728                .await;
729        };
730
731        let Some(flight_descriptor) = cmd?.flight_descriptor else {
732            return self
733                .do_put_error_callback(request, DoPutError::MissingFlightDescriptor)
734                .await;
735        };
736        let message = Any::decode(flight_descriptor.cmd).map_err(decode_error_to_status)?;
737        match Command::try_from(message).map_err(arrow_error_to_status)? {
738            Command::CommandStatementUpdate(command) => {
739                let record_count = self.do_put_statement_update(command, request).await?;
740                let result = DoPutUpdateResult { record_count };
741                let output = futures::stream::iter(vec![Ok(PutResult {
742                    app_metadata: result.encode_to_vec().into(),
743                })]);
744                Ok(Response::new(Box::pin(output)))
745            }
746            Command::CommandStatementIngest(command) => {
747                let record_count = self.do_put_statement_ingest(command, request).await?;
748                let result = DoPutUpdateResult { record_count };
749                let output = futures::stream::iter(vec![Ok(PutResult {
750                    app_metadata: result.encode_to_vec().into(),
751                })]);
752                Ok(Response::new(Box::pin(output)))
753            }
754            Command::CommandPreparedStatementQuery(command) => {
755                let result = self
756                    .do_put_prepared_statement_query(command, request)
757                    .await?;
758                let output = futures::stream::iter(vec![Ok(PutResult {
759                    app_metadata: result.encode_to_vec().into(),
760                })]);
761                Ok(Response::new(Box::pin(output)))
762            }
763            Command::CommandStatementSubstraitPlan(command) => {
764                let record_count = self.do_put_substrait_plan(command, request).await?;
765                let result = DoPutUpdateResult { record_count };
766                let output = futures::stream::iter(vec![Ok(PutResult {
767                    app_metadata: result.encode_to_vec().into(),
768                })]);
769                Ok(Response::new(Box::pin(output)))
770            }
771            Command::CommandPreparedStatementUpdate(command) => {
772                let record_count = self
773                    .do_put_prepared_statement_update(command, request)
774                    .await?;
775                let result = DoPutUpdateResult { record_count };
776                let output = futures::stream::iter(vec![Ok(PutResult {
777                    app_metadata: result.encode_to_vec().into(),
778                })]);
779                Ok(Response::new(Box::pin(output)))
780            }
781            cmd => self.do_put_fallback(request, cmd.into_any()).await,
782        }
783    }
784
785    async fn list_actions(
786        &self,
787        _request: Request<Empty>,
788    ) -> Result<Response<Self::ListActionsStream>, Status> {
789        let create_prepared_statement_action_type = ActionType {
790            r#type: CREATE_PREPARED_STATEMENT.to_string(),
791            description: "Creates a reusable prepared statement resource on the server.\n
792                Request Message: ActionCreatePreparedStatementRequest\n
793                Response Message: ActionCreatePreparedStatementResult"
794                .into(),
795        };
796        let close_prepared_statement_action_type = ActionType {
797            r#type: CLOSE_PREPARED_STATEMENT.to_string(),
798            description: "Closes a reusable prepared statement resource on the server.\n
799                Request Message: ActionClosePreparedStatementRequest\n
800                Response Message: N/A"
801                .into(),
802        };
803        let create_prepared_substrait_plan_action_type = ActionType {
804            r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(),
805            description: "Creates a reusable prepared substrait plan resource on the server.\n
806                Request Message: ActionCreatePreparedSubstraitPlanRequest\n
807                Response Message: ActionCreatePreparedStatementResult"
808                .into(),
809        };
810        let begin_transaction_action_type = ActionType {
811            r#type: BEGIN_TRANSACTION.to_string(),
812            description: "Begins a transaction.\n
813                Request Message: ActionBeginTransactionRequest\n
814                Response Message: ActionBeginTransactionResult"
815                .into(),
816        };
817        let end_transaction_action_type = ActionType {
818            r#type: END_TRANSACTION.to_string(),
819            description: "Ends a transaction\n
820                Request Message: ActionEndTransactionRequest\n
821                Response Message: N/A"
822                .into(),
823        };
824        let begin_savepoint_action_type = ActionType {
825            r#type: BEGIN_SAVEPOINT.to_string(),
826            description: "Begins a savepoint.\n
827                Request Message: ActionBeginSavepointRequest\n
828                Response Message: ActionBeginSavepointResult"
829                .into(),
830        };
831        let end_savepoint_action_type = ActionType {
832            r#type: END_SAVEPOINT.to_string(),
833            description: "Ends a savepoint\n
834                Request Message: ActionEndSavepointRequest\n
835                Response Message: N/A"
836                .into(),
837        };
838        let cancel_query_action_type = ActionType {
839            r#type: CANCEL_QUERY.to_string(),
840            description: "Cancels a query\n
841                Request Message: ActionCancelQueryRequest\n
842                Response Message: ActionCancelQueryResult"
843                .into(),
844        };
845        let mut actions: Vec<Result<ActionType, Status>> = vec![
846            Ok(create_prepared_statement_action_type),
847            Ok(close_prepared_statement_action_type),
848            Ok(create_prepared_substrait_plan_action_type),
849            Ok(begin_transaction_action_type),
850            Ok(end_transaction_action_type),
851            Ok(begin_savepoint_action_type),
852            Ok(end_savepoint_action_type),
853            Ok(cancel_query_action_type),
854        ];
855
856        if let Some(mut custom_actions) = self.list_custom_actions().await {
857            actions.append(&mut custom_actions);
858        }
859
860        let output = futures::stream::iter(actions);
861        Ok(Response::new(Box::pin(output) as Self::ListActionsStream))
862    }
863
864    async fn do_action(
865        &self,
866        request: Request<Action>,
867    ) -> Result<Response<Self::DoActionStream>, Status> {
868        if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
869            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
870
871            let cmd: ActionCreatePreparedStatementRequest = any
872                .unpack()
873                .map_err(arrow_error_to_status)?
874                .ok_or_else(|| {
875                    Status::invalid_argument(
876                        "Unable to unpack ActionCreatePreparedStatementRequest.",
877                    )
878                })?;
879            let stmt = self
880                .do_action_create_prepared_statement(cmd, request)
881                .await?;
882            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
883                body: stmt.as_any().encode_to_vec().into(),
884            })]);
885            return Ok(Response::new(Box::pin(output)));
886        } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
887            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
888
889            let cmd: ActionClosePreparedStatementRequest = any
890                .unpack()
891                .map_err(arrow_error_to_status)?
892                .ok_or_else(|| {
893                    Status::invalid_argument(
894                        "Unable to unpack ActionClosePreparedStatementRequest.",
895                    )
896                })?;
897            self.do_action_close_prepared_statement(cmd, request)
898                .await?;
899            return Ok(Response::new(Box::pin(futures::stream::empty())));
900        } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN {
901            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
902
903            let cmd: ActionCreatePreparedSubstraitPlanRequest = any
904                .unpack()
905                .map_err(arrow_error_to_status)?
906                .ok_or_else(|| {
907                    Status::invalid_argument(
908                        "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.",
909                    )
910                })?;
911            self.do_action_create_prepared_substrait_plan(cmd, request)
912                .await?;
913            return Ok(Response::new(Box::pin(futures::stream::empty())));
914        } else if request.get_ref().r#type == BEGIN_TRANSACTION {
915            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
916
917            let cmd: ActionBeginTransactionRequest = any
918                .unpack()
919                .map_err(arrow_error_to_status)?
920                .ok_or_else(|| {
921                Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.")
922            })?;
923            let stmt = self.do_action_begin_transaction(cmd, request).await?;
924            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
925                body: stmt.as_any().encode_to_vec().into(),
926            })]);
927            return Ok(Response::new(Box::pin(output)));
928        } else if request.get_ref().r#type == END_TRANSACTION {
929            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
930
931            let cmd: ActionEndTransactionRequest = any
932                .unpack()
933                .map_err(arrow_error_to_status)?
934                .ok_or_else(|| {
935                    Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.")
936                })?;
937            self.do_action_end_transaction(cmd, request).await?;
938            return Ok(Response::new(Box::pin(futures::stream::empty())));
939        } else if request.get_ref().r#type == BEGIN_SAVEPOINT {
940            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
941
942            let cmd: ActionBeginSavepointRequest = any
943                .unpack()
944                .map_err(arrow_error_to_status)?
945                .ok_or_else(|| {
946                    Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.")
947                })?;
948            let stmt = self.do_action_begin_savepoint(cmd, request).await?;
949            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
950                body: stmt.as_any().encode_to_vec().into(),
951            })]);
952            return Ok(Response::new(Box::pin(output)));
953        } else if request.get_ref().r#type == END_SAVEPOINT {
954            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
955
956            let cmd: ActionEndSavepointRequest = any
957                .unpack()
958                .map_err(arrow_error_to_status)?
959                .ok_or_else(|| {
960                    Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.")
961                })?;
962            self.do_action_end_savepoint(cmd, request).await?;
963            return Ok(Response::new(Box::pin(futures::stream::empty())));
964        } else if request.get_ref().r#type == CANCEL_QUERY {
965            let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
966
967            let cmd: ActionCancelQueryRequest = any
968                .unpack()
969                .map_err(arrow_error_to_status)?
970                .ok_or_else(|| {
971                    Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.")
972                })?;
973            let stmt = self.do_action_cancel_query(cmd, request).await?;
974            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
975                body: stmt.as_any().encode_to_vec().into(),
976            })]);
977            return Ok(Response::new(Box::pin(output)));
978        }
979
980        self.do_action_fallback(request).await
981    }
982
983    async fn do_exchange(
984        &self,
985        request: Request<Streaming<FlightData>>,
986    ) -> Result<Response<Self::DoExchangeStream>, Status> {
987        self.do_exchange_fallback(request).await
988    }
989}
990
991/// Unrecoverable errors associated with `do_put` requests
992pub enum DoPutError {
993    /// The first element in the request stream is missing the command
994    MissingCommand,
995    /// The first element in the request stream is missing the flight descriptor
996    MissingFlightDescriptor,
997}
998impl Display for DoPutError {
999    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1000        match self {
1001            DoPutError::MissingCommand => {
1002                write!(f, "Command is missing.")
1003            }
1004            DoPutError::MissingFlightDescriptor => {
1005                write!(f, "Flight descriptor is missing.")
1006            }
1007        }
1008    }
1009}
1010
1011fn decode_error_to_status(err: prost::DecodeError) -> Status {
1012    Status::invalid_argument(format!("{err:?}"))
1013}
1014
1015fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
1016    Status::internal(format!("{err:?}"))
1017}
1018
1019/// A wrapper around [`Streaming<FlightData>`] that allows "peeking" at the
1020/// message at the front of the stream without consuming it.
1021///
1022/// This is needed because sometimes the first message in the stream will contain
1023/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic
1024/// must inspect this information.
1025///
1026/// # Example
1027///
1028/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without
1029/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream.
1030/// See the following example:
1031///
1032/// ```no_run
1033/// use arrow_array::RecordBatch;
1034/// use arrow_flight::decode::FlightRecordBatchStream;
1035/// use arrow_flight::FlightDescriptor;
1036/// use arrow_flight::error::FlightError;
1037/// use arrow_flight::sql::server::PeekableFlightDataStream;
1038/// use tonic::{Request, Status};
1039/// use futures::TryStreamExt;
1040///
1041/// #[tokio::main]
1042/// async fn main() -> Result<(), Status> {
1043///     let request: Request<PeekableFlightDataStream> = todo!();
1044///     let stream: PeekableFlightDataStream = request.into_inner();
1045///
1046///     // The first message contains the flight descriptor and the schema.
1047///     // Read the flight descriptor without discarding the schema:
1048///     let flight_descriptor: FlightDescriptor = stream
1049///         .peek()
1050///         .await
1051///         .cloned()
1052///         .transpose()?
1053///         .and_then(|data| data.flight_descriptor)
1054///         .expect("first message should contain flight descriptor");
1055///
1056///     // Pass the stream through a decoder
1057///     let batches: Vec<RecordBatch> = FlightRecordBatchStream::new_from_flight_data(
1058///         request.into_inner().map_err(|e| e.into()),
1059///     )
1060///     .try_collect()
1061///     .await?;
1062/// }
1063/// ```
1064pub struct PeekableFlightDataStream {
1065    inner: Peekable<Streaming<FlightData>>,
1066}
1067
1068impl PeekableFlightDataStream {
1069    fn new(stream: Streaming<FlightData>) -> Self {
1070        Self {
1071            inner: stream.peekable(),
1072        }
1073    }
1074
1075    /// Convert this stream into a `Streaming<FlightData>`.
1076    /// Any messages observed through [`Self::peek`] will be lost
1077    /// after the conversion.
1078    pub fn into_inner(self) -> Streaming<FlightData> {
1079        self.inner.into_inner()
1080    }
1081
1082    /// Convert this stream into a `Peekable<Streaming<FlightData>>`.
1083    /// Preserves the state of the stream, so that calls to [`Self::peek`]
1084    /// and [`Self::poll_next`] are the same.
1085    pub fn into_peekable(self) -> Peekable<Streaming<FlightData>> {
1086        self.inner
1087    }
1088
1089    /// Peek at the head of this stream without advancing it.
1090    pub async fn peek(&mut self) -> Option<&Result<FlightData, Status>> {
1091        Pin::new(&mut self.inner).peek().await
1092    }
1093}
1094
1095impl Stream for PeekableFlightDataStream {
1096    type Item = Result<FlightData, Status>;
1097
1098    fn poll_next(
1099        mut self: Pin<&mut Self>,
1100        cx: &mut std::task::Context<'_>,
1101    ) -> std::task::Poll<Option<Self::Item>> {
1102        self.inner.poll_next_unpin(cx)
1103    }
1104}