1use std::{sync::Arc, time::Duration};
19
20use anyhow::{Context, Result, bail};
21use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
22use arrow_cast::{CastOptions, cast_with_options, pretty::pretty_format_batches};
23use arrow_flight::{
24 FlightInfo,
25 flight_service_client::FlightServiceClient,
26 sql::{CommandGetDbSchemas, CommandGetTables, client::FlightSqlServiceClient},
27};
28use arrow_schema::Schema;
29use clap::{Parser, Subcommand, ValueEnum};
30use core::str;
31use futures::TryStreamExt;
32use tonic::{
33 metadata::MetadataMap,
34 transport::{Channel, ClientTlsConfig, Endpoint},
35};
36use tracing_log::log::info;
37
38#[derive(Debug, Parser)]
40pub struct LoggingArgs {
41 #[clap(
50 short = 'v',
51 long = "verbose",
52 action = clap::ArgAction::Count,
53 )]
54 log_verbose_count: u8,
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
59pub enum CompressionEncoding {
60 Gzip,
61 Deflate,
62 Zstd,
63}
64
65impl From<CompressionEncoding> for tonic::codec::CompressionEncoding {
66 fn from(encoding: CompressionEncoding) -> Self {
67 match encoding {
68 CompressionEncoding::Gzip => Self::Gzip,
69 CompressionEncoding::Deflate => Self::Deflate,
70 CompressionEncoding::Zstd => Self::Zstd,
71 }
72 }
73}
74
75#[derive(Debug, Parser)]
76struct ClientArgs {
77 #[clap(long = "header", short = 'H', value_parser = parse_key_val)]
83 headers: Vec<(String, String)>,
84
85 #[clap(long, requires = "password")]
89 username: Option<String>,
90
91 #[clap(long, requires = "username")]
95 password: Option<String>,
96
97 #[clap(long)]
99 token: Option<String>,
100
101 #[clap(long)]
105 tls: bool,
106
107 #[clap(long, requires = "tls")]
113 key_log: bool,
114
115 #[clap(long)]
119 host: String,
120
121 #[clap(long)]
125 port: Option<u16>,
126
127 #[clap(long, value_delimiter = ',')]
134 accept_compression: Vec<CompressionEncoding>,
135
136 #[clap(long)]
153 send_compression: Option<CompressionEncoding>,
154}
155
156#[derive(Debug, Parser)]
157struct Args {
158 #[clap(flatten)]
160 logging_args: LoggingArgs,
161
162 #[clap(flatten)]
164 client_args: ClientArgs,
165
166 #[clap(subcommand)]
167 cmd: Command,
168}
169
170#[derive(Debug, Subcommand)]
172enum Command {
173 Catalogs,
175 DbSchemas {
177 catalog: String,
181 #[clap(short, long)]
187 db_schema_filter: Option<String>,
188 },
189 Tables {
191 catalog: String,
195 #[clap(short, long)]
201 db_schema_filter: Option<String>,
202 #[clap(short, long)]
208 table_filter: Option<String>,
209 #[clap(long)]
213 table_types: Vec<String>,
214 },
215 TableTypes,
217
218 StatementQuery {
220 query: String,
224 },
225
226 PreparedStatementQuery {
228 query: String,
236
237 #[clap(short, value_parser = parse_key_val)]
244 params: Vec<(String, String)>,
245 },
246}
247
248#[tokio::main]
249async fn main() -> Result<()> {
250 let args = Args::parse();
251 setup_logging(args.logging_args)?;
252 let mut client = setup_client(args.client_args)
253 .await
254 .context("setup client")?;
255
256 let flight_info = match args.cmd {
257 Command::Catalogs => client.get_catalogs().await.context("get catalogs")?,
258 Command::DbSchemas {
259 catalog,
260 db_schema_filter,
261 } => client
262 .get_db_schemas(CommandGetDbSchemas {
263 catalog: Some(catalog),
264 db_schema_filter_pattern: db_schema_filter,
265 })
266 .await
267 .context("get db schemas")?,
268 Command::Tables {
269 catalog,
270 db_schema_filter,
271 table_filter,
272 table_types,
273 } => client
274 .get_tables(CommandGetTables {
275 catalog: Some(catalog),
276 db_schema_filter_pattern: db_schema_filter,
277 table_name_filter_pattern: table_filter,
278 table_types,
279 include_schema: false,
283 })
284 .await
285 .context("get tables")?,
286 Command::TableTypes => client.get_table_types().await.context("get table types")?,
287 Command::StatementQuery { query } => client
288 .execute(query, None)
289 .await
290 .context("execute statement")?,
291 Command::PreparedStatementQuery { query, params } => {
292 let mut prepared_stmt = client
293 .prepare(query, None)
294 .await
295 .context("prepare statement")?;
296
297 if !params.is_empty() {
298 prepared_stmt
299 .set_parameters(
300 construct_record_batch_from_params(
301 ¶ms,
302 prepared_stmt
303 .parameter_schema()
304 .context("get parameter schema")?,
305 )
306 .context("construct parameters")?,
307 )
308 .context("bind parameters")?;
309 }
310
311 prepared_stmt
312 .execute()
313 .await
314 .context("execute prepared statement")?
315 }
316 };
317
318 let batches = execute_flight(&mut client, flight_info)
319 .await
320 .context("read flight data")?;
321
322 let res = pretty_format_batches(batches.as_slice()).context("format results")?;
323 println!("{res}");
324
325 Ok(())
326}
327
328async fn execute_flight(
329 client: &mut FlightSqlServiceClient<Channel>,
330 info: FlightInfo,
331) -> Result<Vec<RecordBatch>> {
332 let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?);
333 let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
334 batches.push(RecordBatch::new_empty(schema));
335 info!("decoded schema");
336
337 for endpoint in info.endpoint {
338 let Some(ticket) = &endpoint.ticket else {
339 bail!("did not get ticket");
340 };
341
342 let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?;
343 log_metadata(flight_data.headers(), "header");
344
345 let mut endpoint_batches: Vec<_> = (&mut flight_data)
346 .try_collect()
347 .await
348 .context("collect data stream")?;
349 batches.append(&mut endpoint_batches);
350
351 if let Some(trailers) = flight_data.trailers() {
352 log_metadata(&trailers, "trailer");
353 }
354 }
355 info!("received data");
356
357 Ok(batches)
358}
359
360fn construct_record_batch_from_params(
361 params: &[(String, String)],
362 parameter_schema: &Schema,
363) -> Result<RecordBatch> {
364 let mut items = Vec::<(&String, ArrayRef)>::new();
365
366 for (name, value) in params {
367 let field = parameter_schema.field_with_name(name)?;
368 let value_as_array = StringArray::new_scalar(value);
369 let casted = cast_with_options(
370 value_as_array.get().0,
371 field.data_type(),
372 &CastOptions::default(),
373 )?;
374 items.push((name, casted))
375 }
376
377 Ok(RecordBatch::try_from_iter(items)?)
378}
379
380fn setup_logging(args: LoggingArgs) -> Result<()> {
381 use tracing_subscriber::{EnvFilter, FmtSubscriber, util::SubscriberInitExt};
382
383 tracing_log::LogTracer::init().context("tracing log init")?;
384
385 let filter = match args.log_verbose_count {
386 0 => "warn",
387 1 => "info",
388 2 => "debug",
389 _ => "trace",
390 };
391 let filter = EnvFilter::try_new(filter).context("set up log env filter")?;
392
393 let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
394 subscriber.try_init().context("init logging subscriber")?;
395
396 Ok(())
397}
398
399async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient<Channel>> {
400 let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
401
402 let protocol = if args.tls { "https" } else { "http" };
403
404 let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
405 .context("create endpoint")?
406 .connect_timeout(Duration::from_secs(20))
407 .timeout(Duration::from_secs(20))
408 .tcp_nodelay(true) .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
410 .http2_keep_alive_interval(Duration::from_secs(300))
411 .keep_alive_timeout(Duration::from_secs(20))
412 .keep_alive_while_idle(true);
413
414 if args.tls {
415 let mut tls_config = ClientTlsConfig::new().with_enabled_roots();
416 if args.key_log {
417 tls_config = tls_config.use_key_log();
418 }
419
420 endpoint = endpoint
421 .tls_config(tls_config)
422 .context("create TLS endpoint")?;
423 }
424
425 let channel = endpoint.connect().await.context("connect to endpoint")?;
426
427 let mut client = FlightServiceClient::new(channel);
428 for encoding in args.accept_compression {
429 client = client.accept_compressed(encoding.into());
430 }
431 if let Some(encoding) = args.send_compression {
432 client = client.send_compressed(encoding.into());
433 }
434 let mut client = FlightSqlServiceClient::new_from_inner(client);
435 info!("connected");
436
437 for (k, v) in args.headers {
438 client.set_header(k, v);
439 }
440
441 if let Some(token) = args.token {
442 client.set_token(token);
443 info!("token set");
444 }
445
446 match (args.username, args.password) {
447 (None, None) => {}
448 (Some(username), Some(password)) => {
449 client
450 .handshake(&username, &password)
451 .await
452 .context("handshake")?;
453 info!("performed handshake");
454 }
455 (Some(_), None) => {
456 bail!("when username is set, you also need to set a password")
457 }
458 (None, Some(_)) => {
459 bail!("when password is set, you also need to set a username")
460 }
461 }
462
463 Ok(client)
464}
465
466fn parse_key_val(s: &str) -> Result<(String, String), String> {
468 let pos = s
469 .find('=')
470 .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
471 Ok((s[..pos].to_owned(), s[pos + 1..].to_owned()))
472}
473
474fn log_metadata(map: &MetadataMap, what: &'static str) {
476 for k_v in map.iter() {
477 match k_v {
478 tonic::metadata::KeyAndValueRef::Ascii(k, v) => {
479 info!(
480 "{}: {}={}",
481 what,
482 k.as_str(),
483 v.to_str().unwrap_or("<invalid>"),
484 );
485 }
486 tonic::metadata::KeyAndValueRef::Binary(k, v) => {
487 info!(
488 "{}: {}={}",
489 what,
490 k.as_str(),
491 String::from_utf8_lossy(v.as_ref()),
492 );
493 }
494 }
495 }
496}