arrow_string/
predicate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
19use arrow_buffer::BooleanBuffer;
20use arrow_schema::ArrowError;
21use memchr::memchr3;
22use memchr::memmem::Finder;
23use regex::{Regex, RegexBuilder};
24use std::iter::zip;
25
26/// A string based predicate
27#[allow(clippy::large_enum_variant)]
28pub(crate) enum Predicate<'a> {
29    Eq(&'a str),
30    Contains(Finder<'a>),
31    StartsWith(&'a str),
32    EndsWith(&'a str),
33
34    /// Equality ignoring ASCII case
35    IEqAscii(&'a str),
36    /// Starts with ignoring ASCII case
37    IStartsWithAscii(&'a str),
38    /// Ends with ignoring ASCII case
39    IEndsWithAscii(&'a str),
40
41    Regex(Regex),
42}
43
44impl<'a> Predicate<'a> {
45    /// Create a predicate for the given like pattern
46    pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> {
47        if !contains_like_pattern(pattern) {
48            Ok(Self::Eq(pattern))
49        } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
50            Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
51        } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
52            Ok(Self::EndsWith(&pattern[1..]))
53        } else if pattern.starts_with('%')
54            && pattern.ends_with('%')
55            && !contains_like_pattern(&pattern[1..pattern.len() - 1])
56        {
57            Ok(Self::contains(&pattern[1..pattern.len() - 1]))
58        } else {
59            Ok(Self::Regex(regex_like(pattern, false)?))
60        }
61    }
62
63    pub(crate) fn contains(needle: &'a str) -> Self {
64        Self::Contains(Finder::new(needle.as_bytes()))
65    }
66
67    /// Create a predicate for the given ilike pattern
68    pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
69        if is_ascii && pattern.is_ascii() {
70            if !contains_like_pattern(pattern) {
71                return Ok(Self::IEqAscii(pattern));
72            } else if pattern.ends_with('%')
73                && !pattern.ends_with("\\%")
74                && !contains_like_pattern(&pattern[..pattern.len() - 1])
75            {
76                return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
77            } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
78                return Ok(Self::IEndsWithAscii(&pattern[1..]));
79            }
80        }
81        Ok(Self::Regex(regex_like(pattern, true)?))
82    }
83
84    /// Evaluate this predicate against the given haystack
85    pub(crate) fn evaluate(&self, haystack: &str) -> bool {
86        match self {
87            Predicate::Eq(v) => *v == haystack,
88            Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
89            Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
90            Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
91            Predicate::IStartsWithAscii(v) => {
92                starts_with(haystack, v, equals_ignore_ascii_case_kernel)
93            }
94            Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
95            Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
96            Predicate::Regex(v) => v.is_match(haystack),
97        }
98    }
99
100    /// Evaluate this predicate against the elements of `array`
101    ///
102    /// If `negate` is true the result of the predicate will be negated
103    #[inline(never)]
104    pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
105    where
106        T: ArrayAccessor<Item = &'i str>,
107    {
108        match self {
109            Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
110                (haystack.len() == v.len() && haystack == *v) != negate
111            }),
112            Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
113                haystack.eq_ignore_ascii_case(v) != negate
114            }),
115            Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
116                finder.find(haystack.as_bytes()).is_some() != negate
117            }),
118            Predicate::StartsWith(v) => {
119                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
120                    let nulls = string_view_array.logical_nulls();
121                    let values = BooleanBuffer::from(
122                        string_view_array
123                            .prefix_bytes_iter(v.len())
124                            .map(|haystack| {
125                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
126                            })
127                            .collect::<Vec<_>>(),
128                    );
129                    BooleanArray::new(values, nulls)
130                } else {
131                    BooleanArray::from_unary(array, |haystack| {
132                        starts_with(haystack, v, equals_kernel) != negate
133                    })
134                }
135            }
136            Predicate::IStartsWithAscii(v) => {
137                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
138                    let nulls = string_view_array.logical_nulls();
139                    let values = BooleanBuffer::from(
140                        string_view_array
141                            .prefix_bytes_iter(v.len())
142                            .map(|haystack| {
143                                equals_bytes(
144                                    haystack,
145                                    v.as_bytes(),
146                                    equals_ignore_ascii_case_kernel,
147                                ) != negate
148                            })
149                            .collect::<Vec<_>>(),
150                    );
151                    BooleanArray::new(values, nulls)
152                } else {
153                    BooleanArray::from_unary(array, |haystack| {
154                        starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
155                    })
156                }
157            }
158            Predicate::EndsWith(v) => {
159                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
160                    let nulls = string_view_array.logical_nulls();
161                    let values = BooleanBuffer::from(
162                        string_view_array
163                            .suffix_bytes_iter(v.len())
164                            .map(|haystack| {
165                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
166                            })
167                            .collect::<Vec<_>>(),
168                    );
169                    BooleanArray::new(values, nulls)
170                } else {
171                    BooleanArray::from_unary(array, |haystack| {
172                        ends_with(haystack, v, equals_kernel) != negate
173                    })
174                }
175            }
176            Predicate::IEndsWithAscii(v) => {
177                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
178                    let nulls = string_view_array.logical_nulls();
179                    let values = BooleanBuffer::from(
180                        string_view_array
181                            .suffix_bytes_iter(v.len())
182                            .map(|haystack| {
183                                equals_bytes(
184                                    haystack,
185                                    v.as_bytes(),
186                                    equals_ignore_ascii_case_kernel,
187                                ) != negate
188                            })
189                            .collect::<Vec<_>>(),
190                    );
191                    BooleanArray::new(values, nulls)
192                } else {
193                    BooleanArray::from_unary(array, |haystack| {
194                        ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
195                    })
196                }
197            }
198            Predicate::Regex(v) => {
199                BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
200            }
201        }
202    }
203}
204
205fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
206    lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
207}
208
209/// This is faster than `str::starts_with` for small strings.
210/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
211fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
212    if needle.len() > haystack.len() {
213        false
214    } else {
215        zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
216    }
217}
218/// This is faster than `str::ends_with` for small strings.
219/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
220fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
221    if needle.len() > haystack.len() {
222        false
223    } else {
224        zip(
225            haystack.as_bytes().iter().rev(),
226            needle.as_bytes().iter().rev(),
227        )
228        .all(byte_eq_kernel)
229    }
230}
231
232fn equals_kernel((n, h): (&u8, &u8)) -> bool {
233    n == h
234}
235
236fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
237    n.eq_ignore_ascii_case(h)
238}
239
240/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does:
241///
242/// 1. Replace `LIKE` multi-character wildcards `%` => `.*` (unless they're at the start or end of the pattern,
243///    where the regex is just truncated - e.g. `%foo%` => `foo` rather than `^.*foo.*$`)
244/// 2. Replace `LIKE` single-character wildcards `_` => `.`
245/// 3. Escape regex meta characters to match them and not be evaluated as regex special chars. e.g. `.` => `\\.`
246/// 4. Replace escaped `LIKE` wildcards removing the escape characters to be able to match it as a regex. e.g. `\\%` => `%`
247fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
248    let mut result = String::with_capacity(pattern.len() * 2);
249    let mut chars_iter = pattern.chars().peekable();
250    match chars_iter.peek() {
251        // if the pattern starts with `%`, we avoid starting the regex with a slow but meaningless `^.*`
252        Some('%') => {
253            chars_iter.next();
254        }
255        _ => result.push('^'),
256    };
257
258    while let Some(c) = chars_iter.next() {
259        match c {
260            '\\' => {
261                match chars_iter.peek() {
262                    Some(&next) => {
263                        if regex_syntax::is_meta_character(next) {
264                            result.push('\\');
265                        }
266                        result.push(next);
267                        // Skipping the next char as it is already appended
268                        chars_iter.next();
269                    }
270                    None => {
271                        // Trailing backslash in the pattern. E.g. PostgreSQL and Trino treat it as an error, but e.g. Snowflake treats it as a literal backslash
272                        result.push('\\');
273                        result.push('\\');
274                    }
275                }
276            }
277            '%' => result.push_str(".*"),
278            '_' => result.push('.'),
279            c => {
280                if regex_syntax::is_meta_character(c) {
281                    result.push('\\');
282                }
283                result.push(c);
284            }
285        }
286    }
287    // instead of ending the regex with `.*$` and making it needlessly slow, we just end the regex
288    if result.ends_with(".*") {
289        result.pop();
290        result.pop();
291    } else {
292        result.push('$');
293    }
294    RegexBuilder::new(&result)
295        .case_insensitive(case_insensitive)
296        .dot_matches_new_line(true)
297        .build()
298        .map_err(|e| {
299            ArrowError::InvalidArgumentError(format!(
300                "Unable to build regex from LIKE pattern: {e}"
301            ))
302        })
303}
304
305fn contains_like_pattern(pattern: &str) -> bool {
306    memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_regex_like() {
315        let test_cases = [
316            // %..%
317            (r"%foobar%", r"foobar"),
318            // ..%..
319            (r"foo%bar", r"^foo.*bar$"),
320            // .._..
321            (r"foo_bar", r"^foo.bar$"),
322            // escaped wildcards
323            (r"\%\_", r"^%_$"),
324            // escaped non-wildcard
325            (r"\a", r"^a$"),
326            // escaped escape and wildcard
327            (r"\\%", r"^\\"),
328            // escaped escape and non-wildcard
329            (r"\\a", r"^\\a$"),
330            // regex meta character
331            (r".", r"^\.$"),
332            (r"$", r"^\$$"),
333            (r"\\", r"^\\$"),
334        ];
335
336        for (like_pattern, expected_regexp) in test_cases {
337            let r = regex_like(like_pattern, false).unwrap();
338            assert_eq!(r.to_string(), expected_regexp);
339        }
340    }
341
342    #[test]
343    fn test_contains() {
344        assert!(Predicate::contains("hay").evaluate("haystack"));
345        assert!(Predicate::contains("haystack").evaluate("haystack"));
346        assert!(Predicate::contains("h").evaluate("haystack"));
347        assert!(Predicate::contains("k").evaluate("haystack"));
348        assert!(Predicate::contains("stack").evaluate("haystack"));
349        assert!(Predicate::contains("sta").evaluate("haystack"));
350        assert!(Predicate::contains("stack").evaluate("hay£stack"));
351        assert!(Predicate::contains("y£s").evaluate("hay£stack"));
352        assert!(Predicate::contains("£").evaluate("hay£stack"));
353        assert!(Predicate::contains("a").evaluate("a"));
354        // not matching
355        assert!(!Predicate::contains("hy").evaluate("haystack"));
356        assert!(!Predicate::contains("stackx").evaluate("haystack"));
357        assert!(!Predicate::contains("x").evaluate("haystack"));
358        assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
359    }
360
361    #[test]
362    fn test_starts_with() {
363        assert!(Predicate::StartsWith("hay").evaluate("haystack"));
364        assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
365        assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
366        assert!(Predicate::StartsWith("ha").evaluate("haystack"));
367        assert!(Predicate::StartsWith("h").evaluate("haystack"));
368        assert!(Predicate::StartsWith("").evaluate("haystack"));
369
370        assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
371        assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
372        assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
373        assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
374        assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
375    }
376
377    #[test]
378    fn test_ends_with() {
379        assert!(Predicate::EndsWith("stack").evaluate("haystack"));
380        assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
381        assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
382        assert!(Predicate::EndsWith("ck").evaluate("haystack"));
383        assert!(Predicate::EndsWith("k").evaluate("haystack"));
384        assert!(Predicate::EndsWith("").evaluate("haystack"));
385
386        assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
387        assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
388        assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
389        assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
390        assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
391        assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
392    }
393
394    #[test]
395    fn test_istarts_with() {
396        assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
397        assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
398        assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
399        assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
400        assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
401        assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
402        assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
403        assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
404        assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
405
406        assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
407        assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
408        assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
409        assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
410    }
411
412    #[test]
413    fn test_iends_with() {
414        assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
415        assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
416        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
417        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
418        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
419        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
420        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
421        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
422        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
423        assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
424        assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
425        assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
426        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
427        assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
428        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
429        assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
430
431        assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
432        assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
433        assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
434        assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
435        assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
436    }
437}