arrow_string/
binary_predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{Array, ArrayAccessor, BinaryViewArray, BooleanArray};
19use arrow_buffer::BooleanBuffer;
20use memchr::memmem::Finder;
21use std::iter::zip;
22
23/// A binary based predicate
24#[allow(clippy::large_enum_variant)]
25pub enum BinaryPredicate<'a> {
26    Contains(Finder<'a>),
27    StartsWith(&'a [u8]),
28    EndsWith(&'a [u8]),
29}
30
31impl<'a> BinaryPredicate<'a> {
32    pub fn contains(needle: &'a [u8]) -> Self {
33        Self::Contains(Finder::new(needle))
34    }
35
36    /// Evaluate this predicate against the given haystack
37    pub fn evaluate(&self, haystack: &[u8]) -> bool {
38        match self {
39            Self::Contains(finder) => finder.find(haystack).is_some(),
40            Self::StartsWith(v) => starts_with(haystack, v, equals_kernel),
41            Self::EndsWith(v) => ends_with(haystack, v, equals_kernel),
42        }
43    }
44
45    /// Evaluate this predicate against the elements of `array`
46    ///
47    /// If `negate` is true the result of the predicate will be negated
48    #[inline(never)]
49    pub fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
50    where
51        T: ArrayAccessor<Item = &'i [u8]>,
52    {
53        match self {
54            Self::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
55                finder.find(haystack).is_some() != negate
56            }),
57            Self::StartsWith(v) => {
58                if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
59                    let nulls = view_array.logical_nulls();
60                    let values = BooleanBuffer::from(
61                        view_array
62                            .prefix_bytes_iter(v.len())
63                            .map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
64                            .collect::<Vec<_>>(),
65                    );
66                    BooleanArray::new(values, nulls)
67                } else {
68                    BooleanArray::from_unary(array, |haystack| {
69                        starts_with(haystack, v, equals_kernel) != negate
70                    })
71                }
72            }
73            Self::EndsWith(v) => {
74                if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
75                    let nulls = view_array.logical_nulls();
76                    let values = BooleanBuffer::from(
77                        view_array
78                            .suffix_bytes_iter(v.len())
79                            .map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
80                            .collect::<Vec<_>>(),
81                    );
82                    BooleanArray::new(values, nulls)
83                } else {
84                    BooleanArray::from_unary(array, |haystack| {
85                        ends_with(haystack, v, equals_kernel) != negate
86                    })
87                }
88            }
89        }
90    }
91}
92
93fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
94    lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
95}
96
97/// This is faster than `[u8]::starts_with` for small slices.
98/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
99fn starts_with(
100    haystack: &[u8],
101    needle: &[u8],
102    byte_eq_kernel: impl Fn((&u8, &u8)) -> bool,
103) -> bool {
104    if needle.len() > haystack.len() {
105        false
106    } else {
107        zip(haystack, needle).all(byte_eq_kernel)
108    }
109}
110/// This is faster than `[u8]::ends_with` for small slices.
111/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
112fn ends_with(haystack: &[u8], needle: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
113    if needle.len() > haystack.len() {
114        false
115    } else {
116        zip(haystack.iter().rev(), needle.iter().rev()).all(byte_eq_kernel)
117    }
118}
119
120fn equals_kernel((n, h): (&u8, &u8)) -> bool {
121    n == h
122}
123
124#[cfg(test)]
125mod tests {
126    use super::BinaryPredicate;
127
128    #[test]
129    fn test_contains() {
130        assert!(BinaryPredicate::contains(b"hay").evaluate(b"haystack"));
131        assert!(BinaryPredicate::contains(b"haystack").evaluate(b"haystack"));
132        assert!(BinaryPredicate::contains(b"h").evaluate(b"haystack"));
133        assert!(BinaryPredicate::contains(b"k").evaluate(b"haystack"));
134        assert!(BinaryPredicate::contains(b"stack").evaluate(b"haystack"));
135        assert!(BinaryPredicate::contains(b"sta").evaluate(b"haystack"));
136        assert!(BinaryPredicate::contains(b"stack").evaluate(b"hay\0stack"));
137        assert!(BinaryPredicate::contains(b"\0s").evaluate(b"hay\0stack"));
138        assert!(BinaryPredicate::contains(b"\0").evaluate(b"hay\0stack"));
139        assert!(BinaryPredicate::contains(b"a").evaluate(b"a"));
140        // not matching
141        assert!(!BinaryPredicate::contains(b"hy").evaluate(b"haystack"));
142        assert!(!BinaryPredicate::contains(b"stackx").evaluate(b"haystack"));
143        assert!(!BinaryPredicate::contains(b"x").evaluate(b"haystack"));
144        assert!(!BinaryPredicate::contains(b"haystack haystack").evaluate(b"haystack"));
145    }
146
147    #[test]
148    fn test_starts_with() {
149        assert!(BinaryPredicate::StartsWith(b"hay").evaluate(b"haystack"));
150        assert!(BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"h\0aystack"));
151        assert!(BinaryPredicate::StartsWith(b"haystack").evaluate(b"haystack"));
152        assert!(BinaryPredicate::StartsWith(b"ha").evaluate(b"haystack"));
153        assert!(BinaryPredicate::StartsWith(b"h").evaluate(b"haystack"));
154        assert!(BinaryPredicate::StartsWith(b"").evaluate(b"haystack"));
155
156        assert!(!BinaryPredicate::StartsWith(b"stack").evaluate(b"haystack"));
157        assert!(!BinaryPredicate::StartsWith(b"haystacks").evaluate(b"haystack"));
158        assert!(!BinaryPredicate::StartsWith(b"HAY").evaluate(b"haystack"));
159        assert!(!BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"haystack"));
160        assert!(!BinaryPredicate::StartsWith(b"hay").evaluate(b"h\0aystack"));
161    }
162
163    #[test]
164    fn test_ends_with() {
165        assert!(BinaryPredicate::EndsWith(b"stack").evaluate(b"haystack"));
166        assert!(BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"hayst\0ack"));
167        assert!(BinaryPredicate::EndsWith(b"haystack").evaluate(b"haystack"));
168        assert!(BinaryPredicate::EndsWith(b"ck").evaluate(b"haystack"));
169        assert!(BinaryPredicate::EndsWith(b"k").evaluate(b"haystack"));
170        assert!(BinaryPredicate::EndsWith(b"").evaluate(b"haystack"));
171
172        assert!(!BinaryPredicate::EndsWith(b"hay").evaluate(b"haystack"));
173        assert!(!BinaryPredicate::EndsWith(b"STACK").evaluate(b"haystack"));
174        assert!(!BinaryPredicate::EndsWith(b"haystacks").evaluate(b"haystack"));
175        assert!(!BinaryPredicate::EndsWith(b"xhaystack").evaluate(b"haystack"));
176        assert!(!BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"haystack"));
177        assert!(!BinaryPredicate::EndsWith(b"stack").evaluate(b"hayst\0ack"));
178    }
179}