1use arrow_schema::ArrowError;
42use bytes::Bytes;
43use paste::paste;
44use prost::Message;
45
46#[allow(clippy::all)]
47mod gen {
48 #![allow(missing_docs)]
50 include!("arrow.flight.protocol.sql.rs");
51}
52
53pub use gen::action_end_transaction_request::EndTransaction;
54pub use gen::command_statement_ingest::table_definition_options::{
55 TableExistsOption, TableNotExistOption,
56};
57pub use gen::command_statement_ingest::TableDefinitionOptions;
58pub use gen::ActionBeginSavepointRequest;
59pub use gen::ActionBeginSavepointResult;
60pub use gen::ActionBeginTransactionRequest;
61pub use gen::ActionBeginTransactionResult;
62pub use gen::ActionCancelQueryRequest;
63pub use gen::ActionCancelQueryResult;
64pub use gen::ActionClosePreparedStatementRequest;
65pub use gen::ActionCreatePreparedStatementRequest;
66pub use gen::ActionCreatePreparedStatementResult;
67pub use gen::ActionCreatePreparedSubstraitPlanRequest;
68pub use gen::ActionEndSavepointRequest;
69pub use gen::ActionEndTransactionRequest;
70pub use gen::CommandGetCatalogs;
71pub use gen::CommandGetCrossReference;
72pub use gen::CommandGetDbSchemas;
73pub use gen::CommandGetExportedKeys;
74pub use gen::CommandGetImportedKeys;
75pub use gen::CommandGetPrimaryKeys;
76pub use gen::CommandGetSqlInfo;
77pub use gen::CommandGetTableTypes;
78pub use gen::CommandGetTables;
79pub use gen::CommandGetXdbcTypeInfo;
80pub use gen::CommandPreparedStatementQuery;
81pub use gen::CommandPreparedStatementUpdate;
82pub use gen::CommandStatementIngest;
83pub use gen::CommandStatementQuery;
84pub use gen::CommandStatementSubstraitPlan;
85pub use gen::CommandStatementUpdate;
86pub use gen::DoPutPreparedStatementResult;
87pub use gen::DoPutUpdateResult;
88pub use gen::Nullable;
89pub use gen::Searchable;
90pub use gen::SqlInfo;
91pub use gen::SqlNullOrdering;
92pub use gen::SqlOuterJoinsSupportLevel;
93pub use gen::SqlSupportedCaseSensitivity;
94pub use gen::SqlSupportedElementActions;
95pub use gen::SqlSupportedGroupBy;
96pub use gen::SqlSupportedPositionedCommands;
97pub use gen::SqlSupportedResultSetConcurrency;
98pub use gen::SqlSupportedResultSetType;
99pub use gen::SqlSupportedSubqueries;
100pub use gen::SqlSupportedTransaction;
101pub use gen::SqlSupportedTransactions;
102pub use gen::SqlSupportedUnions;
103pub use gen::SqlSupportsConvert;
104pub use gen::SqlTransactionIsolationLevel;
105pub use gen::SubstraitPlan;
106pub use gen::SupportedSqlGrammar;
107pub use gen::TicketStatementQuery;
108pub use gen::UpdateDeleteRules;
109pub use gen::XdbcDataType;
110pub use gen::XdbcDatetimeSubcode;
111
112pub mod client;
113pub mod metadata;
114pub mod server;
115
116pub use crate::streams::FallibleRequestStream;
117
118pub trait ProstMessageExt: prost::Message + Default {
120 fn type_url() -> &'static str;
122
123 fn as_any(&self) -> Any;
125}
126
127macro_rules! as_item {
132 ($i:item) => {
133 $i
134 };
135}
136
137macro_rules! prost_message_ext {
138 ($($name:tt,)*) => {
139 paste! {
140 $(
141 const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
142 )*
143
144 as_item! {
145 #[derive(Clone, Debug, PartialEq)]
168 pub enum Command {
169 $(
170 #[doc = concat!(stringify!($name), "variant")]
171 $name($name),)*
172
173 Unknown(Any),
175 }
176 }
177
178 impl Command {
179 pub fn into_any(self) -> Any {
181 match self {
182 $(
183 Self::$name(cmd) => cmd.as_any(),
184 )*
185 Self::Unknown(any) => any,
186 }
187 }
188
189 pub fn type_url(&self) -> &str {
191 match self {
192 $(
193 Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
194 )*
195 Self::Unknown(any) => any.type_url.as_str(),
196 }
197 }
198 }
199
200 impl TryFrom<Any> for Command {
201 type Error = ArrowError;
202
203 fn try_from(any: Any) -> Result<Self, Self::Error> {
204 match any.type_url.as_str() {
205 $(
206 [<$name:snake:upper _TYPE_URL>]
207 => {
208 let m: $name = Message::decode(&*any.value).map_err(|err| {
209 ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
210 })?;
211 Ok(Self::$name(m))
212 }
213 )*
214 _ => Ok(Self::Unknown(any)),
215 }
216 }
217 }
218
219 $(
220 impl ProstMessageExt for $name {
221 fn type_url() -> &'static str {
222 [<$name:snake:upper _TYPE_URL>]
223 }
224
225 fn as_any(&self) -> Any {
226 Any {
227 type_url: <$name>::type_url().to_string(),
228 value: self.encode_to_vec().into(),
229 }
230 }
231 }
232 )*
233 }
234 };
235}
236
237prost_message_ext!(
239 ActionBeginSavepointRequest,
240 ActionBeginSavepointResult,
241 ActionBeginTransactionRequest,
242 ActionBeginTransactionResult,
243 ActionCancelQueryRequest,
244 ActionCancelQueryResult,
245 ActionClosePreparedStatementRequest,
246 ActionCreatePreparedStatementRequest,
247 ActionCreatePreparedStatementResult,
248 ActionCreatePreparedSubstraitPlanRequest,
249 ActionEndSavepointRequest,
250 ActionEndTransactionRequest,
251 CommandGetCatalogs,
252 CommandGetCrossReference,
253 CommandGetDbSchemas,
254 CommandGetExportedKeys,
255 CommandGetImportedKeys,
256 CommandGetPrimaryKeys,
257 CommandGetSqlInfo,
258 CommandGetTableTypes,
259 CommandGetTables,
260 CommandGetXdbcTypeInfo,
261 CommandPreparedStatementQuery,
262 CommandPreparedStatementUpdate,
263 CommandStatementIngest,
264 CommandStatementQuery,
265 CommandStatementSubstraitPlan,
266 CommandStatementUpdate,
267 DoPutPreparedStatementResult,
268 DoPutUpdateResult,
269 TicketStatementQuery,
270);
271
272#[derive(Clone, PartialEq, ::prost::Message)]
290pub struct Any {
291 #[prost(string, tag = "1")]
298 pub type_url: String,
299 #[prost(bytes = "bytes", tag = "2")]
301 pub value: Bytes,
302}
303
304impl Any {
305 pub fn is<M: ProstMessageExt>(&self) -> bool {
307 M::type_url() == self.type_url
308 }
309
310 pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
312 if !self.is::<M>() {
313 return Ok(None);
314 }
315 let m = Message::decode(&*self.value)
316 .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
317 Ok(Some(m))
318 }
319
320 pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
322 Ok(message.as_any())
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_type_url() {
332 assert_eq!(
333 TicketStatementQuery::type_url(),
334 "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
335 );
336 assert_eq!(
337 CommandStatementQuery::type_url(),
338 "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
339 );
340 }
341
342 #[test]
343 fn test_prost_any_pack_unpack() {
344 let query = CommandStatementQuery {
345 query: "select 1".to_string(),
346 transaction_id: None,
347 };
348 let any = Any::pack(&query).unwrap();
349 assert!(any.is::<CommandStatementQuery>());
350 let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
351 assert_eq!(query, unpack_query);
352 }
353
354 #[test]
355 fn test_command() {
356 let query = CommandStatementQuery {
357 query: "select 1".to_string(),
358 transaction_id: None,
359 };
360 let any = Any::pack(&query).unwrap();
361 let cmd: Command = any.try_into().unwrap();
362
363 assert!(matches!(cmd, Command::CommandStatementQuery(_)));
364 assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
365
366 let any = Any {
369 type_url: "fake_url".to_string(),
370 value: Default::default(),
371 };
372
373 let cmd: Command = any.try_into().unwrap();
374 assert!(matches!(cmd, Command::Unknown(_)));
375 assert_eq!(cmd.type_url(), "fake_url");
376 }
377}