Skip to main content

arrow_string/
regexp.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines kernel to extract substrings based on a regular
19//! expression of a \[Large\]StringArray
20
21use crate::like::StringArrayType;
22
23use arrow_array::builder::{
24    BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder,
25};
26use arrow_array::cast::AsArray;
27use arrow_array::*;
28use arrow_buffer::{BooleanBuffer, NullBuffer};
29use arrow_data::ArrayDataBuilder;
30use arrow_schema::{ArrowError, DataType, Field};
31use regex::Regex;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36/// Return BooleanArray indicating which strings in an array match an array of
37/// regular expressions.
38///
39/// This is equivalent to the SQL `array ~ regex_array`, supporting
40/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
41///
42/// If `regex_array` element has an empty value, the corresponding result value is always true.
43///
44/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
45/// which allow special search modes, such as case-insensitive and multi-line mode.
46/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
47/// for more information.
48///
49/// # See Also
50/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings
51/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
52///
53/// # Example
54/// ```
55/// # use arrow_array::{StringArray, BooleanArray};
56/// # use arrow_string::regexp::regexp_is_match;
57/// // First array is the array of strings to match
58/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
59/// // Second array is the array of regular expressions to match against
60/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]);
61/// // Third array is the array of flags to use for each regular expression, if desired
62/// // (the type must be provided to satisfy type inference for the third parameter)
63/// let flags_array: Option<&StringArray> = None;
64/// // The result is a BooleanArray indicating when each string in `array`
65/// // matches the corresponding regular expression in `regex_array`
66/// let result = regexp_is_match(&array, &regex_array, flags_array).unwrap();
67/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true]));
68/// ```
69pub fn regexp_is_match<'a, S1, S2, S3>(
70    array: &'a S1,
71    regex_array: &'a S2,
72    flags_array: Option<&'a S3>,
73) -> Result<BooleanArray, ArrowError>
74where
75    &'a S1: StringArrayType<'a>,
76    &'a S2: StringArrayType<'a>,
77    &'a S3: StringArrayType<'a>,
78{
79    if array.len() != regex_array.len() {
80        return Err(ArrowError::ComputeError(
81            "Cannot perform comparison operation on arrays of different length".to_string(),
82        ));
83    }
84
85    let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
86
87    let mut patterns: HashMap<String, Regex> = HashMap::new();
88    let mut result = BooleanBufferBuilder::new(array.len());
89
90    let complete_pattern = match flags_array {
91        Some(flags) => Box::new(
92            regex_array
93                .iter()
94                .zip(flags.iter())
95                .map(|(pattern, flags)| {
96                    pattern.map(|pattern| match flags {
97                        Some(flag) => format!("(?{flag}){pattern}"),
98                        None => pattern.to_string(),
99                    })
100                }),
101        ) as Box<dyn Iterator<Item = Option<String>>>,
102        None => Box::new(
103            regex_array
104                .iter()
105                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
106        ),
107    };
108
109    array
110        .iter()
111        .zip(complete_pattern)
112        .map(|(value, pattern)| {
113            match (value, pattern) {
114                // Required for Postgres compatibility:
115                // SELECT 'foobarbequebaz' ~ ''); = true
116                (Some(_), Some(pattern)) if pattern == *"" => {
117                    result.append(true);
118                }
119                (Some(value), Some(pattern)) => {
120                    let existing_pattern = patterns.get(&pattern);
121                    let re = match existing_pattern {
122                        Some(re) => re,
123                        None => {
124                            let re = Regex::new(pattern.as_str()).map_err(|e| {
125                                ArrowError::ComputeError(format!(
126                                    "Regular expression did not compile: {e:?}"
127                                ))
128                            })?;
129                            patterns.entry(pattern).or_insert(re)
130                        }
131                    };
132                    result.append(re.is_match(value));
133                }
134                _ => result.append(false),
135            }
136            Ok(())
137        })
138        .collect::<Result<Vec<()>, ArrowError>>()?;
139
140    let data = unsafe {
141        ArrayDataBuilder::new(DataType::Boolean)
142            .len(array.len())
143            .buffers(vec![result.into()])
144            .nulls(nulls)
145            .build_unchecked()
146    };
147
148    Ok(BooleanArray::from(data))
149}
150
151/// Return BooleanArray indicating which strings in an array match a single regular expression.
152///
153/// This is equivalent to the SQL `array ~ regex_array`, supporting
154/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
155///
156/// See the documentation on [`regexp_is_match`] for more details on arguments
157///
158/// # See Also
159/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings
160/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
161///
162/// # Example
163/// ```
164/// # use arrow_array::{StringArray, BooleanArray};
165/// # use arrow_string::regexp::regexp_is_match_scalar;
166/// // array of strings to match
167/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
168/// let regexp = "^Foo"; // regular expression to match against
169/// let flags: Option<&str> = None;  // flags can control the matching behavior
170/// // The result is a BooleanArray indicating when each string in `array`
171/// // matches the regular expression `regexp`
172/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap();
173/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false]));
174/// ```
175pub fn regexp_is_match_scalar<'a, S>(
176    array: &'a S,
177    regex: &str,
178    flag: Option<&str>,
179) -> Result<BooleanArray, ArrowError>
180where
181    &'a S: StringArrayType<'a>,
182{
183    let mut result = BooleanBufferBuilder::new(array.len());
184
185    let pattern = match flag {
186        Some(flag) => format!("(?{flag}){regex}"),
187        None => regex.to_string(),
188    };
189
190    if pattern.is_empty() {
191        result.append_n(array.len(), true);
192    } else {
193        let re = Regex::new(pattern.as_str()).map_err(|e| {
194            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
195        })?;
196        for i in 0..array.len() {
197            let value = array.value(i);
198            result.append(re.is_match(value));
199        }
200    }
201
202    let values = BooleanBuffer::from(result);
203    let nulls = array
204        .nulls()
205        .map(|n| n.inner().sliced())
206        .and_then(|b| NullBuffer::from_unsliced_buffer(b, array.len()));
207    Ok(BooleanArray::new(values, nulls))
208}
209
210macro_rules! process_regexp_array_match {
211    ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => {
212        let mut patterns: HashMap<String, Regex> = HashMap::new();
213
214        let complete_pattern = match $flags_array {
215            Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map(
216                |(pattern, flags)| {
217                    pattern.map(|pattern| match flags {
218                        Some(value) => format!("(?{value}){pattern}"),
219                        None => pattern.to_string(),
220                    })
221                },
222            )) as Box<dyn Iterator<Item = Option<String>>>,
223            None => Box::new(
224                $regex_array
225                    .iter()
226                    .map(|pattern| pattern.map(|pattern| pattern.to_string())),
227            ),
228        };
229
230        $array
231            .iter()
232            .zip(complete_pattern)
233            .map(|(value, pattern)| {
234                match (value, pattern) {
235                    // Required for Postgres compatibility:
236                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
237                    (Some(_), Some(pattern)) if pattern == *"" => {
238                        $list_builder.values().append_value("");
239                        $list_builder.append(true);
240                    }
241                    (Some(value), Some(pattern)) => {
242                        let existing_pattern = patterns.get(&pattern);
243                        let re = match existing_pattern {
244                            Some(re) => re,
245                            None => {
246                                let re = Regex::new(pattern.as_str()).map_err(|e| {
247                                    ArrowError::ComputeError(format!(
248                                        "Regular expression did not compile: {e:?}"
249                                    ))
250                                })?;
251                                patterns.entry(pattern).or_insert(re)
252                            }
253                        };
254                        match re.captures(value) {
255                            Some(caps) => {
256                                let mut iter = caps.iter();
257                                if caps.len() > 1 {
258                                    iter.next();
259                                }
260                                for m in iter.flatten() {
261                                    $list_builder.values().append_value(m.as_str());
262                                }
263
264                                $list_builder.append(true);
265                            }
266                            None => $list_builder.append(false),
267                        }
268                    }
269                    _ => $list_builder.append(false),
270                }
271                Ok(())
272            })
273            .collect::<Result<Vec<()>, ArrowError>>()?;
274    };
275}
276
277fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
278    array: &GenericStringArray<OffsetSize>,
279    regex_array: &GenericStringArray<OffsetSize>,
280    flags_array: Option<&GenericStringArray<OffsetSize>>,
281) -> Result<ArrayRef, ArrowError> {
282    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
283    let mut list_builder = ListBuilder::new(builder);
284
285    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
286
287    Ok(Arc::new(list_builder.finish()))
288}
289
290fn regexp_array_match_utf8view(
291    array: &StringViewArray,
292    regex_array: &StringViewArray,
293    flags_array: Option<&StringViewArray>,
294) -> Result<ArrayRef, ArrowError> {
295    let builder = StringViewBuilder::with_capacity(0);
296    let mut list_builder = ListBuilder::new(builder);
297
298    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
299
300    Ok(Arc::new(list_builder.finish()))
301}
302
303fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
304    regex_array: &'a dyn Array,
305    flag_array: Option<&'a dyn Array>,
306) -> (Option<&'a str>, Option<&'a str>) {
307    let regex = regex_array.as_string::<OffsetSize>();
308    let regex = regex.is_valid(0).then(|| regex.value(0));
309
310    if let Some(flag_array) = flag_array {
311        let flag = flag_array.as_string::<OffsetSize>();
312        (regex, flag.is_valid(0).then(|| flag.value(0)))
313    } else {
314        (regex, None)
315    }
316}
317
318fn get_scalar_pattern_flag_utf8view<'a>(
319    regex_array: &'a dyn Array,
320    flag_array: Option<&'a dyn Array>,
321) -> (Option<&'a str>, Option<&'a str>) {
322    let regex = regex_array.as_string_view();
323    let regex = regex.is_valid(0).then(|| regex.value(0));
324
325    if let Some(flag_array) = flag_array {
326        let flag = flag_array.as_string_view();
327        (regex, flag.is_valid(0).then(|| flag.value(0)))
328    } else {
329        (regex, None)
330    }
331}
332
333macro_rules! process_regexp_match {
334    ($array:expr, $regex:expr, $list_builder:expr) => {
335        $array
336            .iter()
337            .map(|value| {
338                match value {
339                    // Required for Postgres compatibility:
340                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
341                    Some(_) if $regex.as_str().is_empty() => {
342                        $list_builder.values().append_value("");
343                        $list_builder.append(true);
344                    }
345                    Some(value) => match $regex.captures(value) {
346                        Some(caps) => {
347                            let mut iter = caps.iter();
348                            if caps.len() > 1 {
349                                iter.next();
350                            }
351                            for m in iter.flatten() {
352                                $list_builder.values().append_value(m.as_str());
353                            }
354                            $list_builder.append(true);
355                        }
356                        None => $list_builder.append(false),
357                    },
358                    None => $list_builder.append(false),
359                }
360                Ok(())
361            })
362            .collect::<Result<Vec<()>, ArrowError>>()?
363    };
364}
365
366fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
367    array: &GenericStringArray<OffsetSize>,
368    regex: &Regex,
369) -> Result<ArrayRef, ArrowError> {
370    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
371    let mut list_builder = ListBuilder::new(builder);
372
373    process_regexp_match!(array, regex, list_builder);
374
375    Ok(Arc::new(list_builder.finish()))
376}
377
378fn regexp_scalar_match_utf8view(
379    array: &StringViewArray,
380    regex: &Regex,
381) -> Result<ArrayRef, ArrowError> {
382    let builder = StringViewBuilder::with_capacity(0);
383    let mut list_builder = ListBuilder::new(builder);
384
385    process_regexp_match!(array, regex, list_builder);
386
387    Ok(Arc::new(list_builder.finish()))
388}
389
390/// Extract all groups matched by a regular expression for a given String array.
391///
392/// Modelled after the Postgres [regexp_match].
393///
394/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
395/// match of the corresponding index in `regex_array` to string in `array`
396///
397/// If there is no match, the list element is NULL.
398///
399/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
400/// then the list element is a single-element [`GenericStringArray`] containing the substring
401/// matching the whole pattern.
402///
403/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
404/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
405/// the n'th capturing parenthesized subexpression of the pattern.
406///
407/// The flags parameter is an optional text string containing zero or more single-letter flags
408/// that change the function's behavior.
409///
410/// # See Also
411/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings
412///
413/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
414pub fn regexp_match(
415    array: &dyn Array,
416    regex_array: &dyn Datum,
417    flags_array: Option<&dyn Datum>,
418) -> Result<ArrayRef, ArrowError> {
419    let (rhs, is_rhs_scalar) = regex_array.get();
420
421    if array.data_type() != rhs.data_type() {
422        return Err(ArrowError::ComputeError(
423            "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8"
424                .to_string(),
425        ));
426    }
427
428    let (flags, is_flags_scalar) = match flags_array {
429        Some(flags) => {
430            let (flags, is_flags_scalar) = flags.get();
431            (Some(flags), Some(is_flags_scalar))
432        }
433        None => (None, None),
434    };
435
436    if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
437        return Err(ArrowError::ComputeError(
438            "regexp_match() requires both pattern and flags to be either scalar or array"
439                .to_string(),
440        ));
441    }
442
443    if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
444        return Err(ArrowError::ComputeError(
445            "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8"
446                .to_string(),
447        ));
448    }
449
450    if is_rhs_scalar {
451        // Regex and flag is scalars
452        let (regex, flag) = match rhs.data_type() {
453            DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags),
454            DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
455            DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
456            _ => {
457                return Err(ArrowError::ComputeError(
458                    "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8"
459                        .to_string(),
460                ));
461            }
462        };
463
464        if regex.is_none() {
465            return Ok(new_null_array(
466                &DataType::List(Arc::new(Field::new_list_field(
467                    array.data_type().clone(),
468                    true,
469                ))),
470                array.len(),
471            ));
472        }
473
474        let regex = regex.unwrap();
475
476        let pattern = if let Some(flag) = flag {
477            format!("(?{flag}){regex}")
478        } else {
479            regex.to_string()
480        };
481
482        let re = Regex::new(pattern.as_str()).map_err(|e| {
483            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
484        })?;
485
486        match array.data_type() {
487            DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re),
488            DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
489            DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
490            _ => Err(ArrowError::ComputeError(
491                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
492                    .to_string(),
493            )),
494        }
495    } else {
496        match array.data_type() {
497            DataType::Utf8View => {
498                let regex_array = rhs.as_string_view();
499                let flags_array = flags.map(|flags| flags.as_string_view());
500                regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array)
501            }
502            DataType::Utf8 => {
503                let regex_array = rhs.as_string();
504                let flags_array = flags.map(|flags| flags.as_string());
505                regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
506            }
507            DataType::LargeUtf8 => {
508                let regex_array = rhs.as_string();
509                let flags_array = flags.map(|flags| flags.as_string());
510                regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
511            }
512            _ => Err(ArrowError::ComputeError(
513                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
514                    .to_string(),
515            )),
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    macro_rules! test_match_single_group {
525        ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => {
526            #[test]
527            fn $test_name() {
528                let array: $arr_type = <$arr_type>::from($values);
529                let pattern: $arr_type = <$arr_type>::from($patterns);
530
531                let actual = regexp_match(&array, &pattern, None).unwrap();
532
533                let elem_builder: $builder_type = <$builder_type>::new();
534                let mut expected_builder = ListBuilder::new(elem_builder);
535
536                for val in $expected {
537                    match val {
538                        Some(v) => {
539                            expected_builder.values().append_value(v);
540                            expected_builder.append(true);
541                        }
542                        None => expected_builder.append(false),
543                    }
544                }
545
546                let expected = expected_builder.finish();
547                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
548                assert_eq!(&expected, result);
549            }
550        };
551    }
552
553    test_match_single_group!(
554        match_single_group_string,
555        vec![
556            Some("abc-005-def"),
557            Some("X-7-5"),
558            Some("X545"),
559            None,
560            Some("foobarbequebaz"),
561            Some("foobarbequebaz"),
562        ],
563        vec![
564            r".*-(\d*)-.*",
565            r".*-(\d*)-.*",
566            r".*-(\d*)-.*",
567            r".*-(\d*)-.*",
568            r"(bar)(bequ1e)",
569            ""
570        ],
571        StringArray,
572        GenericStringBuilder<i32>,
573        [Some("005"), Some("7"), None, None, None, Some("")]
574    );
575    test_match_single_group!(
576        match_single_group_string_view,
577        vec![
578            Some("abc-005-def"),
579            Some("X-7-5"),
580            Some("X545"),
581            None,
582            Some("foobarbequebaz"),
583            Some("foobarbequebaz"),
584        ],
585        vec![
586            r".*-(\d*)-.*",
587            r".*-(\d*)-.*",
588            r".*-(\d*)-.*",
589            r".*-(\d*)-.*",
590            r"(bar)(bequ1e)",
591            ""
592        ],
593        StringViewArray,
594        StringViewBuilder,
595        [Some("005"), Some("7"), None, None, None, Some("")]
596    );
597
598    macro_rules! test_match_single_group_with_flags {
599        ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
600            #[test]
601            fn $test_name() {
602                let array: $array_type = <$array_type>::from($values);
603                let pattern: $array_type = <$array_type>::from($patterns);
604                let flags: $array_type = <$array_type>::from($flags);
605
606                let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
607
608                let elem_builder: $builder_type = <$builder_type>::new();
609                let mut expected_builder = ListBuilder::new(elem_builder);
610
611                for val in $expected {
612                    match val {
613                        Some(v) => {
614                            expected_builder.values().append_value(v);
615                            expected_builder.append(true);
616                        }
617                        None => {
618                            expected_builder.append(false);
619                        }
620                    }
621                }
622
623                let expected = expected_builder.finish();
624                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
625                assert_eq!(&expected, result);
626            }
627        };
628    }
629
630    test_match_single_group_with_flags!(
631        match_single_group_with_flags_string,
632        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
633        vec![r"x.*-(\d*)-.*"; 4],
634        vec!["i"; 4],
635        StringArray,
636        GenericStringBuilder<i32>,
637        [None, Some("7"), None, None]
638    );
639    test_match_single_group_with_flags!(
640        match_single_group_with_flags_stringview,
641        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
642        vec![r"x.*-(\d*)-.*"; 4],
643        vec!["i"; 4],
644        StringViewArray,
645        StringViewBuilder,
646        [None, Some("7"), None, None]
647    );
648
649    macro_rules! test_match_scalar_pattern {
650        ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
651            #[test]
652            fn $test_name() {
653                let array: $array_type = <$array_type>::from($values);
654
655                let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1]));
656                let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1]));
657
658                let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap();
659
660                let elem_builder: $builder_type = <$builder_type>::new();
661                let mut expected_builder = ListBuilder::new(elem_builder);
662
663                for val in $expected {
664                    match val {
665                        Some(v) => {
666                            expected_builder.values().append_value(v);
667                            expected_builder.append(true);
668                        }
669                        None => expected_builder.append(false),
670                    }
671                }
672
673                let expected = expected_builder.finish();
674                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
675                assert_eq!(&expected, result);
676            }
677        };
678    }
679
680    test_match_scalar_pattern!(
681        match_scalar_pattern_string_with_flags,
682        vec![
683            Some("abc-005-def"),
684            Some("x-7-5"),
685            Some("X-0-Y"),
686            Some("X545"),
687            None
688        ],
689        r"x.*-(\d*)-.*",
690        Some("i"),
691        StringArray,
692        GenericStringBuilder<i32>,
693        [None, Some("7"), Some("0"), None, None]
694    );
695    test_match_scalar_pattern!(
696        match_scalar_pattern_stringview_with_flags,
697        vec![
698            Some("abc-005-def"),
699            Some("x-7-5"),
700            Some("X-0-Y"),
701            Some("X545"),
702            None
703        ],
704        r"x.*-(\d*)-.*",
705        Some("i"),
706        StringViewArray,
707        StringViewBuilder,
708        [None, Some("7"), Some("0"), None, None]
709    );
710
711    test_match_scalar_pattern!(
712        match_scalar_pattern_string_no_flags,
713        vec![
714            Some("abc-005-def"),
715            Some("x-7-5"),
716            Some("X-0-Y"),
717            Some("X545"),
718            None
719        ],
720        r"x.*-(\d*)-.*",
721        None::<&str>,
722        StringArray,
723        GenericStringBuilder<i32>,
724        [None, Some("7"), None, None, None]
725    );
726    test_match_scalar_pattern!(
727        match_scalar_pattern_stringview_no_flags,
728        vec![
729            Some("abc-005-def"),
730            Some("x-7-5"),
731            Some("X-0-Y"),
732            Some("X545"),
733            None
734        ],
735        r"x.*-(\d*)-.*",
736        None::<&str>,
737        StringViewArray,
738        StringViewBuilder,
739        [None, Some("7"), None, None, None]
740    );
741
742    macro_rules! test_match_scalar_no_pattern {
743        ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => {
744            #[test]
745            fn $test_name() {
746                let array: $array_type = <$array_type>::from($values);
747                let pattern = Scalar::new(new_null_array(&$pattern_type, 1));
748
749                let actual = regexp_match(&array, &pattern, None).unwrap();
750
751                let elem_builder: $builder_type = <$builder_type>::new();
752                let mut expected_builder = ListBuilder::new(elem_builder);
753
754                for val in $expected {
755                    match val {
756                        Some(v) => {
757                            expected_builder.values().append_value(v);
758                            expected_builder.append(true);
759                        }
760                        None => expected_builder.append(false),
761                    }
762                }
763
764                let expected = expected_builder.finish();
765                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
766                assert_eq!(&expected, result);
767            }
768        };
769    }
770
771    test_match_scalar_no_pattern!(
772        match_scalar_no_pattern_string,
773        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
774        StringArray,
775        DataType::Utf8,
776        GenericStringBuilder<i32>,
777        [None::<&str>, None, None, None]
778    );
779    test_match_scalar_no_pattern!(
780        match_scalar_no_pattern_stringview,
781        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
782        StringViewArray,
783        DataType::Utf8View,
784        StringViewBuilder,
785        [None::<&str>, None, None, None]
786    );
787
788    macro_rules! test_match_single_group_not_skip {
789        ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
790            #[test]
791            fn $test_name() {
792                let array: $array_type = <$array_type>::from($values);
793                let pattern: $array_type = <$array_type>::from(vec![$pattern]);
794
795                let actual = regexp_match(&array, &pattern, None).unwrap();
796
797                let elem_builder: $builder_type = <$builder_type>::new();
798                let mut expected_builder = ListBuilder::new(elem_builder);
799
800                for val in $expected {
801                    match val {
802                        Some(v) => {
803                            expected_builder.values().append_value(v);
804                            expected_builder.append(true);
805                        }
806                        None => expected_builder.append(false),
807                    }
808                }
809
810                let expected = expected_builder.finish();
811                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
812                assert_eq!(&expected, result);
813            }
814        };
815    }
816
817    test_match_single_group_not_skip!(
818        match_single_group_not_skip_string,
819        vec![Some("foo"), Some("bar")],
820        r"foo",
821        StringArray,
822        GenericStringBuilder<i32>,
823        [Some("foo")]
824    );
825    test_match_single_group_not_skip!(
826        match_single_group_not_skip_stringview,
827        vec![Some("foo"), Some("bar")],
828        r"foo",
829        StringViewArray,
830        StringViewBuilder,
831        [Some("foo")]
832    );
833
834    macro_rules! test_flag_utf8 {
835        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
836            #[test]
837            fn $test_name() {
838                let left = $left;
839                let right = $right;
840                let res = $op(&left, &right, None).unwrap();
841                let expected = $expected;
842                assert_eq!(expected.len(), res.len());
843                for i in 0..res.len() {
844                    let v = res.value(i);
845                    assert_eq!(v, expected[i]);
846                }
847            }
848        };
849        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
850            #[test]
851            fn $test_name() {
852                let left = $left;
853                let right = $right;
854                let flag = Some($flag);
855                let res = $op(&left, &right, flag.as_ref()).unwrap();
856                let expected = $expected;
857                assert_eq!(expected.len(), res.len());
858                for i in 0..res.len() {
859                    let v = res.value(i);
860                    assert_eq!(v, expected[i]);
861                }
862            }
863        };
864    }
865
866    macro_rules! test_flag_utf8_scalar {
867        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
868            #[test]
869            fn $test_name() {
870                let left = $left;
871                let res = $op(&left, $right, None).unwrap();
872                let expected = $expected;
873                assert_eq!(expected.len(), res.len());
874                for i in 0..res.len() {
875                    let v = res.value(i);
876                    assert_eq!(
877                        v,
878                        expected[i],
879                        "unexpected result when comparing {} at position {} to {} ",
880                        left.value(i),
881                        i,
882                        $right
883                    );
884                }
885            }
886        };
887        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
888            #[test]
889            fn $test_name() {
890                let left = $left;
891                let flag = Some($flag);
892                let res = $op(&left, $right, flag).unwrap();
893                let expected = $expected;
894                assert_eq!(expected.len(), res.len());
895                for i in 0..res.len() {
896                    let v = res.value(i);
897                    assert_eq!(
898                        v,
899                        expected[i],
900                        "unexpected result when comparing {} at position {} to {} ",
901                        left.value(i),
902                        i,
903                        $right
904                    );
905                }
906            }
907        };
908    }
909
910    test_flag_utf8!(
911        test_array_regexp_is_match_utf8,
912        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
913        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
914        regexp_is_match::<StringArray, StringArray, StringArray>,
915        [true, false, true, false, false, true]
916    );
917    test_flag_utf8!(
918        test_array_regexp_is_match_utf8_insensitive,
919        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
920        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
921        StringArray::from(vec!["i"; 6]),
922        regexp_is_match,
923        [true, true, true, true, false, true]
924    );
925
926    test_flag_utf8_scalar!(
927        test_array_regexp_is_match_utf8_scalar,
928        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
929        "^ar",
930        regexp_is_match_scalar,
931        [true, false, false, false]
932    );
933    test_flag_utf8_scalar!(
934        test_array_regexp_is_match_utf8_scalar_empty,
935        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
936        "",
937        regexp_is_match_scalar,
938        [true, true, true, true]
939    );
940    test_flag_utf8_scalar!(
941        test_array_regexp_is_match_utf8_scalar_insensitive,
942        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
943        "^ar",
944        "i",
945        regexp_is_match_scalar,
946        [true, true, false, false]
947    );
948
949    test_flag_utf8!(
950        tes_array_regexp_is_match,
951        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
952        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
953        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
954        [true, false, true, false, false, true]
955    );
956    test_flag_utf8!(
957        test_array_regexp_is_match_2,
958        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
959        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
960        regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
961        [true, false, true, false, false, true]
962    );
963    test_flag_utf8!(
964        test_array_regexp_is_match_insensitive,
965        StringViewArray::from(vec![
966            "Official Rust implementation of Apache Arrow",
967            "apache/arrow-rs",
968            "apache/arrow-rs",
969            "parquet",
970            "parquet",
971            "row",
972            "row",
973        ]),
974        StringViewArray::from(vec![
975            ".*rust implement.*",
976            "^ap",
977            "^AP",
978            "et$",
979            "ET$",
980            "foo",
981            ""
982        ]),
983        StringViewArray::from(vec!["i"; 7]),
984        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
985        [true, true, true, true, true, false, true]
986    );
987    test_flag_utf8!(
988        test_array_regexp_is_match_insensitive_2,
989        LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
990        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
991        StringArray::from(vec!["i"; 6]),
992        regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
993        [true, true, true, true, false, true]
994    );
995
996    test_flag_utf8_scalar!(
997        test_array_regexp_is_match_scalar,
998        StringViewArray::from(vec![
999            "apache/arrow-rs",
1000            "APACHE/ARROW-RS",
1001            "parquet",
1002            "PARQUET",
1003        ]),
1004        "^ap",
1005        regexp_is_match_scalar::<StringViewArray>,
1006        [true, false, false, false]
1007    );
1008    test_flag_utf8_scalar!(
1009        test_array_regexp_is_match_scalar_empty,
1010        StringViewArray::from(vec![
1011            "apache/arrow-rs",
1012            "APACHE/ARROW-RS",
1013            "parquet",
1014            "PARQUET",
1015        ]),
1016        "",
1017        regexp_is_match_scalar::<StringViewArray>,
1018        [true, true, true, true]
1019    );
1020    test_flag_utf8_scalar!(
1021        test_array_regexp_is_match_scalar_insensitive,
1022        StringViewArray::from(vec![
1023            "apache/arrow-rs",
1024            "APACHE/ARROW-RS",
1025            "parquet",
1026            "PARQUET",
1027        ]),
1028        "^ap",
1029        "i",
1030        regexp_is_match_scalar::<StringViewArray>,
1031        [true, true, false, false]
1032    );
1033}