arrow_flight/
encode.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{ready, stream::BoxStream, Stream, StreamExt};
28
29/// Creates a [`Stream`] of [`FlightData`]s from a
30/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>.
31///
32/// This can be used to implement [`FlightService::do_get`] in an
33/// Arrow Flight implementation;
34///
35/// This structure encodes a stream of `Result`s rather than `RecordBatch`es  to
36/// propagate errors from streaming execution, where the generation of the
37/// `RecordBatch`es is incremental, and an error may occur even after
38/// several have already been successfully produced.
39///
40/// # Caveats
41/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`],
42///    [`DictionaryArray`]s are converted to their underlying types prior to
43///    transport.
44///    When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every
45///    [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray).
46///    See <https://github.com/apache/arrow-rs/issues/3389>.
47///
48/// [`DictionaryArray`]: arrow_array::array::DictionaryArray
49///
50/// # Example
51/// ```no_run
52/// # use std::sync::Arc;
53/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
54/// # async fn f() {
55/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
56/// # let batch = RecordBatch::try_from_iter(vec![
57/// #      ("a", Arc::new(c1) as ArrayRef)
58/// #   ])
59/// #   .expect("cannot create record batch");
60/// use arrow_flight::encode::FlightDataEncoderBuilder;
61///
62/// // Get an input stream of Result<RecordBatch, FlightError>
63/// let input_stream = futures::stream::iter(vec![Ok(batch)]);
64///
65/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
66/// let flight_data_stream = FlightDataEncoderBuilder::new()
67///  .build(input_stream);
68///
69/// // Create a tonic `Response` that can be returned from a Flight server
70/// let response = tonic::Response::new(flight_data_stream);
71/// # }
72/// ```
73///
74/// # Example: Sending `Vec<RecordBatch>`
75///
76/// You can create a [`Stream`] to pass to [`Self::build`] from an existing
77/// `Vec` of `RecordBatch`es like this:
78///
79/// ```
80/// # use std::sync::Arc;
81/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
82/// # async fn f() {
83/// # fn make_batches() -> Vec<RecordBatch> {
84/// #   let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
85/// #   let batch = RecordBatch::try_from_iter(vec![
86/// #      ("a", Arc::new(c1) as ArrayRef)
87/// #   ])
88/// #   .expect("cannot create record batch");
89/// #   vec![batch.clone(), batch.clone()]
90/// # }
91/// use arrow_flight::encode::FlightDataEncoderBuilder;
92///
93/// // Get batches that you want to send via Flight
94/// let batches: Vec<RecordBatch> = make_batches();
95///
96/// // Create an input stream of Result<RecordBatch, FlightError>
97/// let input_stream = futures::stream::iter(
98///   batches.into_iter().map(Ok)
99/// );
100///
101/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
102/// let flight_data_stream = FlightDataEncoderBuilder::new()
103///  .build(input_stream);
104/// # }
105/// ```
106///
107/// # Example: Determining schema of encoded data
108///
109/// Encoding flight data may hydrate dictionaries, see [`DictionaryHandling`] for more information,
110/// which changes the schema of the encoded data compared to the input record batches.
111/// The fully hydrated schema can be accessed using the [`FlightDataEncoder::known_schema`] method
112/// and explicitly informing the builder of the schema using [`FlightDataEncoderBuilder::with_schema`].
113///
114/// ```
115/// # use std::sync::Arc;
116/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
117/// # async fn f() {
118/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
119/// # let batch = RecordBatch::try_from_iter(vec![
120/// #      ("a", Arc::new(c1) as ArrayRef)
121/// #   ])
122/// #   .expect("cannot create record batch");
123/// use arrow_flight::encode::FlightDataEncoderBuilder;
124///
125/// // Get the schema of the input stream
126/// let schema = batch.schema();
127///
128/// // Get an input stream of Result<RecordBatch, FlightError>
129/// let input_stream = futures::stream::iter(vec![Ok(batch)]);
130///
131/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
132/// let flight_data_stream = FlightDataEncoderBuilder::new()
133///  // Inform the builder of the input stream schema
134///  .with_schema(schema)
135///  .build(input_stream);
136///
137/// // Retrieve the schema of the encoded data
138/// let encoded_schema = flight_data_stream.known_schema();
139/// # }
140/// ```
141///
142/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get
143/// [`FlightError`]: crate::error::FlightError
144#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146    /// The maximum approximate target message size in bytes
147    /// (see details on [`Self::with_max_flight_data_size`]).
148    max_flight_data_size: usize,
149    /// Ipc writer options
150    options: IpcWriteOptions,
151    /// Metadata to add to the schema message
152    app_metadata: Bytes,
153    /// Optional schema, if known before data.
154    schema: Option<SchemaRef>,
155    /// Optional flight descriptor, if known before data.
156    descriptor: Option<FlightDescriptor>,
157    /// Deterimines how `DictionaryArray`s are encoded for transport.
158    /// See [`DictionaryHandling`] for more information.
159    dictionary_handling: DictionaryHandling,
160}
161
162/// Default target size for encoded [`FlightData`].
163///
164/// Note this value would normally be 4MB, but the size calculation is
165/// somewhat inexact, so we set it to 2MB.
166pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169    fn default() -> Self {
170        Self {
171            max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172            options: IpcWriteOptions::default(),
173            app_metadata: Bytes::new(),
174            schema: None,
175            descriptor: None,
176            dictionary_handling: DictionaryHandling::Hydrate,
177        }
178    }
179}
180
181impl FlightDataEncoderBuilder {
182    /// Create a new [`FlightDataEncoderBuilder`].
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Set the (approximate) maximum size, in bytes, of the
188    /// [`FlightData`] produced by this encoder. Defaults to 2MB.
189    ///
190    /// Since there is often a maximum message size for gRPC messages
191    /// (typically around 4MB), this encoder splits up [`RecordBatch`]s
192    /// (preserving order) into multiple [`FlightData`] objects to
193    /// limit the size individual messages sent via gRPC.
194    ///
195    /// The size is approximate because of the additional encoding
196    /// overhead on top of the underlying data buffers themselves.
197    pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198        self.max_flight_data_size = max_flight_data_size;
199        self
200    }
201
202    /// Set [`DictionaryHandling`] for encoder
203    pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204        self.dictionary_handling = dictionary_handling;
205        self
206    }
207
208    /// Specify application specific metadata included in the
209    /// [`FlightData::app_metadata`] field of the the first Schema
210    /// message
211    pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212        self.app_metadata = app_metadata;
213        self
214    }
215
216    /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for transport.
217    pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218        self.options = options;
219        self
220    }
221
222    /// Specify a schema for the RecordBatches being sent. If a schema
223    /// is not specified, an encoded Schema message will be sent when
224    /// the first [`RecordBatch`], if any, is encoded. Some clients
225    /// expect a Schema message even if there is no data sent.
226    pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227        self.schema = Some(schema);
228        self
229    }
230
231    /// Specify a flight descriptor in the first FlightData message.
232    pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233        self.descriptor = descriptor;
234        self
235    }
236
237    /// Takes a [`Stream`] of [`Result<RecordBatch>`] and returns a [`Stream`]
238    /// of [`FlightData`], consuming self.
239    ///
240    /// See example on [`Self`] and [`FlightDataEncoder`] for more details
241    pub fn build<S>(self, input: S) -> FlightDataEncoder
242    where
243        S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244    {
245        let Self {
246            max_flight_data_size,
247            options,
248            app_metadata,
249            schema,
250            descriptor,
251            dictionary_handling,
252        } = self;
253
254        FlightDataEncoder::new(
255            input.boxed(),
256            schema,
257            max_flight_data_size,
258            options,
259            app_metadata,
260            descriptor,
261            dictionary_handling,
262        )
263    }
264}
265
266/// Stream that encodes a stream of record batches to flight data.
267///
268/// See [`FlightDataEncoderBuilder`] for details and example.
269pub struct FlightDataEncoder {
270    /// Input stream
271    inner: BoxStream<'static, Result<RecordBatch>>,
272    /// schema, set after the first batch
273    schema: Option<SchemaRef>,
274    /// Target maximum size of flight data
275    /// (see details on [`FlightDataEncoderBuilder::with_max_flight_data_size`]).
276    max_flight_data_size: usize,
277    /// do the encoding / tracking of dictionaries
278    encoder: FlightIpcEncoder,
279    /// optional metadata to add to schema FlightData
280    app_metadata: Option<Bytes>,
281    /// data queued up to send but not yet sent
282    queue: VecDeque<FlightData>,
283    /// Is this stream done (inner is empty or errored)
284    done: bool,
285    /// cleared after the first FlightData message is sent
286    descriptor: Option<FlightDescriptor>,
287    /// Deterimines how `DictionaryArray`s are encoded for transport.
288    /// See [`DictionaryHandling`] for more information.
289    dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293    fn new(
294        inner: BoxStream<'static, Result<RecordBatch>>,
295        schema: Option<SchemaRef>,
296        max_flight_data_size: usize,
297        options: IpcWriteOptions,
298        app_metadata: Bytes,
299        descriptor: Option<FlightDescriptor>,
300        dictionary_handling: DictionaryHandling,
301    ) -> Self {
302        let mut encoder = Self {
303            inner,
304            schema: None,
305            max_flight_data_size,
306            encoder: FlightIpcEncoder::new(
307                options,
308                dictionary_handling != DictionaryHandling::Resend,
309            ),
310            app_metadata: Some(app_metadata),
311            queue: VecDeque::new(),
312            done: false,
313            descriptor,
314            dictionary_handling,
315        };
316
317        // If schema is known up front, enqueue it immediately
318        if let Some(schema) = schema {
319            encoder.encode_schema(&schema);
320        }
321
322        encoder
323    }
324
325    /// Report the schema of the encoded data when known.
326    /// A schema is known when provided via the [`FlightDataEncoderBuilder::with_schema`] method.
327    pub fn known_schema(&self) -> Option<SchemaRef> {
328        self.schema.clone()
329    }
330
331    /// Place the `FlightData` in the queue to send
332    fn queue_message(&mut self, mut data: FlightData) {
333        if let Some(descriptor) = self.descriptor.take() {
334            data.flight_descriptor = Some(descriptor);
335        }
336        self.queue.push_back(data);
337    }
338
339    /// Place the `FlightData` in the queue to send
340    fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341        for data in datas {
342            self.queue_message(data)
343        }
344    }
345
346    /// Encodes schema as a [`FlightData`] in self.queue.
347    /// Updates `self.schema` and returns the new schema
348    fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349        // The first message is the schema message, and all
350        // batches have the same schema
351        let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352        let schema = Arc::new(prepare_schema_for_flight(
353            schema,
354            &mut self.encoder.dictionary_tracker,
355            send_dictionaries,
356        ));
357        let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359        // attach any metadata requested
360        if let Some(app_metadata) = self.app_metadata.take() {
361            schema_flight_data.app_metadata = app_metadata;
362        }
363        self.queue_message(schema_flight_data);
364        // remember schema
365        self.schema = Some(schema.clone());
366        schema
367    }
368
369    /// Encodes batch into one or more `FlightData` messages in self.queue
370    fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371        let schema = match &self.schema {
372            Some(schema) => schema.clone(),
373            // encode the schema if this is the first time we have seen it
374            None => self.encode_schema(batch.schema_ref()),
375        };
376
377        let batch = match self.dictionary_handling {
378            DictionaryHandling::Resend => batch,
379            DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380        };
381
382        for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383            let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385            self.queue_messages(flight_dictionaries);
386            self.queue_message(flight_batch);
387        }
388
389        Ok(())
390    }
391}
392
393impl Stream for FlightDataEncoder {
394    type Item = Result<FlightData>;
395
396    fn poll_next(
397        mut self: Pin<&mut Self>,
398        cx: &mut std::task::Context<'_>,
399    ) -> Poll<Option<Self::Item>> {
400        loop {
401            if self.done && self.queue.is_empty() {
402                return Poll::Ready(None);
403            }
404
405            // Any messages queued to send?
406            if let Some(data) = self.queue.pop_front() {
407                return Poll::Ready(Some(Ok(data)));
408            }
409
410            // Get next batch
411            let batch = ready!(self.inner.poll_next_unpin(cx));
412
413            match batch {
414                None => {
415                    // inner is done
416                    self.done = true;
417                    // queue must also be empty so we are done
418                    assert!(self.queue.is_empty());
419                    return Poll::Ready(None);
420                }
421                Some(Err(e)) => {
422                    // error from inner
423                    self.done = true;
424                    self.queue.clear();
425                    return Poll::Ready(Some(Err(e)));
426                }
427                Some(Ok(batch)) => {
428                    // had data, encode into the queue
429                    if let Err(e) = self.encode_batch(batch) {
430                        self.done = true;
431                        self.queue.clear();
432                        return Poll::Ready(Some(Err(e)));
433                    }
434                }
435            }
436        }
437    }
438}
439
440/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s
441///
442/// [`DictionaryArray`]: arrow_array::DictionaryArray
443///
444/// In the arrow flight protocol dictionary values and keys are sent as two separate messages.
445/// When a sender is encoding a [`RecordBatch`] containing ['DictionaryArray'] columns, it will
446/// first send a dictionary batch (a batch with header `MessageHeader::DictionaryBatch`) containing
447/// the dictionary values. The receiver is responsible for reading this batch and maintaining state that associates
448/// those dictionary values with the corresponding array using the `dict_id` as a key.
449///
450/// After sending the dictionary batch the sender will send the array data in a batch with header `MessageHeader::RecordBatch`.
451/// For any dictionary array batches in this message, the encoded flight message will only contain the dictionary keys. The receiver
452/// is then responsible for rebuilding the `DictionaryArray` on the client side using the dictionary values from the DictionaryBatch message
453/// and the keys from the RecordBatch message.
454///
455/// For example, if we have a batch with a `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` (a dictionary array where they keys are `u32` and the
456/// values are `String`), then the DictionaryBatch will contain a `StringArray` and the RecordBatch will contain a `UInt32Array`.
457///
458/// Note that since `dict_id` defined in the `Schema` is used as a key to associate dictionary values to their arrays it is required that each
459/// `DictionaryArray` in a `RecordBatch` have a unique `dict_id`.
460///
461/// The current implementation does not support "delta" dictionaries so a new dictionary batch will be sent each time the encoder sees a
462/// dictionary which is not pointer-equal to the previously observed dictionary for a given `dict_id`.
463///
464/// For clients which may not support `DictionaryEncoding`, the `DictionaryHandling::Hydrate` method will bypass the process defined above
465/// and "hydrate" any `DictionaryArray` in the batch to their underlying value type (e.g. `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` will
466/// be sent as a `StringArray`). With this method all data will be sent in ``MessageHeader::RecordBatch` messages and the batch schema
467/// will be adjusted so that all dictionary encoded fields are changed to fields of the dictionary value type.
468#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470    /// Expands to the underlying type (default). This likely sends more data
471    /// over the network but requires less memory (dictionaries are not tracked)
472    /// and is more compatible with other arrow flight client implementations
473    /// that may not support `DictionaryEncoding`
474    ///
475    /// See also:
476    /// * <https://github.com/apache/arrow-rs/issues/1206>
477    Hydrate,
478    /// Send dictionary FlightData with every RecordBatch that contains a
479    /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No
480    /// attempt is made to skip sending the same (logical) dictionary values
481    /// twice.
482    ///
483    /// [`DictionaryArray`]: arrow_array::DictionaryArray
484    ///
485    /// This requires identifying the different dictionaries in use and assigning
486    //  them unique IDs
487    Resend,
488}
489
490fn prepare_field_for_flight(
491    field: &FieldRef,
492    dictionary_tracker: &mut DictionaryTracker,
493    send_dictionaries: bool,
494) -> Field {
495    match field.data_type() {
496        DataType::List(inner) => Field::new_list(
497            field.name(),
498            prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499            field.is_nullable(),
500        )
501        .with_metadata(field.metadata().clone()),
502        DataType::LargeList(inner) => Field::new_list(
503            field.name(),
504            prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505            field.is_nullable(),
506        )
507        .with_metadata(field.metadata().clone()),
508        DataType::Struct(fields) => {
509            let new_fields: Vec<Field> = fields
510                .iter()
511                .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512                .collect();
513            Field::new_struct(field.name(), new_fields, field.is_nullable())
514                .with_metadata(field.metadata().clone())
515        }
516        DataType::Union(fields, mode) => {
517            let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518                .iter()
519                .map(|(type_id, f)| {
520                    (
521                        type_id,
522                        prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523                    )
524                })
525                .unzip();
526
527            Field::new_union(field.name(), type_ids, new_fields, *mode)
528        }
529        DataType::Dictionary(_, value_type) => {
530            if !send_dictionaries {
531                Field::new(
532                    field.name(),
533                    value_type.as_ref().clone(),
534                    field.is_nullable(),
535                )
536                .with_metadata(field.metadata().clone())
537            } else {
538                #[allow(deprecated)]
539                let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
540
541                #[allow(deprecated)]
542                Field::new_dict(
543                    field.name(),
544                    field.data_type().clone(),
545                    field.is_nullable(),
546                    dict_id,
547                    field.dict_is_ordered().unwrap_or_default(),
548                )
549                .with_metadata(field.metadata().clone())
550            }
551        }
552        DataType::Map(inner, sorted) => Field::new(
553            field.name(),
554            DataType::Map(
555                prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
556                *sorted,
557            ),
558            field.is_nullable(),
559        )
560        .with_metadata(field.metadata().clone()),
561        _ => field.as_ref().clone(),
562    }
563}
564
565/// Prepare an arrow Schema for transport over the Arrow Flight protocol
566///
567/// Convert dictionary types to underlying types
568///
569/// See hydrate_dictionary for more information
570fn prepare_schema_for_flight(
571    schema: &Schema,
572    dictionary_tracker: &mut DictionaryTracker,
573    send_dictionaries: bool,
574) -> Schema {
575    let fields: Fields = schema
576        .fields()
577        .iter()
578        .map(|field| match field.data_type() {
579            DataType::Dictionary(_, value_type) => {
580                if !send_dictionaries {
581                    Field::new(
582                        field.name(),
583                        value_type.as_ref().clone(),
584                        field.is_nullable(),
585                    )
586                    .with_metadata(field.metadata().clone())
587                } else {
588                    #[allow(deprecated)]
589                    let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
590                    #[allow(deprecated)]
591                    Field::new_dict(
592                        field.name(),
593                        field.data_type().clone(),
594                        field.is_nullable(),
595                        dict_id,
596                        field.dict_is_ordered().unwrap_or_default(),
597                    )
598                    .with_metadata(field.metadata().clone())
599                }
600            }
601            tpe if tpe.is_nested() => {
602                prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
603            }
604            _ => field.as_ref().clone(),
605        })
606        .collect();
607
608    Schema::new(fields).with_metadata(schema.metadata().clone())
609}
610
611/// Split [`RecordBatch`] so it hopefully fits into a gRPC response.
612///
613/// Data is zero-copy sliced into batches.
614///
615/// Note: this method does not take into account already sliced
616/// arrays: <https://github.com/apache/arrow-rs/issues/3407>
617fn split_batch_for_grpc_response(
618    batch: RecordBatch,
619    max_flight_data_size: usize,
620) -> Vec<RecordBatch> {
621    let size = batch
622        .columns()
623        .iter()
624        .map(|col| col.get_buffer_memory_size())
625        .sum::<usize>();
626
627    let n_batches =
628        (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
629    let rows_per_batch = (batch.num_rows() / n_batches).max(1);
630    let mut out = Vec::with_capacity(n_batches + 1);
631
632    let mut offset = 0;
633    while offset < batch.num_rows() {
634        let length = (rows_per_batch).min(batch.num_rows() - offset);
635        out.push(batch.slice(offset, length));
636
637        offset += length;
638    }
639
640    out
641}
642
643/// The data needed to encode a stream of flight data, holding on to
644/// shared Dictionaries.
645///
646/// TODO: at allow dictionaries to be flushed / avoid building them
647///
648/// TODO limit on the number of dictionaries???
649struct FlightIpcEncoder {
650    options: IpcWriteOptions,
651    data_gen: IpcDataGenerator,
652    dictionary_tracker: DictionaryTracker,
653}
654
655impl FlightIpcEncoder {
656    fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
657        #[allow(deprecated)]
658        let preserve_dict_id = options.preserve_dict_id();
659        Self {
660            options,
661            data_gen: IpcDataGenerator::default(),
662            #[allow(deprecated)]
663            dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id(
664                error_on_replacement,
665                preserve_dict_id,
666            ),
667        }
668    }
669
670    /// Encode a schema as a FlightData
671    fn encode_schema(&self, schema: &Schema) -> FlightData {
672        SchemaAsIpc::new(schema, &self.options).into()
673    }
674
675    /// Convert a `RecordBatch` to a Vec of `FlightData` representing
676    /// dictionaries and a `FlightData` representing the batch
677    fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
678        let (encoded_dictionaries, encoded_batch) =
679            self.data_gen
680                .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
681
682        let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
683        let flight_batch = encoded_batch.into();
684
685        Ok((flight_dictionaries, flight_batch))
686    }
687}
688
689/// Hydrates any dictionaries arrays in `batch` to its underlying type. See
690/// hydrate_dictionary for more information.
691fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
692    let columns = schema
693        .fields()
694        .iter()
695        .zip(batch.columns())
696        .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
697        .collect::<Result<Vec<_>>>()?;
698
699    let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
700
701    Ok(RecordBatch::try_new_with_options(
702        schema, columns, &options,
703    )?)
704}
705
706/// Hydrates a dictionary to its underlying type.
707fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
708    let arr = match (array.data_type(), data_type) {
709        (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
710            let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
711
712            Arc::new(UnionArray::try_new(
713                fields.clone(),
714                union_arr.type_ids().clone(),
715                None,
716                fields
717                    .iter()
718                    .map(|(type_id, field)| {
719                        Ok(arrow_cast::cast(
720                            union_arr.child(type_id),
721                            field.data_type(),
722                        )?)
723                    })
724                    .collect::<Result<Vec<_>>>()?,
725            )?)
726        }
727        (_, data_type) => arrow_cast::cast(array, data_type)?,
728    };
729    Ok(arr)
730}
731
732#[cfg(test)]
733mod tests {
734    use crate::decode::{DecodedPayload, FlightDataDecoder};
735    use arrow_array::builder::{
736        GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
737    };
738    use arrow_array::*;
739    use arrow_array::{cast::downcast_array, types::*};
740    use arrow_buffer::ScalarBuffer;
741    use arrow_cast::pretty::pretty_format_batches;
742    use arrow_ipc::MetadataVersion;
743    use arrow_schema::{UnionFields, UnionMode};
744    use builder::{GenericStringBuilder, MapBuilder};
745    use std::collections::HashMap;
746
747    use super::*;
748
749    #[test]
750    /// ensure only the batch's used data (not the allocated data) is sent
751    /// <https://github.com/apache/arrow-rs/issues/208>
752    fn test_encode_flight_data() {
753        // use 8-byte alignment - default alignment is 64 which produces bigger ipc data
754        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
755        let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
756
757        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
758            .expect("cannot create record batch");
759        let schema = batch.schema_ref();
760
761        let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
762
763        let big_batch = batch.slice(0, batch.num_rows() - 1);
764        let optimized_big_batch =
765            hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
766        let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
767
768        assert_eq!(
769            baseline_flight_batch.data_body.len(),
770            optimized_big_flight_batch.data_body.len()
771        );
772
773        let small_batch = batch.slice(0, 1);
774        let optimized_small_batch =
775            hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
776        let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
777
778        assert!(
779            baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
780        );
781    }
782
783    #[tokio::test]
784    async fn test_dictionary_hydration() {
785        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
786        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
787
788        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
789            "dict",
790            DataType::UInt16,
791            DataType::Utf8,
792            false,
793        )]));
794        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
795        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
796
797        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
798
799        let encoder = FlightDataEncoderBuilder::default().build(stream);
800        let mut decoder = FlightDataDecoder::new(encoder);
801        let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
802        let expected_schema = Arc::new(expected_schema);
803        let mut expected_arrays = vec![
804            StringArray::from(vec!["a", "a", "b"]),
805            StringArray::from(vec!["c", "c", "d"]),
806        ]
807        .into_iter();
808        while let Some(decoded) = decoder.next().await {
809            let decoded = decoded.unwrap();
810            match decoded.payload {
811                DecodedPayload::None => {}
812                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
813                DecodedPayload::RecordBatch(b) => {
814                    assert_eq!(b.schema(), expected_schema);
815                    let expected_array = expected_arrays.next().unwrap();
816                    let actual_array = b.column_by_name("dict").unwrap();
817                    let actual_array = downcast_array::<StringArray>(actual_array);
818
819                    assert_eq!(actual_array, expected_array);
820                }
821            }
822        }
823    }
824
825    #[tokio::test]
826    async fn test_dictionary_resend() {
827        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
828        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
829
830        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
831            "dict",
832            DataType::UInt16,
833            DataType::Utf8,
834            false,
835        )]));
836        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
837        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
838
839        verify_flight_round_trip(vec![batch1, batch2]).await;
840    }
841
842    #[tokio::test]
843    async fn test_dictionary_hydration_known_schema() {
844        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
845        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
846
847        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
848            "dict",
849            DataType::UInt16,
850            DataType::Utf8,
851            false,
852        )]));
853        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
854        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
855
856        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
857
858        let encoder = FlightDataEncoderBuilder::default()
859            .with_schema(schema)
860            .build(stream);
861        let expected_schema =
862            Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
863        assert_eq!(Some(expected_schema), encoder.known_schema())
864    }
865
866    #[tokio::test]
867    async fn test_dictionary_resend_known_schema() {
868        let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
869        let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
870
871        let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
872            "dict",
873            DataType::UInt16,
874            DataType::Utf8,
875            false,
876        )]));
877        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
878        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
879
880        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
881
882        let encoder = FlightDataEncoderBuilder::default()
883            .with_dictionary_handling(DictionaryHandling::Resend)
884            .with_schema(schema.clone())
885            .build(stream);
886        assert_eq!(Some(schema), encoder.known_schema())
887    }
888
889    #[tokio::test]
890    async fn test_multiple_dictionaries_resend() {
891        // Create a schema with two dictionary fields that have the same dict ID
892        let schema = Arc::new(Schema::new(vec![
893            Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
894            Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
895        ]));
896
897        let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
898            Arc::new(vec!["a", "a", "b"].into_iter().collect());
899        let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
900            Arc::new(vec!["c", "c", "d"].into_iter().collect());
901        let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
902            Arc::new(vec!["b", "a", "c"].into_iter().collect());
903        let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
904            Arc::new(vec!["k", "d", "e"].into_iter().collect());
905        let batch1 =
906            RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
907                .unwrap();
908        let batch2 =
909            RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
910                .unwrap();
911
912        verify_flight_round_trip(vec![batch1, batch2]).await;
913    }
914
915    #[tokio::test]
916    async fn test_dictionary_list_hydration() {
917        let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
918
919        builder.append_value(vec![Some("a"), None, Some("b")]);
920
921        let arr1 = builder.finish();
922
923        builder.append_value(vec![Some("c"), None, Some("d")]);
924
925        let arr2 = builder.finish();
926
927        let schema = Arc::new(Schema::new(vec![Field::new_list(
928            "dict_list",
929            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
930            true,
931        )]));
932
933        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
934        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
935
936        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
937
938        let encoder = FlightDataEncoderBuilder::default().build(stream);
939
940        let mut decoder = FlightDataDecoder::new(encoder);
941        let expected_schema = Schema::new(vec![Field::new_list(
942            "dict_list",
943            Field::new_list_field(DataType::Utf8, true),
944            true,
945        )]);
946
947        let expected_schema = Arc::new(expected_schema);
948
949        let mut expected_arrays = vec![
950            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
951            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
952        ]
953        .into_iter();
954
955        while let Some(decoded) = decoder.next().await {
956            let decoded = decoded.unwrap();
957            match decoded.payload {
958                DecodedPayload::None => {}
959                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
960                DecodedPayload::RecordBatch(b) => {
961                    assert_eq!(b.schema(), expected_schema);
962                    let expected_array = expected_arrays.next().unwrap();
963                    let list_array =
964                        downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
965                    let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
966
967                    assert_eq!(elem_array, expected_array);
968                }
969            }
970        }
971    }
972
973    #[tokio::test]
974    async fn test_dictionary_list_resend() {
975        let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
976
977        builder.append_value(vec![Some("a"), None, Some("b")]);
978
979        let arr1 = builder.finish();
980
981        builder.append_value(vec![Some("c"), None, Some("d")]);
982
983        let arr2 = builder.finish();
984
985        let schema = Arc::new(Schema::new(vec![Field::new_list(
986            "dict_list",
987            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
988            true,
989        )]));
990
991        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
992        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
993
994        verify_flight_round_trip(vec![batch1, batch2]).await;
995    }
996
997    #[tokio::test]
998    async fn test_dictionary_struct_hydration() {
999        let struct_fields = vec![Field::new_list(
1000            "dict_list",
1001            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1002            true,
1003        )];
1004
1005        let mut struct_builder = StructBuilder::new(
1006            struct_fields.clone(),
1007            vec![Box::new(builder::ListBuilder::new(
1008                StringDictionaryBuilder::<UInt16Type>::new(),
1009            ))],
1010        );
1011
1012        struct_builder
1013            .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1014            .unwrap()
1015            .append_value(vec![Some("a"), None, Some("b")]);
1016
1017        struct_builder.append(true);
1018
1019        let arr1 = struct_builder.finish();
1020
1021        struct_builder
1022            .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1023            .unwrap()
1024            .append_value(vec![Some("c"), None, Some("d")]);
1025        struct_builder.append(true);
1026
1027        let arr2 = struct_builder.finish();
1028
1029        let schema = Arc::new(Schema::new(vec![Field::new_struct(
1030            "struct",
1031            struct_fields,
1032            true,
1033        )]));
1034
1035        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1036        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1037
1038        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1039
1040        let encoder = FlightDataEncoderBuilder::default().build(stream);
1041
1042        let mut decoder = FlightDataDecoder::new(encoder);
1043        let expected_schema = Schema::new(vec![Field::new_struct(
1044            "struct",
1045            vec![Field::new_list(
1046                "dict_list",
1047                Field::new_list_field(DataType::Utf8, true),
1048                true,
1049            )],
1050            true,
1051        )]);
1052
1053        let expected_schema = Arc::new(expected_schema);
1054
1055        let mut expected_arrays = vec![
1056            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1057            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1058        ]
1059        .into_iter();
1060
1061        while let Some(decoded) = decoder.next().await {
1062            let decoded = decoded.unwrap();
1063            match decoded.payload {
1064                DecodedPayload::None => {}
1065                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1066                DecodedPayload::RecordBatch(b) => {
1067                    assert_eq!(b.schema(), expected_schema);
1068                    let expected_array = expected_arrays.next().unwrap();
1069                    let struct_array =
1070                        downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1071                    let list_array = downcast_array::<ListArray>(struct_array.column(0));
1072
1073                    let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1074
1075                    assert_eq!(elem_array, expected_array);
1076                }
1077            }
1078        }
1079    }
1080
1081    #[tokio::test]
1082    async fn test_dictionary_struct_resend() {
1083        let struct_fields = vec![Field::new_list(
1084            "dict_list",
1085            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1086            true,
1087        )];
1088
1089        let mut struct_builder = StructBuilder::new(
1090            struct_fields.clone(),
1091            vec![Box::new(builder::ListBuilder::new(
1092                StringDictionaryBuilder::<UInt16Type>::new(),
1093            ))],
1094        );
1095
1096        struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1097            .unwrap()
1098            .append_value(vec![Some("a"), None, Some("b")]);
1099        struct_builder.append(true);
1100
1101        let arr1 = struct_builder.finish();
1102
1103        struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1104            .unwrap()
1105            .append_value(vec![Some("c"), None, Some("d")]);
1106        struct_builder.append(true);
1107
1108        let arr2 = struct_builder.finish();
1109
1110        let schema = Arc::new(Schema::new(vec![Field::new_struct(
1111            "struct",
1112            struct_fields,
1113            true,
1114        )]));
1115
1116        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1117        let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1118
1119        verify_flight_round_trip(vec![batch1, batch2]).await;
1120    }
1121
1122    #[tokio::test]
1123    async fn test_dictionary_union_hydration() {
1124        let struct_fields = vec![Field::new_list(
1125            "dict_list",
1126            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1127            true,
1128        )];
1129
1130        let union_fields = [
1131            (
1132                0,
1133                Arc::new(Field::new_list(
1134                    "dict_list",
1135                    Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1136                    true,
1137                )),
1138            ),
1139            (
1140                1,
1141                Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1142            ),
1143            (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1144        ]
1145        .into_iter()
1146        .collect::<UnionFields>();
1147
1148        let struct_fields = vec![Field::new_list(
1149            "dict_list",
1150            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1151            true,
1152        )];
1153
1154        let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1155
1156        builder.append_value(vec![Some("a"), None, Some("b")]);
1157
1158        let arr1 = builder.finish();
1159
1160        let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1161        let arr1 = UnionArray::try_new(
1162            union_fields.clone(),
1163            type_id_buffer,
1164            None,
1165            vec![
1166                Arc::new(arr1) as Arc<dyn Array>,
1167                new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1168                new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1169            ],
1170        )
1171        .unwrap();
1172
1173        builder.append_value(vec![Some("c"), None, Some("d")]);
1174
1175        let arr2 = Arc::new(builder.finish());
1176        let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1177
1178        let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1179        let arr2 = UnionArray::try_new(
1180            union_fields.clone(),
1181            type_id_buffer,
1182            None,
1183            vec![
1184                new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1185                Arc::new(arr2),
1186                new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1187            ],
1188        )
1189        .unwrap();
1190
1191        let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1192        let arr3 = UnionArray::try_new(
1193            union_fields.clone(),
1194            type_id_buffer,
1195            None,
1196            vec![
1197                new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1198                new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1199                Arc::new(StringArray::from(vec!["e"])),
1200            ],
1201        )
1202        .unwrap();
1203
1204        let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1205            .iter()
1206            .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1207            .unzip();
1208        let schema = Arc::new(Schema::new(vec![Field::new_union(
1209            "union",
1210            type_ids.clone(),
1211            union_fields.clone(),
1212            UnionMode::Sparse,
1213        )]));
1214
1215        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1216        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1217        let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1218
1219        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1220
1221        let encoder = FlightDataEncoderBuilder::default().build(stream);
1222
1223        let mut decoder = FlightDataDecoder::new(encoder);
1224
1225        let hydrated_struct_fields = vec![Field::new_list(
1226            "dict_list",
1227            Field::new_list_field(DataType::Utf8, true),
1228            true,
1229        )];
1230
1231        let hydrated_union_fields = vec![
1232            Field::new_list(
1233                "dict_list",
1234                Field::new_list_field(DataType::Utf8, true),
1235                true,
1236            ),
1237            Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1238            Field::new("string", DataType::Utf8, true),
1239        ];
1240
1241        let expected_schema = Schema::new(vec![Field::new_union(
1242            "union",
1243            type_ids.clone(),
1244            hydrated_union_fields,
1245            UnionMode::Sparse,
1246        )]);
1247
1248        let expected_schema = Arc::new(expected_schema);
1249
1250        let mut expected_arrays = vec![
1251            StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1252            StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1253            StringArray::from(vec!["e"]),
1254        ]
1255        .into_iter();
1256
1257        let mut batch = 0;
1258        while let Some(decoded) = decoder.next().await {
1259            let decoded = decoded.unwrap();
1260            match decoded.payload {
1261                DecodedPayload::None => {}
1262                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1263                DecodedPayload::RecordBatch(b) => {
1264                    assert_eq!(b.schema(), expected_schema);
1265                    let expected_array = expected_arrays.next().unwrap();
1266                    let union_arr =
1267                        downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1268
1269                    let elem_array = match batch {
1270                        0 => {
1271                            let list_array = downcast_array::<ListArray>(union_arr.child(0));
1272                            downcast_array::<StringArray>(list_array.value(0).as_ref())
1273                        }
1274                        1 => {
1275                            let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1276                            let list_array = downcast_array::<ListArray>(struct_array.column(0));
1277
1278                            downcast_array::<StringArray>(list_array.value(0).as_ref())
1279                        }
1280                        _ => downcast_array::<StringArray>(union_arr.child(2)),
1281                    };
1282
1283                    batch += 1;
1284
1285                    assert_eq!(elem_array, expected_array);
1286                }
1287            }
1288        }
1289    }
1290
1291    #[tokio::test]
1292    async fn test_dictionary_union_resend() {
1293        let struct_fields = vec![Field::new_list(
1294            "dict_list",
1295            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1296            true,
1297        )];
1298
1299        let union_fields = [
1300            (
1301                0,
1302                Arc::new(Field::new_list(
1303                    "dict_list",
1304                    Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1305                    true,
1306                )),
1307            ),
1308            (
1309                1,
1310                Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1311            ),
1312            (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1313        ]
1314        .into_iter()
1315        .collect::<UnionFields>();
1316
1317        let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1318        let dict_list_ty = field_types.next().unwrap();
1319        let struct_ty = field_types.next().unwrap();
1320        let string_ty = field_types.next().unwrap();
1321
1322        let struct_fields = vec![Field::new_list(
1323            "dict_list",
1324            Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1325            true,
1326        )];
1327
1328        let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1329
1330        builder.append_value(vec![Some("a"), None, Some("b")]);
1331
1332        let arr1 = builder.finish();
1333
1334        let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1335        let arr1 = UnionArray::try_new(
1336            union_fields.clone(),
1337            type_id_buffer,
1338            None,
1339            vec![
1340                Arc::new(arr1),
1341                new_null_array(struct_ty, 1),
1342                new_null_array(string_ty, 1),
1343            ],
1344        )
1345        .unwrap();
1346
1347        builder.append_value(vec![Some("c"), None, Some("d")]);
1348
1349        let arr2 = Arc::new(builder.finish());
1350        let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1351
1352        let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1353        let arr2 = UnionArray::try_new(
1354            union_fields.clone(),
1355            type_id_buffer,
1356            None,
1357            vec![
1358                new_null_array(dict_list_ty, 1),
1359                Arc::new(arr2),
1360                new_null_array(string_ty, 1),
1361            ],
1362        )
1363        .unwrap();
1364
1365        let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1366        let arr3 = UnionArray::try_new(
1367            union_fields.clone(),
1368            type_id_buffer,
1369            None,
1370            vec![
1371                new_null_array(dict_list_ty, 1),
1372                new_null_array(struct_ty, 1),
1373                Arc::new(StringArray::from(vec!["e"])),
1374            ],
1375        )
1376        .unwrap();
1377
1378        let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1379            .iter()
1380            .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1381            .unzip();
1382        let schema = Arc::new(Schema::new(vec![Field::new_union(
1383            "union",
1384            type_ids.clone(),
1385            union_fields.clone(),
1386            UnionMode::Sparse,
1387        )]));
1388
1389        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1390        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1391        let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1392
1393        verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1394    }
1395
1396    #[tokio::test]
1397    async fn test_dictionary_map_hydration() {
1398        let mut builder = MapBuilder::new(
1399            None,
1400            StringDictionaryBuilder::<UInt16Type>::new(),
1401            StringDictionaryBuilder::<UInt16Type>::new(),
1402        );
1403
1404        // {"k1":"a","k2":null,"k3":"b"}
1405        builder.keys().append_value("k1");
1406        builder.values().append_value("a");
1407        builder.keys().append_value("k2");
1408        builder.values().append_null();
1409        builder.keys().append_value("k3");
1410        builder.values().append_value("b");
1411        builder.append(true).unwrap();
1412
1413        let arr1 = builder.finish();
1414
1415        // {"k1":"c","k2":null,"k3":"d"}
1416        builder.keys().append_value("k1");
1417        builder.values().append_value("c");
1418        builder.keys().append_value("k2");
1419        builder.values().append_null();
1420        builder.keys().append_value("k3");
1421        builder.values().append_value("d");
1422        builder.append(true).unwrap();
1423
1424        let arr2 = builder.finish();
1425
1426        let schema = Arc::new(Schema::new(vec![Field::new_map(
1427            "dict_map",
1428            "entries",
1429            Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1430            Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1431            false,
1432            false,
1433        )]));
1434
1435        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1436        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1437
1438        let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1439
1440        let encoder = FlightDataEncoderBuilder::default().build(stream);
1441
1442        let mut decoder = FlightDataDecoder::new(encoder);
1443        let expected_schema = Schema::new(vec![Field::new_map(
1444            "dict_map",
1445            "entries",
1446            Field::new("keys", DataType::Utf8, false),
1447            Field::new("values", DataType::Utf8, true),
1448            false,
1449            false,
1450        )]);
1451
1452        let expected_schema = Arc::new(expected_schema);
1453
1454        // Builder without dictionary fields
1455        let mut builder = MapBuilder::new(
1456            None,
1457            GenericStringBuilder::<i32>::new(),
1458            GenericStringBuilder::<i32>::new(),
1459        );
1460
1461        // {"k1":"a","k2":null,"k3":"b"}
1462        builder.keys().append_value("k1");
1463        builder.values().append_value("a");
1464        builder.keys().append_value("k2");
1465        builder.values().append_null();
1466        builder.keys().append_value("k3");
1467        builder.values().append_value("b");
1468        builder.append(true).unwrap();
1469
1470        let arr1 = builder.finish();
1471
1472        // {"k1":"c","k2":null,"k3":"d"}
1473        builder.keys().append_value("k1");
1474        builder.values().append_value("c");
1475        builder.keys().append_value("k2");
1476        builder.values().append_null();
1477        builder.keys().append_value("k3");
1478        builder.values().append_value("d");
1479        builder.append(true).unwrap();
1480
1481        let arr2 = builder.finish();
1482
1483        let mut expected_arrays = vec![arr1, arr2].into_iter();
1484
1485        while let Some(decoded) = decoder.next().await {
1486            let decoded = decoded.unwrap();
1487            match decoded.payload {
1488                DecodedPayload::None => {}
1489                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1490                DecodedPayload::RecordBatch(b) => {
1491                    assert_eq!(b.schema(), expected_schema);
1492                    let expected_array = expected_arrays.next().unwrap();
1493                    let map_array =
1494                        downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1495
1496                    assert_eq!(map_array, expected_array);
1497                }
1498            }
1499        }
1500    }
1501
1502    #[tokio::test]
1503    async fn test_dictionary_map_resend() {
1504        let mut builder = MapBuilder::new(
1505            None,
1506            StringDictionaryBuilder::<UInt16Type>::new(),
1507            StringDictionaryBuilder::<UInt16Type>::new(),
1508        );
1509
1510        // {"k1":"a","k2":null,"k3":"b"}
1511        builder.keys().append_value("k1");
1512        builder.values().append_value("a");
1513        builder.keys().append_value("k2");
1514        builder.values().append_null();
1515        builder.keys().append_value("k3");
1516        builder.values().append_value("b");
1517        builder.append(true).unwrap();
1518
1519        let arr1 = builder.finish();
1520
1521        // {"k1":"c","k2":null,"k3":"d"}
1522        builder.keys().append_value("k1");
1523        builder.values().append_value("c");
1524        builder.keys().append_value("k2");
1525        builder.values().append_null();
1526        builder.keys().append_value("k3");
1527        builder.values().append_value("d");
1528        builder.append(true).unwrap();
1529
1530        let arr2 = builder.finish();
1531
1532        let schema = Arc::new(Schema::new(vec![Field::new_map(
1533            "dict_map",
1534            "entries",
1535            Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1536            Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1537            false,
1538            false,
1539        )]));
1540
1541        let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1542        let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1543
1544        verify_flight_round_trip(vec![batch1, batch2]).await;
1545    }
1546
1547    async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1548        let expected_schema = batches.first().unwrap().schema();
1549
1550        #[allow(deprecated)]
1551        let encoder = FlightDataEncoderBuilder::default()
1552            .with_options(IpcWriteOptions::default().with_preserve_dict_id(false))
1553            .with_dictionary_handling(DictionaryHandling::Resend)
1554            .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1555
1556        let mut expected_batches = batches.drain(..);
1557
1558        let mut decoder = FlightDataDecoder::new(encoder);
1559        while let Some(decoded) = decoder.next().await {
1560            let decoded = decoded.unwrap();
1561            match decoded.payload {
1562                DecodedPayload::None => {}
1563                DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1564                DecodedPayload::RecordBatch(b) => {
1565                    let expected_batch = expected_batches.next().unwrap();
1566                    assert_eq!(b, expected_batch);
1567                }
1568            }
1569        }
1570    }
1571
1572    #[test]
1573    fn test_schema_metadata_encoded() {
1574        let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1575            HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1576        );
1577
1578        #[allow(deprecated)]
1579        let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
1580
1581        let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1582        assert!(got.metadata().contains_key("some_key"));
1583    }
1584
1585    #[test]
1586    fn test_encode_no_column_batch() {
1587        let batch = RecordBatch::try_new_with_options(
1588            Arc::new(Schema::empty()),
1589            vec![],
1590            &RecordBatchOptions::new().with_row_count(Some(10)),
1591        )
1592        .expect("cannot create record batch");
1593
1594        hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1595    }
1596
1597    fn make_flight_data(
1598        batch: &RecordBatch,
1599        options: &IpcWriteOptions,
1600    ) -> (Vec<FlightData>, FlightData) {
1601        flight_data_from_arrow_batch(batch, options)
1602    }
1603
1604    fn flight_data_from_arrow_batch(
1605        batch: &RecordBatch,
1606        options: &IpcWriteOptions,
1607    ) -> (Vec<FlightData>, FlightData) {
1608        let data_gen = IpcDataGenerator::default();
1609        #[allow(deprecated)]
1610        let mut dictionary_tracker =
1611            DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
1612
1613        let (encoded_dictionaries, encoded_batch) = data_gen
1614            .encoded_batch(batch, &mut dictionary_tracker, options)
1615            .expect("DictionaryTracker configured above to not error on replacement");
1616
1617        let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1618        let flight_batch = encoded_batch.into();
1619
1620        (flight_dictionaries, flight_batch)
1621    }
1622
1623    #[test]
1624    fn test_split_batch_for_grpc_response() {
1625        let max_flight_data_size = 1024;
1626
1627        // no split
1628        let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1629        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1630            .expect("cannot create record batch");
1631        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1632        assert_eq!(split.len(), 1);
1633        assert_eq!(batch, split[0]);
1634
1635        // split once
1636        let n_rows = max_flight_data_size + 1;
1637        assert!(n_rows % 2 == 1, "should be an odd number");
1638        let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1639        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1640            .expect("cannot create record batch");
1641        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1642        assert_eq!(split.len(), 3);
1643        assert_eq!(
1644            split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1645            n_rows
1646        );
1647        let a = pretty_format_batches(&split).unwrap().to_string();
1648        let b = pretty_format_batches(&[batch]).unwrap().to_string();
1649        assert_eq!(a, b);
1650    }
1651
1652    #[test]
1653    fn test_split_batch_for_grpc_response_sizes() {
1654        // 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows
1655        verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1656
1657        // 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows
1658        verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1659
1660        // 2023 8 byte entries into 3k pieces does not divide evenly
1661        verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1662
1663        // 10 8 byte entries into 1 byte pieces means each rows gets its own
1664        verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1665
1666        // 10 8 byte entries into 1k byte pieces means one piece
1667        verify_split(10, 1024, vec![10]);
1668    }
1669
1670    /// Creates a UInt64Array of 8 byte integers with input_rows rows
1671    /// `max_flight_data_size_bytes` pieces and verifies the row counts in
1672    /// those pieces
1673    fn verify_split(
1674        num_input_rows: u64,
1675        max_flight_data_size_bytes: usize,
1676        expected_sizes: Vec<usize>,
1677    ) {
1678        let array: UInt64Array = (0..num_input_rows).collect();
1679
1680        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1681            .expect("cannot create record batch");
1682
1683        let input_rows = batch.num_rows();
1684
1685        let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1686        let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1687        let output_rows: usize = sizes.iter().sum();
1688
1689        assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1690        assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1691    }
1692
1693    // test sending record batches
1694    // test sending record batches with multiple different dictionaries
1695
1696    #[tokio::test]
1697    async fn flight_data_size_even() {
1698        let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024));
1699        let i1 = Int16Array::from_iter_values(0..1024);
1700        let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024));
1701        let i2 = Int64Array::from_iter_values(0..1024);
1702
1703        let batch = RecordBatch::try_from_iter(vec![
1704            ("s1", Arc::new(s1) as _),
1705            ("i1", Arc::new(i1) as _),
1706            ("s2", Arc::new(s2) as _),
1707            ("i2", Arc::new(i2) as _),
1708        ])
1709        .unwrap();
1710
1711        verify_encoded_split(batch, 120).await;
1712    }
1713
1714    #[tokio::test]
1715    async fn flight_data_size_uneven_variable_lengths() {
1716        // each row has a longer string than the last with increasing lengths 0 --> 1024
1717        let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1718        let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1719
1720        // overage is much higher than ideal
1721        // https://github.com/apache/arrow-rs/issues/3478
1722        verify_encoded_split(batch, 4312).await;
1723    }
1724
1725    #[tokio::test]
1726    async fn flight_data_size_large_row() {
1727        // batch with individual that can each exceed the batch size
1728        let array1 = StringArray::from_iter_values(vec![
1729            "*".repeat(500),
1730            "*".repeat(500),
1731            "*".repeat(500),
1732            "*".repeat(500),
1733        ]);
1734        let array2 = StringArray::from_iter_values(vec![
1735            "*".to_string(),
1736            "*".repeat(1000),
1737            "*".repeat(2000),
1738            "*".repeat(4000),
1739        ]);
1740
1741        let array3 = StringArray::from_iter_values(vec![
1742            "*".to_string(),
1743            "*".to_string(),
1744            "*".repeat(1000),
1745            "*".repeat(2000),
1746        ]);
1747
1748        let batch = RecordBatch::try_from_iter(vec![
1749            ("a1", Arc::new(array1) as _),
1750            ("a2", Arc::new(array2) as _),
1751            ("a3", Arc::new(array3) as _),
1752        ])
1753        .unwrap();
1754
1755        // 5k over limit (which is 2x larger than limit of 5k)
1756        // overage is much higher than ideal
1757        // https://github.com/apache/arrow-rs/issues/3478
1758        verify_encoded_split(batch, 5808).await;
1759    }
1760
1761    #[tokio::test]
1762    async fn flight_data_size_string_dictionary() {
1763        // Small dictionary (only 2 distinct values ==> 2 entries in dictionary)
1764        let array: DictionaryArray<Int32Type> = (1..1024)
1765            .map(|i| match i % 3 {
1766                0 => Some("value0"),
1767                1 => Some("value1"),
1768                _ => None,
1769            })
1770            .collect();
1771
1772        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1773
1774        verify_encoded_split(batch, 56).await;
1775    }
1776
1777    #[tokio::test]
1778    async fn flight_data_size_large_dictionary() {
1779        // large dictionary (all distinct values ==> 1024 entries in dictionary)
1780        let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1781
1782        let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1783
1784        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1785
1786        // overage is much higher than ideal
1787        // https://github.com/apache/arrow-rs/issues/3478
1788        verify_encoded_split(batch, 3336).await;
1789    }
1790
1791    #[tokio::test]
1792    async fn flight_data_size_large_dictionary_repeated_non_uniform() {
1793        // large dictionary (1024 distinct values) that are used throughout the array
1794        let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
1795        let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
1796        let array = DictionaryArray::new(keys, Arc::new(values));
1797
1798        let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1799
1800        // overage is much higher than ideal
1801        // https://github.com/apache/arrow-rs/issues/3478
1802        verify_encoded_split(batch, 5288).await;
1803    }
1804
1805    #[tokio::test]
1806    async fn flight_data_size_multiple_dictionaries() {
1807        // high cardinality
1808        let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1809        // highish cardinality
1810        let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
1811        // medium cardinality
1812        let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
1813
1814        let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
1815        let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
1816        let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
1817
1818        let batch = RecordBatch::try_from_iter(vec![
1819            ("a1", Arc::new(array1) as _),
1820            ("a2", Arc::new(array2) as _),
1821            ("a3", Arc::new(array3) as _),
1822        ])
1823        .unwrap();
1824
1825        // overage is much higher than ideal
1826        // https://github.com/apache/arrow-rs/issues/3478
1827        verify_encoded_split(batch, 4136).await;
1828    }
1829
1830    /// Return size, in memory of flight data
1831    fn flight_data_size(d: &FlightData) -> usize {
1832        let flight_descriptor_size = d
1833            .flight_descriptor
1834            .as_ref()
1835            .map(|descriptor| {
1836                let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
1837
1838                std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
1839            })
1840            .unwrap_or(0);
1841
1842        flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
1843    }
1844
1845    /// Coverage for <https://github.com/apache/arrow-rs/issues/3478>
1846    ///
1847    /// Encodes the specified batch using several values of
1848    /// `max_flight_data_size` between 1K to 5K and ensures that the
1849    /// resulting size of the flight data stays within the limit
1850    /// + `allowed_overage`
1851    ///
1852    /// `allowed_overage` is how far off the actual data encoding is
1853    /// from the target limit that was set. It is an improvement when
1854    /// the allowed_overage decreses.
1855    ///
1856    /// Note this overhead will likely always be greater than zero to
1857    /// account for encoding overhead such as IPC headers and padding.
1858    ///
1859    ///
1860    async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
1861        let num_rows = batch.num_rows();
1862
1863        // Track the overall required maximum overage
1864        let mut max_overage_seen = 0;
1865
1866        for max_flight_data_size in [1024, 2021, 5000] {
1867            println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
1868
1869            let mut stream = FlightDataEncoderBuilder::new()
1870                .with_max_flight_data_size(max_flight_data_size)
1871                // use 8-byte alignment - default alignment is 64 which produces bigger ipc data
1872                .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
1873                .build(futures::stream::iter([Ok(batch.clone())]));
1874
1875            let mut i = 0;
1876            while let Some(data) = stream.next().await.transpose().unwrap() {
1877                let actual_data_size = flight_data_size(&data);
1878
1879                let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
1880
1881                assert!(
1882                    actual_overage <= allowed_overage,
1883                    "encoded data[{i}]: actual size {actual_data_size}, \
1884                         actual_overage: {actual_overage} \
1885                         allowed_overage: {allowed_overage}"
1886                );
1887
1888                i += 1;
1889
1890                max_overage_seen = max_overage_seen.max(actual_overage)
1891            }
1892        }
1893
1894        // ensure that the specified overage is exactly the maxmium so
1895        // that when the splitting logic improves, the tests must be
1896        // updated to reflect the better logic
1897        assert_eq!(
1898            allowed_overage, max_overage_seen,
1899            "Specified overage was too high"
1900        );
1901    }
1902}