arrow_string/
predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
19use arrow_buffer::BooleanBuffer;
20use arrow_schema::ArrowError;
21use memchr::memchr3;
22use memchr::memmem::Finder;
23use regex::{Regex, RegexBuilder};
24use std::iter::zip;
25
26/// A string based predicate
27pub(crate) enum Predicate<'a> {
28    Eq(&'a str),
29    Contains(Finder<'a>),
30    StartsWith(&'a str),
31    EndsWith(&'a str),
32
33    /// Equality ignoring ASCII case
34    IEqAscii(&'a str),
35    /// Starts with ignoring ASCII case
36    IStartsWithAscii(&'a str),
37    /// Ends with ignoring ASCII case
38    IEndsWithAscii(&'a str),
39
40    Regex(Regex),
41}
42
43impl<'a> Predicate<'a> {
44    /// Create a predicate for the given like pattern
45    pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> {
46        if !contains_like_pattern(pattern) {
47            Ok(Self::Eq(pattern))
48        } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
49            Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
50        } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
51            Ok(Self::EndsWith(&pattern[1..]))
52        } else if pattern.starts_with('%')
53            && pattern.ends_with('%')
54            && !contains_like_pattern(&pattern[1..pattern.len() - 1])
55        {
56            Ok(Self::contains(&pattern[1..pattern.len() - 1]))
57        } else {
58            Ok(Self::Regex(regex_like(pattern, false)?))
59        }
60    }
61
62    pub(crate) fn contains(needle: &'a str) -> Self {
63        Self::Contains(Finder::new(needle.as_bytes()))
64    }
65
66    /// Create a predicate for the given ilike pattern
67    pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
68        if is_ascii && pattern.is_ascii() {
69            if !contains_like_pattern(pattern) {
70                return Ok(Self::IEqAscii(pattern));
71            } else if pattern.ends_with('%')
72                && !pattern.ends_with("\\%")
73                && !contains_like_pattern(&pattern[..pattern.len() - 1])
74            {
75                return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
76            } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
77                return Ok(Self::IEndsWithAscii(&pattern[1..]));
78            }
79        }
80        Ok(Self::Regex(regex_like(pattern, true)?))
81    }
82
83    /// Evaluate this predicate against the given haystack
84    pub(crate) fn evaluate(&self, haystack: &str) -> bool {
85        match self {
86            Predicate::Eq(v) => *v == haystack,
87            Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
88            Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
89            Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
90            Predicate::IStartsWithAscii(v) => {
91                starts_with(haystack, v, equals_ignore_ascii_case_kernel)
92            }
93            Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
94            Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
95            Predicate::Regex(v) => v.is_match(haystack),
96        }
97    }
98
99    /// Evaluate this predicate against the elements of `array`
100    ///
101    /// If `negate` is true the result of the predicate will be negated
102    #[inline(never)]
103    pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
104    where
105        T: ArrayAccessor<Item = &'i str>,
106    {
107        match self {
108            Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
109                (haystack.len() == v.len() && haystack == *v) != negate
110            }),
111            Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
112                haystack.eq_ignore_ascii_case(v) != negate
113            }),
114            Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
115                finder.find(haystack.as_bytes()).is_some() != negate
116            }),
117            Predicate::StartsWith(v) => {
118                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
119                    let nulls = string_view_array.logical_nulls();
120                    let values = BooleanBuffer::from(
121                        string_view_array
122                            .prefix_bytes_iter(v.len())
123                            .map(|haystack| {
124                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
125                            })
126                            .collect::<Vec<_>>(),
127                    );
128                    BooleanArray::new(values, nulls)
129                } else {
130                    BooleanArray::from_unary(array, |haystack| {
131                        starts_with(haystack, v, equals_kernel) != negate
132                    })
133                }
134            }
135            Predicate::IStartsWithAscii(v) => {
136                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
137                    let nulls = string_view_array.logical_nulls();
138                    let values = BooleanBuffer::from(
139                        string_view_array
140                            .prefix_bytes_iter(v.len())
141                            .map(|haystack| {
142                                equals_bytes(
143                                    haystack,
144                                    v.as_bytes(),
145                                    equals_ignore_ascii_case_kernel,
146                                ) != negate
147                            })
148                            .collect::<Vec<_>>(),
149                    );
150                    BooleanArray::new(values, nulls)
151                } else {
152                    BooleanArray::from_unary(array, |haystack| {
153                        starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
154                    })
155                }
156            }
157            Predicate::EndsWith(v) => {
158                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
159                    let nulls = string_view_array.logical_nulls();
160                    let values = BooleanBuffer::from(
161                        string_view_array
162                            .suffix_bytes_iter(v.len())
163                            .map(|haystack| {
164                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
165                            })
166                            .collect::<Vec<_>>(),
167                    );
168                    BooleanArray::new(values, nulls)
169                } else {
170                    BooleanArray::from_unary(array, |haystack| {
171                        ends_with(haystack, v, equals_kernel) != negate
172                    })
173                }
174            }
175            Predicate::IEndsWithAscii(v) => {
176                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
177                    let nulls = string_view_array.logical_nulls();
178                    let values = BooleanBuffer::from(
179                        string_view_array
180                            .suffix_bytes_iter(v.len())
181                            .map(|haystack| {
182                                equals_bytes(
183                                    haystack,
184                                    v.as_bytes(),
185                                    equals_ignore_ascii_case_kernel,
186                                ) != negate
187                            })
188                            .collect::<Vec<_>>(),
189                    );
190                    BooleanArray::new(values, nulls)
191                } else {
192                    BooleanArray::from_unary(array, |haystack| {
193                        ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
194                    })
195                }
196            }
197            Predicate::Regex(v) => {
198                BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
199            }
200        }
201    }
202}
203
204fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
205    lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
206}
207
208/// This is faster than `str::starts_with` for small strings.
209/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
210fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
211    if needle.len() > haystack.len() {
212        false
213    } else {
214        zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
215    }
216}
217/// This is faster than `str::ends_with` for small strings.
218/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
219fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
220    if needle.len() > haystack.len() {
221        false
222    } else {
223        zip(
224            haystack.as_bytes().iter().rev(),
225            needle.as_bytes().iter().rev(),
226        )
227        .all(byte_eq_kernel)
228    }
229}
230
231fn equals_kernel((n, h): (&u8, &u8)) -> bool {
232    n == h
233}
234
235fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
236    n.eq_ignore_ascii_case(h)
237}
238
239/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does:
240///
241/// 1. Replace `LIKE` multi-character wildcards `%` => `.*` (unless they're at the start or end of the pattern,
242///    where the regex is just truncated - e.g. `%foo%` => `foo` rather than `^.*foo.*$`)
243/// 2. Replace `LIKE` single-character wildcards `_` => `.`
244/// 3. Escape regex meta characters to match them and not be evaluated as regex special chars. e.g. `.` => `\\.`
245/// 4. Replace escaped `LIKE` wildcards removing the escape characters to be able to match it as a regex. e.g. `\\%` => `%`
246fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
247    let mut result = String::with_capacity(pattern.len() * 2);
248    let mut chars_iter = pattern.chars().peekable();
249    match chars_iter.peek() {
250        // if the pattern starts with `%`, we avoid starting the regex with a slow but meaningless `^.*`
251        Some('%') => {
252            chars_iter.next();
253        }
254        _ => result.push('^'),
255    };
256
257    while let Some(c) = chars_iter.next() {
258        match c {
259            '\\' => {
260                match chars_iter.peek() {
261                    Some(&next) => {
262                        if regex_syntax::is_meta_character(next) {
263                            result.push('\\');
264                        }
265                        result.push(next);
266                        // Skipping the next char as it is already appended
267                        chars_iter.next();
268                    }
269                    None => {
270                        // Trailing backslash in the pattern. E.g. PostgreSQL and Trino treat it as an error, but e.g. Snowflake treats it as a literal backslash
271                        result.push('\\');
272                        result.push('\\');
273                    }
274                }
275            }
276            '%' => result.push_str(".*"),
277            '_' => result.push('.'),
278            c => {
279                if regex_syntax::is_meta_character(c) {
280                    result.push('\\');
281                }
282                result.push(c);
283            }
284        }
285    }
286    // instead of ending the regex with `.*$` and making it needlessly slow, we just end the regex
287    if result.ends_with(".*") {
288        result.pop();
289        result.pop();
290    } else {
291        result.push('$');
292    }
293    RegexBuilder::new(&result)
294        .case_insensitive(case_insensitive)
295        .dot_matches_new_line(true)
296        .build()
297        .map_err(|e| {
298            ArrowError::InvalidArgumentError(format!(
299                "Unable to build regex from LIKE pattern: {e}"
300            ))
301        })
302}
303
304fn contains_like_pattern(pattern: &str) -> bool {
305    memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_regex_like() {
314        let test_cases = [
315            // %..%
316            (r"%foobar%", r"foobar"),
317            // ..%..
318            (r"foo%bar", r"^foo.*bar$"),
319            // .._..
320            (r"foo_bar", r"^foo.bar$"),
321            // escaped wildcards
322            (r"\%\_", r"^%_$"),
323            // escaped non-wildcard
324            (r"\a", r"^a$"),
325            // escaped escape and wildcard
326            (r"\\%", r"^\\"),
327            // escaped escape and non-wildcard
328            (r"\\a", r"^\\a$"),
329            // regex meta character
330            (r".", r"^\.$"),
331            (r"$", r"^\$$"),
332            (r"\\", r"^\\$"),
333        ];
334
335        for (like_pattern, expected_regexp) in test_cases {
336            let r = regex_like(like_pattern, false).unwrap();
337            assert_eq!(r.to_string(), expected_regexp);
338        }
339    }
340
341    #[test]
342    fn test_contains() {
343        assert!(Predicate::contains("hay").evaluate("haystack"));
344        assert!(Predicate::contains("haystack").evaluate("haystack"));
345        assert!(Predicate::contains("h").evaluate("haystack"));
346        assert!(Predicate::contains("k").evaluate("haystack"));
347        assert!(Predicate::contains("stack").evaluate("haystack"));
348        assert!(Predicate::contains("sta").evaluate("haystack"));
349        assert!(Predicate::contains("stack").evaluate("hay£stack"));
350        assert!(Predicate::contains("y£s").evaluate("hay£stack"));
351        assert!(Predicate::contains("£").evaluate("hay£stack"));
352        assert!(Predicate::contains("a").evaluate("a"));
353        // not matching
354        assert!(!Predicate::contains("hy").evaluate("haystack"));
355        assert!(!Predicate::contains("stackx").evaluate("haystack"));
356        assert!(!Predicate::contains("x").evaluate("haystack"));
357        assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
358    }
359
360    #[test]
361    fn test_starts_with() {
362        assert!(Predicate::StartsWith("hay").evaluate("haystack"));
363        assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
364        assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
365        assert!(Predicate::StartsWith("ha").evaluate("haystack"));
366        assert!(Predicate::StartsWith("h").evaluate("haystack"));
367        assert!(Predicate::StartsWith("").evaluate("haystack"));
368
369        assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
370        assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
371        assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
372        assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
373        assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
374    }
375
376    #[test]
377    fn test_ends_with() {
378        assert!(Predicate::EndsWith("stack").evaluate("haystack"));
379        assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
380        assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
381        assert!(Predicate::EndsWith("ck").evaluate("haystack"));
382        assert!(Predicate::EndsWith("k").evaluate("haystack"));
383        assert!(Predicate::EndsWith("").evaluate("haystack"));
384
385        assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
386        assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
387        assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
388        assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
389        assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
390        assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
391    }
392
393    #[test]
394    fn test_istarts_with() {
395        assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
396        assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
397        assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
398        assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
399        assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
400        assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
401        assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
402        assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
403        assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
404
405        assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
406        assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
407        assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
408        assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
409    }
410
411    #[test]
412    fn test_iends_with() {
413        assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
414        assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
415        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
416        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
417        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
418        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
419        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
420        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
421        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
422        assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
423        assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
424        assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
425        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
426        assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
427        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
428        assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
429
430        assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
431        assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
432        assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
433        assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
434        assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
435    }
436}