1use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
19
20use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result};
21
22use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
24
25use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
26use bytes::Bytes;
27use futures::{Stream, StreamExt, ready, stream::BoxStream};
28
29#[derive(Debug)]
145pub struct FlightDataEncoderBuilder {
146 max_flight_data_size: usize,
149 options: IpcWriteOptions,
151 app_metadata: Bytes,
153 schema: Option<SchemaRef>,
155 descriptor: Option<FlightDescriptor>,
157 dictionary_handling: DictionaryHandling,
160}
161
162pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;
167
168impl Default for FlightDataEncoderBuilder {
169 fn default() -> Self {
170 Self {
171 max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
172 options: IpcWriteOptions::default(),
173 app_metadata: Bytes::new(),
174 schema: None,
175 descriptor: None,
176 dictionary_handling: DictionaryHandling::Hydrate,
177 }
178 }
179}
180
181impl FlightDataEncoderBuilder {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self {
198 self.max_flight_data_size = max_flight_data_size;
199 self
200 }
201
202 pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self {
204 self.dictionary_handling = dictionary_handling;
205 self
206 }
207
208 pub fn with_metadata(mut self, app_metadata: Bytes) -> Self {
212 self.app_metadata = app_metadata;
213 self
214 }
215
216 pub fn with_options(mut self, options: IpcWriteOptions) -> Self {
218 self.options = options;
219 self
220 }
221
222 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
227 self.schema = Some(schema);
228 self
229 }
230
231 pub fn with_flight_descriptor(mut self, descriptor: Option<FlightDescriptor>) -> Self {
233 self.descriptor = descriptor;
234 self
235 }
236
237 pub fn build<S>(self, input: S) -> FlightDataEncoder
242 where
243 S: Stream<Item = Result<RecordBatch>> + Send + 'static,
244 {
245 let Self {
246 max_flight_data_size,
247 options,
248 app_metadata,
249 schema,
250 descriptor,
251 dictionary_handling,
252 } = self;
253
254 FlightDataEncoder::new(
255 input.boxed(),
256 schema,
257 max_flight_data_size,
258 options,
259 app_metadata,
260 descriptor,
261 dictionary_handling,
262 )
263 }
264}
265
266pub struct FlightDataEncoder {
270 inner: BoxStream<'static, Result<RecordBatch>>,
272 schema: Option<SchemaRef>,
274 max_flight_data_size: usize,
277 encoder: FlightIpcEncoder,
279 app_metadata: Option<Bytes>,
281 queue: VecDeque<FlightData>,
283 done: bool,
285 descriptor: Option<FlightDescriptor>,
287 dictionary_handling: DictionaryHandling,
290}
291
292impl FlightDataEncoder {
293 fn new(
294 inner: BoxStream<'static, Result<RecordBatch>>,
295 schema: Option<SchemaRef>,
296 max_flight_data_size: usize,
297 options: IpcWriteOptions,
298 app_metadata: Bytes,
299 descriptor: Option<FlightDescriptor>,
300 dictionary_handling: DictionaryHandling,
301 ) -> Self {
302 let mut encoder = Self {
303 inner,
304 schema: None,
305 max_flight_data_size,
306 encoder: FlightIpcEncoder::new(
307 options,
308 dictionary_handling != DictionaryHandling::Resend,
309 ),
310 app_metadata: Some(app_metadata),
311 queue: VecDeque::new(),
312 done: false,
313 descriptor,
314 dictionary_handling,
315 };
316
317 if let Some(schema) = schema {
319 encoder.encode_schema(&schema);
320 }
321
322 encoder
323 }
324
325 pub fn known_schema(&self) -> Option<SchemaRef> {
328 self.schema.clone()
329 }
330
331 fn queue_message(&mut self, mut data: FlightData) {
333 if let Some(descriptor) = self.descriptor.take() {
334 data.flight_descriptor = Some(descriptor);
335 }
336 self.queue.push_back(data);
337 }
338
339 fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
341 for data in datas {
342 self.queue_message(data)
343 }
344 }
345
346 fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
349 let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend;
352 let schema = Arc::new(prepare_schema_for_flight(
353 schema,
354 &mut self.encoder.dictionary_tracker,
355 send_dictionaries,
356 ));
357 let mut schema_flight_data = self.encoder.encode_schema(&schema);
358
359 if let Some(app_metadata) = self.app_metadata.take() {
361 schema_flight_data.app_metadata = app_metadata;
362 }
363 self.queue_message(schema_flight_data);
364 self.schema = Some(schema.clone());
366 schema
367 }
368
369 fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
371 let schema = match &self.schema {
372 Some(schema) => schema.clone(),
373 None => self.encode_schema(batch.schema_ref()),
375 };
376
377 let batch = match self.dictionary_handling {
378 DictionaryHandling::Resend => batch,
379 DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
380 };
381
382 for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
383 let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
384
385 self.queue_messages(flight_dictionaries);
386 self.queue_message(flight_batch);
387 }
388
389 Ok(())
390 }
391}
392
393impl Stream for FlightDataEncoder {
394 type Item = Result<FlightData>;
395
396 fn poll_next(
397 mut self: Pin<&mut Self>,
398 cx: &mut std::task::Context<'_>,
399 ) -> Poll<Option<Self::Item>> {
400 loop {
401 if self.done && self.queue.is_empty() {
402 return Poll::Ready(None);
403 }
404
405 if let Some(data) = self.queue.pop_front() {
407 return Poll::Ready(Some(Ok(data)));
408 }
409
410 let batch = ready!(self.inner.poll_next_unpin(cx));
412
413 match batch {
414 None => {
415 self.done = true;
417 assert!(self.queue.is_empty());
419 return Poll::Ready(None);
420 }
421 Some(Err(e)) => {
422 self.done = true;
424 self.queue.clear();
425 return Poll::Ready(Some(Err(e)));
426 }
427 Some(Ok(batch)) => {
428 if let Err(e) = self.encode_batch(batch) {
430 self.done = true;
431 self.queue.clear();
432 return Poll::Ready(Some(Err(e)));
433 }
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, PartialEq)]
469pub enum DictionaryHandling {
470 Hydrate,
478 Resend,
488}
489
490fn prepare_field_for_flight(
491 field: &FieldRef,
492 dictionary_tracker: &mut DictionaryTracker,
493 send_dictionaries: bool,
494) -> Field {
495 match field.data_type() {
496 DataType::List(inner) => Field::new_list(
497 field.name(),
498 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
499 field.is_nullable(),
500 )
501 .with_metadata(field.metadata().clone()),
502 DataType::LargeList(inner) => Field::new_list(
503 field.name(),
504 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
505 field.is_nullable(),
506 )
507 .with_metadata(field.metadata().clone()),
508 DataType::Struct(fields) => {
509 let new_fields: Vec<Field> = fields
510 .iter()
511 .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
512 .collect();
513 Field::new_struct(field.name(), new_fields, field.is_nullable())
514 .with_metadata(field.metadata().clone())
515 }
516 DataType::Union(fields, mode) => {
517 let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
518 .iter()
519 .map(|(type_id, f)| {
520 (
521 type_id,
522 prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
523 )
524 })
525 .unzip();
526
527 Field::new_union(field.name(), type_ids, new_fields, *mode)
528 }
529 DataType::Dictionary(_, value_type) => {
530 if !send_dictionaries {
531 Field::new(
532 field.name(),
533 value_type.as_ref().clone(),
534 field.is_nullable(),
535 )
536 .with_metadata(field.metadata().clone())
537 } else {
538 dictionary_tracker.next_dict_id();
539 #[allow(deprecated)]
540 Field::new_dict(
541 field.name(),
542 field.data_type().clone(),
543 field.is_nullable(),
544 0,
545 field.dict_is_ordered().unwrap_or_default(),
546 )
547 .with_metadata(field.metadata().clone())
548 }
549 }
550 DataType::Map(inner, sorted) => Field::new(
551 field.name(),
552 DataType::Map(
553 prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
554 *sorted,
555 ),
556 field.is_nullable(),
557 )
558 .with_metadata(field.metadata().clone()),
559 _ => field.as_ref().clone(),
560 }
561}
562
563fn prepare_schema_for_flight(
569 schema: &Schema,
570 dictionary_tracker: &mut DictionaryTracker,
571 send_dictionaries: bool,
572) -> Schema {
573 let fields: Fields = schema
574 .fields()
575 .iter()
576 .map(|field| match field.data_type() {
577 DataType::Dictionary(_, value_type) => {
578 if !send_dictionaries {
579 Field::new(
580 field.name(),
581 value_type.as_ref().clone(),
582 field.is_nullable(),
583 )
584 .with_metadata(field.metadata().clone())
585 } else {
586 dictionary_tracker.next_dict_id();
587 #[allow(deprecated)]
588 Field::new_dict(
589 field.name(),
590 field.data_type().clone(),
591 field.is_nullable(),
592 0,
593 field.dict_is_ordered().unwrap_or_default(),
594 )
595 .with_metadata(field.metadata().clone())
596 }
597 }
598 tpe if tpe.is_nested() => {
599 prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
600 }
601 _ => field.as_ref().clone(),
602 })
603 .collect();
604
605 Schema::new(fields).with_metadata(schema.metadata().clone())
606}
607
608fn split_batch_for_grpc_response(
615 batch: RecordBatch,
616 max_flight_data_size: usize,
617) -> Vec<RecordBatch> {
618 let size = batch
619 .columns()
620 .iter()
621 .map(|col| col.get_buffer_memory_size())
622 .sum::<usize>();
623
624 let n_batches =
625 (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
626 let rows_per_batch = (batch.num_rows() / n_batches).max(1);
627 let mut out = Vec::with_capacity(n_batches + 1);
628
629 let mut offset = 0;
630 while offset < batch.num_rows() {
631 let length = (rows_per_batch).min(batch.num_rows() - offset);
632 out.push(batch.slice(offset, length));
633
634 offset += length;
635 }
636
637 out
638}
639
640struct FlightIpcEncoder {
647 options: IpcWriteOptions,
648 data_gen: IpcDataGenerator,
649 dictionary_tracker: DictionaryTracker,
650 compression_context: CompressionContext,
651}
652
653impl FlightIpcEncoder {
654 fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
655 Self {
656 options,
657 data_gen: IpcDataGenerator::default(),
658 dictionary_tracker: DictionaryTracker::new(error_on_replacement),
659 compression_context: CompressionContext::default(),
660 }
661 }
662
663 fn encode_schema(&self, schema: &Schema) -> FlightData {
665 SchemaAsIpc::new(schema, &self.options).into()
666 }
667
668 fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
671 let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
672 batch,
673 &mut self.dictionary_tracker,
674 &self.options,
675 &mut self.compression_context,
676 )?;
677
678 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
679 let flight_batch = encoded_batch.into();
680
681 Ok((flight_dictionaries, flight_batch))
682 }
683}
684
685fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result<RecordBatch> {
688 let columns = schema
689 .fields()
690 .iter()
691 .zip(batch.columns())
692 .map(|(field, c)| hydrate_dictionary(c, field.data_type()))
693 .collect::<Result<Vec<_>>>()?;
694
695 let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
696
697 Ok(RecordBatch::try_new_with_options(
698 schema, columns, &options,
699 )?)
700}
701
702fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef> {
704 let arr = match (array.data_type(), data_type) {
705 (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
706 let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();
707
708 Arc::new(UnionArray::try_new(
709 fields.clone(),
710 union_arr.type_ids().clone(),
711 None,
712 fields
713 .iter()
714 .map(|(type_id, field)| {
715 Ok(arrow_cast::cast(
716 union_arr.child(type_id),
717 field.data_type(),
718 )?)
719 })
720 .collect::<Result<Vec<_>>>()?,
721 )?)
722 }
723 (_, data_type) => arrow_cast::cast(array, data_type)?,
724 };
725 Ok(arr)
726}
727
728#[cfg(test)]
729mod tests {
730 use crate::decode::{DecodedPayload, FlightDataDecoder};
731 use arrow_array::builder::{
732 GenericByteDictionaryBuilder, ListBuilder, StringDictionaryBuilder, StructBuilder,
733 };
734 use arrow_array::*;
735 use arrow_array::{cast::downcast_array, types::*};
736 use arrow_buffer::ScalarBuffer;
737 use arrow_cast::pretty::pretty_format_batches;
738 use arrow_ipc::MetadataVersion;
739 use arrow_schema::{UnionFields, UnionMode};
740 use builder::{GenericStringBuilder, MapBuilder};
741 use std::collections::HashMap;
742
743 use super::*;
744
745 #[test]
746 fn test_encode_flight_data() {
749 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
751 let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
752
753 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)])
754 .expect("cannot create record batch");
755 let schema = batch.schema_ref();
756
757 let (_, baseline_flight_batch) = make_flight_data(&batch, &options);
758
759 let big_batch = batch.slice(0, batch.num_rows() - 1);
760 let optimized_big_batch =
761 hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize");
762 let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options);
763
764 assert_eq!(
765 baseline_flight_batch.data_body.len(),
766 optimized_big_flight_batch.data_body.len()
767 );
768
769 let small_batch = batch.slice(0, 1);
770 let optimized_small_batch =
771 hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize");
772 let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options);
773
774 assert!(
775 baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len()
776 );
777 }
778
779 #[tokio::test]
780 async fn test_dictionary_hydration() {
781 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
782 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
783
784 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
785 "dict",
786 DataType::UInt16,
787 DataType::Utf8,
788 false,
789 )]));
790 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
791 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
792
793 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
794
795 let encoder = FlightDataEncoderBuilder::default().build(stream);
796 let mut decoder = FlightDataDecoder::new(encoder);
797 let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]);
798 let expected_schema = Arc::new(expected_schema);
799 let mut expected_arrays = vec![
800 StringArray::from(vec!["a", "a", "b"]),
801 StringArray::from(vec!["c", "c", "d"]),
802 ]
803 .into_iter();
804 while let Some(decoded) = decoder.next().await {
805 let decoded = decoded.unwrap();
806 match decoded.payload {
807 DecodedPayload::None => {}
808 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
809 DecodedPayload::RecordBatch(b) => {
810 assert_eq!(b.schema(), expected_schema);
811 let expected_array = expected_arrays.next().unwrap();
812 let actual_array = b.column_by_name("dict").unwrap();
813 let actual_array = downcast_array::<StringArray>(actual_array);
814
815 assert_eq!(actual_array, expected_array);
816 }
817 }
818 }
819 }
820
821 #[tokio::test]
822 async fn test_dictionary_resend() {
823 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
824 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
825
826 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
827 "dict",
828 DataType::UInt16,
829 DataType::Utf8,
830 false,
831 )]));
832 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
833 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
834
835 verify_flight_round_trip(vec![batch1, batch2]).await;
836 }
837
838 #[tokio::test]
839 async fn test_dictionary_hydration_known_schema() {
840 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
841 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
842
843 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
844 "dict",
845 DataType::UInt16,
846 DataType::Utf8,
847 false,
848 )]));
849 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
850 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
851
852 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
853
854 let encoder = FlightDataEncoderBuilder::default()
855 .with_schema(schema)
856 .build(stream);
857 let expected_schema =
858 Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
859 assert_eq!(Some(expected_schema), encoder.known_schema())
860 }
861
862 #[tokio::test]
863 async fn test_dictionary_resend_known_schema() {
864 let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
865 let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();
866
867 let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
868 "dict",
869 DataType::UInt16,
870 DataType::Utf8,
871 false,
872 )]));
873 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
874 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
875
876 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
877
878 let encoder = FlightDataEncoderBuilder::default()
879 .with_dictionary_handling(DictionaryHandling::Resend)
880 .with_schema(schema.clone())
881 .build(stream);
882 assert_eq!(Some(schema), encoder.known_schema())
883 }
884
885 #[tokio::test]
886 async fn test_multiple_dictionaries_resend() {
887 let schema = Arc::new(Schema::new(vec![
889 Field::new_dictionary("dict_1", DataType::UInt16, DataType::Utf8, false),
890 Field::new_dictionary("dict_2", DataType::UInt16, DataType::Utf8, false),
891 ]));
892
893 let arr_one_1: Arc<DictionaryArray<UInt16Type>> =
894 Arc::new(vec!["a", "a", "b"].into_iter().collect());
895 let arr_one_2: Arc<DictionaryArray<UInt16Type>> =
896 Arc::new(vec!["c", "c", "d"].into_iter().collect());
897 let arr_two_1: Arc<DictionaryArray<UInt16Type>> =
898 Arc::new(vec!["b", "a", "c"].into_iter().collect());
899 let arr_two_2: Arc<DictionaryArray<UInt16Type>> =
900 Arc::new(vec!["k", "d", "e"].into_iter().collect());
901 let batch1 =
902 RecordBatch::try_new(schema.clone(), vec![arr_one_1.clone(), arr_one_2.clone()])
903 .unwrap();
904 let batch2 =
905 RecordBatch::try_new(schema.clone(), vec![arr_two_1.clone(), arr_two_2.clone()])
906 .unwrap();
907
908 verify_flight_round_trip(vec![batch1, batch2]).await;
909 }
910
911 #[tokio::test]
912 async fn test_dictionary_list_hydration() {
913 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
914
915 builder.append_value(vec![Some("a"), None, Some("b")]);
916
917 let arr1 = builder.finish();
918
919 builder.append_value(vec![Some("c"), None, Some("d")]);
920
921 let arr2 = builder.finish();
922
923 let schema = Arc::new(Schema::new(vec![Field::new_list(
924 "dict_list",
925 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
926 true,
927 )]));
928
929 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
930 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
931
932 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
933
934 let encoder = FlightDataEncoderBuilder::default().build(stream);
935
936 let mut decoder = FlightDataDecoder::new(encoder);
937 let expected_schema = Schema::new(vec![Field::new_list(
938 "dict_list",
939 Field::new_list_field(DataType::Utf8, true),
940 true,
941 )]);
942
943 let expected_schema = Arc::new(expected_schema);
944
945 let mut expected_arrays = vec![
946 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
947 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
948 ]
949 .into_iter();
950
951 while let Some(decoded) = decoder.next().await {
952 let decoded = decoded.unwrap();
953 match decoded.payload {
954 DecodedPayload::None => {}
955 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
956 DecodedPayload::RecordBatch(b) => {
957 assert_eq!(b.schema(), expected_schema);
958 let expected_array = expected_arrays.next().unwrap();
959 let list_array =
960 downcast_array::<ListArray>(b.column_by_name("dict_list").unwrap());
961 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
962
963 assert_eq!(elem_array, expected_array);
964 }
965 }
966 }
967 }
968
969 #[tokio::test]
970 async fn test_dictionary_list_resend() {
971 let mut builder = ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
972
973 builder.append_value(vec![Some("a"), None, Some("b")]);
974
975 let arr1 = builder.finish();
976
977 builder.append_value(vec![Some("c"), None, Some("d")]);
978
979 let arr2 = builder.finish();
980
981 let schema = Arc::new(Schema::new(vec![Field::new_list(
982 "dict_list",
983 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
984 true,
985 )]));
986
987 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
988 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
989
990 verify_flight_round_trip(vec![batch1, batch2]).await;
991 }
992
993 #[tokio::test]
994 async fn test_dictionary_struct_hydration() {
995 let struct_fields = vec![Field::new_list(
996 "dict_list",
997 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
998 true,
999 )];
1000
1001 let mut struct_builder = StructBuilder::new(
1002 struct_fields.clone(),
1003 vec![Box::new(builder::ListBuilder::new(
1004 StringDictionaryBuilder::<UInt16Type>::new(),
1005 ))],
1006 );
1007
1008 struct_builder
1009 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1010 .unwrap()
1011 .append_value(vec![Some("a"), None, Some("b")]);
1012
1013 struct_builder.append(true);
1014
1015 let arr1 = struct_builder.finish();
1016
1017 struct_builder
1018 .field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1019 .unwrap()
1020 .append_value(vec![Some("c"), None, Some("d")]);
1021 struct_builder.append(true);
1022
1023 let arr2 = struct_builder.finish();
1024
1025 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1026 "struct",
1027 struct_fields,
1028 true,
1029 )]));
1030
1031 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1032 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1033
1034 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1035
1036 let encoder = FlightDataEncoderBuilder::default().build(stream);
1037
1038 let mut decoder = FlightDataDecoder::new(encoder);
1039 let expected_schema = Schema::new(vec![Field::new_struct(
1040 "struct",
1041 vec![Field::new_list(
1042 "dict_list",
1043 Field::new_list_field(DataType::Utf8, true),
1044 true,
1045 )],
1046 true,
1047 )]);
1048
1049 let expected_schema = Arc::new(expected_schema);
1050
1051 let mut expected_arrays = vec![
1052 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1053 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1054 ]
1055 .into_iter();
1056
1057 while let Some(decoded) = decoder.next().await {
1058 let decoded = decoded.unwrap();
1059 match decoded.payload {
1060 DecodedPayload::None => {}
1061 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1062 DecodedPayload::RecordBatch(b) => {
1063 assert_eq!(b.schema(), expected_schema);
1064 let expected_array = expected_arrays.next().unwrap();
1065 let struct_array =
1066 downcast_array::<StructArray>(b.column_by_name("struct").unwrap());
1067 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1068
1069 let elem_array = downcast_array::<StringArray>(list_array.value(0).as_ref());
1070
1071 assert_eq!(elem_array, expected_array);
1072 }
1073 }
1074 }
1075 }
1076
1077 #[tokio::test]
1078 async fn test_dictionary_struct_resend() {
1079 let struct_fields = vec![Field::new_list(
1080 "dict_list",
1081 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1082 true,
1083 )];
1084
1085 let mut struct_builder = StructBuilder::new(
1086 struct_fields.clone(),
1087 vec![Box::new(builder::ListBuilder::new(
1088 StringDictionaryBuilder::<UInt16Type>::new(),
1089 ))],
1090 );
1091
1092 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1093 .unwrap()
1094 .append_value(vec![Some("a"), None, Some("b")]);
1095 struct_builder.append(true);
1096
1097 let arr1 = struct_builder.finish();
1098
1099 struct_builder.field_builder::<ListBuilder<GenericByteDictionaryBuilder<UInt16Type,GenericStringType<i32>>>>(0)
1100 .unwrap()
1101 .append_value(vec![Some("c"), None, Some("d")]);
1102 struct_builder.append(true);
1103
1104 let arr2 = struct_builder.finish();
1105
1106 let schema = Arc::new(Schema::new(vec![Field::new_struct(
1107 "struct",
1108 struct_fields,
1109 true,
1110 )]));
1111
1112 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1113 let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap();
1114
1115 verify_flight_round_trip(vec![batch1, batch2]).await;
1116 }
1117
1118 #[tokio::test]
1119 async fn test_dictionary_union_hydration() {
1120 let struct_fields = vec![Field::new_list(
1121 "dict_list",
1122 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1123 true,
1124 )];
1125
1126 let union_fields = [
1127 (
1128 0,
1129 Arc::new(Field::new_list(
1130 "dict_list",
1131 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1132 true,
1133 )),
1134 ),
1135 (
1136 1,
1137 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1138 ),
1139 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1140 ]
1141 .into_iter()
1142 .collect::<UnionFields>();
1143
1144 let struct_fields = vec![Field::new_list(
1145 "dict_list",
1146 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1147 true,
1148 )];
1149
1150 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1151
1152 builder.append_value(vec![Some("a"), None, Some("b")]);
1153
1154 let arr1 = builder.finish();
1155
1156 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1157 let arr1 = UnionArray::try_new(
1158 union_fields.clone(),
1159 type_id_buffer,
1160 None,
1161 vec![
1162 Arc::new(arr1) as Arc<dyn Array>,
1163 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1164 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1165 ],
1166 )
1167 .unwrap();
1168
1169 builder.append_value(vec![Some("c"), None, Some("d")]);
1170
1171 let arr2 = Arc::new(builder.finish());
1172 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1173
1174 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1175 let arr2 = UnionArray::try_new(
1176 union_fields.clone(),
1177 type_id_buffer,
1178 None,
1179 vec![
1180 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1181 Arc::new(arr2),
1182 new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
1183 ],
1184 )
1185 .unwrap();
1186
1187 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1188 let arr3 = UnionArray::try_new(
1189 union_fields.clone(),
1190 type_id_buffer,
1191 None,
1192 vec![
1193 new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
1194 new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
1195 Arc::new(StringArray::from(vec!["e"])),
1196 ],
1197 )
1198 .unwrap();
1199
1200 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1201 .iter()
1202 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1203 .unzip();
1204 let schema = Arc::new(Schema::new(vec![Field::new_union(
1205 "union",
1206 type_ids.clone(),
1207 union_fields.clone(),
1208 UnionMode::Sparse,
1209 )]));
1210
1211 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1212 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1213 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1214
1215 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]);
1216
1217 let encoder = FlightDataEncoderBuilder::default().build(stream);
1218
1219 let mut decoder = FlightDataDecoder::new(encoder);
1220
1221 let hydrated_struct_fields = vec![Field::new_list(
1222 "dict_list",
1223 Field::new_list_field(DataType::Utf8, true),
1224 true,
1225 )];
1226
1227 let hydrated_union_fields = vec![
1228 Field::new_list(
1229 "dict_list",
1230 Field::new_list_field(DataType::Utf8, true),
1231 true,
1232 ),
1233 Field::new_struct("struct", hydrated_struct_fields.clone(), true),
1234 Field::new("string", DataType::Utf8, true),
1235 ];
1236
1237 let expected_schema = Schema::new(vec![Field::new_union(
1238 "union",
1239 type_ids.clone(),
1240 hydrated_union_fields,
1241 UnionMode::Sparse,
1242 )]);
1243
1244 let expected_schema = Arc::new(expected_schema);
1245
1246 let mut expected_arrays = vec![
1247 StringArray::from_iter(vec![Some("a"), None, Some("b")]),
1248 StringArray::from_iter(vec![Some("c"), None, Some("d")]),
1249 StringArray::from(vec!["e"]),
1250 ]
1251 .into_iter();
1252
1253 let mut batch = 0;
1254 while let Some(decoded) = decoder.next().await {
1255 let decoded = decoded.unwrap();
1256 match decoded.payload {
1257 DecodedPayload::None => {}
1258 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1259 DecodedPayload::RecordBatch(b) => {
1260 assert_eq!(b.schema(), expected_schema);
1261 let expected_array = expected_arrays.next().unwrap();
1262 let union_arr =
1263 downcast_array::<UnionArray>(b.column_by_name("union").unwrap());
1264
1265 let elem_array = match batch {
1266 0 => {
1267 let list_array = downcast_array::<ListArray>(union_arr.child(0));
1268 downcast_array::<StringArray>(list_array.value(0).as_ref())
1269 }
1270 1 => {
1271 let struct_array = downcast_array::<StructArray>(union_arr.child(1));
1272 let list_array = downcast_array::<ListArray>(struct_array.column(0));
1273
1274 downcast_array::<StringArray>(list_array.value(0).as_ref())
1275 }
1276 _ => downcast_array::<StringArray>(union_arr.child(2)),
1277 };
1278
1279 batch += 1;
1280
1281 assert_eq!(elem_array, expected_array);
1282 }
1283 }
1284 }
1285 }
1286
1287 #[tokio::test]
1288 async fn test_dictionary_union_resend() {
1289 let struct_fields = vec![Field::new_list(
1290 "dict_list",
1291 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1292 true,
1293 )];
1294
1295 let union_fields = [
1296 (
1297 0,
1298 Arc::new(Field::new_list(
1299 "dict_list",
1300 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1301 true,
1302 )),
1303 ),
1304 (
1305 1,
1306 Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
1307 ),
1308 (2, Arc::new(Field::new("string", DataType::Utf8, true))),
1309 ]
1310 .into_iter()
1311 .collect::<UnionFields>();
1312
1313 let mut field_types = union_fields.iter().map(|(_, field)| field.data_type());
1314 let dict_list_ty = field_types.next().unwrap();
1315 let struct_ty = field_types.next().unwrap();
1316 let string_ty = field_types.next().unwrap();
1317
1318 let struct_fields = vec![Field::new_list(
1319 "dict_list",
1320 Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
1321 true,
1322 )];
1323
1324 let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::<UInt16Type>::new());
1325
1326 builder.append_value(vec![Some("a"), None, Some("b")]);
1327
1328 let arr1 = builder.finish();
1329
1330 let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
1331 let arr1 = UnionArray::try_new(
1332 union_fields.clone(),
1333 type_id_buffer,
1334 None,
1335 vec![
1336 Arc::new(arr1),
1337 new_null_array(struct_ty, 1),
1338 new_null_array(string_ty, 1),
1339 ],
1340 )
1341 .unwrap();
1342
1343 builder.append_value(vec![Some("c"), None, Some("d")]);
1344
1345 let arr2 = Arc::new(builder.finish());
1346 let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);
1347
1348 let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
1349 let arr2 = UnionArray::try_new(
1350 union_fields.clone(),
1351 type_id_buffer,
1352 None,
1353 vec![
1354 new_null_array(dict_list_ty, 1),
1355 Arc::new(arr2),
1356 new_null_array(string_ty, 1),
1357 ],
1358 )
1359 .unwrap();
1360
1361 let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
1362 let arr3 = UnionArray::try_new(
1363 union_fields.clone(),
1364 type_id_buffer,
1365 None,
1366 vec![
1367 new_null_array(dict_list_ty, 1),
1368 new_null_array(struct_ty, 1),
1369 Arc::new(StringArray::from(vec!["e"])),
1370 ],
1371 )
1372 .unwrap();
1373
1374 let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
1375 .iter()
1376 .map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
1377 .unzip();
1378 let schema = Arc::new(Schema::new(vec![Field::new_union(
1379 "union",
1380 type_ids.clone(),
1381 union_fields.clone(),
1382 UnionMode::Sparse,
1383 )]));
1384
1385 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1386 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1387 let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap();
1388
1389 verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
1390 }
1391
1392 #[tokio::test]
1393 async fn test_dictionary_map_hydration() {
1394 let mut builder = MapBuilder::new(
1395 None,
1396 StringDictionaryBuilder::<UInt16Type>::new(),
1397 StringDictionaryBuilder::<UInt16Type>::new(),
1398 );
1399
1400 builder.keys().append_value("k1");
1402 builder.values().append_value("a");
1403 builder.keys().append_value("k2");
1404 builder.values().append_null();
1405 builder.keys().append_value("k3");
1406 builder.values().append_value("b");
1407 builder.append(true).unwrap();
1408
1409 let arr1 = builder.finish();
1410
1411 builder.keys().append_value("k1");
1413 builder.values().append_value("c");
1414 builder.keys().append_value("k2");
1415 builder.values().append_null();
1416 builder.keys().append_value("k3");
1417 builder.values().append_value("d");
1418 builder.append(true).unwrap();
1419
1420 let arr2 = builder.finish();
1421
1422 let schema = Arc::new(Schema::new(vec![Field::new_map(
1423 "dict_map",
1424 "entries",
1425 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1426 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1427 false,
1428 false,
1429 )]));
1430
1431 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1432 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1433
1434 let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
1435
1436 let encoder = FlightDataEncoderBuilder::default().build(stream);
1437
1438 let mut decoder = FlightDataDecoder::new(encoder);
1439 let expected_schema = Schema::new(vec![Field::new_map(
1440 "dict_map",
1441 "entries",
1442 Field::new("keys", DataType::Utf8, false),
1443 Field::new("values", DataType::Utf8, true),
1444 false,
1445 false,
1446 )]);
1447
1448 let expected_schema = Arc::new(expected_schema);
1449
1450 let mut builder = MapBuilder::new(
1452 None,
1453 GenericStringBuilder::<i32>::new(),
1454 GenericStringBuilder::<i32>::new(),
1455 );
1456
1457 builder.keys().append_value("k1");
1459 builder.values().append_value("a");
1460 builder.keys().append_value("k2");
1461 builder.values().append_null();
1462 builder.keys().append_value("k3");
1463 builder.values().append_value("b");
1464 builder.append(true).unwrap();
1465
1466 let arr1 = builder.finish();
1467
1468 builder.keys().append_value("k1");
1470 builder.values().append_value("c");
1471 builder.keys().append_value("k2");
1472 builder.values().append_null();
1473 builder.keys().append_value("k3");
1474 builder.values().append_value("d");
1475 builder.append(true).unwrap();
1476
1477 let arr2 = builder.finish();
1478
1479 let mut expected_arrays = vec![arr1, arr2].into_iter();
1480
1481 while let Some(decoded) = decoder.next().await {
1482 let decoded = decoded.unwrap();
1483 match decoded.payload {
1484 DecodedPayload::None => {}
1485 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1486 DecodedPayload::RecordBatch(b) => {
1487 assert_eq!(b.schema(), expected_schema);
1488 let expected_array = expected_arrays.next().unwrap();
1489 let map_array =
1490 downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());
1491
1492 assert_eq!(map_array, expected_array);
1493 }
1494 }
1495 }
1496 }
1497
1498 #[tokio::test]
1499 async fn test_dictionary_map_resend() {
1500 let mut builder = MapBuilder::new(
1501 None,
1502 StringDictionaryBuilder::<UInt16Type>::new(),
1503 StringDictionaryBuilder::<UInt16Type>::new(),
1504 );
1505
1506 builder.keys().append_value("k1");
1508 builder.values().append_value("a");
1509 builder.keys().append_value("k2");
1510 builder.values().append_null();
1511 builder.keys().append_value("k3");
1512 builder.values().append_value("b");
1513 builder.append(true).unwrap();
1514
1515 let arr1 = builder.finish();
1516
1517 builder.keys().append_value("k1");
1519 builder.values().append_value("c");
1520 builder.keys().append_value("k2");
1521 builder.values().append_null();
1522 builder.keys().append_value("k3");
1523 builder.values().append_value("d");
1524 builder.append(true).unwrap();
1525
1526 let arr2 = builder.finish();
1527
1528 let schema = Arc::new(Schema::new(vec![Field::new_map(
1529 "dict_map",
1530 "entries",
1531 Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
1532 Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
1533 false,
1534 false,
1535 )]));
1536
1537 let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
1538 let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();
1539
1540 verify_flight_round_trip(vec![batch1, batch2]).await;
1541 }
1542
1543 async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
1544 let expected_schema = batches.first().unwrap().schema();
1545
1546 let encoder = FlightDataEncoderBuilder::default()
1547 .with_options(IpcWriteOptions::default())
1548 .with_dictionary_handling(DictionaryHandling::Resend)
1549 .build(futures::stream::iter(batches.clone().into_iter().map(Ok)));
1550
1551 let mut expected_batches = batches.drain(..);
1552
1553 let mut decoder = FlightDataDecoder::new(encoder);
1554 while let Some(decoded) = decoder.next().await {
1555 let decoded = decoded.unwrap();
1556 match decoded.payload {
1557 DecodedPayload::None => {}
1558 DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
1559 DecodedPayload::RecordBatch(b) => {
1560 let expected_batch = expected_batches.next().unwrap();
1561 assert_eq!(b, expected_batch);
1562 }
1563 }
1564 }
1565 }
1566
1567 #[test]
1568 fn test_schema_metadata_encoded() {
1569 let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata(
1570 HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
1571 );
1572
1573 let mut dictionary_tracker = DictionaryTracker::new(false);
1574
1575 let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
1576 assert!(got.metadata().contains_key("some_key"));
1577 }
1578
1579 #[test]
1580 fn test_encode_no_column_batch() {
1581 let batch = RecordBatch::try_new_with_options(
1582 Arc::new(Schema::empty()),
1583 vec![],
1584 &RecordBatchOptions::new().with_row_count(Some(10)),
1585 )
1586 .expect("cannot create record batch");
1587
1588 hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize");
1589 }
1590
1591 fn make_flight_data(
1592 batch: &RecordBatch,
1593 options: &IpcWriteOptions,
1594 ) -> (Vec<FlightData>, FlightData) {
1595 flight_data_from_arrow_batch(batch, options)
1596 }
1597
1598 fn flight_data_from_arrow_batch(
1599 batch: &RecordBatch,
1600 options: &IpcWriteOptions,
1601 ) -> (Vec<FlightData>, FlightData) {
1602 let data_gen = IpcDataGenerator::default();
1603 let mut dictionary_tracker = DictionaryTracker::new(false);
1604 let mut compression_context = CompressionContext::default();
1605
1606 let (encoded_dictionaries, encoded_batch) = data_gen
1607 .encode(
1608 batch,
1609 &mut dictionary_tracker,
1610 options,
1611 &mut compression_context,
1612 )
1613 .expect("DictionaryTracker configured above to not error on replacement");
1614
1615 let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
1616 let flight_batch = encoded_batch.into();
1617
1618 (flight_dictionaries, flight_batch)
1619 }
1620
1621 #[test]
1622 fn test_split_batch_for_grpc_response() {
1623 let max_flight_data_size = 1024;
1624
1625 let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
1627 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1628 .expect("cannot create record batch");
1629 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1630 assert_eq!(split.len(), 1);
1631 assert_eq!(batch, split[0]);
1632
1633 let n_rows = max_flight_data_size + 1;
1635 assert!(n_rows % 2 == 1, "should be an odd number");
1636 let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
1637 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
1638 .expect("cannot create record batch");
1639 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
1640 assert_eq!(split.len(), 3);
1641 assert_eq!(
1642 split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
1643 n_rows
1644 );
1645 let a = pretty_format_batches(&split).unwrap().to_string();
1646 let b = pretty_format_batches(&[batch]).unwrap().to_string();
1647 assert_eq!(a, b);
1648 }
1649
1650 #[test]
1651 fn test_split_batch_for_grpc_response_sizes() {
1652 verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
1654
1655 verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
1657
1658 verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
1660
1661 verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
1663
1664 verify_split(10, 1024, vec![10]);
1666 }
1667
1668 fn verify_split(
1672 num_input_rows: u64,
1673 max_flight_data_size_bytes: usize,
1674 expected_sizes: Vec<usize>,
1675 ) {
1676 let array: UInt64Array = (0..num_input_rows).collect();
1677
1678 let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)])
1679 .expect("cannot create record batch");
1680
1681 let input_rows = batch.num_rows();
1682
1683 let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
1684 let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
1685 let output_rows: usize = sizes.iter().sum();
1686
1687 assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}");
1688 assert_eq!(input_rows, output_rows, "mismatch for {batch:?}");
1689 }
1690
1691 #[tokio::test]
1695 async fn flight_data_size_even() {
1696 let s1 = StringArray::from_iter_values(std::iter::repeat_n(".10 bytes.", 1024));
1697 let i1 = Int16Array::from_iter_values(0..1024);
1698 let s2 = StringArray::from_iter_values(std::iter::repeat_n("6bytes", 1024));
1699 let i2 = Int64Array::from_iter_values(0..1024);
1700
1701 let batch = RecordBatch::try_from_iter(vec![
1702 ("s1", Arc::new(s1) as _),
1703 ("i1", Arc::new(i1) as _),
1704 ("s2", Arc::new(s2) as _),
1705 ("i2", Arc::new(i2) as _),
1706 ])
1707 .unwrap();
1708
1709 verify_encoded_split(batch, 120).await;
1710 }
1711
1712 #[tokio::test]
1713 async fn flight_data_size_uneven_variable_lengths() {
1714 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i)));
1716 let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap();
1717
1718 verify_encoded_split(batch, 4312).await;
1721 }
1722
1723 #[tokio::test]
1724 async fn flight_data_size_large_row() {
1725 let array1 = StringArray::from_iter_values(vec![
1727 "*".repeat(500),
1728 "*".repeat(500),
1729 "*".repeat(500),
1730 "*".repeat(500),
1731 ]);
1732 let array2 = StringArray::from_iter_values(vec![
1733 "*".to_string(),
1734 "*".repeat(1000),
1735 "*".repeat(2000),
1736 "*".repeat(4000),
1737 ]);
1738
1739 let array3 = StringArray::from_iter_values(vec![
1740 "*".to_string(),
1741 "*".to_string(),
1742 "*".repeat(1000),
1743 "*".repeat(2000),
1744 ]);
1745
1746 let batch = RecordBatch::try_from_iter(vec![
1747 ("a1", Arc::new(array1) as _),
1748 ("a2", Arc::new(array2) as _),
1749 ("a3", Arc::new(array3) as _),
1750 ])
1751 .unwrap();
1752
1753 verify_encoded_split(batch, 5808).await;
1757 }
1758
1759 #[tokio::test]
1760 async fn flight_data_size_string_dictionary() {
1761 let array: DictionaryArray<Int32Type> = (1..1024)
1763 .map(|i| match i % 3 {
1764 0 => Some("value0"),
1765 1 => Some("value1"),
1766 _ => None,
1767 })
1768 .collect();
1769
1770 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1771
1772 verify_encoded_split(batch, 56).await;
1773 }
1774
1775 #[tokio::test]
1776 async fn flight_data_size_large_dictionary() {
1777 let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1779
1780 let array: DictionaryArray<Int32Type> = values.iter().map(|s| Some(s.as_str())).collect();
1781
1782 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1783
1784 verify_encoded_split(batch, 3336).await;
1787 }
1788
1789 #[tokio::test]
1790 async fn flight_data_size_large_dictionary_repeated_non_uniform() {
1791 let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i)));
1793 let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024));
1794 let array = DictionaryArray::new(keys, Arc::new(values));
1795
1796 let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();
1797
1798 verify_encoded_split(batch, 5288).await;
1801 }
1802
1803 #[tokio::test]
1804 async fn flight_data_size_multiple_dictionaries() {
1805 let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect();
1807 let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect();
1809 let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect();
1811
1812 let array1: DictionaryArray<Int32Type> = values1.iter().map(|s| Some(s.as_str())).collect();
1813 let array2: DictionaryArray<Int32Type> = values2.iter().map(|s| Some(s.as_str())).collect();
1814 let array3: DictionaryArray<Int32Type> = values3.iter().map(|s| Some(s.as_str())).collect();
1815
1816 let batch = RecordBatch::try_from_iter(vec![
1817 ("a1", Arc::new(array1) as _),
1818 ("a2", Arc::new(array2) as _),
1819 ("a3", Arc::new(array3) as _),
1820 ])
1821 .unwrap();
1822
1823 verify_encoded_split(batch, 4136).await;
1826 }
1827
1828 fn flight_data_size(d: &FlightData) -> usize {
1830 let flight_descriptor_size = d
1831 .flight_descriptor
1832 .as_ref()
1833 .map(|descriptor| {
1834 let path_len: usize = descriptor.path.iter().map(|p| p.len()).sum();
1835
1836 std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len
1837 })
1838 .unwrap_or(0);
1839
1840 flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len()
1841 }
1842
1843 async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) {
1859 let num_rows = batch.num_rows();
1860
1861 let mut max_overage_seen = 0;
1863
1864 for max_flight_data_size in [1024, 2021, 5000] {
1865 println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}");
1866
1867 let mut stream = FlightDataEncoderBuilder::new()
1868 .with_max_flight_data_size(max_flight_data_size)
1869 .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap())
1871 .build(futures::stream::iter([Ok(batch.clone())]));
1872
1873 let mut i = 0;
1874 while let Some(data) = stream.next().await.transpose().unwrap() {
1875 let actual_data_size = flight_data_size(&data);
1876
1877 let actual_overage = actual_data_size.saturating_sub(max_flight_data_size);
1878
1879 assert!(
1880 actual_overage <= allowed_overage,
1881 "encoded data[{i}]: actual size {actual_data_size}, \
1882 actual_overage: {actual_overage} \
1883 allowed_overage: {allowed_overage}"
1884 );
1885
1886 i += 1;
1887
1888 max_overage_seen = max_overage_seen.max(actual_overage)
1889 }
1890 }
1891
1892 assert_eq!(
1896 allowed_overage, max_overage_seen,
1897 "Specified overage was too high"
1898 );
1899 }
1900}