1use arrow_schema::ArrowError;
42use bytes::Bytes;
43use paste::paste;
44use prost::Message;
45
46#[allow(clippy::all)]
47mod r#gen {
48    #![allow(missing_docs)]
50    include!("arrow.flight.protocol.sql.rs");
51}
52
53pub use r#gen::ActionBeginSavepointRequest;
54pub use r#gen::ActionBeginSavepointResult;
55pub use r#gen::ActionBeginTransactionRequest;
56pub use r#gen::ActionBeginTransactionResult;
57pub use r#gen::ActionCancelQueryRequest;
58pub use r#gen::ActionCancelQueryResult;
59pub use r#gen::ActionClosePreparedStatementRequest;
60pub use r#gen::ActionCreatePreparedStatementRequest;
61pub use r#gen::ActionCreatePreparedStatementResult;
62pub use r#gen::ActionCreatePreparedSubstraitPlanRequest;
63pub use r#gen::ActionEndSavepointRequest;
64pub use r#gen::ActionEndTransactionRequest;
65pub use r#gen::CommandGetCatalogs;
66pub use r#gen::CommandGetCrossReference;
67pub use r#gen::CommandGetDbSchemas;
68pub use r#gen::CommandGetExportedKeys;
69pub use r#gen::CommandGetImportedKeys;
70pub use r#gen::CommandGetPrimaryKeys;
71pub use r#gen::CommandGetSqlInfo;
72pub use r#gen::CommandGetTableTypes;
73pub use r#gen::CommandGetTables;
74pub use r#gen::CommandGetXdbcTypeInfo;
75pub use r#gen::CommandPreparedStatementQuery;
76pub use r#gen::CommandPreparedStatementUpdate;
77pub use r#gen::CommandStatementIngest;
78pub use r#gen::CommandStatementQuery;
79pub use r#gen::CommandStatementSubstraitPlan;
80pub use r#gen::CommandStatementUpdate;
81pub use r#gen::DoPutPreparedStatementResult;
82pub use r#gen::DoPutUpdateResult;
83pub use r#gen::Nullable;
84pub use r#gen::Searchable;
85pub use r#gen::SqlInfo;
86pub use r#gen::SqlNullOrdering;
87pub use r#gen::SqlOuterJoinsSupportLevel;
88pub use r#gen::SqlSupportedCaseSensitivity;
89pub use r#gen::SqlSupportedElementActions;
90pub use r#gen::SqlSupportedGroupBy;
91pub use r#gen::SqlSupportedPositionedCommands;
92pub use r#gen::SqlSupportedResultSetConcurrency;
93pub use r#gen::SqlSupportedResultSetType;
94pub use r#gen::SqlSupportedSubqueries;
95pub use r#gen::SqlSupportedTransaction;
96pub use r#gen::SqlSupportedTransactions;
97pub use r#gen::SqlSupportedUnions;
98pub use r#gen::SqlSupportsConvert;
99pub use r#gen::SqlTransactionIsolationLevel;
100pub use r#gen::SubstraitPlan;
101pub use r#gen::SupportedSqlGrammar;
102pub use r#gen::TicketStatementQuery;
103pub use r#gen::UpdateDeleteRules;
104pub use r#gen::XdbcDataType;
105pub use r#gen::XdbcDatetimeSubcode;
106pub use r#gen::action_end_transaction_request::EndTransaction;
107pub use r#gen::command_statement_ingest::TableDefinitionOptions;
108pub use r#gen::command_statement_ingest::table_definition_options::{
109    TableExistsOption, TableNotExistOption,
110};
111
112pub mod client;
113pub mod metadata;
114pub mod server;
115
116pub use crate::streams::FallibleRequestStream;
117
118pub trait ProstMessageExt: prost::Message + Default {
120    fn type_url() -> &'static str;
122
123    fn as_any(&self) -> Any;
125}
126
127macro_rules! as_item {
132    ($i:item) => {
133        $i
134    };
135}
136
137macro_rules! prost_message_ext {
138    ($($name:tt,)*) => {
139        paste! {
140            $(
141            const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
142            )*
143
144                as_item! {
145                #[derive(Clone, Debug, PartialEq)]
168                pub enum Command {
169                    $(
170                        #[doc = concat!(stringify!($name), "variant")]
171                        $name($name),)*
172
173                    Unknown(Any),
175                }
176            }
177
178            impl Command {
179                pub fn into_any(self) -> Any {
181                    match self {
182                        $(
183                        Self::$name(cmd) => cmd.as_any(),
184                        )*
185                        Self::Unknown(any) => any,
186                    }
187                }
188
189                pub fn type_url(&self) -> &str {
191                    match self {
192                        $(
193                        Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
194                        )*
195                        Self::Unknown(any) => any.type_url.as_str(),
196                    }
197                }
198            }
199
200            impl TryFrom<Any> for Command {
201                type Error = ArrowError;
202
203                fn try_from(any: Any) -> Result<Self, Self::Error> {
204                    match any.type_url.as_str() {
205                        $(
206                        [<$name:snake:upper _TYPE_URL>]
207                            => {
208                                let m: $name = Message::decode(&*any.value).map_err(|err| {
209                                    ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
210                                })?;
211                                Ok(Self::$name(m))
212                            }
213                        )*
214                        _ => Ok(Self::Unknown(any)),
215                    }
216                }
217            }
218
219            $(
220                impl ProstMessageExt for $name {
221                    fn type_url() -> &'static str {
222                        [<$name:snake:upper _TYPE_URL>]
223                    }
224
225                    fn as_any(&self) -> Any {
226                        Any {
227                            type_url: <$name>::type_url().to_string(),
228                            value: self.encode_to_vec().into(),
229                        }
230                    }
231                }
232            )*
233        }
234    };
235}
236
237prost_message_ext!(
239    ActionBeginSavepointRequest,
240    ActionBeginSavepointResult,
241    ActionBeginTransactionRequest,
242    ActionBeginTransactionResult,
243    ActionCancelQueryRequest,
244    ActionCancelQueryResult,
245    ActionClosePreparedStatementRequest,
246    ActionCreatePreparedStatementRequest,
247    ActionCreatePreparedStatementResult,
248    ActionCreatePreparedSubstraitPlanRequest,
249    ActionEndSavepointRequest,
250    ActionEndTransactionRequest,
251    CommandGetCatalogs,
252    CommandGetCrossReference,
253    CommandGetDbSchemas,
254    CommandGetExportedKeys,
255    CommandGetImportedKeys,
256    CommandGetPrimaryKeys,
257    CommandGetSqlInfo,
258    CommandGetTableTypes,
259    CommandGetTables,
260    CommandGetXdbcTypeInfo,
261    CommandPreparedStatementQuery,
262    CommandPreparedStatementUpdate,
263    CommandStatementIngest,
264    CommandStatementQuery,
265    CommandStatementSubstraitPlan,
266    CommandStatementUpdate,
267    DoPutPreparedStatementResult,
268    DoPutUpdateResult,
269    TicketStatementQuery,
270);
271
272#[derive(Clone, PartialEq, ::prost::Message)]
290pub struct Any {
291    #[prost(string, tag = "1")]
298    pub type_url: String,
299    #[prost(bytes = "bytes", tag = "2")]
301    pub value: Bytes,
302}
303
304impl Any {
305    pub fn is<M: ProstMessageExt>(&self) -> bool {
307        M::type_url() == self.type_url
308    }
309
310    pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
312        if !self.is::<M>() {
313            return Ok(None);
314        }
315        let m = Message::decode(&*self.value)
316            .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
317        Ok(Some(m))
318    }
319
320    pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
322        Ok(message.as_any())
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_type_url() {
332        assert_eq!(
333            TicketStatementQuery::type_url(),
334            "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
335        );
336        assert_eq!(
337            CommandStatementQuery::type_url(),
338            "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
339        );
340    }
341
342    #[test]
343    fn test_prost_any_pack_unpack() {
344        let query = CommandStatementQuery {
345            query: "select 1".to_string(),
346            transaction_id: None,
347        };
348        let any = Any::pack(&query).unwrap();
349        assert!(any.is::<CommandStatementQuery>());
350        let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
351        assert_eq!(query, unpack_query);
352    }
353
354    #[test]
355    fn test_command() {
356        let query = CommandStatementQuery {
357            query: "select 1".to_string(),
358            transaction_id: None,
359        };
360        let any = Any::pack(&query).unwrap();
361        let cmd: Command = any.try_into().unwrap();
362
363        assert!(matches!(cmd, Command::CommandStatementQuery(_)));
364        assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
365
366        let any = Any {
369            type_url: "fake_url".to_string(),
370            value: Default::default(),
371        };
372
373        let cmd: Command = any.try_into().unwrap();
374        assert!(matches!(cmd, Command::Unknown(_)));
375        assert_eq!(cmd.type_url(), "fake_url");
376    }
377}