1use crate::binary_predicate::BinaryPredicate;
21
22use arrow_array::cast::AsArray;
23use arrow_array::*;
24use arrow_schema::*;
25use arrow_select::take::take;
26
27#[derive(Debug)]
28pub(crate) enum Op {
29 Contains,
30 StartsWith,
31 EndsWith,
32}
33
34impl TryFrom<crate::like::Op> for Op {
35 type Error = ArrowError;
36
37 fn try_from(value: crate::like::Op) -> Result<Self, Self::Error> {
38 match value {
39 crate::like::Op::Contains => Ok(Op::Contains),
40 crate::like::Op::StartsWith => Ok(Op::StartsWith),
41 crate::like::Op::EndsWith => Ok(Op::EndsWith),
42 _ => Err(ArrowError::InvalidArgumentError(format!(
43 "Invalid binary operation: {value}"
44 ))),
45 }
46 }
47}
48
49impl std::fmt::Display for Op {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Op::Contains => write!(f, "CONTAINS"),
53 Op::StartsWith => write!(f, "STARTS_WITH"),
54 Op::EndsWith => write!(f, "ENDS_WITH"),
55 }
56 }
57}
58
59pub(crate) fn binary_apply<'a, 'i, T: BinaryArrayType<'a> + 'a>(
60 op: Op,
61 l: T,
62 l_s: bool,
63 l_v: Option<&'a dyn AnyDictionaryArray>,
64 r: T,
65 r_s: bool,
66 r_v: Option<&'a dyn AnyDictionaryArray>,
67) -> Result<BooleanArray, ArrowError> {
68 let l_len = l_v.map(|l| l.len()).unwrap_or(l.len());
69 if r_s {
70 let idx = match r_v {
71 Some(dict) if dict.null_count() != 0 => return Ok(BooleanArray::new_null(l_len)),
72 Some(dict) => dict.normalized_keys()[0],
73 None => 0,
74 };
75 if r.is_null(idx) {
76 return Ok(BooleanArray::new_null(l_len));
77 }
78 op_scalar::<T>(op, l, l_v, r.value(idx))
79 } else {
80 match (l_s, l_v, r_v) {
81 (true, None, None) => {
82 let v = l.is_valid(0).then(|| l.value(0));
83 op_binary(op, std::iter::repeat(v), r.iter())
84 }
85 (true, Some(l_v), None) => {
86 let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
87 let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
88 op_binary(op, std::iter::repeat(v), r.iter())
89 }
90 (true, None, Some(r_v)) => {
91 let v = l.is_valid(0).then(|| l.value(0));
92 op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
93 }
94 (true, Some(l_v), Some(r_v)) => {
95 let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
96 let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
97 op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
98 }
99 (false, None, None) => op_binary(op, l.iter(), r.iter()),
100 (false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()),
101 (false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)),
102 (false, Some(l_v), Some(r_v)) => {
103 op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v))
104 }
105 }
106 }
107}
108
109#[inline(never)]
110fn op_scalar<'a, T: BinaryArrayType<'a>>(
111 op: Op,
112 l: T,
113 l_v: Option<&dyn AnyDictionaryArray>,
114 r: &[u8],
115) -> Result<BooleanArray, ArrowError> {
116 let r = match op {
117 Op::Contains => BinaryPredicate::contains(r).evaluate_array(l, false),
118 Op::StartsWith => BinaryPredicate::StartsWith(r).evaluate_array(l, false),
119 Op::EndsWith => BinaryPredicate::EndsWith(r).evaluate_array(l, false),
120 };
121
122 Ok(match l_v {
123 Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(),
124 None => r,
125 })
126}
127
128fn vectored_iter<'a, T: BinaryArrayType<'a> + 'a>(
129 a: T,
130 a_v: &'a dyn AnyDictionaryArray,
131) -> impl Iterator<Item = Option<&'a [u8]>> + 'a {
132 let nulls = a_v.nulls();
133 let keys = a_v.normalized_keys();
134 keys.into_iter().enumerate().map(move |(idx, key)| {
135 if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) {
136 return None;
137 }
138 Some(a.value(key))
139 })
140}
141
142#[inline(never)]
143fn op_binary<'a>(
144 op: Op,
145 l: impl Iterator<Item = Option<&'a [u8]>>,
146 r: impl Iterator<Item = Option<&'a [u8]>>,
147) -> Result<BooleanArray, ArrowError> {
148 match op {
149 Op::Contains => Ok(l
150 .zip(r)
151 .map(|(l, r)| Some(bytes_contains(l?, r?)))
152 .collect()),
153 Op::StartsWith => Ok(l
154 .zip(r)
155 .map(|(l, r)| Some(BinaryPredicate::StartsWith(r?).evaluate(l?)))
156 .collect()),
157 Op::EndsWith => Ok(l
158 .zip(r)
159 .map(|(l, r)| Some(BinaryPredicate::EndsWith(r?).evaluate(l?)))
160 .collect()),
161 }
162}
163
164fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool {
165 memchr::memmem::find(haystack, needle).is_some()
166}