arrow_data/equal/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::data::ArrayData;
19use arrow_schema::{DataType, UnionFields, UnionMode};
20
21use super::equal_range;
22
23#[allow(clippy::too_many_arguments)]
24fn equal_dense(
25    lhs: &ArrayData,
26    rhs: &ArrayData,
27    lhs_type_ids: &[i8],
28    rhs_type_ids: &[i8],
29    lhs_offsets: &[i32],
30    rhs_offsets: &[i32],
31    lhs_fields: &UnionFields,
32    rhs_fields: &UnionFields,
33) -> bool {
34    let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());
35
36    lhs_type_ids
37        .iter()
38        .zip(rhs_type_ids.iter())
39        .zip(offsets)
40        .all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
41            let lhs_child_index = lhs_fields
42                .iter()
43                .position(|(r, _)| r == *l_type_id)
44                .unwrap();
45            let rhs_child_index = rhs_fields
46                .iter()
47                .position(|(r, _)| r == *r_type_id)
48                .unwrap();
49            let lhs_values = &lhs.child_data()[lhs_child_index];
50            let rhs_values = &rhs.child_data()[rhs_child_index];
51
52            equal_range(
53                lhs_values,
54                rhs_values,
55                *l_offset as usize,
56                *r_offset as usize,
57                1,
58            )
59        })
60}
61
62fn equal_sparse(
63    lhs: &ArrayData,
64    rhs: &ArrayData,
65    lhs_start: usize,
66    rhs_start: usize,
67    len: usize,
68) -> bool {
69    lhs.child_data()
70        .iter()
71        .zip(rhs.child_data())
72        .all(|(lhs_values, rhs_values)| {
73            equal_range(
74                lhs_values,
75                rhs_values,
76                lhs_start + lhs.offset(),
77                rhs_start + rhs.offset(),
78                len,
79            )
80        })
81}
82
83pub(super) fn union_equal(
84    lhs: &ArrayData,
85    rhs: &ArrayData,
86    lhs_start: usize,
87    rhs_start: usize,
88    len: usize,
89) -> bool {
90    let lhs_type_ids = lhs.buffer::<i8>(0);
91    let rhs_type_ids = rhs.buffer::<i8>(0);
92
93    let lhs_type_id_range = &lhs_type_ids[lhs_start..lhs_start + len];
94    let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len];
95
96    match (lhs.data_type(), rhs.data_type()) {
97        (
98            DataType::Union(lhs_fields, UnionMode::Dense),
99            DataType::Union(rhs_fields, UnionMode::Dense),
100        ) => {
101            let lhs_offsets = lhs.buffer::<i32>(1);
102            let rhs_offsets = rhs.buffer::<i32>(1);
103
104            let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len];
105            let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len];
106
107            lhs_type_id_range == rhs_type_id_range
108                && equal_dense(
109                    lhs,
110                    rhs,
111                    lhs_type_id_range,
112                    rhs_type_id_range,
113                    lhs_offsets_range,
114                    rhs_offsets_range,
115                    lhs_fields,
116                    rhs_fields,
117                )
118        }
119        (DataType::Union(_, UnionMode::Sparse), DataType::Union(_, UnionMode::Sparse)) => {
120            lhs_type_id_range == rhs_type_id_range
121                && equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
122        }
123        _ => unimplemented!(
124            "Logical equality not yet implemented between dense and sparse union arrays"
125        ),
126    }
127}