Arrow Flight

Recipes related to leveraging Arrow Flight protocol

Simple Parquet storage service with Arrow Flight

Suppose you want to implement a service that can store, send and receive Parquet files using the Arrow Flight protocol, pyarrow provides an implementation framework in pyarrow.flight and particularly through the pyarrow.flight.FlightServerBase class.

import pathlib

import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet


class FlightServer(pa.flight.FlightServerBase):

    def __init__(self, location="grpc://0.0.0.0:8815",
                repo=pathlib.Path("./datasets"), **kwargs):
        super(FlightServer, self).__init__(location, **kwargs)
        self._location = location
        self._repo = repo

    def _make_flight_info(self, dataset):
        dataset_path = self._repo / dataset
        schema = pa.parquet.read_schema(dataset_path)
        metadata = pa.parquet.read_metadata(dataset_path)
        descriptor = pa.flight.FlightDescriptor.for_path(
            dataset.encode('utf-8')
        )
        endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
        return pyarrow.flight.FlightInfo(schema,
                                        descriptor,
                                        endpoints,
                                        metadata.num_rows,
                                        metadata.serialized_size)

    def list_flights(self, context, criteria):
        for dataset in self._repo.iterdir():
            yield self._make_flight_info(dataset.name)

    def get_flight_info(self, context, descriptor):
        return self._make_flight_info(descriptor.path[0].decode('utf-8'))

    def do_put(self, context, descriptor, reader, writer):
        dataset = descriptor.path[0].decode('utf-8')
        dataset_path = self._repo / dataset
        data_table = reader.read_all()
        pa.parquet.write_table(data_table, dataset_path)

    def do_get(self, context, ticket):
        dataset = ticket.ticket.decode('utf-8')
        dataset_path = self._repo / dataset
        return pa.flight.RecordBatchStream(pa.parquet.read_table(dataset_path))

    def list_actions(self, context):
        return [
            ("drop_dataset", "Delete a dataset."),
        ]

    def do_action(self, context, action):
        if action.type == "drop_dataset":
            self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
        else:
            raise NotImplementedError

    def do_drop_dataset(self, dataset):
        dataset_path = self._repo / dataset
        dataset_path.unlink()

The example server exposes pyarrow.flight.FlightServerBase.list_flights() which is the method in charge of returning the list of data streams available for fetching.

Likewise, pyarrow.flight.FlightServerBase.get_flight_info() provides the information regarding a single specific data stream.

Then we expose pyarrow.flight.FlightServerBase.do_get() which is in charge of actually fetching the exposed data streams and sending them to the client.

Allowing to list and download data streams would be pretty useless if we didn’t expose a way to create them, this is the responsibility of pyarrow.flight.FlightServerBase.do_put() which is in charge of receiving new data from the client and dealing with it (in this case saving it into a parquet file)

This are the most common Arrow Flight requests, if we need to add more functionalities, we can do so using custom actions.

In the previous example a drop_dataset custom action is added. All custom actions are executed through the pyarrow.flight.FlightServerBase.do_action() method, thus it’s up to the server subclass to dispatch them properly. In this case we invoke the do_drop_dataset method when the action.type is the one we expect.

Our server can then be started with pyarrow.flight.FlightServerBase.serve()

if __name__ == '__main__':
    server = FlightServer()
    server._repo.mkdir(exist_ok=True)
    server.serve()

Once the server is started we can build a client to perform requests to it

import pyarrow as pa
import pyarrow.flight

client = pa.flight.connect("grpc://0.0.0.0:8815")

We can create a new table and upload it so that it gets stored in a new parquet file:

# Upload a new dataset
data_table = pa.table(
    [["Mario", "Luigi", "Peach"]],
    names=["Character"]
)
upload_descriptor = pa.flight.FlightDescriptor.for_path("uploaded.parquet")
writer, _ = client.do_put(upload_descriptor, data_table.schema)
writer.write_table(data_table)
writer.close()

