1use arrow_schema::ArrowError;
42use bytes::Bytes;
43use paste::paste;
44use prost::Message;
45
46#[allow(clippy::all)]
47mod gen {
48 #![allow(rustdoc::unportable_markdown)]
49 #![allow(missing_docs)]
51 include!("arrow.flight.protocol.sql.rs");
52}
53
54pub use gen::action_end_transaction_request::EndTransaction;
55pub use gen::command_statement_ingest::table_definition_options::{
56 TableExistsOption, TableNotExistOption,
57};
58pub use gen::command_statement_ingest::TableDefinitionOptions;
59pub use gen::ActionBeginSavepointRequest;
60pub use gen::ActionBeginSavepointResult;
61pub use gen::ActionBeginTransactionRequest;
62pub use gen::ActionBeginTransactionResult;
63pub use gen::ActionCancelQueryRequest;
64pub use gen::ActionCancelQueryResult;
65pub use gen::ActionClosePreparedStatementRequest;
66pub use gen::ActionCreatePreparedStatementRequest;
67pub use gen::ActionCreatePreparedStatementResult;
68pub use gen::ActionCreatePreparedSubstraitPlanRequest;
69pub use gen::ActionEndSavepointRequest;
70pub use gen::ActionEndTransactionRequest;
71pub use gen::CommandGetCatalogs;
72pub use gen::CommandGetCrossReference;
73pub use gen::CommandGetDbSchemas;
74pub use gen::CommandGetExportedKeys;
75pub use gen::CommandGetImportedKeys;
76pub use gen::CommandGetPrimaryKeys;
77pub use gen::CommandGetSqlInfo;
78pub use gen::CommandGetTableTypes;
79pub use gen::CommandGetTables;
80pub use gen::CommandGetXdbcTypeInfo;
81pub use gen::CommandPreparedStatementQuery;
82pub use gen::CommandPreparedStatementUpdate;
83pub use gen::CommandStatementIngest;
84pub use gen::CommandStatementQuery;
85pub use gen::CommandStatementSubstraitPlan;
86pub use gen::CommandStatementUpdate;
87pub use gen::DoPutPreparedStatementResult;
88pub use gen::DoPutUpdateResult;
89pub use gen::Nullable;
90pub use gen::Searchable;
91pub use gen::SqlInfo;
92pub use gen::SqlNullOrdering;
93pub use gen::SqlOuterJoinsSupportLevel;
94pub use gen::SqlSupportedCaseSensitivity;
95pub use gen::SqlSupportedElementActions;
96pub use gen::SqlSupportedGroupBy;
97pub use gen::SqlSupportedPositionedCommands;
98pub use gen::SqlSupportedResultSetConcurrency;
99pub use gen::SqlSupportedResultSetType;
100pub use gen::SqlSupportedSubqueries;
101pub use gen::SqlSupportedTransaction;
102pub use gen::SqlSupportedTransactions;
103pub use gen::SqlSupportedUnions;
104pub use gen::SqlSupportsConvert;
105pub use gen::SqlTransactionIsolationLevel;
106pub use gen::SubstraitPlan;
107pub use gen::SupportedSqlGrammar;
108pub use gen::TicketStatementQuery;
109pub use gen::UpdateDeleteRules;
110pub use gen::XdbcDataType;
111pub use gen::XdbcDatetimeSubcode;
112
113pub mod client;
114pub mod metadata;
115pub mod server;
116
117pub trait ProstMessageExt: prost::Message + Default {
119 fn type_url() -> &'static str;
121
122 fn as_any(&self) -> Any;
124}
125
126macro_rules! as_item {
131 ($i:item) => {
132 $i
133 };
134}
135
136macro_rules! prost_message_ext {
137 ($($name:tt,)*) => {
138 paste! {
139 $(
140 const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
141 )*
142
143 as_item! {
144 #[derive(Clone, Debug, PartialEq)]
167 pub enum Command {
168 $(
169 #[doc = concat!(stringify!($name), "variant")]
170 $name($name),)*
171
172 Unknown(Any),
174 }
175 }
176
177 impl Command {
178 pub fn into_any(self) -> Any {
180 match self {
181 $(
182 Self::$name(cmd) => cmd.as_any(),
183 )*
184 Self::Unknown(any) => any,
185 }
186 }
187
188 pub fn type_url(&self) -> &str {
190 match self {
191 $(
192 Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
193 )*
194 Self::Unknown(any) => any.type_url.as_str(),
195 }
196 }
197 }
198
199 impl TryFrom<Any> for Command {
200 type Error = ArrowError;
201
202 fn try_from(any: Any) -> Result<Self, Self::Error> {
203 match any.type_url.as_str() {
204 $(
205 [<$name:snake:upper _TYPE_URL>]
206 => {
207 let m: $name = Message::decode(&*any.value).map_err(|err| {
208 ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
209 })?;
210 Ok(Self::$name(m))
211 }
212 )*
213 _ => Ok(Self::Unknown(any)),
214 }
215 }
216 }
217
218 $(
219 impl ProstMessageExt for $name {
220 fn type_url() -> &'static str {
221 [<$name:snake:upper _TYPE_URL>]
222 }
223
224 fn as_any(&self) -> Any {
225 Any {
226 type_url: <$name>::type_url().to_string(),
227 value: self.encode_to_vec().into(),
228 }
229 }
230 }
231 )*
232 }
233 };
234}
235
236prost_message_ext!(
238 ActionBeginSavepointRequest,
239 ActionBeginSavepointResult,
240 ActionBeginTransactionRequest,
241 ActionBeginTransactionResult,
242 ActionCancelQueryRequest,
243 ActionCancelQueryResult,
244 ActionClosePreparedStatementRequest,
245 ActionCreatePreparedStatementRequest,
246 ActionCreatePreparedStatementResult,
247 ActionCreatePreparedSubstraitPlanRequest,
248 ActionEndSavepointRequest,
249 ActionEndTransactionRequest,
250 CommandGetCatalogs,
251 CommandGetCrossReference,
252 CommandGetDbSchemas,
253 CommandGetExportedKeys,
254 CommandGetImportedKeys,
255 CommandGetPrimaryKeys,
256 CommandGetSqlInfo,
257 CommandGetTableTypes,
258 CommandGetTables,
259 CommandGetXdbcTypeInfo,
260 CommandPreparedStatementQuery,
261 CommandPreparedStatementUpdate,
262 CommandStatementIngest,
263 CommandStatementQuery,
264 CommandStatementSubstraitPlan,
265 CommandStatementUpdate,
266 DoPutPreparedStatementResult,
267 DoPutUpdateResult,
268 TicketStatementQuery,
269);
270
271#[derive(Clone, PartialEq, ::prost::Message)]
289pub struct Any {
290 #[prost(string, tag = "1")]
297 pub type_url: String,
298 #[prost(bytes = "bytes", tag = "2")]
300 pub value: Bytes,
301}
302
303impl Any {
304 pub fn is<M: ProstMessageExt>(&self) -> bool {
306 M::type_url() == self.type_url
307 }
308
309 pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
311 if !self.is::<M>() {
312 return Ok(None);
313 }
314 let m = Message::decode(&*self.value)
315 .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
316 Ok(Some(m))
317 }
318
319 pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
321 Ok(message.as_any())
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_type_url() {
331 assert_eq!(
332 TicketStatementQuery::type_url(),
333 "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
334 );
335 assert_eq!(
336 CommandStatementQuery::type_url(),
337 "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
338 );
339 }
340
341 #[test]
342 fn test_prost_any_pack_unpack() {
343 let query = CommandStatementQuery {
344 query: "select 1".to_string(),
345 transaction_id: None,
346 };
347 let any = Any::pack(&query).unwrap();
348 assert!(any.is::<CommandStatementQuery>());
349 let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
350 assert_eq!(query, unpack_query);
351 }
352
353 #[test]
354 fn test_command() {
355 let query = CommandStatementQuery {
356 query: "select 1".to_string(),
357 transaction_id: None,
358 };
359 let any = Any::pack(&query).unwrap();
360 let cmd: Command = any.try_into().unwrap();
361
362 assert!(matches!(cmd, Command::CommandStatementQuery(_)));
363 assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
364
365 let any = Any {
368 type_url: "fake_url".to_string(),
369 value: Default::default(),
370 };
371
372 let cmd: Command = any.try_into().unwrap();
373 assert!(matches!(cmd, Command::Unknown(_)));
374 assert_eq!(cmd.type_url(), "fake_url");
375 }
376}