1use crate::filter::filter_record_batch;
24use arrow_array::types::{BinaryViewType, StringViewType};
25use arrow_array::{downcast_primitive, Array, ArrayRef, BooleanArray, RecordBatch};
26use arrow_schema::{ArrowError, DataType, SchemaRef};
27use std::collections::VecDeque;
28use std::sync::Arc;
29mod byte_view;
33mod generic;
34mod primitive;
35
36use byte_view::InProgressByteViewArray;
37use generic::GenericInProgressArray;
38use primitive::InProgressPrimitiveArray;
39
40#[derive(Debug)]
132pub struct BatchCoalescer {
133 schema: SchemaRef,
135 target_batch_size: usize,
139 in_progress_arrays: Vec<Box<dyn InProgressArray>>,
141 buffered_rows: usize,
143 completed: VecDeque<RecordBatch>,
145}
146
147impl BatchCoalescer {
148 pub fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
156 let in_progress_arrays = schema
157 .fields()
158 .iter()
159 .map(|field| create_in_progress_array(field.data_type(), target_batch_size))
160 .collect::<Vec<_>>();
161
162 Self {
163 schema,
164 target_batch_size,
165 in_progress_arrays,
166 completed: VecDeque::with_capacity(1),
168 buffered_rows: 0,
169 }
170 }
171
172 pub fn schema(&self) -> SchemaRef {
174 Arc::clone(&self.schema)
175 }
176
177 pub fn push_batch_with_filter(
202 &mut self,
203 batch: RecordBatch,
204 filter: &BooleanArray,
205 ) -> Result<(), ArrowError> {
206 let filtered_batch = filter_record_batch(&batch, filter)?;
209 self.push_batch(filtered_batch)
210 }
211
212 pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
239 let (_schema, arrays, mut num_rows) = batch.into_parts();
240 if num_rows == 0 {
241 return Ok(());
242 }
243
244 assert_eq!(arrays.len(), self.in_progress_arrays.len());
246 self.in_progress_arrays
247 .iter_mut()
248 .zip(arrays)
249 .for_each(|(in_progress, array)| {
250 in_progress.set_source(Some(array));
251 });
252
253 let mut offset = 0;
256 while num_rows > (self.target_batch_size - self.buffered_rows) {
257 let remaining_rows = self.target_batch_size - self.buffered_rows;
258 debug_assert!(remaining_rows > 0);
259
260 for in_progress in self.in_progress_arrays.iter_mut() {
262 in_progress.copy_rows(offset, remaining_rows)?;
263 }
264
265 self.buffered_rows += remaining_rows;
266 offset += remaining_rows;
267 num_rows -= remaining_rows;
268
269 self.finish_buffered_batch()?;
270 }
271
272 self.buffered_rows += num_rows;
274 if num_rows > 0 {
275 for in_progress in self.in_progress_arrays.iter_mut() {
276 in_progress.copy_rows(offset, num_rows)?;
277 }
278 }
279
280 if self.buffered_rows >= self.target_batch_size {
282 self.finish_buffered_batch()?;
283 }
284
285 for in_progress in self.in_progress_arrays.iter_mut() {
287 in_progress.set_source(None);
288 }
289
290 Ok(())
291 }
292
293 pub fn finish_buffered_batch(&mut self) -> Result<(), ArrowError> {
301 if self.buffered_rows == 0 {
302 return Ok(());
303 }
304 let new_arrays = self
305 .in_progress_arrays
306 .iter_mut()
307 .map(|array| array.finish())
308 .collect::<Result<Vec<_>, ArrowError>>()?;
309
310 for (array, field) in new_arrays.iter().zip(self.schema.fields().iter()) {
311 debug_assert_eq!(array.data_type(), field.data_type());
312 debug_assert_eq!(array.len(), self.buffered_rows);
313 }
314
315 let batch = unsafe {
317 RecordBatch::new_unchecked(Arc::clone(&self.schema), new_arrays, self.buffered_rows)
318 };
319
320 self.buffered_rows = 0;
321 self.completed.push_back(batch);
322 Ok(())
323 }
324
325 pub fn is_empty(&self) -> bool {
327 self.buffered_rows == 0 && self.completed.is_empty()
328 }
329
330 pub fn has_completed_batch(&self) -> bool {
332 !self.completed.is_empty()
333 }
334
335 pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
337 self.completed.pop_front()
338 }
339}
340
341fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
343 macro_rules! instantiate_primitive {
344 ($t:ty) => {
345 Box::new(InProgressPrimitiveArray::<$t>::new(
346 batch_size,
347 data_type.clone(),
348 ))
349 };
350 }
351
352 downcast_primitive! {
353 data_type => (instantiate_primitive),
355 DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
356 DataType::BinaryView => {
357 Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
358 }
359 _ => Box::new(GenericInProgressArray::new()),
360 }
361}
362
363trait InProgressArray: std::fmt::Debug + Send + Sync {
373 fn set_source(&mut self, source: Option<ArrayRef>);
378
379 fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError>;
385
386 fn finish(&mut self) -> Result<ArrayRef, ArrowError>;
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::concat::concat_batches;
394 use arrow_array::builder::StringViewBuilder;
395 use arrow_array::cast::AsArray;
396 use arrow_array::{
397 BinaryViewArray, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
398 TimestampNanosecondArray, UInt32Array,
399 };
400 use arrow_schema::{DataType, Field, Schema};
401 use rand::{Rng, SeedableRng};
402 use std::ops::Range;
403
404 #[test]
405 fn test_coalesce() {
406 let batch = uint32_batch(0..8);
407 Test::new()
408 .with_batches(std::iter::repeat_n(batch, 10))
409 .with_batch_size(21)
411 .with_expected_output_sizes(vec![21, 21, 21, 17])
412 .run();
413 }
414
415 #[test]
416 fn test_coalesce_one_by_one() {
417 let batch = uint32_batch(0..1); Test::new()
419 .with_batches(std::iter::repeat_n(batch, 97))
420 .with_batch_size(20)
422 .with_expected_output_sizes(vec![20, 20, 20, 20, 17])
423 .run();
424 }
425
426 #[test]
427 fn test_coalesce_empty() {
428 let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
429
430 Test::new()
431 .with_batches(vec![])
432 .with_schema(schema)
433 .with_batch_size(21)
434 .with_expected_output_sizes(vec![])
435 .run();
436 }
437
438 #[test]
439 fn test_single_large_batch_greater_than_target() {
440 let batch = uint32_batch(0..4096);
442 Test::new()
443 .with_batch(batch)
444 .with_batch_size(1000)
445 .with_expected_output_sizes(vec![1000, 1000, 1000, 1000, 96])
446 .run();
447 }
448
449 #[test]
450 fn test_single_large_batch_smaller_than_target() {
451 let batch = uint32_batch(0..4096);
453 Test::new()
454 .with_batch(batch)
455 .with_batch_size(8192)
456 .with_expected_output_sizes(vec![4096])
457 .run();
458 }
459
460 #[test]
461 fn test_single_large_batch_equal_to_target() {
462 let batch = uint32_batch(0..4096);
464 Test::new()
465 .with_batch(batch)
466 .with_batch_size(4096)
467 .with_expected_output_sizes(vec![4096])
468 .run();
469 }
470
471 #[test]
472 fn test_single_large_batch_equally_divisible_in_target() {
473 let batch = uint32_batch(0..4096);
475 Test::new()
476 .with_batch(batch)
477 .with_batch_size(1024)
478 .with_expected_output_sizes(vec![1024, 1024, 1024, 1024])
479 .run();
480 }
481
482 #[test]
483 fn test_empty_schema() {
484 let schema = Schema::empty();
485 let batch = RecordBatch::new_empty(schema.into());
486 Test::new()
487 .with_batch(batch)
488 .with_expected_output_sizes(vec![])
489 .run();
490 }
491
492 #[test]
494 fn test_coalesce_filtered_001() {
495 let mut filter_builder = RandomFilterBuilder {
496 num_rows: 8000,
497 selectivity: 0.001,
498 seed: 0,
499 };
500
501 let mut test = Test::new();
505 for _ in 0..10 {
506 test = test
507 .with_batch(multi_column_batch(0..8000))
508 .with_filter(filter_builder.next_filter())
509 }
510 test.with_batch_size(15)
511 .with_expected_output_sizes(vec![15, 15, 15, 13])
512 .run();
513 }
514
515 #[test]
517 fn test_coalesce_filtered_01() {
518 let mut filter_builder = RandomFilterBuilder {
519 num_rows: 8000,
520 selectivity: 0.01,
521 seed: 0,
522 };
523
524 let mut test = Test::new();
528 for _ in 0..10 {
529 test = test
530 .with_batch(multi_column_batch(0..8000))
531 .with_filter(filter_builder.next_filter())
532 }
533 test.with_batch_size(128)
534 .with_expected_output_sizes(vec![128, 128, 128, 128, 128, 128, 15])
535 .run();
536 }
537
538 #[test]
540 fn test_coalesce_filtered_1() {
541 let mut filter_builder = RandomFilterBuilder {
542 num_rows: 8000,
543 selectivity: 0.1,
544 seed: 0,
545 };
546
547 let mut test = Test::new();
551 for _ in 0..10 {
552 test = test
553 .with_batch(multi_column_batch(0..8000))
554 .with_filter(filter_builder.next_filter())
555 }
556 test.with_batch_size(1024)
557 .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 840])
558 .run();
559 }
560
561 #[test]
563 fn test_coalesce_filtered_90() {
564 let mut filter_builder = RandomFilterBuilder {
565 num_rows: 800,
566 selectivity: 0.90,
567 seed: 0,
568 };
569
570 let mut test = Test::new();
574 for _ in 0..10 {
575 test = test
576 .with_batch(multi_column_batch(0..800))
577 .with_filter(filter_builder.next_filter())
578 }
579 test.with_batch_size(1024)
580 .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 13])
581 .run();
582 }
583
584 #[test]
585 fn test_coalesce_non_null() {
586 Test::new()
587 .with_batch(uint32_batch_non_null(0..3000))
589 .with_batch(uint32_batch_non_null(0..1040))
590 .with_batch_size(1024)
591 .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
592 .run();
593 }
594 #[test]
595 fn test_utf8_split() {
596 Test::new()
597 .with_batch(utf8_batch(0..3000))
599 .with_batch(utf8_batch(0..1040))
600 .with_batch_size(1024)
601 .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
602 .run();
603 }
604
605 #[test]
606 fn test_string_view_no_views() {
607 let output_batches = Test::new()
608 .with_batch(stringview_batch([Some("foo"), Some("bar")]))
610 .with_batch(stringview_batch([Some("baz"), Some("qux")]))
611 .with_expected_output_sizes(vec![4])
612 .run();
613
614 expect_buffer_layout(
615 col_as_string_view("c0", output_batches.first().unwrap()),
616 vec![],
617 );
618 }
619
620 #[test]
621 fn test_string_view_batch_small_no_compact() {
622 let batch = stringview_batch_repeated(1000, [Some("a"), Some("b"), Some("c")]);
624 let output_batches = Test::new()
625 .with_batch(batch.clone())
626 .with_expected_output_sizes(vec![1000])
627 .run();
628
629 let array = col_as_string_view("c0", &batch);
630 let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
631 assert_eq!(array.data_buffers().len(), 0);
632 assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); expect_buffer_layout(gc_array, vec![]);
635 }
636
637 #[test]
638 fn test_string_view_batch_large_no_compact() {
639 let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
641 let output_batches = Test::new()
642 .with_batch(batch.clone())
643 .with_batch_size(1000)
644 .with_expected_output_sizes(vec![1000])
645 .run();
646
647 let array = col_as_string_view("c0", &batch);
648 let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
649 assert_eq!(array.data_buffers().len(), 5);
650 assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); expect_buffer_layout(
653 gc_array,
654 vec![
655 ExpectedLayout {
656 len: 8190,
657 capacity: 8192,
658 },
659 ExpectedLayout {
660 len: 8190,
661 capacity: 8192,
662 },
663 ExpectedLayout {
664 len: 8190,
665 capacity: 8192,
666 },
667 ExpectedLayout {
668 len: 8190,
669 capacity: 8192,
670 },
671 ExpectedLayout {
672 len: 2240,
673 capacity: 8192,
674 },
675 ],
676 );
677 }
678
679 #[test]
680 fn test_string_view_batch_small_with_buffers_no_compact() {
681 let short_strings = std::iter::repeat(Some("SmallString"));
683 let long_strings = std::iter::once(Some("This string is longer than 12 bytes"));
684 let values = short_strings.take(20).chain(long_strings);
686 let batch = stringview_batch_repeated(1000, values)
687 .slice(5, 10);
689 let output_batches = Test::new()
690 .with_batch(batch.clone())
691 .with_batch_size(1000)
692 .with_expected_output_sizes(vec![10])
693 .run();
694
695 let array = col_as_string_view("c0", &batch);
696 let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
697 assert_eq!(array.data_buffers().len(), 1); assert_eq!(gc_array.data_buffers().len(), 0); }
700
701 #[test]
702 fn test_string_view_batch_large_slice_compact() {
703 let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")])
705 .slice(11, 22);
707
708 let output_batches = Test::new()
709 .with_batch(batch.clone())
710 .with_batch_size(1000)
711 .with_expected_output_sizes(vec![22])
712 .run();
713
714 let array = col_as_string_view("c0", &batch);
715 let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
716 assert_eq!(array.data_buffers().len(), 5);
717
718 expect_buffer_layout(
719 gc_array,
720 vec![ExpectedLayout {
721 len: 770,
722 capacity: 8192,
723 }],
724 );
725 }
726
727 #[test]
728 fn test_string_view_mixed() {
729 let large_view_batch =
730 stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
731 let small_view_batch = stringview_batch_repeated(1000, [Some("SmallString")]);
732 let mixed_batch = stringview_batch_repeated(
733 1000,
734 [Some("This string is longer than 12 bytes"), Some("Small")],
735 );
736 let mixed_batch_nulls = stringview_batch_repeated(
737 1000,
738 [
739 Some("This string is longer than 12 bytes"),
740 Some("Small"),
741 None,
742 ],
743 );
744
745 let output_batches = Test::new()
748 .with_batch(large_view_batch.clone())
749 .with_batch(small_view_batch)
750 .with_batch(large_view_batch.slice(10, 20))
752 .with_batch(mixed_batch_nulls)
753 .with_batch(large_view_batch.slice(10, 20))
755 .with_batch(mixed_batch)
756 .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
757 .run();
758
759 expect_buffer_layout(
760 col_as_string_view("c0", output_batches.first().unwrap()),
761 vec![
762 ExpectedLayout {
763 len: 8190,
764 capacity: 8192,
765 },
766 ExpectedLayout {
767 len: 8190,
768 capacity: 8192,
769 },
770 ExpectedLayout {
771 len: 8190,
772 capacity: 8192,
773 },
774 ExpectedLayout {
775 len: 8190,
776 capacity: 8192,
777 },
778 ExpectedLayout {
779 len: 2240,
780 capacity: 8192,
781 },
782 ],
783 );
784 }
785
786 #[test]
787 fn test_string_view_many_small_compact() {
788 let batch = stringview_batch_repeated(
791 200,
792 [Some("This string is 28 bytes long"), Some("small string")],
793 );
794 let output_batches = Test::new()
795 .with_batch(batch.clone())
798 .with_batch(batch.clone())
799 .with_batch(batch.clone())
800 .with_batch(batch.clone())
801 .with_batch(batch.clone())
802 .with_batch(batch.clone())
803 .with_batch(batch.clone())
804 .with_batch(batch.clone())
805 .with_batch(batch.clone())
806 .with_batch(batch.clone())
807 .with_batch_size(8000)
808 .with_expected_output_sizes(vec![2000]) .run();
810
811 expect_buffer_layout(
813 col_as_string_view("c0", output_batches.first().unwrap()),
814 vec![
815 ExpectedLayout {
816 len: 8176,
817 capacity: 8192,
818 },
819 ExpectedLayout {
820 len: 16380,
821 capacity: 16384,
822 },
823 ExpectedLayout {
824 len: 3444,
825 capacity: 32768,
826 },
827 ],
828 );
829 }
830
831 #[test]
832 fn test_string_view_many_small_boundary() {
833 let batch = stringview_batch_repeated(100, [Some("This string is a power of two=32")]);
835 let output_batches = Test::new()
836 .with_batches(std::iter::repeat_n(batch, 20))
837 .with_batch_size(900)
838 .with_expected_output_sizes(vec![900, 900, 200])
839 .run();
840
841 expect_buffer_layout(
843 col_as_string_view("c0", output_batches.first().unwrap()),
844 vec![
845 ExpectedLayout {
846 len: 8192,
847 capacity: 8192,
848 },
849 ExpectedLayout {
850 len: 16384,
851 capacity: 16384,
852 },
853 ExpectedLayout {
854 len: 4224,
855 capacity: 32768,
856 },
857 ],
858 );
859 }
860
861 #[test]
862 fn test_string_view_large_small() {
863 let mixed_batch = stringview_batch_repeated(
865 200,
866 [Some("This string is 28 bytes long"), Some("small string")],
867 );
868 let all_large = stringview_batch_repeated(
870 50,
871 [Some(
872 "This buffer has only large strings in it so there are no buffer copies",
873 )],
874 );
875
876 let output_batches = Test::new()
877 .with_batch(mixed_batch.clone())
880 .with_batch(mixed_batch.clone())
881 .with_batch(all_large.clone())
882 .with_batch(mixed_batch.clone())
883 .with_batch(all_large.clone())
884 .with_batch(mixed_batch.clone())
885 .with_batch(mixed_batch.clone())
886 .with_batch(all_large.clone())
887 .with_batch(mixed_batch.clone())
888 .with_batch(all_large.clone())
889 .with_batch_size(8000)
890 .with_expected_output_sizes(vec![1400])
891 .run();
892
893 expect_buffer_layout(
894 col_as_string_view("c0", output_batches.first().unwrap()),
895 vec![
896 ExpectedLayout {
897 len: 8190,
898 capacity: 8192,
899 },
900 ExpectedLayout {
901 len: 16366,
902 capacity: 16384,
903 },
904 ExpectedLayout {
905 len: 6244,
906 capacity: 32768,
907 },
908 ],
909 );
910 }
911
912 #[test]
913 fn test_binary_view() {
914 let values: Vec<Option<&[u8]>> = vec![
915 Some(b"foo"),
916 None,
917 Some(b"A longer string that is more than 12 bytes"),
918 ];
919
920 let binary_view =
921 BinaryViewArray::from_iter(std::iter::repeat(values.iter()).flatten().take(1000));
922 let batch =
923 RecordBatch::try_from_iter(vec![("c0", Arc::new(binary_view) as ArrayRef)]).unwrap();
924
925 Test::new()
926 .with_batch(batch.clone())
927 .with_batch(batch.clone())
928 .with_batch_size(512)
929 .with_expected_output_sizes(vec![512, 512, 512, 464])
930 .run();
931 }
932
933 #[derive(Debug, Clone, PartialEq)]
934 struct ExpectedLayout {
935 len: usize,
936 capacity: usize,
937 }
938
939 fn expect_buffer_layout(array: &StringViewArray, expected: Vec<ExpectedLayout>) {
941 let actual = array
942 .data_buffers()
943 .iter()
944 .map(|b| ExpectedLayout {
945 len: b.len(),
946 capacity: b.capacity(),
947 })
948 .collect::<Vec<_>>();
949
950 assert_eq!(
951 actual, expected,
952 "Expected buffer layout {expected:#?} but got {actual:#?}"
953 );
954 }
955
956 #[derive(Debug, Clone)]
961 struct Test {
962 input_batches: Vec<RecordBatch>,
964 filters: Vec<BooleanArray>,
969 schema: Option<SchemaRef>,
971 expected_output_sizes: Vec<usize>,
973 target_batch_size: usize,
975 }
976
977 impl Default for Test {
978 fn default() -> Self {
979 Self {
980 input_batches: vec![],
981 filters: vec![],
982 schema: None,
983 expected_output_sizes: vec![],
984 target_batch_size: 1024,
985 }
986 }
987 }
988
989 impl Test {
990 fn new() -> Self {
991 Self::default()
992 }
993
994 fn with_batch_size(mut self, target_batch_size: usize) -> Self {
996 self.target_batch_size = target_batch_size;
997 self
998 }
999
1000 fn with_batch(mut self, batch: RecordBatch) -> Self {
1002 self.input_batches.push(batch);
1003 self
1004 }
1005
1006 fn with_filter(mut self, filter: BooleanArray) -> Self {
1008 self.filters.push(filter);
1009 self
1010 }
1011
1012 fn with_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
1014 self.input_batches.extend(batches);
1015 self
1016 }
1017
1018 fn with_schema(mut self, schema: SchemaRef) -> Self {
1020 self.schema = Some(schema);
1021 self
1022 }
1023
1024 fn with_expected_output_sizes(mut self, sizes: impl IntoIterator<Item = usize>) -> Self {
1026 self.expected_output_sizes.extend(sizes);
1027 self
1028 }
1029
1030 fn run(self) -> Vec<RecordBatch> {
1034 let expected_output = self.expected_output();
1035 let schema = self.schema();
1036
1037 let Self {
1038 input_batches,
1039 filters,
1040 schema: _,
1041 target_batch_size,
1042 expected_output_sizes,
1043 } = self;
1044
1045 let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
1046
1047 let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), target_batch_size);
1048
1049 let mut filters = filters.into_iter();
1051 for batch in input_batches {
1052 if let Some(filter) = filters.next() {
1053 coalescer.push_batch_with_filter(batch, &filter).unwrap();
1054 } else {
1055 coalescer.push_batch(batch).unwrap();
1056 }
1057 }
1058 assert_eq!(schema, coalescer.schema());
1059
1060 if had_input {
1061 assert!(!coalescer.is_empty(), "Coalescer should not be empty");
1062 } else {
1063 assert!(coalescer.is_empty(), "Coalescer should be empty");
1064 }
1065
1066 coalescer.finish_buffered_batch().unwrap();
1067 if had_input {
1068 assert!(
1069 coalescer.has_completed_batch(),
1070 "Coalescer should have completed batches"
1071 );
1072 }
1073
1074 let mut output_batches = vec![];
1075 while let Some(batch) = coalescer.next_completed_batch() {
1076 output_batches.push(batch);
1077 }
1078
1079 let mut starting_idx = 0;
1081 let actual_output_sizes: Vec<usize> =
1082 output_batches.iter().map(|b| b.num_rows()).collect();
1083 assert_eq!(
1084 expected_output_sizes, actual_output_sizes,
1085 "Unexpected number of rows in output batches\n\
1086 Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
1087 );
1088 let iter = expected_output_sizes
1089 .iter()
1090 .zip(output_batches.iter())
1091 .enumerate();
1092
1093 for (i, (expected_size, batch)) in iter {
1094 let expected_batch = expected_output.slice(starting_idx, *expected_size);
1097 let expected_batch = normalize_batch(expected_batch);
1098 let batch = normalize_batch(batch.clone());
1099 assert_eq!(
1100 expected_batch, batch,
1101 "Unexpected content in batch {i}:\
1102 \n\nExpected:\n{expected_batch:#?}\n\nActual:\n{batch:#?}"
1103 );
1104 starting_idx += *expected_size;
1105 }
1106 output_batches
1107 }
1108
1109 fn schema(&self) -> SchemaRef {
1112 self.schema
1113 .clone()
1114 .unwrap_or_else(|| Arc::clone(&self.input_batches[0].schema()))
1115 }
1116
1117 fn expected_output(&self) -> RecordBatch {
1119 let schema = self.schema();
1120 if self.filters.is_empty() {
1121 return concat_batches(&schema, &self.input_batches).unwrap();
1122 }
1123
1124 let mut filters = self.filters.iter();
1125 let filtered_batches = self
1126 .input_batches
1127 .iter()
1128 .map(|batch| {
1129 if let Some(filter) = filters.next() {
1130 filter_record_batch(batch, filter).unwrap()
1131 } else {
1132 batch.clone()
1133 }
1134 })
1135 .collect::<Vec<_>>();
1136 concat_batches(&schema, &filtered_batches).unwrap()
1137 }
1138 }
1139
1140 fn uint32_batch(range: Range<u32>) -> RecordBatch {
1143 let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));
1144
1145 let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
1146 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1147 }
1148
1149 fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
1151 let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
1152
1153 let array = UInt32Array::from_iter_values(range);
1154 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1155 }
1156
1157 fn utf8_batch(range: Range<u32>) -> RecordBatch {
1160 let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Utf8, true)]));
1161
1162 let array = StringArray::from_iter(range.map(|i| {
1163 if i % 3 == 0 {
1164 None
1165 } else {
1166 Some(format!("value{i}"))
1167 }
1168 }));
1169
1170 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1171 }
1172
1173 fn stringview_batch<'a>(values: impl IntoIterator<Item = Option<&'a str>>) -> RecordBatch {
1175 let schema = Arc::new(Schema::new(vec![Field::new(
1176 "c0",
1177 DataType::Utf8View,
1178 false,
1179 )]));
1180
1181 let array = StringViewArray::from_iter(values);
1182 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1183 }
1184
1185 fn stringview_batch_repeated<'a>(
1188 num_rows: usize,
1189 values: impl IntoIterator<Item = Option<&'a str>>,
1190 ) -> RecordBatch {
1191 let schema = Arc::new(Schema::new(vec![Field::new(
1192 "c0",
1193 DataType::Utf8View,
1194 true,
1195 )]));
1196
1197 let values: Vec<_> = values.into_iter().collect();
1199 let values_iter = std::iter::repeat(values.iter())
1200 .flatten()
1201 .cloned()
1202 .take(num_rows);
1203
1204 let mut builder = StringViewBuilder::with_capacity(100).with_fixed_block_size(8192);
1205 for val in values_iter {
1206 builder.append_option(val);
1207 }
1208
1209 let array = builder.finish();
1210 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1211 }
1212
1213 fn multi_column_batch(range: Range<i32>) -> RecordBatch {
1215 let int64_array = Int64Array::from_iter(range.clone().map(|v| {
1216 if v % 5 == 0 {
1217 None
1218 } else {
1219 Some(v as i64)
1220 }
1221 }));
1222 let string_view_array = StringViewArray::from_iter(range.clone().map(|v| {
1223 if v % 5 == 0 {
1224 None
1225 } else if v % 7 == 0 {
1226 Some(format!("This is a string longer than 12 bytes{v}"))
1227 } else {
1228 Some(format!("Short {v}"))
1229 }
1230 }));
1231 let string_array = StringArray::from_iter(range.clone().map(|v| {
1232 if v % 11 == 0 {
1233 None
1234 } else {
1235 Some(format!("Value {v}"))
1236 }
1237 }));
1238 let timestamp_array = TimestampNanosecondArray::from_iter(range.map(|v| {
1239 if v % 3 == 0 {
1240 None
1241 } else {
1242 Some(v as i64 * 1000) }
1244 }))
1245 .with_timezone("America/New_York");
1246
1247 RecordBatch::try_from_iter(vec![
1248 ("int64", Arc::new(int64_array) as ArrayRef),
1249 ("stringview", Arc::new(string_view_array) as ArrayRef),
1250 ("string", Arc::new(string_array) as ArrayRef),
1251 ("timestamp", Arc::new(timestamp_array) as ArrayRef),
1252 ])
1253 .unwrap()
1254 }
1255
1256 #[derive(Debug)]
1262 struct RandomFilterBuilder {
1263 num_rows: usize,
1264 selectivity: f64,
1265 seed: u64,
1268 }
1269 impl RandomFilterBuilder {
1270 fn next_filter(&mut self) -> BooleanArray {
1273 assert!(self.selectivity >= 0.0 && self.selectivity <= 1.0);
1274 let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
1275 self.seed += 1;
1276 BooleanArray::from_iter(
1277 (0..self.num_rows)
1278 .map(|_| rng.random_bool(self.selectivity))
1279 .map(Some),
1280 )
1281 }
1282 }
1283
1284 fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b StringViewArray {
1286 batch
1287 .column_by_name(name)
1288 .expect("column not found")
1289 .as_string_view_opt()
1290 .expect("column is not a string view")
1291 }
1292
1293 fn normalize_batch(batch: RecordBatch) -> RecordBatch {
1296 let (schema, mut columns, row_count) = batch.into_parts();
1298
1299 for column in columns.iter_mut() {
1300 let Some(string_view) = column.as_string_view_opt() else {
1301 continue;
1302 };
1303
1304 let mut builder = StringViewBuilder::new();
1307 for s in string_view.iter() {
1308 builder.append_option(s);
1309 }
1310 *column = Arc::new(builder.finish());
1312 }
1313
1314 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
1315 RecordBatch::try_new_with_options(schema, columns, &options).unwrap()
1316 }
1317}