1use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
19use arrow_buffer::BooleanBuffer;
20use arrow_schema::ArrowError;
21use memchr::memchr3;
22use memchr::memmem::Finder;
23use regex::{Regex, RegexBuilder};
24use std::iter::zip;
25
26#[allow(clippy::large_enum_variant)]
28pub(crate) enum Predicate<'a> {
29 Eq(&'a str),
30 Contains(Finder<'a>),
31 StartsWith(&'a str),
32 EndsWith(&'a str),
33
34 IEqAscii(&'a str),
36 IStartsWithAscii(&'a str),
38 IEndsWithAscii(&'a str),
40
41 Regex(Regex),
42}
43
44impl<'a> Predicate<'a> {
45 pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> {
47 if !contains_like_pattern(pattern) {
48 Ok(Self::Eq(pattern))
49 } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
50 Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
51 } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
52 Ok(Self::EndsWith(&pattern[1..]))
53 } else if pattern.starts_with('%')
54 && pattern.ends_with('%')
55 && !contains_like_pattern(&pattern[1..pattern.len() - 1])
56 {
57 Ok(Self::contains(&pattern[1..pattern.len() - 1]))
58 } else {
59 Ok(Self::Regex(regex_like(pattern, false)?))
60 }
61 }
62
63 pub(crate) fn contains(needle: &'a str) -> Self {
64 Self::Contains(Finder::new(needle.as_bytes()))
65 }
66
67 pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
69 if is_ascii && pattern.is_ascii() {
70 if !contains_like_pattern(pattern) {
71 return Ok(Self::IEqAscii(pattern));
72 } else if pattern.ends_with('%')
73 && !pattern.ends_with("\\%")
74 && !contains_like_pattern(&pattern[..pattern.len() - 1])
75 {
76 return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
77 } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
78 return Ok(Self::IEndsWithAscii(&pattern[1..]));
79 }
80 }
81 Ok(Self::Regex(regex_like(pattern, true)?))
82 }
83
84 pub(crate) fn evaluate(&self, haystack: &str) -> bool {
86 match self {
87 Predicate::Eq(v) => *v == haystack,
88 Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
89 Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
90 Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
91 Predicate::IStartsWithAscii(v) => {
92 starts_with(haystack, v, equals_ignore_ascii_case_kernel)
93 }
94 Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
95 Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
96 Predicate::Regex(v) => v.is_match(haystack),
97 }
98 }
99
100 #[inline(never)]
104 pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
105 where
106 T: ArrayAccessor<Item = &'i str>,
107 {
108 match self {
109 Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
110 (haystack.len() == v.len() && haystack == *v) != negate
111 }),
112 Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
113 haystack.eq_ignore_ascii_case(v) != negate
114 }),
115 Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
116 finder.find(haystack.as_bytes()).is_some() != negate
117 }),
118 Predicate::StartsWith(v) => {
119 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
120 let nulls = string_view_array.logical_nulls();
121 let values = BooleanBuffer::from(
122 string_view_array
123 .prefix_bytes_iter(v.len())
124 .map(|haystack| {
125 equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
126 })
127 .collect::<Vec<_>>(),
128 );
129 BooleanArray::new(values, nulls)
130 } else {
131 BooleanArray::from_unary(array, |haystack| {
132 starts_with(haystack, v, equals_kernel) != negate
133 })
134 }
135 }
136 Predicate::IStartsWithAscii(v) => {
137 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
138 let nulls = string_view_array.logical_nulls();
139 let values = BooleanBuffer::from(
140 string_view_array
141 .prefix_bytes_iter(v.len())
142 .map(|haystack| {
143 equals_bytes(
144 haystack,
145 v.as_bytes(),
146 equals_ignore_ascii_case_kernel,
147 ) != negate
148 })
149 .collect::<Vec<_>>(),
150 );
151 BooleanArray::new(values, nulls)
152 } else {
153 BooleanArray::from_unary(array, |haystack| {
154 starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
155 })
156 }
157 }
158 Predicate::EndsWith(v) => {
159 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
160 let nulls = string_view_array.logical_nulls();
161 let values = BooleanBuffer::from(
162 string_view_array
163 .suffix_bytes_iter(v.len())
164 .map(|haystack| {
165 equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
166 })
167 .collect::<Vec<_>>(),
168 );
169 BooleanArray::new(values, nulls)
170 } else {
171 BooleanArray::from_unary(array, |haystack| {
172 ends_with(haystack, v, equals_kernel) != negate
173 })
174 }
175 }
176 Predicate::IEndsWithAscii(v) => {
177 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
178 let nulls = string_view_array.logical_nulls();
179 let values = BooleanBuffer::from(
180 string_view_array
181 .suffix_bytes_iter(v.len())
182 .map(|haystack| {
183 equals_bytes(
184 haystack,
185 v.as_bytes(),
186 equals_ignore_ascii_case_kernel,
187 ) != negate
188 })
189 .collect::<Vec<_>>(),
190 );
191 BooleanArray::new(values, nulls)
192 } else {
193 BooleanArray::from_unary(array, |haystack| {
194 ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
195 })
196 }
197 }
198 Predicate::Regex(v) => {
199 BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
200 }
201 }
202 }
203}
204
205fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
206 lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
207}
208
209fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
212 if needle.len() > haystack.len() {
213 false
214 } else {
215 zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
216 }
217}
218fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
221 if needle.len() > haystack.len() {
222 false
223 } else {
224 zip(
225 haystack.as_bytes().iter().rev(),
226 needle.as_bytes().iter().rev(),
227 )
228 .all(byte_eq_kernel)
229 }
230}
231
232fn equals_kernel((n, h): (&u8, &u8)) -> bool {
233 n == h
234}
235
236fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
237 n.eq_ignore_ascii_case(h)
238}
239
240fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
248 let mut result = String::with_capacity(pattern.len() * 2);
249 let mut chars_iter = pattern.chars().peekable();
250 match chars_iter.peek() {
251 Some('%') => {
253 chars_iter.next();
254 }
255 _ => result.push('^'),
256 };
257
258 while let Some(c) = chars_iter.next() {
259 match c {
260 '\\' => {
261 match chars_iter.peek() {
262 Some(&next) => {
263 if regex_syntax::is_meta_character(next) {
264 result.push('\\');
265 }
266 result.push(next);
267 chars_iter.next();
269 }
270 None => {
271 result.push('\\');
273 result.push('\\');
274 }
275 }
276 }
277 '%' => result.push_str(".*"),
278 '_' => result.push('.'),
279 c => {
280 if regex_syntax::is_meta_character(c) {
281 result.push('\\');
282 }
283 result.push(c);
284 }
285 }
286 }
287 if result.ends_with(".*") {
289 result.pop();
290 result.pop();
291 } else {
292 result.push('$');
293 }
294 RegexBuilder::new(&result)
295 .case_insensitive(case_insensitive)
296 .dot_matches_new_line(true)
297 .build()
298 .map_err(|e| {
299 ArrowError::InvalidArgumentError(format!(
300 "Unable to build regex from LIKE pattern: {e}"
301 ))
302 })
303}
304
305fn contains_like_pattern(pattern: &str) -> bool {
306 memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_regex_like() {
315 let test_cases = [
316 (r"%foobar%", r"foobar"),
318 (r"foo%bar", r"^foo.*bar$"),
320 (r"foo_bar", r"^foo.bar$"),
322 (r"\%\_", r"^%_$"),
324 (r"\a", r"^a$"),
326 (r"\\%", r"^\\"),
328 (r"\\a", r"^\\a$"),
330 (r".", r"^\.$"),
332 (r"$", r"^\$$"),
333 (r"\\", r"^\\$"),
334 ];
335
336 for (like_pattern, expected_regexp) in test_cases {
337 let r = regex_like(like_pattern, false).unwrap();
338 assert_eq!(r.to_string(), expected_regexp);
339 }
340 }
341
342 #[test]
343 fn test_contains() {
344 assert!(Predicate::contains("hay").evaluate("haystack"));
345 assert!(Predicate::contains("haystack").evaluate("haystack"));
346 assert!(Predicate::contains("h").evaluate("haystack"));
347 assert!(Predicate::contains("k").evaluate("haystack"));
348 assert!(Predicate::contains("stack").evaluate("haystack"));
349 assert!(Predicate::contains("sta").evaluate("haystack"));
350 assert!(Predicate::contains("stack").evaluate("hay£stack"));
351 assert!(Predicate::contains("y£s").evaluate("hay£stack"));
352 assert!(Predicate::contains("£").evaluate("hay£stack"));
353 assert!(Predicate::contains("a").evaluate("a"));
354 assert!(!Predicate::contains("hy").evaluate("haystack"));
356 assert!(!Predicate::contains("stackx").evaluate("haystack"));
357 assert!(!Predicate::contains("x").evaluate("haystack"));
358 assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
359 }
360
361 #[test]
362 fn test_starts_with() {
363 assert!(Predicate::StartsWith("hay").evaluate("haystack"));
364 assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
365 assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
366 assert!(Predicate::StartsWith("ha").evaluate("haystack"));
367 assert!(Predicate::StartsWith("h").evaluate("haystack"));
368 assert!(Predicate::StartsWith("").evaluate("haystack"));
369
370 assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
371 assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
372 assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
373 assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
374 assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
375 }
376
377 #[test]
378 fn test_ends_with() {
379 assert!(Predicate::EndsWith("stack").evaluate("haystack"));
380 assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
381 assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
382 assert!(Predicate::EndsWith("ck").evaluate("haystack"));
383 assert!(Predicate::EndsWith("k").evaluate("haystack"));
384 assert!(Predicate::EndsWith("").evaluate("haystack"));
385
386 assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
387 assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
388 assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
389 assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
390 assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
391 assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
392 }
393
394 #[test]
395 fn test_istarts_with() {
396 assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
397 assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
398 assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
399 assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
400 assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
401 assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
402 assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
403 assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
404 assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
405
406 assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
407 assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
408 assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
409 assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
410 }
411
412 #[test]
413 fn test_iends_with() {
414 assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
415 assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
416 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
417 assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
418 assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
419 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
420 assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
421 assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
422 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
423 assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
424 assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
425 assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
426 assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
427 assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
428 assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
429 assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
430
431 assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
432 assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
433 assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
434 assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
435 assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
436 }
437}