1use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{ready, stream::BoxStream, Stream, StreamExt};
28
29#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146 max_flight_data_size: usize,
149 options: IpcWriteOptions,
151 app_metadata: Bytes,
153 schema: Option<SchemaRef>,
155 descriptor: Option<FlightDescriptor>,
157 dictionary_handling: DictionaryHandling,
160}
161
162pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169 fn default() -> Self {
170 Self {
171 max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172 options: IpcWriteOptions::default(),
173 app_metadata: Bytes::new(),
174 schema: None,
175 descriptor: None,
176 dictionary_handling: DictionaryHandling::Hydrate,
177 }
178 }
179}
180
181impl FlightDataEncoderBuilder {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198 self.max_flight_data_size = max_flight_data_size;
199 self
200 }
201
202 pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204 self.dictionary_handling = dictionary_handling;
205 self
206 }
207
208 pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212 self.app_metadata = app_metadata;
213 self
214 }
215
216 pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218 self.options = options;
219 self
220 }
221
222 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227 self.schema = Some(schema);
228 self
229 }
230
231 pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233 self.descriptor = descriptor;
234 self
235 }
236
237 pub fn build<S>(self, input: S) -> FlightDataEncoder
242 where
243 S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244 {
245 let Self {
246 max_flight_data_size,
247 options,
248 app_metadata,
249 schema,
250 descriptor,
251 dictionary_handling,
252 } = self;
253
254 FlightDataEncoder::new(
255 input.boxed(),
256 schema,
257 max_flight_data_size,
258 options,
259 app_metadata,
260 descriptor,
261 dictionary_handling,
262 )
263 }
264}
265
266pub struct FlightDataEncoder {
270 inner: BoxStream<'static, Result<RecordBatch>>,
272 schema: Option<SchemaRef>,
274 max_flight_data_size: usize,
277 encoder: FlightIpcEncoder,
279 app_metadata: Option<Bytes>,
281 queue: VecDeque<FlightData>,
283 done: bool,
285 descriptor: Option<FlightDescriptor>,
287 dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293 fn new(
294 inner: BoxStream<'static, Result<RecordBatch>>,
295 schema: Option<SchemaRef>,
296 max_flight_data_size: usize,
297 options: IpcWriteOptions,
298 app_metadata: Bytes,
299 descriptor: Option<FlightDescriptor>,
300 dictionary_handling: DictionaryHandling,
301 ) -> Self {
302 let mut encoder = Self {
303 inner,
304 schema: None,
305 max_flight_data_size,
306 encoder: FlightIpcEncoder::new(
307 options,
308 dictionary_handling != DictionaryHandling::Resend,
309 ),
310 app_metadata: Some(app_metadata),
311 queue: VecDeque::new(),
312 done: false,
313 descriptor,
314 dictionary_handling,
315 };
316
317 if let Some(schema) = schema {
319 encoder.encode_schema(&schema);
320 }
321
322 encoder
323 }
324
325 pub fn known_schema(&self) -> Option<SchemaRef> {
328 self.schema.clone()
329 }
330
331 fn queue_message(&mut self, mut data: FlightData) {
333 if let Some(descriptor) = self.descriptor.take() {
334 data.flight_descriptor = Some(descriptor);
335 }
336 self.queue.push_back(data);
337 }
338
339 fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341 for data in datas {
342 self.queue_message(data)
343 }
344 }
345
346 fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349 let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352 let schema = Arc::new(prepare_schema_for_flight(
353 schema,
354 &mut self.encoder.dictionary_tracker,
355 send_dictionaries,
356 ));
357 let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359 if let Some(app_metadata) = self.app_metadata.take() {
361 schema_flight_data.app_metadata = app_metadata;
362 }
363 self.queue_message(schema_flight_data);
364 self.schema = Some(schema.clone());
366 schema
367 }
368
369 fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371 let schema = match &self.schema {
372 Some(schema) => schema.clone(),
373 None => self.encode_schema(batch.schema_ref()),
375 };
376
377 let batch = match self.dictionary_handling {
378 DictionaryHandling::Resend => batch,
379 DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380 };
381
382 for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383 let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385 self.queue_messages(flight_dictionaries);
386 self.queue_message(flight_batch);
387 }
388
389 Ok(())
390 }
391}
392
393impl Stream for FlightDataEncoder {
394 type Item = Result<FlightData>;
395
396 fn poll_next(
397 mut self: Pin<&mut Self>,
398 cx: &mut std::task::Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 loop {
401 if self.done && self.queue.is_empty() {
402 return Poll::Ready(None);
403 }
404
405 if let Some(data) = self.queue.pop_front() {
407 return Poll::Ready(Some(Ok(data)));
408 }
409
410 let batch = ready!(self.inner.poll_next_unpin(cx));
412
413 match batch {
414 None => {
415 self.done = true;
417 assert!(self.queue.is_empty());
419 return Poll::Ready(None);
420 }
421 Some(Err(e)) => {
422 self.done = true;
424 self.queue.clear();
425 return Poll::Ready(Some(Err(e)));
426 }
427 Some(Ok(batch)) => {
428 if let Err(e) = self.encode_batch(batch) {
430 self.done = true;
431 self.queue.clear();
432 return Poll::Ready(Some(Err(e)));
433 }
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470 Hydrate,
478 Resend,
488}
489
490fn prepare_field_for_flight(
491 field: &FieldRef,
492 dictionary_tracker: &mut DictionaryTracker,
493 send_dictionaries: bool,
494) -> Field {
495 match field.data_type() {
496 DataType::List(inner) => Field::new_list(
497 field.name(),
498 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499 field.is_nullable(),
500 )
501 .with_metadata(field.metadata().clone()),
502 DataType::LargeList(inner) => Field::new_list(
503 field.name(),
504 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505 field.is_nullable(),
506 )
507 .with_metadata(field.metadata().clone()),
508 DataType::Struct(fields) => {
509 let new_fields: Vec<Field> = fields
510 .iter()
511 .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512 .collect();
513 Field::new_struct(field.name(), new_fields, field.is_nullable())
514 .with_metadata(field.metadata().clone())
515 }
516 DataType::Union(fields, mode) => {
517 let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518 .iter()
519 .map(|(type_id, f)| {
520 (
521 type_id,
522 prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523 )
524 })
525 .unzip();
526
527 Field::new_union(field.name(), type_ids, new_fields, *mode)
528 }
529 DataType::Dictionary(_, value_type) => {
530 if !send_dictionaries {
531 Field::new(
532 field.name(),
533 value_type.as_ref().clone(),
534 field.is_nullable(),
535 )
536 .with_metadata(field.metadata().clone())
537 } else {
538 dictionary_tracker.next_dict_id();
539 #[allow(deprecated)]
540 Field::new_dict(
541 field.name(),
542 field.data_type().clone(),
543 field.is_nullable(),
544 0,
545 field.dict_is_ordered().unwrap_or_default(),
546 )
547 .with_metadata(field.metadata().clone())
548 }
549 }
550 DataType::Map(inner, sorted) => Field::new(
551 field.name(),
552 DataType::Map(
553 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
554 *sorted,
555 ),
556 field.is_nullable(),
557 )
558 .with_metadata(field.metadata().clone()),
559 _ => field.as_ref().clone(),
560 }
561}
562
563fn prepare_schema_for_flight(
569 schema: &Schema,
570 dictionary_tracker: &mut DictionaryTracker,
571 send_dictionaries: bool,
572) -> Schema {
573 let fields: Fields = schema
574 .fields()
575 .iter()
576 .map(|field| match field.data_type() {
577 DataType::Dictionary(_, value_type) => {
578 if !send_dictionaries {
579 Field::new(
580 field.name(),
581 value_type.as_ref().clone(),
582 field.is_nullable(),
583 )
584 .with_metadata(field.metadata().clone())
585 } else {
586 dictionary_tracker.next_dict_id();
587 #[allow(deprecated)]
588 Field::new_dict(
589 field.name(),
590 field.data_type().clone(),
591 field.is_nullable(),
592 0,
593 field.dict_is_ordered().unwrap_or_default(),
594 )
595 .with_metadata(field.metadata().clone())
596 }
597 }
598 tpe if tpe.is_nested() => {
599 prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
600 }
601 _ => field.as_ref().clone(),
602 })
603 .collect();
604
605 Schema::new(fields).with_metadata(schema.metadata().clone())
606}
607
608fn split_batch_for_grpc_response(
615 batch: RecordBatch,
616 max_flight_data_size: usize,
617) -> Vec<RecordBatch> {
618 let size = batch
619 .columns()
620 .iter()
621 .map(|col| col.get_buffer_memory_size())
622 .sum::<usize>();
623
624 let n_batches =
625 (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
626 let rows_per_batch = (batch.num_rows() / n_batches).max(1);
627 let mut out = Vec::with_capacity(n_batches + 1);
628
629 let mut offset = 0;
630 while offset < batch.num_rows() {
631 let length = (rows_per_batch).min(batch.num_rows() - offset);
632 out.push(batch.slice(offset, length));
633
634 offset += length;
635 }
636
637 out
638}
639
640struct FlightIpcEncoder {
647 options: IpcWriteOptions,
648 data_gen: IpcDataGenerator,
649 dictionary_tracker: DictionaryTracker,
650}
651
652impl FlightIpcEncoder {
653 fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
654 Self {
655 options,
656 data_gen: IpcDataGenerator::default(),
657 dictionary_tracker: DictionaryTracker::new(error_on_replacement),
658 }
659 }
660
661 fn encode_schema(&self, schema: &Schema) -> FlightData {
663 SchemaAsIpc::new(schema, &self.options).into()
664 }
665
666 fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
669 let (encoded_dictionaries, encoded_batch) =
670 self.data_gen
671 .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
672
673 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
674 let flight_batch = encoded_batch.into();
675
676 Ok((flight_dictionaries, flight_batch))
677 }
678}
679
680fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
683 let columns = schema
684 .fields()
685 .iter()
686 .zip(batch.columns())
687 .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
688 .collect::<Result<Vec<_>>>()?;
689
690 let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
691
692 Ok(RecordBatch::try_new_with_options(
693 schema, columns, &options,
694 )?)
695}
696
697fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
699 let arr = match (array.data_type(), data_type) {
700 (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
701 let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
702
703 Arc::new(UnionArray::try_new(
704 fields.clone(),
705 union_arr.type_ids().clone(),
706 None,
707 fields
708 .iter()
709 .map(|(type_id, field)| {
710 Ok(arrow_cast::cast(
711 union_arr.child(type_id),
712 field.data_type(),
713 )?)
714 })
715 .collect::<Result<Vec<_>>>()?,
716 )?)
717 }
718 (_, data_type) => arrow_cast::cast(array, data_type)?,
719 };
720 Ok(arr)
721}
722
723#[cfg(test)]
724mod tests {
725 use crate::decode::{DecodedPayload, FlightDataDecoder};
726 use arrow_array::builder::{
727 GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
728 };
729 use arrow_array::*;
730 use arrow_array::{cast::downcast_array, types::*};
731 use arrow_buffer::ScalarBuffer;
732 use arrow_cast::pretty::pretty_format_batches;
733 use arrow_ipc::MetadataVersion;
734 use arrow_schema::{UnionFields, UnionMode};
735 use builder::{GenericStringBuilder, MapBuilder};
736 use std::collections::HashMap;
737
738 use super::*;
739
740 #[test]
741 fn test_encode_flight_data() {
744 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
746 let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
747
748 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
749 .expect("cannot create record batch");
750 let schema = batch.schema_ref();
751
752 let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
753
754 let big_batch = batch.slice(0, batch.num_rows() - 1);
755 let optimized_big_batch =
756 hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
757 let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
758
759 assert_eq!(
760 baseline_flight_batch.data_body.len(),
761 optimized_big_flight_batch.data_body.len()
762 );
763
764 let small_batch = batch.slice(0, 1);
765 let optimized_small_batch =
766 hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
767 let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
768
769 assert!(
770 baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
771 );
772 }
773
774 #[tokio::test]
775 async fn test_dictionary_hydration() {
776 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
777 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
778
779 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
780 "dict",
781 DataType::UInt16,
782 DataType::Utf8,
783 false,
784 )]));
785 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
786 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
787
788 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
789
790 let encoder = FlightDataEncoderBuilder::default().build(stream);
791 let mut decoder = FlightDataDecoder::new(encoder);
792 let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
793 let expected_schema = Arc::new(expected_schema);
794 let mut expected_arrays = vec![
795 StringArray::from(vec!["a", "a", "b"]),
796 StringArray::from(vec!["c", "c", "d"]),
797 ]
798 .into_iter();
799 while let Some(decoded) = decoder.next().await {
800 let decoded = decoded.unwrap();
801 match decoded.payload {
802 DecodedPayload::None => {}
803 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
804 DecodedPayload::RecordBatch(b) => {
805 assert_eq!(b.schema(), expected_schema);
806 let expected_array = expected_arrays.next().unwrap();
807 let actual_array = b.column_by_name("dict").unwrap();
808 let actual_array = downcast_array::<StringArray>(actual_array);
809
810 assert_eq!(actual_array, expected_array);
811 }
812 }
813 }
814 }
815
816 #[tokio::test]
817 async fn test_dictionary_resend() {
818 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
819 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
820
821 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
822 "dict",
823 DataType::UInt16,
824 DataType::Utf8,
825 false,
826 )]));
827 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
828 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
829
830 verify_flight_round_trip(vec![batch1, batch2]).await;
831 }
832
833 #[tokio::test]
834 async fn test_dictionary_hydration_known_schema() {
835 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
836 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
837
838 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
839 "dict",
840 DataType::UInt16,
841 DataType::Utf8,
842 false,
843 )]));
844 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
845 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
846
847 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
848
849 let encoder = FlightDataEncoderBuilder::default()
850 .with_schema(schema)
851 .build(stream);
852 let expected_schema =
853 Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
854 assert_eq!(Some(expected_schema), encoder.known_schema())
855 }
856
857 #[tokio::test]
858 async fn test_dictionary_resend_known_schema() {
859 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
860 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
861
862 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
863 "dict",
864 DataType::UInt16,
865 DataType::Utf8,
866 false,
867 )]));
868 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
869 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
870
871 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
872
873 let encoder = FlightDataEncoderBuilder::default()
874 .with_dictionary_handling(DictionaryHandling::Resend)
875 .with_schema(schema.clone())
876 .build(stream);
877 assert_eq!(Some(schema), encoder.known_schema())
878 }
879
880 #[tokio::test]
881 async fn test_multiple_dictionaries_resend() {
882 let schema = Arc::new(Schema::new(vec![
884 Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
885 Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
886 ]));
887
888 let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
889 Arc::new(vec!["a", "a", "b"].into_iter().collect());
890 let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
891 Arc::new(vec!["c", "c", "d"].into_iter().collect());
892 let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
893 Arc::new(vec!["b", "a", "c"].into_iter().collect());
894 let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
895 Arc::new(vec!["k", "d", "e"].into_iter().collect());
896 let batch1 =
897 RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
898 .unwrap();
899 let batch2 =
900 RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
901 .unwrap();
902
903 verify_flight_round_trip(vec![batch1, batch2]).await;
904 }
905
906 #[tokio::test]
907 async fn test_dictionary_list_hydration() {
908 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
909
910 builder.append_value(vec![Some("a"), None, Some("b")]);
911
912 let arr1 = builder.finish();
913
914 builder.append_value(vec![Some("c"), None, Some("d")]);
915
916 let arr2 = builder.finish();
917
918 let schema = Arc::new(Schema::new(vec![Field::new_list(
919 "dict_list",
920 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
921 true,
922 )]));
923
924 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
925 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
926
927 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
928
929 let encoder = FlightDataEncoderBuilder::default().build(stream);
930
931 let mut decoder = FlightDataDecoder::new(encoder);
932 let expected_schema = Schema::new(vec![Field::new_list(
933 "dict_list",
934 Field::new_list_field(DataType::Utf8, true),
935 true,
936 )]);
937
938 let expected_schema = Arc::new(expected_schema);
939
940 let mut expected_arrays = vec![
941 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
942 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
943 ]
944 .into_iter();
945
946 while let Some(decoded) = decoder.next().await {
947 let decoded = decoded.unwrap();
948 match decoded.payload {
949 DecodedPayload::None => {}
950 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
951 DecodedPayload::RecordBatch(b) => {
952 assert_eq!(b.schema(), expected_schema);
953 let expected_array = expected_arrays.next().unwrap();
954 let list_array =
955 downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
956 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
957
958 assert_eq!(elem_array, expected_array);
959 }
960 }
961 }
962 }
963
964 #[tokio::test]
965 async fn test_dictionary_list_resend() {
966 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
967
968 builder.append_value(vec![Some("a"), None, Some("b")]);
969
970 let arr1 = builder.finish();
971
972 builder.append_value(vec![Some("c"), None, Some("d")]);
973
974 let arr2 = builder.finish();
975
976 let schema = Arc::new(Schema::new(vec![Field::new_list(
977 "dict_list",
978 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
979 true,
980 )]));
981
982 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
983 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
984
985 verify_flight_round_trip(vec![batch1, batch2]).await;
986 }
987
988 #[tokio::test]
989 async fn test_dictionary_struct_hydration() {
990 let struct_fields = vec![Field::new_list(
991 "dict_list",
992 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
993 true,
994 )];
995
996 let mut struct_builder = StructBuilder::new(
997 struct_fields.clone(),
998 vec![Box::new(builder::ListBuilder::new(
999 StringDictionaryBuilder::<UInt16Type>::new(),
1000 ))],
1001 );
1002
1003 struct_builder
1004 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1005 .unwrap()
1006 .append_value(vec![Some("a"), None, Some("b")]);
1007
1008 struct_builder.append(true);
1009
1010 let arr1 = struct_builder.finish();
1011
1012 struct_builder
1013 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1014 .unwrap()
1015 .append_value(vec![Some("c"), None, Some("d")]);
1016 struct_builder.append(true);
1017
1018 let arr2 = struct_builder.finish();
1019
1020 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1021 "struct",
1022 struct_fields,
1023 true,
1024 )]));
1025
1026 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1027 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1028
1029 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1030
1031 let encoder = FlightDataEncoderBuilder::default().build(stream);
1032
1033 let mut decoder = FlightDataDecoder::new(encoder);
1034 let expected_schema = Schema::new(vec![Field::new_struct(
1035 "struct",
1036 vec![Field::new_list(
1037 "dict_list",
1038 Field::new_list_field(DataType::Utf8, true),
1039 true,
1040 )],
1041 true,
1042 )]);
1043
1044 let expected_schema = Arc::new(expected_schema);
1045
1046 let mut expected_arrays = vec![
1047 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1048 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1049 ]
1050 .into_iter();
1051
1052 while let Some(decoded) = decoder.next().await {
1053 let decoded = decoded.unwrap();
1054 match decoded.payload {
1055 DecodedPayload::None => {}
1056 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1057 DecodedPayload::RecordBatch(b) => {
1058 assert_eq!(b.schema(), expected_schema);
1059 let expected_array = expected_arrays.next().unwrap();
1060 let struct_array =
1061 downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1062 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1063
1064 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1065
1066 assert_eq!(elem_array, expected_array);
1067 }
1068 }
1069 }
1070 }
1071
1072 #[tokio::test]
1073 async fn test_dictionary_struct_resend() {
1074 let struct_fields = vec![Field::new_list(
1075 "dict_list",
1076 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1077 true,
1078 )];
1079
1080 let mut struct_builder = StructBuilder::new(
1081 struct_fields.clone(),
1082 vec![Box::new(builder::ListBuilder::new(
1083 StringDictionaryBuilder::<UInt16Type>::new(),
1084 ))],
1085 );
1086
1087 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1088 .unwrap()
1089 .append_value(vec![Some("a"), None, Some("b")]);
1090 struct_builder.append(true);
1091
1092 let arr1 = struct_builder.finish();
1093
1094 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1095 .unwrap()
1096 .append_value(vec![Some("c"), None, Some("d")]);
1097 struct_builder.append(true);
1098
1099 let arr2 = struct_builder.finish();
1100
1101 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1102 "struct",
1103 struct_fields,
1104 true,
1105 )]));
1106
1107 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1108 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1109
1110 verify_flight_round_trip(vec![batch1, batch2]).await;
1111 }
1112
1113 #[tokio::test]
1114 async fn test_dictionary_union_hydration() {
1115 let struct_fields = vec![Field::new_list(
1116 "dict_list",
1117 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1118 true,
1119 )];
1120
1121 let union_fields = [
1122 (
1123 0,
1124 Arc::new(Field::new_list(
1125 "dict_list",
1126 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1127 true,
1128 )),
1129 ),
1130 (
1131 1,
1132 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1133 ),
1134 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1135 ]
1136 .into_iter()
1137 .collect::<UnionFields>();
1138
1139 let struct_fields = vec![Field::new_list(
1140 "dict_list",
1141 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1142 true,
1143 )];
1144
1145 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1146
1147 builder.append_value(vec![Some("a"), None, Some("b")]);
1148
1149 let arr1 = builder.finish();
1150
1151 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1152 let arr1 = UnionArray::try_new(
1153 union_fields.clone(),
1154 type_id_buffer,
1155 None,
1156 vec![
1157 Arc::new(arr1) as Arc<dyn Array>,
1158 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1159 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1160 ],
1161 )
1162 .unwrap();
1163
1164 builder.append_value(vec![Some("c"), None, Some("d")]);
1165
1166 let arr2 = Arc::new(builder.finish());
1167 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1168
1169 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1170 let arr2 = UnionArray::try_new(
1171 union_fields.clone(),
1172 type_id_buffer,
1173 None,
1174 vec![
1175 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1176 Arc::new(arr2),
1177 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1178 ],
1179 )
1180 .unwrap();
1181
1182 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1183 let arr3 = UnionArray::try_new(
1184 union_fields.clone(),
1185 type_id_buffer,
1186 None,
1187 vec![
1188 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1189 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1190 Arc::new(StringArray::from(vec!["e"])),
1191 ],
1192 )
1193 .unwrap();
1194
1195 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1196 .iter()
1197 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1198 .unzip();
1199 let schema = Arc::new(Schema::new(vec![Field::new_union(
1200 "union",
1201 type_ids.clone(),
1202 union_fields.clone(),
1203 UnionMode::Sparse,
1204 )]));
1205
1206 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1207 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1208 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1209
1210 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1211
1212 let encoder = FlightDataEncoderBuilder::default().build(stream);
1213
1214 let mut decoder = FlightDataDecoder::new(encoder);
1215
1216 let hydrated_struct_fields = vec![Field::new_list(
1217 "dict_list",
1218 Field::new_list_field(DataType::Utf8, true),
1219 true,
1220 )];
1221
1222 let hydrated_union_fields = vec![
1223 Field::new_list(
1224 "dict_list",
1225 Field::new_list_field(DataType::Utf8, true),
1226 true,
1227 ),
1228 Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1229 Field::new("string", DataType::Utf8, true),
1230 ];
1231
1232 let expected_schema = Schema::new(vec![Field::new_union(
1233 "union",
1234 type_ids.clone(),
1235 hydrated_union_fields,
1236 UnionMode::Sparse,
1237 )]);
1238
1239 let expected_schema = Arc::new(expected_schema);
1240
1241 let mut expected_arrays = vec![
1242 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1243 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1244 StringArray::from(vec!["e"]),
1245 ]
1246 .into_iter();
1247
1248 let mut batch = 0;
1249 while let Some(decoded) = decoder.next().await {
1250 let decoded = decoded.unwrap();
1251 match decoded.payload {
1252 DecodedPayload::None => {}
1253 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1254 DecodedPayload::RecordBatch(b) => {
1255 assert_eq!(b.schema(), expected_schema);
1256 let expected_array = expected_arrays.next().unwrap();
1257 let union_arr =
1258 downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1259
1260 let elem_array = match batch {
1261 0 => {
1262 let list_array = downcast_array::<ListArray>(union_arr.child(0));
1263 downcast_array::<StringArray>(list_array.value(0).as_ref())
1264 }
1265 1 => {
1266 let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1267 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1268
1269 downcast_array::<StringArray>(list_array.value(0).as_ref())
1270 }
1271 _ => downcast_array::<StringArray>(union_arr.child(2)),
1272 };
1273
1274 batch += 1;
1275
1276 assert_eq!(elem_array, expected_array);
1277 }
1278 }
1279 }
1280 }
1281
1282 #[tokio::test]
1283 async fn test_dictionary_union_resend() {
1284 let struct_fields = vec![Field::new_list(
1285 "dict_list",
1286 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1287 true,
1288 )];
1289
1290 let union_fields = [
1291 (
1292 0,
1293 Arc::new(Field::new_list(
1294 "dict_list",
1295 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1296 true,
1297 )),
1298 ),
1299 (
1300 1,
1301 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1302 ),
1303 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1304 ]
1305 .into_iter()
1306 .collect::<UnionFields>();
1307
1308 let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1309 let dict_list_ty = field_types.next().unwrap();
1310 let struct_ty = field_types.next().unwrap();
1311 let string_ty = field_types.next().unwrap();
1312
1313 let struct_fields = vec![Field::new_list(
1314 "dict_list",
1315 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1316 true,
1317 )];
1318
1319 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1320
1321 builder.append_value(vec![Some("a"), None, Some("b")]);
1322
1323 let arr1 = builder.finish();
1324
1325 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1326 let arr1 = UnionArray::try_new(
1327 union_fields.clone(),
1328 type_id_buffer,
1329 None,
1330 vec![
1331 Arc::new(arr1),
1332 new_null_array(struct_ty, 1),
1333 new_null_array(string_ty, 1),
1334 ],
1335 )
1336 .unwrap();
1337
1338 builder.append_value(vec![Some("c"), None, Some("d")]);
1339
1340 let arr2 = Arc::new(builder.finish());
1341 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1342
1343 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1344 let arr2 = UnionArray::try_new(
1345 union_fields.clone(),
1346 type_id_buffer,
1347 None,
1348 vec![
1349 new_null_array(dict_list_ty, 1),
1350 Arc::new(arr2),
1351 new_null_array(string_ty, 1),
1352 ],
1353 )
1354 .unwrap();
1355
1356 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1357 let arr3 = UnionArray::try_new(
1358 union_fields.clone(),
1359 type_id_buffer,
1360 None,
1361 vec![
1362 new_null_array(dict_list_ty, 1),
1363 new_null_array(struct_ty, 1),
1364 Arc::new(StringArray::from(vec!["e"])),
1365 ],
1366 )
1367 .unwrap();
1368
1369 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1370 .iter()
1371 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1372 .unzip();
1373 let schema = Arc::new(Schema::new(vec![Field::new_union(
1374 "union",
1375 type_ids.clone(),
1376 union_fields.clone(),
1377 UnionMode::Sparse,
1378 )]));
1379
1380 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1381 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1382 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1383
1384 verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1385 }
1386
1387 #[tokio::test]
1388 async fn test_dictionary_map_hydration() {
1389 let mut builder = MapBuilder::new(
1390 None,
1391 StringDictionaryBuilder::<UInt16Type>::new(),
1392 StringDictionaryBuilder::<UInt16Type>::new(),
1393 );
1394
1395 builder.keys().append_value("k1");
1397 builder.values().append_value("a");
1398 builder.keys().append_value("k2");
1399 builder.values().append_null();
1400 builder.keys().append_value("k3");
1401 builder.values().append_value("b");
1402 builder.append(true).unwrap();
1403
1404 let arr1 = builder.finish();
1405
1406 builder.keys().append_value("k1");
1408 builder.values().append_value("c");
1409 builder.keys().append_value("k2");
1410 builder.values().append_null();
1411 builder.keys().append_value("k3");
1412 builder.values().append_value("d");
1413 builder.append(true).unwrap();
1414
1415 let arr2 = builder.finish();
1416
1417 let schema = Arc::new(Schema::new(vec![Field::new_map(
1418 "dict_map",
1419 "entries",
1420 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1421 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1422 false,
1423 false,
1424 )]));
1425
1426 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1427 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1428
1429 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1430
1431 let encoder = FlightDataEncoderBuilder::default().build(stream);
1432
1433 let mut decoder = FlightDataDecoder::new(encoder);
1434 let expected_schema = Schema::new(vec![Field::new_map(
1435 "dict_map",
1436 "entries",
1437 Field::new("keys", DataType::Utf8, false),
1438 Field::new("values", DataType::Utf8, true),
1439 false,
1440 false,
1441 )]);
1442
1443 let expected_schema = Arc::new(expected_schema);
1444
1445 let mut builder = MapBuilder::new(
1447 None,
1448 GenericStringBuilder::<i32>::new(),
1449 GenericStringBuilder::<i32>::new(),
1450 );
1451
1452 builder.keys().append_value("k1");
1454 builder.values().append_value("a");
1455 builder.keys().append_value("k2");
1456 builder.values().append_null();
1457 builder.keys().append_value("k3");
1458 builder.values().append_value("b");
1459 builder.append(true).unwrap();
1460
1461 let arr1 = builder.finish();
1462
1463 builder.keys().append_value("k1");
1465 builder.values().append_value("c");
1466 builder.keys().append_value("k2");
1467 builder.values().append_null();
1468 builder.keys().append_value("k3");
1469 builder.values().append_value("d");
1470 builder.append(true).unwrap();
1471
1472 let arr2 = builder.finish();
1473
1474 let mut expected_arrays = vec![arr1, arr2].into_iter();
1475
1476 while let Some(decoded) = decoder.next().await {
1477 let decoded = decoded.unwrap();
1478 match decoded.payload {
1479 DecodedPayload::None => {}
1480 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1481 DecodedPayload::RecordBatch(b) => {
1482 assert_eq!(b.schema(), expected_schema);
1483 let expected_array = expected_arrays.next().unwrap();
1484 let map_array =
1485 downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1486
1487 assert_eq!(map_array, expected_array);
1488 }
1489 }
1490 }
1491 }
1492
1493 #[tokio::test]
1494 async fn test_dictionary_map_resend() {
1495 let mut builder = MapBuilder::new(
1496 None,
1497 StringDictionaryBuilder::<UInt16Type>::new(),
1498 StringDictionaryBuilder::<UInt16Type>::new(),
1499 );
1500
1501 builder.keys().append_value("k1");
1503 builder.values().append_value("a");
1504 builder.keys().append_value("k2");
1505 builder.values().append_null();
1506 builder.keys().append_value("k3");
1507 builder.values().append_value("b");
1508 builder.append(true).unwrap();
1509
1510 let arr1 = builder.finish();
1511
1512 builder.keys().append_value("k1");
1514 builder.values().append_value("c");
1515 builder.keys().append_value("k2");
1516 builder.values().append_null();
1517 builder.keys().append_value("k3");
1518 builder.values().append_value("d");
1519 builder.append(true).unwrap();
1520
1521 let arr2 = builder.finish();
1522
1523 let schema = Arc::new(Schema::new(vec![Field::new_map(
1524 "dict_map",
1525 "entries",
1526 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1527 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1528 false,
1529 false,
1530 )]));
1531
1532 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1533 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1534
1535 verify_flight_round_trip(vec![batch1, batch2]).await;
1536 }
1537
1538 async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1539 let expected_schema = batches.first().unwrap().schema();
1540
1541 let encoder = FlightDataEncoderBuilder::default()
1542 .with_options(IpcWriteOptions::default())
1543 .with_dictionary_handling(DictionaryHandling::Resend)
1544 .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1545
1546 let mut expected_batches = batches.drain(..);
1547
1548 let mut decoder = FlightDataDecoder::new(encoder);
1549 while let Some(decoded) = decoder.next().await {
1550 let decoded = decoded.unwrap();
1551 match decoded.payload {
1552 DecodedPayload::None => {}
1553 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1554 DecodedPayload::RecordBatch(b) => {
1555 let expected_batch = expected_batches.next().unwrap();
1556 assert_eq!(b, expected_batch);
1557 }
1558 }
1559 }
1560 }
1561
1562 #[test]
1563 fn test_schema_metadata_encoded() {
1564 let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1565 HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1566 );
1567
1568 let mut dictionary_tracker = DictionaryTracker::new(false);
1569
1570 let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1571 assert!(got.metadata().contains_key("some_key"));
1572 }
1573
1574 #[test]
1575 fn test_encode_no_column_batch() {
1576 let batch = RecordBatch::try_new_with_options(
1577 Arc::new(Schema::empty()),
1578 vec![],
1579 &RecordBatchOptions::new().with_row_count(Some(10)),
1580 )
1581 .expect("cannot create record batch");
1582
1583 hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1584 }
1585
1586 fn make_flight_data(
1587 batch: &RecordBatch,
1588 options: &IpcWriteOptions,
1589 ) -> (Vec<FlightData>, FlightData) {
1590 flight_data_from_arrow_batch(batch, options)
1591 }
1592
1593 fn flight_data_from_arrow_batch(
1594 batch: &RecordBatch,
1595 options: &IpcWriteOptions,
1596 ) -> (Vec<FlightData>, FlightData) {
1597 let data_gen = IpcDataGenerator::default();
1598 let mut dictionary_tracker = DictionaryTracker::new(false);
1599
1600 let (encoded_dictionaries, encoded_batch) = data_gen
1601 .encoded_batch(batch, &mut dictionary_tracker, options)
1602 .expect("DictionaryTracker configured above to not error on replacement");
1603
1604 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1605 let flight_batch = encoded_batch.into();
1606
1607 (flight_dictionaries, flight_batch)
1608 }
1609
1610 #[test]
1611 fn test_split_batch_for_grpc_response() {
1612 let max_flight_data_size = 1024;
1613
1614 let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1616 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1617 .expect("cannot create record batch");
1618 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1619 assert_eq!(split.len(), 1);
1620 assert_eq!(batch, split[0]);
1621
1622 let n_rows = max_flight_data_size + 1;
1624 assert!(n_rows % 2 == 1, "should be an odd number");
1625 let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1626 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1627 .expect("cannot create record batch");
1628 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1629 assert_eq!(split.len(), 3);
1630 assert_eq!(
1631 split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1632 n_rows
1633 );
1634 let a = pretty_format_batches(&split).unwrap().to_string();
1635 let b = pretty_format_batches(&[batch]).unwrap().to_string();
1636 assert_eq!(a, b);
1637 }
1638
1639 #[test]
1640 fn test_split_batch_for_grpc_response_sizes() {
1641 verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1643
1644 verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1646
1647 verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1649
1650 verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1652
1653 verify_split(10, 1024, vec![10]);
1655 }
1656
1657 fn verify_split(
1661 num_input_rows: u64,
1662 max_flight_data_size_bytes: usize,
1663 expected_sizes: Vec<usize>,
1664 ) {
1665 let array: UInt64Array = (0..num_input_rows).collect();
1666
1667 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1668 .expect("cannot create record batch");
1669
1670 let input_rows = batch.num_rows();
1671
1672 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1673 let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1674 let output_rows: usize = sizes.iter().sum();
1675
1676 assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1677 assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1678 }
1679
1680 #[tokio::test]
1684 async fn flight_data_size_even() {
1685 let s1 = StringArray::from_iter_values(std::iter::repeat_n(".10 bytes.", 1024));
1686 let i1 = Int16Array::from_iter_values(0..1024);
1687 let s2 = StringArray::from_iter_values(std::iter::repeat_n("6bytes", 1024));
1688 let i2 = Int64Array::from_iter_values(0..1024);
1689
1690 let batch = RecordBatch::try_from_iter(vec![
1691 ("s1", Arc::new(s1) as _),
1692 ("i1", Arc::new(i1) as _),
1693 ("s2", Arc::new(s2) as _),
1694 ("i2", Arc::new(i2) as _),
1695 ])
1696 .unwrap();
1697
1698 verify_encoded_split(batch, 120).await;
1699 }
1700
1701 #[tokio::test]
1702 async fn flight_data_size_uneven_variable_lengths() {
1703 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1705 let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1706
1707 verify_encoded_split(batch, 4312).await;
1710 }
1711
1712 #[tokio::test]
1713 async fn flight_data_size_large_row() {
1714 let array1 = StringArray::from_iter_values(vec![
1716 "*".repeat(500),
1717 "*".repeat(500),
1718 "*".repeat(500),
1719 "*".repeat(500),
1720 ]);
1721 let array2 = StringArray::from_iter_values(vec![
1722 "*".to_string(),
1723 "*".repeat(1000),
1724 "*".repeat(2000),
1725 "*".repeat(4000),
1726 ]);
1727
1728 let array3 = StringArray::from_iter_values(vec![
1729 "*".to_string(),
1730 "*".to_string(),
1731 "*".repeat(1000),
1732 "*".repeat(2000),
1733 ]);
1734
1735 let batch = RecordBatch::try_from_iter(vec![
1736 ("a1", Arc::new(array1) as _),
1737 ("a2", Arc::new(array2) as _),
1738 ("a3", Arc::new(array3) as _),
1739 ])
1740 .unwrap();
1741
1742 verify_encoded_split(batch, 5808).await;
1746 }
1747
1748 #[tokio::test]
1749 async fn flight_data_size_string_dictionary() {
1750 let array: DictionaryArray<Int32Type> = (1..1024)
1752 .map(|i| match i % 3 {
1753 0 => Some("value0"),
1754 1 => Some("value1"),
1755 _ => None,
1756 })
1757 .collect();
1758
1759 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1760
1761 verify_encoded_split(batch, 56).await;
1762 }
1763
1764 #[tokio::test]
1765 async fn flight_data_size_large_dictionary() {
1766 let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1768
1769 let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1770
1771 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1772
1773 verify_encoded_split(batch, 3336).await;
1776 }
1777
1778 #[tokio::test]
1779 async fn flight_data_size_large_dictionary_repeated_non_uniform() {
1780 let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
1782 let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
1783 let array = DictionaryArray::new(keys, Arc::new(values));
1784
1785 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1786
1787 verify_encoded_split(batch, 5288).await;
1790 }
1791
1792 #[tokio::test]
1793 async fn flight_data_size_multiple_dictionaries() {
1794 let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1796 let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
1798 let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
1800
1801 let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
1802 let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
1803 let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
1804
1805 let batch = RecordBatch::try_from_iter(vec![
1806 ("a1", Arc::new(array1) as _),
1807 ("a2", Arc::new(array2) as _),
1808 ("a3", Arc::new(array3) as _),
1809 ])
1810 .unwrap();
1811
1812 verify_encoded_split(batch, 4136).await;
1815 }
1816
1817 fn flight_data_size(d: &FlightData) -> usize {
1819 let flight_descriptor_size = d
1820 .flight_descriptor
1821 .as_ref()
1822 .map(|descriptor| {
1823 let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
1824
1825 std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
1826 })
1827 .unwrap_or(0);
1828
1829 flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
1830 }
1831
1832 async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
1848 let num_rows = batch.num_rows();
1849
1850 let mut max_overage_seen = 0;
1852
1853 for max_flight_data_size in [1024, 2021, 5000] {
1854 println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
1855
1856 let mut stream = FlightDataEncoderBuilder::new()
1857 .with_max_flight_data_size(max_flight_data_size)
1858 .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
1860 .build(futures::stream::iter([Ok(batch.clone())]));
1861
1862 let mut i = 0;
1863 while let Some(data) = stream.next().await.transpose().unwrap() {
1864 let actual_data_size = flight_data_size(&data);
1865
1866 let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
1867
1868 assert!(
1869 actual_overage <= allowed_overage,
1870 "encoded data[{i}]: actual size {actual_data_size}, \
1871 actual_overage: {actual_overage} \
1872 allowed_overage: {allowed_overage}"
1873 );
1874
1875 i += 1;
1876
1877 max_overage_seen = max_overage_seen.max(actual_overage)
1878 }
1879 }
1880
1881 assert_eq!(
1885 allowed_overage, max_overage_seen,
1886 "Specified overage was too high"
1887 );
1888 }
1889}