Once uploaded we should be able to retrieve the metadata for our newly uploaded table:

# Retrieve metadata of newly uploaded dataset
flight = client.get_flight_info(upload_descriptor)
descriptor = flight.descriptor
print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
print("=== Schema ===")
print(flight.schema)
print("==============")
Path: uploaded.parquet Rows: 3 Size: ...
=== Schema ===
Character: string
==============

And we can fetch the content of the dataset:

# Read content of the dataset
reader = client.do_get(flight.endpoints[0].ticket)
read_table = reader.read_all()
print(read_table.to_pandas().head())
  Character
0     Mario
1     Luigi
2     Peach

Once we finished we can invoke our custom action to delete the dataset we newly uploaded:

# Drop the newly uploaded dataset
client.do_action(pa.flight.Action("drop_dataset", "uploaded.parquet".encode('utf-8')))

To confirm our dataset was deleted, we might list all parquet files that are currently stored by the server:

# List existing datasets.
for flight in client.list_flights():
    descriptor = flight.descriptor
    print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
    print("=== Schema ===")
    print(flight.schema)
    print("==============")
    print("")

Streaming Parquet Storage Service

We can improve the Parquet storage service and avoid holding entire datasets in memory by streaming data. Flight readers and writers, like others in PyArrow, can be iterated through, so let’s update the server from before to take advantage of this:

import pathlib

import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet


class FlightServer(pa.flight.FlightServerBase):

    def __init__(self, location="grpc://0.0.0.0:8815",
                repo=pathlib.Path("./datasets"), **kwargs):
        super(FlightServer, self).__init__(location, **kwargs)
        self._location = location
        self._repo = repo

    def _make_flight_info(self, dataset):
        dataset_path = self._repo / dataset
        schema = pa.parquet.read_schema(dataset_path)
        metadata = pa.parquet.read_metadata(dataset_path)
        descriptor = pa.flight.FlightDescriptor.for_path(
            dataset.encode('utf-8')
        )
        endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
        return pyarrow.flight.FlightInfo(schema,
                                        descriptor,
                                        endpoints,
                                        metadata.num_rows,
                                        metadata.serialized_size)

    def list_flights(self, context, criteria):
        for dataset in self._repo.iterdir():
            yield self._make_flight_info(dataset.name)

    def get_flight_info(self, context, descriptor):
        return self._make_flight_info(descriptor.path[0].decode('utf-8'))

    def do_put(self, context, descriptor, reader, writer):
        dataset = descriptor.path[0].decode('utf-8')
        dataset_path = self._repo / dataset
        # Read the uploaded data and write to Parquet incrementally
        with dataset_path.open("wb") as sink:
            with pa.parquet.ParquetWriter(sink, reader.schema) as writer:
                for chunk in reader:
                    writer.write_table(pa.Table.from_batches([chunk.data]))

    def do_get(self, context, ticket):
        dataset = ticket.ticket.decode('utf-8')
        # Stream data from a file
        dataset_path = self._repo / dataset
        reader = pa.parquet.ParquetFile(dataset_path)
        return pa.flight.GeneratorStream(
            reader.schema_arrow, reader.iter_batches())

    def list_actions(self, context):
        return [
            ("drop_dataset", "Delete a dataset."),
        ]

    def do_action(self, context, action):
        if action.type == "drop_dataset":
            self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
        else:
            raise NotImplementedError

    def do_drop_dataset(self, dataset):
        dataset_path = self._repo / dataset
        dataset_path.unlink()

First, we’ve modified pyarrow.flight.FlightServerBase.do_put(). Instead of reading all the uploaded data into a pyarrow.Table before writing, we instead iterate through each batch as it comes and add it to a Parquet file.

Then, we’ve modified pyarrow.flight.FlightServerBase.do_get() to stream data to the client. This uses pyarrow.flight.GeneratorStream, which takes a schema and any iterable or iterator. Flight then iterates through and sends each record batch to the client, allowing us to handle even large Parquet files that don’t fit into memory.

