1use arrow_schema::ArrowError;
42use bytes::Bytes;
43use paste::paste;
44use prost::Message;
45
46#[allow(clippy::all)]
47mod gen {
48 #![allow(missing_docs)]
50 include!("arrow.flight.protocol.sql.rs");
51}
52
53pub use gen::action_end_transaction_request::EndTransaction;
54pub use gen::command_statement_ingest::table_definition_options::{
55 TableExistsOption, TableNotExistOption,
56};
57pub use gen::command_statement_ingest::TableDefinitionOptions;
58pub use gen::ActionBeginSavepointRequest;
59pub use gen::ActionBeginSavepointResult;
60pub use gen::ActionBeginTransactionRequest;
61pub use gen::ActionBeginTransactionResult;
62pub use gen::ActionCancelQueryRequest;
63pub use gen::ActionCancelQueryResult;
64pub use gen::ActionClosePreparedStatementRequest;
65pub use gen::ActionCreatePreparedStatementRequest;
66pub use gen::ActionCreatePreparedStatementResult;
67pub use gen::ActionCreatePreparedSubstraitPlanRequest;
68pub use gen::ActionEndSavepointRequest;
69pub use gen::ActionEndTransactionRequest;
70pub use gen::CommandGetCatalogs;
71pub use gen::CommandGetCrossReference;
72pub use gen::CommandGetDbSchemas;
73pub use gen::CommandGetExportedKeys;
74pub use gen::CommandGetImportedKeys;
75pub use gen::CommandGetPrimaryKeys;
76pub use gen::CommandGetSqlInfo;
77pub use gen::CommandGetTableTypes;
78pub use gen::CommandGetTables;
79pub use gen::CommandGetXdbcTypeInfo;
80pub use gen::CommandPreparedStatementQuery;
81pub use gen::CommandPreparedStatementUpdate;
82pub use gen::CommandStatementIngest;
83pub use gen::CommandStatementQuery;
84pub use gen::CommandStatementSubstraitPlan;
85pub use gen::CommandStatementUpdate;
86pub use gen::DoPutPreparedStatementResult;
87pub use gen::DoPutUpdateResult;
88pub use gen::Nullable;
89pub use gen::Searchable;
90pub use gen::SqlInfo;
91pub use gen::SqlNullOrdering;
92pub use gen::SqlOuterJoinsSupportLevel;
93pub use gen::SqlSupportedCaseSensitivity;
94pub use gen::SqlSupportedElementActions;
95pub use gen::SqlSupportedGroupBy;
96pub use gen::SqlSupportedPositionedCommands;
97pub use gen::SqlSupportedResultSetConcurrency;
98pub use gen::SqlSupportedResultSetType;
99pub use gen::SqlSupportedSubqueries;
100pub use gen::SqlSupportedTransaction;
101pub use gen::SqlSupportedTransactions;
102pub use gen::SqlSupportedUnions;
103pub use gen::SqlSupportsConvert;
104pub use gen::SqlTransactionIsolationLevel;
105pub use gen::SubstraitPlan;
106pub use gen::SupportedSqlGrammar;
107pub use gen::TicketStatementQuery;
108pub use gen::UpdateDeleteRules;
109pub use gen::XdbcDataType;
110pub use gen::XdbcDatetimeSubcode;
111
112pub mod client;
113pub mod metadata;
114pub mod server;
115
116pub trait ProstMessageExt: prost::Message + Default {
118 fn type_url() -> &'static str;
120
121 fn as_any(&self) -> Any;
123}
124
125macro_rules! as_item {
130 ($i:item) => {
131 $i
132 };
133}
134
135macro_rules! prost_message_ext {
136 ($($name:tt,)*) => {
137 paste! {
138 $(
139 const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
140 )*
141
142 as_item! {
143 #[derive(Clone, Debug, PartialEq)]
166 pub enum Command {
167 $(
168 #[doc = concat!(stringify!($name), "variant")]
169 $name($name),)*
170
171 Unknown(Any),
173 }
174 }
175
176 impl Command {
177 pub fn into_any(self) -> Any {
179 match self {
180 $(
181 Self::$name(cmd) => cmd.as_any(),
182 )*
183 Self::Unknown(any) => any,
184 }
185 }
186
187 pub fn type_url(&self) -> &str {
189 match self {
190 $(
191 Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
192 )*
193 Self::Unknown(any) => any.type_url.as_str(),
194 }
195 }
196 }
197
198 impl TryFrom<Any> for Command {
199 type Error = ArrowError;
200
201 fn try_from(any: Any) -> Result<Self, Self::Error> {
202 match any.type_url.as_str() {
203 $(
204 [<$name:snake:upper _TYPE_URL>]
205 => {
206 let m: $name = Message::decode(&*any.value).map_err(|err| {
207 ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
208 })?;
209 Ok(Self::$name(m))
210 }
211 )*
212 _ => Ok(Self::Unknown(any)),
213 }
214 }
215 }
216
217 $(
218 impl ProstMessageExt for $name {
219 fn type_url() -> &'static str {
220 [<$name:snake:upper _TYPE_URL>]
221 }
222
223 fn as_any(&self) -> Any {
224 Any {
225 type_url: <$name>::type_url().to_string(),
226 value: self.encode_to_vec().into(),
227 }
228 }
229 }
230 )*
231 }
232 };
233}
234
235prost_message_ext!(
237 ActionBeginSavepointRequest,
238 ActionBeginSavepointResult,
239 ActionBeginTransactionRequest,
240 ActionBeginTransactionResult,
241 ActionCancelQueryRequest,
242 ActionCancelQueryResult,
243 ActionClosePreparedStatementRequest,
244 ActionCreatePreparedStatementRequest,
245 ActionCreatePreparedStatementResult,
246 ActionCreatePreparedSubstraitPlanRequest,
247 ActionEndSavepointRequest,
248 ActionEndTransactionRequest,
249 CommandGetCatalogs,
250 CommandGetCrossReference,
251 CommandGetDbSchemas,
252 CommandGetExportedKeys,
253 CommandGetImportedKeys,
254 CommandGetPrimaryKeys,
255 CommandGetSqlInfo,
256 CommandGetTableTypes,
257 CommandGetTables,
258 CommandGetXdbcTypeInfo,
259 CommandPreparedStatementQuery,
260 CommandPreparedStatementUpdate,
261 CommandStatementIngest,
262 CommandStatementQuery,
263 CommandStatementSubstraitPlan,
264 CommandStatementUpdate,
265 DoPutPreparedStatementResult,
266 DoPutUpdateResult,
267 TicketStatementQuery,
268);
269
270#[derive(Clone, PartialEq, ::prost::Message)]
288pub struct Any {
289 #[prost(string, tag = "1")]
296 pub type_url: String,
297 #[prost(bytes = "bytes", tag = "2")]
299 pub value: Bytes,
300}
301
302impl Any {
303 pub fn is<M: ProstMessageExt>(&self) -> bool {
305 M::type_url() == self.type_url
306 }
307
308 pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
310 if !self.is::<M>() {
311 return Ok(None);
312 }
313 let m = Message::decode(&*self.value)
314 .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
315 Ok(Some(m))
316 }
317
318 pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
320 Ok(message.as_any())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_type_url() {
330 assert_eq!(
331 TicketStatementQuery::type_url(),
332 "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
333 );
334 assert_eq!(
335 CommandStatementQuery::type_url(),
336 "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
337 );
338 }
339
340 #[test]
341 fn test_prost_any_pack_unpack() {
342 let query = CommandStatementQuery {
343 query: "select 1".to_string(),
344 transaction_id: None,
345 };
346 let any = Any::pack(&query).unwrap();
347 assert!(any.is::<CommandStatementQuery>());
348 let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
349 assert_eq!(query, unpack_query);
350 }
351
352 #[test]
353 fn test_command() {
354 let query = CommandStatementQuery {
355 query: "select 1".to_string(),
356 transaction_id: None,
357 };
358 let any = Any::pack(&query).unwrap();
359 let cmd: Command = any.try_into().unwrap();
360
361 assert!(matches!(cmd, Command::CommandStatementQuery(_)));
362 assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
363
364 let any = Any {
367 type_url: "fake_url".to_string(),
368 value: Default::default(),
369 };
370
371 let cmd: Command = any.try_into().unwrap();
372 assert!(matches!(cmd, Command::Unknown(_)));
373 assert_eq!(cmd.type_url(), "fake_url");
374 }
375}