arrow_string/
binary_predicate.rs1use arrow_array::{Array, ArrayAccessor, BinaryViewArray, BooleanArray};
19use arrow_buffer::BooleanBuffer;
20use memchr::memmem::Finder;
21use std::iter::zip;
22
23#[allow(clippy::large_enum_variant)]
25pub enum BinaryPredicate<'a> {
26 Contains(Finder<'a>),
27 StartsWith(&'a [u8]),
28 EndsWith(&'a [u8]),
29}
30
31impl<'a> BinaryPredicate<'a> {
32 pub fn contains(needle: &'a [u8]) -> Self {
33 Self::Contains(Finder::new(needle))
34 }
35
36 pub fn evaluate(&self, haystack: &[u8]) -> bool {
38 match self {
39 Self::Contains(finder) => finder.find(haystack).is_some(),
40 Self::StartsWith(v) => starts_with(haystack, v, equals_kernel),
41 Self::EndsWith(v) => ends_with(haystack, v, equals_kernel),
42 }
43 }
44
45 #[inline(never)]
49 pub fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
50 where
51 T: ArrayAccessor<Item = &'i [u8]>,
52 {
53 match self {
54 Self::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
55 finder.find(haystack).is_some() != negate
56 }),
57 Self::StartsWith(v) => {
58 if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
59 let nulls = view_array.logical_nulls();
60 let values = BooleanBuffer::from(
61 view_array
62 .prefix_bytes_iter(v.len())
63 .map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
64 .collect::<Vec<_>>(),
65 );
66 BooleanArray::new(values, nulls)
67 } else {
68 BooleanArray::from_unary(array, |haystack| {
69 starts_with(haystack, v, equals_kernel) != negate
70 })
71 }
72 }
73 Self::EndsWith(v) => {
74 if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
75 let nulls = view_array.logical_nulls();
76 let values = BooleanBuffer::from(
77 view_array
78 .suffix_bytes_iter(v.len())
79 .map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
80 .collect::<Vec<_>>(),
81 );
82 BooleanArray::new(values, nulls)
83 } else {
84 BooleanArray::from_unary(array, |haystack| {
85 ends_with(haystack, v, equals_kernel) != negate
86 })
87 }
88 }
89 }
90 }
91}
92
93fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
94 lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
95}
96
97fn starts_with(
100 haystack: &[u8],
101 needle: &[u8],
102 byte_eq_kernel: impl Fn((&u8, &u8)) -> bool,
103) -> bool {
104 if needle.len() > haystack.len() {
105 false
106 } else {
107 zip(haystack, needle).all(byte_eq_kernel)
108 }
109}
110fn ends_with(haystack: &[u8], needle: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
113 if needle.len() > haystack.len() {
114 false
115 } else {
116 zip(haystack.iter().rev(), needle.iter().rev()).all(byte_eq_kernel)
117 }
118}
119
120fn equals_kernel((n, h): (&u8, &u8)) -> bool {
121 n == h
122}
123
124#[cfg(test)]
125mod tests {
126 use super::BinaryPredicate;
127
128 #[test]
129 fn test_contains() {
130 assert!(BinaryPredicate::contains(b"hay").evaluate(b"haystack"));
131 assert!(BinaryPredicate::contains(b"haystack").evaluate(b"haystack"));
132 assert!(BinaryPredicate::contains(b"h").evaluate(b"haystack"));
133 assert!(BinaryPredicate::contains(b"k").evaluate(b"haystack"));
134 assert!(BinaryPredicate::contains(b"stack").evaluate(b"haystack"));
135 assert!(BinaryPredicate::contains(b"sta").evaluate(b"haystack"));
136 assert!(BinaryPredicate::contains(b"stack").evaluate(b"hay\0stack"));
137 assert!(BinaryPredicate::contains(b"\0s").evaluate(b"hay\0stack"));
138 assert!(BinaryPredicate::contains(b"\0").evaluate(b"hay\0stack"));
139 assert!(BinaryPredicate::contains(b"a").evaluate(b"a"));
140 assert!(!BinaryPredicate::contains(b"hy").evaluate(b"haystack"));
142 assert!(!BinaryPredicate::contains(b"stackx").evaluate(b"haystack"));
143 assert!(!BinaryPredicate::contains(b"x").evaluate(b"haystack"));
144 assert!(!BinaryPredicate::contains(b"haystack haystack").evaluate(b"haystack"));
145 }
146
147 #[test]
148 fn test_starts_with() {
149 assert!(BinaryPredicate::StartsWith(b"hay").evaluate(b"haystack"));
150 assert!(BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"h\0aystack"));
151 assert!(BinaryPredicate::StartsWith(b"haystack").evaluate(b"haystack"));
152 assert!(BinaryPredicate::StartsWith(b"ha").evaluate(b"haystack"));
153 assert!(BinaryPredicate::StartsWith(b"h").evaluate(b"haystack"));
154 assert!(BinaryPredicate::StartsWith(b"").evaluate(b"haystack"));
155
156 assert!(!BinaryPredicate::StartsWith(b"stack").evaluate(b"haystack"));
157 assert!(!BinaryPredicate::StartsWith(b"haystacks").evaluate(b"haystack"));
158 assert!(!BinaryPredicate::StartsWith(b"HAY").evaluate(b"haystack"));
159 assert!(!BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"haystack"));
160 assert!(!BinaryPredicate::StartsWith(b"hay").evaluate(b"h\0aystack"));
161 }
162
163 #[test]
164 fn test_ends_with() {
165 assert!(BinaryPredicate::EndsWith(b"stack").evaluate(b"haystack"));
166 assert!(BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"hayst\0ack"));
167 assert!(BinaryPredicate::EndsWith(b"haystack").evaluate(b"haystack"));
168 assert!(BinaryPredicate::EndsWith(b"ck").evaluate(b"haystack"));
169 assert!(BinaryPredicate::EndsWith(b"k").evaluate(b"haystack"));
170 assert!(BinaryPredicate::EndsWith(b"").evaluate(b"haystack"));
171
172 assert!(!BinaryPredicate::EndsWith(b"hay").evaluate(b"haystack"));
173 assert!(!BinaryPredicate::EndsWith(b"STACK").evaluate(b"haystack"));
174 assert!(!BinaryPredicate::EndsWith(b"haystacks").evaluate(b"haystack"));
175 assert!(!BinaryPredicate::EndsWith(b"xhaystack").evaluate(b"haystack"));
176 assert!(!BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"haystack"));
177 assert!(!BinaryPredicate::EndsWith(b"stack").evaluate(b"hayst\0ack"));
178 }
179}