While GeneratorStream has the advantage that it can stream data, that means Flight must call back into Python for each record batch to send. In contrast, RecordBatchStream requires that all data is in-memory up front, but once created, all data transfer is handled purely in C++, without needing to call Python code.

Let’s give the server a spin. As before, we’ll start the server:

if __name__ == '__main__':
    server = FlightServer()
    server._repo.mkdir(exist_ok=True)
    server.serve()

We create a client, and this time, we’ll write batches to the writer, as if we had a stream of data instead of a table in memory:

import pyarrow as pa
import pyarrow.flight

client = pa.flight.connect("grpc://0.0.0.0:8815")

# Upload a new dataset
NUM_BATCHES = 1024
ROWS_PER_BATCH = 4096
upload_descriptor = pa.flight.FlightDescriptor.for_path("streamed.parquet")
batch = pa.record_batch([
    pa.array(range(ROWS_PER_BATCH)),
], names=["ints"])
writer, _ = client.do_put(upload_descriptor, batch.schema)
with writer:
    for _ in range(NUM_BATCHES):
        writer.write_batch(batch)

As before, we can then read it back. Again, we’ll read each batch from the stream as it arrives, instead of reading them all into a table:

# Read content of the dataset
flight = client.get_flight_info(upload_descriptor)
reader = client.do_get(flight.endpoints[0].ticket)
total_rows = 0
for chunk in reader:
    total_rows += chunk.data.num_rows
print("Got", total_rows, "rows total, expected", NUM_BATCHES * ROWS_PER_BATCH)
Got 4194304 rows total, expected 4194304

Authentication with user/password

Often, services need a way to authenticate the user and identify who they are. Flight provides several ways to implement authentication; the simplest uses a user-password scheme. At startup, the client authenticates itself with the server using a username and password. The server returns an authorization token to include on future requests.

Warning

Authentication should only be used over a secure encrypted channel, i.e. TLS should be enabled.

Note

While the scheme is described as “(HTTP) basic authentication”, it does not actually implement HTTP authentication (RFC 7325) per se.

While Flight provides some interfaces to implement such a scheme, the server must provide the actual implementation, as demonstrated below. The implementation here is not secure and is provided as a minimal example only.

import base64
import secrets

import pyarrow as pa
import pyarrow.flight


class EchoServer(pa.flight.FlightServerBase):
    """A simple server that just echoes any requests from DoAction."""

    def do_action(self, context, action):
        return [action.type.encode("utf-8"), action.body]


class BasicAuthServerMiddlewareFactory(pa.flight.ServerMiddlewareFactory):
    """
    Middleware that implements username-password authentication.

    Parameters
    ----------
    creds: Dict[str, str]
        A dictionary of username-password values to accept.
    """

    def __init__(self, creds):
        self.creds = creds
        # Map generated bearer tokens to users
        self.tokens = {}

    def start_call(self, info, headers):
        """Validate credentials at the start of every call."""
        # Search for the authentication header (case-insensitive)
        auth_header = None
        for header in headers:
            if header.lower() == "authorization":
                auth_header = headers[header][0]
                break

        if not auth_header:
            raise pa.flight.FlightUnauthenticatedError("No credentials supplied")

        # The header has the structure "AuthType TokenValue", e.g.
        # "Basic <encoded username+password>" or "Bearer <random token>".
        auth_type, _, value = auth_header.partition(" ")

        if auth_type == "Basic":
            # Initial "login". The user provided a username/password
            # combination encoded in the same way as HTTP Basic Auth.
            decoded = base64.b64decode(value).decode("utf-8")
            username, _, password = decoded.partition(':')
            if not password or password != self.creds.get(username):
                raise pa.flight.FlightUnauthenticatedError("Unknown user or invalid password")
            # Generate a secret, random bearer token for future calls.
            token = secrets.token_urlsafe(32)
            self.tokens[token] = username
            return BasicAuthServerMiddleware(token)
        elif auth_type == "Bearer":
            # An actual call. Validate the bearer token.
            username = self.tokens.get(value)
            if username is None:
                raise pa.flight.FlightUnauthenticatedError("Invalid token")
            return BasicAuthServerMiddleware(value)

        raise pa.flight.FlightUnauthenticatedError("No credentials supplied")


