arrow_flight/sql/
client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A FlightSQL Client [`FlightSqlServiceClient`]
19
20use base64::prelude::BASE64_STANDARD;
21use base64::Engine;
22use bytes::Bytes;
23use std::collections::HashMap;
24use std::str::FromStr;
25use tonic::metadata::AsciiMetadataKey;
26
27use crate::decode::FlightRecordBatchStream;
28use crate::encode::FlightDataEncoderBuilder;
29use crate::error::FlightError;
30use crate::flight_service_client::FlightServiceClient;
31use crate::sql::gen::action_end_transaction_request::EndTransaction;
32use crate::sql::server::{
33    BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION,
34};
35use crate::sql::{
36    ActionBeginTransactionRequest, ActionBeginTransactionResult,
37    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
38    ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs,
39    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
40    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
41    CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
42    CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
43    DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
44};
45use crate::streams::FallibleRequestStream;
46use crate::trailers::extract_lazy_trailers;
47use crate::{
48    Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
49    IpcMessage, PutResult, Ticket,
50};
51use arrow_array::RecordBatch;
52use arrow_buffer::Buffer;
53use arrow_ipc::convert::fb_to_schema;
54use arrow_ipc::reader::read_record_batch;
55use arrow_ipc::{root_as_message, MessageHeader};
56use arrow_schema::{ArrowError, Schema, SchemaRef};
57use futures::{stream, Stream, TryStreamExt};
58use prost::Message;
59use tonic::transport::Channel;
60use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
61
62/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
63/// by FlightSQL protocol.
64#[derive(Debug, Clone)]
65pub struct FlightSqlServiceClient<T> {
66    token: Option<String>,
67    headers: HashMap<String, String>,
68    flight_client: FlightServiceClient<T>,
69}
70
71/// A FlightSql protocol client that can run queries against FlightSql servers
72/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances.
73/// Github issues are welcomed.
74impl FlightSqlServiceClient<Channel> {
75    /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
76    pub fn new(channel: Channel) -> Self {
77        Self::new_from_inner(FlightServiceClient::new(channel))
78    }
79
80    /// Creates a new higher level client with the provided lower level client
81    pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
82        Self {
83            token: None,
84            flight_client: inner,
85            headers: HashMap::default(),
86        }
87    }
88
89    /// Return a reference to the underlying [`FlightServiceClient`]
90    pub fn inner(&self) -> &FlightServiceClient<Channel> {
91        &self.flight_client
92    }
93
94    /// Return a mutable reference to the underlying [`FlightServiceClient`]
95    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
96        &mut self.flight_client
97    }
98
99    /// Consume this client and return the underlying [`FlightServiceClient`]
100    pub fn into_inner(self) -> FlightServiceClient<Channel> {
101        self.flight_client
102    }
103
104    /// Set auth token to the given value.
105    pub fn set_token(&mut self, token: String) {
106        self.token = Some(token);
107    }
108
109    /// Clear the auth token.
110    pub fn clear_token(&mut self) {
111        self.token = None;
112    }
113
114    /// Share the bearer token with potentially different `DoGet` clients
115    pub fn token(&self) -> Option<&String> {
116        self.token.as_ref()
117    }
118
119    /// Set header value.
120    pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
121        let key: String = key.into();
122        let value: String = value.into();
123        self.headers.insert(key, value);
124    }
125
126    async fn get_flight_info_for_command<M: ProstMessageExt>(
127        &mut self,
128        cmd: M,
129    ) -> Result<FlightInfo, ArrowError> {
130        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
131        let req = self.set_request_headers(descriptor.into_request())?;
132        let fi = self
133            .flight_client
134            .get_flight_info(req)
135            .await
136            .map_err(status_to_arrow_error)?
137            .into_inner();
138        Ok(fi)
139    }
140
141    /// Execute a query on the server.
142    pub async fn execute(
143        &mut self,
144        query: String,
145        transaction_id: Option<Bytes>,
146    ) -> Result<FlightInfo, ArrowError> {
147        let cmd = CommandStatementQuery {
148            query,
149            transaction_id,
150        };
151        self.get_flight_info_for_command(cmd).await
152    }
153
154    /// Perform a `handshake` with the server, passing credentials and establishing a session.
155    ///
156    /// If the server returns an "authorization" header, it is automatically parsed and set as
157    /// a token for future requests. Any other data returned by the server in the handshake
158    /// response is returned as a binary blob.
159    pub async fn handshake(&mut self, username: &str, password: &str) -> Result<Bytes, ArrowError> {
160        let cmd = HandshakeRequest {
161            protocol_version: 0,
162            payload: Default::default(),
163        };
164        let mut req = tonic::Request::new(stream::iter(vec![cmd]));
165        let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
166        let val = format!("Basic {val}")
167            .parse()
168            .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
169        req.metadata_mut().insert("authorization", val);
170        let req = self.set_request_headers(req)?;
171        let resp = self
172            .flight_client
173            .handshake(req)
174            .await
175            .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?;
176        if let Some(auth) = resp.metadata().get("authorization") {
177            let auth = auth
178                .to_str()
179                .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?;
180            let bearer = "Bearer ";
181            if !auth.starts_with(bearer) {
182                Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
183            }
184            let auth = auth[bearer.len()..].to_string();
185            self.token = Some(auth);
186        }
187        let responses: Vec<HandshakeResponse> = resp
188            .into_inner()
189            .try_collect()
190            .await
191            .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?;
192        let resp = match responses.as_slice() {
193            [resp] => resp.payload.clone(),
194            [] => Bytes::new(),
195            _ => Err(ArrowError::ParseError(
196                "Multiple handshake responses".to_string(),
197            ))?,
198        };
199        Ok(resp)
200    }
201
202    /// Execute a update query on the server, and return the number of records affected
203    pub async fn execute_update(
204        &mut self,
205        query: String,
206        transaction_id: Option<Bytes>,
207    ) -> Result<i64, ArrowError> {
208        let cmd = CommandStatementUpdate {
209            query,
210            transaction_id,
211        };
212        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
213        let req = self.set_request_headers(
214            stream::iter(vec![FlightData {
215                flight_descriptor: Some(descriptor),
216                ..Default::default()
217            }])
218            .into_request(),
219        )?;
220        let mut result = self
221            .flight_client
222            .do_put(req)
223            .await
224            .map_err(status_to_arrow_error)?
225            .into_inner();
226        let result = result
227            .message()
228            .await
229            .map_err(status_to_arrow_error)?
230            .unwrap();
231        let result: DoPutUpdateResult =
232            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
233        Ok(result.record_count)
234    }
235
236    /// Execute a bulk ingest on the server and return the number of records added
237    pub async fn execute_ingest<S>(
238        &mut self,
239        command: CommandStatementIngest,
240        stream: S,
241    ) -> Result<i64, ArrowError>
242    where
243        S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
244    {
245        let (sender, receiver) = futures::channel::oneshot::channel();
246
247        let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
248        let flight_data = FlightDataEncoderBuilder::new()
249            .with_flight_descriptor(Some(descriptor))
250            .build(stream);
251
252        // Intercept client errors and send them to the one shot channel above
253        let flight_data = Box::pin(flight_data);
254        let flight_data: FallibleRequestStream<FlightData, FlightError> =
255            FallibleRequestStream::new(sender, flight_data);
256
257        let req = self.set_request_headers(flight_data.into_streaming_request())?;
258        let mut result = self
259            .flight_client
260            .do_put(req)
261            .await
262            .map_err(status_to_arrow_error)?
263            .into_inner();
264
265        // check if the there were any errors in the input stream provided note
266        // if receiver.await fails, it means the sender was dropped and there is
267        // no message to return.
268        if let Ok(msg) = receiver.await {
269            return Err(ArrowError::ExternalError(Box::new(msg)));
270        }
271
272        let result = result
273            .message()
274            .await
275            .map_err(status_to_arrow_error)?
276            .unwrap();
277        let result: DoPutUpdateResult =
278            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
279        Ok(result.record_count)
280    }
281
282    /// Request a list of catalogs as tabular FlightInfo results
283    pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
284        self.get_flight_info_for_command(CommandGetCatalogs {})
285            .await
286    }
287
288    /// Request a list of database schemas as tabular FlightInfo results
289    pub async fn get_db_schemas(
290        &mut self,
291        request: CommandGetDbSchemas,
292    ) -> Result<FlightInfo, ArrowError> {
293        self.get_flight_info_for_command(request).await
294    }
295
296    /// Given a flight ticket, request to be sent the stream. Returns record batch stream reader
297    pub async fn do_get(
298        &mut self,
299        ticket: impl IntoRequest<Ticket>,
300    ) -> Result<FlightRecordBatchStream, ArrowError> {
301        let req = self.set_request_headers(ticket.into_request())?;
302
303        let (md, response_stream, _ext) = self
304            .flight_client
305            .do_get(req)
306            .await
307            .map_err(status_to_arrow_error)?
308            .into_parts();
309        let (response_stream, trailers) = extract_lazy_trailers(response_stream);
310
311        Ok(FlightRecordBatchStream::new_from_flight_data(
312            response_stream.map_err(FlightError::Tonic),
313        )
314        .with_headers(md)
315        .with_trailers(trailers))
316    }
317
318    /// Push a stream to the flight service associated with a particular flight stream.
319    pub async fn do_put(
320        &mut self,
321        request: impl tonic::IntoStreamingRequest<Message = FlightData>,
322    ) -> Result<Streaming<PutResult>, ArrowError> {
323        let req = self.set_request_headers(request.into_streaming_request())?;
324        Ok(self
325            .flight_client
326            .do_put(req)
327            .await
328            .map_err(status_to_arrow_error)?
329            .into_inner())
330    }
331
332    /// DoAction allows a flight client to do a specific action against a flight service
333    pub async fn do_action(
334        &mut self,
335        request: impl IntoRequest<Action>,
336    ) -> Result<Streaming<crate::Result>, ArrowError> {
337        let req = self.set_request_headers(request.into_request())?;
338        Ok(self
339            .flight_client
340            .do_action(req)
341            .await
342            .map_err(status_to_arrow_error)?
343            .into_inner())
344    }
345
346    /// Request a list of tables.
347    pub async fn get_tables(
348        &mut self,
349        request: CommandGetTables,
350    ) -> Result<FlightInfo, ArrowError> {
351        self.get_flight_info_for_command(request).await
352    }
353
354    /// Request the primary keys for a table.
355    pub async fn get_primary_keys(
356        &mut self,
357        request: CommandGetPrimaryKeys,
358    ) -> Result<FlightInfo, ArrowError> {
359        self.get_flight_info_for_command(request).await
360    }
361
362    /// Retrieves a description about the foreign key columns that reference the
363    /// primary key columns of the given table.
364    pub async fn get_exported_keys(
365        &mut self,
366        request: CommandGetExportedKeys,
367    ) -> Result<FlightInfo, ArrowError> {
368        self.get_flight_info_for_command(request).await
369    }
370
371    /// Retrieves the foreign key columns for the given table.
372    pub async fn get_imported_keys(
373        &mut self,
374        request: CommandGetImportedKeys,
375    ) -> Result<FlightInfo, ArrowError> {
376        self.get_flight_info_for_command(request).await
377    }
378
379    /// Retrieves a description of the foreign key columns in the given foreign key
380    /// table that reference the primary key or the columns representing a unique
381    /// constraint of the parent table (could be the same or a different table).
382    pub async fn get_cross_reference(
383        &mut self,
384        request: CommandGetCrossReference,
385    ) -> Result<FlightInfo, ArrowError> {
386        self.get_flight_info_for_command(request).await
387    }
388
389    /// Request a list of table types.
390    pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
391        self.get_flight_info_for_command(CommandGetTableTypes {})
392            .await
393    }
394
395    /// Request a list of SQL information.
396    pub async fn get_sql_info(
397        &mut self,
398        sql_infos: Vec<SqlInfo>,
399    ) -> Result<FlightInfo, ArrowError> {
400        let request = CommandGetSqlInfo {
401            info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
402        };
403        self.get_flight_info_for_command(request).await
404    }
405
406    /// Request XDBC SQL information.
407    pub async fn get_xdbc_type_info(
408        &mut self,
409        request: CommandGetXdbcTypeInfo,
410    ) -> Result<FlightInfo, ArrowError> {
411        self.get_flight_info_for_command(request).await
412    }
413
414    /// Create a prepared statement object.
415    pub async fn prepare(
416        &mut self,
417        query: String,
418        transaction_id: Option<Bytes>,
419    ) -> Result<PreparedStatement<Channel>, ArrowError> {
420        let cmd = ActionCreatePreparedStatementRequest {
421            query,
422            transaction_id,
423        };
424        let action = Action {
425            r#type: CREATE_PREPARED_STATEMENT.to_string(),
426            body: cmd.as_any().encode_to_vec().into(),
427        };
428        let req = self.set_request_headers(action.into_request())?;
429        let mut result = self
430            .flight_client
431            .do_action(req)
432            .await
433            .map_err(status_to_arrow_error)?
434            .into_inner();
435        let result = result
436            .message()
437            .await
438            .map_err(status_to_arrow_error)?
439            .unwrap();
440        let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
441        let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
442        let dataset_schema = match prepared_result.dataset_schema.len() {
443            0 => Schema::empty(),
444            _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
445        };
446        let parameter_schema = match prepared_result.parameter_schema.len() {
447            0 => Schema::empty(),
448            _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
449        };
450        Ok(PreparedStatement::new(
451            self.clone(),
452            prepared_result.prepared_statement_handle,
453            dataset_schema,
454            parameter_schema,
455        ))
456    }
457
458    /// Request to begin a transaction.
459    pub async fn begin_transaction(&mut self) -> Result<Bytes, ArrowError> {
460        let cmd = ActionBeginTransactionRequest {};
461        let action = Action {
462            r#type: BEGIN_TRANSACTION.to_string(),
463            body: cmd.as_any().encode_to_vec().into(),
464        };
465        let req = self.set_request_headers(action.into_request())?;
466        let mut result = self
467            .flight_client
468            .do_action(req)
469            .await
470            .map_err(status_to_arrow_error)?
471            .into_inner();
472        let result = result
473            .message()
474            .await
475            .map_err(status_to_arrow_error)?
476            .unwrap();
477        let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
478        let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap();
479        Ok(begin_result.transaction_id)
480    }
481
482    /// Request to commit/rollback a transaction.
483    pub async fn end_transaction(
484        &mut self,
485        transaction_id: Bytes,
486        action: EndTransaction,
487    ) -> Result<(), ArrowError> {
488        let cmd = ActionEndTransactionRequest {
489            transaction_id,
490            action: action as i32,
491        };
492        let action = Action {
493            r#type: END_TRANSACTION.to_string(),
494            body: cmd.as_any().encode_to_vec().into(),
495        };
496        let req = self.set_request_headers(action.into_request())?;
497        let _ = self
498            .flight_client
499            .do_action(req)
500            .await
501            .map_err(status_to_arrow_error)?
502            .into_inner();
503        Ok(())
504    }
505
506    /// Explicitly shut down and clean up the client.
507    pub async fn close(&mut self) -> Result<(), ArrowError> {
508        // TODO: consume self instead of &mut self to explicitly prevent reuse?
509        Ok(())
510    }
511
512    fn set_request_headers<T>(
513        &self,
514        mut req: tonic::Request<T>,
515    ) -> Result<tonic::Request<T>, ArrowError> {
516        for (k, v) in &self.headers {
517            let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
518                ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
519            })?;
520            let v = v.parse().map_err(|e| {
521                ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}"))
522            })?;
523            req.metadata_mut().insert(k, v);
524        }
525        if let Some(token) = &self.token {
526            let val = format!("Bearer {token}").parse().map_err(|e| {
527                ArrowError::ParseError(format!("Cannot convert token to header value: {e}"))
528            })?;
529            req.metadata_mut().insert("authorization", val);
530        }
531        Ok(req)
532    }
533}
534
535/// A PreparedStatement
536#[derive(Debug, Clone)]
537pub struct PreparedStatement<T> {
538    flight_sql_client: FlightSqlServiceClient<T>,
539    parameter_binding: Option<RecordBatch>,
540    handle: Bytes,
541    dataset_schema: Schema,
542    parameter_schema: Schema,
543}
544
545impl PreparedStatement<Channel> {
546    pub(crate) fn new(
547        flight_client: FlightSqlServiceClient<Channel>,
548        handle: impl Into<Bytes>,
549        dataset_schema: Schema,
550        parameter_schema: Schema,
551    ) -> Self {
552        PreparedStatement {
553            flight_sql_client: flight_client,
554            parameter_binding: None,
555            handle: handle.into(),
556            dataset_schema,
557            parameter_schema,
558        }
559    }
560
561    /// Executes the prepared statement query on the server.
562    pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
563        self.write_bind_params().await?;
564
565        let cmd = CommandPreparedStatementQuery {
566            prepared_statement_handle: self.handle.clone(),
567        };
568
569        let result = self
570            .flight_sql_client
571            .get_flight_info_for_command(cmd)
572            .await?;
573        Ok(result)
574    }
575
576    /// Executes the prepared statement update query on the server.
577    pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
578        self.write_bind_params().await?;
579
580        let cmd = CommandPreparedStatementUpdate {
581            prepared_statement_handle: self.handle.clone(),
582        };
583        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
584        let mut result = self
585            .flight_sql_client
586            .do_put(stream::iter(vec![FlightData {
587                flight_descriptor: Some(descriptor),
588                ..Default::default()
589            }]))
590            .await?;
591        let result = result
592            .message()
593            .await
594            .map_err(status_to_arrow_error)?
595            .unwrap();
596        let result: DoPutUpdateResult =
597            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
598        Ok(result.record_count)
599    }
600
601    /// Retrieve the parameter schema from the query.
602    pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
603        Ok(&self.parameter_schema)
604    }
605
606    /// Retrieve the ResultSet schema from the query.
607    pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
608        Ok(&self.dataset_schema)
609    }
610
611    /// Set a RecordBatch that contains the parameters that will be bind.
612    pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> {
613        self.parameter_binding = Some(parameter_binding);
614        Ok(())
615    }
616
617    /// Submit parameters to the server, if any have been set on this prepared statement instance
618    /// Updates our stored prepared statement handle with the handle given by the server response.
619    async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
620        if let Some(ref params_batch) = self.parameter_binding {
621            let cmd = CommandPreparedStatementQuery {
622                prepared_statement_handle: self.handle.clone(),
623            };
624
625            let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
626            let flight_stream_builder = FlightDataEncoderBuilder::new()
627                .with_flight_descriptor(Some(descriptor))
628                .with_schema(params_batch.schema());
629            let flight_data = flight_stream_builder
630                .build(futures::stream::iter(
631                    self.parameter_binding.clone().map(Ok),
632                ))
633                .try_collect::<Vec<_>>()
634                .await
635                .map_err(flight_error_to_arrow_error)?;
636
637            // Attempt to update the stored handle with any updated handle in the DoPut result.
638            // Older servers do not respond with a result for DoPut, so skip this step when
639            // the stream closes with no response.
640            if let Some(result) = self
641                .flight_sql_client
642                .do_put(stream::iter(flight_data))
643                .await?
644                .message()
645                .await
646                .map_err(status_to_arrow_error)?
647            {
648                if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
649                    self.handle = handle;
650                }
651            }
652        }
653        Ok(())
654    }
655
656    /// Decodes the app_metadata stored in a [`PutResult`] as a
657    /// [`DoPutPreparedStatementResult`] and then returns
658    /// the inner prepared statement handle as [`Bytes`]
659    fn unpack_prepared_statement_handle(
660        &self,
661        put_result: &PutResult,
662    ) -> Result<Option<Bytes>, ArrowError> {
663        let result: DoPutPreparedStatementResult =
664            Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
665        Ok(result.prepared_statement_handle)
666    }
667
668    /// Close the prepared statement, so that this PreparedStatement can not used
669    /// anymore and server can free up any resources.
670    pub async fn close(mut self) -> Result<(), ArrowError> {
671        let cmd = ActionClosePreparedStatementRequest {
672            prepared_statement_handle: self.handle.clone(),
673        };
674        let action = Action {
675            r#type: CLOSE_PREPARED_STATEMENT.to_string(),
676            body: cmd.as_any().encode_to_vec().into(),
677        };
678        let _ = self.flight_sql_client.do_action(action).await?;
679        Ok(())
680    }
681}
682
683fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
684    ArrowError::IpcError(err.to_string())
685}
686
687fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
688    ArrowError::IpcError(format!("{status:?}"))
689}
690
691fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
692    match err {
693        FlightError::Arrow(e) => e,
694        e => ArrowError::ExternalError(Box::new(e)),
695    }
696}
697
698/// A polymorphic structure to natively represent different types of data contained in `FlightData`
699pub enum ArrowFlightData {
700    /// A record batch
701    RecordBatch(RecordBatch),
702    /// A schema
703    Schema(Schema),
704}
705
706/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation
707pub fn arrow_data_from_flight_data(
708    flight_data: FlightData,
709    arrow_schema_ref: &SchemaRef,
710) -> Result<ArrowFlightData, ArrowError> {
711    let ipc_message = root_as_message(&flight_data.data_header[..])
712        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
713
714    match ipc_message.header_type() {
715        MessageHeader::RecordBatch => {
716            let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| {
717                ArrowError::ComputeError(
718                    "Unable to convert flight data header to a record batch".to_string(),
719                )
720            })?;
721
722            let dictionaries_by_field = HashMap::new();
723            let record_batch = read_record_batch(
724                &Buffer::from(flight_data.data_body),
725                ipc_record_batch,
726                arrow_schema_ref.clone(),
727                &dictionaries_by_field,
728                None,
729                &ipc_message.version(),
730            )?;
731            Ok(ArrowFlightData::RecordBatch(record_batch))
732        }
733        MessageHeader::Schema => {
734            let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
735                ArrowError::ComputeError(
736                    "Unable to convert flight data header to a schema".to_string(),
737                )
738            })?;
739
740            let arrow_schema = fb_to_schema(ipc_schema);
741            Ok(ArrowFlightData::Schema(arrow_schema))
742        }
743        MessageHeader::DictionaryBatch => {
744            let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
745                ArrowError::ComputeError(
746                    "Unable to convert flight data header to a dictionary batch".to_string(),
747                )
748            })?;
749            Err(ArrowError::NotYetImplemented(
750                "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(),
751            ))
752        }
753        MessageHeader::Tensor => {
754            let _ = ipc_message.header_as_tensor().ok_or_else(|| {
755                ArrowError::ComputeError(
756                    "Unable to convert flight data header to a tensor".to_string(),
757                )
758            })?;
759            Err(ArrowError::NotYetImplemented(
760                "no idea on how to convert an ipc tensor to an arrow type".to_string(),
761            ))
762        }
763        MessageHeader::SparseTensor => {
764            let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
765                ArrowError::ComputeError(
766                    "Unable to convert flight data header to a sparse tensor".to_string(),
767                )
768            })?;
769            Err(ArrowError::NotYetImplemented(
770                "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(),
771            ))
772        }
773        _ => Err(ArrowError::ComputeError(format!(
774            "Unable to convert message with header_type: '{:?}' to arrow data",
775            ipc_message.header_type()
776        ))),
777    }
778}