1use crate::filter::{SlicesIterator, prep_null_mask_filter};
21use crate::zip::zip;
22use arrow_array::{Array, ArrayRef, BooleanArray, Datum, make_array, new_empty_array};
23use arrow_data::ArrayData;
24use arrow_data::transform::MutableArrayData;
25use arrow_schema::ArrowError;
26
27pub trait MergeIndex: PartialEq + Eq + Copy {
36 fn index(&self) -> Option<usize>;
41}
42
43impl MergeIndex for usize {
44 fn index(&self) -> Option<usize> {
45 Some(*self)
46 }
47}
48
49impl MergeIndex for Option<usize> {
50 fn index(&self) -> Option<usize> {
51 *self
52 }
53}
54
55pub fn merge_n(values: &[&dyn Array], indices: &[impl MergeIndex]) -> Result<ArrayRef, ArrowError> {
110 if values.is_empty() {
111 return Err(ArrowError::InvalidArgumentError(
112 "merge_n requires at least one value array".to_string(),
113 ));
114 }
115
116 let data_type = values[0].data_type();
117
118 for array in values.iter().skip(1) {
119 if array.data_type() != data_type {
120 return Err(ArrowError::InvalidArgumentError(format!(
121 "It is not possible to merge arrays of different data types ({} and {})",
122 data_type,
123 array.data_type()
124 )));
125 }
126 }
127
128 if indices.is_empty() {
129 return Ok(new_empty_array(data_type));
130 }
131
132 #[cfg(debug_assertions)]
133 for ix in indices {
134 if let Some(index) = ix.index() {
135 assert!(
136 index < values.len(),
137 "Index out of bounds: {} >= {}",
138 index,
139 values.len()
140 );
141 }
142 }
143
144 let data: Vec<ArrayData> = values.iter().map(|a| a.to_data()).collect();
145 let data_refs = data.iter().collect();
146
147 let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
148
149 let mut take_offsets = vec![0; values.len() + 1];
153 let mut start_row_ix = 0;
154 loop {
155 let array_ix = indices[start_row_ix];
156
157 let mut end_row_ix = start_row_ix + 1;
159 while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
160 end_row_ix += 1;
161 }
162 let slice_length = end_row_ix - start_row_ix;
163
164 match array_ix.index() {
166 None => mutable.extend_nulls(slice_length),
167 Some(index) => {
168 let start_offset = take_offsets[index];
169 let end_offset = start_offset + slice_length;
170 mutable.extend(index, start_offset, end_offset);
171 take_offsets[index] = end_offset;
172 }
173 }
174
175 if end_row_ix == indices.len() {
176 break;
177 } else {
178 start_row_ix = end_row_ix;
180 }
181 }
182
183 Ok(make_array(mutable.freeze()))
184}
185
186pub fn merge(
215 mask: &BooleanArray,
216 truthy: &dyn Datum,
217 falsy: &dyn Datum,
218) -> Result<ArrayRef, ArrowError> {
219 let (truthy_array, truthy_is_scalar) = truthy.get();
220 let (falsy_array, falsy_is_scalar) = falsy.get();
221
222 if truthy_is_scalar && falsy_is_scalar {
223 return zip(mask, truthy, falsy);
226 }
227
228 if truthy_array.data_type() != falsy_array.data_type() {
229 return Err(ArrowError::InvalidArgumentError(
230 "arguments need to have the same data type".into(),
231 ));
232 }
233
234 if truthy_is_scalar && truthy_array.len() != 1 {
235 return Err(ArrowError::InvalidArgumentError(
236 "scalar arrays must have 1 element".into(),
237 ));
238 }
239 if falsy_is_scalar && falsy_array.len() != 1 {
240 return Err(ArrowError::InvalidArgumentError(
241 "scalar arrays must have 1 element".into(),
242 ));
243 }
244
245 let falsy = falsy_array.to_data();
246 let truthy = truthy_array.to_data();
247
248 let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, mask.len());
249
250 let mut filled = 0;
255 let mut falsy_offset = 0;
256 let mut truthy_offset = 0;
257
258 let mask_buffer = match mask.null_count() {
260 0 => mask.values().clone(),
261 _ => prep_null_mask_filter(mask).into_parts().0,
262 };
263
264 SlicesIterator::from(&mask_buffer).for_each(|(start, end)| {
265 if start > filled {
267 if falsy_is_scalar {
268 for _ in filled..start {
269 mutable.extend(1, 0, 1);
271 }
272 } else {
273 let falsy_length = start - filled;
274 let falsy_end = falsy_offset + falsy_length;
275 mutable.extend(1, falsy_offset, falsy_end);
276 falsy_offset = falsy_end;
277 }
278 }
279 if truthy_is_scalar {
281 for _ in start..end {
282 mutable.extend(0, 0, 1);
284 }
285 } else {
286 let truthy_length = end - start;
287 let truthy_end = truthy_offset + truthy_length;
288 mutable.extend(0, truthy_offset, truthy_end);
289 truthy_offset = truthy_end;
290 }
291 filled = end;
292 });
293 if filled < mask.len() {
295 if falsy_is_scalar {
296 for _ in filled..mask.len() {
297 mutable.extend(1, 0, 1);
299 }
300 } else {
301 let falsy_length = mask.len() - filled;
302 let falsy_end = falsy_offset + falsy_length;
303 mutable.extend(1, falsy_offset, falsy_end);
304 }
305 }
306
307 let data = mutable.freeze();
308 Ok(make_array(data))
309}
310
311#[cfg(test)]
312mod tests {
313 use crate::merge::{MergeIndex, merge, merge_n};
314 use arrow_array::cast::AsArray;
315 use arrow_array::{Array, BooleanArray, Datum, Int32Array, Scalar, StringArray, UInt64Array};
316 use arrow_schema::ArrowError::InvalidArgumentError;
317
318 #[derive(PartialEq, Eq, Copy, Clone)]
319 struct CompactMergeIndex {
320 index: u8,
321 }
322
323 impl MergeIndex for CompactMergeIndex {
324 fn index(&self) -> Option<usize> {
325 if self.index == u8::MAX {
326 None
327 } else {
328 Some(self.index as usize)
329 }
330 }
331 }
332
333 #[test]
334 fn test_merge() {
335 let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
336 let a2 = StringArray::from(vec![Some("C"), Some("D")]);
337
338 let indices = BooleanArray::from(vec![true, false, true, false, true, true]);
339
340 let merged = merge(&indices, &a1, &a2).unwrap();
341 let merged = merged.as_string::<i32>();
342
343 assert_eq!(merged.len(), indices.len());
344 assert!(merged.is_valid(0));
345 assert_eq!(merged.value(0), "A");
346 assert!(merged.is_valid(1));
347 assert_eq!(merged.value(1), "C");
348 assert!(merged.is_valid(2));
349 assert_eq!(merged.value(2), "B");
350 assert!(merged.is_valid(3));
351 assert_eq!(merged.value(3), "D");
352 assert!(merged.is_valid(4));
353 assert_eq!(merged.value(4), "E");
354 assert!(!merged.is_valid(5));
355 }
356
357 #[test]
358 fn test_merge_null_is_false() {
359 let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
360 let a2 = StringArray::from(vec![Some("C"), Some("D")]);
361
362 let indices = BooleanArray::from(vec![
363 Some(true),
364 None,
365 Some(true),
366 None,
367 Some(true),
368 Some(true),
369 ]);
370
371 let merged = merge(&indices, &a1, &a2).unwrap();
372 let merged = merged.as_string::<i32>();
373
374 assert_eq!(merged.len(), indices.len());
375 assert!(merged.is_valid(0));
376 assert_eq!(merged.value(0), "A");
377 assert!(merged.is_valid(1));
378 assert_eq!(merged.value(1), "C");
379 assert!(merged.is_valid(2));
380 assert_eq!(merged.value(2), "B");
381 assert!(merged.is_valid(3));
382 assert_eq!(merged.value(3), "D");
383 assert!(merged.is_valid(4));
384 assert_eq!(merged.value(4), "E");
385 assert!(!merged.is_valid(5));
386 }
387
388 #[test]
389 fn test_merge_false_tail() {
390 let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
391 let a2 = StringArray::from(vec![Some("C"), Some("D"), None, Some("F")]);
392
393 let indices = BooleanArray::from(vec![true, false, true, false, true, true, false, false]);
394
395 let merged = merge(&indices, &a1, &a2).unwrap();
396 let merged = merged.as_string::<i32>();
397
398 assert_eq!(merged.len(), indices.len());
399 assert!(merged.is_valid(0));
400 assert_eq!(merged.value(0), "A");
401 assert!(merged.is_valid(1));
402 assert_eq!(merged.value(1), "C");
403 assert!(merged.is_valid(2));
404 assert_eq!(merged.value(2), "B");
405 assert!(merged.is_valid(3));
406 assert_eq!(merged.value(3), "D");
407 assert!(merged.is_valid(4));
408 assert_eq!(merged.value(4), "E");
409 assert!(!merged.is_valid(5));
410 assert!(!merged.is_valid(6));
411 assert!(merged.is_valid(7));
412 assert_eq!(merged.value(7), "F");
413 }
414
415 #[test]
416 fn test_merge_scalars() {
417 let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
418 let falsy = Scalar::new(StringArray::from(vec![Some("B")]));
419
420 let mask = BooleanArray::from(vec![true, false, false, true]);
421
422 let merged = merge(&mask, &truthy, &falsy).unwrap();
423 let merged = merged.as_string::<i32>();
424
425 assert_eq!(merged.len(), mask.len());
426 assert!(merged.is_valid(0));
427 assert_eq!(merged.value(0), "A");
428 assert!(merged.is_valid(1));
429 assert_eq!(merged.value(1), "B");
430 assert!(merged.is_valid(2));
431 assert_eq!(merged.value(2), "B");
432 assert!(merged.is_valid(3));
433 assert_eq!(merged.value(3), "A");
434 }
435
436 #[test]
437 fn test_merge_scalar_and_array() {
438 let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
439 let falsy = StringArray::from(vec![Some("B"), Some("C")]);
440
441 let mask = BooleanArray::from(vec![true, false, false, true]);
442
443 let merged = merge(&mask, &truthy, &falsy).unwrap();
444 let merged = merged.as_string::<i32>();
445
446 assert_eq!(merged.len(), mask.len());
447 assert!(merged.is_valid(0));
448 assert_eq!(merged.value(0), "A");
449 assert!(merged.is_valid(1));
450 assert_eq!(merged.value(1), "B");
451 assert!(merged.is_valid(2));
452 assert_eq!(merged.value(2), "C");
453 assert!(merged.is_valid(3));
454 assert_eq!(merged.value(3), "A");
455 }
456
457 #[test]
458 fn test_merge_array_and_scalar() {
459 let truthy = StringArray::from(vec![Some("B"), Some("C")]);
460 let falsy = Scalar::new(StringArray::from(vec![Some("A")]));
461
462 let mask = BooleanArray::from(vec![true, false, false, true, false, false]);
463
464 let merged = merge(&mask, &truthy, &falsy).unwrap();
465 let merged = merged.as_string::<i32>();
466
467 assert_eq!(merged.len(), mask.len());
468 assert!(merged.is_valid(0));
469 assert_eq!(merged.value(0), "B");
470 assert!(merged.is_valid(1));
471 assert_eq!(merged.value(1), "A");
472 assert!(merged.is_valid(2));
473 assert_eq!(merged.value(2), "A");
474 assert!(merged.is_valid(3));
475 assert_eq!(merged.value(3), "C");
476 assert!(merged.is_valid(4));
477 assert_eq!(merged.value(4), "A");
478 assert!(merged.is_valid(5));
479 assert_eq!(merged.value(5), "A");
480 }
481
482 #[test]
483 fn test_merge_empty_mask() {
484 let a1 = StringArray::from(vec![Some("A")]);
485 let a2 = StringArray::from(vec![Some("B")]);
486 let mask: Vec<bool> = vec![];
487 let mask = BooleanArray::from(mask);
488 let result = merge(&mask, &a1, &a2).unwrap();
489 assert_eq!(result.len(), 0);
490 }
491
492 #[derive(Debug, Copy, Clone)]
493 pub struct UnsafeScalar<T: Array>(T);
494
495 impl<T: Array> Datum for UnsafeScalar<T> {
496 fn get(&self) -> (&dyn Array, bool) {
497 (&self.0, true)
498 }
499 }
500
501 #[test]
502 fn test_merge_invalid_truthy_scalar() {
503 let truthy = UnsafeScalar(StringArray::from(vec![Some("A"), Some("C")]));
504 let falsy = StringArray::from(vec![Some("B"), Some("D")]);
505 let mask = BooleanArray::from(vec![true, false, true, false]);
506 let merged = merge(&mask, &truthy, &falsy);
507 assert!(matches!(merged, Err(InvalidArgumentError { .. })));
508 }
509
510 #[test]
511 fn test_merge_invalid_falsy_scalar() {
512 let truthy = StringArray::from(vec![Some("A"), Some("C")]);
513 let falsy = UnsafeScalar(StringArray::from(vec![Some("B"), Some("D")]));
514 let mask = vec![true, false, true, false];
515 let mask = BooleanArray::from(mask);
516 let merged = merge(&mask, &truthy, &falsy);
517 assert!(matches!(merged, Err(InvalidArgumentError { .. })));
518 }
519
520 #[test]
521 fn test_merge_incompatible_arrays() {
522 let truthy = StringArray::from(vec![Some("A"), Some("B")]);
523 let falsy = Int32Array::from(vec![1, 2]);
524 let mask = BooleanArray::from(vec![true, false, true, false]);
525 let merged = merge(&mask, &truthy, &falsy);
526 assert!(matches!(merged, Err(InvalidArgumentError { .. })));
527 }
528
529 #[test]
530 fn test_merge_n() {
531 let a1 = StringArray::from(vec![Some("A")]);
532 let a2 = StringArray::from(vec![Some("B"), None, None]);
533 let a3 = StringArray::from(vec![Some("C"), Some("D")]);
534
535 let indices = vec![
536 CompactMergeIndex { index: u8::MAX },
537 CompactMergeIndex { index: 1 },
538 CompactMergeIndex { index: 0 },
539 CompactMergeIndex { index: u8::MAX },
540 CompactMergeIndex { index: 2 },
541 CompactMergeIndex { index: 2 },
542 CompactMergeIndex { index: 1 },
543 CompactMergeIndex { index: 1 },
544 ];
545
546 let arrays = [a1, a2, a3];
547 let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
548 let merged = merge_n(&array_refs, &indices).unwrap();
549 let merged = merged.as_string::<i32>();
550
551 assert_eq!(merged.len(), indices.len());
552 assert!(!merged.is_valid(0));
553 assert!(merged.is_valid(1));
554 assert_eq!(merged.value(1), "B");
555 assert!(merged.is_valid(2));
556 assert_eq!(merged.value(2), "A");
557 assert!(!merged.is_valid(3));
558 assert!(merged.is_valid(4));
559 assert_eq!(merged.value(4), "C");
560 assert!(merged.is_valid(5));
561 assert_eq!(merged.value(5), "D");
562 assert!(!merged.is_valid(6));
563 assert!(!merged.is_valid(7));
564 }
565
566 #[test]
567 #[should_panic]
568 fn test_merge_n_invalid_indices() {
569 let a1 = StringArray::from(vec![Some("A")]);
570
571 let indices = vec![CompactMergeIndex { index: 99 }];
572
573 let arrays = [a1];
574 let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
575 let _ = merge_n(&array_refs, &indices);
576 }
577
578 #[test]
579 fn test_merge_n_empty_indices() {
580 let a1 = StringArray::from(vec![Some("A")]);
581 let a2 = StringArray::from(vec![Some("B"), None, None]);
582 let a3 = StringArray::from(vec![Some("C"), Some("D")]);
583
584 let indices: Vec<CompactMergeIndex> = vec![];
585
586 let arrays = [a1, a2, a3];
587 let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
588 let merged = merge_n(&array_refs, &indices).unwrap();
589
590 assert_eq!(merged.len(), indices.len());
591 }
592
593 #[test]
594 fn test_merge_n_empty_values() {
595 let indices: Vec<CompactMergeIndex> = vec![];
596
597 let arrays: Vec<&dyn Array> = vec![];
598 let merged = merge_n(&arrays, &indices);
599
600 assert!(matches!(merged, Err(InvalidArgumentError { .. })));
601 }
602
603 #[test]
604 fn test_merge_n_incompatible_arrays() {
605 let a1: Box<dyn Array> = Box::new(StringArray::from(vec![Some("A")]));
606 let a2: Box<dyn Array> = Box::new(Int32Array::from(vec![1, 2, 3]));
607 let a3: Box<dyn Array> = Box::new(UInt64Array::from(vec![42, 314]));
608
609 let indices: Vec<CompactMergeIndex> = vec![];
610
611 let arrays = [a1.as_ref(), a2.as_ref(), a3.as_ref()];
612 let merged = merge_n(&arrays, &indices);
613
614 assert!(matches!(merged, Err(InvalidArgumentError { .. })));
615 }
616}