use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
use arrow_buffer::BooleanBuffer;
use arrow_schema::ArrowError;
use memchr::memchr3;
use memchr::memmem::Finder;
use regex::{Regex, RegexBuilder};
use std::iter::zip;
pub enum Predicate<'a> {
Eq(&'a str),
Contains(Finder<'a>),
StartsWith(&'a str),
EndsWith(&'a str),
IEqAscii(&'a str),
IStartsWithAscii(&'a str),
IEndsWithAscii(&'a str),
Regex(Regex),
}
impl<'a> Predicate<'a> {
pub fn like(pattern: &'a str) -> Result<Self, ArrowError> {
if !contains_like_pattern(pattern) {
Ok(Self::Eq(pattern))
} else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
} else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
Ok(Self::EndsWith(&pattern[1..]))
} else if pattern.starts_with('%')
&& pattern.ends_with('%')
&& !contains_like_pattern(&pattern[1..pattern.len() - 1])
{
Ok(Self::contains(&pattern[1..pattern.len() - 1]))
} else {
Ok(Self::Regex(regex_like(pattern, false)?))
}
}
pub fn contains(needle: &'a str) -> Self {
Self::Contains(Finder::new(needle.as_bytes()))
}
pub fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
if is_ascii && pattern.is_ascii() {
if !contains_like_pattern(pattern) {
return Ok(Self::IEqAscii(pattern));
} else if pattern.ends_with('%')
&& !pattern.ends_with("\\%")
&& !contains_like_pattern(&pattern[..pattern.len() - 1])
{
return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
} else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
return Ok(Self::IEndsWithAscii(&pattern[1..]));
}
}
Ok(Self::Regex(regex_like(pattern, true)?))
}
pub fn evaluate(&self, haystack: &str) -> bool {
match self {
Predicate::Eq(v) => *v == haystack,
Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
Predicate::IStartsWithAscii(v) => {
starts_with(haystack, v, equals_ignore_ascii_case_kernel)
}
Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
Predicate::Regex(v) => v.is_match(haystack),
}
}
#[inline(never)]
pub fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
where
T: ArrayAccessor<Item = &'i str>,
{
match self {
Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
(haystack.len() == v.len() && haystack == *v) != negate
}),
Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
haystack.eq_ignore_ascii_case(v) != negate
}),
Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
finder.find(haystack.as_bytes()).is_some() != negate
}),
Predicate::StartsWith(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let nulls = string_view_array.logical_nulls();
let values = BooleanBuffer::from(
string_view_array
.prefix_bytes_iter(v.len())
.map(|haystack| {
equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
})
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
})
}
}
Predicate::IStartsWithAscii(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let nulls = string_view_array.logical_nulls();
let values = BooleanBuffer::from(
string_view_array
.prefix_bytes_iter(v.len())
.map(|haystack| {
equals_bytes(
haystack,
v.as_bytes(),
equals_ignore_ascii_case_kernel,
) != negate
})
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
})
}
}
Predicate::EndsWith(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let nulls = string_view_array.logical_nulls();
let values = BooleanBuffer::from(
string_view_array
.suffix_bytes_iter(v.len())
.map(|haystack| {
equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
})
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
})
}
}
Predicate::IEndsWithAscii(v) => {
if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
let nulls = string_view_array.logical_nulls();
let values = BooleanBuffer::from(
string_view_array
.suffix_bytes_iter(v.len())
.map(|haystack| {
equals_bytes(
haystack,
v.as_bytes(),
equals_ignore_ascii_case_kernel,
) != negate
})
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
})
}
}
Predicate::Regex(v) => {
BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
}
}
}
}
fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
}
fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
}
}
fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(
haystack.as_bytes().iter().rev(),
needle.as_bytes().iter().rev(),
)
.all(byte_eq_kernel)
}
}
fn equals_kernel((n, h): (&u8, &u8)) -> bool {
n == h
}
fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
n.eq_ignore_ascii_case(h)
}
fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
let mut result = String::with_capacity(pattern.len() * 2);
let mut chars_iter = pattern.chars().peekable();
match chars_iter.peek() {
Some('%') => {
chars_iter.next();
}
_ => result.push('^'),
};
while let Some(c) = chars_iter.next() {
match c {
'\\' => {
match chars_iter.peek() {
Some(&next) => {
if regex_syntax::is_meta_character(next) {
result.push('\\');
}
result.push(next);
chars_iter.next();
}
None => {
result.push('\\');
result.push('\\');
}
}
}
'%' => result.push_str(".*"),
'_' => result.push('.'),
c => {
if regex_syntax::is_meta_character(c) {
result.push('\\');
}
result.push(c);
}
}
}
if result.ends_with(".*") {
result.pop();
result.pop();
} else {
result.push('$');
}
RegexBuilder::new(&result)
.case_insensitive(case_insensitive)
.dot_matches_new_line(true)
.build()
.map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Unable to build regex from LIKE pattern: {e}"
))
})
}
fn contains_like_pattern(pattern: &str) -> bool {
memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regex_like() {
let test_cases = [
(r"%foobar%", r"foobar"),
(r"foo%bar", r"^foo.*bar$"),
(r"foo_bar", r"^foo.bar$"),
(r"\%\_", r"^%_$"),
(r"\a", r"^a$"),
(r"\\%", r"^\\"),
(r"\\a", r"^\\a$"),
(r".", r"^\.$"),
(r"$", r"^\$$"),
(r"\\", r"^\\$"),
];
for (like_pattern, expected_regexp) in test_cases {
let r = regex_like(like_pattern, false).unwrap();
assert_eq!(r.to_string(), expected_regexp);
}
}
#[test]
fn test_contains() {
assert!(Predicate::contains("hay").evaluate("haystack"));
assert!(Predicate::contains("haystack").evaluate("haystack"));
assert!(Predicate::contains("h").evaluate("haystack"));
assert!(Predicate::contains("k").evaluate("haystack"));
assert!(Predicate::contains("stack").evaluate("haystack"));
assert!(Predicate::contains("sta").evaluate("haystack"));
assert!(Predicate::contains("stack").evaluate("hay£stack"));
assert!(Predicate::contains("y£s").evaluate("hay£stack"));
assert!(Predicate::contains("£").evaluate("hay£stack"));
assert!(Predicate::contains("a").evaluate("a"));
assert!(!Predicate::contains("hy").evaluate("haystack"));
assert!(!Predicate::contains("stackx").evaluate("haystack"));
assert!(!Predicate::contains("x").evaluate("haystack"));
assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
}
#[test]
fn test_starts_with() {
assert!(Predicate::StartsWith("hay").evaluate("haystack"));
assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
assert!(Predicate::StartsWith("ha").evaluate("haystack"));
assert!(Predicate::StartsWith("h").evaluate("haystack"));
assert!(Predicate::StartsWith("").evaluate("haystack"));
assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
}
#[test]
fn test_ends_with() {
assert!(Predicate::EndsWith("stack").evaluate("haystack"));
assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
assert!(Predicate::EndsWith("ck").evaluate("haystack"));
assert!(Predicate::EndsWith("k").evaluate("haystack"));
assert!(Predicate::EndsWith("").evaluate("haystack"));
assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
}
#[test]
fn test_istarts_with() {
assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
}
#[test]
fn test_iends_with() {
assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
}
}