arrow_ord/
ord.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains functions and function factories to compare arrays.
19
20use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::*;
23use arrow_buffer::{ArrowNativeType, NullBuffer};
24use arrow_schema::{ArrowError, SortOptions};
25use std::cmp::Ordering;
26
27/// Compare the values at two arbitrary indices in two arrays.
28pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
29
30/// If parent sort order is descending we need to invert the value of nulls_first so that
31/// when the parent is sorted based on the produced ranks, nulls are still ordered correctly
32fn child_opts(opts: SortOptions) -> SortOptions {
33    SortOptions {
34        descending: false,
35        nulls_first: opts.nulls_first != opts.descending,
36    }
37}
38
39fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
40where
41    A: Array + Clone,
42    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
43{
44    let l = l.logical_nulls().filter(|x| x.null_count() > 0);
45    let r = r.logical_nulls().filter(|x| x.null_count() > 0);
46    match (opts.nulls_first, opts.descending) {
47        (true, true) => compare_impl::<true, true, _>(l, r, cmp),
48        (true, false) => compare_impl::<true, false, _>(l, r, cmp),
49        (false, true) => compare_impl::<false, true, _>(l, r, cmp),
50        (false, false) => compare_impl::<false, false, _>(l, r, cmp),
51    }
52}
53
54fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
55    l: Option<NullBuffer>,
56    r: Option<NullBuffer>,
57    cmp: F,
58) -> DynComparator
59where
60    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
61{
62    let cmp = move |i, j| match DESCENDING {
63        true => cmp(i, j).reverse(),
64        false => cmp(i, j),
65    };
66
67    let (left_null, right_null) = match NULLS_FIRST {
68        true => (Ordering::Less, Ordering::Greater),
69        false => (Ordering::Greater, Ordering::Less),
70    };
71
72    match (l, r) {
73        (None, None) => Box::new(cmp),
74        (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
75            true => left_null,
76            false => cmp(i, j),
77        }),
78        (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
79            true => right_null,
80            false => cmp(i, j),
81        }),
82        (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) {
83            (true, true) => Ordering::Equal,
84            (true, false) => left_null,
85            (false, true) => right_null,
86            (false, false) => cmp(i, j),
87        }),
88    }
89}
90
91fn compare_primitive<T: ArrowPrimitiveType>(
92    left: &dyn Array,
93    right: &dyn Array,
94    opts: SortOptions,
95) -> DynComparator
96where
97    T::Native: ArrowNativeTypeOp,
98{
99    let left = left.as_primitive::<T>();
100    let right = right.as_primitive::<T>();
101    let l_values = left.values().clone();
102    let r_values = right.values().clone();
103
104    compare(&left, &right, opts, move |i, j| {
105        l_values[i].compare(r_values[j])
106    })
107}
108
109fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator {
110    let left = left.as_boolean();
111    let right = right.as_boolean();
112
113    let l_values = left.values().clone();
114    let r_values = right.values().clone();
115
116    compare(left, right, opts, move |i, j| {
117        l_values.value(i).cmp(&r_values.value(j))
118    })
119}
120
121fn compare_bytes<T: ByteArrayType>(
122    left: &dyn Array,
123    right: &dyn Array,
124    opts: SortOptions,
125) -> DynComparator {
126    let left = left.as_bytes::<T>();
127    let right = right.as_bytes::<T>();
128
129    let l = left.clone();
130    let r = right.clone();
131    compare(left, right, opts, move |i, j| {
132        let l: &[u8] = l.value(i).as_ref();
133        let r: &[u8] = r.value(j).as_ref();
134        l.cmp(r)
135    })
136}
137
138fn compare_byte_view<T: ByteViewType>(
139    left: &dyn Array,
140    right: &dyn Array,
141    opts: SortOptions,
142) -> DynComparator {
143    let left = left.as_byte_view::<T>();
144    let right = right.as_byte_view::<T>();
145
146    let l = left.clone();
147    let r = right.clone();
148    compare(left, right, opts, move |i, j| {
149        crate::cmp::compare_byte_view(&l, i, &r, j)
150    })
151}
152
153fn compare_dict<K: ArrowDictionaryKeyType>(
154    left: &dyn Array,
155    right: &dyn Array,
156    opts: SortOptions,
157) -> Result<DynComparator, ArrowError> {
158    let left = left.as_dictionary::<K>();
159    let right = right.as_dictionary::<K>();
160
161    let c_opts = child_opts(opts);
162    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
163    let left_keys = left.keys().values().clone();
164    let right_keys = right.keys().values().clone();
165
166    let f = compare(left, right, opts, move |i, j| {
167        let l = left_keys[i].as_usize();
168        let r = right_keys[j].as_usize();
169        cmp(l, r)
170    });
171    Ok(f)
172}
173
174fn compare_list<O: OffsetSizeTrait>(
175    left: &dyn Array,
176    right: &dyn Array,
177    opts: SortOptions,
178) -> Result<DynComparator, ArrowError> {
179    let left = left.as_list::<O>();
180    let right = right.as_list::<O>();
181
182    let c_opts = child_opts(opts);
183    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
184
185    let l_o = left.offsets().clone();
186    let r_o = right.offsets().clone();
187    let f = compare(left, right, opts, move |i, j| {
188        let l_end = l_o[i + 1].as_usize();
189        let l_start = l_o[i].as_usize();
190
191        let r_end = r_o[j + 1].as_usize();
192        let r_start = r_o[j].as_usize();
193
194        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
195            match cmp(i, j) {
196                Ordering::Equal => continue,
197                r => return r,
198            }
199        }
200        (l_end - l_start).cmp(&(r_end - r_start))
201    });
202    Ok(f)
203}
204
205fn compare_fixed_list(
206    left: &dyn Array,
207    right: &dyn Array,
208    opts: SortOptions,
209) -> Result<DynComparator, ArrowError> {
210    let left = left.as_fixed_size_list();
211    let right = right.as_fixed_size_list();
212
213    let c_opts = child_opts(opts);
214    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
215
216    let l_size = left.value_length().to_usize().unwrap();
217    let r_size = right.value_length().to_usize().unwrap();
218    let size_cmp = l_size.cmp(&r_size);
219
220    let f = compare(left, right, opts, move |i, j| {
221        let l_start = i * l_size;
222        let l_end = l_start + l_size;
223        let r_start = j * r_size;
224        let r_end = r_start + r_size;
225        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
226            match cmp(i, j) {
227                Ordering::Equal => continue,
228                r => return r,
229            }
230        }
231        size_cmp
232    });
233    Ok(f)
234}
235
236fn compare_struct(
237    left: &dyn Array,
238    right: &dyn Array,
239    opts: SortOptions,
240) -> Result<DynComparator, ArrowError> {
241    let left = left.as_struct();
242    let right = right.as_struct();
243
244    if left.columns().len() != right.columns().len() {
245        return Err(ArrowError::InvalidArgumentError(
246            "Cannot compare StructArray with different number of columns".to_string(),
247        ));
248    }
249
250    let c_opts = child_opts(opts);
251    let columns = left.columns().iter().zip(right.columns());
252    let comparators = columns
253        .map(|(l, r)| make_comparator(l, r, c_opts))
254        .collect::<Result<Vec<_>, _>>()?;
255
256    let f = compare(left, right, opts, move |i, j| {
257        for cmp in &comparators {
258            match cmp(i, j) {
259                Ordering::Equal => continue,
260                r => return r,
261            }
262        }
263        Ordering::Equal
264    });
265    Ok(f)
266}
267
268/// Returns a comparison function that compares two values at two different positions
269/// between the two arrays.
270///
271/// For comparing arrays element-wise, see also the vectorised kernels in [`crate::cmp`].
272///
273/// If `nulls_first` is true `NULL` values will be considered less than any non-null value,
274/// otherwise they will be considered greater.
275///
276/// # Basic Usage
277///
278/// ```
279/// # use std::cmp::Ordering;
280/// # use arrow_array::Int32Array;
281/// # use arrow_ord::ord::make_comparator;
282/// # use arrow_schema::SortOptions;
283/// #
284/// let array1 = Int32Array::from(vec![1, 2]);
285/// let array2 = Int32Array::from(vec![3, 4]);
286///
287/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
288/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
289/// assert_eq!(cmp(0, 1), Ordering::Less);
290///
291/// let array1 = Int32Array::from(vec![Some(1), None]);
292/// let array2 = Int32Array::from(vec![None, Some(2)]);
293/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
294///
295/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
296/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
297/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
298/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
299/// ```
300///
301/// # Postgres-compatible Nested Comparison
302///
303/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a NULL yields
304/// a NULL, many systems, including postgres, instead apply a total ordering to comparison of
305/// nested nulls. That is nulls within nested types are either greater than any value (postgres),
306/// or less than any value (Spark).
307///
308/// In particular
309///
310/// ```ignore
311/// { a: 1, b: null } == { a: 1, b: null } => true
312/// { a: 1, b: null } == { a: 1, b: 1 } => false
313/// { a: 1, b: null } == null => null
314/// null == null => null
315/// ```
316///
317/// This could be implemented as below
318///
319/// ```
320/// # use arrow_array::{Array, BooleanArray};
321/// # use arrow_buffer::NullBuffer;
322/// # use arrow_ord::cmp;
323/// # use arrow_ord::ord::make_comparator;
324/// # use arrow_schema::{ArrowError, SortOptions};
325/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {
326///     if !a.data_type().is_nested() {
327///         return cmp::eq(&a, &b); // Use faster vectorised kernel
328///     }
329///
330///     let cmp = make_comparator(a, b, SortOptions::default())?;
331///     let len = a.len().min(b.len());
332///     let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
333///     let nulls = NullBuffer::union(a.nulls(), b.nulls());
334///     Ok(BooleanArray::new(values, nulls))
335/// }
336/// ````
337pub fn make_comparator(
338    left: &dyn Array,
339    right: &dyn Array,
340    opts: SortOptions,
341) -> Result<DynComparator, ArrowError> {
342    use arrow_schema::DataType::*;
343
344    macro_rules! primitive_helper {
345        ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
346            Ok(compare_primitive::<$t>($left, $right, $nulls_first))
347        };
348    }
349    downcast_primitive! {
350        left.data_type(), right.data_type() => (primitive_helper, left, right, opts),
351        (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
352        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
353        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, right, opts)),
354        (Utf8View, Utf8View) => Ok(compare_byte_view::<StringViewType>(left, right, opts)),
355        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
356        (LargeBinary, LargeBinary) => Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
357        (BinaryView, BinaryView) => Ok(compare_byte_view::<BinaryViewType>(left, right, opts)),
358        (FixedSizeBinary(_), FixedSizeBinary(_)) => {
359            let left = left.as_fixed_size_binary();
360            let right = right.as_fixed_size_binary();
361
362            let l = left.clone();
363            let r = right.clone();
364            Ok(compare(left, right, opts, move |i, j| {
365                l.value(i).cmp(r.value(j))
366            }))
367        },
368        (List(_), List(_)) => compare_list::<i32>(left, right, opts),
369        (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
370        (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts),
371        (Struct(_), Struct(_)) => compare_struct(left, right, opts),
372        (Dictionary(l_key, _), Dictionary(r_key, _)) => {
373             macro_rules! dict_helper {
374                ($t:ty, $left:expr, $right:expr, $opts: expr) => {
375                     compare_dict::<$t>($left, $right, $opts)
376                 };
377             }
378            downcast_integer! {
379                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts),
380                 _ => unreachable!()
381             }
382        },
383        (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
384            true => format!("The data type type {lhs:?} has no natural order"),
385            false => "Can't compare arrays of different types".to_string(),
386        }))
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use arrow_array::builder::{Int32Builder, ListBuilder};
394    use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer};
395    use arrow_schema::{DataType, Field, Fields};
396    use half::f16;
397    use std::sync::Arc;
398
399    #[test]
400    fn test_fixed_size_binary() {
401        let items = vec![vec![1u8], vec![2u8]];
402        let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
403
404        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
405
406        assert_eq!(Ordering::Less, cmp(0, 1));
407    }
408
409    #[test]
410    fn test_fixed_size_binary_fixed_size_binary() {
411        let items = vec![vec![1u8]];
412        let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
413        let items = vec![vec![2u8]];
414        let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
415
416        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
417
418        assert_eq!(Ordering::Less, cmp(0, 0));
419    }
420
421    #[test]
422    fn test_i32() {
423        let array = Int32Array::from(vec![1, 2]);
424
425        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
426
427        assert_eq!(Ordering::Less, (cmp)(0, 1));
428    }
429
430    #[test]
431    fn test_i32_i32() {
432        let array1 = Int32Array::from(vec![1]);
433        let array2 = Int32Array::from(vec![2]);
434
435        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
436
437        assert_eq!(Ordering::Less, cmp(0, 0));
438    }
439
440    #[test]
441    fn test_f16() {
442        let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]);
443
444        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
445
446        assert_eq!(Ordering::Less, cmp(0, 1));
447    }
448
449    #[test]
450    fn test_f64() {
451        let array = Float64Array::from(vec![1.0, 2.0]);
452
453        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
454
455        assert_eq!(Ordering::Less, cmp(0, 1));
456    }
457
458    #[test]
459    fn test_f64_nan() {
460        let array = Float64Array::from(vec![1.0, f64::NAN]);
461
462        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
463
464        assert_eq!(Ordering::Less, cmp(0, 1));
465        assert_eq!(Ordering::Equal, cmp(1, 1));
466    }
467
468    #[test]
469    fn test_f64_zeros() {
470        let array = Float64Array::from(vec![-0.0, 0.0]);
471
472        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
473
474        assert_eq!(Ordering::Less, cmp(0, 1));
475        assert_eq!(Ordering::Greater, cmp(1, 0));
476    }
477
478    #[test]
479    fn test_interval_day_time() {
480        let array = IntervalDayTimeArray::from(vec![
481            // 0 days, 1 second
482            IntervalDayTimeType::make_value(0, 1000),
483            // 1 day, 2 milliseconds
484            IntervalDayTimeType::make_value(1, 2),
485            // 90M milliseconds (which is more than is in 1 day)
486            IntervalDayTimeType::make_value(0, 90_000_000),
487        ]);
488
489        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
490
491        assert_eq!(Ordering::Less, cmp(0, 1));
492        assert_eq!(Ordering::Greater, cmp(1, 0));
493
494        // somewhat confusingly, while 90M milliseconds is more than 1 day,
495        // it will compare less as the comparison is done on the underlying
496        // values not field by field
497        assert_eq!(Ordering::Greater, cmp(1, 2));
498        assert_eq!(Ordering::Less, cmp(2, 1));
499    }
500
501    #[test]
502    fn test_interval_year_month() {
503        let array = IntervalYearMonthArray::from(vec![
504            // 1 year, 0 months
505            IntervalYearMonthType::make_value(1, 0),
506            // 0 years, 13 months
507            IntervalYearMonthType::make_value(0, 13),
508            // 1 year, 1 month
509            IntervalYearMonthType::make_value(1, 1),
510        ]);
511
512        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
513
514        assert_eq!(Ordering::Less, cmp(0, 1));
515        assert_eq!(Ordering::Greater, cmp(1, 0));
516
517        // the underlying representation is months, so both quantities are the same
518        assert_eq!(Ordering::Equal, cmp(1, 2));
519        assert_eq!(Ordering::Equal, cmp(2, 1));
520    }
521
522    #[test]
523    fn test_interval_month_day_nano() {
524        let array = IntervalMonthDayNanoArray::from(vec![
525            // 100 days
526            IntervalMonthDayNanoType::make_value(0, 100, 0),
527            // 1 month
528            IntervalMonthDayNanoType::make_value(1, 0, 0),
529            // 100 day, 1 nanoseconds
530            IntervalMonthDayNanoType::make_value(0, 100, 2),
531        ]);
532
533        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
534
535        assert_eq!(Ordering::Less, cmp(0, 1));
536        assert_eq!(Ordering::Greater, cmp(1, 0));
537
538        // somewhat confusingly, while 100 days is more than 1 month in all cases
539        // it will compare less as the comparison is done on the underlying
540        // values not field by field
541        assert_eq!(Ordering::Greater, cmp(1, 2));
542        assert_eq!(Ordering::Less, cmp(2, 1));
543    }
544
545    #[test]
546    fn test_decimal() {
547        let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)]
548            .into_iter()
549            .collect::<Decimal128Array>()
550            .with_precision_and_scale(23, 6)
551            .unwrap();
552
553        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
554        assert_eq!(Ordering::Less, cmp(1, 0));
555        assert_eq!(Ordering::Greater, cmp(0, 2));
556    }
557
558    #[test]
559    fn test_decimali256() {
560        let array = vec![
561            Some(i256::from_i128(5_i128)),
562            Some(i256::from_i128(2_i128)),
563            Some(i256::from_i128(3_i128)),
564        ]
565        .into_iter()
566        .collect::<Decimal256Array>()
567        .with_precision_and_scale(53, 6)
568        .unwrap();
569
570        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
571        assert_eq!(Ordering::Less, cmp(1, 0));
572        assert_eq!(Ordering::Greater, cmp(0, 2));
573    }
574
575    #[test]
576    fn test_dict() {
577        let data = vec!["a", "b", "c", "a", "a", "c", "c"];
578        let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
579
580        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
581
582        assert_eq!(Ordering::Less, cmp(0, 1));
583        assert_eq!(Ordering::Equal, cmp(3, 4));
584        assert_eq!(Ordering::Greater, cmp(2, 3));
585    }
586
587    #[test]
588    fn test_multiple_dict() {
589        let d1 = vec!["a", "b", "c", "d"];
590        let a1 = d1.into_iter().collect::<DictionaryArray<Int16Type>>();
591        let d2 = vec!["e", "f", "g", "a"];
592        let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
593
594        let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
595
596        assert_eq!(Ordering::Less, cmp(0, 0));
597        assert_eq!(Ordering::Equal, cmp(0, 3));
598        assert_eq!(Ordering::Greater, cmp(1, 3));
599    }
600
601    #[test]
602    fn test_primitive_dict() {
603        let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
604        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
605        let array1 = DictionaryArray::new(keys, Arc::new(values));
606
607        let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
608        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
609        let array2 = DictionaryArray::new(keys, Arc::new(values));
610
611        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
612
613        assert_eq!(Ordering::Less, cmp(0, 0));
614        assert_eq!(Ordering::Less, cmp(0, 3));
615        assert_eq!(Ordering::Equal, cmp(3, 3));
616        assert_eq!(Ordering::Greater, cmp(3, 1));
617        assert_eq!(Ordering::Greater, cmp(3, 2));
618    }
619
620    #[test]
621    fn test_float_dict() {
622        let values = Float32Array::from(vec![1.0, 0.5, 2.1, 5.5]);
623        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
624        let array1 = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
625
626        let values = Float32Array::from(vec![1.2, 3.2, 4.0, 5.5]);
627        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
628        let array2 = DictionaryArray::new(keys, Arc::new(values));
629
630        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
631
632        assert_eq!(Ordering::Less, cmp(0, 0));
633        assert_eq!(Ordering::Less, cmp(0, 3));
634        assert_eq!(Ordering::Equal, cmp(3, 3));
635        assert_eq!(Ordering::Greater, cmp(3, 1));
636        assert_eq!(Ordering::Greater, cmp(3, 2));
637    }
638
639    #[test]
640    fn test_timestamp_dict() {
641        let values = TimestampSecondArray::from(vec![1, 0, 2, 5]);
642        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
643        let array1 = DictionaryArray::new(keys, Arc::new(values));
644
645        let values = TimestampSecondArray::from(vec![2, 3, 4, 5]);
646        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
647        let array2 = DictionaryArray::new(keys, Arc::new(values));
648
649        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
650
651        assert_eq!(Ordering::Less, cmp(0, 0));
652        assert_eq!(Ordering::Less, cmp(0, 3));
653        assert_eq!(Ordering::Equal, cmp(3, 3));
654        assert_eq!(Ordering::Greater, cmp(3, 1));
655        assert_eq!(Ordering::Greater, cmp(3, 2));
656    }
657
658    #[test]
659    fn test_interval_dict() {
660        let v1 = IntervalDayTime::new(0, 1);
661        let v2 = IntervalDayTime::new(0, 2);
662        let v3 = IntervalDayTime::new(12, 2);
663
664        let values = IntervalDayTimeArray::from(vec![Some(v1), Some(v2), None, Some(v3)]);
665        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
666        let array1 = DictionaryArray::new(keys, Arc::new(values));
667
668        let values = IntervalDayTimeArray::from(vec![Some(v3), Some(v2), None, Some(v1)]);
669        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
670        let array2 = DictionaryArray::new(keys, Arc::new(values));
671
672        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
673
674        assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3
675        assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1
676        assert_eq!(Ordering::Greater, cmp(3, 3)); // v3 vs v1
677        assert_eq!(Ordering::Greater, cmp(3, 1)); // v3 vs v2
678        assert_eq!(Ordering::Greater, cmp(3, 2)); // v3 vs v2
679    }
680
681    #[test]
682    fn test_duration_dict() {
683        let values = DurationSecondArray::from(vec![1, 0, 2, 5]);
684        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
685        let array1 = DictionaryArray::new(keys, Arc::new(values));
686
687        let values = DurationSecondArray::from(vec![2, 3, 4, 5]);
688        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
689        let array2 = DictionaryArray::new(keys, Arc::new(values));
690
691        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
692
693        assert_eq!(Ordering::Less, cmp(0, 0));
694        assert_eq!(Ordering::Less, cmp(0, 3));
695        assert_eq!(Ordering::Equal, cmp(3, 3));
696        assert_eq!(Ordering::Greater, cmp(3, 1));
697        assert_eq!(Ordering::Greater, cmp(3, 2));
698    }
699
700    #[test]
701    fn test_decimal_dict() {
702        let values = Decimal128Array::from(vec![1, 0, 2, 5]);
703        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
704        let array1 = DictionaryArray::new(keys, Arc::new(values));
705
706        let values = Decimal128Array::from(vec![2, 3, 4, 5]);
707        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
708        let array2 = DictionaryArray::new(keys, Arc::new(values));
709
710        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
711
712        assert_eq!(Ordering::Less, cmp(0, 0));
713        assert_eq!(Ordering::Less, cmp(0, 3));
714        assert_eq!(Ordering::Equal, cmp(3, 3));
715        assert_eq!(Ordering::Greater, cmp(3, 1));
716        assert_eq!(Ordering::Greater, cmp(3, 2));
717    }
718
719    #[test]
720    fn test_decimal256_dict() {
721        let values = Decimal256Array::from(vec![
722            i256::from_i128(1),
723            i256::from_i128(0),
724            i256::from_i128(2),
725            i256::from_i128(5),
726        ]);
727        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
728        let array1 = DictionaryArray::new(keys, Arc::new(values));
729
730        let values = Decimal256Array::from(vec![
731            i256::from_i128(2),
732            i256::from_i128(3),
733            i256::from_i128(4),
734            i256::from_i128(5),
735        ]);
736        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
737        let array2 = DictionaryArray::new(keys, Arc::new(values));
738
739        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
740
741        assert_eq!(Ordering::Less, cmp(0, 0));
742        assert_eq!(Ordering::Less, cmp(0, 3));
743        assert_eq!(Ordering::Equal, cmp(3, 3));
744        assert_eq!(Ordering::Greater, cmp(3, 1));
745        assert_eq!(Ordering::Greater, cmp(3, 2));
746    }
747
748    fn test_bytes_impl<T: ByteArrayType>() {
749        let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
750        let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
751        let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
752
753        assert_eq!(Ordering::Less, cmp(0, 1));
754        assert_eq!(Ordering::Greater, cmp(0, 2));
755        assert_eq!(Ordering::Equal, cmp(1, 1));
756    }
757
758    #[test]
759    fn test_bytes() {
760        test_bytes_impl::<Utf8Type>();
761        test_bytes_impl::<LargeUtf8Type>();
762        test_bytes_impl::<BinaryType>();
763        test_bytes_impl::<LargeBinaryType>();
764    }
765
766    #[test]
767    fn test_lists() {
768        let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
769        a.extend([
770            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
771            Some(vec![
772                Some(vec![Some(1), Some(2), Some(3)]),
773                Some(vec![Some(1)]),
774            ]),
775            Some(vec![]),
776        ]);
777        let a = a.finish();
778        let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
779        b.extend([
780            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
781            Some(vec![
782                Some(vec![Some(1), Some(2), None]),
783                Some(vec![Some(1)]),
784            ]),
785            Some(vec![
786                Some(vec![Some(1), Some(2), Some(3), Some(4)]),
787                Some(vec![Some(1)]),
788            ]),
789            None,
790        ]);
791        let b = b.finish();
792
793        let opts = SortOptions {
794            descending: false,
795            nulls_first: true,
796        };
797        let cmp = make_comparator(&a, &b, opts).unwrap();
798        assert_eq!(cmp(0, 0), Ordering::Equal);
799        assert_eq!(cmp(0, 1), Ordering::Less);
800        assert_eq!(cmp(0, 2), Ordering::Less);
801        assert_eq!(cmp(1, 2), Ordering::Less);
802        assert_eq!(cmp(1, 3), Ordering::Greater);
803        assert_eq!(cmp(2, 0), Ordering::Less);
804
805        let opts = SortOptions {
806            descending: true,
807            nulls_first: true,
808        };
809        let cmp = make_comparator(&a, &b, opts).unwrap();
810        assert_eq!(cmp(0, 0), Ordering::Equal);
811        assert_eq!(cmp(0, 1), Ordering::Less);
812        assert_eq!(cmp(0, 2), Ordering::Less);
813        assert_eq!(cmp(1, 2), Ordering::Greater);
814        assert_eq!(cmp(1, 3), Ordering::Greater);
815        assert_eq!(cmp(2, 0), Ordering::Greater);
816
817        let opts = SortOptions {
818            descending: true,
819            nulls_first: false,
820        };
821        let cmp = make_comparator(&a, &b, opts).unwrap();
822        assert_eq!(cmp(0, 0), Ordering::Equal);
823        assert_eq!(cmp(0, 1), Ordering::Greater);
824        assert_eq!(cmp(0, 2), Ordering::Greater);
825        assert_eq!(cmp(1, 2), Ordering::Greater);
826        assert_eq!(cmp(1, 3), Ordering::Less);
827        assert_eq!(cmp(2, 0), Ordering::Greater);
828
829        let opts = SortOptions {
830            descending: false,
831            nulls_first: false,
832        };
833        let cmp = make_comparator(&a, &b, opts).unwrap();
834        assert_eq!(cmp(0, 0), Ordering::Equal);
835        assert_eq!(cmp(0, 1), Ordering::Greater);
836        assert_eq!(cmp(0, 2), Ordering::Greater);
837        assert_eq!(cmp(1, 2), Ordering::Less);
838        assert_eq!(cmp(1, 3), Ordering::Less);
839        assert_eq!(cmp(2, 0), Ordering::Less);
840    }
841
842    #[test]
843    fn test_struct() {
844        let fields = Fields::from(vec![
845            Field::new("a", DataType::Int32, true),
846            Field::new_list("b", Field::new_list_field(DataType::Int32, true), true),
847        ]);
848
849        let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
850        let mut b = ListBuilder::new(Int32Builder::new());
851        b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
852        let b = b.finish();
853
854        let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
855        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
856        let s1 = StructArray::new(fields.clone(), values, nulls);
857
858        let a = Int32Array::from(vec![None, Some(2), None]);
859        let mut b = ListBuilder::new(Int32Builder::new());
860        b.extend([None, None, Some(vec![])]);
861        let b = b.finish();
862
863        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
864        let s2 = StructArray::new(fields.clone(), values, None);
865
866        let opts = SortOptions {
867            descending: false,
868            nulls_first: true,
869        };
870        let cmp = make_comparator(&s1, &s2, opts).unwrap();
871        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
872        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
873        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
874        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
875        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
876        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
877        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
878
879        let opts = SortOptions {
880            descending: true,
881            nulls_first: true,
882        };
883        let cmp = make_comparator(&s1, &s2, opts).unwrap();
884        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
885        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
886        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
887        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
888        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
889        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
890        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
891
892        let opts = SortOptions {
893            descending: true,
894            nulls_first: false,
895        };
896        let cmp = make_comparator(&s1, &s2, opts).unwrap();
897        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
898        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
899        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
900        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
901        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
902        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
903        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
904
905        let opts = SortOptions {
906            descending: false,
907            nulls_first: false,
908        };
909        let cmp = make_comparator(&s1, &s2, opts).unwrap();
910        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
911        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
912        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
913        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
914        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
915        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
916        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
917    }
918}