Skip to main content

arrow_ord/
cmp.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Comparison kernels for `Array`s.
19//!
20//! These kernels can leverage SIMD if available on your system.  Currently no runtime
21//! detection is provided, you should enable the specific SIMD intrinsics using
22//! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
23//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24//!
25
26use arrow_array::cast::AsArray;
27use arrow_array::types::{ByteArrayType, ByteViewType};
28use arrow_array::{
29    AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray,
30    GenericByteArray, GenericByteViewArray, downcast_primitive_array,
31};
32use arrow_buffer::bit_util::ceil;
33use arrow_buffer::{BooleanBuffer, NullBuffer};
34use arrow_schema::ArrowError;
35use arrow_select::take::take;
36use std::cmp::Ordering;
37use std::ops::Not;
38
39#[derive(Debug, Copy, Clone)]
40enum Op {
41    Equal,
42    NotEqual,
43    Less,
44    LessEqual,
45    Greater,
46    GreaterEqual,
47    Distinct,
48    NotDistinct,
49}
50
51impl std::fmt::Display for Op {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Op::Equal => write!(f, "=="),
55            Op::NotEqual => write!(f, "!="),
56            Op::Less => write!(f, "<"),
57            Op::LessEqual => write!(f, "<="),
58            Op::Greater => write!(f, ">"),
59            Op::GreaterEqual => write!(f, ">="),
60            Op::Distinct => write!(f, "IS DISTINCT FROM"),
61            Op::NotDistinct => write!(f, "IS NOT DISTINCT FROM"),
62        }
63    }
64}
65
66/// Perform `left == right` operation on two [`Datum`].
67///
68/// Comparing null values on either side will yield a null in the corresponding
69/// slot of the resulting [`BooleanArray`].
70///
71/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
72/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
73/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
74/// to treat them as equal, please normalize zeros before calling this kernel. See
75/// [`f32::total_cmp`] and [`f64::total_cmp`].
76///
77/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
78/// For comparisons involving nested types see [`crate::ord::make_comparator`]
79pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
80    compare_op(Op::Equal, lhs, rhs)
81}
82
83/// Perform `left != right` operation on two [`Datum`].
84///
85/// Comparing null values on either side will yield a null in the corresponding
86/// slot of the resulting [`BooleanArray`].
87///
88/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
89/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
90/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
91/// to treat them as equal, please normalize zeros before calling this kernel. See
92/// [`f32::total_cmp`] and [`f64::total_cmp`].
93///
94/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
95/// For comparisons involving nested types see [`crate::ord::make_comparator`]
96pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
97    compare_op(Op::NotEqual, lhs, rhs)
98}
99
100/// Perform `left < right` operation on two [`Datum`].
101///
102/// Comparing null values on either side will yield a null in the corresponding
103/// slot of the resulting [`BooleanArray`].
104///
105/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
106/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
107/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
108/// to treat them as equal, please normalize zeros before calling this kernel. See
109/// [`f32::total_cmp`] and [`f64::total_cmp`].
110///
111/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
112/// For comparisons involving nested types see [`crate::ord::make_comparator`]
113pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
114    compare_op(Op::Less, lhs, rhs)
115}
116
117/// Perform `left <= right` operation on two [`Datum`].
118///
119/// Comparing null values on either side will yield a null in the corresponding
120/// slot of the resulting [`BooleanArray`].
121///
122/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
123/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
124/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
125/// to treat them as equal, please normalize zeros before calling this kernel. See
126/// [`f32::total_cmp`] and [`f64::total_cmp`].
127///
128/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
129/// For comparisons involving nested types see [`crate::ord::make_comparator`]
130pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
131    compare_op(Op::LessEqual, lhs, rhs)
132}
133
134/// Perform `left > right` operation on two [`Datum`].
135///
136/// Comparing null values on either side will yield a null in the corresponding
137/// slot of the resulting [`BooleanArray`].
138///
139/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
140/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
141/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
142/// to treat them as equal, please normalize zeros before calling this kernel. See
143/// [`f32::total_cmp`] and [`f64::total_cmp`].
144///
145/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
146/// For comparisons involving nested types see [`crate::ord::make_comparator`]
147pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
148    compare_op(Op::Greater, lhs, rhs)
149}
150
151/// Perform `left >= right` operation on two [`Datum`].
152///
153/// Comparing null values on either side will yield a null in the corresponding
154/// slot of the resulting [`BooleanArray`].
155///
156/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
157/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
158/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
159/// to treat them as equal, please normalize zeros before calling this kernel. See
160/// [`f32::total_cmp`] and [`f64::total_cmp`].
161///
162/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
163/// For comparisons involving nested types see [`crate::ord::make_comparator`]
164pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
165    compare_op(Op::GreaterEqual, lhs, rhs)
166}
167
168/// Perform `left IS DISTINCT FROM right` operation on two [`Datum`]
169///
170/// [`distinct`] is similar to [`neq`], only differing in null handling. In particular, two
171/// operands are considered DISTINCT if they have a different value or if one of them is NULL
172/// and the other isn't. The result of [`distinct`] is never NULL.
173///
174/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
175/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
176/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
177/// to treat them as equal, please normalize zeros before calling this kernel. See
178/// [`f32::total_cmp`] and [`f64::total_cmp`].
179///
180/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
181/// For comparisons involving nested types see [`crate::ord::make_comparator`]
182pub fn distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
183    compare_op(Op::Distinct, lhs, rhs)
184}
185
186/// Perform `left IS NOT DISTINCT FROM right` operation on two [`Datum`]
187///
188/// [`not_distinct`] is similar to [`eq`], only differing in null handling. In particular, two
189/// operands are considered `NOT DISTINCT` if they have the same value or if both of them
190/// is NULL. The result of [`not_distinct`] is never NULL.
191///
192/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
193/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
194/// Note that totalOrder treats positive and negative zeros as different. If it is necessary
195/// to treat them as equal, please normalize zeros before calling this kernel. See
196/// [`f32::total_cmp`] and [`f64::total_cmp`].
197///
198/// Nested types, such as lists, are not supported as the null semantics are not well-defined.
199/// For comparisons involving nested types see [`crate::ord::make_comparator`]
200pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
201    compare_op(Op::NotDistinct, lhs, rhs)
202}
203
204/// Perform `op` on the provided `Datum`
205#[inline(never)]
206fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
207    use arrow_schema::DataType::*;
208    let (l, l_s) = lhs.get();
209    let (r, r_s) = rhs.get();
210
211    let l_len = l.len();
212    let r_len = r.len();
213
214    if l_len != r_len && !l_s && !r_s {
215        return Err(ArrowError::InvalidArgumentError(format!(
216            "Cannot compare arrays of different lengths, got {l_len} vs {r_len}"
217        )));
218    }
219
220    let len = match l_s {
221        true => r_len,
222        false => l_len,
223    };
224
225    let l_nulls = l.logical_nulls();
226    let r_nulls = r.logical_nulls();
227
228    let l_v = l.as_any_dictionary_opt();
229    let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
230    let l_t = l.data_type();
231
232    let r_v = r.as_any_dictionary_opt();
233    let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
234    let r_t = r.data_type();
235
236    if r_t.is_nested() || l_t.is_nested() {
237        return Err(ArrowError::InvalidArgumentError(format!(
238            "Nested comparison: {l_t} {op} {r_t} (hint: use make_comparator instead)"
239        )));
240    } else if l_t != r_t {
241        return Err(ArrowError::InvalidArgumentError(format!(
242            "Invalid comparison operation: {l_t} {op} {r_t}"
243        )));
244    }
245
246    // Defer computation as may not be necessary
247    let values = || -> BooleanBuffer {
248        let d = downcast_primitive_array! {
249            (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
250            (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
251            (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
252            (Utf8View, Utf8View) => apply(op, l.as_string_view(), l_s, l_v, r.as_string_view(), r_s, r_v),
253            (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
254            (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
255            (BinaryView, BinaryView) => apply(op, l.as_binary_view(), l_s, l_v, r.as_binary_view(), r_s, r_v),
256            (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
257            (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
258            (Null, Null) => None,
259            _ => unreachable!(),
260        };
261        d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
262    };
263
264    let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
265    let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
266    Ok(match (l_nulls, l_s, r_nulls, r_s) {
267        (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
268            // Either both sides are scalar or neither side is scalar
269            match op {
270                Op::Distinct => {
271                    let values = values();
272                    let l = l.inner().bit_chunks().iter_padded();
273                    let r = r.inner().bit_chunks().iter_padded();
274                    let ne = values.bit_chunks().iter_padded();
275
276                    let c = |((l, r), n)| (l ^ r) | (l & r & n);
277                    let buffer = l.zip(r).zip(ne).map(c).collect();
278                    BooleanBuffer::new(buffer, 0, len).into()
279                }
280                Op::NotDistinct => {
281                    let values = values();
282                    let l = l.inner().bit_chunks().iter_padded();
283                    let r = r.inner().bit_chunks().iter_padded();
284                    let e = values.bit_chunks().iter_padded();
285
286                    let c = |((l, r), e)| u64::not(l | r) | (l & r & e);
287                    let buffer = l.zip(r).zip(e).map(c).collect();
288                    BooleanBuffer::new(buffer, 0, len).into()
289                }
290                _ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))),
291            }
292        }
293        (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
294            // Scalar is null, other side is non-scalar and nullable
295            match op {
296                Op::Distinct => a.into_inner().into(),
297                Op::NotDistinct => a.into_inner().not().into(),
298                _ => BooleanArray::new_null(len),
299            }
300        }
301        (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => {
302            // Only one side is nullable
303            match is_scalar {
304                true => match op {
305                    // Scalar is null, other side is not nullable
306                    Op::Distinct => BooleanBuffer::new_set(len).into(),
307                    Op::NotDistinct => BooleanBuffer::new_unset(len).into(),
308                    _ => BooleanArray::new_null(len),
309                },
310                false => match op {
311                    Op::Distinct => {
312                        let values = values();
313                        let l = nulls.inner().bit_chunks().iter_padded();
314                        let ne = values.bit_chunks().iter_padded();
315                        let c = |(l, n)| u64::not(l) | n;
316                        let buffer = l.zip(ne).map(c).collect();
317                        BooleanBuffer::new(buffer, 0, len).into()
318                    }
319                    Op::NotDistinct => (nulls.inner() & &values()).into(),
320                    _ => BooleanArray::new(values(), Some(nulls)),
321                },
322            }
323        }
324        // Neither side is nullable
325        (None, _, None, _) => BooleanArray::new(values(), None),
326    })
327}
328
329/// Perform a potentially vectored `op` on the provided `ArrayOrd`
330fn apply<T: ArrayOrd>(
331    op: Op,
332    l: T,
333    l_s: bool,
334    l_v: Option<&dyn AnyDictionaryArray>,
335    r: T,
336    r_s: bool,
337    r_v: Option<&dyn AnyDictionaryArray>,
338) -> Option<BooleanBuffer> {
339    if l.len() == 0 || r.len() == 0 {
340        return None; // Handle empty dictionaries
341    }
342
343    if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
344        // Not scalar and at least one side has a dictionary, need to perform vectored comparison
345        let l_v = l_v
346            .map(|x| x.normalized_keys())
347            .unwrap_or_else(|| (0..l.len()).collect());
348
349        let r_v = r_v
350            .map(|x| x.normalized_keys())
351            .unwrap_or_else(|| (0..r.len()).collect());
352
353        assert_eq!(l_v.len(), r_v.len()); // Sanity check
354
355        Some(match op {
356            Op::Equal | Op::NotDistinct => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
357            Op::NotEqual | Op::Distinct => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq),
358            Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
359            Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt),
360            Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt),
361            Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt),
362        })
363    } else {
364        let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
365        let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default());
366
367        let buffer = match op {
368            Op::Equal | Op::NotDistinct => apply_op(l, l_s, r, r_s, false, T::is_eq),
369            Op::NotEqual | Op::Distinct => apply_op(l, l_s, r, r_s, true, T::is_eq),
370            Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
371            Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
372            Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
373            Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
374        };
375
376        // If a side had a dictionary, and was not scalar, we need to materialize this
377        Some(match (l_v, r_v) {
378            (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
379            (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
380            _ => buffer,
381        })
382    }
383}
384
385/// Perform a take operation on `buffer` with the given dictionary
386fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer {
387    let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap();
388    array.as_boolean().values().clone()
389}
390
391/// Invokes `f` with values `0..len` collecting the boolean results into a new `BooleanBuffer`
392///
393/// This is similar to [`arrow_buffer::MutableBuffer::collect_bool`] but with
394/// the option to efficiently negate the result
395fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer {
396    let mut buffer = Vec::with_capacity(ceil(len, 64));
397
398    let chunks = len / 64;
399    let remainder = len % 64;
400    buffer.extend((0..chunks).map(|chunk| {
401        let mut packed = 0;
402        for bit_idx in 0..64 {
403            let i = bit_idx + chunk * 64;
404            packed |= (f(i) as u64) << bit_idx;
405        }
406        if neg {
407            packed = !packed
408        }
409
410        packed
411    }));
412
413    if remainder != 0 {
414        let mut packed = 0;
415        for bit_idx in 0..remainder {
416            let i = bit_idx + chunks * 64;
417            packed |= (f(i) as u64) << bit_idx;
418        }
419        if neg {
420            packed = !packed
421        }
422
423        buffer.push(packed);
424    }
425    BooleanBuffer::new(buffer.into(), 0, len)
426}
427
428/// Applies `op` to possibly scalar `ArrayOrd`
429///
430/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the scalar value in `l`
431/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the scalar value in `r`
432///
433/// If `neg` is true the result of `op` will be negated
434fn apply_op<T: ArrayOrd>(
435    l: T,
436    l_s: Option<usize>,
437    r: T,
438    r_s: Option<usize>,
439    neg: bool,
440    op: impl Fn(T::Item, T::Item) -> bool,
441) -> BooleanBuffer {
442    match (l_s, r_s) {
443        (None, None) => {
444            assert_eq!(l.len(), r.len());
445            collect_bool(l.len(), neg, |idx| unsafe {
446                op(l.value_unchecked(idx), r.value_unchecked(idx))
447            })
448        }
449        (Some(l_s), Some(r_s)) => {
450            let a = l.value(l_s);
451            let b = r.value(r_s);
452            std::iter::once(op(a, b) ^ neg).collect()
453        }
454        (Some(l_s), None) => {
455            let v = l.value(l_s);
456            collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) }))
457        }
458        (None, Some(r_s)) => {
459            let v = r.value(r_s);
460            collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v))
461        }
462    }
463}
464
465/// Applies `op` to possibly scalar `ArrayOrd` with the given indices
466fn apply_op_vectored<T: ArrayOrd>(
467    l: T,
468    l_v: &[usize],
469    r: T,
470    r_v: &[usize],
471    neg: bool,
472    op: impl Fn(T::Item, T::Item) -> bool,
473) -> BooleanBuffer {
474    assert_eq!(l_v.len(), r_v.len());
475    collect_bool(l_v.len(), neg, |idx| unsafe {
476        let l_idx = *l_v.get_unchecked(idx);
477        let r_idx = *r_v.get_unchecked(idx);
478        op(l.value_unchecked(l_idx), r.value_unchecked(r_idx))
479    })
480}
481
482trait ArrayOrd {
483    type Item: Copy;
484
485    fn len(&self) -> usize;
486
487    fn value(&self, idx: usize) -> Self::Item {
488        assert!(idx < self.len());
489        unsafe { self.value_unchecked(idx) }
490    }
491
492    /// # Safety
493    ///
494    /// Safe if `idx < self.len()`
495    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item;
496
497    fn is_eq(l: Self::Item, r: Self::Item) -> bool;
498
499    fn is_lt(l: Self::Item, r: Self::Item) -> bool;
500}
501
502impl ArrayOrd for &BooleanArray {
503    type Item = bool;
504
505    fn len(&self) -> usize {
506        Array::len(self)
507    }
508
509    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
510        unsafe { BooleanArray::value_unchecked(self, idx) }
511    }
512
513    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
514        l == r
515    }
516
517    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
518        !l & r
519    }
520}
521
522impl<T: ArrowNativeTypeOp> ArrayOrd for &[T] {
523    type Item = T;
524
525    fn len(&self) -> usize {
526        (*self).len()
527    }
528
529    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
530        unsafe { *self.get_unchecked(idx) }
531    }
532
533    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
534        l.is_eq(r)
535    }
536
537    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
538        l.is_lt(r)
539    }
540}
541
542impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray<T> {
543    type Item = &'a [u8];
544
545    fn len(&self) -> usize {
546        Array::len(self)
547    }
548
549    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
550        unsafe { GenericByteArray::value_unchecked(self, idx).as_ref() }
551    }
552
553    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
554        l == r
555    }
556
557    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
558        l < r
559    }
560}
561
562impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
563    /// This is the item type for the GenericByteViewArray::compare
564    /// Item.0 is the array, Item.1 is the index
565    type Item = (&'a GenericByteViewArray<T>, usize);
566
567    #[inline(always)]
568    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
569        let l_view = unsafe { l.0.views().get_unchecked(l.1) };
570        let r_view = unsafe { r.0.views().get_unchecked(r.1) };
571        if l.0.data_buffers().is_empty() && r.0.data_buffers().is_empty() {
572            // For eq case, we can directly compare the inlined bytes
573            return l_view == r_view;
574        }
575
576        // Fast path for same view (and both inlined)
577        if l_view == r_view && *l_view as u32 <= 12 {
578            return true;
579        }
580
581        let l_len = *l_view as u32;
582        let r_len = *r_view as u32;
583        // Lengths differ
584        if l_len != r_len {
585            return false;
586        }
587
588        // Both are empty
589        if l_len == 0 {
590            return true;
591        }
592
593        // Check prefix
594        if (*l_view >> 32) as u32 != (*r_view >> 32) as u32 {
595            return false;
596        }
597
598        // Both are inlined, and prefixes are equal (so they differ in rest of inlined bytes)
599        if l_len <= 12 {
600            return false;
601        }
602
603        // # Safety
604        // The index is within bounds as it is checked in value()
605        unsafe {
606            let l_buffer_idx = (*l_view >> 64) as u32;
607            let l_offset = (*l_view >> 96) as u32;
608            let r_buffer_idx = (*r_view >> 64) as u32;
609            let r_offset = (*r_view >> 96) as u32;
610
611            let l_data = l.0.data_buffers().get_unchecked(l_buffer_idx as usize);
612            let r_data = r.0.data_buffers().get_unchecked(r_buffer_idx as usize);
613
614            let l_slice = l_data
615                .as_slice()
616                .get_unchecked(l_offset as usize..(l_offset + l_len) as usize);
617            let r_slice = r_data
618                .as_slice()
619                .get_unchecked(r_offset as usize..(r_offset + r_len) as usize);
620            l_slice == r_slice
621        }
622    }
623
624    #[inline(always)]
625    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
626        let l_view = unsafe { l.0.views().get_unchecked(l.1) };
627        let r_view = unsafe { r.0.views().get_unchecked(r.1) };
628
629        if l.0.data_buffers().is_empty() && r.0.data_buffers().is_empty() {
630            // For lt case, we can directly compare the inlined bytes
631            return GenericByteViewArray::<T>::inline_key_fast(*l_view)
632                < GenericByteViewArray::<T>::inline_key_fast(*r_view);
633        }
634
635        if (*l_view as u32) <= 12 && (*r_view as u32) <= 12 {
636            return GenericByteViewArray::<T>::inline_key_fast(*l_view)
637                < GenericByteViewArray::<T>::inline_key_fast(*r_view);
638        }
639
640        let l_prefix = (*l_view >> 32) as u32;
641        let r_prefix = (*r_view >> 32) as u32;
642        if l_prefix != r_prefix {
643            return l_prefix.swap_bytes() < r_prefix.swap_bytes();
644        }
645
646        // Fallback to the generic, unchecked comparison for mixed cases
647        // # Safety
648        // The index is within bounds as it is checked in value()
649        unsafe {
650            let l_data: &[u8] = l.0.value_unchecked(l.1).as_ref();
651            let r_data: &[u8] = r.0.value_unchecked(r.1).as_ref();
652            l_data < r_data
653        }
654    }
655
656    fn len(&self) -> usize {
657        Array::len(self)
658    }
659
660    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
661        (self, idx)
662    }
663}
664
665impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
666    type Item = &'a [u8];
667
668    fn len(&self) -> usize {
669        Array::len(self)
670    }
671
672    unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
673        unsafe { FixedSizeBinaryArray::value_unchecked(self, idx) }
674    }
675
676    fn is_eq(l: Self::Item, r: Self::Item) -> bool {
677        l == r
678    }
679
680    fn is_lt(l: Self::Item, r: Self::Item) -> bool {
681        l < r
682    }
683}
684
685/// Compares two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
686#[inline(always)]
687pub fn compare_byte_view<T: ByteViewType>(
688    left: &GenericByteViewArray<T>,
689    left_idx: usize,
690    right: &GenericByteViewArray<T>,
691    right_idx: usize,
692) -> Ordering {
693    assert!(left_idx < left.len());
694    assert!(right_idx < right.len());
695    if left.data_buffers().is_empty() && right.data_buffers().is_empty() {
696        let l_view = unsafe { left.views().get_unchecked(left_idx) };
697        let r_view = unsafe { right.views().get_unchecked(right_idx) };
698        return GenericByteViewArray::<T>::inline_key_fast(*l_view)
699            .cmp(&GenericByteViewArray::<T>::inline_key_fast(*r_view));
700    }
701    unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) }
702}
703
704#[cfg(test)]
705mod tests {
706    use std::sync::Arc;
707
708    use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray};
709    use arrow_buffer::{Buffer, ScalarBuffer};
710
711    use super::*;
712
713    #[test]
714    fn test_null_dict() {
715        let a = DictionaryArray::new(Int32Array::new_null(10), Arc::new(Int32Array::new_null(0)));
716        let r = eq(&a, &a).unwrap();
717        assert_eq!(r.null_count(), 10);
718
719        let a = DictionaryArray::new(
720            Int32Array::from(vec![1, 2, 3, 4, 5, 6]),
721            Arc::new(Int32Array::new_null(10)),
722        );
723        let r = eq(&a, &a).unwrap();
724        assert_eq!(r.null_count(), 6);
725
726        let scalar =
727            DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0)));
728        let r = eq(&a, &Scalar::new(&scalar)).unwrap();
729        assert_eq!(r.null_count(), 6);
730
731        let scalar =
732            DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0)));
733        let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap();
734        assert_eq!(r.null_count(), 1);
735
736        let a = DictionaryArray::new(
737            Int32Array::from(vec![0, 1, 2]),
738            Arc::new(Int32Array::from(vec![3, 2, 1])),
739        );
740        let r = eq(&a, &Scalar::new(&scalar)).unwrap();
741        assert_eq!(r.null_count(), 3);
742    }
743
744    #[test]
745    fn is_distinct_from_non_nulls() {
746        let left_int_array = Int32Array::from(vec![0, 1, 2, 3, 4]);
747        let right_int_array = Int32Array::from(vec![4, 3, 2, 1, 0]);
748
749        assert_eq!(
750            BooleanArray::from(vec![true, true, false, true, true,]),
751            distinct(&left_int_array, &right_int_array).unwrap()
752        );
753        assert_eq!(
754            BooleanArray::from(vec![false, false, true, false, false,]),
755            not_distinct(&left_int_array, &right_int_array).unwrap()
756        );
757    }
758
759    #[test]
760    fn is_distinct_from_nulls() {
761        // [0, 0, NULL, 0, 0, 0]
762        let left_int_array = Int32Array::new(
763            vec![0, 0, 1, 3, 0, 0].into(),
764            Some(NullBuffer::from(vec![true, true, false, true, true, true])),
765        );
766        // [0, NULL, NULL, NULL, 0, NULL]
767        let right_int_array = Int32Array::new(
768            vec![0; 6].into(),
769            Some(NullBuffer::from(vec![
770                true, false, false, false, true, false,
771            ])),
772        );
773
774        assert_eq!(
775            BooleanArray::from(vec![false, true, false, true, false, true,]),
776            distinct(&left_int_array, &right_int_array).unwrap()
777        );
778
779        assert_eq!(
780            BooleanArray::from(vec![true, false, true, false, true, false,]),
781            not_distinct(&left_int_array, &right_int_array).unwrap()
782        );
783    }
784
785    #[test]
786    fn test_distinct_scalar() {
787        let a = Int32Array::new_scalar(12);
788        let b = Int32Array::new_scalar(12);
789        assert!(!distinct(&a, &b).unwrap().value(0));
790        assert!(not_distinct(&a, &b).unwrap().value(0));
791
792        let a = Int32Array::new_scalar(12);
793        let b = Int32Array::new_null(1);
794        assert!(distinct(&a, &b).unwrap().value(0));
795        assert!(!not_distinct(&a, &b).unwrap().value(0));
796        assert!(distinct(&b, &a).unwrap().value(0));
797        assert!(!not_distinct(&b, &a).unwrap().value(0));
798
799        let b = Scalar::new(b);
800        assert!(distinct(&a, &b).unwrap().value(0));
801        assert!(!not_distinct(&a, &b).unwrap().value(0));
802
803        assert!(!distinct(&b, &b).unwrap().value(0));
804        assert!(not_distinct(&b, &b).unwrap().value(0));
805
806        let a = Int32Array::new(
807            vec![0, 1, 2, 3].into(),
808            Some(vec![false, false, true, true].into()),
809        );
810        let expected = BooleanArray::from(vec![false, false, true, true]);
811        assert_eq!(distinct(&a, &b).unwrap(), expected);
812        assert_eq!(distinct(&b, &a).unwrap(), expected);
813
814        let expected = BooleanArray::from(vec![true, true, false, false]);
815        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
816        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
817
818        let b = Int32Array::new_scalar(1);
819        let expected = BooleanArray::from(vec![true; 4]);
820        assert_eq!(distinct(&a, &b).unwrap(), expected);
821        assert_eq!(distinct(&b, &a).unwrap(), expected);
822        let expected = BooleanArray::from(vec![false; 4]);
823        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
824        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
825
826        let b = Int32Array::new_scalar(3);
827        let expected = BooleanArray::from(vec![true, true, true, false]);
828        assert_eq!(distinct(&a, &b).unwrap(), expected);
829        assert_eq!(distinct(&b, &a).unwrap(), expected);
830        let expected = BooleanArray::from(vec![false, false, false, true]);
831        assert_eq!(not_distinct(&a, &b).unwrap(), expected);
832        assert_eq!(not_distinct(&b, &a).unwrap(), expected);
833    }
834
835    #[test]
836    fn test_scalar_negation() {
837        let a = Int32Array::new_scalar(54);
838        let b = Int32Array::new_scalar(54);
839        let r = eq(&a, &b).unwrap();
840        assert!(r.value(0));
841
842        let r = neq(&a, &b).unwrap();
843        assert!(!r.value(0))
844    }
845
846    #[test]
847    fn test_scalar_empty() {
848        let a = Int32Array::new_null(0);
849        let b = Int32Array::new_scalar(23);
850        let r = eq(&a, &b).unwrap();
851        assert_eq!(r.len(), 0);
852        let r = eq(&b, &a).unwrap();
853        assert_eq!(r.len(), 0);
854    }
855
856    #[test]
857    fn test_dictionary_nulls() {
858        let values = StringArray::from(vec![Some("us-west"), Some("us-east")]);
859        let nulls = NullBuffer::from(vec![false, true, true]);
860
861        let key_values = vec![100i32, 1i32, 0i32].into();
862        let keys = Int32Array::new(key_values, Some(nulls));
863        let col = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();
864
865        neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap();
866    }
867
868    #[test]
869    fn test_string_view_mixed_lt() {
870        let a = arrow_array::StringViewArray::from(vec![
871            Some("apple"),
872            Some("apple"),
873            Some("apple_long_string"),
874        ]);
875        let b = arrow_array::StringViewArray::from(vec![
876            Some("apple_long_string"),
877            Some("appl"),
878            Some("apple"),
879        ]);
880        // "apple" < "apple_long_string" -> true
881        // "apple" < "appl" -> false
882        // "apple_long_string" < "apple" -> false
883        assert_eq!(
884            lt(&a, &b).unwrap(),
885            BooleanArray::from(vec![true, false, false])
886        );
887    }
888
889    #[test]
890    fn test_string_view_eq() {
891        let a = arrow_array::StringViewArray::from(vec![
892            Some("hello"),
893            Some("world"),
894            None,
895            Some("very long string exceeding 12 bytes"),
896        ]);
897        let b = arrow_array::StringViewArray::from(vec![
898            Some("hello"),
899            Some("world"),
900            None,
901            Some("very long string exceeding 12 bytes"),
902        ]);
903        assert_eq!(
904            eq(&a, &b).unwrap(),
905            BooleanArray::from(vec![Some(true), Some(true), None, Some(true)])
906        );
907
908        let c = arrow_array::StringViewArray::from(vec![
909            Some("hello"),
910            Some("world!"),
911            None,
912            Some("very long string exceeding 12 bytes!"),
913        ]);
914        assert_eq!(
915            eq(&a, &c).unwrap(),
916            BooleanArray::from(vec![Some(true), Some(false), None, Some(false)])
917        );
918    }
919
920    #[test]
921    fn test_string_view_lt() {
922        let a = arrow_array::StringViewArray::from(vec![
923            Some("apple"),
924            Some("banana"),
925            Some("very long apple exceeding 12 bytes"),
926            Some("very long banana exceeding 12 bytes"),
927        ]);
928        let b = arrow_array::StringViewArray::from(vec![
929            Some("banana"),
930            Some("apple"),
931            Some("very long banana exceeding 12 bytes"),
932            Some("very long apple exceeding 12 bytes"),
933        ]);
934        assert_eq!(
935            lt(&a, &b).unwrap(),
936            BooleanArray::from(vec![true, false, true, false])
937        );
938    }
939
940    #[test]
941    fn test_string_view_eq_prefix_mismatch() {
942        // Prefix mismatch should short-circuit equality for long values.
943        let a =
944            arrow_array::StringViewArray::from(vec![Some("very long apple exceeding 12 bytes")]);
945        let b =
946            arrow_array::StringViewArray::from(vec![Some("very long banana exceeding 12 bytes")]);
947        assert_eq!(eq(&a, &b).unwrap(), BooleanArray::from(vec![Some(false)]));
948    }
949
950    #[test]
951    fn test_string_view_lt_prefix_mismatch() {
952        // Prefix mismatch should decide ordering without full compare for long values.
953        let a =
954            arrow_array::StringViewArray::from(vec![Some("apple long string exceeding 12 bytes")]);
955        let b =
956            arrow_array::StringViewArray::from(vec![Some("banana long string exceeding 12 bytes")]);
957        assert_eq!(lt(&a, &b).unwrap(), BooleanArray::from(vec![true]));
958    }
959
960    #[test]
961    fn test_string_view_eq_inline_fast_path() {
962        // Inline-only arrays should compare by view equality fast path.
963        let a = arrow_array::StringViewArray::from(vec![Some("ab")]);
964        let b = arrow_array::StringViewArray::from(vec![Some("ab")]);
965        assert!(!has_buffers(&a));
966        assert!(!has_buffers(&b));
967        assert_eq!(eq(&a, &b).unwrap(), BooleanArray::from(vec![Some(true)]));
968    }
969
970    #[test]
971    fn test_string_view_eq_inline_prefix_mismatch_with_buffers() {
972        // Non-empty buffers force the prefix mismatch branch for inline values.
973        let a = arrow_array::StringViewArray::from(vec![
974            Some("ab"),
975            Some("long string to allocate buffers"),
976        ]);
977        let b = arrow_array::StringViewArray::from(vec![
978            Some("ac"),
979            Some("long string to allocate buffers"),
980        ]);
981        assert!(has_buffers(&a));
982        assert!(has_buffers(&b));
983        assert_eq!(
984            eq(&a, &b).unwrap(),
985            BooleanArray::from(vec![Some(false), Some(true)])
986        );
987    }
988
989    #[test]
990    fn test_string_view_eq_empty_len_branch() {
991        // Reach the zero-length branch by bypassing the inline fast path with a dummy buffer.
992        let raw_a = 0u128;
993        let raw_b = 1u128 << 96;
994        let views_a = ScalarBuffer::from(vec![raw_a]);
995        let views_b = ScalarBuffer::from(vec![raw_b]);
996        let buffers: Arc<[Buffer]> = Arc::from([Buffer::from_slice_ref([0u8])]);
997        let a =
998            unsafe { arrow_array::StringViewArray::new_unchecked(views_a, buffers.clone(), None) };
999        let b = unsafe { arrow_array::StringViewArray::new_unchecked(views_b, buffers, None) };
1000        assert!(has_buffers(&a));
1001        assert!(has_buffers(&b));
1002        assert!(<&arrow_array::StringViewArray as ArrayOrd>::is_eq(
1003            (&a, 0),
1004            (&b, 0)
1005        ));
1006    }
1007
1008    #[test]
1009    fn test_string_view_long_prefix_mismatch_array_ord() {
1010        // Long strings with differing prefixes should short-circuit on prefix ordering.
1011        let a =
1012            arrow_array::StringViewArray::from(vec![Some("apple long string exceeding 12 bytes")]);
1013        let b =
1014            arrow_array::StringViewArray::from(vec![Some("banana long string exceeding 12 bytes")]);
1015        assert!(has_buffers(&a));
1016        assert!(has_buffers(&b));
1017        assert!(<&arrow_array::StringViewArray as ArrayOrd>::is_lt(
1018            (&a, 0),
1019            (&b, 0)
1020        ));
1021    }
1022
1023    #[test]
1024    fn test_string_view_inline_mismatch_array_ord() {
1025        // Long strings with differing prefixes should short-circuit on prefix ordering.
1026        let a = arrow_array::StringViewArray::from(vec![Some("ap")]);
1027        let b = arrow_array::StringViewArray::from(vec![Some("ba")]);
1028        assert!(!has_buffers(&a));
1029        assert!(!has_buffers(&b));
1030        assert!(<&arrow_array::StringViewArray as ArrayOrd>::is_lt(
1031            (&a, 0),
1032            (&b, 0)
1033        ));
1034    }
1035    #[test]
1036    fn test_compare_byte_view_inline_fast_path() {
1037        // Inline-only views should compare via inline key in compare_byte_view.
1038        let a = arrow_array::StringViewArray::from(vec![Some("ab")]);
1039        let b = arrow_array::StringViewArray::from(vec![Some("ac")]);
1040        assert!(!has_buffers(&a));
1041        assert!(!has_buffers(&b));
1042        assert_eq!(compare_byte_view(&a, 0, &b, 0), Ordering::Less);
1043    }
1044
1045    fn has_buffers<T: ByteViewType>(array: &GenericByteViewArray<T>) -> bool {
1046        !array.data_buffers().is_empty()
1047    }
1048
1049    #[test]
1050    fn test_compare_byte_view() {
1051        let a = arrow_array::StringViewArray::from(vec![
1052            Some("apple"),
1053            Some("banana"),
1054            Some("very long apple exceeding 12 bytes"),
1055            Some("very long banana exceeding 12 bytes"),
1056        ]);
1057        let b = arrow_array::StringViewArray::from(vec![
1058            Some("apple"),
1059            Some("apple"),
1060            Some("very long apple exceeding 12 bytes"),
1061            Some("very long apple exceeding 12 bytes"),
1062        ]);
1063
1064        assert_eq!(compare_byte_view(&a, 0, &b, 0), Ordering::Equal);
1065        assert_eq!(compare_byte_view(&a, 1, &b, 1), Ordering::Greater);
1066        assert_eq!(compare_byte_view(&a, 2, &b, 2), Ordering::Equal);
1067        assert_eq!(compare_byte_view(&a, 3, &b, 3), Ordering::Greater);
1068    }
1069}