Skip to main content

arrow_ord/
ord.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Contains functions and function factories to compare arrays.
19
20use arrow_array::cast::AsArray;
21use arrow_array::types::*;
22use arrow_array::*;
23use arrow_buffer::{ArrowNativeType, NullBuffer};
24use arrow_schema::{ArrowError, DataType, SortOptions};
25use std::{cmp::Ordering, collections::HashMap};
26
27/// Compare the values at two arbitrary indices in two arrays.
28pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
29
30/// If parent sort order is descending we need to invert the value of nulls_first so that
31/// when the parent is sorted based on the produced ranks, nulls are still ordered correctly
32fn child_opts(opts: SortOptions) -> SortOptions {
33    SortOptions {
34        descending: false,
35        nulls_first: opts.nulls_first != opts.descending,
36    }
37}
38
39fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
40where
41    A: Array + Clone,
42    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
43{
44    let l = l.logical_nulls().filter(|x| x.null_count() > 0);
45    let r = r.logical_nulls().filter(|x| x.null_count() > 0);
46    match (opts.nulls_first, opts.descending) {
47        (true, true) => compare_impl::<true, true, _>(l, r, cmp),
48        (true, false) => compare_impl::<true, false, _>(l, r, cmp),
49        (false, true) => compare_impl::<false, true, _>(l, r, cmp),
50        (false, false) => compare_impl::<false, false, _>(l, r, cmp),
51    }
52}
53
54fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
55    l: Option<NullBuffer>,
56    r: Option<NullBuffer>,
57    cmp: F,
58) -> DynComparator
59where
60    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
61{
62    let cmp = move |i, j| match DESCENDING {
63        true => cmp(i, j).reverse(),
64        false => cmp(i, j),
65    };
66
67    let (left_null, right_null) = match NULLS_FIRST {
68        true => (Ordering::Less, Ordering::Greater),
69        false => (Ordering::Greater, Ordering::Less),
70    };
71
72    match (l, r) {
73        (None, None) => Box::new(cmp),
74        (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
75            true => left_null,
76            false => cmp(i, j),
77        }),
78        (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
79            true => right_null,
80            false => cmp(i, j),
81        }),
82        (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) {
83            (true, true) => Ordering::Equal,
84            (true, false) => left_null,
85            (false, true) => right_null,
86            (false, false) => cmp(i, j),
87        }),
88    }
89}
90
91fn compare_primitive<T: ArrowPrimitiveType>(
92    left: &dyn Array,
93    right: &dyn Array,
94    opts: SortOptions,
95) -> DynComparator
96where
97    T::Native: ArrowNativeTypeOp,
98{
99    let left = left.as_primitive::<T>();
100    let right = right.as_primitive::<T>();
101    let l_values = left.values().clone();
102    let r_values = right.values().clone();
103
104    compare(&left, &right, opts, move |i, j| {
105        l_values[i].compare(r_values[j])
106    })
107}
108
109fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator {
110    let left = left.as_boolean();
111    let right = right.as_boolean();
112
113    let l_values = left.values().clone();
114    let r_values = right.values().clone();
115
116    compare(left, right, opts, move |i, j| {
117        l_values.value(i).cmp(&r_values.value(j))
118    })
119}
120
121fn compare_bytes<T: ByteArrayType>(
122    left: &dyn Array,
123    right: &dyn Array,
124    opts: SortOptions,
125) -> DynComparator {
126    let left = left.as_bytes::<T>();
127    let right = right.as_bytes::<T>();
128
129    let l = left.clone();
130    let r = right.clone();
131    compare(left, right, opts, move |i, j| {
132        let l: &[u8] = l.value(i).as_ref();
133        let r: &[u8] = r.value(j).as_ref();
134        l.cmp(r)
135    })
136}
137
138fn compare_byte_view<T: ByteViewType>(
139    left: &dyn Array,
140    right: &dyn Array,
141    opts: SortOptions,
142) -> DynComparator {
143    let left = left.as_byte_view::<T>();
144    let right = right.as_byte_view::<T>();
145
146    let l = left.clone();
147    let r = right.clone();
148    compare(left, right, opts, move |i, j| {
149        crate::cmp::compare_byte_view(&l, i, &r, j)
150    })
151}
152
153fn compare_dict<K: ArrowDictionaryKeyType>(
154    left: &dyn Array,
155    right: &dyn Array,
156    opts: SortOptions,
157) -> Result<DynComparator, ArrowError> {
158    let left = left.as_dictionary::<K>();
159    let right = right.as_dictionary::<K>();
160
161    let c_opts = child_opts(opts);
162    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
163    let left_keys = left.keys().values().clone();
164    let right_keys = right.keys().values().clone();
165
166    let f = compare(left, right, opts, move |i, j| {
167        let l = left_keys[i].as_usize();
168        let r = right_keys[j].as_usize();
169        cmp(l, r)
170    });
171    Ok(f)
172}
173
174fn compare_list<O: OffsetSizeTrait>(
175    left: &dyn Array,
176    right: &dyn Array,
177    opts: SortOptions,
178) -> Result<DynComparator, ArrowError> {
179    let left = left.as_list::<O>();
180    let right = right.as_list::<O>();
181
182    let c_opts = child_opts(opts);
183    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
184
185    let l_o = left.offsets().clone();
186    let r_o = right.offsets().clone();
187    let f = compare(left, right, opts, move |i, j| {
188        let l_end = l_o[i + 1].as_usize();
189        let l_start = l_o[i].as_usize();
190
191        let r_end = r_o[j + 1].as_usize();
192        let r_start = r_o[j].as_usize();
193
194        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
195            match cmp(i, j) {
196                Ordering::Equal => continue,
197                r => return r,
198            }
199        }
200        (l_end - l_start).cmp(&(r_end - r_start))
201    });
202    Ok(f)
203}
204
205fn compare_fixed_list(
206    left: &dyn Array,
207    right: &dyn Array,
208    opts: SortOptions,
209) -> Result<DynComparator, ArrowError> {
210    let left = left.as_fixed_size_list();
211    let right = right.as_fixed_size_list();
212
213    let c_opts = child_opts(opts);
214    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
215
216    let l_size = left.value_length().to_usize().unwrap();
217    let r_size = right.value_length().to_usize().unwrap();
218    let size_cmp = l_size.cmp(&r_size);
219
220    let f = compare(left, right, opts, move |i, j| {
221        let l_start = i * l_size;
222        let l_end = l_start + l_size;
223        let r_start = j * r_size;
224        let r_end = r_start + r_size;
225        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
226            match cmp(i, j) {
227                Ordering::Equal => continue,
228                r => return r,
229            }
230        }
231        size_cmp
232    });
233    Ok(f)
234}
235
236fn compare_list_view<O: OffsetSizeTrait>(
237    left: &dyn Array,
238    right: &dyn Array,
239    opts: SortOptions,
240) -> Result<DynComparator, ArrowError> {
241    let left = left.as_list_view::<O>();
242    let right = right.as_list_view::<O>();
243
244    let c_opts = child_opts(opts);
245    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?;
246
247    let l_offsets = left.offsets().clone();
248    let l_sizes = left.sizes().clone();
249    let r_offsets = right.offsets().clone();
250    let r_sizes = right.sizes().clone();
251
252    let f = compare(left, right, opts, move |i, j| {
253        let l_start = l_offsets[i].as_usize();
254        let l_len = l_sizes[i].as_usize();
255        let l_end = l_start + l_len;
256
257        let r_start = r_offsets[j].as_usize();
258        let r_len = r_sizes[j].as_usize();
259        let r_end = r_start + r_len;
260
261        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
262            match cmp(i, j) {
263                Ordering::Equal => continue,
264                r => return r,
265            }
266        }
267        l_len.cmp(&r_len)
268    });
269    Ok(f)
270}
271
272fn compare_map(
273    left: &dyn Array,
274    right: &dyn Array,
275    opts: SortOptions,
276) -> Result<DynComparator, ArrowError> {
277    let left = left.as_map();
278    let right = right.as_map();
279
280    let c_opts = child_opts(opts);
281    let cmp = make_comparator(left.entries(), right.entries(), c_opts)?;
282
283    let l_o = left.offsets().clone();
284    let r_o = right.offsets().clone();
285    let f = compare(left, right, opts, move |i, j| {
286        let l_end = l_o[i + 1].as_usize();
287        let l_start = l_o[i].as_usize();
288
289        let r_end = r_o[j + 1].as_usize();
290        let r_start = r_o[j].as_usize();
291
292        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
293            match cmp(i, j) {
294                Ordering::Equal => continue,
295                r => return r,
296            }
297        }
298        (l_end - l_start).cmp(&(r_end - r_start))
299    });
300    Ok(f)
301}
302
303fn compare_struct(
304    left: &dyn Array,
305    right: &dyn Array,
306    opts: SortOptions,
307) -> Result<DynComparator, ArrowError> {
308    let left = left.as_struct();
309    let right = right.as_struct();
310
311    if left.columns().len() != right.columns().len() {
312        return Err(ArrowError::InvalidArgumentError(
313            "Cannot compare StructArray with different number of columns".to_string(),
314        ));
315    }
316
317    let c_opts = child_opts(opts);
318    let columns = left.columns().iter().zip(right.columns());
319    let comparators = columns
320        .map(|(l, r)| make_comparator(l, r, c_opts))
321        .collect::<Result<Vec<_>, _>>()?;
322
323    let f = compare(left, right, opts, move |i, j| {
324        for cmp in &comparators {
325            match cmp(i, j) {
326                Ordering::Equal => continue,
327                r => return r,
328            }
329        }
330        Ordering::Equal
331    });
332    Ok(f)
333}
334
335fn compare_union(
336    left: &dyn Array,
337    right: &dyn Array,
338    opts: SortOptions,
339) -> Result<DynComparator, ArrowError> {
340    let left = left.as_union();
341    let right = right.as_union();
342
343    let (left_fields, left_mode) = match left.data_type() {
344        DataType::Union(fields, mode) => (fields, mode),
345        _ => unreachable!(),
346    };
347    let (right_fields, right_mode) = match right.data_type() {
348        DataType::Union(fields, mode) => (fields, mode),
349        _ => unreachable!(),
350    };
351
352    if left_fields != right_fields {
353        return Err(ArrowError::InvalidArgumentError(format!(
354            "Cannot compare UnionArrays with different fields: left={:?}, right={:?}",
355            left_fields, right_fields
356        )));
357    }
358
359    if left_mode != right_mode {
360        return Err(ArrowError::InvalidArgumentError(format!(
361            "Cannot compare UnionArrays with different modes: left={:?}, right={:?}",
362            left_mode, right_mode
363        )));
364    }
365
366    let c_opts = child_opts(opts);
367
368    let mut field_comparators = HashMap::with_capacity(left_fields.len());
369
370    for (type_id, _field) in left_fields.iter() {
371        let left_child = left.child(type_id);
372        let right_child = right.child(type_id);
373        let cmp = make_comparator(left_child.as_ref(), right_child.as_ref(), c_opts)?;
374
375        field_comparators.insert(type_id, cmp);
376    }
377
378    let left_type_ids = left.type_ids().clone();
379    let right_type_ids = right.type_ids().clone();
380
381    let left_offsets = left.offsets().cloned();
382    let right_offsets = right.offsets().cloned();
383
384    let f = compare(left, right, opts, move |i, j| {
385        let left_type_id = left_type_ids[i];
386        let right_type_id = right_type_ids[j];
387
388        // first, compare by type_id
389        match left_type_id.cmp(&right_type_id) {
390            Ordering::Equal => {
391                // second, compare by values
392                let left_offset = left_offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i);
393                let right_offset = right_offsets.as_ref().map(|o| o[j] as usize).unwrap_or(j);
394
395                let cmp = field_comparators
396                    .get(&left_type_id)
397                    .expect("type id not found in field_comparators");
398
399                cmp(left_offset, right_offset)
400            }
401            other => other,
402        }
403    });
404    Ok(f)
405}
406
407/// Returns a comparison function that compares two values at two different positions
408/// between the two arrays.
409///
410/// For comparing arrays element-wise, see also the vectorised kernels in [`crate::cmp`].
411///
412/// If `nulls_first` is true `NULL` values will be considered less than any non-null value,
413/// otherwise they will be considered greater.
414///
415/// # Basic Usage
416///
417/// ```
418/// # use std::cmp::Ordering;
419/// # use arrow_array::Int32Array;
420/// # use arrow_ord::ord::make_comparator;
421/// # use arrow_schema::SortOptions;
422/// #
423/// let array1 = Int32Array::from(vec![1, 2]);
424/// let array2 = Int32Array::from(vec![3, 4]);
425///
426/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
427/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
428/// assert_eq!(cmp(0, 1), Ordering::Less);
429///
430/// let array1 = Int32Array::from(vec![Some(1), None]);
431/// let array2 = Int32Array::from(vec![None, Some(2)]);
432/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
433///
434/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
435/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
436/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
437/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
438/// ```
439///
440/// # Postgres-compatible Nested Comparison
441///
442/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a NULL yields
443/// a NULL, many systems, including postgres, instead apply a total ordering to comparison of
444/// nested nulls. That is nulls within nested types are either greater than any value (postgres),
445/// or less than any value (Spark).
446///
447/// In particular
448///
449/// ```ignore
450/// { a: 1, b: null } == { a: 1, b: null } => true
451/// { a: 1, b: null } == { a: 1, b: 1 } => false
452/// { a: 1, b: null } == null => null
453/// null == null => null
454/// ```
455///
456/// This could be implemented as below
457///
458/// ```
459/// # use arrow_array::{Array, BooleanArray};
460/// # use arrow_buffer::NullBuffer;
461/// # use arrow_ord::cmp;
462/// # use arrow_ord::ord::make_comparator;
463/// # use arrow_schema::{ArrowError, SortOptions};
464/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {
465///     if !a.data_type().is_nested() {
466///         return cmp::eq(&a, &b); // Use faster vectorised kernel
467///     }
468///
469///     let cmp = make_comparator(a, b, SortOptions::default())?;
470///     let len = a.len().min(b.len());
471///     let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
472///     let nulls = NullBuffer::union(a.nulls(), b.nulls());
473///     Ok(BooleanArray::new(values, nulls))
474/// }
475/// ````
476pub fn make_comparator(
477    left: &dyn Array,
478    right: &dyn Array,
479    opts: SortOptions,
480) -> Result<DynComparator, ArrowError> {
481    use arrow_schema::DataType::*;
482
483    macro_rules! primitive_helper {
484        ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
485            Ok(compare_primitive::<$t>($left, $right, $nulls_first))
486        };
487    }
488    downcast_primitive! {
489        left.data_type(), right.data_type() => (primitive_helper, left, right, opts),
490        (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
491        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
492        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, right, opts)),
493        (Utf8View, Utf8View) => Ok(compare_byte_view::<StringViewType>(left, right, opts)),
494        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
495        (LargeBinary, LargeBinary) => Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
496        (BinaryView, BinaryView) => Ok(compare_byte_view::<BinaryViewType>(left, right, opts)),
497        (FixedSizeBinary(_), FixedSizeBinary(_)) => {
498            let left = left.as_fixed_size_binary();
499            let right = right.as_fixed_size_binary();
500
501            let l = left.clone();
502            let r = right.clone();
503            Ok(compare(left, right, opts, move |i, j| {
504                l.value(i).cmp(r.value(j))
505            }))
506        },
507        (List(_), List(_)) => compare_list::<i32>(left, right, opts),
508        (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
509        (ListView(_), ListView(_)) => compare_list_view::<i32>(left, right, opts),
510        (LargeListView(_), LargeListView(_)) => compare_list_view::<i64>(left, right, opts),
511        (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts),
512        (Struct(_), Struct(_)) => compare_struct(left, right, opts),
513        (Dictionary(l_key, _), Dictionary(r_key, _)) => {
514             macro_rules! dict_helper {
515                ($t:ty, $left:expr, $right:expr, $opts: expr) => {
516                     compare_dict::<$t>($left, $right, $opts)
517                 };
518             }
519            downcast_integer! {
520                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts),
521                 _ => unreachable!()
522             }
523        },
524        (Map(_, _), Map(_, _)) => compare_map(left, right, opts),
525        (Null, Null) => Ok(Box::new(|_, _| Ordering::Equal)),
526        (Union(_, _), Union(_, _)) => compare_union(left, right, opts),
527        (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
528            true => format!("The data type type {lhs:?} has no natural order"),
529            false => "Can't compare arrays of different types".to_string(),
530        }))
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use arrow_array::builder::{Int32Builder, ListBuilder, MapBuilder, StringBuilder};
538    use arrow_buffer::{IntervalDayTime, NullBuffer, OffsetBuffer, ScalarBuffer, i256};
539    use arrow_schema::{DataType, Field, Fields, UnionFields};
540    use half::f16;
541    use std::sync::Arc;
542
543    #[test]
544    fn test_fixed_size_binary() {
545        let items = vec![vec![1u8], vec![2u8]];
546        let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
547
548        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
549
550        assert_eq!(Ordering::Less, cmp(0, 1));
551    }
552
553    #[test]
554    fn test_fixed_size_binary_fixed_size_binary() {
555        let items = vec![vec![1u8]];
556        let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
557        let items = vec![vec![2u8]];
558        let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
559
560        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
561
562        assert_eq!(Ordering::Less, cmp(0, 0));
563    }
564
565    #[test]
566    fn test_i32() {
567        let array = Int32Array::from(vec![1, 2]);
568
569        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
570
571        assert_eq!(Ordering::Less, (cmp)(0, 1));
572    }
573
574    #[test]
575    fn test_i32_i32() {
576        let array1 = Int32Array::from(vec![1]);
577        let array2 = Int32Array::from(vec![2]);
578
579        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
580
581        assert_eq!(Ordering::Less, cmp(0, 0));
582    }
583
584    #[test]
585    fn test_f16() {
586        let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]);
587
588        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
589
590        assert_eq!(Ordering::Less, cmp(0, 1));
591    }
592
593    #[test]
594    fn test_f64() {
595        let array = Float64Array::from(vec![1.0, 2.0]);
596
597        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
598
599        assert_eq!(Ordering::Less, cmp(0, 1));
600    }
601
602    #[test]
603    fn test_f64_nan() {
604        let array = Float64Array::from(vec![1.0, f64::NAN]);
605
606        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
607
608        assert_eq!(Ordering::Less, cmp(0, 1));
609        assert_eq!(Ordering::Equal, cmp(1, 1));
610    }
611
612    #[test]
613    fn test_f64_zeros() {
614        let array = Float64Array::from(vec![-0.0, 0.0]);
615
616        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
617
618        assert_eq!(Ordering::Less, cmp(0, 1));
619        assert_eq!(Ordering::Greater, cmp(1, 0));
620    }
621
622    #[test]
623    fn test_interval_day_time() {
624        let array = IntervalDayTimeArray::from(vec![
625            // 0 days, 1 second
626            IntervalDayTimeType::make_value(0, 1000),
627            // 1 day, 2 milliseconds
628            IntervalDayTimeType::make_value(1, 2),
629            // 90M milliseconds (which is more than is in 1 day)
630            IntervalDayTimeType::make_value(0, 90_000_000),
631        ]);
632
633        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
634
635        assert_eq!(Ordering::Less, cmp(0, 1));
636        assert_eq!(Ordering::Greater, cmp(1, 0));
637
638        // somewhat confusingly, while 90M milliseconds is more than 1 day,
639        // it will compare less as the comparison is done on the underlying
640        // values not field by field
641        assert_eq!(Ordering::Greater, cmp(1, 2));
642        assert_eq!(Ordering::Less, cmp(2, 1));
643    }
644
645    #[test]
646    fn test_interval_year_month() {
647        let array = IntervalYearMonthArray::from(vec![
648            // 1 year, 0 months
649            IntervalYearMonthType::make_value(1, 0),
650            // 0 years, 13 months
651            IntervalYearMonthType::make_value(0, 13),
652            // 1 year, 1 month
653            IntervalYearMonthType::make_value(1, 1),
654        ]);
655
656        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
657
658        assert_eq!(Ordering::Less, cmp(0, 1));
659        assert_eq!(Ordering::Greater, cmp(1, 0));
660
661        // the underlying representation is months, so both quantities are the same
662        assert_eq!(Ordering::Equal, cmp(1, 2));
663        assert_eq!(Ordering::Equal, cmp(2, 1));
664    }
665
666    #[test]
667    fn test_interval_month_day_nano() {
668        let array = IntervalMonthDayNanoArray::from(vec![
669            // 100 days
670            IntervalMonthDayNanoType::make_value(0, 100, 0),
671            // 1 month
672            IntervalMonthDayNanoType::make_value(1, 0, 0),
673            // 100 day, 1 nanoseconds
674            IntervalMonthDayNanoType::make_value(0, 100, 2),
675        ]);
676
677        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
678
679        assert_eq!(Ordering::Less, cmp(0, 1));
680        assert_eq!(Ordering::Greater, cmp(1, 0));
681
682        // somewhat confusingly, while 100 days is more than 1 month in all cases
683        // it will compare less as the comparison is done on the underlying
684        // values not field by field
685        assert_eq!(Ordering::Greater, cmp(1, 2));
686        assert_eq!(Ordering::Less, cmp(2, 1));
687    }
688
689    #[test]
690    fn test_decimali32() {
691        let array = vec![Some(5_i32), Some(2_i32), Some(3_i32)]
692            .into_iter()
693            .collect::<Decimal32Array>()
694            .with_precision_and_scale(8, 6)
695            .unwrap();
696
697        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
698        assert_eq!(Ordering::Less, cmp(1, 0));
699        assert_eq!(Ordering::Greater, cmp(0, 2));
700    }
701
702    #[test]
703    fn test_decimali64() {
704        let array = vec![Some(5_i64), Some(2_i64), Some(3_i64)]
705            .into_iter()
706            .collect::<Decimal64Array>()
707            .with_precision_and_scale(16, 6)
708            .unwrap();
709
710        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
711        assert_eq!(Ordering::Less, cmp(1, 0));
712        assert_eq!(Ordering::Greater, cmp(0, 2));
713    }
714
715    #[test]
716    fn test_decimali128() {
717        let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)]
718            .into_iter()
719            .collect::<Decimal128Array>()
720            .with_precision_and_scale(23, 6)
721            .unwrap();
722
723        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
724        assert_eq!(Ordering::Less, cmp(1, 0));
725        assert_eq!(Ordering::Greater, cmp(0, 2));
726    }
727
728    #[test]
729    fn test_decimali256() {
730        let array = vec![
731            Some(i256::from_i128(5_i128)),
732            Some(i256::from_i128(2_i128)),
733            Some(i256::from_i128(3_i128)),
734        ]
735        .into_iter()
736        .collect::<Decimal256Array>()
737        .with_precision_and_scale(53, 6)
738        .unwrap();
739
740        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
741        assert_eq!(Ordering::Less, cmp(1, 0));
742        assert_eq!(Ordering::Greater, cmp(0, 2));
743    }
744
745    #[test]
746    fn test_dict() {
747        let data = vec!["a", "b", "c", "a", "a", "c", "c"];
748        let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
749
750        let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap();
751
752        assert_eq!(Ordering::Less, cmp(0, 1));
753        assert_eq!(Ordering::Equal, cmp(3, 4));
754        assert_eq!(Ordering::Greater, cmp(2, 3));
755    }
756
757    #[test]
758    fn test_multiple_dict() {
759        let d1 = vec!["a", "b", "c", "d"];
760        let a1 = d1.into_iter().collect::<DictionaryArray<Int16Type>>();
761        let d2 = vec!["e", "f", "g", "a"];
762        let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
763
764        let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
765
766        assert_eq!(Ordering::Less, cmp(0, 0));
767        assert_eq!(Ordering::Equal, cmp(0, 3));
768        assert_eq!(Ordering::Greater, cmp(1, 3));
769    }
770
771    #[test]
772    fn test_primitive_dict() {
773        let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
774        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
775        let array1 = DictionaryArray::new(keys, Arc::new(values));
776
777        let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
778        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
779        let array2 = DictionaryArray::new(keys, Arc::new(values));
780
781        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
782
783        assert_eq!(Ordering::Less, cmp(0, 0));
784        assert_eq!(Ordering::Less, cmp(0, 3));
785        assert_eq!(Ordering::Equal, cmp(3, 3));
786        assert_eq!(Ordering::Greater, cmp(3, 1));
787        assert_eq!(Ordering::Greater, cmp(3, 2));
788    }
789
790    #[test]
791    fn test_float_dict() {
792        let values = Float32Array::from(vec![1.0, 0.5, 2.1, 5.5]);
793        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
794        let array1 = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
795
796        let values = Float32Array::from(vec![1.2, 3.2, 4.0, 5.5]);
797        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
798        let array2 = DictionaryArray::new(keys, Arc::new(values));
799
800        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
801
802        assert_eq!(Ordering::Less, cmp(0, 0));
803        assert_eq!(Ordering::Less, cmp(0, 3));
804        assert_eq!(Ordering::Equal, cmp(3, 3));
805        assert_eq!(Ordering::Greater, cmp(3, 1));
806        assert_eq!(Ordering::Greater, cmp(3, 2));
807    }
808
809    #[test]
810    fn test_timestamp_dict() {
811        let values = TimestampSecondArray::from(vec![1, 0, 2, 5]);
812        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
813        let array1 = DictionaryArray::new(keys, Arc::new(values));
814
815        let values = TimestampSecondArray::from(vec![2, 3, 4, 5]);
816        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
817        let array2 = DictionaryArray::new(keys, Arc::new(values));
818
819        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
820
821        assert_eq!(Ordering::Less, cmp(0, 0));
822        assert_eq!(Ordering::Less, cmp(0, 3));
823        assert_eq!(Ordering::Equal, cmp(3, 3));
824        assert_eq!(Ordering::Greater, cmp(3, 1));
825        assert_eq!(Ordering::Greater, cmp(3, 2));
826    }
827
828    #[test]
829    fn test_interval_dict() {
830        let v1 = IntervalDayTime::new(0, 1);
831        let v2 = IntervalDayTime::new(0, 2);
832        let v3 = IntervalDayTime::new(12, 2);
833
834        let values = IntervalDayTimeArray::from(vec![Some(v1), Some(v2), None, Some(v3)]);
835        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
836        let array1 = DictionaryArray::new(keys, Arc::new(values));
837
838        let values = IntervalDayTimeArray::from(vec![Some(v3), Some(v2), None, Some(v1)]);
839        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
840        let array2 = DictionaryArray::new(keys, Arc::new(values));
841
842        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
843
844        assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3
845        assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1
846        assert_eq!(Ordering::Greater, cmp(3, 3)); // v3 vs v1
847        assert_eq!(Ordering::Greater, cmp(3, 1)); // v3 vs v2
848        assert_eq!(Ordering::Greater, cmp(3, 2)); // v3 vs v2
849    }
850
851    #[test]
852    fn test_duration_dict() {
853        let values = DurationSecondArray::from(vec![1, 0, 2, 5]);
854        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
855        let array1 = DictionaryArray::new(keys, Arc::new(values));
856
857        let values = DurationSecondArray::from(vec![2, 3, 4, 5]);
858        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
859        let array2 = DictionaryArray::new(keys, Arc::new(values));
860
861        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
862
863        assert_eq!(Ordering::Less, cmp(0, 0));
864        assert_eq!(Ordering::Less, cmp(0, 3));
865        assert_eq!(Ordering::Equal, cmp(3, 3));
866        assert_eq!(Ordering::Greater, cmp(3, 1));
867        assert_eq!(Ordering::Greater, cmp(3, 2));
868    }
869
870    #[test]
871    fn test_decimal_dict() {
872        let values = Decimal128Array::from(vec![1, 0, 2, 5]);
873        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
874        let array1 = DictionaryArray::new(keys, Arc::new(values));
875
876        let values = Decimal128Array::from(vec![2, 3, 4, 5]);
877        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
878        let array2 = DictionaryArray::new(keys, Arc::new(values));
879
880        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
881
882        assert_eq!(Ordering::Less, cmp(0, 0));
883        assert_eq!(Ordering::Less, cmp(0, 3));
884        assert_eq!(Ordering::Equal, cmp(3, 3));
885        assert_eq!(Ordering::Greater, cmp(3, 1));
886        assert_eq!(Ordering::Greater, cmp(3, 2));
887    }
888
889    #[test]
890    fn test_decimal256_dict() {
891        let values = Decimal256Array::from(vec![
892            i256::from_i128(1),
893            i256::from_i128(0),
894            i256::from_i128(2),
895            i256::from_i128(5),
896        ]);
897        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
898        let array1 = DictionaryArray::new(keys, Arc::new(values));
899
900        let values = Decimal256Array::from(vec![
901            i256::from_i128(2),
902            i256::from_i128(3),
903            i256::from_i128(4),
904            i256::from_i128(5),
905        ]);
906        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
907        let array2 = DictionaryArray::new(keys, Arc::new(values));
908
909        let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap();
910
911        assert_eq!(Ordering::Less, cmp(0, 0));
912        assert_eq!(Ordering::Less, cmp(0, 3));
913        assert_eq!(Ordering::Equal, cmp(3, 3));
914        assert_eq!(Ordering::Greater, cmp(3, 1));
915        assert_eq!(Ordering::Greater, cmp(3, 2));
916    }
917
918    fn test_bytes_impl<T: ByteArrayType>() {
919        let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
920        let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
921        let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
922
923        assert_eq!(Ordering::Less, cmp(0, 1));
924        assert_eq!(Ordering::Greater, cmp(0, 2));
925        assert_eq!(Ordering::Equal, cmp(1, 1));
926    }
927
928    #[test]
929    fn test_bytes() {
930        test_bytes_impl::<Utf8Type>();
931        test_bytes_impl::<LargeUtf8Type>();
932        test_bytes_impl::<BinaryType>();
933        test_bytes_impl::<LargeBinaryType>();
934    }
935
936    fn assert_cmp_cases<A: Array>(
937        array1: &A,
938        array2: &A,
939        opts: SortOptions,
940        cases: &[(usize, usize, Ordering)],
941    ) {
942        let cmp = make_comparator(array1, array2, opts).unwrap();
943        for (left, right, expected) in cases {
944            assert_eq!(cmp(*left, *right), *expected);
945        }
946    }
947
948    #[test]
949    fn test_lists() {
950        let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
951        a.extend([
952            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
953            Some(vec![
954                Some(vec![Some(1), Some(2), Some(3)]),
955                Some(vec![Some(1)]),
956            ]),
957            Some(vec![]),
958        ]);
959        let a = a.finish();
960        let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
961        b.extend([
962            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
963            Some(vec![
964                Some(vec![Some(1), Some(2), None]),
965                Some(vec![Some(1)]),
966            ]),
967            Some(vec![
968                Some(vec![Some(1), Some(2), Some(3), Some(4)]),
969                Some(vec![Some(1)]),
970            ]),
971            None,
972        ]);
973        let b = b.finish();
974
975        // Ascending with nulls first.
976        assert_cmp_cases(
977            &a,
978            &b,
979            SortOptions {
980                descending: false,
981                nulls_first: true,
982            },
983            &[
984                (0, 0, Ordering::Equal),
985                (0, 1, Ordering::Less),
986                (0, 2, Ordering::Less),
987                (1, 2, Ordering::Less),
988                (1, 3, Ordering::Greater),
989                (2, 0, Ordering::Less),
990            ],
991        );
992
993        // Descending with nulls first.
994        assert_cmp_cases(
995            &a,
996            &b,
997            SortOptions {
998                descending: true,
999                nulls_first: true,
1000            },
1001            &[
1002                (0, 0, Ordering::Equal),
1003                (0, 1, Ordering::Less),
1004                (0, 2, Ordering::Less),
1005                (1, 2, Ordering::Greater),
1006                (1, 3, Ordering::Greater),
1007                (2, 0, Ordering::Greater),
1008            ],
1009        );
1010
1011        // Descending with nulls last.
1012        assert_cmp_cases(
1013            &a,
1014            &b,
1015            SortOptions {
1016                descending: true,
1017                nulls_first: false,
1018            },
1019            &[
1020                (0, 0, Ordering::Equal),
1021                (0, 1, Ordering::Greater),
1022                (0, 2, Ordering::Greater),
1023                (1, 2, Ordering::Greater),
1024                (1, 3, Ordering::Less),
1025                (2, 0, Ordering::Greater),
1026            ],
1027        );
1028
1029        // Ascending with nulls last.
1030        assert_cmp_cases(
1031            &a,
1032            &b,
1033            SortOptions {
1034                descending: false,
1035                nulls_first: false,
1036            },
1037            &[
1038                (0, 0, Ordering::Equal),
1039                (0, 1, Ordering::Greater),
1040                (0, 2, Ordering::Greater),
1041                (1, 2, Ordering::Less),
1042                (1, 3, Ordering::Less),
1043                (2, 0, Ordering::Less),
1044            ],
1045        );
1046    }
1047
1048    fn list_view_array<O: OffsetSizeTrait>(
1049        values: Vec<i32>,
1050        offsets: &[usize],
1051        sizes: &[usize],
1052        valid: Option<&[bool]>,
1053    ) -> GenericListViewArray<O> {
1054        let offsets = offsets
1055            .iter()
1056            .map(|v| O::from_usize(*v).unwrap())
1057            .collect::<ScalarBuffer<O>>();
1058        let sizes = sizes
1059            .iter()
1060            .map(|v| O::from_usize(*v).unwrap())
1061            .collect::<ScalarBuffer<O>>();
1062        let field = Arc::new(Field::new_list_field(DataType::Int32, true));
1063        let values = Int32Array::from(values);
1064        let nulls = valid.map(NullBuffer::from);
1065        GenericListViewArray::new(field, offsets, sizes, Arc::new(values), nulls)
1066    }
1067
1068    fn test_list_view_comparisons<O: OffsetSizeTrait>() {
1069        let array = list_view_array::<O>(
1070            vec![1, 2, 3, 4, 5],
1071            &[0, 2, 1, 0, 3],
1072            &[2, 2, 2, 0, 2],
1073            Some(&[true, true, true, true, false]),
1074        );
1075
1076        // Ascending with nulls first (non-monotonic offsets and empty list).
1077        assert_cmp_cases(
1078            &array,
1079            &array,
1080            SortOptions {
1081                descending: false,
1082                nulls_first: true,
1083            },
1084            &[
1085                (0, 2, Ordering::Less),    // [1,2] < [2,3]
1086                (1, 2, Ordering::Greater), // [3,4] > [2,3]
1087                (3, 0, Ordering::Less),    // [] < [1,2]
1088                (4, 0, Ordering::Less),    // null < [1,2]
1089            ],
1090        );
1091
1092        // Ascending with nulls last.
1093        assert_cmp_cases(
1094            &array,
1095            &array,
1096            SortOptions {
1097                descending: false,
1098                nulls_first: false,
1099            },
1100            &[
1101                (0, 2, Ordering::Less),
1102                (1, 2, Ordering::Greater),
1103                (3, 0, Ordering::Less),
1104                (4, 0, Ordering::Greater), // null last
1105            ],
1106        );
1107
1108        // Descending with nulls first.
1109        assert_cmp_cases(
1110            &array,
1111            &array,
1112            SortOptions {
1113                descending: true,
1114                nulls_first: true,
1115            },
1116            &[
1117                (0, 2, Ordering::Greater),
1118                (1, 2, Ordering::Less),
1119                (3, 0, Ordering::Greater),
1120                (4, 0, Ordering::Less),
1121            ],
1122        );
1123
1124        // Descending with nulls last.
1125        assert_cmp_cases(
1126            &array,
1127            &array,
1128            SortOptions {
1129                descending: true,
1130                nulls_first: false,
1131            },
1132            &[
1133                (0, 2, Ordering::Greater),
1134                (1, 2, Ordering::Less),
1135                (3, 0, Ordering::Greater),
1136                (4, 0, Ordering::Greater),
1137            ],
1138        );
1139    }
1140
1141    #[test]
1142    fn test_list_view() {
1143        test_list_view_comparisons::<i32>();
1144    }
1145
1146    #[test]
1147    fn test_large_list_view() {
1148        test_list_view_comparisons::<i64>();
1149    }
1150
1151    #[test]
1152    fn test_struct() {
1153        let fields = Fields::from(vec![
1154            Field::new("a", DataType::Int32, true),
1155            Field::new_list("b", Field::new_list_field(DataType::Int32, true), true),
1156        ]);
1157
1158        let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
1159        let mut b = ListBuilder::new(Int32Builder::new());
1160        b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
1161        let b = b.finish();
1162
1163        let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
1164        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
1165        let s1 = StructArray::new(fields.clone(), values, nulls);
1166
1167        let a = Int32Array::from(vec![None, Some(2), None]);
1168        let mut b = ListBuilder::new(Int32Builder::new());
1169        b.extend([None, None, Some(vec![])]);
1170        let b = b.finish();
1171
1172        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
1173        let s2 = StructArray::new(fields.clone(), values, None);
1174
1175        let opts = SortOptions {
1176            descending: false,
1177            nulls_first: true,
1178        };
1179        let cmp = make_comparator(&s1, &s2, opts).unwrap();
1180        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
1181        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
1182        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
1183        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
1184        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
1185        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
1186        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
1187
1188        let opts = SortOptions {
1189            descending: true,
1190            nulls_first: true,
1191        };
1192        let cmp = make_comparator(&s1, &s2, opts).unwrap();
1193        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
1194        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
1195        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
1196        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
1197        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
1198        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
1199        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
1200
1201        let opts = SortOptions {
1202            descending: true,
1203            nulls_first: false,
1204        };
1205        let cmp = make_comparator(&s1, &s2, opts).unwrap();
1206        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
1207        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
1208        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
1209        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
1210        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
1211        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
1212        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
1213
1214        let opts = SortOptions {
1215            descending: false,
1216            nulls_first: false,
1217        };
1218        let cmp = make_comparator(&s1, &s2, opts).unwrap();
1219        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
1220        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
1221        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
1222        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
1223        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
1224        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
1225        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
1226    }
1227
1228    #[test]
1229    fn test_map() {
1230        // Create first map array demonstrating key priority over values:
1231        // [{"a": 100, "b": 1}, {"b": 999, "c": 1}, {}, {"x": 1}]
1232        let string_builder = StringBuilder::new();
1233        let int_builder = Int32Builder::new();
1234        let mut map1_builder = MapBuilder::new(None, string_builder, int_builder);
1235
1236        // {"a": 100, "b": 1} - high value for "a", low value for "b"
1237        map1_builder.keys().append_value("a");
1238        map1_builder.values().append_value(100);
1239        map1_builder.keys().append_value("b");
1240        map1_builder.values().append_value(1);
1241        map1_builder.append(true).unwrap();
1242
1243        // {"b": 999, "c": 1} - very high value for "b", low value for "c"
1244        map1_builder.keys().append_value("b");
1245        map1_builder.values().append_value(999);
1246        map1_builder.keys().append_value("c");
1247        map1_builder.values().append_value(1);
1248        map1_builder.append(true).unwrap();
1249
1250        // {}
1251        map1_builder.append(true).unwrap();
1252
1253        // {"x": 1}
1254        map1_builder.keys().append_value("x");
1255        map1_builder.values().append_value(1);
1256        map1_builder.append(true).unwrap();
1257
1258        let map1 = map1_builder.finish();
1259
1260        // Create second map array:
1261        // [{"a": 1, "c": 999}, {"b": 1, "d": 999}, {"a": 1}, None]
1262        let string_builder = StringBuilder::new();
1263        let int_builder = Int32Builder::new();
1264        let mut map2_builder = MapBuilder::new(None, string_builder, int_builder);
1265
1266        // {"a": 1, "c": 999} - low value for "a", high value for "c"
1267        map2_builder.keys().append_value("a");
1268        map2_builder.values().append_value(1);
1269        map2_builder.keys().append_value("c");
1270        map2_builder.values().append_value(999);
1271        map2_builder.append(true).unwrap();
1272
1273        // {"b": 1, "d": 999} - low value for "b", high value for "d"
1274        map2_builder.keys().append_value("b");
1275        map2_builder.values().append_value(1);
1276        map2_builder.keys().append_value("d");
1277        map2_builder.values().append_value(999);
1278        map2_builder.append(true).unwrap();
1279
1280        // {"a": 1}
1281        map2_builder.keys().append_value("a");
1282        map2_builder.values().append_value(1);
1283        map2_builder.append(true).unwrap();
1284
1285        // None
1286        map2_builder.append(false).unwrap();
1287
1288        let map2 = map2_builder.finish();
1289
1290        let opts = SortOptions {
1291            descending: false,
1292            nulls_first: true,
1293        };
1294        let cmp = make_comparator(&map1, &map2, opts).unwrap();
1295
1296        // Test that keys have priority over values:
1297        // {"a": 100, "b": 1} vs {"a": 1, "c": 999}
1298        // First entries match (a:100 vs a:1), but 100 > 1, so Greater
1299        assert_eq!(cmp(0, 0), Ordering::Greater);
1300
1301        // {"b": 999, "c": 1} vs {"b": 1, "d": 999}
1302        // First entries match (b:999 vs b:1), but 999 > 1, so Greater
1303        assert_eq!(cmp(1, 1), Ordering::Greater);
1304
1305        // Key comparison: "a" < "b", so {"a": 100, "b": 1} < {"b": 999, "c": 1}
1306        assert_eq!(cmp(0, 1), Ordering::Less);
1307
1308        // Empty map vs non-empty
1309        assert_eq!(cmp(2, 2), Ordering::Less); // {} < {"a": 1}
1310
1311        // Non-null vs null
1312        assert_eq!(cmp(3, 3), Ordering::Greater); // {"x": 1} > None
1313
1314        // Key priority test: "x" > "a", regardless of values
1315        assert_eq!(cmp(3, 0), Ordering::Greater); // {"x": 1} > {"a": 1, "c": 999}
1316
1317        // Empty vs non-empty
1318        assert_eq!(cmp(2, 0), Ordering::Less); // {} < {"a": 1, "c": 999}
1319
1320        let opts = SortOptions {
1321            descending: true,
1322            nulls_first: true,
1323        };
1324        let cmp = make_comparator(&map1, &map2, opts).unwrap();
1325
1326        // With descending=true, value comparison is reversed
1327        assert_eq!(cmp(0, 0), Ordering::Less); // {"a": 100, "b": 1} vs {"a": 1, "c": 999} (reversed)
1328        assert_eq!(cmp(1, 1), Ordering::Less); // {"b": 999, "c": 1} vs {"b": 1, "d": 999} (reversed)
1329        assert_eq!(cmp(0, 1), Ordering::Greater); // {"a": 100, "b": 1} vs {"b": 999, "c": 1} (key order reversed)
1330        assert_eq!(cmp(3, 3), Ordering::Greater); // {"x": 1} > None
1331        assert_eq!(cmp(2, 2), Ordering::Greater); // {} > {"a": 1} (reversed)
1332
1333        let opts = SortOptions {
1334            descending: false,
1335            nulls_first: false,
1336        };
1337        let cmp = make_comparator(&map1, &map2, opts).unwrap();
1338
1339        // Same key priority behavior with nulls_first=false
1340        assert_eq!(cmp(0, 0), Ordering::Greater); // {"a": 100, "b": 1} vs {"a": 1, "c": 999}
1341        assert_eq!(cmp(1, 1), Ordering::Greater); // {"b": 999, "c": 1} vs {"b": 1, "d": 999}
1342        assert_eq!(cmp(3, 3), Ordering::Less); // {"x": 1} < None (nulls last)
1343        assert_eq!(cmp(2, 2), Ordering::Less); // {} < {"a": 1}
1344    }
1345
1346    #[test]
1347    fn test_map_vs_list_consistency() {
1348        // Create map arrays and convert them to list arrays to verify comparison consistency
1349        // Map arrays: [{"a": 1, "b": 2}, {"x": 10}, {}, {"c": 3}]
1350        let string_builder = StringBuilder::new();
1351        let int_builder = Int32Builder::new();
1352        let mut map1_builder = MapBuilder::new(None, string_builder, int_builder);
1353
1354        // {"a": 1, "b": 2}
1355        map1_builder.keys().append_value("a");
1356        map1_builder.values().append_value(1);
1357        map1_builder.keys().append_value("b");
1358        map1_builder.values().append_value(2);
1359        map1_builder.append(true).unwrap();
1360
1361        // {"x": 10}
1362        map1_builder.keys().append_value("x");
1363        map1_builder.values().append_value(10);
1364        map1_builder.append(true).unwrap();
1365
1366        // {}
1367        map1_builder.append(true).unwrap();
1368
1369        // {"c": 3}
1370        map1_builder.keys().append_value("c");
1371        map1_builder.values().append_value(3);
1372        map1_builder.append(true).unwrap();
1373
1374        let map1 = map1_builder.finish();
1375
1376        // Second map array: [{"a": 1, "b": 2}, {"y": 20}, {"d": 4}, None]
1377        let string_builder = StringBuilder::new();
1378        let int_builder = Int32Builder::new();
1379        let mut map2_builder = MapBuilder::new(None, string_builder, int_builder);
1380
1381        // {"a": 1, "b": 2}
1382        map2_builder.keys().append_value("a");
1383        map2_builder.values().append_value(1);
1384        map2_builder.keys().append_value("b");
1385        map2_builder.values().append_value(2);
1386        map2_builder.append(true).unwrap();
1387
1388        // {"y": 20}
1389        map2_builder.keys().append_value("y");
1390        map2_builder.values().append_value(20);
1391        map2_builder.append(true).unwrap();
1392
1393        // {"d": 4}
1394        map2_builder.keys().append_value("d");
1395        map2_builder.values().append_value(4);
1396        map2_builder.append(true).unwrap();
1397
1398        // None
1399        map2_builder.append(false).unwrap();
1400
1401        let map2 = map2_builder.finish();
1402
1403        // Convert map arrays to list arrays (Map entries are struct arrays with key-value pairs)
1404        let list1: ListArray = map1.clone().into();
1405        let list2: ListArray = map2.clone().into();
1406
1407        let test_cases = [
1408            SortOptions {
1409                descending: false,
1410                nulls_first: true,
1411            },
1412            SortOptions {
1413                descending: true,
1414                nulls_first: true,
1415            },
1416            SortOptions {
1417                descending: false,
1418                nulls_first: false,
1419            },
1420            SortOptions {
1421                descending: true,
1422                nulls_first: false,
1423            },
1424        ];
1425
1426        for opts in test_cases {
1427            let map_cmp = make_comparator(&map1, &map2, opts).unwrap();
1428            let list_cmp = make_comparator(&list1, &list2, opts).unwrap();
1429
1430            // Test all possible index combinations
1431            for i in 0..map1.len() {
1432                for j in 0..map2.len() {
1433                    let map_result = map_cmp(i, j);
1434                    let list_result = list_cmp(i, j);
1435                    assert_eq!(
1436                        map_result, list_result,
1437                        "Map comparison and List comparison should be equal for indices ({i}, {j}) with opts {opts:?}. Map: {map_result:?}, List: {list_result:?}"
1438                    );
1439                }
1440            }
1441        }
1442    }
1443
1444    #[test]
1445    fn test_dense_union() {
1446        // create a dense union array with Int32 (type_id = 0) and Utf8 (type_id=1)
1447        // the values are: [1, "b", 2, "a", 3]
1448        //  type_ids are: [0,  1,  0,  1,  0]
1449        //   offsets are: [0, 0, 1, 1, 2] from [1, 2, 3] and ["b", "a"]
1450        let int_array = Int32Array::from(vec![1, 2, 3]);
1451        let str_array = StringArray::from(vec!["b", "a"]);
1452
1453        let type_ids = [0, 1, 0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1454        let offsets = [0, 0, 1, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
1455
1456        let union_fields = [
1457            (0, Arc::new(Field::new("A", DataType::Int32, false))),
1458            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
1459        ]
1460        .into_iter()
1461        .collect::<UnionFields>();
1462
1463        let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1464
1465        let array1 =
1466            UnionArray::try_new(union_fields.clone(), type_ids, Some(offsets), children).unwrap();
1467
1468        // create a second array: [2, "a", 1, "c"]
1469        //          type ids are: [0,  1,  0,  1]
1470        //           offsets are: [0, 0, 1, 1] from [2, 1] and ["a", "c"]
1471        let int_array2 = Int32Array::from(vec![2, 1]);
1472        let str_array2 = StringArray::from(vec!["a", "c"]);
1473        let type_ids2 = [0, 1, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1474        let offsets2 = [0, 0, 1, 1].into_iter().collect::<ScalarBuffer<i32>>();
1475
1476        let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];
1477
1478        let array2 =
1479            UnionArray::try_new(union_fields, type_ids2, Some(offsets2), children2).unwrap();
1480
1481        let opts = SortOptions {
1482            descending: false,
1483            nulls_first: true,
1484        };
1485
1486        // comparing
1487        // [1, "b", 2, "a", 3]
1488        // [2, "a", 1, "c"]
1489        let cmp = make_comparator(&array1, &array2, opts).unwrap();
1490
1491        // array1[0] = (type_id=0, value=1)
1492        // array2[0] = (type_id=0, value=2)
1493        assert_eq!(cmp(0, 0), Ordering::Less); // 1 < 2
1494
1495        // array1[0] = (type_id=0, value=1)
1496        // array2[1] = (type_id=1, value="a")
1497        assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1498
1499        // array1[1] = (type_id=1, value="b")
1500        // array2[1] = (type_id=1, value="a")
1501        assert_eq!(cmp(1, 1), Ordering::Greater); // "b" > "a"
1502
1503        // array1[2] = (type_id=0, value=2)
1504        // array2[0] = (type_id=0, value=2)
1505        assert_eq!(cmp(2, 0), Ordering::Equal); // 2 == 2
1506
1507        // array1[3] = (type_id=1, value="a")
1508        // array2[1] = (type_id=1, value="a")
1509        assert_eq!(cmp(3, 1), Ordering::Equal); // "a" == "a"
1510
1511        // array1[1] = (type_id=1, value="b")
1512        // array2[3] = (type_id=1, value="c")
1513        assert_eq!(cmp(1, 3), Ordering::Less); // "b" < "c"
1514
1515        let opts_desc = SortOptions {
1516            descending: true,
1517            nulls_first: true,
1518        };
1519        let cmp_desc = make_comparator(&array1, &array2, opts_desc).unwrap();
1520
1521        assert_eq!(cmp_desc(0, 0), Ordering::Greater); // 1 > 2 (reversed)
1522        assert_eq!(cmp_desc(0, 1), Ordering::Greater); // type_id 0 < 1, reversed to Greater
1523        assert_eq!(cmp_desc(1, 1), Ordering::Less); // "b" < "a" (reversed)
1524    }
1525
1526    #[test]
1527    fn test_sparse_union() {
1528        // create a sparse union array with Int32 (type_id=0) and Utf8 (type_id=1)
1529        // values: [1, "b", 3]
1530        // note, in sparse unions, child arrays have the same length as the union
1531        let int_array = Int32Array::from(vec![Some(1), None, Some(3)]);
1532        let str_array = StringArray::from(vec![None, Some("b"), None]);
1533        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1534
1535        let union_fields = [
1536            (0, Arc::new(Field::new("a", DataType::Int32, false))),
1537            (1, Arc::new(Field::new("b", DataType::Utf8, false))),
1538        ]
1539        .into_iter()
1540        .collect::<UnionFields>();
1541
1542        let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1543
1544        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
1545
1546        let opts = SortOptions::default();
1547        let cmp = make_comparator(&array, &array, opts).unwrap();
1548
1549        // array[0] = (type_id=0, value=1), array[2] = (type_id=0, value=3)
1550        assert_eq!(cmp(0, 2), Ordering::Less); // 1 < 3
1551        // array[0] = (type_id=0, value=1), array[1] = (type_id=1, value="b")
1552        assert_eq!(cmp(0, 1), Ordering::Less); // type_id 0 < 1
1553    }
1554
1555    #[test]
1556    #[should_panic(expected = "index out of bounds")]
1557    fn test_union_out_of_bounds() {
1558        // create a dense union array with 3 elements
1559        let int_array = Int32Array::from(vec![1, 2]);
1560        let str_array = StringArray::from(vec!["a"]);
1561
1562        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
1563        let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
1564
1565        let union_fields = [
1566            (0, Arc::new(Field::new("A", DataType::Int32, false))),
1567            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
1568        ]
1569        .into_iter()
1570        .collect::<UnionFields>();
1571
1572        let children = vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)];
1573
1574        let array = UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap();
1575
1576        let opts = SortOptions::default();
1577        let cmp = make_comparator(&array, &array, opts).unwrap();
1578
1579        // oob
1580        cmp(0, 3);
1581    }
1582
1583    #[test]
1584    fn test_union_incompatible_fields() {
1585        // create first union with Int32 and Utf8
1586        let int_array1 = Int32Array::from(vec![1, 2]);
1587        let str_array1 = StringArray::from(vec!["a", "b"]);
1588
1589        let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1590        let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1591
1592        let union_fields1 = [
1593            (0, Arc::new(Field::new("A", DataType::Int32, false))),
1594            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
1595        ]
1596        .into_iter()
1597        .collect::<UnionFields>();
1598
1599        let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];
1600
1601        let array1 =
1602            UnionArray::try_new(union_fields1, type_ids1, Some(offsets1), children1).unwrap();
1603
1604        // create second union with Int32 and Float64 (incompatible with first)
1605        let int_array2 = Int32Array::from(vec![3, 4]);
1606        let float_array2 = Float64Array::from(vec![1.0, 2.0]);
1607
1608        let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1609        let offsets2 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1610
1611        let union_fields2 = [
1612            (0, Arc::new(Field::new("A", DataType::Int32, false))),
1613            (1, Arc::new(Field::new("C", DataType::Float64, false))),
1614        ]
1615        .into_iter()
1616        .collect::<UnionFields>();
1617
1618        let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(float_array2)];
1619
1620        let array2 =
1621            UnionArray::try_new(union_fields2, type_ids2, Some(offsets2), children2).unwrap();
1622
1623        let opts = SortOptions::default();
1624
1625        let Result::Err(ArrowError::InvalidArgumentError(out)) =
1626            make_comparator(&array1, &array2, opts)
1627        else {
1628            panic!("expected error when making comparator of incompatible union arrays");
1629        };
1630
1631        assert_eq!(
1632            &out,
1633            "Cannot compare UnionArrays with different fields: left=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"B\", data_type: Utf8 })], right=[(0, Field { name: \"A\", data_type: Int32 }), (1, Field { name: \"C\", data_type: Float64 })]"
1634        );
1635    }
1636
1637    #[test]
1638    fn test_union_incompatible_modes() {
1639        // create first union as Dense with Int32 and Utf8
1640        let int_array1 = Int32Array::from(vec![1, 2]);
1641        let str_array1 = StringArray::from(vec!["a", "b"]);
1642
1643        let type_ids1 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1644        let offsets1 = [0, 0].into_iter().collect::<ScalarBuffer<i32>>();
1645
1646        let union_fields1 = [
1647            (0, Arc::new(Field::new("A", DataType::Int32, false))),
1648            (1, Arc::new(Field::new("B", DataType::Utf8, false))),
1649        ]
1650        .into_iter()
1651        .collect::<UnionFields>();
1652
1653        let children1 = vec![Arc::new(int_array1) as ArrayRef, Arc::new(str_array1)];
1654
1655        let array1 =
1656            UnionArray::try_new(union_fields1.clone(), type_ids1, Some(offsets1), children1)
1657                .unwrap();
1658
1659        // create second union as Sparse with same fields (Int32 and Utf8)
1660        let int_array2 = Int32Array::from(vec![Some(3), None]);
1661        let str_array2 = StringArray::from(vec![None, Some("c")]);
1662
1663        let type_ids2 = [0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1664
1665        let children2 = vec![Arc::new(int_array2) as ArrayRef, Arc::new(str_array2)];
1666
1667        let array2 = UnionArray::try_new(union_fields1, type_ids2, None, children2).unwrap();
1668
1669        let opts = SortOptions::default();
1670
1671        let Result::Err(ArrowError::InvalidArgumentError(out)) =
1672            make_comparator(&array1, &array2, opts)
1673        else {
1674            panic!("expected error when making comparator of union arrays with different modes");
1675        };
1676
1677        assert_eq!(
1678            &out,
1679            "Cannot compare UnionArrays with different modes: left=Dense, right=Sparse"
1680        );
1681    }
1682
1683    #[test]
1684    fn test_null_array_cmp() {
1685        let a = NullArray::new(3);
1686        let b = NullArray::new(3);
1687        let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
1688
1689        assert_eq!(cmp(0, 0), Ordering::Equal);
1690        assert_eq!(cmp(0, 1), Ordering::Equal);
1691        assert_eq!(cmp(2, 0), Ordering::Equal);
1692    }
1693}