arrow_data/equal/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Module containing functionality to compute array equality.
19//! This module uses [ArrayData] and does not
20//! depend on dynamic casting of `Array`.
21
22use crate::data::ArrayData;
23use arrow_buffer::i256;
24use arrow_schema::{DataType, IntervalUnit};
25use half::f16;
26
27mod boolean;
28mod byte_view;
29mod dictionary;
30mod fixed_binary;
31mod fixed_list;
32mod list;
33mod null;
34mod primitive;
35mod run;
36mod structure;
37mod union;
38mod utils;
39mod variable_size;
40
41// these methods assume the same type, len and null count.
42// For this reason, they are not exposed and are instead used
43// to build the generic functions below (`equal_range` and `equal`).
44use boolean::boolean_equal;
45use byte_view::byte_view_equal;
46use dictionary::dictionary_equal;
47use fixed_binary::fixed_binary_equal;
48use fixed_list::fixed_list_equal;
49use list::list_equal;
50use null::null_equal;
51use primitive::primitive_equal;
52use structure::struct_equal;
53use union::union_equal;
54use variable_size::variable_sized_equal;
55
56use self::run::run_equal;
57
58/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively
59/// for `len` slots.
60#[inline]
61fn equal_values(
62    lhs: &ArrayData,
63    rhs: &ArrayData,
64    lhs_start: usize,
65    rhs_start: usize,
66    len: usize,
67) -> bool {
68    match lhs.data_type() {
69        DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len),
70        DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len),
71        DataType::UInt8 => primitive_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len),
72        DataType::UInt16 => primitive_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len),
73        DataType::UInt32 => primitive_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len),
74        DataType::UInt64 => primitive_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len),
75        DataType::Int8 => primitive_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
76        DataType::Int16 => primitive_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len),
77        DataType::Int32 => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
78        DataType::Int64 => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
79        DataType::Float32 => primitive_equal::<f32>(lhs, rhs, lhs_start, rhs_start, len),
80        DataType::Float64 => primitive_equal::<f64>(lhs, rhs, lhs_start, rhs_start, len),
81        DataType::Decimal32(_, _) => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
82        DataType::Decimal64(_, _) => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
83        DataType::Decimal128(_, _) => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len),
84        DataType::Decimal256(_, _) => primitive_equal::<i256>(lhs, rhs, lhs_start, rhs_start, len),
85        DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => {
86            primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len)
87        }
88        DataType::Date64
89        | DataType::Interval(IntervalUnit::DayTime)
90        | DataType::Time64(_)
91        | DataType::Timestamp(_, _)
92        | DataType::Duration(_) => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
93        DataType::Interval(IntervalUnit::MonthDayNano) => {
94            primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len)
95        }
96        DataType::Utf8 | DataType::Binary => {
97            variable_sized_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len)
98        }
99        DataType::LargeUtf8 | DataType::LargeBinary => {
100            variable_sized_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len)
101        }
102        DataType::FixedSizeBinary(_) => fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len),
103        DataType::BinaryView | DataType::Utf8View => {
104            byte_view_equal(lhs, rhs, lhs_start, rhs_start, len)
105        }
106        DataType::List(_) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
107        DataType::ListView(_) | DataType::LargeListView(_) => {
108            unimplemented!("ListView/LargeListView not yet implemented")
109        }
110        DataType::LargeList(_) => list_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
111        DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len),
112        DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len),
113        DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
114        DataType::Dictionary(data_type, _) => match data_type.as_ref() {
115            DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
116            DataType::Int16 => dictionary_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len),
117            DataType::Int32 => dictionary_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
118            DataType::Int64 => dictionary_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
119            DataType::UInt8 => dictionary_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len),
120            DataType::UInt16 => dictionary_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len),
121            DataType::UInt32 => dictionary_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len),
122            DataType::UInt64 => dictionary_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len),
123            _ => unreachable!(),
124        },
125        DataType::Float16 => primitive_equal::<f16>(lhs, rhs, lhs_start, rhs_start, len),
126        DataType::Map(_, _) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
127        DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len),
128    }
129}
130
131fn equal_range(
132    lhs: &ArrayData,
133    rhs: &ArrayData,
134    lhs_start: usize,
135    rhs_start: usize,
136    len: usize,
137) -> bool {
138    utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len)
139        && equal_values(lhs, rhs, lhs_start, rhs_start, len)
140}
141
142/// Logically compares two [ArrayData].
143///
144/// Two arrays are logically equal if and only if:
145/// * their data types are equal
146/// * their lengths are equal
147/// * their null counts are equal
148/// * their null bitmaps are equal
149/// * each of their items are equal
150///
151/// Two items are equal when their in-memory representation is physically equal
152/// (i.e. has the same bit content).
153///
154/// The physical comparison depend on the data type.
155///
156/// # Panics
157///
158/// This function may panic whenever any of the [ArrayData] does not follow the
159/// Arrow specification. (e.g. wrong number of buffers, buffer `len` does not
160/// correspond to the declared `len`)
161pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
162    utils::base_equal(lhs, rhs)
163        && lhs.null_count() == rhs.null_count()
164        && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len())
165        && equal_values(lhs, rhs, 0, 0, lhs.len())
166}
167
168// See arrow/tests/array_equal.rs for tests