arrow_select/
coalesce.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`BatchCoalescer`]  concatenates multiple [`RecordBatch`]es after
19//! operations such as [`filter`] and [`take`].
20//!
21//! [`filter`]: crate::filter::filter
22//! [`take`]: crate::take::take
23use crate::filter::filter_record_batch;
24use arrow_array::types::{BinaryViewType, StringViewType};
25use arrow_array::{downcast_primitive, Array, ArrayRef, BooleanArray, RecordBatch};
26use arrow_schema::{ArrowError, DataType, SchemaRef};
27use std::collections::VecDeque;
28use std::sync::Arc;
29// Originally From DataFusion's coalesce module:
30// https://github.com/apache/datafusion/blob/9d2f04996604e709ee440b65f41e7b882f50b788/datafusion/physical-plan/src/coalesce/mod.rs#L26-L25
31
32mod byte_view;
33mod generic;
34mod primitive;
35
36use byte_view::InProgressByteViewArray;
37use generic::GenericInProgressArray;
38use primitive::InProgressPrimitiveArray;
39
40/// Concatenate multiple [`RecordBatch`]es
41///
42/// Implements the common pattern of incrementally creating output
43/// [`RecordBatch`]es of a specific size from an input stream of
44/// [`RecordBatch`]es.
45///
46/// This is useful after operations such as [`filter`] and [`take`] that produce
47/// smaller batches, and we want to coalesce them into larger batches for
48/// further processing.
49///
50/// # Motivation
51///
52/// If we use [`concat_batches`] to implement the same functionality, there are 2 potential issues:
53/// 1. At least 2x peak memory (holding the input and output of concat)
54/// 2. 2 copies of the data (to create the output of filter and then create the output of concat)
55///
56/// See: <https://github.com/apache/arrow-rs/issues/6692> for more discussions
57/// about the motivation.
58///
59/// [`filter`]: crate::filter::filter
60/// [`take`]: crate::take::take
61/// [`concat_batches`]: crate::concat::concat_batches
62///
63/// # Example
64/// ```
65/// use arrow_array::record_batch;
66/// use arrow_select::coalesce::{BatchCoalescer};
67/// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
68/// let batch2 = record_batch!(("a", Int32, [4, 5])).unwrap();
69///
70/// // Create a `BatchCoalescer` that will produce batches with at least 4 rows
71/// let target_batch_size = 4;
72/// let mut coalescer = BatchCoalescer::new(batch1.schema(), 4);
73///
74/// // push the batches
75/// coalescer.push_batch(batch1).unwrap();
76/// // only pushed 3 rows (not yet 4, enough to produce a batch)
77/// assert!(coalescer.next_completed_batch().is_none());
78/// coalescer.push_batch(batch2).unwrap();
79/// // now we have 5 rows, so we can produce a batch
80/// let finished = coalescer.next_completed_batch().unwrap();
81/// // 4 rows came out (target batch size is 4)
82/// let expected = record_batch!(("a", Int32, [1, 2, 3, 4])).unwrap();
83/// assert_eq!(finished, expected);
84///
85/// // Have no more input, but still have an in-progress batch
86/// assert!(coalescer.next_completed_batch().is_none());
87/// // We can finish the batch, which will produce the remaining rows
88/// coalescer.finish_buffered_batch().unwrap();
89/// let expected = record_batch!(("a", Int32, [5])).unwrap();
90/// assert_eq!(coalescer.next_completed_batch().unwrap(), expected);
91///
92/// // The coalescer is now empty
93/// assert!(coalescer.next_completed_batch().is_none());
94/// ```
95///
96/// # Background
97///
98/// Generally speaking, larger [`RecordBatch`]es are more efficient to process
99/// than smaller [`RecordBatch`]es (until the CPU cache is exceeded) because
100/// there is fixed processing overhead per batch. This coalescer builds up these
101/// larger batches incrementally.
102///
103/// ```text
104/// ┌────────────────────┐
105/// │    RecordBatch     │
106/// │   num_rows = 100   │
107/// └────────────────────┘                 ┌────────────────────┐
108///                                        │                    │
109/// ┌────────────────────┐     Coalesce    │                    │
110/// │                    │      Batches    │                    │
111/// │    RecordBatch     │                 │                    │
112/// │   num_rows = 200   │  ─ ─ ─ ─ ─ ─ ▶  │                    │
113/// │                    │                 │    RecordBatch     │
114/// │                    │                 │   num_rows = 400   │
115/// └────────────────────┘                 │                    │
116///                                        │                    │
117/// ┌────────────────────┐                 │                    │
118/// │                    │                 │                    │
119/// │    RecordBatch     │                 │                    │
120/// │   num_rows = 100   │                 └────────────────────┘
121/// │                    │
122/// └────────────────────┘
123/// ```
124///
125/// # Notes:
126///
127/// 1. Output rows are produced in the same order as the input rows
128///
129/// 2. The output is a sequence of batches, with all but the last being at exactly
130///    `target_batch_size` rows.
131#[derive(Debug)]
132pub struct BatchCoalescer {
133    /// The input schema
134    schema: SchemaRef,
135    /// The target batch size (and thus size for views allocation). This is a
136    /// hard limit: the output batch will be exactly `target_batch_size`,
137    /// rather than possibly being slightly above.
138    target_batch_size: usize,
139    /// In-progress arrays
140    in_progress_arrays: Vec<Box<dyn InProgressArray>>,
141    /// Buffered row count. Always less than `batch_size`
142    buffered_rows: usize,
143    /// Completed batches
144    completed: VecDeque<RecordBatch>,
145}
146
147impl BatchCoalescer {
148    /// Create a new `BatchCoalescer`
149    ///
150    /// # Arguments
151    /// - `schema` - the schema of the output batches
152    /// - `target_batch_size` - the number of rows in each output batch.
153    ///   Typical values are `4096` or `8192` rows.
154    ///
155    pub fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
156        let in_progress_arrays = schema
157            .fields()
158            .iter()
159            .map(|field| create_in_progress_array(field.data_type(), target_batch_size))
160            .collect::<Vec<_>>();
161
162        Self {
163            schema,
164            target_batch_size,
165            in_progress_arrays,
166            // We will for sure store at least one completed batch
167            completed: VecDeque::with_capacity(1),
168            buffered_rows: 0,
169        }
170    }
171
172    /// Return the schema of the output batches
173    pub fn schema(&self) -> SchemaRef {
174        Arc::clone(&self.schema)
175    }
176
177    /// Push a batch into the Coalescer after applying a filter
178    ///
179    /// This is semantically equivalent of calling [`Self::push_batch`]
180    /// with the results from  [`filter_record_batch`]
181    ///
182    /// # Example
183    /// ```
184    /// # use arrow_array::{record_batch, BooleanArray};
185    /// # use arrow_select::coalesce::BatchCoalescer;
186    /// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
187    /// let batch2 = record_batch!(("a", Int32, [4, 5, 6])).unwrap();
188    /// // Apply a filter to each batch to pick the first and last row
189    /// let filter = BooleanArray::from(vec![true, false, true]);
190    /// // create a new Coalescer that targets creating 1000 row batches
191    /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000);
192    /// coalescer.push_batch_with_filter(batch1, &filter);
193    /// coalescer.push_batch_with_filter(batch2, &filter);
194    /// // finsh and retrieve the created batch
195    /// coalescer.finish_buffered_batch().unwrap();
196    /// let completed_batch = coalescer.next_completed_batch().unwrap();
197    /// // filtered out 2 and 5:
198    /// let expected_batch = record_batch!(("a", Int32, [1, 3, 4, 6])).unwrap();
199    /// assert_eq!(completed_batch, expected_batch);
200    /// ```
201    pub fn push_batch_with_filter(
202        &mut self,
203        batch: RecordBatch,
204        filter: &BooleanArray,
205    ) -> Result<(), ArrowError> {
206        // TODO: optimize this to avoid materializing (copying the results
207        // of filter to a new batch)
208        let filtered_batch = filter_record_batch(&batch, filter)?;
209        self.push_batch(filtered_batch)
210    }
211
212    /// Push all the rows from `batch` into the Coalescer
213    ///
214    /// When buffered data plus incoming rows reach `target_batch_size` ,
215    /// completed batches are generated eagerly and can be retrieved via
216    /// [`Self::next_completed_batch()`].
217    /// Output batches contain exactly `target_batch_size` rows, so the tail of
218    /// the input batch may remain buffered.
219    /// Remaining partial data either waits for future input batches or can be
220    /// materialized immediately by calling [`Self::finish_buffered_batch()`].
221    ///
222    /// # Example
223    /// ```
224    /// # use arrow_array::record_batch;
225    /// # use arrow_select::coalesce::BatchCoalescer;
226    /// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
227    /// let batch2 = record_batch!(("a", Int32, [4, 5, 6])).unwrap();
228    /// // create a new Coalescer that targets creating 1000 row batches
229    /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000);
230    /// coalescer.push_batch(batch1);
231    /// coalescer.push_batch(batch2);
232    /// // finsh and retrieve the created batch
233    /// coalescer.finish_buffered_batch().unwrap();
234    /// let completed_batch = coalescer.next_completed_batch().unwrap();
235    /// let expected_batch = record_batch!(("a", Int32, [1, 2, 3, 4, 5, 6])).unwrap();
236    /// assert_eq!(completed_batch, expected_batch);
237    /// ```
238    pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
239        let (_schema, arrays, mut num_rows) = batch.into_parts();
240        if num_rows == 0 {
241            return Ok(());
242        }
243
244        // setup input rows
245        assert_eq!(arrays.len(), self.in_progress_arrays.len());
246        self.in_progress_arrays
247            .iter_mut()
248            .zip(arrays)
249            .for_each(|(in_progress, array)| {
250                in_progress.set_source(Some(array));
251            });
252
253        // If pushing this batch would exceed the target batch size,
254        // finish the current batch and start a new one
255        let mut offset = 0;
256        while num_rows > (self.target_batch_size - self.buffered_rows) {
257            let remaining_rows = self.target_batch_size - self.buffered_rows;
258            debug_assert!(remaining_rows > 0);
259
260            // Copy remaining_rows from each array
261            for in_progress in self.in_progress_arrays.iter_mut() {
262                in_progress.copy_rows(offset, remaining_rows)?;
263            }
264
265            self.buffered_rows += remaining_rows;
266            offset += remaining_rows;
267            num_rows -= remaining_rows;
268
269            self.finish_buffered_batch()?;
270        }
271
272        // Add any the remaining rows to the buffer
273        self.buffered_rows += num_rows;
274        if num_rows > 0 {
275            for in_progress in self.in_progress_arrays.iter_mut() {
276                in_progress.copy_rows(offset, num_rows)?;
277            }
278        }
279
280        // If we have reached the target batch size, finalize the buffered batch
281        if self.buffered_rows >= self.target_batch_size {
282            self.finish_buffered_batch()?;
283        }
284
285        // clear in progress sources (to allow the memory to be freed)
286        for in_progress in self.in_progress_arrays.iter_mut() {
287            in_progress.set_source(None);
288        }
289
290        Ok(())
291    }
292
293    /// Concatenates any buffered batches into a single `RecordBatch` and
294    /// clears any output buffers
295    ///
296    /// Normally this is called when the input stream is exhausted, and
297    /// we want to finalize the last batch of rows.
298    ///
299    /// See [`Self::next_completed_batch()`] for the completed batches.
300    pub fn finish_buffered_batch(&mut self) -> Result<(), ArrowError> {
301        if self.buffered_rows == 0 {
302            return Ok(());
303        }
304        let new_arrays = self
305            .in_progress_arrays
306            .iter_mut()
307            .map(|array| array.finish())
308            .collect::<Result<Vec<_>, ArrowError>>()?;
309
310        for (array, field) in new_arrays.iter().zip(self.schema.fields().iter()) {
311            debug_assert_eq!(array.data_type(), field.data_type());
312            debug_assert_eq!(array.len(), self.buffered_rows);
313        }
314
315        // SAFETY: each array was created of the correct type and length.
316        let batch = unsafe {
317            RecordBatch::new_unchecked(Arc::clone(&self.schema), new_arrays, self.buffered_rows)
318        };
319
320        self.buffered_rows = 0;
321        self.completed.push_back(batch);
322        Ok(())
323    }
324
325    /// Returns true if there is any buffered data
326    pub fn is_empty(&self) -> bool {
327        self.buffered_rows == 0 && self.completed.is_empty()
328    }
329
330    /// Returns true if there are any completed batches
331    pub fn has_completed_batch(&self) -> bool {
332        !self.completed.is_empty()
333    }
334
335    /// Removes and returns the next completed batch, if any.
336    pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
337        self.completed.pop_front()
338    }
339}
340
341/// Return a new `InProgressArray` for the given data type
342fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
343    macro_rules! instantiate_primitive {
344        ($t:ty) => {
345            Box::new(InProgressPrimitiveArray::<$t>::new(
346                batch_size,
347                data_type.clone(),
348            ))
349        };
350    }
351
352    downcast_primitive! {
353        // Instantiate InProgressPrimitiveArray for each primitive type
354        data_type => (instantiate_primitive),
355        DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
356        DataType::BinaryView => {
357            Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
358        }
359        _ => Box::new(GenericInProgressArray::new()),
360    }
361}
362
363/// Incrementally builds up arrays
364///
365/// [`GenericInProgressArray`] is the default implementation that buffers
366/// arrays and uses other kernels concatenates them when finished.
367///
368/// Some types have specialized implementations for this array types (e.g.,
369/// [`StringViewArray`], etc.).
370///
371/// [`StringViewArray`]: arrow_array::StringViewArray
372trait InProgressArray: std::fmt::Debug + Send + Sync {
373    /// Set the source array.
374    ///
375    /// Calls to [`Self::copy_rows`] will copy rows from this array into the
376    /// current in-progress array
377    fn set_source(&mut self, source: Option<ArrayRef>);
378
379    /// Copy rows from the current source array into the in-progress array
380    ///
381    /// The source array is set by [`Self::set_source`].
382    ///
383    /// Return an error if the source array is not set
384    fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError>;
385
386    /// Finish the currently in-progress array and return it as an `ArrayRef`
387    fn finish(&mut self) -> Result<ArrayRef, ArrowError>;
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::concat::concat_batches;
394    use arrow_array::builder::StringViewBuilder;
395    use arrow_array::cast::AsArray;
396    use arrow_array::{
397        BinaryViewArray, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
398        TimestampNanosecondArray, UInt32Array,
399    };
400    use arrow_schema::{DataType, Field, Schema};
401    use rand::{Rng, SeedableRng};
402    use std::ops::Range;
403
404    #[test]
405    fn test_coalesce() {
406        let batch = uint32_batch(0..8);
407        Test::new()
408            .with_batches(std::iter::repeat_n(batch, 10))
409            // expected output is exactly 21 rows (except for the final batch)
410            .with_batch_size(21)
411            .with_expected_output_sizes(vec![21, 21, 21, 17])
412            .run();
413    }
414
415    #[test]
416    fn test_coalesce_one_by_one() {
417        let batch = uint32_batch(0..1); // single row input
418        Test::new()
419            .with_batches(std::iter::repeat_n(batch, 97))
420            // expected output is exactly 20 rows (except for the final batch)
421            .with_batch_size(20)
422            .with_expected_output_sizes(vec![20, 20, 20, 20, 17])
423            .run();
424    }
425
426    #[test]
427    fn test_coalesce_empty() {
428        let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
429
430        Test::new()
431            .with_batches(vec![])
432            .with_schema(schema)
433            .with_batch_size(21)
434            .with_expected_output_sizes(vec![])
435            .run();
436    }
437
438    #[test]
439    fn test_single_large_batch_greater_than_target() {
440        // test a single large batch
441        let batch = uint32_batch(0..4096);
442        Test::new()
443            .with_batch(batch)
444            .with_batch_size(1000)
445            .with_expected_output_sizes(vec![1000, 1000, 1000, 1000, 96])
446            .run();
447    }
448
449    #[test]
450    fn test_single_large_batch_smaller_than_target() {
451        // test a single large batch
452        let batch = uint32_batch(0..4096);
453        Test::new()
454            .with_batch(batch)
455            .with_batch_size(8192)
456            .with_expected_output_sizes(vec![4096])
457            .run();
458    }
459
460    #[test]
461    fn test_single_large_batch_equal_to_target() {
462        // test a single large batch
463        let batch = uint32_batch(0..4096);
464        Test::new()
465            .with_batch(batch)
466            .with_batch_size(4096)
467            .with_expected_output_sizes(vec![4096])
468            .run();
469    }
470
471    #[test]
472    fn test_single_large_batch_equally_divisible_in_target() {
473        // test a single large batch
474        let batch = uint32_batch(0..4096);
475        Test::new()
476            .with_batch(batch)
477            .with_batch_size(1024)
478            .with_expected_output_sizes(vec![1024, 1024, 1024, 1024])
479            .run();
480    }
481
482    #[test]
483    fn test_empty_schema() {
484        let schema = Schema::empty();
485        let batch = RecordBatch::new_empty(schema.into());
486        Test::new()
487            .with_batch(batch)
488            .with_expected_output_sizes(vec![])
489            .run();
490    }
491
492    /// Coalesce multiple batches, 80k rows, with a 0.1% selectivity filter
493    #[test]
494    fn test_coalesce_filtered_001() {
495        let mut filter_builder = RandomFilterBuilder {
496            num_rows: 8000,
497            selectivity: 0.001,
498            seed: 0,
499        };
500
501        // add 10 batches of 8000 rows each
502        // 80k rows, selecting 0.1% means 80 rows
503        // not exactly 80 as the rows are random;
504        let mut test = Test::new();
505        for _ in 0..10 {
506            test = test
507                .with_batch(multi_column_batch(0..8000))
508                .with_filter(filter_builder.next_filter())
509        }
510        test.with_batch_size(15)
511            .with_expected_output_sizes(vec![15, 15, 15, 13])
512            .run();
513    }
514
515    /// Coalesce multiple batches, 80k rows, with a 1% selectivity filter
516    #[test]
517    fn test_coalesce_filtered_01() {
518        let mut filter_builder = RandomFilterBuilder {
519            num_rows: 8000,
520            selectivity: 0.01,
521            seed: 0,
522        };
523
524        // add 10 batches of 8000 rows each
525        // 80k rows, selecting 1% means 800 rows
526        // not exactly 800 as the rows are random;
527        let mut test = Test::new();
528        for _ in 0..10 {
529            test = test
530                .with_batch(multi_column_batch(0..8000))
531                .with_filter(filter_builder.next_filter())
532        }
533        test.with_batch_size(128)
534            .with_expected_output_sizes(vec![128, 128, 128, 128, 128, 128, 15])
535            .run();
536    }
537
538    /// Coalesce multiple batches, 80k rows, with a 10% selectivity filter
539    #[test]
540    fn test_coalesce_filtered_1() {
541        let mut filter_builder = RandomFilterBuilder {
542            num_rows: 8000,
543            selectivity: 0.1,
544            seed: 0,
545        };
546
547        // add 10 batches of 8000 rows each
548        // 80k rows, selecting 10% means 8000 rows
549        // not exactly 800 as the rows are random;
550        let mut test = Test::new();
551        for _ in 0..10 {
552            test = test
553                .with_batch(multi_column_batch(0..8000))
554                .with_filter(filter_builder.next_filter())
555        }
556        test.with_batch_size(1024)
557            .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 840])
558            .run();
559    }
560
561    /// Coalesce multiple batches, 8k rows, with a 90% selectivity filter
562    #[test]
563    fn test_coalesce_filtered_90() {
564        let mut filter_builder = RandomFilterBuilder {
565            num_rows: 800,
566            selectivity: 0.90,
567            seed: 0,
568        };
569
570        // add 10 batches of 800 rows each
571        // 8k rows, selecting 99% means 7200 rows
572        // not exactly 7200 as the rows are random;
573        let mut test = Test::new();
574        for _ in 0..10 {
575            test = test
576                .with_batch(multi_column_batch(0..800))
577                .with_filter(filter_builder.next_filter())
578        }
579        test.with_batch_size(1024)
580            .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 13])
581            .run();
582    }
583
584    #[test]
585    fn test_coalesce_non_null() {
586        Test::new()
587            // 4040 rows of unit32
588            .with_batch(uint32_batch_non_null(0..3000))
589            .with_batch(uint32_batch_non_null(0..1040))
590            .with_batch_size(1024)
591            .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
592            .run();
593    }
594    #[test]
595    fn test_utf8_split() {
596        Test::new()
597            // 4040 rows of utf8 strings in total, split into batches of 1024
598            .with_batch(utf8_batch(0..3000))
599            .with_batch(utf8_batch(0..1040))
600            .with_batch_size(1024)
601            .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
602            .run();
603    }
604
605    #[test]
606    fn test_string_view_no_views() {
607        let output_batches = Test::new()
608            // both input batches have no views, so no need to compact
609            .with_batch(stringview_batch([Some("foo"), Some("bar")]))
610            .with_batch(stringview_batch([Some("baz"), Some("qux")]))
611            .with_expected_output_sizes(vec![4])
612            .run();
613
614        expect_buffer_layout(
615            col_as_string_view("c0", output_batches.first().unwrap()),
616            vec![],
617        );
618    }
619
620    #[test]
621    fn test_string_view_batch_small_no_compact() {
622        // view with only short strings (no buffers) --> no need to compact
623        let batch = stringview_batch_repeated(1000, [Some("a"), Some("b"), Some("c")]);
624        let output_batches = Test::new()
625            .with_batch(batch.clone())
626            .with_expected_output_sizes(vec![1000])
627            .run();
628
629        let array = col_as_string_view("c0", &batch);
630        let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
631        assert_eq!(array.data_buffers().len(), 0);
632        assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction
633
634        expect_buffer_layout(gc_array, vec![]);
635    }
636
637    #[test]
638    fn test_string_view_batch_large_no_compact() {
639        // view with large strings (has buffers) but full --> no need to compact
640        let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
641        let output_batches = Test::new()
642            .with_batch(batch.clone())
643            .with_batch_size(1000)
644            .with_expected_output_sizes(vec![1000])
645            .run();
646
647        let array = col_as_string_view("c0", &batch);
648        let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
649        assert_eq!(array.data_buffers().len(), 5);
650        assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction
651
652        expect_buffer_layout(
653            gc_array,
654            vec![
655                ExpectedLayout {
656                    len: 8190,
657                    capacity: 8192,
658                },
659                ExpectedLayout {
660                    len: 8190,
661                    capacity: 8192,
662                },
663                ExpectedLayout {
664                    len: 8190,
665                    capacity: 8192,
666                },
667                ExpectedLayout {
668                    len: 8190,
669                    capacity: 8192,
670                },
671                ExpectedLayout {
672                    len: 2240,
673                    capacity: 8192,
674                },
675            ],
676        );
677    }
678
679    #[test]
680    fn test_string_view_batch_small_with_buffers_no_compact() {
681        // view with buffers but only short views
682        let short_strings = std::iter::repeat(Some("SmallString"));
683        let long_strings = std::iter::once(Some("This string is longer than 12 bytes"));
684        // 20 short strings, then a long ones
685        let values = short_strings.take(20).chain(long_strings);
686        let batch = stringview_batch_repeated(1000, values)
687            // take only 10 short strings (no long ones)
688            .slice(5, 10);
689        let output_batches = Test::new()
690            .with_batch(batch.clone())
691            .with_batch_size(1000)
692            .with_expected_output_sizes(vec![10])
693            .run();
694
695        let array = col_as_string_view("c0", &batch);
696        let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
697        assert_eq!(array.data_buffers().len(), 1); // input has one buffer
698        assert_eq!(gc_array.data_buffers().len(), 0); // output has no buffers as only short strings
699    }
700
701    #[test]
702    fn test_string_view_batch_large_slice_compact() {
703        // view with large strings (has buffers) and only partially used  --> no need to compact
704        let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")])
705            // slice only 22 rows, so most of the buffer is not used
706            .slice(11, 22);
707
708        let output_batches = Test::new()
709            .with_batch(batch.clone())
710            .with_batch_size(1000)
711            .with_expected_output_sizes(vec![22])
712            .run();
713
714        let array = col_as_string_view("c0", &batch);
715        let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
716        assert_eq!(array.data_buffers().len(), 5);
717
718        expect_buffer_layout(
719            gc_array,
720            vec![ExpectedLayout {
721                len: 770,
722                capacity: 8192,
723            }],
724        );
725    }
726
727    #[test]
728    fn test_string_view_mixed() {
729        let large_view_batch =
730            stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
731        let small_view_batch = stringview_batch_repeated(1000, [Some("SmallString")]);
732        let mixed_batch = stringview_batch_repeated(
733            1000,
734            [Some("This string is longer than 12 bytes"), Some("Small")],
735        );
736        let mixed_batch_nulls = stringview_batch_repeated(
737            1000,
738            [
739                Some("This string is longer than 12 bytes"),
740                Some("Small"),
741                None,
742            ],
743        );
744
745        // Several batches with mixed inline / non inline
746        // 4k rows in
747        let output_batches = Test::new()
748            .with_batch(large_view_batch.clone())
749            .with_batch(small_view_batch)
750            // this batch needs to be compacted (less than 1/2 full)
751            .with_batch(large_view_batch.slice(10, 20))
752            .with_batch(mixed_batch_nulls)
753            // this batch needs to be compacted (less than 1/2 full)
754            .with_batch(large_view_batch.slice(10, 20))
755            .with_batch(mixed_batch)
756            .with_expected_output_sizes(vec![1024, 1024, 1024, 968])
757            .run();
758
759        expect_buffer_layout(
760            col_as_string_view("c0", output_batches.first().unwrap()),
761            vec![
762                ExpectedLayout {
763                    len: 8190,
764                    capacity: 8192,
765                },
766                ExpectedLayout {
767                    len: 8190,
768                    capacity: 8192,
769                },
770                ExpectedLayout {
771                    len: 8190,
772                    capacity: 8192,
773                },
774                ExpectedLayout {
775                    len: 8190,
776                    capacity: 8192,
777                },
778                ExpectedLayout {
779                    len: 2240,
780                    capacity: 8192,
781                },
782            ],
783        );
784    }
785
786    #[test]
787    fn test_string_view_many_small_compact() {
788        // 200 rows alternating long (28) and short (≤12) strings.
789        // Only the 100 long strings go into data buffers: 100 × 28 = 2800.
790        let batch = stringview_batch_repeated(
791            200,
792            [Some("This string is 28 bytes long"), Some("small string")],
793        );
794        let output_batches = Test::new()
795            // First allocated buffer is 8kb.
796            // Appending 10 batches of 2800 bytes will use 2800 * 10 = 14kb (8kb, an 16kb and 32kbkb)
797            .with_batch(batch.clone())
798            .with_batch(batch.clone())
799            .with_batch(batch.clone())
800            .with_batch(batch.clone())
801            .with_batch(batch.clone())
802            .with_batch(batch.clone())
803            .with_batch(batch.clone())
804            .with_batch(batch.clone())
805            .with_batch(batch.clone())
806            .with_batch(batch.clone())
807            .with_batch_size(8000)
808            .with_expected_output_sizes(vec![2000]) // only 1000 rows total
809            .run();
810
811        // expect a nice even distribution of buffers
812        expect_buffer_layout(
813            col_as_string_view("c0", output_batches.first().unwrap()),
814            vec![
815                ExpectedLayout {
816                    len: 8176,
817                    capacity: 8192,
818                },
819                ExpectedLayout {
820                    len: 16380,
821                    capacity: 16384,
822                },
823                ExpectedLayout {
824                    len: 3444,
825                    capacity: 32768,
826                },
827            ],
828        );
829    }
830
831    #[test]
832    fn test_string_view_many_small_boundary() {
833        // The strings are designed to exactly fit into buffers that are powers of 2 long
834        let batch = stringview_batch_repeated(100, [Some("This string is a power of two=32")]);
835        let output_batches = Test::new()
836            .with_batches(std::iter::repeat_n(batch, 20))
837            .with_batch_size(900)
838            .with_expected_output_sizes(vec![900, 900, 200])
839            .run();
840
841        // expect each buffer to be entirely full except the last one
842        expect_buffer_layout(
843            col_as_string_view("c0", output_batches.first().unwrap()),
844            vec![
845                ExpectedLayout {
846                    len: 8192,
847                    capacity: 8192,
848                },
849                ExpectedLayout {
850                    len: 16384,
851                    capacity: 16384,
852                },
853                ExpectedLayout {
854                    len: 4224,
855                    capacity: 32768,
856                },
857            ],
858        );
859    }
860
861    #[test]
862    fn test_string_view_large_small() {
863        // The strings are 37 bytes long, so each batch has 100 * 28 = 2800 bytes
864        let mixed_batch = stringview_batch_repeated(
865            200,
866            [Some("This string is 28 bytes long"), Some("small string")],
867        );
868        // These strings aren't copied, this array has an 8k buffer
869        let all_large = stringview_batch_repeated(
870            50,
871            [Some(
872                "This buffer has only large strings in it so there are no buffer copies",
873            )],
874        );
875
876        let output_batches = Test::new()
877            // First allocated buffer is 8kb.
878            // Appending five batches of 2800 bytes will use 2800 * 10 = 28kb (8kb, an 16kb and 32kbkb)
879            .with_batch(mixed_batch.clone())
880            .with_batch(mixed_batch.clone())
881            .with_batch(all_large.clone())
882            .with_batch(mixed_batch.clone())
883            .with_batch(all_large.clone())
884            .with_batch(mixed_batch.clone())
885            .with_batch(mixed_batch.clone())
886            .with_batch(all_large.clone())
887            .with_batch(mixed_batch.clone())
888            .with_batch(all_large.clone())
889            .with_batch_size(8000)
890            .with_expected_output_sizes(vec![1400])
891            .run();
892
893        expect_buffer_layout(
894            col_as_string_view("c0", output_batches.first().unwrap()),
895            vec![
896                ExpectedLayout {
897                    len: 8190,
898                    capacity: 8192,
899                },
900                ExpectedLayout {
901                    len: 16366,
902                    capacity: 16384,
903                },
904                ExpectedLayout {
905                    len: 6244,
906                    capacity: 32768,
907                },
908            ],
909        );
910    }
911
912    #[test]
913    fn test_binary_view() {
914        let values: Vec<Option<&[u8]>> = vec![
915            Some(b"foo"),
916            None,
917            Some(b"A longer string that is more than 12 bytes"),
918        ];
919
920        let binary_view =
921            BinaryViewArray::from_iter(std::iter::repeat(values.iter()).flatten().take(1000));
922        let batch =
923            RecordBatch::try_from_iter(vec![("c0", Arc::new(binary_view) as ArrayRef)]).unwrap();
924
925        Test::new()
926            .with_batch(batch.clone())
927            .with_batch(batch.clone())
928            .with_batch_size(512)
929            .with_expected_output_sizes(vec![512, 512, 512, 464])
930            .run();
931    }
932
933    #[derive(Debug, Clone, PartialEq)]
934    struct ExpectedLayout {
935        len: usize,
936        capacity: usize,
937    }
938
939    /// Asserts that the buffer layout of the specified StringViewArray matches the expected layout
940    fn expect_buffer_layout(array: &StringViewArray, expected: Vec<ExpectedLayout>) {
941        let actual = array
942            .data_buffers()
943            .iter()
944            .map(|b| ExpectedLayout {
945                len: b.len(),
946                capacity: b.capacity(),
947            })
948            .collect::<Vec<_>>();
949
950        assert_eq!(
951            actual, expected,
952            "Expected buffer layout {expected:#?} but got {actual:#?}"
953        );
954    }
955
956    /// Test for [`BatchCoalescer`]
957    ///
958    /// Pushes the input batches to the coalescer and verifies that the resulting
959    /// batches have the expected number of rows and contents.
960    #[derive(Debug, Clone)]
961    struct Test {
962        /// Batches to feed to the coalescer.
963        input_batches: Vec<RecordBatch>,
964        /// Filters to apply to the corresponding input batches.
965        ///
966        /// If there are no filters for the input batches, the batch will be
967        /// pushed as is.
968        filters: Vec<BooleanArray>,
969        /// The schema. If not provided, the first batch's schema is used.
970        schema: Option<SchemaRef>,
971        /// Expected output sizes of the resulting batches
972        expected_output_sizes: Vec<usize>,
973        /// target batch size (default to 1024)
974        target_batch_size: usize,
975    }
976
977    impl Default for Test {
978        fn default() -> Self {
979            Self {
980                input_batches: vec![],
981                filters: vec![],
982                schema: None,
983                expected_output_sizes: vec![],
984                target_batch_size: 1024,
985            }
986        }
987    }
988
989    impl Test {
990        fn new() -> Self {
991            Self::default()
992        }
993
994        /// Set the target batch size
995        fn with_batch_size(mut self, target_batch_size: usize) -> Self {
996            self.target_batch_size = target_batch_size;
997            self
998        }
999
1000        /// Extend the input batches with `batch`
1001        fn with_batch(mut self, batch: RecordBatch) -> Self {
1002            self.input_batches.push(batch);
1003            self
1004        }
1005
1006        /// Extend the filters with `filter`
1007        fn with_filter(mut self, filter: BooleanArray) -> Self {
1008            self.filters.push(filter);
1009            self
1010        }
1011
1012        /// Extends the input batches with `batches`
1013        fn with_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
1014            self.input_batches.extend(batches);
1015            self
1016        }
1017
1018        /// Specifies the schema for the test
1019        fn with_schema(mut self, schema: SchemaRef) -> Self {
1020            self.schema = Some(schema);
1021            self
1022        }
1023
1024        /// Extends `sizes` to expected output sizes
1025        fn with_expected_output_sizes(mut self, sizes: impl IntoIterator<Item = usize>) -> Self {
1026            self.expected_output_sizes.extend(sizes);
1027            self
1028        }
1029
1030        /// Runs the test -- see documentation on [`Test`] for details
1031        ///
1032        /// Returns the resulting output batches
1033        fn run(self) -> Vec<RecordBatch> {
1034            let expected_output = self.expected_output();
1035            let schema = self.schema();
1036
1037            let Self {
1038                input_batches,
1039                filters,
1040                schema: _,
1041                target_batch_size,
1042                expected_output_sizes,
1043            } = self;
1044
1045            let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
1046
1047            let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), target_batch_size);
1048
1049            // feed input batches and filters to the coalescer
1050            let mut filters = filters.into_iter();
1051            for batch in input_batches {
1052                if let Some(filter) = filters.next() {
1053                    coalescer.push_batch_with_filter(batch, &filter).unwrap();
1054                } else {
1055                    coalescer.push_batch(batch).unwrap();
1056                }
1057            }
1058            assert_eq!(schema, coalescer.schema());
1059
1060            if had_input {
1061                assert!(!coalescer.is_empty(), "Coalescer should not be empty");
1062            } else {
1063                assert!(coalescer.is_empty(), "Coalescer should be empty");
1064            }
1065
1066            coalescer.finish_buffered_batch().unwrap();
1067            if had_input {
1068                assert!(
1069                    coalescer.has_completed_batch(),
1070                    "Coalescer should have completed batches"
1071                );
1072            }
1073
1074            let mut output_batches = vec![];
1075            while let Some(batch) = coalescer.next_completed_batch() {
1076                output_batches.push(batch);
1077            }
1078
1079            // make sure we got the expected number of output batches and content
1080            let mut starting_idx = 0;
1081            let actual_output_sizes: Vec<usize> =
1082                output_batches.iter().map(|b| b.num_rows()).collect();
1083            assert_eq!(
1084                expected_output_sizes, actual_output_sizes,
1085                "Unexpected number of rows in output batches\n\
1086                Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
1087            );
1088            let iter = expected_output_sizes
1089                .iter()
1090                .zip(output_batches.iter())
1091                .enumerate();
1092
1093            for (i, (expected_size, batch)) in iter {
1094                // compare the contents of the batch after normalization (using
1095                // `==` compares the underlying memory layout too)
1096                let expected_batch = expected_output.slice(starting_idx, *expected_size);
1097                let expected_batch = normalize_batch(expected_batch);
1098                let batch = normalize_batch(batch.clone());
1099                assert_eq!(
1100                    expected_batch, batch,
1101                    "Unexpected content in batch {i}:\
1102                    \n\nExpected:\n{expected_batch:#?}\n\nActual:\n{batch:#?}"
1103                );
1104                starting_idx += *expected_size;
1105            }
1106            output_batches
1107        }
1108
1109        /// Return the expected output schema. If not overridden by `with_schema`, it
1110        /// returns the schema of the first input batch.
1111        fn schema(&self) -> SchemaRef {
1112            self.schema
1113                .clone()
1114                .unwrap_or_else(|| Arc::clone(&self.input_batches[0].schema()))
1115        }
1116
1117        /// Returns the expected output as a single `RecordBatch`
1118        fn expected_output(&self) -> RecordBatch {
1119            let schema = self.schema();
1120            if self.filters.is_empty() {
1121                return concat_batches(&schema, &self.input_batches).unwrap();
1122            }
1123
1124            let mut filters = self.filters.iter();
1125            let filtered_batches = self
1126                .input_batches
1127                .iter()
1128                .map(|batch| {
1129                    if let Some(filter) = filters.next() {
1130                        filter_record_batch(batch, filter).unwrap()
1131                    } else {
1132                        batch.clone()
1133                    }
1134                })
1135                .collect::<Vec<_>>();
1136            concat_batches(&schema, &filtered_batches).unwrap()
1137        }
1138    }
1139
1140    /// Return a RecordBatch with a UInt32Array with the specified range and
1141    /// every third value is null.
1142    fn uint32_batch(range: Range<u32>) -> RecordBatch {
1143        let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));
1144
1145        let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
1146        RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1147    }
1148
1149    /// Return a RecordBatch with a UInt32Array with no nulls specified range
1150    fn uint32_batch_non_null(range: Range<u32>) -> RecordBatch {
1151        let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
1152
1153        let array = UInt32Array::from_iter_values(range);
1154        RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1155    }
1156
1157    /// Return a RecordBatch with a StringArrary with values `value0`, `value1`, ...
1158    /// and every third value is `None`.
1159    fn utf8_batch(range: Range<u32>) -> RecordBatch {
1160        let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Utf8, true)]));
1161
1162        let array = StringArray::from_iter(range.map(|i| {
1163            if i % 3 == 0 {
1164                None
1165            } else {
1166                Some(format!("value{i}"))
1167            }
1168        }));
1169
1170        RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1171    }
1172
1173    /// Return a RecordBatch with a StringViewArray with (only) the specified values
1174    fn stringview_batch<'a>(values: impl IntoIterator<Item = Option<&'a str>>) -> RecordBatch {
1175        let schema = Arc::new(Schema::new(vec![Field::new(
1176            "c0",
1177            DataType::Utf8View,
1178            false,
1179        )]));
1180
1181        let array = StringViewArray::from_iter(values);
1182        RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1183    }
1184
1185    /// Return a RecordBatch with a StringViewArray with num_rows by repeating
1186    /// values over and over.
1187    fn stringview_batch_repeated<'a>(
1188        num_rows: usize,
1189        values: impl IntoIterator<Item = Option<&'a str>>,
1190    ) -> RecordBatch {
1191        let schema = Arc::new(Schema::new(vec![Field::new(
1192            "c0",
1193            DataType::Utf8View,
1194            true,
1195        )]));
1196
1197        // Repeat the values to a total of num_rows
1198        let values: Vec<_> = values.into_iter().collect();
1199        let values_iter = std::iter::repeat(values.iter())
1200            .flatten()
1201            .cloned()
1202            .take(num_rows);
1203
1204        let mut builder = StringViewBuilder::with_capacity(100).with_fixed_block_size(8192);
1205        for val in values_iter {
1206            builder.append_option(val);
1207        }
1208
1209        let array = builder.finish();
1210        RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
1211    }
1212
1213    /// Return a RecordBatch of 100 rows
1214    fn multi_column_batch(range: Range<i32>) -> RecordBatch {
1215        let int64_array = Int64Array::from_iter(range.clone().map(|v| {
1216            if v % 5 == 0 {
1217                None
1218            } else {
1219                Some(v as i64)
1220            }
1221        }));
1222        let string_view_array = StringViewArray::from_iter(range.clone().map(|v| {
1223            if v % 5 == 0 {
1224                None
1225            } else if v % 7 == 0 {
1226                Some(format!("This is a string longer than 12 bytes{v}"))
1227            } else {
1228                Some(format!("Short {v}"))
1229            }
1230        }));
1231        let string_array = StringArray::from_iter(range.clone().map(|v| {
1232            if v % 11 == 0 {
1233                None
1234            } else {
1235                Some(format!("Value {v}"))
1236            }
1237        }));
1238        let timestamp_array = TimestampNanosecondArray::from_iter(range.map(|v| {
1239            if v % 3 == 0 {
1240                None
1241            } else {
1242                Some(v as i64 * 1000) // simulate a timestamp in milliseconds
1243            }
1244        }))
1245        .with_timezone("America/New_York");
1246
1247        RecordBatch::try_from_iter(vec![
1248            ("int64", Arc::new(int64_array) as ArrayRef),
1249            ("stringview", Arc::new(string_view_array) as ArrayRef),
1250            ("string", Arc::new(string_array) as ArrayRef),
1251            ("timestamp", Arc::new(timestamp_array) as ArrayRef),
1252        ])
1253        .unwrap()
1254    }
1255
1256    /// Return a boolean array that filters out randomly selected rows
1257    /// from the input batch with a `selectivity`.
1258    ///
1259    /// For example a `selectivity` of 0.1 will filter out
1260    /// 90% of the rows.
1261    #[derive(Debug)]
1262    struct RandomFilterBuilder {
1263        num_rows: usize,
1264        selectivity: f64,
1265        /// seed for random number generator, increases by one each time
1266        /// `next_filter` is called
1267        seed: u64,
1268    }
1269    impl RandomFilterBuilder {
1270        /// Build the next filter with the current seed and increment the seed
1271        /// by one.
1272        fn next_filter(&mut self) -> BooleanArray {
1273            assert!(self.selectivity >= 0.0 && self.selectivity <= 1.0);
1274            let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
1275            self.seed += 1;
1276            BooleanArray::from_iter(
1277                (0..self.num_rows)
1278                    .map(|_| rng.random_bool(self.selectivity))
1279                    .map(Some),
1280            )
1281        }
1282    }
1283
1284    /// Returns the named column as a StringViewArray
1285    fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b StringViewArray {
1286        batch
1287            .column_by_name(name)
1288            .expect("column not found")
1289            .as_string_view_opt()
1290            .expect("column is not a string view")
1291    }
1292
1293    /// Normalize the `RecordBatch` so that the memory layout is consistent
1294    /// (e.g. StringArray is compacted).
1295    fn normalize_batch(batch: RecordBatch) -> RecordBatch {
1296        // Only need to normalize StringViews (as == also tests for memory layout)
1297        let (schema, mut columns, row_count) = batch.into_parts();
1298
1299        for column in columns.iter_mut() {
1300            let Some(string_view) = column.as_string_view_opt() else {
1301                continue;
1302            };
1303
1304            // Re-create the StringViewArray to ensure memory layout is
1305            // consistent
1306            let mut builder = StringViewBuilder::new();
1307            for s in string_view.iter() {
1308                builder.append_option(s);
1309            }
1310            // Update the column with the new StringViewArray
1311            *column = Arc::new(builder.finish());
1312        }
1313
1314        let options = RecordBatchOptions::new().with_row_count(Some(row_count));
1315        RecordBatch::try_new_with_options(schema, columns, &options).unwrap()
1316    }
1317}