1use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::*;
23use arrow_buffer::{ArrowNativeType, NullBuffer};
24use arrow_schema::{ArrowError, SortOptions};
25use std::cmp::Ordering;
26
27pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
29
30fn child_opts(opts: SortOptions) -> SortOptions {
33 SortOptions {
34 descending: false,
35 nulls_first: opts.nulls_first != opts.descending,
36 }
37}
38
39fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
40where
41 A: Array + Clone,
42 F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
43{
44 let l = l.logical_nulls().filter(|x| x.null_count() > 0);
45 let r = r.logical_nulls().filter(|x| x.null_count() > 0);
46 match (opts.nulls_first, opts.descending) {
47 (true, true) => compare_impl::<true, true, _>(l, r, cmp),
48 (true, false) => compare_impl::<true, false, _>(l, r, cmp),
49 (false, true) => compare_impl::<false, true, _>(l, r, cmp),
50 (false, false) => compare_impl::<false, false, _>(l, r, cmp),
51 }
52}
53
54fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
55 l: Option<NullBuffer>,
56 r: Option<NullBuffer>,
57 cmp: F,
58) -> DynComparator
59where
60 F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
61{
62 let cmp = move |i, j| match DESCENDING {
63 true => cmp(i, j).reverse(),
64 false => cmp(i, j),
65 };
66
67 let (left_null, right_null) = match NULLS_FIRST {
68 true => (Ordering::Less, Ordering::Greater),
69 false => (Ordering::Greater, Ordering::Less),
70 };
71
72 match (l, r) {
73 (None, None) => Box::new(cmp),
74 (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
75 true => left_null,
76 false => cmp(i, j),
77 }),
78 (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
79 true => right_null,
80 false => cmp(i, j),
81 }),
82 (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) {
83 (true, true) => Ordering::Equal,
84 (true, false) => left_null,
85 (false, true) => right_null,
86 (false, false) => cmp(i, j),
87 }),
88 }
89}
90
91fn compare_primitive<T: ArrowPrimitiveType>(
92 left: &dyn Array,
93 right: &dyn Array,
94 opts: SortOptions,
95) -> DynComparator
96where
97 T::Native: ArrowNativeTypeOp,
98{
99 let left = left.as_primitive::<T>();
100 let right = right.as_primitive::<T>();
101 let l_values = left.values().clone();
102 let r_values = right.values().clone();
103
104 compare(&left, &right, opts, move |i, j| {
105 l_values[i].compare(r_values[j])
106 })
107}
108
109fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator {
110 let left = left.as_boolean();
111 let right = right.as_boolean();
112
113 let l_values = left.values().clone();
114 let r_values = right.values().clone();
115
116 compare(left, right, opts, move |i, j| {
117 l_values.value(i).cmp(&r_values.value(j))
118 })
119}
120
121fn compare_bytes<T: ByteArrayType>(
122 left: &dyn Array,
123 right: &dyn Array,
124 opts: SortOptions,
125) -> DynComparator {
126 let left = left.as_bytes::<T>();
127 let right = right.as_bytes::<T>();
128
129 let l = left.clone();
130 let r = right.clone();
131 compare(left, right, opts, move |i, j| {
132 let l: &[u8] = l.value(i).as_ref();
133 let r: &[u8] = r.value(j).as_ref();
134 l.cmp(r)
135 })
136}
137
138fn compare_byte_view<T: ByteViewType>(
139 left: &dyn Array,
140 right: &dyn Array,
141 opts: SortOptions,
142) -> DynComparator {
143 let left = left.as_byte_view::<T>();
144 let right = right.as_byte_view::<T>();
145
146 let l = left.clone();
147 let r = right.clone();
148 compare(left, right, opts, move |i, j| {
149 crate::cmp::compare_byte_view(&l, i, &r, j)
150 })
151}
152
153fn compare_dict<K: ArrowDictionaryKeyType>(
154 left: &dyn Array,
155 right: &dyn Array,
156 opts: SortOptions,
157) -> Result<DynComparator, ArrowError> {
158 let left = left.as_dictionary::<K>();
159 let right = right.as_dictionary::<K>();
160
161 let c_opts = child_opts(opts);
162 let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
163 let left_keys = left.keys().values().clone();
164 let right_keys = right.keys().values().clone();
165
166 let f = compare(left, right, opts, move |i, j| {
167 let l = left_keys[i].as_usize();
168 let r = right_keys[j].as_usize();
169 cmp(l, r)
170 });
171 Ok(f)
172}
173
174fn compare_list<O: OffsetSizeTrait>(
175 left: &dyn Array,
176 right: &dyn Array,
177 opts: SortOptions,
178) -> Result<DynComparator, ArrowError> {
179 let left = left.as_list::<O>();
180 let right = right.as_list::<O>();
181
182 let c_opts = child_opts(opts);
183 let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
184
185 let l_o = left.offsets().clone();
186 let r_o = right.offsets().clone();
187 let f = compare(left, right, opts, move |i, j| {
188 let l_end = l_o[i + 1].as_usize();
189 let l_start = l_o[i].as_usize();
190
191 let r_end = r_o[j + 1].as_usize();
192 let r_start = r_o[j].as_usize();
193
194 for (i, j) in (l_start..l_end).zip(r_start..r_end) {
195 match cmp(i, j) {
196 Ordering::Equal => continue,
197 r => return r,
198 }
199 }
200 (l_end - l_start).cmp(&(r_end - r_start))
201 });
202 Ok(f)
203}
204
205fn compare_fixed_list(
206 left: &dyn Array,
207 right: &dyn Array,
208 opts: SortOptions,
209) -> Result<DynComparator, ArrowError> {
210 let left = left.as_fixed_size_list();
211 let right = right.as_fixed_size_list();
212
213 let c_opts = child_opts(opts);
214 let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
215
216 let l_size = left.value_length().to_usize().unwrap();
217 let r_size = right.value_length().to_usize().unwrap();
218 let size_cmp = l_size.cmp(&r_size);
219
220 let f = compare(left, right, opts, move |i, j| {
221 let l_start = i * l_size;
222 let l_end = l_start + l_size;
223 let r_start = j * r_size;
224 let r_end = r_start + r_size;
225 for (i, j) in (l_start..l_end).zip(r_start..r_end) {
226 match cmp(i, j) {
227 Ordering::Equal => continue,
228 r => return r,
229 }
230 }
231 size_cmp
232 });
233 Ok(f)
234}
235
236fn compare_map(
237 left: &dyn Array,
238 right: &dyn Array,
239 opts: SortOptions,
240) -> Result<DynComparator, ArrowError> {
241 let left = left.as_map();
242 let right = right.as_map();
243
244 let c_opts = child_opts(opts);
245 let cmp = make_comparator(left.entries(), right.entries(), c_opts)?;
246
247 let l_o = left.offsets().clone();
248 let r_o = right.offsets().clone();
249 let f = compare(left, right, opts, move |i, j| {
250 let l_end = l_o[i + 1].as_usize();
251 let l_start = l_o[i].as_usize();
252
253 let r_end = r_o[j + 1].as_usize();
254 let r_start = r_o[j].as_usize();
255
256 for (i, j) in (l_start..l_end).zip(r_start..r_end) {
257 match cmp(i, j) {
258 Ordering::Equal => continue,
259 r => return r,
260 }
261 }
262 (l_end - l_start).cmp(&(r_end - r_start))
263 });
264 Ok(f)
265}
266
267fn compare_struct(
268 left: &dyn Array,
269 right: &dyn Array,
270 opts: SortOptions,
271) -> Result<DynComparator, ArrowError> {
272 let left = left.as_struct();
273 let right = right.as_struct();
274
275 if left.columns().len() != right.columns().len() {
276 return Err(ArrowError::InvalidArgumentError(
277 "Cannot compare StructArray with different number of columns".to_string(),
278 ));
279 }
280
281 let c_opts = child_opts(opts);
282 let columns = left.columns().iter().zip(right.columns());
283 let comparators = columns
284 .map(|(l, r)| make_comparator(l, r, c_opts))
285 .collect::<Result<Vec<_>, _>>()?;
286
287 let f = compare(left, right, opts, move |i, j| {
288 for cmp in &comparators {
289 match cmp(i, j) {
290 Ordering::Equal => continue,
291 r => return r,
292 }
293 }
294 Ordering::Equal
295 });
296 Ok(f)
297}
298
299pub fn make_comparator(
369 left: &dyn Array,
370 right: &dyn Array,
371 opts: SortOptions,
372) -> Result<DynComparator, ArrowError> {
373 use arrow_schema::DataType::*;
374
375 macro_rules! primitive_helper {
376 ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
377 Ok(compare_primitive::<$t>($left, $right, $nulls_first))
378 };
379 }
380 downcast_primitive! {
381 left.data_type(), right.data_type() => (primitive_helper, left, right, opts),
382 (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
383 (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
384 (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, right, opts)),
385 (Utf8View, Utf8View) => Ok(compare_byte_view::<StringViewType>(left, right, opts)),
386 (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
387 (LargeBinary, LargeBinary) => Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
388 (BinaryView, BinaryView) => Ok(compare_byte_view::<BinaryViewType>(left, right, opts)),
389 (FixedSizeBinary(_), FixedSizeBinary(_)) => {
390 let left = left.as_fixed_size_binary();
391 let right = right.as_fixed_size_binary();
392
393 let l = left.clone();
394 let r = right.clone();
395 Ok(compare(left, right, opts, move |i, j| {
396 l.value(i).cmp(r.value(j))
397 }))
398 },
399 (List(_), List(_)) => compare_list::<i32>(left, right, opts),
400 (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
401 (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts),
402 (Struct(_), Struct(_)) => compare_struct(left, right, opts),
403 (Dictionary(l_key, _), Dictionary(r_key, _)) => {
404 macro_rules! dict_helper {
405 ($t:ty, $left:expr, $right:expr, $opts: expr) => {
406 compare_dict::<$t>($left, $right, $opts)
407 };
408 }
409 downcast_integer! {
410 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts),
411 _ => unreachable!()
412 }
413 },
414 (Map(_, _), Map(_, _)) => compare_map(left, right, opts),
415 (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
416 true => format!("The data type type {lhs:?} has no natural order"),
417 false => "Can't compare arrays of different types".to_string(),
418 }))
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
426 use arrow_buffer::{IntervalDayTime, OffsetBuffer, i256};
427 use arrow_schema::{DataType, Field, Fields};
428 use half::f16;
429 use std::sync::Arc;
430
431 #[test]
432 fn test_fixed_size_binary() {
433 let items = vec![vec![1u8], vec![2u8]];
434 let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
435
436 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
437
438 assert_eq!(Ordering::Less, cmp(0, 1));
439 }
440
441 #[test]
442 fn test_fixed_size_binary_fixed_size_binary() {
443 let items = vec![vec![1u8]];
444 let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
445 let items = vec![vec![2u8]];
446 let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
447
448 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
449
450 assert_eq!(Ordering::Less, cmp(0, 0));
451 }
452
453 #[test]
454 fn test_i32() {
455 let array = Int32Array::from(vec![1, 2]);
456
457 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
458
459 assert_eq!(Ordering::Less, (cmp)(0, 1));
460 }
461
462 #[test]
463 fn test_i32_i32() {
464 let array1 = Int32Array::from(vec![1]);
465 let array2 = Int32Array::from(vec![2]);
466
467 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
468
469 assert_eq!(Ordering::Less, cmp(0, 0));
470 }
471
472 #[test]
473 fn test_f16() {
474 let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]);
475
476 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
477
478 assert_eq!(Ordering::Less, cmp(0, 1));
479 }
480
481 #[test]
482 fn test_f64() {
483 let array = Float64Array::from(vec![1.0, 2.0]);
484
485 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
486
487 assert_eq!(Ordering::Less, cmp(0, 1));
488 }
489
490 #[test]
491 fn test_f64_nan() {
492 let array = Float64Array::from(vec![1.0, f64::NAN]);
493
494 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
495
496 assert_eq!(Ordering::Less, cmp(0, 1));
497 assert_eq!(Ordering::Equal, cmp(1, 1));
498 }
499
500 #[test]
501 fn test_f64_zeros() {
502 let array = Float64Array::from(vec![-0.0, 0.0]);
503
504 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
505
506 assert_eq!(Ordering::Less, cmp(0, 1));
507 assert_eq!(Ordering::Greater, cmp(1, 0));
508 }
509
510 #[test]
511 fn test_interval_day_time() {
512 let array = IntervalDayTimeArray::from(vec![
513 IntervalDayTimeType::make_value(0, 1000),
515 IntervalDayTimeType::make_value(1, 2),
517 IntervalDayTimeType::make_value(0, 90_000_000),
519 ]);
520
521 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
522
523 assert_eq!(Ordering::Less, cmp(0, 1));
524 assert_eq!(Ordering::Greater, cmp(1, 0));
525
526 assert_eq!(Ordering::Greater, cmp(1, 2));
530 assert_eq!(Ordering::Less, cmp(2, 1));
531 }
532
533 #[test]
534 fn test_interval_year_month() {
535 let array = IntervalYearMonthArray::from(vec![
536 IntervalYearMonthType::make_value(1, 0),
538 IntervalYearMonthType::make_value(0, 13),
540 IntervalYearMonthType::make_value(1, 1),
542 ]);
543
544 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
545
546 assert_eq!(Ordering::Less, cmp(0, 1));
547 assert_eq!(Ordering::Greater, cmp(1, 0));
548
549 assert_eq!(Ordering::Equal, cmp(1, 2));
551 assert_eq!(Ordering::Equal, cmp(2, 1));
552 }
553
554 #[test]
555 fn test_interval_month_day_nano() {
556 let array = IntervalMonthDayNanoArray::from(vec![
557 IntervalMonthDayNanoType::make_value(0, 100, 0),
559 IntervalMonthDayNanoType::make_value(1, 0, 0),
561 IntervalMonthDayNanoType::make_value(0, 100, 2),
563 ]);
564
565 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
566
567 assert_eq!(Ordering::Less, cmp(0, 1));
568 assert_eq!(Ordering::Greater, cmp(1, 0));
569
570 assert_eq!(Ordering::Greater, cmp(1, 2));
574 assert_eq!(Ordering::Less, cmp(2, 1));
575 }
576
577 #[test]
578 fn test_decimali32() {
579 let array = vec![Some(5_i32), Some(2_i32), Some(3_i32)]
580 .into_iter()
581 .collect::<Decimal32Array>()
582 .with_precision_and_scale(8, 6)
583 .unwrap();
584
585 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
586 assert_eq!(Ordering::Less, cmp(1, 0));
587 assert_eq!(Ordering::Greater, cmp(0, 2));
588 }
589
590 #[test]
591 fn test_decimali64() {
592 let array = vec![Some(5_i64), Some(2_i64), Some(3_i64)]
593 .into_iter()
594 .collect::<Decimal64Array>()
595 .with_precision_and_scale(16, 6)
596 .unwrap();
597
598 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
599 assert_eq!(Ordering::Less, cmp(1, 0));
600 assert_eq!(Ordering::Greater, cmp(0, 2));
601 }
602
603 #[test]
604 fn test_decimali128() {
605 let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)]
606 .into_iter()
607 .collect::<Decimal128Array>()
608 .with_precision_and_scale(23, 6)
609 .unwrap();
610
611 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
612 assert_eq!(Ordering::Less, cmp(1, 0));
613 assert_eq!(Ordering::Greater, cmp(0, 2));
614 }
615
616 #[test]
617 fn test_decimali256() {
618 let array = vec![
619 Some(i256::from_i128(5_i128)),
620 Some(i256::from_i128(2_i128)),
621 Some(i256::from_i128(3_i128)),
622 ]
623 .into_iter()
624 .collect::<Decimal256Array>()
625 .with_precision_and_scale(53, 6)
626 .unwrap();
627
628 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
629 assert_eq!(Ordering::Less, cmp(1, 0));
630 assert_eq!(Ordering::Greater, cmp(0, 2));
631 }
632
633 #[test]
634 fn test_dict() {
635 let data = vec!["a", "b", "c", "a", "a", "c", "c"];
636 let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
637
638 let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
639
640 assert_eq!(Ordering::Less, cmp(0, 1));
641 assert_eq!(Ordering::Equal, cmp(3, 4));
642 assert_eq!(Ordering::Greater, cmp(2, 3));
643 }
644
645 #[test]
646 fn test_multiple_dict() {
647 let d1 = vec!["a", "b", "c", "d"];
648 let a1 = d1.into_iter().collect::<DictionaryArray<Int16Type>>();
649 let d2 = vec!["e", "f", "g", "a"];
650 let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
651
652 let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
653
654 assert_eq!(Ordering::Less, cmp(0, 0));
655 assert_eq!(Ordering::Equal, cmp(0, 3));
656 assert_eq!(Ordering::Greater, cmp(1, 3));
657 }
658
659 #[test]
660 fn test_primitive_dict() {
661 let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
662 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
663 let array1 = DictionaryArray::new(keys, Arc::new(values));
664
665 let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
666 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
667 let array2 = DictionaryArray::new(keys, Arc::new(values));
668
669 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
670
671 assert_eq!(Ordering::Less, cmp(0, 0));
672 assert_eq!(Ordering::Less, cmp(0, 3));
673 assert_eq!(Ordering::Equal, cmp(3, 3));
674 assert_eq!(Ordering::Greater, cmp(3, 1));
675 assert_eq!(Ordering::Greater, cmp(3, 2));
676 }
677
678 #[test]
679 fn test_float_dict() {
680 let values = Float32Array::from(vec![1.0, 0.5, 2.1, 5.5]);
681 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
682 let array1 = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
683
684 let values = Float32Array::from(vec![1.2, 3.2, 4.0, 5.5]);
685 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
686 let array2 = DictionaryArray::new(keys, Arc::new(values));
687
688 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
689
690 assert_eq!(Ordering::Less, cmp(0, 0));
691 assert_eq!(Ordering::Less, cmp(0, 3));
692 assert_eq!(Ordering::Equal, cmp(3, 3));
693 assert_eq!(Ordering::Greater, cmp(3, 1));
694 assert_eq!(Ordering::Greater, cmp(3, 2));
695 }
696
697 #[test]
698 fn test_timestamp_dict() {
699 let values = TimestampSecondArray::from(vec![1, 0, 2, 5]);
700 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
701 let array1 = DictionaryArray::new(keys, Arc::new(values));
702
703 let values = TimestampSecondArray::from(vec![2, 3, 4, 5]);
704 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
705 let array2 = DictionaryArray::new(keys, Arc::new(values));
706
707 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
708
709 assert_eq!(Ordering::Less, cmp(0, 0));
710 assert_eq!(Ordering::Less, cmp(0, 3));
711 assert_eq!(Ordering::Equal, cmp(3, 3));
712 assert_eq!(Ordering::Greater, cmp(3, 1));
713 assert_eq!(Ordering::Greater, cmp(3, 2));
714 }
715
716 #[test]
717 fn test_interval_dict() {
718 let v1 = IntervalDayTime::new(0, 1);
719 let v2 = IntervalDayTime::new(0, 2);
720 let v3 = IntervalDayTime::new(12, 2);
721
722 let values = IntervalDayTimeArray::from(vec![Some(v1), Some(v2), None, Some(v3)]);
723 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
724 let array1 = DictionaryArray::new(keys, Arc::new(values));
725
726 let values = IntervalDayTimeArray::from(vec![Some(v3), Some(v2), None, Some(v1)]);
727 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
728 let array2 = DictionaryArray::new(keys, Arc::new(values));
729
730 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
731
732 assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Equal, cmp(0, 3)); assert_eq!(Ordering::Greater, cmp(3, 3)); assert_eq!(Ordering::Greater, cmp(3, 1)); assert_eq!(Ordering::Greater, cmp(3, 2)); }
738
739 #[test]
740 fn test_duration_dict() {
741 let values = DurationSecondArray::from(vec![1, 0, 2, 5]);
742 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
743 let array1 = DictionaryArray::new(keys, Arc::new(values));
744
745 let values = DurationSecondArray::from(vec![2, 3, 4, 5]);
746 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
747 let array2 = DictionaryArray::new(keys, Arc::new(values));
748
749 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
750
751 assert_eq!(Ordering::Less, cmp(0, 0));
752 assert_eq!(Ordering::Less, cmp(0, 3));
753 assert_eq!(Ordering::Equal, cmp(3, 3));
754 assert_eq!(Ordering::Greater, cmp(3, 1));
755 assert_eq!(Ordering::Greater, cmp(3, 2));
756 }
757
758 #[test]
759 fn test_decimal_dict() {
760 let values = Decimal128Array::from(vec![1, 0, 2, 5]);
761 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
762 let array1 = DictionaryArray::new(keys, Arc::new(values));
763
764 let values = Decimal128Array::from(vec![2, 3, 4, 5]);
765 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
766 let array2 = DictionaryArray::new(keys, Arc::new(values));
767
768 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
769
770 assert_eq!(Ordering::Less, cmp(0, 0));
771 assert_eq!(Ordering::Less, cmp(0, 3));
772 assert_eq!(Ordering::Equal, cmp(3, 3));
773 assert_eq!(Ordering::Greater, cmp(3, 1));
774 assert_eq!(Ordering::Greater, cmp(3, 2));
775 }
776
777 #[test]
778 fn test_decimal256_dict() {
779 let values = Decimal256Array::from(vec![
780 i256::from_i128(1),
781 i256::from_i128(0),
782 i256::from_i128(2),
783 i256::from_i128(5),
784 ]);
785 let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
786 let array1 = DictionaryArray::new(keys, Arc::new(values));
787
788 let values = Decimal256Array::from(vec![
789 i256::from_i128(2),
790 i256::from_i128(3),
791 i256::from_i128(4),
792 i256::from_i128(5),
793 ]);
794 let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
795 let array2 = DictionaryArray::new(keys, Arc::new(values));
796
797 let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
798
799 assert_eq!(Ordering::Less, cmp(0, 0));
800 assert_eq!(Ordering::Less, cmp(0, 3));
801 assert_eq!(Ordering::Equal, cmp(3, 3));
802 assert_eq!(Ordering::Greater, cmp(3, 1));
803 assert_eq!(Ordering::Greater, cmp(3, 2));
804 }
805
806 fn test_bytes_impl<T: ByteArrayType>() {
807 let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
808 let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
809 let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
810
811 assert_eq!(Ordering::Less, cmp(0, 1));
812 assert_eq!(Ordering::Greater, cmp(0, 2));
813 assert_eq!(Ordering::Equal, cmp(1, 1));
814 }
815
816 #[test]
817 fn test_bytes() {
818 test_bytes_impl::<Utf8Type>();
819 test_bytes_impl::<LargeUtf8Type>();
820 test_bytes_impl::<BinaryType>();
821 test_bytes_impl::<LargeBinaryType>();
822 }
823
824 #[test]
825 fn test_lists() {
826 let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
827 a.extend([
828 Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
829 Some(vec![
830 Some(vec![Some(1), Some(2), Some(3)]),
831 Some(vec![Some(1)]),
832 ]),
833 Some(vec![]),
834 ]);
835 let a = a.finish();
836 let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
837 b.extend([
838 Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
839 Some(vec![
840 Some(vec![Some(1), Some(2), None]),
841 Some(vec![Some(1)]),
842 ]),
843 Some(vec![
844 Some(vec![Some(1), Some(2), Some(3), Some(4)]),
845 Some(vec![Some(1)]),
846 ]),
847 None,
848 ]);
849 let b = b.finish();
850
851 let opts = SortOptions {
852 descending: false,
853 nulls_first: true,
854 };
855 let cmp = make_comparator(&a, &b, opts).unwrap();
856 assert_eq!(cmp(0, 0), Ordering::Equal);
857 assert_eq!(cmp(0, 1), Ordering::Less);
858 assert_eq!(cmp(0, 2), Ordering::Less);
859 assert_eq!(cmp(1, 2), Ordering::Less);
860 assert_eq!(cmp(1, 3), Ordering::Greater);
861 assert_eq!(cmp(2, 0), Ordering::Less);
862
863 let opts = SortOptions {
864 descending: true,
865 nulls_first: true,
866 };
867 let cmp = make_comparator(&a, &b, opts).unwrap();
868 assert_eq!(cmp(0, 0), Ordering::Equal);
869 assert_eq!(cmp(0, 1), Ordering::Less);
870 assert_eq!(cmp(0, 2), Ordering::Less);
871 assert_eq!(cmp(1, 2), Ordering::Greater);
872 assert_eq!(cmp(1, 3), Ordering::Greater);
873 assert_eq!(cmp(2, 0), Ordering::Greater);
874
875 let opts = SortOptions {
876 descending: true,
877 nulls_first: false,
878 };
879 let cmp = make_comparator(&a, &b, opts).unwrap();
880 assert_eq!(cmp(0, 0), Ordering::Equal);
881 assert_eq!(cmp(0, 1), Ordering::Greater);
882 assert_eq!(cmp(0, 2), Ordering::Greater);
883 assert_eq!(cmp(1, 2), Ordering::Greater);
884 assert_eq!(cmp(1, 3), Ordering::Less);
885 assert_eq!(cmp(2, 0), Ordering::Greater);
886
887 let opts = SortOptions {
888 descending: false,
889 nulls_first: false,
890 };
891 let cmp = make_comparator(&a, &b, opts).unwrap();
892 assert_eq!(cmp(0, 0), Ordering::Equal);
893 assert_eq!(cmp(0, 1), Ordering::Greater);
894 assert_eq!(cmp(0, 2), Ordering::Greater);
895 assert_eq!(cmp(1, 2), Ordering::Less);
896 assert_eq!(cmp(1, 3), Ordering::Less);
897 assert_eq!(cmp(2, 0), Ordering::Less);
898 }
899
900 #[test]
901 fn test_struct() {
902 let fields = Fields::from(vec![
903 Field::new("a", DataType::Int32, true),
904 Field::new_list("b", Field::new_list_field(DataType::Int32, true), true),
905 ]);
906
907 let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
908 let mut b = ListBuilder::new(Int32Builder::new());
909 b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
910 let b = b.finish();
911
912 let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
913 let values = vec![Arc::new(a) as _, Arc::new(b) as _];
914 let s1 = StructArray::new(fields.clone(), values, nulls);
915
916 let a = Int32Array::from(vec![None, Some(2), None]);
917 let mut b = ListBuilder::new(Int32Builder::new());
918 b.extend([None, None, Some(vec![])]);
919 let b = b.finish();
920
921 let values = vec![Arc::new(a) as _, Arc::new(b) as _];
922 let s2 = StructArray::new(fields.clone(), values, None);
923
924 let opts = SortOptions {
925 descending: false,
926 nulls_first: true,
927 };
928 let cmp = make_comparator(&s1, &s2, opts).unwrap();
929 assert_eq!(cmp(0, 1), Ordering::Less); assert_eq!(cmp(0, 0), Ordering::Greater); assert_eq!(cmp(1, 1), Ordering::Greater); assert_eq!(cmp(2, 2), Ordering::Less); assert_eq!(cmp(3, 0), Ordering::Less); assert_eq!(cmp(2, 0), Ordering::Equal); assert_eq!(cmp(3, 0), Ordering::Less); let opts = SortOptions {
938 descending: true,
939 nulls_first: true,
940 };
941 let cmp = make_comparator(&s1, &s2, opts).unwrap();
942 assert_eq!(cmp(0, 1), Ordering::Greater); assert_eq!(cmp(0, 0), Ordering::Greater); assert_eq!(cmp(1, 1), Ordering::Greater); assert_eq!(cmp(2, 2), Ordering::Less); assert_eq!(cmp(3, 0), Ordering::Less); assert_eq!(cmp(2, 0), Ordering::Equal); assert_eq!(cmp(3, 0), Ordering::Less); let opts = SortOptions {
951 descending: true,
952 nulls_first: false,
953 };
954 let cmp = make_comparator(&s1, &s2, opts).unwrap();
955 assert_eq!(cmp(0, 1), Ordering::Greater); assert_eq!(cmp(0, 0), Ordering::Less); assert_eq!(cmp(1, 1), Ordering::Less); assert_eq!(cmp(2, 2), Ordering::Greater); assert_eq!(cmp(3, 0), Ordering::Greater); assert_eq!(cmp(2, 0), Ordering::Equal); assert_eq!(cmp(3, 0), Ordering::Greater); let opts = SortOptions {
964 descending: false,
965 nulls_first: false,
966 };
967 let cmp = make_comparator(&s1, &s2, opts).unwrap();
968 assert_eq!(cmp(0, 1), Ordering::Less); assert_eq!(cmp(0, 0), Ordering::Less); assert_eq!(cmp(1, 1), Ordering::Less); assert_eq!(cmp(2, 2), Ordering::Greater); assert_eq!(cmp(3, 0), Ordering::Greater); assert_eq!(cmp(2, 0), Ordering::Equal); assert_eq!(cmp(3, 0), Ordering::Greater); }
976
977 #[test]
978 fn test_map() {
979 let string_builder = StringBuilder::new();
982 let int_builder = Int32Builder::new();
983 let mut map1_builder = MapBuilder::new(None, string_builder, int_builder);
984
985 map1_builder.keys().append_value("a");
987 map1_builder.values().append_value(100);
988 map1_builder.keys().append_value("b");
989 map1_builder.values().append_value(1);
990 map1_builder.append(true).unwrap();
991
992 map1_builder.keys().append_value("b");
994 map1_builder.values().append_value(999);
995 map1_builder.keys().append_value("c");
996 map1_builder.values().append_value(1);
997 map1_builder.append(true).unwrap();
998
999 map1_builder.append(true).unwrap();
1001
1002 map1_builder.keys().append_value("x");
1004 map1_builder.values().append_value(1);
1005 map1_builder.append(true).unwrap();
1006
1007 let map1 = map1_builder.finish();
1008
1009 let string_builder = StringBuilder::new();
1012 let int_builder = Int32Builder::new();
1013 let mut map2_builder = MapBuilder::new(None, string_builder, int_builder);
1014
1015 map2_builder.keys().append_value("a");
1017 map2_builder.values().append_value(1);
1018 map2_builder.keys().append_value("c");
1019 map2_builder.values().append_value(999);
1020 map2_builder.append(true).unwrap();
1021
1022 map2_builder.keys().append_value("b");
1024 map2_builder.values().append_value(1);
1025 map2_builder.keys().append_value("d");
1026 map2_builder.values().append_value(999);
1027 map2_builder.append(true).unwrap();
1028
1029 map2_builder.keys().append_value("a");
1031 map2_builder.values().append_value(1);
1032 map2_builder.append(true).unwrap();
1033
1034 map2_builder.append(false).unwrap();
1036
1037 let map2 = map2_builder.finish();
1038
1039 let opts = SortOptions {
1040 descending: false,
1041 nulls_first: true,
1042 };
1043 let cmp = make_comparator(&map1, &map2, opts).unwrap();
1044
1045 assert_eq!(cmp(0, 0), Ordering::Greater);
1049
1050 assert_eq!(cmp(1, 1), Ordering::Greater);
1053
1054 assert_eq!(cmp(0, 1), Ordering::Less);
1056
1057 assert_eq!(cmp(2, 2), Ordering::Less); assert_eq!(cmp(3, 3), Ordering::Greater); assert_eq!(cmp(3, 0), Ordering::Greater); assert_eq!(cmp(2, 0), Ordering::Less); let opts = SortOptions {
1070 descending: true,
1071 nulls_first: true,
1072 };
1073 let cmp = make_comparator(&map1, &map2, opts).unwrap();
1074
1075 assert_eq!(cmp(0, 0), Ordering::Less); assert_eq!(cmp(1, 1), Ordering::Less); assert_eq!(cmp(0, 1), Ordering::Greater); assert_eq!(cmp(3, 3), Ordering::Greater); assert_eq!(cmp(2, 2), Ordering::Greater); let opts = SortOptions {
1083 descending: false,
1084 nulls_first: false,
1085 };
1086 let cmp = make_comparator(&map1, &map2, opts).unwrap();
1087
1088 assert_eq!(cmp(0, 0), Ordering::Greater); assert_eq!(cmp(1, 1), Ordering::Greater); assert_eq!(cmp(3, 3), Ordering::Less); assert_eq!(cmp(2, 2), Ordering::Less); }
1094
1095 #[test]
1096 fn test_map_vs_list_consistency() {
1097 let string_builder = StringBuilder::new();
1100 let int_builder = Int32Builder::new();
1101 let mut map1_builder = MapBuilder::new(None, string_builder, int_builder);
1102
1103 map1_builder.keys().append_value("a");
1105 map1_builder.values().append_value(1);
1106 map1_builder.keys().append_value("b");
1107 map1_builder.values().append_value(2);
1108 map1_builder.append(true).unwrap();
1109
1110 map1_builder.keys().append_value("x");
1112 map1_builder.values().append_value(10);
1113 map1_builder.append(true).unwrap();
1114
1115 map1_builder.append(true).unwrap();
1117
1118 map1_builder.keys().append_value("c");
1120 map1_builder.values().append_value(3);
1121 map1_builder.append(true).unwrap();
1122
1123 let map1 = map1_builder.finish();
1124
1125 let string_builder = StringBuilder::new();
1127 let int_builder = Int32Builder::new();
1128 let mut map2_builder = MapBuilder::new(None, string_builder, int_builder);
1129
1130 map2_builder.keys().append_value("a");
1132 map2_builder.values().append_value(1);
1133 map2_builder.keys().append_value("b");
1134 map2_builder.values().append_value(2);
1135 map2_builder.append(true).unwrap();
1136
1137 map2_builder.keys().append_value("y");
1139 map2_builder.values().append_value(20);
1140 map2_builder.append(true).unwrap();
1141
1142 map2_builder.keys().append_value("d");
1144 map2_builder.values().append_value(4);
1145 map2_builder.append(true).unwrap();
1146
1147 map2_builder.append(false).unwrap();
1149
1150 let map2 = map2_builder.finish();
1151
1152 let list1: ListArray = map1.clone().into();
1154 let list2: ListArray = map2.clone().into();
1155
1156 let test_cases = [
1157 SortOptions {
1158 descending: false,
1159 nulls_first: true,
1160 },
1161 SortOptions {
1162 descending: true,
1163 nulls_first: true,
1164 },
1165 SortOptions {
1166 descending: false,
1167 nulls_first: false,
1168 },
1169 SortOptions {
1170 descending: true,
1171 nulls_first: false,
1172 },
1173 ];
1174
1175 for opts in test_cases {
1176 let map_cmp = make_comparator(&map1, &map2, opts).unwrap();
1177 let list_cmp = make_comparator(&list1, &list2, opts).unwrap();
1178
1179 for i in 0..map1.len() {
1181 for j in 0..map2.len() {
1182 let map_result = map_cmp(i, j);
1183 let list_result = list_cmp(i, j);
1184 assert_eq!(
1185 map_result, list_result,
1186 "Map comparison and List comparison should be equal for indices ({i}, {j}) with opts {opts:?}. Map: {map_result:?}, List: {list_result:?}"
1187 );
1188 }
1189 }
1190 }
1191 }
1192}