arrow_select/
merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`merge`] and [`merge_n`]: Combine values from two or more arrays
19
20use crate::filter::{SlicesIterator, prep_null_mask_filter};
21use crate::zip::zip;
22use arrow_array::{Array, ArrayRef, BooleanArray, Datum, make_array, new_empty_array};
23use arrow_data::ArrayData;
24use arrow_data::transform::MutableArrayData;
25use arrow_schema::ArrowError;
26
27/// An index for the [merge_n] function.
28///
29/// This trait allows the indices argument for [merge_n] to be stored using a more
30/// compact representation than `usize` when the input arrays are small.
31/// If the number of input arrays is less than 256 for instance, the indices can be stored as `u8`.
32///
33/// Implementation must ensure that all values which return `None` from [MergeIndex::index] are
34/// considered equal by the [PartialEq] and [Eq] implementations.
35pub trait MergeIndex: PartialEq + Eq + Copy {
36    /// Returns the index value as an `Option<usize>`.
37    ///
38    /// `None` values returned by this function indicate holes in the index array and will result
39    /// in null values in the array created by [merge].
40    fn index(&self) -> Option<usize>;
41}
42
43impl MergeIndex for usize {
44    fn index(&self) -> Option<usize> {
45        Some(*self)
46    }
47}
48
49impl MergeIndex for Option<usize> {
50    fn index(&self) -> Option<usize> {
51        *self
52    }
53}
54
55/// Merges elements by index from a list of [`Array`], creating a new [`Array`] from
56/// those values.
57///
58/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed
59/// sequentially. The first occurrence of index value `n` will be mapped to the first
60/// value of the array at index `n`. The second occurrence to the second value, and so on.
61/// An index value where `MergeIndex::index` returns `None` is interpreted as a null value.
62///
63/// # Implementation notes
64///
65/// This algorithm is similar in nature to both [zip] and
66/// [interleave](crate::interleave::interleave), but there are some important differences.
67///
68/// In contrast to [zip], this function supports multiple input arrays. Instead of
69/// a boolean selection vector, an index array is to take values from the input arrays, and a special
70/// marker values can be used to indicate null values.
71///
72/// In contrast to [interleave](crate::interleave::interleave), this function does not use pairs of
73/// indices. The values in `indices` serve the same purpose as the first value in the pairs passed
74/// to `interleave`.
75/// The index in the array is implicit and is derived from the number of times a particular array
76/// index occurs.
77/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values
78/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be
79/// copied in a single operation from the source array instead of copying them one by one.
80/// Long spans of null values are also especially cheap because they do not need to be represented
81/// in an input array.
82///
83/// # Panics
84///
85/// This function does not check that the number of occurrences of any particular array index matches
86/// the length of the corresponding input array. If an array contains more values than required, the
87/// spurious values will be ignored. If an array contains fewer values than necessary, this function
88/// will panic.
89///
90/// # Example
91///
92/// ```text
93/// ┌───────────┐  ┌─────────┐                             ┌─────────┐
94/// │┌─────────┐│  │   None  │                             │   NULL  │
95/// ││    A    ││  ├─────────┤                             ├─────────┤
96/// │└─────────┘│  │    1    │                             │    B    │
97/// │┌─────────┐│  ├─────────┤                             ├─────────┤
98/// ││    B    ││  │    0    │    merge(values, indices)   │    A    │
99/// │└─────────┘│  ├─────────┤  ─────────────────────────▶ ├─────────┤
100/// │┌─────────┐│  │   None  │                             │   NULL  │
101/// ││    C    ││  ├─────────┤                             ├─────────┤
102/// │├─────────┤│  │    2    │                             │    C    │
103/// ││    D    ││  ├─────────┤                             ├─────────┤
104/// │└─────────┘│  │    2    │                             │    D    │
105/// └───────────┘  └─────────┘                             └─────────┘
106///    values        indices                                  result
107///
108/// ```
109pub fn merge_n(values: &[&dyn Array], indices: &[impl MergeIndex]) -> Result<ArrayRef, ArrowError> {
110    if values.is_empty() {
111        return Err(ArrowError::InvalidArgumentError(
112            "merge_n requires at least one value array".to_string(),
113        ));
114    }
115
116    let data_type = values[0].data_type();
117
118    for array in values.iter().skip(1) {
119        if array.data_type() != data_type {
120            return Err(ArrowError::InvalidArgumentError(format!(
121                "It is not possible to merge arrays of different data types ({} and {})",
122                data_type,
123                array.data_type()
124            )));
125        }
126    }
127
128    if indices.is_empty() {
129        return Ok(new_empty_array(data_type));
130    }
131
132    #[cfg(debug_assertions)]
133    for ix in indices {
134        if let Some(index) = ix.index() {
135            assert!(
136                index < values.len(),
137                "Index out of bounds: {} >= {}",
138                index,
139                values.len()
140            );
141        }
142    }
143
144    let data: Vec<ArrayData> = values.iter().map(|a| a.to_data()).collect();
145    let data_refs = data.iter().collect();
146
147    let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
148
149    // This loop extends the mutable array by taking slices from the partial results.
150    //
151    // take_offsets keeps track of how many values have been taken from each array.
152    let mut take_offsets = vec![0; values.len() + 1];
153    let mut start_row_ix = 0;
154    loop {
155        let array_ix = indices[start_row_ix];
156
157        // Determine the length of the slice to take.
158        let mut end_row_ix = start_row_ix + 1;
159        while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
160            end_row_ix += 1;
161        }
162        let slice_length = end_row_ix - start_row_ix;
163
164        // Extend mutable with either nulls or with values from the array.
165        match array_ix.index() {
166            None => mutable.extend_nulls(slice_length),
167            Some(index) => {
168                let start_offset = take_offsets[index];
169                let end_offset = start_offset + slice_length;
170                mutable.extend(index, start_offset, end_offset);
171                take_offsets[index] = end_offset;
172            }
173        }
174
175        if end_row_ix == indices.len() {
176            break;
177        } else {
178            // Set the start_row_ix for the next slice.
179            start_row_ix = end_row_ix;
180        }
181    }
182
183    Ok(make_array(mutable.freeze()))
184}
185
186/// Merges two arrays in the order specified by a boolean mask.
187///
188/// This algorithm is a variant of [zip] that does not require the truthy and
189/// falsy arrays to have the same length.
190///
191/// When truthy of falsy are [Scalar](arrow_array::Scalar), the single
192/// scalar value is repeated whenever the mask array contains true or false respectively.
193///
194/// # Example
195///
196/// ```text
197///  truthy
198/// ┌─────────┐  mask
199/// │    A    │  ┌─────────┐                             ┌─────────┐
200/// ├─────────┤  │  true   │                             │    A    │
201/// │    C    │  ├─────────┤                             ├─────────┤
202/// ├─────────┤  │  true   │                             │    C    │
203/// │   NULL  │  ├─────────┤                             ├─────────┤
204/// ├─────────┤  │  false  │  merge(mask, truthy, falsy) │    B    │
205/// │    D    │  ├─────────┤  ─────────────────────────▶ ├─────────┤
206/// └─────────┘  │  true   │                             │   NULL  │
207///  falsy       ├─────────┤                             ├─────────┤
208/// ┌─────────┐  │  false  │                             │    E    │
209/// │    B    │  ├─────────┤                             ├─────────┤
210/// ├─────────┤  │  true   │                             │    D    │
211/// │    E    │  └─────────┘                             └─────────┘
212/// └─────────┘
213/// ```
214pub fn merge(
215    mask: &BooleanArray,
216    truthy: &dyn Datum,
217    falsy: &dyn Datum,
218) -> Result<ArrayRef, ArrowError> {
219    let (truthy_array, truthy_is_scalar) = truthy.get();
220    let (falsy_array, falsy_is_scalar) = falsy.get();
221
222    if truthy_is_scalar && falsy_is_scalar {
223        // When both truthy and falsy are scalars, we can use `zip` since the result is the same
224        // and zip has optimized code for scalars.
225        return zip(mask, truthy, falsy);
226    }
227
228    if truthy_array.data_type() != falsy_array.data_type() {
229        return Err(ArrowError::InvalidArgumentError(
230            "arguments need to have the same data type".into(),
231        ));
232    }
233
234    if truthy_is_scalar && truthy_array.len() != 1 {
235        return Err(ArrowError::InvalidArgumentError(
236            "scalar arrays must have 1 element".into(),
237        ));
238    }
239    if falsy_is_scalar && falsy_array.len() != 1 {
240        return Err(ArrowError::InvalidArgumentError(
241            "scalar arrays must have 1 element".into(),
242        ));
243    }
244
245    let falsy = falsy_array.to_data();
246    let truthy = truthy_array.to_data();
247
248    let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, mask.len());
249
250    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
251    // fill with falsy values
252
253    // keep track of how much is filled
254    let mut filled = 0;
255    let mut falsy_offset = 0;
256    let mut truthy_offset = 0;
257
258    // Ensure nulls are treated as false
259    let mask_buffer = match mask.null_count() {
260        0 => mask.values().clone(),
261        _ => prep_null_mask_filter(mask).into_parts().0,
262    };
263
264    SlicesIterator::from(&mask_buffer).for_each(|(start, end)| {
265        // the gap needs to be filled with falsy values
266        if start > filled {
267            if falsy_is_scalar {
268                for _ in filled..start {
269                    // Copy the first item from the 'falsy' array into the output buffer.
270                    mutable.extend(1, 0, 1);
271                }
272            } else {
273                let falsy_length = start - filled;
274                let falsy_end = falsy_offset + falsy_length;
275                mutable.extend(1, falsy_offset, falsy_end);
276                falsy_offset = falsy_end;
277            }
278        }
279        // fill with truthy values
280        if truthy_is_scalar {
281            for _ in start..end {
282                // Copy the first item from the 'truthy' array into the output buffer.
283                mutable.extend(0, 0, 1);
284            }
285        } else {
286            let truthy_length = end - start;
287            let truthy_end = truthy_offset + truthy_length;
288            mutable.extend(0, truthy_offset, truthy_end);
289            truthy_offset = truthy_end;
290        }
291        filled = end;
292    });
293    // the remaining part is falsy
294    if filled < mask.len() {
295        if falsy_is_scalar {
296            for _ in filled..mask.len() {
297                // Copy the first item from the 'falsy' array into the output buffer.
298                mutable.extend(1, 0, 1);
299            }
300        } else {
301            let falsy_length = mask.len() - filled;
302            let falsy_end = falsy_offset + falsy_length;
303            mutable.extend(1, falsy_offset, falsy_end);
304        }
305    }
306
307    let data = mutable.freeze();
308    Ok(make_array(data))
309}
310
311#[cfg(test)]
312mod tests {
313    use crate::merge::{MergeIndex, merge, merge_n};
314    use arrow_array::cast::AsArray;
315    use arrow_array::{Array, BooleanArray, Datum, Int32Array, Scalar, StringArray, UInt64Array};
316    use arrow_schema::ArrowError::InvalidArgumentError;
317
318    #[derive(PartialEq, Eq, Copy, Clone)]
319    struct CompactMergeIndex {
320        index: u8,
321    }
322
323    impl MergeIndex for CompactMergeIndex {
324        fn index(&self) -> Option<usize> {
325            if self.index == u8::MAX {
326                None
327            } else {
328                Some(self.index as usize)
329            }
330        }
331    }
332
333    #[test]
334    fn test_merge() {
335        let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
336        let a2 = StringArray::from(vec![Some("C"), Some("D")]);
337
338        let indices = BooleanArray::from(vec![true, false, true, false, true, true]);
339
340        let merged = merge(&indices, &a1, &a2).unwrap();
341        let merged = merged.as_string::<i32>();
342
343        assert_eq!(merged.len(), indices.len());
344        assert!(merged.is_valid(0));
345        assert_eq!(merged.value(0), "A");
346        assert!(merged.is_valid(1));
347        assert_eq!(merged.value(1), "C");
348        assert!(merged.is_valid(2));
349        assert_eq!(merged.value(2), "B");
350        assert!(merged.is_valid(3));
351        assert_eq!(merged.value(3), "D");
352        assert!(merged.is_valid(4));
353        assert_eq!(merged.value(4), "E");
354        assert!(!merged.is_valid(5));
355    }
356
357    #[test]
358    fn test_merge_null_is_false() {
359        let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
360        let a2 = StringArray::from(vec![Some("C"), Some("D")]);
361
362        let indices = BooleanArray::from(vec![
363            Some(true),
364            None,
365            Some(true),
366            None,
367            Some(true),
368            Some(true),
369        ]);
370
371        let merged = merge(&indices, &a1, &a2).unwrap();
372        let merged = merged.as_string::<i32>();
373
374        assert_eq!(merged.len(), indices.len());
375        assert!(merged.is_valid(0));
376        assert_eq!(merged.value(0), "A");
377        assert!(merged.is_valid(1));
378        assert_eq!(merged.value(1), "C");
379        assert!(merged.is_valid(2));
380        assert_eq!(merged.value(2), "B");
381        assert!(merged.is_valid(3));
382        assert_eq!(merged.value(3), "D");
383        assert!(merged.is_valid(4));
384        assert_eq!(merged.value(4), "E");
385        assert!(!merged.is_valid(5));
386    }
387
388    #[test]
389    fn test_merge_false_tail() {
390        let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
391        let a2 = StringArray::from(vec![Some("C"), Some("D"), None, Some("F")]);
392
393        let indices = BooleanArray::from(vec![true, false, true, false, true, true, false, false]);
394
395        let merged = merge(&indices, &a1, &a2).unwrap();
396        let merged = merged.as_string::<i32>();
397
398        assert_eq!(merged.len(), indices.len());
399        assert!(merged.is_valid(0));
400        assert_eq!(merged.value(0), "A");
401        assert!(merged.is_valid(1));
402        assert_eq!(merged.value(1), "C");
403        assert!(merged.is_valid(2));
404        assert_eq!(merged.value(2), "B");
405        assert!(merged.is_valid(3));
406        assert_eq!(merged.value(3), "D");
407        assert!(merged.is_valid(4));
408        assert_eq!(merged.value(4), "E");
409        assert!(!merged.is_valid(5));
410        assert!(!merged.is_valid(6));
411        assert!(merged.is_valid(7));
412        assert_eq!(merged.value(7), "F");
413    }
414
415    #[test]
416    fn test_merge_scalars() {
417        let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
418        let falsy = Scalar::new(StringArray::from(vec![Some("B")]));
419
420        let mask = BooleanArray::from(vec![true, false, false, true]);
421
422        let merged = merge(&mask, &truthy, &falsy).unwrap();
423        let merged = merged.as_string::<i32>();
424
425        assert_eq!(merged.len(), mask.len());
426        assert!(merged.is_valid(0));
427        assert_eq!(merged.value(0), "A");
428        assert!(merged.is_valid(1));
429        assert_eq!(merged.value(1), "B");
430        assert!(merged.is_valid(2));
431        assert_eq!(merged.value(2), "B");
432        assert!(merged.is_valid(3));
433        assert_eq!(merged.value(3), "A");
434    }
435
436    #[test]
437    fn test_merge_scalar_and_array() {
438        let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
439        let falsy = StringArray::from(vec![Some("B"), Some("C")]);
440
441        let mask = BooleanArray::from(vec![true, false, false, true]);
442
443        let merged = merge(&mask, &truthy, &falsy).unwrap();
444        let merged = merged.as_string::<i32>();
445
446        assert_eq!(merged.len(), mask.len());
447        assert!(merged.is_valid(0));
448        assert_eq!(merged.value(0), "A");
449        assert!(merged.is_valid(1));
450        assert_eq!(merged.value(1), "B");
451        assert!(merged.is_valid(2));
452        assert_eq!(merged.value(2), "C");
453        assert!(merged.is_valid(3));
454        assert_eq!(merged.value(3), "A");
455    }
456
457    #[test]
458    fn test_merge_array_and_scalar() {
459        let truthy = StringArray::from(vec![Some("B"), Some("C")]);
460        let falsy = Scalar::new(StringArray::from(vec![Some("A")]));
461
462        let mask = BooleanArray::from(vec![true, false, false, true, false, false]);
463
464        let merged = merge(&mask, &truthy, &falsy).unwrap();
465        let merged = merged.as_string::<i32>();
466
467        assert_eq!(merged.len(), mask.len());
468        assert!(merged.is_valid(0));
469        assert_eq!(merged.value(0), "B");
470        assert!(merged.is_valid(1));
471        assert_eq!(merged.value(1), "A");
472        assert!(merged.is_valid(2));
473        assert_eq!(merged.value(2), "A");
474        assert!(merged.is_valid(3));
475        assert_eq!(merged.value(3), "C");
476        assert!(merged.is_valid(4));
477        assert_eq!(merged.value(4), "A");
478        assert!(merged.is_valid(5));
479        assert_eq!(merged.value(5), "A");
480    }
481
482    #[test]
483    fn test_merge_empty_mask() {
484        let a1 = StringArray::from(vec![Some("A")]);
485        let a2 = StringArray::from(vec![Some("B")]);
486        let mask: Vec<bool> = vec![];
487        let mask = BooleanArray::from(mask);
488        let result = merge(&mask, &a1, &a2).unwrap();
489        assert_eq!(result.len(), 0);
490    }
491
492    #[derive(Debug, Copy, Clone)]
493    pub struct UnsafeScalar<T: Array>(T);
494
495    impl<T: Array> Datum for UnsafeScalar<T> {
496        fn get(&self) -> (&dyn Array, bool) {
497            (&self.0, true)
498        }
499    }
500
501    #[test]
502    fn test_merge_invalid_truthy_scalar() {
503        let truthy = UnsafeScalar(StringArray::from(vec![Some("A"), Some("C")]));
504        let falsy = StringArray::from(vec![Some("B"), Some("D")]);
505        let mask = BooleanArray::from(vec![true, false, true, false]);
506        let merged = merge(&mask, &truthy, &falsy);
507        assert!(matches!(merged, Err(InvalidArgumentError { .. })));
508    }
509
510    #[test]
511    fn test_merge_invalid_falsy_scalar() {
512        let truthy = StringArray::from(vec![Some("A"), Some("C")]);
513        let falsy = UnsafeScalar(StringArray::from(vec![Some("B"), Some("D")]));
514        let mask = vec![true, false, true, false];
515        let mask = BooleanArray::from(mask);
516        let merged = merge(&mask, &truthy, &falsy);
517        assert!(matches!(merged, Err(InvalidArgumentError { .. })));
518    }
519
520    #[test]
521    fn test_merge_incompatible_arrays() {
522        let truthy = StringArray::from(vec![Some("A"), Some("B")]);
523        let falsy = Int32Array::from(vec![1, 2]);
524        let mask = BooleanArray::from(vec![true, false, true, false]);
525        let merged = merge(&mask, &truthy, &falsy);
526        assert!(matches!(merged, Err(InvalidArgumentError { .. })));
527    }
528
529    #[test]
530    fn test_merge_n() {
531        let a1 = StringArray::from(vec![Some("A")]);
532        let a2 = StringArray::from(vec![Some("B"), None, None]);
533        let a3 = StringArray::from(vec![Some("C"), Some("D")]);
534
535        let indices = vec![
536            CompactMergeIndex { index: u8::MAX },
537            CompactMergeIndex { index: 1 },
538            CompactMergeIndex { index: 0 },
539            CompactMergeIndex { index: u8::MAX },
540            CompactMergeIndex { index: 2 },
541            CompactMergeIndex { index: 2 },
542            CompactMergeIndex { index: 1 },
543            CompactMergeIndex { index: 1 },
544        ];
545
546        let arrays = [a1, a2, a3];
547        let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
548        let merged = merge_n(&array_refs, &indices).unwrap();
549        let merged = merged.as_string::<i32>();
550
551        assert_eq!(merged.len(), indices.len());
552        assert!(!merged.is_valid(0));
553        assert!(merged.is_valid(1));
554        assert_eq!(merged.value(1), "B");
555        assert!(merged.is_valid(2));
556        assert_eq!(merged.value(2), "A");
557        assert!(!merged.is_valid(3));
558        assert!(merged.is_valid(4));
559        assert_eq!(merged.value(4), "C");
560        assert!(merged.is_valid(5));
561        assert_eq!(merged.value(5), "D");
562        assert!(!merged.is_valid(6));
563        assert!(!merged.is_valid(7));
564    }
565
566    #[test]
567    #[should_panic]
568    fn test_merge_n_invalid_indices() {
569        let a1 = StringArray::from(vec![Some("A")]);
570
571        let indices = vec![CompactMergeIndex { index: 99 }];
572
573        let arrays = [a1];
574        let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
575        let _ = merge_n(&array_refs, &indices);
576    }
577
578    #[test]
579    fn test_merge_n_empty_indices() {
580        let a1 = StringArray::from(vec![Some("A")]);
581        let a2 = StringArray::from(vec![Some("B"), None, None]);
582        let a3 = StringArray::from(vec![Some("C"), Some("D")]);
583
584        let indices: Vec<CompactMergeIndex> = vec![];
585
586        let arrays = [a1, a2, a3];
587        let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
588        let merged = merge_n(&array_refs, &indices).unwrap();
589
590        assert_eq!(merged.len(), indices.len());
591    }
592
593    #[test]
594    fn test_merge_n_empty_values() {
595        let indices: Vec<CompactMergeIndex> = vec![];
596
597        let arrays: Vec<&dyn Array> = vec![];
598        let merged = merge_n(&arrays, &indices);
599
600        assert!(matches!(merged, Err(InvalidArgumentError { .. })));
601    }
602
603    #[test]
604    fn test_merge_n_incompatible_arrays() {
605        let a1: Box<dyn Array> = Box::new(StringArray::from(vec![Some("A")]));
606        let a2: Box<dyn Array> = Box::new(Int32Array::from(vec![1, 2, 3]));
607        let a3: Box<dyn Array> = Box::new(UInt64Array::from(vec![42, 314]));
608
609        let indices: Vec<CompactMergeIndex> = vec![];
610
611        let arrays = [a1.as_ref(), a2.as_ref(), a3.as_ref()];
612        let merged = merge_n(&arrays, &indices);
613
614        assert!(matches!(merged, Err(InvalidArgumentError { .. })));
615    }
616}