class BasicAuthServerMiddleware(pa.flight.ServerMiddleware):
    """Middleware that implements username-password authentication."""

    def __init__(self, token):
        self.token = token

    def sending_headers(self):
        """Return the authentication token to the client."""
        return {"authorization": f"Bearer {self.token}"}


class NoOpAuthHandler(pa.flight.ServerAuthHandler):
    """
    A handler that implements username-password authentication.

    This is required only so that the server will respond to the internal
    Handshake RPC call, which the client calls when authenticate_basic_token
    is called. Otherwise, it should be a no-op as the actual authentication is
    implemented in middleware.
    """

    def authenticate(self, outgoing, incoming):
        pass

    def is_valid(self, token):
        return ""

We can then start the server:

if __name__ == '__main__':
    server = EchoServer(
        auth_handler=NoOpAuthHandler(),
        location="grpc://0.0.0.0:8816",
        middleware={
            "basic": BasicAuthServerMiddlewareFactory({
                "test": "password",
            })
        },
    )
    server.serve()

Then, we can make a client and log in:

import pyarrow as pa
import pyarrow.flight

client = pa.flight.connect("grpc://0.0.0.0:8816")

token_pair = client.authenticate_basic_token(b'test', b'password')
print(token_pair)
(b'authorization', b'Bearer ...')

For future calls, we include the authentication token with the call:

action = pa.flight.Action("echo", b"Hello, world!")
options = pa.flight.FlightCallOptions(headers=[token_pair])
for response in client.do_action(action=action, options=options):
    print(response.body.to_pybytes())
b'echo'
b'Hello, world!'

If we fail to do so, we get an authentication error:

try:
    list(client.do_action(action=action))
except pa.flight.FlightUnauthenticatedError as e:
    print("Unauthenticated:", e)
else:
    raise RuntimeError("Expected call to fail")
Unauthenticated: No credentials supplied. Detail: Unauthenticated

Or if we use the wrong credentials on login, we also get an error:

try:
    client.authenticate_basic_token(b'invalid', b'password')
except pa.flight.FlightUnauthenticatedError as e:
    print("Unauthenticated:", e)
else:
    raise RuntimeError("Expected call to fail")
Unauthenticated: Unknown user or invalid password. Detail: Unauthenticated

Securing connections with TLS

Following on from the previous scenario where traffic to the server is managed via a username and password, HTTPS (more specifically TLS) communication allows an additional layer of security by encrypting messages between the client and server. This is achieved using certificates. During development, the easiest approach is developing with self-signed certificates. At startup, the server loads the public and private key and the client authenticates the server with the TLS root certificate.

Note

In production environments it is recommended to make use of a certificate signed by a certificate authority.

Step 1 - Generating the Self Signed Certificate

Generate a self-signed certificate by using dotnet on Windows, or openssl on Linux or MacOS. Alternatively, the self-signed certificate from the Arrow testing data repository can be used. Depending on the file generated, you may need to convert it to a .crt and .key file as required for the Arrow server. One method to achieve this is openssl, please visit this IBM article for more info.

Step 2 - Running a server with TLS enabled

The code below is a minimal working example of an Arrow server used to receive data with TLS.

import argparse
import pyarrow
import pyarrow.flight


