flight_sql_client/
flight_sql_client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{sync::Arc, time::Duration};
19
20use anyhow::{bail, Context, Result};
21use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
22use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions};
23use arrow_flight::{
24    sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables},
25    FlightInfo,
26};
27use arrow_schema::Schema;
28use clap::{Parser, Subcommand};
29use core::str;
30use futures::TryStreamExt;
31use tonic::{
32    metadata::MetadataMap,
33    transport::{Channel, ClientTlsConfig, Endpoint},
34};
35use tracing_log::log::info;
36
37/// Logging CLI config.
38#[derive(Debug, Parser)]
39pub struct LoggingArgs {
40    /// Log verbosity.
41    ///
42    /// Defaults to "warn".
43    ///
44    /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace".
45    ///
46    /// Note you can also set logging level using `RUST_LOG` environment variable:
47    /// `RUST_LOG=debug`.
48    #[clap(
49        short = 'v',
50        long = "verbose",
51        action = clap::ArgAction::Count,
52    )]
53    log_verbose_count: u8,
54}
55
56#[derive(Debug, Parser)]
57struct ClientArgs {
58    /// Additional headers.
59    ///
60    /// Can be given multiple times. Headers and values are separated by '='.
61    ///
62    /// Example: `-H foo=bar -H baz=42`
63    #[clap(long = "header", short = 'H', value_parser = parse_key_val)]
64    headers: Vec<(String, String)>,
65
66    /// Username.
67    ///
68    /// Optional. If given, `password` must also be set.
69    #[clap(long, requires = "password")]
70    username: Option<String>,
71
72    /// Password.
73    ///
74    /// Optional. If given, `username` must also be set.
75    #[clap(long, requires = "username")]
76    password: Option<String>,
77
78    /// Auth token.
79    #[clap(long)]
80    token: Option<String>,
81
82    /// Use TLS.
83    ///
84    /// If not provided, use cleartext connection.
85    #[clap(long)]
86    tls: bool,
87
88    /// Server host.
89    ///
90    /// Required.
91    #[clap(long)]
92    host: String,
93
94    /// Server port.
95    ///
96    /// Defaults to `443` if `tls` is set, otherwise defaults to `80`.
97    #[clap(long)]
98    port: Option<u16>,
99}
100
101#[derive(Debug, Parser)]
102struct Args {
103    /// Logging args.
104    #[clap(flatten)]
105    logging_args: LoggingArgs,
106
107    /// Client args.
108    #[clap(flatten)]
109    client_args: ClientArgs,
110
111    #[clap(subcommand)]
112    cmd: Command,
113}
114
115/// Different available commands.
116#[derive(Debug, Subcommand)]
117enum Command {
118    /// Get catalogs.
119    Catalogs,
120    /// Get db schemas for a catalog.
121    DbSchemas {
122        /// Name of a catalog.
123        ///
124        /// Required.
125        catalog: String,
126        /// Specifies a filter pattern for schemas to search for.
127        /// When no schema_filter is provided, the pattern will not be used to narrow the search.
128        /// In the pattern string, two special characters can be used to denote matching rules:
129        ///     - "%" means to match any substring with 0 or more characters.
130        ///     - "_" means to match any one character.
131        #[clap(short, long)]
132        db_schema_filter: Option<String>,
133    },
134    /// Get tables for a catalog.
135    Tables {
136        /// Name of a catalog.
137        ///
138        /// Required.
139        catalog: String,
140        /// Specifies a filter pattern for schemas to search for.
141        /// When no schema_filter is provided, the pattern will not be used to narrow the search.
142        /// In the pattern string, two special characters can be used to denote matching rules:
143        ///     - "%" means to match any substring with 0 or more characters.
144        ///     - "_" means to match any one character.
145        #[clap(short, long)]
146        db_schema_filter: Option<String>,
147        /// Specifies a filter pattern for tables to search for.
148        /// When no table_filter is provided, all tables matching other filters are searched.
149        /// In the pattern string, two special characters can be used to denote matching rules:
150        ///     - "%" means to match any substring with 0 or more characters.
151        ///     - "_" means to match any one character.
152        #[clap(short, long)]
153        table_filter: Option<String>,
154        /// Specifies a filter of table types which must match.
155        /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables.
156        /// TABLE, VIEW, and SYSTEM TABLE are commonly supported.
157        #[clap(long)]
158        table_types: Vec<String>,
159    },
160    /// Get table types.
161    TableTypes,
162
163    /// Execute given statement.
164    StatementQuery {
165        /// SQL query.
166        ///
167        /// Required.
168        query: String,
169    },
170
171    /// Prepare given statement and then execute it.
172    PreparedStatementQuery {
173        /// SQL query.
174        ///
175        /// Required.
176        ///
177        /// Can contains placeholders like `$1`.
178        ///
179        /// Example: `SELECT * FROM t WHERE x = $1`
180        query: String,
181
182        /// Additional parameters.
183        ///
184        /// Can be given multiple times. Names and values are separated by '='. Values will be
185        /// converted to the type that the server reported for the prepared statement.
186        ///
187        /// Example: `-p $1=42`
188        #[clap(short, value_parser = parse_key_val)]
189        params: Vec<(String, String)>,
190    },
191}
192
193#[tokio::main]
194async fn main() -> Result<()> {
195    let args = Args::parse();
196    setup_logging(args.logging_args)?;
197    let mut client = setup_client(args.client_args)
198        .await
199        .context("setup client")?;
200
201    let flight_info = match args.cmd {
202        Command::Catalogs => client.get_catalogs().await.context("get catalogs")?,
203        Command::DbSchemas {
204            catalog,
205            db_schema_filter,
206        } => client
207            .get_db_schemas(CommandGetDbSchemas {
208                catalog: Some(catalog),
209                db_schema_filter_pattern: db_schema_filter,
210            })
211            .await
212            .context("get db schemas")?,
213        Command::Tables {
214            catalog,
215            db_schema_filter,
216            table_filter,
217            table_types,
218        } => client
219            .get_tables(CommandGetTables {
220                catalog: Some(catalog),
221                db_schema_filter_pattern: db_schema_filter,
222                table_name_filter_pattern: table_filter,
223                table_types,
224                // Schema is returned as ipc encoded bytes.
225                // We do not support returning the schema as there is no trivial mechanism
226                // to display the information to the user.
227                include_schema: false,
228            })
229            .await
230            .context("get tables")?,
231        Command::TableTypes => client.get_table_types().await.context("get table types")?,
232        Command::StatementQuery { query } => client
233            .execute(query, None)
234            .await
235            .context("execute statement")?,
236        Command::PreparedStatementQuery { query, params } => {
237            let mut prepared_stmt = client
238                .prepare(query, None)
239                .await
240                .context("prepare statement")?;
241
242            if !params.is_empty() {
243                prepared_stmt
244                    .set_parameters(
245                        construct_record_batch_from_params(
246                            &params,
247                            prepared_stmt
248                                .parameter_schema()
249                                .context("get parameter schema")?,
250                        )
251                        .context("construct parameters")?,
252                    )
253                    .context("bind parameters")?;
254            }
255
256            prepared_stmt
257                .execute()
258                .await
259                .context("execute prepared statement")?
260        }
261    };
262
263    let batches = execute_flight(&mut client, flight_info)
264        .await
265        .context("read flight data")?;
266
267    let res = pretty_format_batches(batches.as_slice()).context("format results")?;
268    println!("{res}");
269
270    Ok(())
271}
272
273async fn execute_flight(
274    client: &mut FlightSqlServiceClient<Channel>,
275    info: FlightInfo,
276) -> Result<Vec<RecordBatch>> {
277    let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?);
278    let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
279    batches.push(RecordBatch::new_empty(schema));
280    info!("decoded schema");
281
282    for endpoint in info.endpoint {
283        let Some(ticket) = &endpoint.ticket else {
284            bail!("did not get ticket");
285        };
286
287        let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?;
288        log_metadata(flight_data.headers(), "header");
289
290        let mut endpoint_batches: Vec<_> = (&mut flight_data)
291            .try_collect()
292            .await
293            .context("collect data stream")?;
294        batches.append(&mut endpoint_batches);
295
296        if let Some(trailers) = flight_data.trailers() {
297            log_metadata(&trailers, "trailer");
298        }
299    }
300    info!("received data");
301
302    Ok(batches)
303}
304
305fn construct_record_batch_from_params(
306    params: &[(String, String)],
307    parameter_schema: &Schema,
308) -> Result<RecordBatch> {
309    let mut items = Vec::<(&String, ArrayRef)>::new();
310
311    for (name, value) in params {
312        let field = parameter_schema.field_with_name(name)?;
313        let value_as_array = StringArray::new_scalar(value);
314        let casted = cast_with_options(
315            value_as_array.get().0,
316            field.data_type(),
317            &CastOptions::default(),
318        )?;
319        items.push((name, casted))
320    }
321
322    Ok(RecordBatch::try_from_iter(items)?)
323}
324
325fn setup_logging(args: LoggingArgs) -> Result<()> {
326    use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber};
327
328    tracing_log::LogTracer::init().context("tracing log init")?;
329
330    let filter = match args.log_verbose_count {
331        0 => "warn",
332        1 => "info",
333        2 => "debug",
334        _ => "trace",
335    };
336    let filter = EnvFilter::try_new(filter).context("set up log env filter")?;
337
338    let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
339    subscriber.try_init().context("init logging subscriber")?;
340
341    Ok(())
342}
343
344async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient<Channel>> {
345    let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
346
347    let protocol = if args.tls { "https" } else { "http" };
348
349    let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
350        .context("create endpoint")?
351        .connect_timeout(Duration::from_secs(20))
352        .timeout(Duration::from_secs(20))
353        .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait
354        .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
355        .http2_keep_alive_interval(Duration::from_secs(300))
356        .keep_alive_timeout(Duration::from_secs(20))
357        .keep_alive_while_idle(true);
358
359    if args.tls {
360        let tls_config = ClientTlsConfig::new().with_enabled_roots();
361        endpoint = endpoint
362            .tls_config(tls_config)
363            .context("create TLS endpoint")?;
364    }
365
366    let channel = endpoint.connect().await.context("connect to endpoint")?;
367
368    let mut client = FlightSqlServiceClient::new(channel);
369    info!("connected");
370
371    for (k, v) in args.headers {
372        client.set_header(k, v);
373    }
374
375    if let Some(token) = args.token {
376        client.set_token(token);
377        info!("token set");
378    }
379
380    match (args.username, args.password) {
381        (None, None) => {}
382        (Some(username), Some(password)) => {
383            client
384                .handshake(&username, &password)
385                .await
386                .context("handshake")?;
387            info!("performed handshake");
388        }
389        (Some(_), None) => {
390            bail!("when username is set, you also need to set a password")
391        }
392        (None, Some(_)) => {
393            bail!("when password is set, you also need to set a username")
394        }
395    }
396
397    Ok(client)
398}
399
400/// Parse a single key-value pair
401fn parse_key_val(s: &str) -> Result<(String, String), String> {
402    let pos = s
403        .find('=')
404        .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
405    Ok((s[..pos].to_owned(), s[pos + 1..].to_owned()))
406}
407
408/// Log headers/trailers.
409fn log_metadata(map: &MetadataMap, what: &'static str) {
410    for k_v in map.iter() {
411        match k_v {
412            tonic::metadata::KeyAndValueRef::Ascii(k, v) => {
413                info!(
414                    "{}: {}={}",
415                    what,
416                    k.as_str(),
417                    v.to_str().unwrap_or("<invalid>"),
418                );
419            }
420            tonic::metadata::KeyAndValueRef::Binary(k, v) => {
421                info!(
422                    "{}: {}={}",
423                    what,
424                    k.as_str(),
425                    String::from_utf8_lossy(v.as_ref()),
426                );
427            }
428        }
429    }
430}