arrow_flight/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities to assist with reading and writing Arrow data as Flight messages
19
20use crate::{FlightData, SchemaAsIpc};
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use arrow_array::{ArrayRef, RecordBatch};
25use arrow_buffer::Buffer;
26use arrow_ipc::convert::fb_to_schema;
27use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
28use arrow_schema::{ArrowError, Schema, SchemaRef};
29
30/// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es
31pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result<Vec<RecordBatch>, ArrowError> {
32    let schema = flight_data.first().ok_or_else(|| {
33        ArrowError::CastError("Need at least one FlightData for schema".to_string())
34    })?;
35    let message = root_as_message(&schema.data_header[..])
36        .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?;
37
38    let ipc_schema: arrow_ipc::Schema = message
39        .header_as_schema()
40        .ok_or_else(|| ArrowError::CastError("Cannot get header as Schema".to_string()))?;
41    let schema = fb_to_schema(ipc_schema);
42    let schema = Arc::new(schema);
43
44    let mut batches = vec![];
45    let dictionaries_by_id = HashMap::new();
46    for datum in flight_data[1..].iter() {
47        let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?;
48        batches.push(batch);
49    }
50    Ok(batches)
51}
52
53/// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`.
54pub fn flight_data_to_arrow_batch(
55    data: &FlightData,
56    schema: SchemaRef,
57    dictionaries_by_id: &HashMap<i64, ArrayRef>,
58) -> Result<RecordBatch, ArrowError> {
59    // check that the data_header is a record batch message
60    let message = arrow_ipc::root_as_message(&data.data_header[..])
61        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
62
63    message
64        .header_as_record_batch()
65        .ok_or_else(|| {
66            ArrowError::ParseError(
67                "Unable to convert flight data header to a record batch".to_string(),
68            )
69        })
70        .map(|batch| {
71            reader::read_record_batch(
72                &Buffer::from(data.data_body.as_ref()),
73                batch,
74                schema,
75                dictionaries_by_id,
76                None,
77                &message.version(),
78            )
79        })?
80}
81
82/// Convert `RecordBatch`es to wire protocol `FlightData`s
83pub fn batches_to_flight_data(
84    schema: &Schema,
85    batches: Vec<RecordBatch>,
86) -> Result<Vec<FlightData>, ArrowError> {
87    let options = IpcWriteOptions::default();
88    let schema_flight_data: FlightData = SchemaAsIpc::new(schema, &options).into();
89    let mut dictionaries = vec![];
90    let mut flight_data = vec![];
91
92    let data_gen = writer::IpcDataGenerator::default();
93    #[allow(deprecated)]
94    let mut dictionary_tracker =
95        writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
96
97    for batch in batches.iter() {
98        let (encoded_dictionaries, encoded_batch) =
99            data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?;
100
101        dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into));
102        flight_data.push(encoded_batch.into());
103    }
104
105    let mut stream = Vec::with_capacity(1 + dictionaries.len() + flight_data.len());
106
107    stream.push(schema_flight_data);
108    stream.extend(dictionaries);
109    stream.extend(flight_data);
110    let flight_data = stream;
111    Ok(flight_data)
112}