class FlightServer(pyarrow.flight.FlightServerBase):
    def __init__(self, host="localhost", location=None,
                 tls_certificates=None, verify_client=False,
                 root_certificates=None, auth_handler=None):
        super(FlightServer, self).__init__(
            location, auth_handler, tls_certificates, verify_client,
            root_certificates)
        self.flights = {}

    @classmethod
    def descriptor_to_key(self, descriptor):
        return (descriptor.descriptor_type.value, descriptor.command,
                tuple(descriptor.path or tuple()))

    def do_put(self, context, descriptor, reader, writer):
        key = FlightServer.descriptor_to_key(descriptor)
        print(key)
        self.flights[key] = reader.read_all()
        print(self.flights[key])


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tls", nargs=2, default=None, metavar=('CERTFILE', 'KEYFILE'))
    args = parser.parse_args()
    tls_certificates = []

    scheme = "grpc+tls"
    host = "localhost"
    port = "5005"

    with open(args.tls[0], "rb") as cert_file:
        tls_cert_chain = cert_file.read()
    with open(args.tls[1], "rb") as key_file:
        tls_private_key = key_file.read()

    tls_certificates.append((tls_cert_chain, tls_private_key))

    location = "{}://{}:{}".format(scheme, host, port)

    server = FlightServer(host, location,
                          tls_certificates=tls_certificates)
    print("Serving on", location)
    server.serve()


if __name__ == '__main__':
    main()

Running the server, you should see Serving on grpc+tls://localhost:5005.

Step 3 - Securely Connecting to the Server Suppose we want to connect to the client and push some data to it. The following code securely sends information to the server using TLS encryption.

import argparse
import pyarrow
import pyarrow.flight
import pandas as pd

# Assumes incoming data object is a Pandas Dataframe
def push_to_server(name, data, client):
    object_to_send = pyarrow.Table.from_pandas(data)
    writer, _ = client.do_put(pyarrow.flight.FlightDescriptor.for_path(name), object_to_send.schema)
    writer.write_table(object_to_send)
    writer.close()

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--tls-roots', default=None,
                        help='Path to trusted TLS certificate(s)')
    parser.add_argument('--host', default="localhost",
                        help='Host endpoint')
    parser.add_argument('--port', default=5005,
                        help='Host port')
    args = parser.parse_args()
    kwargs = {}

    with open(args.tls_roots, "rb") as root_certs:
        kwargs["tls_root_certs"] = root_certs.read()

    client = pyarrow.flight.FlightClient(f"grpc+tls://{args.host}:{args.port}", **kwargs)
    data = {'Animal': ['Dog', 'Cat', 'Mouse'], 'Size': ['Big', 'Small', 'Tiny']}
    df = pd.DataFrame(data, columns=['Animal', 'Size'])
    push_to_server("AnimalData", df, client)

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        print(e)

Running the client script, you should see the server printing out information about the data it just received.

Propagating OpenTelemetry Traces

Distributed tracing with OpenTelemetry allows collecting call-level performance measurements across a Flight service. In order to correlate spans across a Flight client and server, trace context must be passed between the two. This can be passed manually through headers in pyarrow.flight.FlightCallOptions, or can be automatically propagated using middleware.

This example shows how to accomplish trace propagation through middleware. The client middleware needs to inject the trace context into the call headers. The server middleware needs to extract the trace context from the headers and pass the context into a new span. Optionally, the client middleware can also create a new span to time the client-side call.

Step 1: define the client middleware:

import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import inject
from opentelemetry.trace.status import StatusCode

class ClientTracingMiddlewareFactory(flight.ClientMiddlewareFactory):
    def __init__(self):
        self._tracer = trace.get_tracer(__name__)

    def start_call(self, info):
        span = self._tracer.start_span(f"client.{info.method}")
        return ClientTracingMiddleware(span)

