1use arrow_schema::ArrowError;
42use bytes::Bytes;
43use paste::paste;
44use prost::Message;
45
46#[allow(clippy::all)]
47mod r#gen {
48 #![allow(missing_docs)]
50 include!("arrow.flight.protocol.sql.rs");
51}
52
53pub use r#gen::ActionBeginSavepointRequest;
54pub use r#gen::ActionBeginSavepointResult;
55pub use r#gen::ActionBeginTransactionRequest;
56pub use r#gen::ActionBeginTransactionResult;
57pub use r#gen::ActionCancelQueryRequest;
58pub use r#gen::ActionCancelQueryResult;
59pub use r#gen::ActionClosePreparedStatementRequest;
60pub use r#gen::ActionCreatePreparedStatementRequest;
61pub use r#gen::ActionCreatePreparedStatementResult;
62pub use r#gen::ActionCreatePreparedSubstraitPlanRequest;
63pub use r#gen::ActionEndSavepointRequest;
64pub use r#gen::ActionEndTransactionRequest;
65pub use r#gen::CommandGetCatalogs;
66pub use r#gen::CommandGetCrossReference;
67pub use r#gen::CommandGetDbSchemas;
68pub use r#gen::CommandGetExportedKeys;
69pub use r#gen::CommandGetImportedKeys;
70pub use r#gen::CommandGetPrimaryKeys;
71pub use r#gen::CommandGetSqlInfo;
72pub use r#gen::CommandGetTableTypes;
73pub use r#gen::CommandGetTables;
74pub use r#gen::CommandGetXdbcTypeInfo;
75pub use r#gen::CommandPreparedStatementQuery;
76pub use r#gen::CommandPreparedStatementUpdate;
77pub use r#gen::CommandStatementIngest;
78pub use r#gen::CommandStatementQuery;
79pub use r#gen::CommandStatementSubstraitPlan;
80pub use r#gen::CommandStatementUpdate;
81pub use r#gen::DoPutPreparedStatementResult;
82pub use r#gen::DoPutUpdateResult;
83pub use r#gen::Nullable;
84pub use r#gen::Searchable;
85pub use r#gen::SqlInfo;
86pub use r#gen::SqlNullOrdering;
87pub use r#gen::SqlOuterJoinsSupportLevel;
88pub use r#gen::SqlSupportedCaseSensitivity;
89pub use r#gen::SqlSupportedElementActions;
90pub use r#gen::SqlSupportedGroupBy;
91pub use r#gen::SqlSupportedPositionedCommands;
92pub use r#gen::SqlSupportedResultSetConcurrency;
93pub use r#gen::SqlSupportedResultSetType;
94pub use r#gen::SqlSupportedSubqueries;
95pub use r#gen::SqlSupportedTransaction;
96pub use r#gen::SqlSupportedTransactions;
97pub use r#gen::SqlSupportedUnions;
98pub use r#gen::SqlSupportsConvert;
99pub use r#gen::SqlTransactionIsolationLevel;
100pub use r#gen::SubstraitPlan;
101pub use r#gen::SupportedSqlGrammar;
102pub use r#gen::TicketStatementQuery;
103pub use r#gen::UpdateDeleteRules;
104pub use r#gen::XdbcDataType;
105pub use r#gen::XdbcDatetimeSubcode;
106pub use r#gen::action_end_transaction_request::EndTransaction;
107pub use r#gen::command_statement_ingest::TableDefinitionOptions;
108pub use r#gen::command_statement_ingest::table_definition_options::{
109 TableExistsOption, TableNotExistOption,
110};
111
112pub mod client;
113pub mod metadata;
114pub mod server;
115
116pub use crate::streams::FallibleRequestStream;
117
118pub trait ProstMessageExt: prost::Message + Default {
120 fn type_url() -> &'static str;
122
123 fn as_any(&self) -> Any;
125}
126
127macro_rules! as_item {
132 ($i:item) => {
133 $i
134 };
135}
136
137macro_rules! prost_message_ext {
138 ($($name:tt,)*) => {
139 paste! {
140 $(
141 const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
142 )*
143
144 as_item! {
145 #[derive(Clone, Debug, PartialEq)]
168 pub enum Command {
169 $(
170 #[doc = concat!(stringify!($name), "variant")]
171 $name($name),)*
172
173 Unknown(Any),
175 }
176 }
177
178 impl Command {
179 pub fn into_any(self) -> Any {
181 match self {
182 $(
183 Self::$name(cmd) => cmd.as_any(),
184 )*
185 Self::Unknown(any) => any,
186 }
187 }
188
189 pub fn type_url(&self) -> &str {
191 match self {
192 $(
193 Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
194 )*
195 Self::Unknown(any) => any.type_url.as_str(),
196 }
197 }
198 }
199
200 impl TryFrom<Any> for Command {
201 type Error = ArrowError;
202
203 fn try_from(any: Any) -> Result<Self, Self::Error> {
204 match any.type_url.as_str() {
205 $(
206 [<$name:snake:upper _TYPE_URL>]
207 => {
208 let m: $name = Message::decode(&*any.value).map_err(|err| {
209 ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
210 })?;
211 Ok(Self::$name(m))
212 }
213 )*
214 _ => Ok(Self::Unknown(any)),
215 }
216 }
217 }
218
219 $(
220 impl ProstMessageExt for $name {
221 fn type_url() -> &'static str {
222 [<$name:snake:upper _TYPE_URL>]
223 }
224
225 fn as_any(&self) -> Any {
226 Any {
227 type_url: <$name>::type_url().to_string(),
228 value: self.encode_to_vec().into(),
229 }
230 }
231 }
232 )*
233 }
234 };
235}
236
237prost_message_ext!(
239 ActionBeginSavepointRequest,
240 ActionBeginSavepointResult,
241 ActionBeginTransactionRequest,
242 ActionBeginTransactionResult,
243 ActionCancelQueryRequest,
244 ActionCancelQueryResult,
245 ActionClosePreparedStatementRequest,
246 ActionCreatePreparedStatementRequest,
247 ActionCreatePreparedStatementResult,
248 ActionCreatePreparedSubstraitPlanRequest,
249 ActionEndSavepointRequest,
250 ActionEndTransactionRequest,
251 CommandGetCatalogs,
252 CommandGetCrossReference,
253 CommandGetDbSchemas,
254 CommandGetExportedKeys,
255 CommandGetImportedKeys,
256 CommandGetPrimaryKeys,
257 CommandGetSqlInfo,
258 CommandGetTableTypes,
259 CommandGetTables,
260 CommandGetXdbcTypeInfo,
261 CommandPreparedStatementQuery,
262 CommandPreparedStatementUpdate,
263 CommandStatementIngest,
264 CommandStatementQuery,
265 CommandStatementSubstraitPlan,
266 CommandStatementUpdate,
267 DoPutPreparedStatementResult,
268 DoPutUpdateResult,
269 TicketStatementQuery,
270);
271
272#[derive(Clone, PartialEq, ::prost::Message)]
290pub struct Any {
291 #[prost(string, tag = "1")]
298 pub type_url: String,
299 #[prost(bytes = "bytes", tag = "2")]
301 pub value: Bytes,
302}
303
304impl Any {
305 pub fn is<M: ProstMessageExt>(&self) -> bool {
307 M::type_url() == self.type_url
308 }
309
310 pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
312 if !self.is::<M>() {
313 return Ok(None);
314 }
315 let m = Message::decode(&*self.value)
316 .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
317 Ok(Some(m))
318 }
319
320 pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
322 Ok(message.as_any())
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_type_url() {
332 assert_eq!(
333 TicketStatementQuery::type_url(),
334 "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
335 );
336 assert_eq!(
337 CommandStatementQuery::type_url(),
338 "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
339 );
340 }
341
342 #[test]
343 fn test_prost_any_pack_unpack() {
344 let query = CommandStatementQuery {
345 query: "select 1".to_string(),
346 transaction_id: None,
347 };
348 let any = Any::pack(&query).unwrap();
349 assert!(any.is::<CommandStatementQuery>());
350 let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
351 assert_eq!(query, unpack_query);
352 }
353
354 #[test]
355 fn test_command() {
356 let query = CommandStatementQuery {
357 query: "select 1".to_string(),
358 transaction_id: None,
359 };
360 let any = Any::pack(&query).unwrap();
361 let cmd: Command = any.try_into().unwrap();
362
363 assert!(matches!(cmd, Command::CommandStatementQuery(_)));
364 assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
365
366 let any = Any {
369 type_url: "fake_url".to_string(),
370 value: Default::default(),
371 };
372
373 let cmd: Command = any.try_into().unwrap();
374 assert!(matches!(cmd, Command::Unknown(_)));
375 assert_eq!(cmd.type_url(), "fake_url");
376 }
377}