Arrow Flight¶
Recipes related to leveraging Arrow Flight protocol
Simple Parquet storage service with Arrow Flight¶
Suppose you want to implement a service that can store, send and receive
Parquet files using the Arrow Flight protocol,
pyarrow
provides an implementation framework in pyarrow.flight
and particularly through the pyarrow.flight.FlightServerBase
class.
import pathlib
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
class FlightServer(pa.flight.FlightServerBase):
def __init__(self, location="grpc://0.0.0.0:8815",
repo=pathlib.Path("./datasets"), **kwargs):
super(FlightServer, self).__init__(location, **kwargs)
self._location = location
self._repo = repo
def _make_flight_info(self, dataset):
dataset_path = self._repo / dataset
schema = pa.parquet.read_schema(dataset_path)
metadata = pa.parquet.read_metadata(dataset_path)
descriptor = pa.flight.FlightDescriptor.for_path(
dataset.encode('utf-8')
)
endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
return pyarrow.flight.FlightInfo(schema,
descriptor,
endpoints,
metadata.num_rows,
metadata.serialized_size)
def list_flights(self, context, criteria):
for dataset in self._repo.iterdir():
yield self._make_flight_info(dataset.name)
def get_flight_info(self, context, descriptor):
return self._make_flight_info(descriptor.path[0].decode('utf-8'))
def do_put(self, context, descriptor, reader, writer):
dataset = descriptor.path[0].decode('utf-8')
dataset_path = self._repo / dataset
data_table = reader.read_all()
pa.parquet.write_table(data_table, dataset_path)
def do_get(self, context, ticket):
dataset = ticket.ticket.decode('utf-8')
dataset_path = self._repo / dataset
return pa.flight.RecordBatchStream(pa.parquet.read_table(dataset_path))
def list_actions(self, context):
return [
("drop_dataset", "Delete a dataset."),
]
def do_action(self, context, action):
if action.type == "drop_dataset":
self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
else:
raise NotImplementedError
def do_drop_dataset(self, dataset):
dataset_path = self._repo / dataset
dataset_path.unlink()
The example server exposes pyarrow.flight.FlightServerBase.list_flights()
which is the method in charge of returning the list of data streams available
for fetching.
Likewise, pyarrow.flight.FlightServerBase.get_flight_info()
provides
the information regarding a single specific data stream.
Then we expose pyarrow.flight.FlightServerBase.do_get()
which is in charge
of actually fetching the exposed data streams and sending them to the client.
Allowing to list and download data streams would be pretty useless if we didn’t
expose a way to create them, this is the responsibility of
pyarrow.flight.FlightServerBase.do_put()
which is in charge of receiving
new data from the client and dealing with it (in this case saving it
into a parquet file)
This are the most common Arrow Flight requests, if we need to add more functionalities, we can do so using custom actions.
In the previous example a drop_dataset
custom action is added.
All custom actions are executed through the
pyarrow.flight.FlightServerBase.do_action()
method, thus it’s up to
the server subclass to dispatch them properly. In this case we invoke
the do_drop_dataset method when the action.type is the one we expect.
Our server can then be started with
pyarrow.flight.FlightServerBase.serve()
if __name__ == '__main__':
server = FlightServer()
server._repo.mkdir(exist_ok=True)
server.serve()
Once the server is started we can build a client to perform requests to it
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8815")
We can create a new table and upload it so that it gets stored in a new parquet file:
# Upload a new dataset
data_table = pa.table(
[["Mario", "Luigi", "Peach"]],
names=["Character"]
)
upload_descriptor = pa.flight.FlightDescriptor.for_path("uploaded.parquet")
writer, _ = client.do_put(upload_descriptor, data_table.schema)
writer.write_table(data_table)
writer.close()
Once uploaded we should be able to retrieve the metadata for our newly uploaded table:
# Retrieve metadata of newly uploaded dataset
flight = client.get_flight_info(upload_descriptor)
descriptor = flight.descriptor
print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
print("=== Schema ===")
print(flight.schema)
print("==============")
Path: uploaded.parquet Rows: 3 Size: ...
=== Schema ===
Character: string
==============
And we can fetch the content of the dataset:
# Read content of the dataset
reader = client.do_get(flight.endpoints[0].ticket)
read_table = reader.read_all()
print(read_table.to_pandas().head())
Character
0 Mario
1 Luigi
2 Peach
Once we finished we can invoke our custom action to delete the dataset we newly uploaded:
# Drop the newly uploaded dataset
client.do_action(pa.flight.Action("drop_dataset", "uploaded.parquet".encode('utf-8')))
To confirm our dataset was deleted, we might list all parquet files that are currently stored by the server:
# List existing datasets.
for flight in client.list_flights():
descriptor = flight.descriptor
print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
print("=== Schema ===")
print(flight.schema)
print("==============")
print("")
Streaming Parquet Storage Service¶
We can improve the Parquet storage service and avoid holding entire datasets in memory by streaming data. Flight readers and writers, like others in PyArrow, can be iterated through, so let’s update the server from before to take advantage of this:
import pathlib
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
class FlightServer(pa.flight.FlightServerBase):
def __init__(self, location="grpc://0.0.0.0:8815",
repo=pathlib.Path("./datasets"), **kwargs):
super(FlightServer, self).__init__(location, **kwargs)
self._location = location
self._repo = repo
def _make_flight_info(self, dataset):
dataset_path = self._repo / dataset
schema = pa.parquet.read_schema(dataset_path)
metadata = pa.parquet.read_metadata(dataset_path)
descriptor = pa.flight.FlightDescriptor.for_path(
dataset.encode('utf-8')
)
endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
return pyarrow.flight.FlightInfo(schema,
descriptor,
endpoints,
metadata.num_rows,
metadata.serialized_size)
def list_flights(self, context, criteria):
for dataset in self._repo.iterdir():
yield self._make_flight_info(dataset.name)
def get_flight_info(self, context, descriptor):
return self._make_flight_info(descriptor.path[0].decode('utf-8'))
def do_put(self, context, descriptor, reader, writer):
dataset = descriptor.path[0].decode('utf-8')
dataset_path = self._repo / dataset
# Read the uploaded data and write to Parquet incrementally
with dataset_path.open("wb") as sink:
with pa.parquet.ParquetWriter(sink, reader.schema) as writer:
for chunk in reader:
writer.write_table(pa.Table.from_batches([chunk.data]))
def do_get(self, context, ticket):
dataset = ticket.ticket.decode('utf-8')
# Stream data from a file
dataset_path = self._repo / dataset
reader = pa.parquet.ParquetFile(dataset_path)
return pa.flight.GeneratorStream(
reader.schema_arrow, reader.iter_batches())
def list_actions(self, context):
return [
("drop_dataset", "Delete a dataset."),
]
def do_action(self, context, action):
if action.type == "drop_dataset":
self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
else:
raise NotImplementedError
def do_drop_dataset(self, dataset):
dataset_path = self._repo / dataset
dataset_path.unlink()
First, we’ve modified pyarrow.flight.FlightServerBase.do_put()
. Instead
of reading all the uploaded data into a pyarrow.Table
before writing,
we instead iterate through each batch as it comes and add it to a Parquet file.
Then, we’ve modified pyarrow.flight.FlightServerBase.do_get()
to stream
data to the client. This uses pyarrow.flight.GeneratorStream
, which
takes a schema and any iterable or iterator. Flight then iterates through and
sends each record batch to the client, allowing us to handle even large Parquet
files that don’t fit into memory.
While GeneratorStream has the advantage that it can stream data, that means Flight must call back into Python for each record batch to send. In contrast, RecordBatchStream requires that all data is in-memory up front, but once created, all data transfer is handled purely in C++, without needing to call Python code.
Let’s give the server a spin. As before, we’ll start the server:
if __name__ == '__main__':
server = FlightServer()
server._repo.mkdir(exist_ok=True)
server.serve()
We create a client, and this time, we’ll write batches to the writer, as if we had a stream of data instead of a table in memory:
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8815")
# Upload a new dataset
NUM_BATCHES = 1024
ROWS_PER_BATCH = 4096
upload_descriptor = pa.flight.FlightDescriptor.for_path("streamed.parquet")
batch = pa.record_batch([
pa.array(range(ROWS_PER_BATCH)),
], names=["ints"])
writer, _ = client.do_put(upload_descriptor, batch.schema)
with writer:
for _ in range(NUM_BATCHES):
writer.write_batch(batch)
As before, we can then read it back. Again, we’ll read each batch from the stream as it arrives, instead of reading them all into a table:
# Read content of the dataset
flight = client.get_flight_info(upload_descriptor)
reader = client.do_get(flight.endpoints[0].ticket)
total_rows = 0
for chunk in reader:
total_rows += chunk.data.num_rows
print("Got", total_rows, "rows total, expected", NUM_BATCHES * ROWS_PER_BATCH)
Got 4194304 rows total, expected 4194304
Authentication with user/password¶
Often, services need a way to authenticate the user and identify who they are. Flight provides several ways to implement authentication; the simplest uses a user-password scheme. At startup, the client authenticates itself with the server using a username and password. The server returns an authorization token to include on future requests.
Warning
Authentication should only be used over a secure encrypted channel, i.e. TLS should be enabled.
Note
While the scheme is described as “(HTTP) basic authentication”, it does not actually implement HTTP authentication (RFC 7325) per se.
While Flight provides some interfaces to implement such a scheme, the server must provide the actual implementation, as demonstrated below. The implementation here is not secure and is provided as a minimal example only.
import base64
import secrets
import pyarrow as pa
import pyarrow.flight
class EchoServer(pa.flight.FlightServerBase):
"""A simple server that just echoes any requests from DoAction."""
def do_action(self, context, action):
return [action.type.encode("utf-8"), action.body]
class BasicAuthServerMiddlewareFactory(pa.flight.ServerMiddlewareFactory):
"""
Middleware that implements username-password authentication.
Parameters
----------
creds: Dict[str, str]
A dictionary of username-password values to accept.
"""
def __init__(self, creds):
self.creds = creds
# Map generated bearer tokens to users
self.tokens = {}
def start_call(self, info, headers):
"""Validate credentials at the start of every call."""
# Search for the authentication header (case-insensitive)
auth_header = None
for header in headers:
if header.lower() == "authorization":
auth_header = headers[header][0]
break
if not auth_header:
raise pa.flight.FlightUnauthenticatedError("No credentials supplied")
# The header has the structure "AuthType TokenValue", e.g.
# "Basic <encoded username+password>" or "Bearer <random token>".
auth_type, _, value = auth_header.partition(" ")
if auth_type == "Basic":
# Initial "login". The user provided a username/password
# combination encoded in the same way as HTTP Basic Auth.
decoded = base64.b64decode(value).decode("utf-8")
username, _, password = decoded.partition(':')
if not password or password != self.creds.get(username):
raise pa.flight.FlightUnauthenticatedError("Unknown user or invalid password")
# Generate a secret, random bearer token for future calls.
token = secrets.token_urlsafe(32)
self.tokens[token] = username
return BasicAuthServerMiddleware(token)
elif auth_type == "Bearer":
# An actual call. Validate the bearer token.
username = self.tokens.get(value)
if username is None:
raise pa.flight.FlightUnauthenticatedError("Invalid token")
return BasicAuthServerMiddleware(value)
raise pa.flight.FlightUnauthenticatedError("No credentials supplied")
class BasicAuthServerMiddleware(pa.flight.ServerMiddleware):
"""Middleware that implements username-password authentication."""
def __init__(self, token):
self.token = token
def sending_headers(self):
"""Return the authentication token to the client."""
return {"authorization": f"Bearer {self.token}"}
class NoOpAuthHandler(pa.flight.ServerAuthHandler):
"""
A handler that implements username-password authentication.
This is required only so that the server will respond to the internal
Handshake RPC call, which the client calls when authenticate_basic_token
is called. Otherwise, it should be a no-op as the actual authentication is
implemented in middleware.
"""
def authenticate(self, outgoing, incoming):
pass
def is_valid(self, token):
return ""
We can then start the server:
if __name__ == '__main__':
server = EchoServer(
auth_handler=NoOpAuthHandler(),
location="grpc://0.0.0.0:8816",
middleware={
"basic": BasicAuthServerMiddlewareFactory({
"test": "password",
})
},
)
server.serve()
Then, we can make a client and log in:
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8816")
token_pair = client.authenticate_basic_token(b'test', b'password')
print(token_pair)
(b'authorization', b'Bearer ...')
For future calls, we include the authentication token with the call:
action = pa.flight.Action("echo", b"Hello, world!")
options = pa.flight.FlightCallOptions(headers=[token_pair])
for response in client.do_action(action=action, options=options):
print(response.body.to_pybytes())
b'echo'
b'Hello, world!'
If we fail to do so, we get an authentication error:
try:
list(client.do_action(action=action))
except pa.flight.FlightUnauthenticatedError as e:
print("Unauthenticated:", e)
else:
raise RuntimeError("Expected call to fail")
Unauthenticated: No credentials supplied. Detail: Unauthenticated
Or if we use the wrong credentials on login, we also get an error:
try:
client.authenticate_basic_token(b'invalid', b'password')
except pa.flight.FlightUnauthenticatedError as e:
print("Unauthenticated:", e)
else:
raise RuntimeError("Expected call to fail")
Unauthenticated: Unknown user or invalid password. Detail: Unauthenticated
Securing connections with TLS¶
Following on from the previous scenario where traffic to the server is managed via a username and password, HTTPS (more specifically TLS) communication allows an additional layer of security by encrypting messages between the client and server. This is achieved using certificates. During development, the easiest approach is developing with self-signed certificates. At startup, the server loads the public and private key and the client authenticates the server with the TLS root certificate.
Note
In production environments it is recommended to make use of a certificate signed by a certificate authority.
Step 1 - Generating the Self Signed Certificate
Generate a self-signed certificate by using dotnet on Windows, or openssl on Linux or MacOS. Alternatively, the self-signed certificate from the Arrow testing data repository can be used. Depending on the file generated, you may need to convert it to a .crt and .key file as required for the Arrow server. One method to achieve this is openssl, please visit this IBM article for more info.
Step 2 - Running a server with TLS enabled
The code below is a minimal working example of an Arrow server used to receive data with TLS.
import argparse
import pyarrow
import pyarrow.flight
class FlightServer(pyarrow.flight.FlightServerBase):
def __init__(self, host="localhost", location=None,
tls_certificates=None, verify_client=False,
root_certificates=None, auth_handler=None):
super(FlightServer, self).__init__(
location, auth_handler, tls_certificates, verify_client,
root_certificates)
self.flights = {}
@classmethod
def descriptor_to_key(self, descriptor):
return (descriptor.descriptor_type.value, descriptor.command,
tuple(descriptor.path or tuple()))
def do_put(self, context, descriptor, reader, writer):
key = FlightServer.descriptor_to_key(descriptor)
print(key)
self.flights[key] = reader.read_all()
print(self.flights[key])
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tls", nargs=2, default=None, metavar=('CERTFILE', 'KEYFILE'))
args = parser.parse_args()
tls_certificates = []
scheme = "grpc+tls"
host = "localhost"
port = "5005"
with open(args.tls[0], "rb") as cert_file:
tls_cert_chain = cert_file.read()
with open(args.tls[1], "rb") as key_file:
tls_private_key = key_file.read()
tls_certificates.append((tls_cert_chain, tls_private_key))
location = "{}://{}:{}".format(scheme, host, port)
server = FlightServer(host, location,
tls_certificates=tls_certificates)
print("Serving on", location)
server.serve()
if __name__ == '__main__':
main()
Running the server, you should see Serving on grpc+tls://localhost:5005
.
Step 3 - Securely Connecting to the Server Suppose we want to connect to the client and push some data to it. The following code securely sends information to the server using TLS encryption.
import argparse
import pyarrow
import pyarrow.flight
import pandas as pd
# Assumes incoming data object is a Pandas Dataframe
def push_to_server(name, data, client):
object_to_send = pyarrow.Table.from_pandas(data)
writer, _ = client.do_put(pyarrow.flight.FlightDescriptor.for_path(name), object_to_send.schema)
writer.write_table(object_to_send)
writer.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--tls-roots', default=None,
help='Path to trusted TLS certificate(s)')
parser.add_argument('--host', default="localhost",
help='Host endpoint')
parser.add_argument('--port', default=5005,
help='Host port')
args = parser.parse_args()
kwargs = {}
with open(args.tls_roots, "rb") as root_certs:
kwargs["tls_root_certs"] = root_certs.read()
client = pyarrow.flight.FlightClient(f"grpc+tls://{args.host}:{args.port}", **kwargs)
data = {'Animal': ['Dog', 'Cat', 'Mouse'], 'Size': ['Big', 'Small', 'Tiny']}
df = pd.DataFrame(data, columns=['Animal', 'Size'])
push_to_server("AnimalData", df, client)
if __name__ == '__main__':
try:
main()
except Exception as e:
print(e)
Running the client script, you should see the server printing out information about the data it just received.
Propagating OpenTelemetry Traces¶
Distributed tracing with OpenTelemetry allows collecting call-level performance
measurements across a Flight service. In order to correlate spans across a Flight
client and server, trace context must be passed between the two. This can be passed
manually through headers in pyarrow.flight.FlightCallOptions
, or can
be automatically propagated using middleware.
This example shows how to accomplish trace propagation through middleware. The client middleware needs to inject the trace context into the call headers. The server middleware needs to extract the trace context from the headers and pass the context into a new span. Optionally, the client middleware can also create a new span to time the client-side call.
Step 1: define the client middleware:
import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import inject
from opentelemetry.trace.status import StatusCode
class ClientTracingMiddlewareFactory(flight.ClientMiddlewareFactory):
def __init__(self):
self._tracer = trace.get_tracer(__name__)
def start_call(self, info):
span = self._tracer.start_span(f"client.{info.method}")
return ClientTracingMiddleware(span)
class ClientTracingMiddleware(flight.ClientMiddleware):
def __init__(self, span):
self._span = span
def sending_headers(self):
ctx = trace.set_span_in_context(self._span)
carrier = {}
inject(carrier=carrier, context=ctx)
return carrier
def call_completed(self, exception):
if exception:
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR)
print(exception)
else:
self._span.set_status(StatusCode.OK)
self._span.end()
Step 2: define the server middleware:
import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import extract
from opentelemetry.trace.status import StatusCode
class ServerTracingMiddlewareFactory(flight.ServerMiddlewareFactory):
def __init__(self):
self._tracer = trace.get_tracer(__name__)
def start_call(self, info, headers):
context = extract(headers)
span = self._tracer.start_span(f"server.{info.method}", context=context)
return ServerTracingMiddleware(span)
class ServerTracingMiddleware(flight.ServerMiddleware):
def __init__(self, span):
self._span = span
def call_completed(self, exception):
if exception:
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR)
print(exception)
else:
self._span.set_status(StatusCode.OK)
self._span.end()
Step 3: configure the trace exporter, processor, and provider:
Both the server and client will need to be configured with the OpenTelemetry SDK to record spans and export them somewhere. For the sake of the example, we’ll collect the spans into a Python list, but this is normally where you would set them up to be exported to some service like Jaeger. See other examples of exporters at OpenTelemetry Exporters.
As part of this, you will need to define the resource where spans are running. At a minimum this is the service name, but it could include other information like a hostname, process id, service version, and operating system.
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
class TestSpanExporter(SpanExporter):
def __init__(self):
self.spans = []
def export(self, spans):
self.spans.extend(spans)
return SpanExportResult.SUCCESS
def configure_tracing():
# Service name is required for most backends,
# and although it's not necessary for console export,
# it's good to set service name anyways.
resource = Resource(attributes={
SERVICE_NAME: "my-service"
})
exporter = TestSpanExporter()
provider = TracerProvider(resource=resource)
processor = SimpleSpanProcessor(exporter)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
return exporter
Step 4: add the middleware to the server:
We can use the middleware now in our EchoServer from earlier.
if __name__ == '__main__':
exporter = configure_tracing()
server = EchoServer(
location="grpc://0.0.0.0:8816",
middleware={
"tracing": ServerTracingMiddlewareFactory()
},
)
server.serve()
Step 5: add the middleware to the client:
client = pa.flight.connect(
"grpc://0.0.0.0:8816",
middleware=[ClientTracingMiddlewareFactory()],
)
Step 6: use the client within active spans:
When we make a call with our client within an OpenTelemetry span, our client middleware will create a child span for the client-side Flight call and then propagate the span context to the server. Our server middleware will pick up that trace context and create another child span.
from opentelemetry import trace
# Client would normally also need to configure tracing, but for this example
# the client and server are running in the same Python process.
# exporter = configure_tracing()
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("hello_world") as span:
action = pa.flight.Action("echo", b"Hello, world!")
# Call list() on do_action to drain all results.
list(client.do_action(action=action))
print(f"There are {len(exporter.spans)} spans.")
print(f"The span names are:\n {list(span.name for span in exporter.spans)}.")
print(f"The span status codes are:\n "
f"{list(span.status.status_code for span in exporter.spans)}.")
There are 3 spans.
The span names are:
['server.FlightMethod.DO_ACTION', 'client.FlightMethod.DO_ACTION', 'hello_world'].
The span status codes are:
[<StatusCode.OK: 1>, <StatusCode.OK: 1>, <StatusCode.UNSET: 0>].
As expected, we have three spans: one in our client code, one in the client middleware, and one in the server middleware.