arrow_flight/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities to assist with reading and writing Arrow data as Flight messages
19
20use crate::{FlightData, SchemaAsIpc};
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use arrow_array::{ArrayRef, RecordBatch};
25use arrow_buffer::Buffer;
26use arrow_ipc::convert::fb_to_schema;
27use arrow_ipc::writer::CompressionContext;
28use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
29use arrow_schema::{ArrowError, Schema, SchemaRef};
30
31/// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es
32pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result<Vec<RecordBatch>, ArrowError> {
33    let schema = flight_data.first().ok_or_else(|| {
34        ArrowError::CastError("Need at least one FlightData for schema".to_string())
35    })?;
36    let message = root_as_message(&schema.data_header[..])
37        .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?;
38
39    let ipc_schema: arrow_ipc::Schema = message
40        .header_as_schema()
41        .ok_or_else(|| ArrowError::CastError("Cannot get header as Schema".to_string()))?;
42    let schema = fb_to_schema(ipc_schema);
43    let schema = Arc::new(schema);
44
45    let mut batches = vec![];
46    let dictionaries_by_id = HashMap::new();
47    for datum in flight_data[1..].iter() {
48        let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?;
49        batches.push(batch);
50    }
51    Ok(batches)
52}
53
54/// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`.
55pub fn flight_data_to_arrow_batch(
56    data: &FlightData,
57    schema: SchemaRef,
58    dictionaries_by_id: &HashMap<i64, ArrayRef>,
59) -> Result<RecordBatch, ArrowError> {
60    // check that the data_header is a record batch message
61    let message = arrow_ipc::root_as_message(&data.data_header[..])
62        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
63
64    message
65        .header_as_record_batch()
66        .ok_or_else(|| {
67            ArrowError::ParseError(
68                "Unable to convert flight data header to a record batch".to_string(),
69            )
70        })
71        .map(|batch| {
72            reader::read_record_batch(
73                &Buffer::from(data.data_body.as_ref()),
74                batch,
75                schema,
76                dictionaries_by_id,
77                None,
78                &message.version(),
79            )
80        })?
81}
82
83/// Convert `RecordBatch`es to wire protocol `FlightData`s
84pub fn batches_to_flight_data(
85    schema: &Schema,
86    batches: Vec<RecordBatch>,
87) -> Result<Vec<FlightData>, ArrowError> {
88    let options = IpcWriteOptions::default();
89    let schema_flight_data: FlightData = SchemaAsIpc::new(schema, &options).into();
90    let mut dictionaries = vec![];
91    let mut flight_data = vec![];
92
93    let data_gen = writer::IpcDataGenerator::default();
94    let mut dictionary_tracker = writer::DictionaryTracker::new(false);
95    let mut compression_context = CompressionContext::default();
96
97    for batch in batches.iter() {
98        let (encoded_dictionaries, encoded_batch) = data_gen.encode(
99            batch,
100            &mut dictionary_tracker,
101            &options,
102            &mut compression_context,
103        )?;
104
105        dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into));
106        flight_data.push(encoded_batch.into());
107    }
108
109    let mut stream = Vec::with_capacity(1 + dictionaries.len() + flight_data.len());
110
111    stream.push(schema_flight_data);
112    stream.extend(dictionaries);
113    stream.extend(flight_data);
114    let flight_data = stream;
115    Ok(flight_data)
116}