1use std::{sync::Arc, time::Duration};
19
20use anyhow::{bail, Context, Result};
21use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
22use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions};
23use arrow_flight::{
24 sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables},
25 FlightInfo,
26};
27use arrow_schema::Schema;
28use clap::{Parser, Subcommand};
29use core::str;
30use futures::TryStreamExt;
31use tonic::{
32 metadata::MetadataMap,
33 transport::{Channel, ClientTlsConfig, Endpoint},
34};
35use tracing_log::log::info;
36
37#[derive(Debug, Parser)]
39pub struct LoggingArgs {
40 #[clap(
49 short = 'v',
50 long = "verbose",
51 action = clap::ArgAction::Count,
52 )]
53 log_verbose_count: u8,
54}
55
56#[derive(Debug, Parser)]
57struct ClientArgs {
58 #[clap(long = "header", short = 'H', value_parser = parse_key_val)]
64 headers: Vec<(String, String)>,
65
66 #[clap(long, requires = "password")]
70 username: Option<String>,
71
72 #[clap(long, requires = "username")]
76 password: Option<String>,
77
78 #[clap(long)]
80 token: Option<String>,
81
82 #[clap(long)]
86 tls: bool,
87
88 #[clap(long)]
92 host: String,
93
94 #[clap(long)]
98 port: Option<u16>,
99}
100
101#[derive(Debug, Parser)]
102struct Args {
103 #[clap(flatten)]
105 logging_args: LoggingArgs,
106
107 #[clap(flatten)]
109 client_args: ClientArgs,
110
111 #[clap(subcommand)]
112 cmd: Command,
113}
114
115#[derive(Debug, Subcommand)]
117enum Command {
118 Catalogs,
120 DbSchemas {
122 catalog: String,
126 #[clap(short, long)]
132 db_schema_filter: Option<String>,
133 },
134 Tables {
136 catalog: String,
140 #[clap(short, long)]
146 db_schema_filter: Option<String>,
147 #[clap(short, long)]
153 table_filter: Option<String>,
154 #[clap(long)]
158 table_types: Vec<String>,
159 },
160 TableTypes,
162
163 StatementQuery {
165 query: String,
169 },
170
171 PreparedStatementQuery {
173 query: String,
181
182 #[clap(short, value_parser = parse_key_val)]
189 params: Vec<(String, String)>,
190 },
191}
192
193#[tokio::main]
194async fn main() -> Result<()> {
195 let args = Args::parse();
196 setup_logging(args.logging_args)?;
197 let mut client = setup_client(args.client_args)
198 .await
199 .context("setup client")?;
200
201 let flight_info = match args.cmd {
202 Command::Catalogs => client.get_catalogs().await.context("get catalogs")?,
203 Command::DbSchemas {
204 catalog,
205 db_schema_filter,
206 } => client
207 .get_db_schemas(CommandGetDbSchemas {
208 catalog: Some(catalog),
209 db_schema_filter_pattern: db_schema_filter,
210 })
211 .await
212 .context("get db schemas")?,
213 Command::Tables {
214 catalog,
215 db_schema_filter,
216 table_filter,
217 table_types,
218 } => client
219 .get_tables(CommandGetTables {
220 catalog: Some(catalog),
221 db_schema_filter_pattern: db_schema_filter,
222 table_name_filter_pattern: table_filter,
223 table_types,
224 include_schema: false,
228 })
229 .await
230 .context("get tables")?,
231 Command::TableTypes => client.get_table_types().await.context("get table types")?,
232 Command::StatementQuery { query } => client
233 .execute(query, None)
234 .await
235 .context("execute statement")?,
236 Command::PreparedStatementQuery { query, params } => {
237 let mut prepared_stmt = client
238 .prepare(query, None)
239 .await
240 .context("prepare statement")?;
241
242 if !params.is_empty() {
243 prepared_stmt
244 .set_parameters(
245 construct_record_batch_from_params(
246 ¶ms,
247 prepared_stmt
248 .parameter_schema()
249 .context("get parameter schema")?,
250 )
251 .context("construct parameters")?,
252 )
253 .context("bind parameters")?;
254 }
255
256 prepared_stmt
257 .execute()
258 .await
259 .context("execute prepared statement")?
260 }
261 };
262
263 let batches = execute_flight(&mut client, flight_info)
264 .await
265 .context("read flight data")?;
266
267 let res = pretty_format_batches(batches.as_slice()).context("format results")?;
268 println!("{res}");
269
270 Ok(())
271}
272
273async fn execute_flight(
274 client: &mut FlightSqlServiceClient<Channel>,
275 info: FlightInfo,
276) -> Result<Vec<RecordBatch>> {
277 let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?);
278 let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
279 batches.push(RecordBatch::new_empty(schema));
280 info!("decoded schema");
281
282 for endpoint in info.endpoint {
283 let Some(ticket) = &endpoint.ticket else {
284 bail!("did not get ticket");
285 };
286
287 let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?;
288 log_metadata(flight_data.headers(), "header");
289
290 let mut endpoint_batches: Vec<_> = (&mut flight_data)
291 .try_collect()
292 .await
293 .context("collect data stream")?;
294 batches.append(&mut endpoint_batches);
295
296 if let Some(trailers) = flight_data.trailers() {
297 log_metadata(&trailers, "trailer");
298 }
299 }
300 info!("received data");
301
302 Ok(batches)
303}
304
305fn construct_record_batch_from_params(
306 params: &[(String, String)],
307 parameter_schema: &Schema,
308) -> Result<RecordBatch> {
309 let mut items = Vec::<(&String, ArrayRef)>::new();
310
311 for (name, value) in params {
312 let field = parameter_schema.field_with_name(name)?;
313 let value_as_array = StringArray::new_scalar(value);
314 let casted = cast_with_options(
315 value_as_array.get().0,
316 field.data_type(),
317 &CastOptions::default(),
318 )?;
319 items.push((name, casted))
320 }
321
322 Ok(RecordBatch::try_from_iter(items)?)
323}
324
325fn setup_logging(args: LoggingArgs) -> Result<()> {
326 use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber};
327
328 tracing_log::LogTracer::init().context("tracing log init")?;
329
330 let filter = match args.log_verbose_count {
331 0 => "warn",
332 1 => "info",
333 2 => "debug",
334 _ => "trace",
335 };
336 let filter = EnvFilter::try_new(filter).context("set up log env filter")?;
337
338 let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
339 subscriber.try_init().context("init logging subscriber")?;
340
341 Ok(())
342}
343
344async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient<Channel>> {
345 let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
346
347 let protocol = if args.tls { "https" } else { "http" };
348
349 let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
350 .context("create endpoint")?
351 .connect_timeout(Duration::from_secs(20))
352 .timeout(Duration::from_secs(20))
353 .tcp_nodelay(true) .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
355 .http2_keep_alive_interval(Duration::from_secs(300))
356 .keep_alive_timeout(Duration::from_secs(20))
357 .keep_alive_while_idle(true);
358
359 if args.tls {
360 let tls_config = ClientTlsConfig::new().with_enabled_roots();
361 endpoint = endpoint
362 .tls_config(tls_config)
363 .context("create TLS endpoint")?;
364 }
365
366 let channel = endpoint.connect().await.context("connect to endpoint")?;
367
368 let mut client = FlightSqlServiceClient::new(channel);
369 info!("connected");
370
371 for (k, v) in args.headers {
372 client.set_header(k, v);
373 }
374
375 if let Some(token) = args.token {
376 client.set_token(token);
377 info!("token set");
378 }
379
380 match (args.username, args.password) {
381 (None, None) => {}
382 (Some(username), Some(password)) => {
383 client
384 .handshake(&username, &password)
385 .await
386 .context("handshake")?;
387 info!("performed handshake");
388 }
389 (Some(_), None) => {
390 bail!("when username is set, you also need to set a password")
391 }
392 (None, Some(_)) => {
393 bail!("when password is set, you also need to set a username")
394 }
395 }
396
397 Ok(client)
398}
399
400fn parse_key_val(s: &str) -> Result<(String, String), String> {
402 let pos = s
403 .find('=')
404 .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
405 Ok((s[..pos].to_owned(), s[pos + 1..].to_owned()))
406}
407
408fn log_metadata(map: &MetadataMap, what: &'static str) {
410 for k_v in map.iter() {
411 match k_v {
412 tonic::metadata::KeyAndValueRef::Ascii(k, v) => {
413 info!(
414 "{}: {}={}",
415 what,
416 k.as_str(),
417 v.to_str().unwrap_or("<invalid>"),
418 );
419 }
420 tonic::metadata::KeyAndValueRef::Binary(k, v) => {
421 info!(
422 "{}: {}={}",
423 what,
424 k.as_str(),
425 String::from_utf8_lossy(v.as_ref()),
426 );
427 }
428 }
429 }
430}