arrow_flight/sql/
client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A FlightSQL Client [`FlightSqlServiceClient`]
19
20use base64::Engine;
21use base64::prelude::BASE64_STANDARD;
22use bytes::Bytes;
23use std::collections::HashMap;
24use std::str::FromStr;
25use tonic::metadata::AsciiMetadataKey;
26
27use crate::decode::FlightRecordBatchStream;
28use crate::encode::FlightDataEncoderBuilder;
29use crate::error::FlightError;
30use crate::flight_service_client::FlightServiceClient;
31use crate::sql::r#gen::action_end_transaction_request::EndTransaction;
32use crate::sql::server::{
33    BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION,
34};
35use crate::sql::{
36    ActionBeginTransactionRequest, ActionBeginTransactionResult,
37    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
38    ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs,
39    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
40    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
41    CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
42    CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
43    DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
44};
45use crate::streams::FallibleRequestStream;
46use crate::trailers::extract_lazy_trailers;
47use crate::{
48    Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
49    IpcMessage, PutResult, Ticket,
50};
51use arrow_array::RecordBatch;
52use arrow_buffer::Buffer;
53use arrow_ipc::convert::fb_to_schema;
54use arrow_ipc::reader::read_record_batch;
55use arrow_ipc::{MessageHeader, root_as_message};
56use arrow_schema::{ArrowError, Schema, SchemaRef};
57use futures::{Stream, TryStreamExt, stream};
58use prost::Message;
59use tonic::codegen::{Body, StdError};
60use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
61
62/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
63/// by FlightSQL protocol.
64#[derive(Debug)]
65pub struct FlightSqlServiceClient<T> {
66    token: Option<String>,
67    headers: HashMap<String, String>,
68    flight_client: FlightServiceClient<T>,
69}
70
71/// A FlightSql protocol client that can run queries against FlightSql servers
72/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances.
73/// Github issues are welcomed.
74impl<T> FlightSqlServiceClient<T>
75where
76    T: tonic::client::GrpcService<tonic::body::Body>,
77    T::Error: Into<StdError>,
78    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
79    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
80{
81    /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
82    pub fn new(channel: T) -> Self {
83        Self::new_from_inner(FlightServiceClient::new(channel))
84    }
85
86    /// Creates a new higher level client with the provided lower level client
87    pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
88        Self {
89            token: None,
90            flight_client: inner,
91            headers: HashMap::default(),
92        }
93    }
94
95    /// Return a reference to the underlying [`FlightServiceClient`]
96    pub fn inner(&self) -> &FlightServiceClient<T> {
97        &self.flight_client
98    }
99
100    /// Return a mutable reference to the underlying [`FlightServiceClient`]
101    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
102        &mut self.flight_client
103    }
104
105    /// Consume this client and return the underlying [`FlightServiceClient`]
106    pub fn into_inner(self) -> FlightServiceClient<T> {
107        self.flight_client
108    }
109
110    /// Set auth token to the given value.
111    pub fn set_token(&mut self, token: String) {
112        self.token = Some(token);
113    }
114
115    /// Clear the auth token.
116    pub fn clear_token(&mut self) {
117        self.token = None;
118    }
119
120    /// Share the bearer token with potentially different `DoGet` clients
121    pub fn token(&self) -> Option<&String> {
122        self.token.as_ref()
123    }
124
125    /// Set header value.
126    pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
127        let key: String = key.into();
128        let value: String = value.into();
129        self.headers.insert(key, value);
130    }
131
132    async fn get_flight_info_for_command<M: ProstMessageExt>(
133        &mut self,
134        cmd: M,
135    ) -> Result<FlightInfo, ArrowError> {
136        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
137        let req = self.set_request_headers(descriptor.into_request())?;
138        let fi = self
139            .flight_client
140            .get_flight_info(req)
141            .await
142            .map_err(status_to_arrow_error)?
143            .into_inner();
144        Ok(fi)
145    }
146
147    /// Execute a query on the server.
148    pub async fn execute(
149        &mut self,
150        query: String,
151        transaction_id: Option<Bytes>,
152    ) -> Result<FlightInfo, ArrowError> {
153        let cmd = CommandStatementQuery {
154            query,
155            transaction_id,
156        };
157        self.get_flight_info_for_command(cmd).await
158    }
159
160    /// Perform a `handshake` with the server, passing credentials and establishing a session.
161    ///
162    /// If the server returns an "authorization" header, it is automatically parsed and set as
163    /// a token for future requests. Any other data returned by the server in the handshake
164    /// response is returned as a binary blob.
165    pub async fn handshake(&mut self, username: &str, password: &str) -> Result<Bytes, ArrowError> {
166        let cmd = HandshakeRequest {
167            protocol_version: 0,
168            payload: Default::default(),
169        };
170        let mut req = tonic::Request::new(stream::iter(vec![cmd]));
171        let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
172        let val = format!("Basic {val}")
173            .parse()
174            .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
175        req.metadata_mut().insert("authorization", val);
176        let req = self.set_request_headers(req)?;
177        let resp = self
178            .flight_client
179            .handshake(req)
180            .await
181            .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?;
182        if let Some(auth) = resp.metadata().get("authorization") {
183            let auth = auth
184                .to_str()
185                .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?;
186            let bearer = "Bearer ";
187            if !auth.starts_with(bearer) {
188                Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
189            }
190            let auth = auth[bearer.len()..].to_string();
191            self.token = Some(auth);
192        }
193        let responses: Vec<HandshakeResponse> = resp
194            .into_inner()
195            .try_collect()
196            .await
197            .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?;
198        let resp = match responses.as_slice() {
199            [resp] => resp.payload.clone(),
200            [] => Bytes::new(),
201            _ => Err(ArrowError::ParseError(
202                "Multiple handshake responses".to_string(),
203            ))?,
204        };
205        Ok(resp)
206    }
207
208    /// Execute a update query on the server, and return the number of records affected
209    pub async fn execute_update(
210        &mut self,
211        query: String,
212        transaction_id: Option<Bytes>,
213    ) -> Result<i64, ArrowError> {
214        let cmd = CommandStatementUpdate {
215            query,
216            transaction_id,
217        };
218        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
219        let req = self.set_request_headers(
220            stream::iter(vec![FlightData {
221                flight_descriptor: Some(descriptor),
222                ..Default::default()
223            }])
224            .into_request(),
225        )?;
226        let mut result = self
227            .flight_client
228            .do_put(req)
229            .await
230            .map_err(status_to_arrow_error)?
231            .into_inner();
232        let result = result
233            .message()
234            .await
235            .map_err(status_to_arrow_error)?
236            .unwrap();
237        let result: DoPutUpdateResult =
238            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
239        Ok(result.record_count)
240    }
241
242    /// Execute a bulk ingest on the server and return the number of records added
243    pub async fn execute_ingest<S>(
244        &mut self,
245        command: CommandStatementIngest,
246        stream: S,
247    ) -> Result<i64, ArrowError>
248    where
249        S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
250    {
251        let (sender, receiver) = futures::channel::oneshot::channel();
252
253        let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
254        let flight_data = FlightDataEncoderBuilder::new()
255            .with_flight_descriptor(Some(descriptor))
256            .build(stream);
257
258        // Intercept client errors and send them to the one shot channel above
259        let flight_data = Box::pin(flight_data);
260        let flight_data: FallibleRequestStream<FlightData, FlightError> =
261            FallibleRequestStream::new(sender, flight_data);
262
263        let req = self.set_request_headers(flight_data.into_streaming_request())?;
264        let mut result = self
265            .flight_client
266            .do_put(req)
267            .await
268            .map_err(status_to_arrow_error)?
269            .into_inner();
270
271        // check if the there were any errors in the input stream provided note
272        // if receiver.await fails, it means the sender was dropped and there is
273        // no message to return.
274        if let Ok(msg) = receiver.await {
275            return Err(ArrowError::ExternalError(Box::new(msg)));
276        }
277
278        let result = result
279            .message()
280            .await
281            .map_err(status_to_arrow_error)?
282            .unwrap();
283        let result: DoPutUpdateResult =
284            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
285        Ok(result.record_count)
286    }
287
288    /// Request a list of catalogs as tabular FlightInfo results
289    pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
290        self.get_flight_info_for_command(CommandGetCatalogs {})
291            .await
292    }
293
294    /// Request a list of database schemas as tabular FlightInfo results
295    pub async fn get_db_schemas(
296        &mut self,
297        request: CommandGetDbSchemas,
298    ) -> Result<FlightInfo, ArrowError> {
299        self.get_flight_info_for_command(request).await
300    }
301
302    /// Given a flight ticket, request to be sent the stream. Returns record batch stream reader
303    pub async fn do_get(
304        &mut self,
305        ticket: impl IntoRequest<Ticket>,
306    ) -> Result<FlightRecordBatchStream, ArrowError> {
307        let req = self.set_request_headers(ticket.into_request())?;
308
309        let (md, response_stream, _ext) = self
310            .flight_client
311            .do_get(req)
312            .await
313            .map_err(status_to_arrow_error)?
314            .into_parts();
315        let (response_stream, trailers) = extract_lazy_trailers(response_stream);
316
317        Ok(FlightRecordBatchStream::new_from_flight_data(
318            response_stream.map_err(|status| status.into()),
319        )
320        .with_headers(md)
321        .with_trailers(trailers))
322    }
323
324    /// Push a stream to the flight service associated with a particular flight stream.
325    pub async fn do_put(
326        &mut self,
327        request: impl tonic::IntoStreamingRequest<Message = FlightData>,
328    ) -> Result<Streaming<PutResult>, ArrowError> {
329        let req = self.set_request_headers(request.into_streaming_request())?;
330        Ok(self
331            .flight_client
332            .do_put(req)
333            .await
334            .map_err(status_to_arrow_error)?
335            .into_inner())
336    }
337
338    /// DoAction allows a flight client to do a specific action against a flight service
339    pub async fn do_action(
340        &mut self,
341        request: impl IntoRequest<Action>,
342    ) -> Result<Streaming<crate::Result>, ArrowError> {
343        let req = self.set_request_headers(request.into_request())?;
344        Ok(self
345            .flight_client
346            .do_action(req)
347            .await
348            .map_err(status_to_arrow_error)?
349            .into_inner())
350    }
351
352    /// Request a list of tables.
353    pub async fn get_tables(
354        &mut self,
355        request: CommandGetTables,
356    ) -> Result<FlightInfo, ArrowError> {
357        self.get_flight_info_for_command(request).await
358    }
359
360    /// Request the primary keys for a table.
361    pub async fn get_primary_keys(
362        &mut self,
363        request: CommandGetPrimaryKeys,
364    ) -> Result<FlightInfo, ArrowError> {
365        self.get_flight_info_for_command(request).await
366    }
367
368    /// Retrieves a description about the foreign key columns that reference the
369    /// primary key columns of the given table.
370    pub async fn get_exported_keys(
371        &mut self,
372        request: CommandGetExportedKeys,
373    ) -> Result<FlightInfo, ArrowError> {
374        self.get_flight_info_for_command(request).await
375    }
376
377    /// Retrieves the foreign key columns for the given table.
378    pub async fn get_imported_keys(
379        &mut self,
380        request: CommandGetImportedKeys,
381    ) -> Result<FlightInfo, ArrowError> {
382        self.get_flight_info_for_command(request).await
383    }
384
385    /// Retrieves a description of the foreign key columns in the given foreign key
386    /// table that reference the primary key or the columns representing a unique
387    /// constraint of the parent table (could be the same or a different table).
388    pub async fn get_cross_reference(
389        &mut self,
390        request: CommandGetCrossReference,
391    ) -> Result<FlightInfo, ArrowError> {
392        self.get_flight_info_for_command(request).await
393    }
394
395    /// Request a list of table types.
396    pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
397        self.get_flight_info_for_command(CommandGetTableTypes {})
398            .await
399    }
400
401    /// Request a list of SQL information.
402    pub async fn get_sql_info(
403        &mut self,
404        sql_infos: Vec<SqlInfo>,
405    ) -> Result<FlightInfo, ArrowError> {
406        let request = CommandGetSqlInfo {
407            info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
408        };
409        self.get_flight_info_for_command(request).await
410    }
411
412    /// Request XDBC SQL information.
413    pub async fn get_xdbc_type_info(
414        &mut self,
415        request: CommandGetXdbcTypeInfo,
416    ) -> Result<FlightInfo, ArrowError> {
417        self.get_flight_info_for_command(request).await
418    }
419
420    /// Create a prepared statement object.
421    pub async fn prepare(
422        &mut self,
423        query: String,
424        transaction_id: Option<Bytes>,
425    ) -> Result<PreparedStatement<T>, ArrowError>
426    where
427        T: Clone,
428    {
429        let cmd = ActionCreatePreparedStatementRequest {
430            query,
431            transaction_id,
432        };
433        let action = Action {
434            r#type: CREATE_PREPARED_STATEMENT.to_string(),
435            body: cmd.as_any().encode_to_vec().into(),
436        };
437        let req = self.set_request_headers(action.into_request())?;
438        let mut result = self
439            .flight_client
440            .do_action(req)
441            .await
442            .map_err(status_to_arrow_error)?
443            .into_inner();
444        let result = result
445            .message()
446            .await
447            .map_err(status_to_arrow_error)?
448            .unwrap();
449        let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
450        let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
451        let dataset_schema = match prepared_result.dataset_schema.len() {
452            0 => Schema::empty(),
453            _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
454        };
455        let parameter_schema = match prepared_result.parameter_schema.len() {
456            0 => Schema::empty(),
457            _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
458        };
459        Ok(PreparedStatement::new(
460            self.clone(),
461            prepared_result.prepared_statement_handle,
462            dataset_schema,
463            parameter_schema,
464        ))
465    }
466
467    /// Request to begin a transaction.
468    pub async fn begin_transaction(&mut self) -> Result<Bytes, ArrowError> {
469        let cmd = ActionBeginTransactionRequest {};
470        let action = Action {
471            r#type: BEGIN_TRANSACTION.to_string(),
472            body: cmd.as_any().encode_to_vec().into(),
473        };
474        let req = self.set_request_headers(action.into_request())?;
475        let mut result = self
476            .flight_client
477            .do_action(req)
478            .await
479            .map_err(status_to_arrow_error)?
480            .into_inner();
481        let result = result
482            .message()
483            .await
484            .map_err(status_to_arrow_error)?
485            .unwrap();
486        let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
487        let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap();
488        Ok(begin_result.transaction_id)
489    }
490
491    /// Request to commit/rollback a transaction.
492    pub async fn end_transaction(
493        &mut self,
494        transaction_id: Bytes,
495        action: EndTransaction,
496    ) -> Result<(), ArrowError> {
497        let cmd = ActionEndTransactionRequest {
498            transaction_id,
499            action: action as i32,
500        };
501        let action = Action {
502            r#type: END_TRANSACTION.to_string(),
503            body: cmd.as_any().encode_to_vec().into(),
504        };
505        let req = self.set_request_headers(action.into_request())?;
506        let _ = self
507            .flight_client
508            .do_action(req)
509            .await
510            .map_err(status_to_arrow_error)?
511            .into_inner();
512        Ok(())
513    }
514
515    /// Explicitly shut down and clean up the client.
516    pub async fn close(&mut self) -> Result<(), ArrowError> {
517        // TODO: consume self instead of &mut self to explicitly prevent reuse?
518        Ok(())
519    }
520
521    fn set_request_headers<M>(
522        &self,
523        mut req: tonic::Request<M>,
524    ) -> Result<tonic::Request<M>, ArrowError> {
525        for (k, v) in &self.headers {
526            let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
527                ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
528            })?;
529            let v = v.parse().map_err(|e| {
530                ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}"))
531            })?;
532            req.metadata_mut().insert(k, v);
533        }
534        if let Some(token) = &self.token {
535            let val = format!("Bearer {token}").parse().map_err(|e| {
536                ArrowError::ParseError(format!("Cannot convert token to header value: {e}"))
537            })?;
538            req.metadata_mut().insert("authorization", val);
539        }
540        Ok(req)
541    }
542}
543
544impl<T: Clone> Clone for FlightSqlServiceClient<T> {
545    fn clone(&self) -> Self {
546        Self {
547            headers: self.headers.clone(),
548            token: self.token.clone(),
549            flight_client: self.flight_client.clone(),
550        }
551    }
552}
553
554/// A PreparedStatement
555#[derive(Debug, Clone)]
556pub struct PreparedStatement<T> {
557    flight_sql_client: FlightSqlServiceClient<T>,
558    parameter_binding: Option<RecordBatch>,
559    handle: Bytes,
560    dataset_schema: Schema,
561    parameter_schema: Schema,
562}
563
564impl<T> PreparedStatement<T>
565where
566    T: tonic::client::GrpcService<tonic::body::Body>,
567    T::Error: Into<StdError>,
568    T::ResponseBody: Body<Data = Bytes> + Send + 'static,
569    <T::ResponseBody as Body>::Error: Into<StdError> + Send,
570{
571    pub(crate) fn new(
572        flight_client: FlightSqlServiceClient<T>,
573        handle: impl Into<Bytes>,
574        dataset_schema: Schema,
575        parameter_schema: Schema,
576    ) -> Self {
577        PreparedStatement {
578            flight_sql_client: flight_client,
579            parameter_binding: None,
580            handle: handle.into(),
581            dataset_schema,
582            parameter_schema,
583        }
584    }
585
586    /// Executes the prepared statement query on the server.
587    pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
588        self.write_bind_params().await?;
589
590        let cmd = CommandPreparedStatementQuery {
591            prepared_statement_handle: self.handle.clone(),
592        };
593
594        let result = self
595            .flight_sql_client
596            .get_flight_info_for_command(cmd)
597            .await?;
598        Ok(result)
599    }
600
601    /// Executes the prepared statement update query on the server.
602    pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
603        self.write_bind_params().await?;
604
605        let cmd = CommandPreparedStatementUpdate {
606            prepared_statement_handle: self.handle.clone(),
607        };
608        let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
609        let mut result = self
610            .flight_sql_client
611            .do_put(stream::iter(vec![FlightData {
612                flight_descriptor: Some(descriptor),
613                ..Default::default()
614            }]))
615            .await?;
616        let result = result
617            .message()
618            .await
619            .map_err(status_to_arrow_error)?
620            .unwrap();
621        let result: DoPutUpdateResult =
622            Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
623        Ok(result.record_count)
624    }
625
626    /// Retrieve the parameter schema from the query.
627    pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
628        Ok(&self.parameter_schema)
629    }
630
631    /// Retrieve the ResultSet schema from the query.
632    pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
633        Ok(&self.dataset_schema)
634    }
635
636    /// Set a RecordBatch that contains the parameters that will be bind.
637    pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> {
638        self.parameter_binding = Some(parameter_binding);
639        Ok(())
640    }
641
642    /// Submit parameters to the server, if any have been set on this prepared statement instance
643    /// Updates our stored prepared statement handle with the handle given by the server response.
644    async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
645        if let Some(ref params_batch) = self.parameter_binding {
646            let cmd = CommandPreparedStatementQuery {
647                prepared_statement_handle: self.handle.clone(),
648            };
649
650            let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
651            let flight_stream_builder = FlightDataEncoderBuilder::new()
652                .with_flight_descriptor(Some(descriptor))
653                .with_schema(params_batch.schema());
654            let flight_data = flight_stream_builder
655                .build(futures::stream::iter(
656                    self.parameter_binding.clone().map(Ok),
657                ))
658                .try_collect::<Vec<_>>()
659                .await
660                .map_err(flight_error_to_arrow_error)?;
661
662            // Attempt to update the stored handle with any updated handle in the DoPut result.
663            // Older servers do not respond with a result for DoPut, so skip this step when
664            // the stream closes with no response.
665            if let Some(result) = self
666                .flight_sql_client
667                .do_put(stream::iter(flight_data))
668                .await?
669                .message()
670                .await
671                .map_err(status_to_arrow_error)?
672            {
673                if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
674                    self.handle = handle;
675                }
676            }
677        }
678        Ok(())
679    }
680
681    /// Decodes the app_metadata stored in a [`PutResult`] as a
682    /// [`DoPutPreparedStatementResult`] and then returns
683    /// the inner prepared statement handle as [`Bytes`]
684    fn unpack_prepared_statement_handle(
685        &self,
686        put_result: &PutResult,
687    ) -> Result<Option<Bytes>, ArrowError> {
688        let result: DoPutPreparedStatementResult =
689            Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
690        Ok(result.prepared_statement_handle)
691    }
692
693    /// Close the prepared statement, so that this PreparedStatement can not used
694    /// anymore and server can free up any resources.
695    pub async fn close(mut self) -> Result<(), ArrowError> {
696        let cmd = ActionClosePreparedStatementRequest {
697            prepared_statement_handle: self.handle.clone(),
698        };
699        let action = Action {
700            r#type: CLOSE_PREPARED_STATEMENT.to_string(),
701            body: cmd.as_any().encode_to_vec().into(),
702        };
703        let _ = self.flight_sql_client.do_action(action).await?;
704        Ok(())
705    }
706}
707
708fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
709    ArrowError::IpcError(err.to_string())
710}
711
712fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
713    ArrowError::IpcError(format!("{status:?}"))
714}
715
716fn flight_error_to_arrow_error(err: FlightError) -> ArrowError {
717    match err {
718        FlightError::Arrow(e) => e,
719        e => ArrowError::ExternalError(Box::new(e)),
720    }
721}
722
723/// A polymorphic structure to natively represent different types of data contained in `FlightData`
724pub enum ArrowFlightData {
725    /// A record batch
726    RecordBatch(RecordBatch),
727    /// A schema
728    Schema(Schema),
729}
730
731/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation
732pub fn arrow_data_from_flight_data(
733    flight_data: FlightData,
734    arrow_schema_ref: &SchemaRef,
735) -> Result<ArrowFlightData, ArrowError> {
736    let ipc_message = root_as_message(&flight_data.data_header[..])
737        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
738
739    match ipc_message.header_type() {
740        MessageHeader::RecordBatch => {
741            let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| {
742                ArrowError::ComputeError(
743                    "Unable to convert flight data header to a record batch".to_string(),
744                )
745            })?;
746
747            let dictionaries_by_field = HashMap::new();
748            let record_batch = read_record_batch(
749                &Buffer::from(flight_data.data_body),
750                ipc_record_batch,
751                arrow_schema_ref.clone(),
752                &dictionaries_by_field,
753                None,
754                &ipc_message.version(),
755            )?;
756            Ok(ArrowFlightData::RecordBatch(record_batch))
757        }
758        MessageHeader::Schema => {
759            let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
760                ArrowError::ComputeError(
761                    "Unable to convert flight data header to a schema".to_string(),
762                )
763            })?;
764
765            let arrow_schema = fb_to_schema(ipc_schema);
766            Ok(ArrowFlightData::Schema(arrow_schema))
767        }
768        MessageHeader::DictionaryBatch => {
769            let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
770                ArrowError::ComputeError(
771                    "Unable to convert flight data header to a dictionary batch".to_string(),
772                )
773            })?;
774            Err(ArrowError::NotYetImplemented(
775                "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(),
776            ))
777        }
778        MessageHeader::Tensor => {
779            let _ = ipc_message.header_as_tensor().ok_or_else(|| {
780                ArrowError::ComputeError(
781                    "Unable to convert flight data header to a tensor".to_string(),
782                )
783            })?;
784            Err(ArrowError::NotYetImplemented(
785                "no idea on how to convert an ipc tensor to an arrow type".to_string(),
786            ))
787        }
788        MessageHeader::SparseTensor => {
789            let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
790                ArrowError::ComputeError(
791                    "Unable to convert flight data header to a sparse tensor".to_string(),
792                )
793            })?;
794            Err(ArrowError::NotYetImplemented(
795                "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(),
796            ))
797        }
798        _ => Err(ArrowError::ComputeError(format!(
799            "Unable to convert message with header_type: '{:?}' to arrow data",
800            ipc_message.header_type()
801        ))),
802    }
803}