1use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{ready, stream::BoxStream, Stream, StreamExt};
28
29#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146 max_flight_data_size: usize,
149 options: IpcWriteOptions,
151 app_metadata: Bytes,
153 schema: Option<SchemaRef>,
155 descriptor: Option<FlightDescriptor>,
157 dictionary_handling: DictionaryHandling,
160}
161
162pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169 fn default() -> Self {
170 Self {
171 max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172 options: IpcWriteOptions::default(),
173 app_metadata: Bytes::new(),
174 schema: None,
175 descriptor: None,
176 dictionary_handling: DictionaryHandling::Hydrate,
177 }
178 }
179}
180
181impl FlightDataEncoderBuilder {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198 self.max_flight_data_size = max_flight_data_size;
199 self
200 }
201
202 pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204 self.dictionary_handling = dictionary_handling;
205 self
206 }
207
208 pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212 self.app_metadata = app_metadata;
213 self
214 }
215
216 pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218 self.options = options;
219 self
220 }
221
222 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227 self.schema = Some(schema);
228 self
229 }
230
231 pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233 self.descriptor = descriptor;
234 self
235 }
236
237 pub fn build<S>(self, input: S) -> FlightDataEncoder
242 where
243 S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244 {
245 let Self {
246 max_flight_data_size,
247 options,
248 app_metadata,
249 schema,
250 descriptor,
251 dictionary_handling,
252 } = self;
253
254 FlightDataEncoder::new(
255 input.boxed(),
256 schema,
257 max_flight_data_size,
258 options,
259 app_metadata,
260 descriptor,
261 dictionary_handling,
262 )
263 }
264}
265
266pub struct FlightDataEncoder {
270 inner: BoxStream<'static, Result<RecordBatch>>,
272 schema: Option<SchemaRef>,
274 max_flight_data_size: usize,
277 encoder: FlightIpcEncoder,
279 app_metadata: Option<Bytes>,
281 queue: VecDeque<FlightData>,
283 done: bool,
285 descriptor: Option<FlightDescriptor>,
287 dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293 fn new(
294 inner: BoxStream<'static, Result<RecordBatch>>,
295 schema: Option<SchemaRef>,
296 max_flight_data_size: usize,
297 options: IpcWriteOptions,
298 app_metadata: Bytes,
299 descriptor: Option<FlightDescriptor>,
300 dictionary_handling: DictionaryHandling,
301 ) -> Self {
302 let mut encoder = Self {
303 inner,
304 schema: None,
305 max_flight_data_size,
306 encoder: FlightIpcEncoder::new(
307 options,
308 dictionary_handling != DictionaryHandling::Resend,
309 ),
310 app_metadata: Some(app_metadata),
311 queue: VecDeque::new(),
312 done: false,
313 descriptor,
314 dictionary_handling,
315 };
316
317 if let Some(schema) = schema {
319 encoder.encode_schema(&schema);
320 }
321
322 encoder
323 }
324
325 pub fn known_schema(&self) -> Option<SchemaRef> {
328 self.schema.clone()
329 }
330
331 fn queue_message(&mut self, mut data: FlightData) {
333 if let Some(descriptor) = self.descriptor.take() {
334 data.flight_descriptor = Some(descriptor);
335 }
336 self.queue.push_back(data);
337 }
338
339 fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341 for data in datas {
342 self.queue_message(data)
343 }
344 }
345
346 fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349 let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352 let schema = Arc::new(prepare_schema_for_flight(
353 schema,
354 &mut self.encoder.dictionary_tracker,
355 send_dictionaries,
356 ));
357 let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359 if let Some(app_metadata) = self.app_metadata.take() {
361 schema_flight_data.app_metadata = app_metadata;
362 }
363 self.queue_message(schema_flight_data);
364 self.schema = Some(schema.clone());
366 schema
367 }
368
369 fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371 let schema = match &self.schema {
372 Some(schema) => schema.clone(),
373 None => self.encode_schema(batch.schema_ref()),
375 };
376
377 let batch = match self.dictionary_handling {
378 DictionaryHandling::Resend => batch,
379 DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380 };
381
382 for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383 let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385 self.queue_messages(flight_dictionaries);
386 self.queue_message(flight_batch);
387 }
388
389 Ok(())
390 }
391}
392
393impl Stream for FlightDataEncoder {
394 type Item = Result<FlightData>;
395
396 fn poll_next(
397 mut self: Pin<&mut Self>,
398 cx: &mut std::task::Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 loop {
401 if self.done && self.queue.is_empty() {
402 return Poll::Ready(None);
403 }
404
405 if let Some(data) = self.queue.pop_front() {
407 return Poll::Ready(Some(Ok(data)));
408 }
409
410 let batch = ready!(self.inner.poll_next_unpin(cx));
412
413 match batch {
414 None => {
415 self.done = true;
417 assert!(self.queue.is_empty());
419 return Poll::Ready(None);
420 }
421 Some(Err(e)) => {
422 self.done = true;
424 self.queue.clear();
425 return Poll::Ready(Some(Err(e)));
426 }
427 Some(Ok(batch)) => {
428 if let Err(e) = self.encode_batch(batch) {
430 self.done = true;
431 self.queue.clear();
432 return Poll::Ready(Some(Err(e)));
433 }
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470 Hydrate,
478 Resend,
488}
489
490fn prepare_field_for_flight(
491 field: &FieldRef,
492 dictionary_tracker: &mut DictionaryTracker,
493 send_dictionaries: bool,
494) -> Field {
495 match field.data_type() {
496 DataType::List(inner) => Field::new_list(
497 field.name(),
498 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499 field.is_nullable(),
500 )
501 .with_metadata(field.metadata().clone()),
502 DataType::LargeList(inner) => Field::new_list(
503 field.name(),
504 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505 field.is_nullable(),
506 )
507 .with_metadata(field.metadata().clone()),
508 DataType::Struct(fields) => {
509 let new_fields: Vec<Field> = fields
510 .iter()
511 .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512 .collect();
513 Field::new_struct(field.name(), new_fields, field.is_nullable())
514 .with_metadata(field.metadata().clone())
515 }
516 DataType::Union(fields, mode) => {
517 let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518 .iter()
519 .map(|(type_id, f)| {
520 (
521 type_id,
522 prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523 )
524 })
525 .unzip();
526
527 Field::new_union(field.name(), type_ids, new_fields, *mode)
528 }
529 DataType::Dictionary(_, value_type) => {
530 if !send_dictionaries {
531 Field::new(
532 field.name(),
533 value_type.as_ref().clone(),
534 field.is_nullable(),
535 )
536 .with_metadata(field.metadata().clone())
537 } else {
538 #[allow(deprecated)]
539 let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
540
541 #[allow(deprecated)]
542 Field::new_dict(
543 field.name(),
544 field.data_type().clone(),
545 field.is_nullable(),
546 dict_id,
547 field.dict_is_ordered().unwrap_or_default(),
548 )
549 .with_metadata(field.metadata().clone())
550 }
551 }
552 DataType::Map(inner, sorted) => Field::new(
553 field.name(),
554 DataType::Map(
555 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
556 *sorted,
557 ),
558 field.is_nullable(),
559 )
560 .with_metadata(field.metadata().clone()),
561 _ => field.as_ref().clone(),
562 }
563}
564
565fn prepare_schema_for_flight(
571 schema: &Schema,
572 dictionary_tracker: &mut DictionaryTracker,
573 send_dictionaries: bool,
574) -> Schema {
575 let fields: Fields = schema
576 .fields()
577 .iter()
578 .map(|field| match field.data_type() {
579 DataType::Dictionary(_, value_type) => {
580 if !send_dictionaries {
581 Field::new(
582 field.name(),
583 value_type.as_ref().clone(),
584 field.is_nullable(),
585 )
586 .with_metadata(field.metadata().clone())
587 } else {
588 #[allow(deprecated)]
589 let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
590 #[allow(deprecated)]
591 Field::new_dict(
592 field.name(),
593 field.data_type().clone(),
594 field.is_nullable(),
595 dict_id,
596 field.dict_is_ordered().unwrap_or_default(),
597 )
598 .with_metadata(field.metadata().clone())
599 }
600 }
601 tpe if tpe.is_nested() => {
602 prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
603 }
604 _ => field.as_ref().clone(),
605 })
606 .collect();
607
608 Schema::new(fields).with_metadata(schema.metadata().clone())
609}
610
611fn split_batch_for_grpc_response(
618 batch: RecordBatch,
619 max_flight_data_size: usize,
620) -> Vec<RecordBatch> {
621 let size = batch
622 .columns()
623 .iter()
624 .map(|col| col.get_buffer_memory_size())
625 .sum::<usize>();
626
627 let n_batches =
628 (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
629 let rows_per_batch = (batch.num_rows() / n_batches).max(1);
630 let mut out = Vec::with_capacity(n_batches + 1);
631
632 let mut offset = 0;
633 while offset < batch.num_rows() {
634 let length = (rows_per_batch).min(batch.num_rows() - offset);
635 out.push(batch.slice(offset, length));
636
637 offset += length;
638 }
639
640 out
641}
642
643struct FlightIpcEncoder {
650 options: IpcWriteOptions,
651 data_gen: IpcDataGenerator,
652 dictionary_tracker: DictionaryTracker,
653}
654
655impl FlightIpcEncoder {
656 fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
657 #[allow(deprecated)]
658 let preserve_dict_id = options.preserve_dict_id();
659 Self {
660 options,
661 data_gen: IpcDataGenerator::default(),
662 #[allow(deprecated)]
663 dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id(
664 error_on_replacement,
665 preserve_dict_id,
666 ),
667 }
668 }
669
670 fn encode_schema(&self, schema: &Schema) -> FlightData {
672 SchemaAsIpc::new(schema, &self.options).into()
673 }
674
675 fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
678 let (encoded_dictionaries, encoded_batch) =
679 self.data_gen
680 .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
681
682 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
683 let flight_batch = encoded_batch.into();
684
685 Ok((flight_dictionaries, flight_batch))
686 }
687}
688
689fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
692 let columns = schema
693 .fields()
694 .iter()
695 .zip(batch.columns())
696 .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
697 .collect::<Result<Vec<_>>>()?;
698
699 let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
700
701 Ok(RecordBatch::try_new_with_options(
702 schema, columns, &options,
703 )?)
704}
705
706fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
708 let arr = match (array.data_type(), data_type) {
709 (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
710 let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
711
712 Arc::new(UnionArray::try_new(
713 fields.clone(),
714 union_arr.type_ids().clone(),
715 None,
716 fields
717 .iter()
718 .map(|(type_id, field)| {
719 Ok(arrow_cast::cast(
720 union_arr.child(type_id),
721 field.data_type(),
722 )?)
723 })
724 .collect::<Result<Vec<_>>>()?,
725 )?)
726 }
727 (_, data_type) => arrow_cast::cast(array, data_type)?,
728 };
729 Ok(arr)
730}
731
732#[cfg(test)]
733mod tests {
734 use crate::decode::{DecodedPayload, FlightDataDecoder};
735 use arrow_array::builder::{
736 GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
737 };
738 use arrow_array::*;
739 use arrow_array::{cast::downcast_array, types::*};
740 use arrow_buffer::ScalarBuffer;
741 use arrow_cast::pretty::pretty_format_batches;
742 use arrow_ipc::MetadataVersion;
743 use arrow_schema::{UnionFields, UnionMode};
744 use builder::{GenericStringBuilder, MapBuilder};
745 use std::collections::HashMap;
746
747 use super::*;
748
749 #[test]
750 fn test_encode_flight_data() {
753 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
755 let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
756
757 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
758 .expect("cannot create record batch");
759 let schema = batch.schema_ref();
760
761 let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
762
763 let big_batch = batch.slice(0, batch.num_rows() - 1);
764 let optimized_big_batch =
765 hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
766 let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
767
768 assert_eq!(
769 baseline_flight_batch.data_body.len(),
770 optimized_big_flight_batch.data_body.len()
771 );
772
773 let small_batch = batch.slice(0, 1);
774 let optimized_small_batch =
775 hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
776 let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
777
778 assert!(
779 baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
780 );
781 }
782
783 #[tokio::test]
784 async fn test_dictionary_hydration() {
785 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
786 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
787
788 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
789 "dict",
790 DataType::UInt16,
791 DataType::Utf8,
792 false,
793 )]));
794 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
795 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
796
797 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
798
799 let encoder = FlightDataEncoderBuilder::default().build(stream);
800 let mut decoder = FlightDataDecoder::new(encoder);
801 let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
802 let expected_schema = Arc::new(expected_schema);
803 let mut expected_arrays = vec![
804 StringArray::from(vec!["a", "a", "b"]),
805 StringArray::from(vec!["c", "c", "d"]),
806 ]
807 .into_iter();
808 while let Some(decoded) = decoder.next().await {
809 let decoded = decoded.unwrap();
810 match decoded.payload {
811 DecodedPayload::None => {}
812 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
813 DecodedPayload::RecordBatch(b) => {
814 assert_eq!(b.schema(), expected_schema);
815 let expected_array = expected_arrays.next().unwrap();
816 let actual_array = b.column_by_name("dict").unwrap();
817 let actual_array = downcast_array::<StringArray>(actual_array);
818
819 assert_eq!(actual_array, expected_array);
820 }
821 }
822 }
823 }
824
825 #[tokio::test]
826 async fn test_dictionary_resend() {
827 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
828 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
829
830 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
831 "dict",
832 DataType::UInt16,
833 DataType::Utf8,
834 false,
835 )]));
836 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
837 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
838
839 verify_flight_round_trip(vec![batch1, batch2]).await;
840 }
841
842 #[tokio::test]
843 async fn test_dictionary_hydration_known_schema() {
844 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
845 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
846
847 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
848 "dict",
849 DataType::UInt16,
850 DataType::Utf8,
851 false,
852 )]));
853 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
854 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
855
856 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
857
858 let encoder = FlightDataEncoderBuilder::default()
859 .with_schema(schema)
860 .build(stream);
861 let expected_schema =
862 Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
863 assert_eq!(Some(expected_schema), encoder.known_schema())
864 }
865
866 #[tokio::test]
867 async fn test_dictionary_resend_known_schema() {
868 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
869 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
870
871 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
872 "dict",
873 DataType::UInt16,
874 DataType::Utf8,
875 false,
876 )]));
877 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
878 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
879
880 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
881
882 let encoder = FlightDataEncoderBuilder::default()
883 .with_dictionary_handling(DictionaryHandling::Resend)
884 .with_schema(schema.clone())
885 .build(stream);
886 assert_eq!(Some(schema), encoder.known_schema())
887 }
888
889 #[tokio::test]
890 async fn test_multiple_dictionaries_resend() {
891 let schema = Arc::new(Schema::new(vec![
893 Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
894 Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
895 ]));
896
897 let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
898 Arc::new(vec!["a", "a", "b"].into_iter().collect());
899 let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
900 Arc::new(vec!["c", "c", "d"].into_iter().collect());
901 let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
902 Arc::new(vec!["b", "a", "c"].into_iter().collect());
903 let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
904 Arc::new(vec!["k", "d", "e"].into_iter().collect());
905 let batch1 =
906 RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
907 .unwrap();
908 let batch2 =
909 RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
910 .unwrap();
911
912 verify_flight_round_trip(vec![batch1, batch2]).await;
913 }
914
915 #[tokio::test]
916 async fn test_dictionary_list_hydration() {
917 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
918
919 builder.append_value(vec![Some("a"), None, Some("b")]);
920
921 let arr1 = builder.finish();
922
923 builder.append_value(vec![Some("c"), None, Some("d")]);
924
925 let arr2 = builder.finish();
926
927 let schema = Arc::new(Schema::new(vec![Field::new_list(
928 "dict_list",
929 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
930 true,
931 )]));
932
933 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
934 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
935
936 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
937
938 let encoder = FlightDataEncoderBuilder::default().build(stream);
939
940 let mut decoder = FlightDataDecoder::new(encoder);
941 let expected_schema = Schema::new(vec![Field::new_list(
942 "dict_list",
943 Field::new_list_field(DataType::Utf8, true),
944 true,
945 )]);
946
947 let expected_schema = Arc::new(expected_schema);
948
949 let mut expected_arrays = vec![
950 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
951 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
952 ]
953 .into_iter();
954
955 while let Some(decoded) = decoder.next().await {
956 let decoded = decoded.unwrap();
957 match decoded.payload {
958 DecodedPayload::None => {}
959 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
960 DecodedPayload::RecordBatch(b) => {
961 assert_eq!(b.schema(), expected_schema);
962 let expected_array = expected_arrays.next().unwrap();
963 let list_array =
964 downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
965 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
966
967 assert_eq!(elem_array, expected_array);
968 }
969 }
970 }
971 }
972
973 #[tokio::test]
974 async fn test_dictionary_list_resend() {
975 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
976
977 builder.append_value(vec![Some("a"), None, Some("b")]);
978
979 let arr1 = builder.finish();
980
981 builder.append_value(vec![Some("c"), None, Some("d")]);
982
983 let arr2 = builder.finish();
984
985 let schema = Arc::new(Schema::new(vec![Field::new_list(
986 "dict_list",
987 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
988 true,
989 )]));
990
991 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
992 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
993
994 verify_flight_round_trip(vec![batch1, batch2]).await;
995 }
996
997 #[tokio::test]
998 async fn test_dictionary_struct_hydration() {
999 let struct_fields = vec![Field::new_list(
1000 "dict_list",
1001 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1002 true,
1003 )];
1004
1005 let mut struct_builder = StructBuilder::new(
1006 struct_fields.clone(),
1007 vec![Box::new(builder::ListBuilder::new(
1008 StringDictionaryBuilder::<UInt16Type>::new(),
1009 ))],
1010 );
1011
1012 struct_builder
1013 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1014 .unwrap()
1015 .append_value(vec![Some("a"), None, Some("b")]);
1016
1017 struct_builder.append(true);
1018
1019 let arr1 = struct_builder.finish();
1020
1021 struct_builder
1022 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1023 .unwrap()
1024 .append_value(vec![Some("c"), None, Some("d")]);
1025 struct_builder.append(true);
1026
1027 let arr2 = struct_builder.finish();
1028
1029 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1030 "struct",
1031 struct_fields,
1032 true,
1033 )]));
1034
1035 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1036 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1037
1038 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1039
1040 let encoder = FlightDataEncoderBuilder::default().build(stream);
1041
1042 let mut decoder = FlightDataDecoder::new(encoder);
1043 let expected_schema = Schema::new(vec![Field::new_struct(
1044 "struct",
1045 vec![Field::new_list(
1046 "dict_list",
1047 Field::new_list_field(DataType::Utf8, true),
1048 true,
1049 )],
1050 true,
1051 )]);
1052
1053 let expected_schema = Arc::new(expected_schema);
1054
1055 let mut expected_arrays = vec![
1056 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1057 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1058 ]
1059 .into_iter();
1060
1061 while let Some(decoded) = decoder.next().await {
1062 let decoded = decoded.unwrap();
1063 match decoded.payload {
1064 DecodedPayload::None => {}
1065 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1066 DecodedPayload::RecordBatch(b) => {
1067 assert_eq!(b.schema(), expected_schema);
1068 let expected_array = expected_arrays.next().unwrap();
1069 let struct_array =
1070 downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1071 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1072
1073 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1074
1075 assert_eq!(elem_array, expected_array);
1076 }
1077 }
1078 }
1079 }
1080
1081 #[tokio::test]
1082 async fn test_dictionary_struct_resend() {
1083 let struct_fields = vec![Field::new_list(
1084 "dict_list",
1085 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1086 true,
1087 )];
1088
1089 let mut struct_builder = StructBuilder::new(
1090 struct_fields.clone(),
1091 vec![Box::new(builder::ListBuilder::new(
1092 StringDictionaryBuilder::<UInt16Type>::new(),
1093 ))],
1094 );
1095
1096 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1097 .unwrap()
1098 .append_value(vec![Some("a"), None, Some("b")]);
1099 struct_builder.append(true);
1100
1101 let arr1 = struct_builder.finish();
1102
1103 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1104 .unwrap()
1105 .append_value(vec![Some("c"), None, Some("d")]);
1106 struct_builder.append(true);
1107
1108 let arr2 = struct_builder.finish();
1109
1110 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1111 "struct",
1112 struct_fields,
1113 true,
1114 )]));
1115
1116 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1117 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1118
1119 verify_flight_round_trip(vec![batch1, batch2]).await;
1120 }
1121
1122 #[tokio::test]
1123 async fn test_dictionary_union_hydration() {
1124 let struct_fields = vec![Field::new_list(
1125 "dict_list",
1126 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1127 true,
1128 )];
1129
1130 let union_fields = [
1131 (
1132 0,
1133 Arc::new(Field::new_list(
1134 "dict_list",
1135 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1136 true,
1137 )),
1138 ),
1139 (
1140 1,
1141 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1142 ),
1143 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1144 ]
1145 .into_iter()
1146 .collect::<UnionFields>();
1147
1148 let struct_fields = vec![Field::new_list(
1149 "dict_list",
1150 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1151 true,
1152 )];
1153
1154 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1155
1156 builder.append_value(vec![Some("a"), None, Some("b")]);
1157
1158 let arr1 = builder.finish();
1159
1160 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1161 let arr1 = UnionArray::try_new(
1162 union_fields.clone(),
1163 type_id_buffer,
1164 None,
1165 vec![
1166 Arc::new(arr1) as Arc<dyn Array>,
1167 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1168 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1169 ],
1170 )
1171 .unwrap();
1172
1173 builder.append_value(vec![Some("c"), None, Some("d")]);
1174
1175 let arr2 = Arc::new(builder.finish());
1176 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1177
1178 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1179 let arr2 = UnionArray::try_new(
1180 union_fields.clone(),
1181 type_id_buffer,
1182 None,
1183 vec![
1184 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1185 Arc::new(arr2),
1186 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1187 ],
1188 )
1189 .unwrap();
1190
1191 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1192 let arr3 = UnionArray::try_new(
1193 union_fields.clone(),
1194 type_id_buffer,
1195 None,
1196 vec![
1197 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1198 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1199 Arc::new(StringArray::from(vec!["e"])),
1200 ],
1201 )
1202 .unwrap();
1203
1204 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1205 .iter()
1206 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1207 .unzip();
1208 let schema = Arc::new(Schema::new(vec![Field::new_union(
1209 "union",
1210 type_ids.clone(),
1211 union_fields.clone(),
1212 UnionMode::Sparse,
1213 )]));
1214
1215 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1216 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1217 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1218
1219 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1220
1221 let encoder = FlightDataEncoderBuilder::default().build(stream);
1222
1223 let mut decoder = FlightDataDecoder::new(encoder);
1224
1225 let hydrated_struct_fields = vec![Field::new_list(
1226 "dict_list",
1227 Field::new_list_field(DataType::Utf8, true),
1228 true,
1229 )];
1230
1231 let hydrated_union_fields = vec![
1232 Field::new_list(
1233 "dict_list",
1234 Field::new_list_field(DataType::Utf8, true),
1235 true,
1236 ),
1237 Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1238 Field::new("string", DataType::Utf8, true),
1239 ];
1240
1241 let expected_schema = Schema::new(vec![Field::new_union(
1242 "union",
1243 type_ids.clone(),
1244 hydrated_union_fields,
1245 UnionMode::Sparse,
1246 )]);
1247
1248 let expected_schema = Arc::new(expected_schema);
1249
1250 let mut expected_arrays = vec![
1251 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1252 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1253 StringArray::from(vec!["e"]),
1254 ]
1255 .into_iter();
1256
1257 let mut batch = 0;
1258 while let Some(decoded) = decoder.next().await {
1259 let decoded = decoded.unwrap();
1260 match decoded.payload {
1261 DecodedPayload::None => {}
1262 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1263 DecodedPayload::RecordBatch(b) => {
1264 assert_eq!(b.schema(), expected_schema);
1265 let expected_array = expected_arrays.next().unwrap();
1266 let union_arr =
1267 downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1268
1269 let elem_array = match batch {
1270 0 => {
1271 let list_array = downcast_array::<ListArray>(union_arr.child(0));
1272 downcast_array::<StringArray>(list_array.value(0).as_ref())
1273 }
1274 1 => {
1275 let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1276 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1277
1278 downcast_array::<StringArray>(list_array.value(0).as_ref())
1279 }
1280 _ => downcast_array::<StringArray>(union_arr.child(2)),
1281 };
1282
1283 batch += 1;
1284
1285 assert_eq!(elem_array, expected_array);
1286 }
1287 }
1288 }
1289 }
1290
1291 #[tokio::test]
1292 async fn test_dictionary_union_resend() {
1293 let struct_fields = vec![Field::new_list(
1294 "dict_list",
1295 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1296 true,
1297 )];
1298
1299 let union_fields = [
1300 (
1301 0,
1302 Arc::new(Field::new_list(
1303 "dict_list",
1304 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1305 true,
1306 )),
1307 ),
1308 (
1309 1,
1310 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1311 ),
1312 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1313 ]
1314 .into_iter()
1315 .collect::<UnionFields>();
1316
1317 let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1318 let dict_list_ty = field_types.next().unwrap();
1319 let struct_ty = field_types.next().unwrap();
1320 let string_ty = field_types.next().unwrap();
1321
1322 let struct_fields = vec![Field::new_list(
1323 "dict_list",
1324 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1325 true,
1326 )];
1327
1328 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1329
1330 builder.append_value(vec![Some("a"), None, Some("b")]);
1331
1332 let arr1 = builder.finish();
1333
1334 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1335 let arr1 = UnionArray::try_new(
1336 union_fields.clone(),
1337 type_id_buffer,
1338 None,
1339 vec![
1340 Arc::new(arr1),
1341 new_null_array(struct_ty, 1),
1342 new_null_array(string_ty, 1),
1343 ],
1344 )
1345 .unwrap();
1346
1347 builder.append_value(vec![Some("c"), None, Some("d")]);
1348
1349 let arr2 = Arc::new(builder.finish());
1350 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1351
1352 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1353 let arr2 = UnionArray::try_new(
1354 union_fields.clone(),
1355 type_id_buffer,
1356 None,
1357 vec![
1358 new_null_array(dict_list_ty, 1),
1359 Arc::new(arr2),
1360 new_null_array(string_ty, 1),
1361 ],
1362 )
1363 .unwrap();
1364
1365 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1366 let arr3 = UnionArray::try_new(
1367 union_fields.clone(),
1368 type_id_buffer,
1369 None,
1370 vec![
1371 new_null_array(dict_list_ty, 1),
1372 new_null_array(struct_ty, 1),
1373 Arc::new(StringArray::from(vec!["e"])),
1374 ],
1375 )
1376 .unwrap();
1377
1378 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1379 .iter()
1380 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1381 .unzip();
1382 let schema = Arc::new(Schema::new(vec![Field::new_union(
1383 "union",
1384 type_ids.clone(),
1385 union_fields.clone(),
1386 UnionMode::Sparse,
1387 )]));
1388
1389 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1390 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1391 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1392
1393 verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1394 }
1395
1396 #[tokio::test]
1397 async fn test_dictionary_map_hydration() {
1398 let mut builder = MapBuilder::new(
1399 None,
1400 StringDictionaryBuilder::<UInt16Type>::new(),
1401 StringDictionaryBuilder::<UInt16Type>::new(),
1402 );
1403
1404 builder.keys().append_value("k1");
1406 builder.values().append_value("a");
1407 builder.keys().append_value("k2");
1408 builder.values().append_null();
1409 builder.keys().append_value("k3");
1410 builder.values().append_value("b");
1411 builder.append(true).unwrap();
1412
1413 let arr1 = builder.finish();
1414
1415 builder.keys().append_value("k1");
1417 builder.values().append_value("c");
1418 builder.keys().append_value("k2");
1419 builder.values().append_null();
1420 builder.keys().append_value("k3");
1421 builder.values().append_value("d");
1422 builder.append(true).unwrap();
1423
1424 let arr2 = builder.finish();
1425
1426 let schema = Arc::new(Schema::new(vec![Field::new_map(
1427 "dict_map",
1428 "entries",
1429 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1430 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1431 false,
1432 false,
1433 )]));
1434
1435 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1436 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1437
1438 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1439
1440 let encoder = FlightDataEncoderBuilder::default().build(stream);
1441
1442 let mut decoder = FlightDataDecoder::new(encoder);
1443 let expected_schema = Schema::new(vec![Field::new_map(
1444 "dict_map",
1445 "entries",
1446 Field::new("keys", DataType::Utf8, false),
1447 Field::new("values", DataType::Utf8, true),
1448 false,
1449 false,
1450 )]);
1451
1452 let expected_schema = Arc::new(expected_schema);
1453
1454 let mut builder = MapBuilder::new(
1456 None,
1457 GenericStringBuilder::<i32>::new(),
1458 GenericStringBuilder::<i32>::new(),
1459 );
1460
1461 builder.keys().append_value("k1");
1463 builder.values().append_value("a");
1464 builder.keys().append_value("k2");
1465 builder.values().append_null();
1466 builder.keys().append_value("k3");
1467 builder.values().append_value("b");
1468 builder.append(true).unwrap();
1469
1470 let arr1 = builder.finish();
1471
1472 builder.keys().append_value("k1");
1474 builder.values().append_value("c");
1475 builder.keys().append_value("k2");
1476 builder.values().append_null();
1477 builder.keys().append_value("k3");
1478 builder.values().append_value("d");
1479 builder.append(true).unwrap();
1480
1481 let arr2 = builder.finish();
1482
1483 let mut expected_arrays = vec![arr1, arr2].into_iter();
1484
1485 while let Some(decoded) = decoder.next().await {
1486 let decoded = decoded.unwrap();
1487 match decoded.payload {
1488 DecodedPayload::None => {}
1489 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1490 DecodedPayload::RecordBatch(b) => {
1491 assert_eq!(b.schema(), expected_schema);
1492 let expected_array = expected_arrays.next().unwrap();
1493 let map_array =
1494 downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1495
1496 assert_eq!(map_array, expected_array);
1497 }
1498 }
1499 }
1500 }
1501
1502 #[tokio::test]
1503 async fn test_dictionary_map_resend() {
1504 let mut builder = MapBuilder::new(
1505 None,
1506 StringDictionaryBuilder::<UInt16Type>::new(),
1507 StringDictionaryBuilder::<UInt16Type>::new(),
1508 );
1509
1510 builder.keys().append_value("k1");
1512 builder.values().append_value("a");
1513 builder.keys().append_value("k2");
1514 builder.values().append_null();
1515 builder.keys().append_value("k3");
1516 builder.values().append_value("b");
1517 builder.append(true).unwrap();
1518
1519 let arr1 = builder.finish();
1520
1521 builder.keys().append_value("k1");
1523 builder.values().append_value("c");
1524 builder.keys().append_value("k2");
1525 builder.values().append_null();
1526 builder.keys().append_value("k3");
1527 builder.values().append_value("d");
1528 builder.append(true).unwrap();
1529
1530 let arr2 = builder.finish();
1531
1532 let schema = Arc::new(Schema::new(vec![Field::new_map(
1533 "dict_map",
1534 "entries",
1535 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1536 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1537 false,
1538 false,
1539 )]));
1540
1541 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1542 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1543
1544 verify_flight_round_trip(vec![batch1, batch2]).await;
1545 }
1546
1547 async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1548 let expected_schema = batches.first().unwrap().schema();
1549
1550 #[allow(deprecated)]
1551 let encoder = FlightDataEncoderBuilder::default()
1552 .with_options(IpcWriteOptions::default().with_preserve_dict_id(false))
1553 .with_dictionary_handling(DictionaryHandling::Resend)
1554 .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1555
1556 let mut expected_batches = batches.drain(..);
1557
1558 let mut decoder = FlightDataDecoder::new(encoder);
1559 while let Some(decoded) = decoder.next().await {
1560 let decoded = decoded.unwrap();
1561 match decoded.payload {
1562 DecodedPayload::None => {}
1563 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1564 DecodedPayload::RecordBatch(b) => {
1565 let expected_batch = expected_batches.next().unwrap();
1566 assert_eq!(b, expected_batch);
1567 }
1568 }
1569 }
1570 }
1571
1572 #[test]
1573 fn test_schema_metadata_encoded() {
1574 let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1575 HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1576 );
1577
1578 #[allow(deprecated)]
1579 let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
1580
1581 let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1582 assert!(got.metadata().contains_key("some_key"));
1583 }
1584
1585 #[test]
1586 fn test_encode_no_column_batch() {
1587 let batch = RecordBatch::try_new_with_options(
1588 Arc::new(Schema::empty()),
1589 vec![],
1590 &RecordBatchOptions::new().with_row_count(Some(10)),
1591 )
1592 .expect("cannot create record batch");
1593
1594 hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1595 }
1596
1597 fn make_flight_data(
1598 batch: &RecordBatch,
1599 options: &IpcWriteOptions,
1600 ) -> (Vec<FlightData>, FlightData) {
1601 flight_data_from_arrow_batch(batch, options)
1602 }
1603
1604 fn flight_data_from_arrow_batch(
1605 batch: &RecordBatch,
1606 options: &IpcWriteOptions,
1607 ) -> (Vec<FlightData>, FlightData) {
1608 let data_gen = IpcDataGenerator::default();
1609 #[allow(deprecated)]
1610 let mut dictionary_tracker =
1611 DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
1612
1613 let (encoded_dictionaries, encoded_batch) = data_gen
1614 .encoded_batch(batch, &mut dictionary_tracker, options)
1615 .expect("DictionaryTracker configured above to not error on replacement");
1616
1617 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1618 let flight_batch = encoded_batch.into();
1619
1620 (flight_dictionaries, flight_batch)
1621 }
1622
1623 #[test]
1624 fn test_split_batch_for_grpc_response() {
1625 let max_flight_data_size = 1024;
1626
1627 let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1629 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1630 .expect("cannot create record batch");
1631 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1632 assert_eq!(split.len(), 1);
1633 assert_eq!(batch, split[0]);
1634
1635 let n_rows = max_flight_data_size + 1;
1637 assert!(n_rows % 2 == 1, "should be an odd number");
1638 let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1639 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1640 .expect("cannot create record batch");
1641 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1642 assert_eq!(split.len(), 3);
1643 assert_eq!(
1644 split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1645 n_rows
1646 );
1647 let a = pretty_format_batches(&split).unwrap().to_string();
1648 let b = pretty_format_batches(&[batch]).unwrap().to_string();
1649 assert_eq!(a, b);
1650 }
1651
1652 #[test]
1653 fn test_split_batch_for_grpc_response_sizes() {
1654 verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1656
1657 verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1659
1660 verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1662
1663 verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1665
1666 verify_split(10, 1024, vec![10]);
1668 }
1669
1670 fn verify_split(
1674 num_input_rows: u64,
1675 max_flight_data_size_bytes: usize,
1676 expected_sizes: Vec<usize>,
1677 ) {
1678 let array: UInt64Array = (0..num_input_rows).collect();
1679
1680 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1681 .expect("cannot create record batch");
1682
1683 let input_rows = batch.num_rows();
1684
1685 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1686 let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1687 let output_rows: usize = sizes.iter().sum();
1688
1689 assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1690 assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1691 }
1692
1693 #[tokio::test]
1697 async fn flight_data_size_even() {
1698 let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024));
1699 let i1 = Int16Array::from_iter_values(0..1024);
1700 let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024));
1701 let i2 = Int64Array::from_iter_values(0..1024);
1702
1703 let batch = RecordBatch::try_from_iter(vec![
1704 ("s1", Arc::new(s1) as _),
1705 ("i1", Arc::new(i1) as _),
1706 ("s2", Arc::new(s2) as _),
1707 ("i2", Arc::new(i2) as _),
1708 ])
1709 .unwrap();
1710
1711 verify_encoded_split(batch, 120).await;
1712 }
1713
1714 #[tokio::test]
1715 async fn flight_data_size_uneven_variable_lengths() {
1716 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1718 let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1719
1720 verify_encoded_split(batch, 4312).await;
1723 }
1724
1725 #[tokio::test]
1726 async fn flight_data_size_large_row() {
1727 let array1 = StringArray::from_iter_values(vec![
1729 "*".repeat(500),
1730 "*".repeat(500),
1731 "*".repeat(500),
1732 "*".repeat(500),
1733 ]);
1734 let array2 = StringArray::from_iter_values(vec![
1735 "*".to_string(),
1736 "*".repeat(1000),
1737 "*".repeat(2000),
1738 "*".repeat(4000),
1739 ]);
1740
1741 let array3 = StringArray::from_iter_values(vec![
1742 "*".to_string(),
1743 "*".to_string(),
1744 "*".repeat(1000),
1745 "*".repeat(2000),
1746 ]);
1747
1748 let batch = RecordBatch::try_from_iter(vec![
1749 ("a1", Arc::new(array1) as _),
1750 ("a2", Arc::new(array2) as _),
1751 ("a3", Arc::new(array3) as _),
1752 ])
1753 .unwrap();
1754
1755 verify_encoded_split(batch, 5808).await;
1759 }
1760
1761 #[tokio::test]
1762 async fn flight_data_size_string_dictionary() {
1763 let array: DictionaryArray<Int32Type> = (1..1024)
1765 .map(|i| match i % 3 {
1766 0 => Some("value0"),
1767 1 => Some("value1"),
1768 _ => None,
1769 })
1770 .collect();
1771
1772 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1773
1774 verify_encoded_split(batch, 56).await;
1775 }
1776
1777 #[tokio::test]
1778 async fn flight_data_size_large_dictionary() {
1779 let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1781
1782 let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1783
1784 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1785
1786 verify_encoded_split(batch, 3336).await;
1789 }
1790
1791 #[tokio::test]
1792 async fn flight_data_size_large_dictionary_repeated_non_uniform() {
1793 let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
1795 let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
1796 let array = DictionaryArray::new(keys, Arc::new(values));
1797
1798 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1799
1800 verify_encoded_split(batch, 5288).await;
1803 }
1804
1805 #[tokio::test]
1806 async fn flight_data_size_multiple_dictionaries() {
1807 let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1809 let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
1811 let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
1813
1814 let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
1815 let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
1816 let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
1817
1818 let batch = RecordBatch::try_from_iter(vec![
1819 ("a1", Arc::new(array1) as _),
1820 ("a2", Arc::new(array2) as _),
1821 ("a3", Arc::new(array3) as _),
1822 ])
1823 .unwrap();
1824
1825 verify_encoded_split(batch, 4136).await;
1828 }
1829
1830 fn flight_data_size(d: &FlightData) -> usize {
1832 let flight_descriptor_size = d
1833 .flight_descriptor
1834 .as_ref()
1835 .map(|descriptor| {
1836 let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
1837
1838 std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
1839 })
1840 .unwrap_or(0);
1841
1842 flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
1843 }
1844
1845 async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
1861 let num_rows = batch.num_rows();
1862
1863 let mut max_overage_seen = 0;
1865
1866 for max_flight_data_size in [1024, 2021, 5000] {
1867 println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
1868
1869 let mut stream = FlightDataEncoderBuilder::new()
1870 .with_max_flight_data_size(max_flight_data_size)
1871 .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
1873 .build(futures::stream::iter([Ok(batch.clone())]));
1874
1875 let mut i = 0;
1876 while let Some(data) = stream.next().await.transpose().unwrap() {
1877 let actual_data_size = flight_data_size(&data);
1878
1879 let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
1880
1881 assert!(
1882 actual_overage <= allowed_overage,
1883 "encoded data[{i}]: actual size {actual_data_size}, \
1884 actual_overage: {actual_overage} \
1885 allowed_overage: {allowed_overage}"
1886 );
1887
1888 i += 1;
1889
1890 max_overage_seen = max_overage_seen.max(actual_overage)
1891 }
1892 }
1893
1894 assert_eq!(
1898 allowed_overage, max_overage_seen,
1899 "Specified overage was too high"
1900 );
1901 }
1902}