1use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
19use arrow_buffer::BooleanBuffer;
20use arrow_schema::ArrowError;
21use memchr::memchr3;
22use memchr::memmem::Finder;
23use regex::{Regex, RegexBuilder};
24use std::iter::zip;
25
26pub(crate) enum Predicate<'a> {
28 Eq(&'a str),
29 Contains(Finder<'a>),
30 StartsWith(&'a str),
31 EndsWith(&'a str),
32
33 IEqAscii(&'a str),
35 IStartsWithAscii(&'a str),
37 IEndsWithAscii(&'a str),
39
40 Regex(Regex),
41}
42
43impl<'a> Predicate<'a> {
44 pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> {
46 if !contains_like_pattern(pattern) {
47 Ok(Self::Eq(pattern))
48 } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
49 Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
50 } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
51 Ok(Self::EndsWith(&pattern[1..]))
52 } else if pattern.starts_with('%')
53 && pattern.ends_with('%')
54 && !contains_like_pattern(&pattern[1..pattern.len() - 1])
55 {
56 Ok(Self::contains(&pattern[1..pattern.len() - 1]))
57 } else {
58 Ok(Self::Regex(regex_like(pattern, false)?))
59 }
60 }
61
62 pub(crate) fn contains(needle: &'a str) -> Self {
63 Self::Contains(Finder::new(needle.as_bytes()))
64 }
65
66 pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
68 if is_ascii && pattern.is_ascii() {
69 if !contains_like_pattern(pattern) {
70 return Ok(Self::IEqAscii(pattern));
71 } else if pattern.ends_with('%')
72 && !pattern.ends_with("\\%")
73 && !contains_like_pattern(&pattern[..pattern.len() - 1])
74 {
75 return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
76 } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
77 return Ok(Self::IEndsWithAscii(&pattern[1..]));
78 }
79 }
80 Ok(Self::Regex(regex_like(pattern, true)?))
81 }
82
83 pub(crate) fn evaluate(&self, haystack: &str) -> bool {
85 match self {
86 Predicate::Eq(v) => *v == haystack,
87 Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
88 Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
89 Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
90 Predicate::IStartsWithAscii(v) => {
91 starts_with(haystack, v, equals_ignore_ascii_case_kernel)
92 }
93 Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
94 Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
95 Predicate::Regex(v) => v.is_match(haystack),
96 }
97 }
98
99 #[inline(never)]
103 pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
104 where
105 T: ArrayAccessor<Item = &'i str>,
106 {
107 match self {
108 Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
109 (haystack.len() == v.len() && haystack == *v) != negate
110 }),
111 Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
112 haystack.eq_ignore_ascii_case(v) != negate
113 }),
114 Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
115 finder.find(haystack.as_bytes()).is_some() != negate
116 }),
117 Predicate::StartsWith(v) => {
118 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
119 let nulls = string_view_array.logical_nulls();
120 let values = BooleanBuffer::from(
121 string_view_array
122 .prefix_bytes_iter(v.len())
123 .map(|haystack| {
124 equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
125 })
126 .collect::<Vec<_>>(),
127 );
128 BooleanArray::new(values, nulls)
129 } else {
130 BooleanArray::from_unary(array, |haystack| {
131 starts_with(haystack, v, equals_kernel) != negate
132 })
133 }
134 }
135 Predicate::IStartsWithAscii(v) => {
136 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
137 let nulls = string_view_array.logical_nulls();
138 let values = BooleanBuffer::from(
139 string_view_array
140 .prefix_bytes_iter(v.len())
141 .map(|haystack| {
142 equals_bytes(
143 haystack,
144 v.as_bytes(),
145 equals_ignore_ascii_case_kernel,
146 ) != negate
147 })
148 .collect::<Vec<_>>(),
149 );
150 BooleanArray::new(values, nulls)
151 } else {
152 BooleanArray::from_unary(array, |haystack| {
153 starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
154 })
155 }
156 }
157 Predicate::EndsWith(v) => {
158 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
159 let nulls = string_view_array.logical_nulls();
160 let values = BooleanBuffer::from(
161 string_view_array
162 .suffix_bytes_iter(v.len())
163 .map(|haystack| {
164 equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
165 })
166 .collect::<Vec<_>>(),
167 );
168 BooleanArray::new(values, nulls)
169 } else {
170 BooleanArray::from_unary(array, |haystack| {
171 ends_with(haystack, v, equals_kernel) != negate
172 })
173 }
174 }
175 Predicate::IEndsWithAscii(v) => {
176 if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
177 let nulls = string_view_array.logical_nulls();
178 let values = BooleanBuffer::from(
179 string_view_array
180 .suffix_bytes_iter(v.len())
181 .map(|haystack| {
182 equals_bytes(
183 haystack,
184 v.as_bytes(),
185 equals_ignore_ascii_case_kernel,
186 ) != negate
187 })
188 .collect::<Vec<_>>(),
189 );
190 BooleanArray::new(values, nulls)
191 } else {
192 BooleanArray::from_unary(array, |haystack| {
193 ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
194 })
195 }
196 }
197 Predicate::Regex(v) => {
198 BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
199 }
200 }
201 }
202}
203
204fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
205 lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
206}
207
208fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
211 if needle.len() > haystack.len() {
212 false
213 } else {
214 zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
215 }
216}
217fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
220 if needle.len() > haystack.len() {
221 false
222 } else {
223 zip(
224 haystack.as_bytes().iter().rev(),
225 needle.as_bytes().iter().rev(),
226 )
227 .all(byte_eq_kernel)
228 }
229}
230
231fn equals_kernel((n, h): (&u8, &u8)) -> bool {
232 n == h
233}
234
235fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
236 n.eq_ignore_ascii_case(h)
237}
238
239fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
247 let mut result = String::with_capacity(pattern.len() * 2);
248 let mut chars_iter = pattern.chars().peekable();
249 match chars_iter.peek() {
250 Some('%') => {
252 chars_iter.next();
253 }
254 _ => result.push('^'),
255 };
256
257 while let Some(c) = chars_iter.next() {
258 match c {
259 '\\' => {
260 match chars_iter.peek() {
261 Some(&next) => {
262 if regex_syntax::is_meta_character(next) {
263 result.push('\\');
264 }
265 result.push(next);
266 chars_iter.next();
268 }
269 None => {
270 result.push('\\');
272 result.push('\\');
273 }
274 }
275 }
276 '%' => result.push_str(".*"),
277 '_' => result.push('.'),
278 c => {
279 if regex_syntax::is_meta_character(c) {
280 result.push('\\');
281 }
282 result.push(c);
283 }
284 }
285 }
286 if result.ends_with(".*") {
288 result.pop();
289 result.pop();
290 } else {
291 result.push('$');
292 }
293 RegexBuilder::new(&result)
294 .case_insensitive(case_insensitive)
295 .dot_matches_new_line(true)
296 .build()
297 .map_err(|e| {
298 ArrowError::InvalidArgumentError(format!(
299 "Unable to build regex from LIKE pattern: {e}"
300 ))
301 })
302}
303
304fn contains_like_pattern(pattern: &str) -> bool {
305 memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_regex_like() {
314 let test_cases = [
315 (r"%foobar%", r"foobar"),
317 (r"foo%bar", r"^foo.*bar$"),
319 (r"foo_bar", r"^foo.bar$"),
321 (r"\%\_", r"^%_$"),
323 (r"\a", r"^a$"),
325 (r"\\%", r"^\\"),
327 (r"\\a", r"^\\a$"),
329 (r".", r"^\.$"),
331 (r"$", r"^\$$"),
332 (r"\\", r"^\\$"),
333 ];
334
335 for (like_pattern, expected_regexp) in test_cases {
336 let r = regex_like(like_pattern, false).unwrap();
337 assert_eq!(r.to_string(), expected_regexp);
338 }
339 }
340
341 #[test]
342 fn test_contains() {
343 assert!(Predicate::contains("hay").evaluate("haystack"));
344 assert!(Predicate::contains("haystack").evaluate("haystack"));
345 assert!(Predicate::contains("h").evaluate("haystack"));
346 assert!(Predicate::contains("k").evaluate("haystack"));
347 assert!(Predicate::contains("stack").evaluate("haystack"));
348 assert!(Predicate::contains("sta").evaluate("haystack"));
349 assert!(Predicate::contains("stack").evaluate("hay£stack"));
350 assert!(Predicate::contains("y£s").evaluate("hay£stack"));
351 assert!(Predicate::contains("£").evaluate("hay£stack"));
352 assert!(Predicate::contains("a").evaluate("a"));
353 assert!(!Predicate::contains("hy").evaluate("haystack"));
355 assert!(!Predicate::contains("stackx").evaluate("haystack"));
356 assert!(!Predicate::contains("x").evaluate("haystack"));
357 assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
358 }
359
360 #[test]
361 fn test_starts_with() {
362 assert!(Predicate::StartsWith("hay").evaluate("haystack"));
363 assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
364 assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
365 assert!(Predicate::StartsWith("ha").evaluate("haystack"));
366 assert!(Predicate::StartsWith("h").evaluate("haystack"));
367 assert!(Predicate::StartsWith("").evaluate("haystack"));
368
369 assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
370 assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
371 assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
372 assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
373 assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
374 }
375
376 #[test]
377 fn test_ends_with() {
378 assert!(Predicate::EndsWith("stack").evaluate("haystack"));
379 assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
380 assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
381 assert!(Predicate::EndsWith("ck").evaluate("haystack"));
382 assert!(Predicate::EndsWith("k").evaluate("haystack"));
383 assert!(Predicate::EndsWith("").evaluate("haystack"));
384
385 assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
386 assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
387 assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
388 assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
389 assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
390 assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
391 }
392
393 #[test]
394 fn test_istarts_with() {
395 assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
396 assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
397 assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
398 assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
399 assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
400 assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
401 assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
402 assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
403 assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
404
405 assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
406 assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
407 assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
408 assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
409 }
410
411 #[test]
412 fn test_iends_with() {
413 assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
414 assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
415 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
416 assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
417 assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
418 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
419 assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
420 assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
421 assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
422 assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
423 assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
424 assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
425 assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
426 assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
427 assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
428 assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
429
430 assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
431 assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
432 assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
433 assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
434 assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
435 }
436}