class ClientTracingMiddleware(flight.ClientMiddleware):
    def __init__(self, span):
        self._span = span

    def sending_headers(self):
        ctx = trace.set_span_in_context(self._span)
        carrier = {}
        inject(carrier=carrier, context=ctx)
        return carrier

    def call_completed(self, exception):
        if exception:
            self._span.record_exception(exception)
            self._span.set_status(StatusCode.ERROR)
            print(exception)
        else:
            self._span.set_status(StatusCode.OK)
        self._span.end()

Step 2: define the server middleware:

import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import extract
from opentelemetry.trace.status import StatusCode

class ServerTracingMiddlewareFactory(flight.ServerMiddlewareFactory):
    def __init__(self):
        self._tracer = trace.get_tracer(__name__)

    def start_call(self, info, headers):
        context = extract(headers)
        span = self._tracer.start_span(f"server.{info.method}", context=context)
        return ServerTracingMiddleware(span)

class ServerTracingMiddleware(flight.ServerMiddleware):
    def __init__(self, span):
        self._span = span

    def call_completed(self, exception):
        if exception:
            self._span.record_exception(exception)
            self._span.set_status(StatusCode.ERROR)
            print(exception)
        else:
            self._span.set_status(StatusCode.OK)
        self._span.end()

Step 3: configure the trace exporter, processor, and provider:

Both the server and client will need to be configured with the OpenTelemetry SDK to record spans and export them somewhere. For the sake of the example, we’ll collect the spans into a Python list, but this is normally where you would set them up to be exported to some service like Jaeger. See other examples of exporters at OpenTelemetry Exporters.

As part of this, you will need to define the resource where spans are running. At a minimum this is the service name, but it could include other information like a hostname, process id, service version, and operating system.

from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult

class TestSpanExporter(SpanExporter):
    def __init__(self):
        self.spans = []

    def export(self, spans):
        self.spans.extend(spans)
        return SpanExportResult.SUCCESS

def configure_tracing():
    # Service name is required for most backends,
    # and although it's not necessary for console export,
    # it's good to set service name anyways.
    resource = Resource(attributes={
        SERVICE_NAME: "my-service"
    })
    exporter = TestSpanExporter()
    provider = TracerProvider(resource=resource)
    processor = SimpleSpanProcessor(exporter)
    provider.add_span_processor(processor)
    trace.set_tracer_provider(provider)
    return exporter

Step 4: add the middleware to the server:

We can use the middleware now in our EchoServer from earlier.

if __name__ == '__main__':
    exporter = configure_tracing()
    server = EchoServer(
        location="grpc://0.0.0.0:8816",
        middleware={
            "tracing": ServerTracingMiddlewareFactory()
        },
    )
    server.serve()

Step 5: add the middleware to the client:

client = pa.flight.connect(
    "grpc://0.0.0.0:8816",
    middleware=[ClientTracingMiddlewareFactory()],
)

Step 6: use the client within active spans:

When we make a call with our client within an OpenTelemetry span, our client middleware will create a child span for the client-side Flight call and then propagate the span context to the server. Our server middleware will pick up that trace context and create another child span.

from opentelemetry import trace

# Client would normally also need to configure tracing, but for this example
# the client and server are running in the same Python process.
# exporter = configure_tracing()

tracer = trace.get_tracer(__name__)

with tracer.start_as_current_span("hello_world") as span:
    action = pa.flight.Action("echo", b"Hello, world!")
    # Call list() on do_action to drain all results.
    list(client.do_action(action=action))

print(f"There are {len(exporter.spans)} spans.")
print(f"The span names are:\n  {list(span.name for span in exporter.spans)}.")
print(f"The span status codes are:\n  "
      f"{list(span.status.status_code for span in exporter.spans)}.")
There are 3 spans.
The span names are:
  ['server.FlightMethod.DO_ACTION', 'client.FlightMethod.DO_ACTION', 'hello_world'].
The span status codes are:
  [<StatusCode.OK: 1>, <StatusCode.OK: 1>, <StatusCode.UNSET: 0>].

As expected, we have three spans: one in our client code, one in the client middleware, and one in the server middleware.