flight_sql_client/
flight_sql_client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{sync::Arc, time::Duration};
19
20use anyhow::{Context, Result, bail};
21use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
22use arrow_cast::{CastOptions, cast_with_options, pretty::pretty_format_batches};
23use arrow_flight::{
24    FlightInfo,
25    flight_service_client::FlightServiceClient,
26    sql::{CommandGetDbSchemas, CommandGetTables, client::FlightSqlServiceClient},
27};
28use arrow_schema::Schema;
29use clap::{Parser, Subcommand, ValueEnum};
30use core::str;
31use futures::TryStreamExt;
32use tonic::{
33    metadata::MetadataMap,
34    transport::{Channel, ClientTlsConfig, Endpoint},
35};
36use tracing_log::log::info;
37
38/// Logging CLI config.
39#[derive(Debug, Parser)]
40pub struct LoggingArgs {
41    /// Log verbosity.
42    ///
43    /// Defaults to "warn".
44    ///
45    /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace".
46    ///
47    /// Note you can also set logging level using `RUST_LOG` environment variable:
48    /// `RUST_LOG=debug`.
49    #[clap(
50        short = 'v',
51        long = "verbose",
52        action = clap::ArgAction::Count,
53    )]
54    log_verbose_count: u8,
55}
56
57/// gRPC/HTTP compression algorithms.
58#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
59pub enum CompressionEncoding {
60    Gzip,
61    Deflate,
62    Zstd,
63}
64
65impl From<CompressionEncoding> for tonic::codec::CompressionEncoding {
66    fn from(encoding: CompressionEncoding) -> Self {
67        match encoding {
68            CompressionEncoding::Gzip => Self::Gzip,
69            CompressionEncoding::Deflate => Self::Deflate,
70            CompressionEncoding::Zstd => Self::Zstd,
71        }
72    }
73}
74
75#[derive(Debug, Parser)]
76struct ClientArgs {
77    /// Additional headers.
78    ///
79    /// Can be given multiple times. Headers and values are separated by '='.
80    ///
81    /// Example: `-H foo=bar -H baz=42`
82    #[clap(long = "header", short = 'H', value_parser = parse_key_val)]
83    headers: Vec<(String, String)>,
84
85    /// Username.
86    ///
87    /// Optional. If given, `password` must also be set.
88    #[clap(long, requires = "password")]
89    username: Option<String>,
90
91    /// Password.
92    ///
93    /// Optional. If given, `username` must also be set.
94    #[clap(long, requires = "username")]
95    password: Option<String>,
96
97    /// Auth token.
98    #[clap(long)]
99    token: Option<String>,
100
101    /// Use TLS.
102    ///
103    /// If not provided, use cleartext connection.
104    #[clap(long)]
105    tls: bool,
106
107    /// Dump TLS key log.
108    ///
109    /// The target file is specified by the `SSLKEYLOGFILE` environment variable.
110    ///
111    /// Requires `--tls`.
112    #[clap(long, requires = "tls")]
113    key_log: bool,
114
115    /// Server host.
116    ///
117    /// Required.
118    #[clap(long)]
119    host: String,
120
121    /// Server port.
122    ///
123    /// Defaults to `443` if `tls` is set, otherwise defaults to `80`.
124    #[clap(long)]
125    port: Option<u16>,
126
127    /// Compression accepted by the client for responses sent by the server.
128    ///
129    /// The client will send this information to the server as part of the request. The server is free to pick an
130    /// algorithm from that list or use no compression (called "identity" encoding).
131    ///
132    /// You may define multiple algorithms by using a comma-separated list.
133    #[clap(long, value_delimiter = ',')]
134    accept_compression: Vec<CompressionEncoding>,
135
136    /// Compression of requests sent by the client to the server.
137    ///
138    /// Since the client needs to decide on the compression before sending the request, there is no client<->server
139    /// negotiation. If the server does NOT support the chosen compression, it will respond with an error a la:
140    ///
141    /// ```
142    /// Ipc error: Status {
143    ///     code: Unimplemented,
144    ///     message: "Content is compressed with `zstd` which isn't supported",
145    ///     metadata: MetadataMap { headers: {"grpc-accept-encoding": "identity", ...} },
146    ///     ...
147    /// }
148    /// ```
149    ///
150    /// Based on the algorithms listed in the `grpc-accept-encoding` header, you may make a more educated guess for
151    /// your next request. Note that `identity` is a synonym for "no compression".
152    #[clap(long)]
153    send_compression: Option<CompressionEncoding>,
154}
155
156#[derive(Debug, Parser)]
157struct Args {
158    /// Logging args.
159    #[clap(flatten)]
160    logging_args: LoggingArgs,
161
162    /// Client args.
163    #[clap(flatten)]
164    client_args: ClientArgs,
165
166    #[clap(subcommand)]
167    cmd: Command,
168}
169
170/// Different available commands.
171#[derive(Debug, Subcommand)]
172enum Command {
173    /// Get catalogs.
174    Catalogs,
175    /// Get db schemas for a catalog.
176    DbSchemas {
177        /// Name of a catalog.
178        ///
179        /// Required.
180        catalog: String,
181        /// Specifies a filter pattern for schemas to search for.
182        /// When no schema_filter is provided, the pattern will not be used to narrow the search.
183        /// In the pattern string, two special characters can be used to denote matching rules:
184        ///     - "%" means to match any substring with 0 or more characters.
185        ///     - "_" means to match any one character.
186        #[clap(short, long)]
187        db_schema_filter: Option<String>,
188    },
189    /// Get tables for a catalog.
190    Tables {
191        /// Name of a catalog.
192        ///
193        /// Required.
194        catalog: String,
195        /// Specifies a filter pattern for schemas to search for.
196        /// When no schema_filter is provided, the pattern will not be used to narrow the search.
197        /// In the pattern string, two special characters can be used to denote matching rules:
198        ///     - "%" means to match any substring with 0 or more characters.
199        ///     - "_" means to match any one character.
200        #[clap(short, long)]
201        db_schema_filter: Option<String>,
202        /// Specifies a filter pattern for tables to search for.
203        /// When no table_filter is provided, all tables matching other filters are searched.
204        /// In the pattern string, two special characters can be used to denote matching rules:
205        ///     - "%" means to match any substring with 0 or more characters.
206        ///     - "_" means to match any one character.
207        #[clap(short, long)]
208        table_filter: Option<String>,
209        /// Specifies a filter of table types which must match.
210        /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables.
211        /// TABLE, VIEW, and SYSTEM TABLE are commonly supported.
212        #[clap(long)]
213        table_types: Vec<String>,
214    },
215    /// Get table types.
216    TableTypes,
217
218    /// Execute given statement.
219    StatementQuery {
220        /// SQL query.
221        ///
222        /// Required.
223        query: String,
224    },
225
226    /// Prepare given statement and then execute it.
227    PreparedStatementQuery {
228        /// SQL query.
229        ///
230        /// Required.
231        ///
232        /// Can contains placeholders like `$1`.
233        ///
234        /// Example: `SELECT * FROM t WHERE x = $1`
235        query: String,
236
237        /// Additional parameters.
238        ///
239        /// Can be given multiple times. Names and values are separated by '='. Values will be
240        /// converted to the type that the server reported for the prepared statement.
241        ///
242        /// Example: `-p $1=42`
243        #[clap(short, value_parser = parse_key_val)]
244        params: Vec<(String, String)>,
245    },
246}
247
248#[tokio::main]
249async fn main() -> Result<()> {
250    let args = Args::parse();
251    setup_logging(args.logging_args)?;
252    let mut client = setup_client(args.client_args)
253        .await
254        .context("setup client")?;
255
256    let flight_info = match args.cmd {
257        Command::Catalogs => client.get_catalogs().await.context("get catalogs")?,
258        Command::DbSchemas {
259            catalog,
260            db_schema_filter,
261        } => client
262            .get_db_schemas(CommandGetDbSchemas {
263                catalog: Some(catalog),
264                db_schema_filter_pattern: db_schema_filter,
265            })
266            .await
267            .context("get db schemas")?,
268        Command::Tables {
269            catalog,
270            db_schema_filter,
271            table_filter,
272            table_types,
273        } => client
274            .get_tables(CommandGetTables {
275                catalog: Some(catalog),
276                db_schema_filter_pattern: db_schema_filter,
277                table_name_filter_pattern: table_filter,
278                table_types,
279                // Schema is returned as ipc encoded bytes.
280                // We do not support returning the schema as there is no trivial mechanism
281                // to display the information to the user.
282                include_schema: false,
283            })
284            .await
285            .context("get tables")?,
286        Command::TableTypes => client.get_table_types().await.context("get table types")?,
287        Command::StatementQuery { query } => client
288            .execute(query, None)
289            .await
290            .context("execute statement")?,
291        Command::PreparedStatementQuery { query, params } => {
292            let mut prepared_stmt = client
293                .prepare(query, None)
294                .await
295                .context("prepare statement")?;
296
297            if !params.is_empty() {
298                prepared_stmt
299                    .set_parameters(
300                        construct_record_batch_from_params(
301                            &params,
302                            prepared_stmt
303                                .parameter_schema()
304                                .context("get parameter schema")?,
305                        )
306                        .context("construct parameters")?,
307                    )
308                    .context("bind parameters")?;
309            }
310
311            prepared_stmt
312                .execute()
313                .await
314                .context("execute prepared statement")?
315        }
316    };
317
318    let batches = execute_flight(&mut client, flight_info)
319        .await
320        .context("read flight data")?;
321
322    let res = pretty_format_batches(batches.as_slice()).context("format results")?;
323    println!("{res}");
324
325    Ok(())
326}
327
328async fn execute_flight(
329    client: &mut FlightSqlServiceClient<Channel>,
330    info: FlightInfo,
331) -> Result<Vec<RecordBatch>> {
332    let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?);
333    let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
334    batches.push(RecordBatch::new_empty(schema));
335    info!("decoded schema");
336
337    for endpoint in info.endpoint {
338        let Some(ticket) = &endpoint.ticket else {
339            bail!("did not get ticket");
340        };
341
342        let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?;
343        log_metadata(flight_data.headers(), "header");
344
345        let mut endpoint_batches: Vec<_> = (&mut flight_data)
346            .try_collect()
347            .await
348            .context("collect data stream")?;
349        batches.append(&mut endpoint_batches);
350
351        if let Some(trailers) = flight_data.trailers() {
352            log_metadata(&trailers, "trailer");
353        }
354    }
355    info!("received data");
356
357    Ok(batches)
358}
359
360fn construct_record_batch_from_params(
361    params: &[(String, String)],
362    parameter_schema: &Schema,
363) -> Result<RecordBatch> {
364    let mut items = Vec::<(&String, ArrayRef)>::new();
365
366    for (name, value) in params {
367        let field = parameter_schema.field_with_name(name)?;
368        let value_as_array = StringArray::new_scalar(value);
369        let casted = cast_with_options(
370            value_as_array.get().0,
371            field.data_type(),
372            &CastOptions::default(),
373        )?;
374        items.push((name, casted))
375    }
376
377    Ok(RecordBatch::try_from_iter(items)?)
378}
379
380fn setup_logging(args: LoggingArgs) -> Result<()> {
381    use tracing_subscriber::{EnvFilter, FmtSubscriber, util::SubscriberInitExt};
382
383    tracing_log::LogTracer::init().context("tracing log init")?;
384
385    let filter = match args.log_verbose_count {
386        0 => "warn",
387        1 => "info",
388        2 => "debug",
389        _ => "trace",
390    };
391    let filter = EnvFilter::try_new(filter).context("set up log env filter")?;
392
393    let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
394    subscriber.try_init().context("init logging subscriber")?;
395
396    Ok(())
397}
398
399async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient<Channel>> {
400    let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
401
402    let protocol = if args.tls { "https" } else { "http" };
403
404    let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
405        .context("create endpoint")?
406        .connect_timeout(Duration::from_secs(20))
407        .timeout(Duration::from_secs(20))
408        .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait
409        .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
410        .http2_keep_alive_interval(Duration::from_secs(300))
411        .keep_alive_timeout(Duration::from_secs(20))
412        .keep_alive_while_idle(true);
413
414    if args.tls {
415        let mut tls_config = ClientTlsConfig::new().with_enabled_roots();
416        if args.key_log {
417            tls_config = tls_config.use_key_log();
418        }
419
420        endpoint = endpoint
421            .tls_config(tls_config)
422            .context("create TLS endpoint")?;
423    }
424
425    let channel = endpoint.connect().await.context("connect to endpoint")?;
426
427    let mut client = FlightServiceClient::new(channel);
428    for encoding in args.accept_compression {
429        client = client.accept_compressed(encoding.into());
430    }
431    if let Some(encoding) = args.send_compression {
432        client = client.send_compressed(encoding.into());
433    }
434    let mut client = FlightSqlServiceClient::new_from_inner(client);
435    info!("connected");
436
437    for (k, v) in args.headers {
438        client.set_header(k, v);
439    }
440
441    if let Some(token) = args.token {
442        client.set_token(token);
443        info!("token set");
444    }
445
446    match (args.username, args.password) {
447        (None, None) => {}
448        (Some(username), Some(password)) => {
449            client
450                .handshake(&username, &password)
451                .await
452                .context("handshake")?;
453            info!("performed handshake");
454        }
455        (Some(_), None) => {
456            bail!("when username is set, you also need to set a password")
457        }
458        (None, Some(_)) => {
459            bail!("when password is set, you also need to set a username")
460        }
461    }
462
463    Ok(client)
464}
465
466/// Parse a single key-value pair
467fn parse_key_val(s: &str) -> Result<(String, String), String> {
468    let pos = s
469        .find('=')
470        .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
471    Ok((s[..pos].to_owned(), s[pos + 1..].to_owned()))
472}
473
474/// Log headers/trailers.
475fn log_metadata(map: &MetadataMap, what: &'static str) {
476    for k_v in map.iter() {
477        match k_v {
478            tonic::metadata::KeyAndValueRef::Ascii(k, v) => {
479                info!(
480                    "{}: {}={}",
481                    what,
482                    k.as_str(),
483                    v.to_str().unwrap_or("<invalid>"),
484                );
485            }
486            tonic::metadata::KeyAndValueRef::Binary(k, v) => {
487                info!(
488                    "{}: {}={}",
489                    what,
490                    k.as_str(),
491                    String::from_utf8_lossy(v.as_ref()),
492                );
493            }
494        }
495    }
496}