arrow_string/
regexp.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines kernel to extract substrings based on a regular
19//! expression of a \[Large\]StringArray
20
21use crate::like::StringArrayType;
22
23use arrow_array::builder::{
24    BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder,
25};
26use arrow_array::cast::AsArray;
27use arrow_array::*;
28use arrow_buffer::NullBuffer;
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, Field};
31use regex::Regex;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36/// Return BooleanArray indicating which strings in an array match an array of
37/// regular expressions.
38///
39/// This is equivalent to the SQL `array ~ regex_array`, supporting
40/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
41///
42/// If `regex_array` element has an empty value, the corresponding result value is always true.
43///
44/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
45/// which allow special search modes, such as case-insensitive and multi-line mode.
46/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
47/// for more information.
48///
49/// # See Also
50/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings
51/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
52///
53/// # Example
54/// ```
55/// # use arrow_array::{StringArray, BooleanArray};
56/// # use arrow_string::regexp::regexp_is_match;
57/// // First array is the array of strings to match
58/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
59/// // Second array is the array of regular expressions to match against
60/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]);
61/// // Third array is the array of flags to use for each regular expression, if desired
62/// // (the type must be provided to satisfy type inference for the third parameter)
63/// let flags_array: Option<&StringArray> = None;
64/// // The result is a BooleanArray indicating when each string in `array`
65/// // matches the corresponding regular expression in `regex_array`
66/// let result = regexp_is_match(&array, &regex_array, flags_array).unwrap();
67/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true]));
68/// ```
69pub fn regexp_is_match<'a, S1, S2, S3>(
70    array: &'a S1,
71    regex_array: &'a S2,
72    flags_array: Option<&'a S3>,
73) -> Result<BooleanArray, ArrowError>
74where
75    &'a S1: StringArrayType<'a>,
76    &'a S2: StringArrayType<'a>,
77    &'a S3: StringArrayType<'a>,
78{
79    if array.len() != regex_array.len() {
80        return Err(ArrowError::ComputeError(
81            "Cannot perform comparison operation on arrays of different length".to_string(),
82        ));
83    }
84
85    let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
86
87    let mut patterns: HashMap<String, Regex> = HashMap::new();
88    let mut result = BooleanBufferBuilder::new(array.len());
89
90    let complete_pattern = match flags_array {
91        Some(flags) => Box::new(
92            regex_array
93                .iter()
94                .zip(flags.iter())
95                .map(|(pattern, flags)| {
96                    pattern.map(|pattern| match flags {
97                        Some(flag) => format!("(?{flag}){pattern}"),
98                        None => pattern.to_string(),
99                    })
100                }),
101        ) as Box<dyn Iterator<Item = Option<String>>>,
102        None => Box::new(
103            regex_array
104                .iter()
105                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
106        ),
107    };
108
109    array
110        .iter()
111        .zip(complete_pattern)
112        .map(|(value, pattern)| {
113            match (value, pattern) {
114                // Required for Postgres compatibility:
115                // SELECT 'foobarbequebaz' ~ ''); = true
116                (Some(_), Some(pattern)) if pattern == *"" => {
117                    result.append(true);
118                }
119                (Some(value), Some(pattern)) => {
120                    let existing_pattern = patterns.get(&pattern);
121                    let re = match existing_pattern {
122                        Some(re) => re,
123                        None => {
124                            let re = Regex::new(pattern.as_str()).map_err(|e| {
125                                ArrowError::ComputeError(format!(
126                                    "Regular expression did not compile: {e:?}"
127                                ))
128                            })?;
129                            patterns.entry(pattern).or_insert(re)
130                        }
131                    };
132                    result.append(re.is_match(value));
133                }
134                _ => result.append(false),
135            }
136            Ok(())
137        })
138        .collect::<Result<Vec<()>, ArrowError>>()?;
139
140    let data = unsafe {
141        ArrayDataBuilder::new(DataType::Boolean)
142            .len(array.len())
143            .buffers(vec![result.into()])
144            .nulls(nulls)
145            .build_unchecked()
146    };
147
148    Ok(BooleanArray::from(data))
149}
150
151/// Return BooleanArray indicating which strings in an array match a single regular expression.
152///
153/// This is equivalent to the SQL `array ~ regex_array`, supporting
154/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
155///
156/// See the documentation on [`regexp_is_match`] for more details on arguments
157///
158/// # See Also
159/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings
160/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
161///
162/// # Example
163/// ```
164/// # use arrow_array::{StringArray, BooleanArray};
165/// # use arrow_string::regexp::regexp_is_match_scalar;
166/// // array of strings to match
167/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
168/// let regexp = "^Foo"; // regular expression to match against
169/// let flags: Option<&str> = None;  // flags can control the matching behavior
170/// // The result is a BooleanArray indicating when each string in `array`
171/// // matches the regular expression `regexp`
172/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap();
173/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false]));
174/// ```
175pub fn regexp_is_match_scalar<'a, S>(
176    array: &'a S,
177    regex: &str,
178    flag: Option<&str>,
179) -> Result<BooleanArray, ArrowError>
180where
181    &'a S: StringArrayType<'a>,
182{
183    let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
184    let mut result = BooleanBufferBuilder::new(array.len());
185
186    let pattern = match flag {
187        Some(flag) => format!("(?{flag}){regex}"),
188        None => regex.to_string(),
189    };
190
191    if pattern.is_empty() {
192        result.append_n(array.len(), true);
193    } else {
194        let re = Regex::new(pattern.as_str()).map_err(|e| {
195            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
196        })?;
197        for i in 0..array.len() {
198            let value = array.value(i);
199            result.append(re.is_match(value));
200        }
201    }
202
203    let buffer = result.into();
204    let data = unsafe {
205        ArrayData::new_unchecked(
206            DataType::Boolean,
207            array.len(),
208            None,
209            null_bit_buffer,
210            0,
211            vec![buffer],
212            vec![],
213        )
214    };
215
216    Ok(BooleanArray::from(data))
217}
218
219macro_rules! process_regexp_array_match {
220    ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => {
221        let mut patterns: HashMap<String, Regex> = HashMap::new();
222
223        let complete_pattern = match $flags_array {
224            Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map(
225                |(pattern, flags)| {
226                    pattern.map(|pattern| match flags {
227                        Some(value) => format!("(?{value}){pattern}"),
228                        None => pattern.to_string(),
229                    })
230                },
231            )) as Box<dyn Iterator<Item = Option<String>>>,
232            None => Box::new(
233                $regex_array
234                    .iter()
235                    .map(|pattern| pattern.map(|pattern| pattern.to_string())),
236            ),
237        };
238
239        $array
240            .iter()
241            .zip(complete_pattern)
242            .map(|(value, pattern)| {
243                match (value, pattern) {
244                    // Required for Postgres compatibility:
245                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
246                    (Some(_), Some(pattern)) if pattern == *"" => {
247                        $list_builder.values().append_value("");
248                        $list_builder.append(true);
249                    }
250                    (Some(value), Some(pattern)) => {
251                        let existing_pattern = patterns.get(&pattern);
252                        let re = match existing_pattern {
253                            Some(re) => re,
254                            None => {
255                                let re = Regex::new(pattern.as_str()).map_err(|e| {
256                                    ArrowError::ComputeError(format!(
257                                        "Regular expression did not compile: {e:?}"
258                                    ))
259                                })?;
260                                patterns.entry(pattern).or_insert(re)
261                            }
262                        };
263                        match re.captures(value) {
264                            Some(caps) => {
265                                let mut iter = caps.iter();
266                                if caps.len() > 1 {
267                                    iter.next();
268                                }
269                                for m in iter.flatten() {
270                                    $list_builder.values().append_value(m.as_str());
271                                }
272
273                                $list_builder.append(true);
274                            }
275                            None => $list_builder.append(false),
276                        }
277                    }
278                    _ => $list_builder.append(false),
279                }
280                Ok(())
281            })
282            .collect::<Result<Vec<()>, ArrowError>>()?;
283    };
284}
285
286fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
287    array: &GenericStringArray<OffsetSize>,
288    regex_array: &GenericStringArray<OffsetSize>,
289    flags_array: Option<&GenericStringArray<OffsetSize>>,
290) -> Result<ArrayRef, ArrowError> {
291    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
292    let mut list_builder = ListBuilder::new(builder);
293
294    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
295
296    Ok(Arc::new(list_builder.finish()))
297}
298
299fn regexp_array_match_utf8view(
300    array: &StringViewArray,
301    regex_array: &StringViewArray,
302    flags_array: Option<&StringViewArray>,
303) -> Result<ArrayRef, ArrowError> {
304    let builder = StringViewBuilder::with_capacity(0);
305    let mut list_builder = ListBuilder::new(builder);
306
307    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
308
309    Ok(Arc::new(list_builder.finish()))
310}
311
312fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
313    regex_array: &'a dyn Array,
314    flag_array: Option<&'a dyn Array>,
315) -> (Option<&'a str>, Option<&'a str>) {
316    let regex = regex_array.as_string::<OffsetSize>();
317    let regex = regex.is_valid(0).then(|| regex.value(0));
318
319    if let Some(flag_array) = flag_array {
320        let flag = flag_array.as_string::<OffsetSize>();
321        (regex, flag.is_valid(0).then(|| flag.value(0)))
322    } else {
323        (regex, None)
324    }
325}
326
327fn get_scalar_pattern_flag_utf8view<'a>(
328    regex_array: &'a dyn Array,
329    flag_array: Option<&'a dyn Array>,
330) -> (Option<&'a str>, Option<&'a str>) {
331    let regex = regex_array.as_string_view();
332    let regex = regex.is_valid(0).then(|| regex.value(0));
333
334    if let Some(flag_array) = flag_array {
335        let flag = flag_array.as_string_view();
336        (regex, flag.is_valid(0).then(|| flag.value(0)))
337    } else {
338        (regex, None)
339    }
340}
341
342macro_rules! process_regexp_match {
343    ($array:expr, $regex:expr, $list_builder:expr) => {
344        $array
345            .iter()
346            .map(|value| {
347                match value {
348                    // Required for Postgres compatibility:
349                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
350                    Some(_) if $regex.as_str().is_empty() => {
351                        $list_builder.values().append_value("");
352                        $list_builder.append(true);
353                    }
354                    Some(value) => match $regex.captures(value) {
355                        Some(caps) => {
356                            let mut iter = caps.iter();
357                            if caps.len() > 1 {
358                                iter.next();
359                            }
360                            for m in iter.flatten() {
361                                $list_builder.values().append_value(m.as_str());
362                            }
363                            $list_builder.append(true);
364                        }
365                        None => $list_builder.append(false),
366                    },
367                    None => $list_builder.append(false),
368                }
369                Ok(())
370            })
371            .collect::<Result<Vec<()>, ArrowError>>()?
372    };
373}
374
375fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
376    array: &GenericStringArray<OffsetSize>,
377    regex: &Regex,
378) -> Result<ArrayRef, ArrowError> {
379    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
380    let mut list_builder = ListBuilder::new(builder);
381
382    process_regexp_match!(array, regex, list_builder);
383
384    Ok(Arc::new(list_builder.finish()))
385}
386
387fn regexp_scalar_match_utf8view(
388    array: &StringViewArray,
389    regex: &Regex,
390) -> Result<ArrayRef, ArrowError> {
391    let builder = StringViewBuilder::with_capacity(0);
392    let mut list_builder = ListBuilder::new(builder);
393
394    process_regexp_match!(array, regex, list_builder);
395
396    Ok(Arc::new(list_builder.finish()))
397}
398
399/// Extract all groups matched by a regular expression for a given String array.
400///
401/// Modelled after the Postgres [regexp_match].
402///
403/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
404/// match of the corresponding index in `regex_array` to string in `array`
405///
406/// If there is no match, the list element is NULL.
407///
408/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
409/// then the list element is a single-element [`GenericStringArray`] containing the substring
410/// matching the whole pattern.
411///
412/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
413/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
414/// the n'th capturing parenthesized subexpression of the pattern.
415///
416/// The flags parameter is an optional text string containing zero or more single-letter flags
417/// that change the function's behavior.
418///
419/// # See Also
420/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings
421///
422/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
423pub fn regexp_match(
424    array: &dyn Array,
425    regex_array: &dyn Datum,
426    flags_array: Option<&dyn Datum>,
427) -> Result<ArrayRef, ArrowError> {
428    let (rhs, is_rhs_scalar) = regex_array.get();
429
430    if array.data_type() != rhs.data_type() {
431        return Err(ArrowError::ComputeError(
432            "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8"
433                .to_string(),
434        ));
435    }
436
437    let (flags, is_flags_scalar) = match flags_array {
438        Some(flags) => {
439            let (flags, is_flags_scalar) = flags.get();
440            (Some(flags), Some(is_flags_scalar))
441        }
442        None => (None, None),
443    };
444
445    if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
446        return Err(ArrowError::ComputeError(
447            "regexp_match() requires both pattern and flags to be either scalar or array"
448                .to_string(),
449        ));
450    }
451
452    if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
453        return Err(ArrowError::ComputeError(
454            "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8"
455                .to_string(),
456        ));
457    }
458
459    if is_rhs_scalar {
460        // Regex and flag is scalars
461        let (regex, flag) = match rhs.data_type() {
462            DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags),
463            DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
464            DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
465            _ => {
466                return Err(ArrowError::ComputeError(
467                    "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8"
468                        .to_string(),
469                ));
470            }
471        };
472
473        if regex.is_none() {
474            return Ok(new_null_array(
475                &DataType::List(Arc::new(Field::new_list_field(
476                    array.data_type().clone(),
477                    true,
478                ))),
479                array.len(),
480            ));
481        }
482
483        let regex = regex.unwrap();
484
485        let pattern = if let Some(flag) = flag {
486            format!("(?{flag}){regex}")
487        } else {
488            regex.to_string()
489        };
490
491        let re = Regex::new(pattern.as_str()).map_err(|e| {
492            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
493        })?;
494
495        match array.data_type() {
496            DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re),
497            DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
498            DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
499            _ => Err(ArrowError::ComputeError(
500                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
501                    .to_string(),
502            )),
503        }
504    } else {
505        match array.data_type() {
506            DataType::Utf8View => {
507                let regex_array = rhs.as_string_view();
508                let flags_array = flags.map(|flags| flags.as_string_view());
509                regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array)
510            }
511            DataType::Utf8 => {
512                let regex_array = rhs.as_string();
513                let flags_array = flags.map(|flags| flags.as_string());
514                regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
515            }
516            DataType::LargeUtf8 => {
517                let regex_array = rhs.as_string();
518                let flags_array = flags.map(|flags| flags.as_string());
519                regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
520            }
521            _ => Err(ArrowError::ComputeError(
522                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
523                    .to_string(),
524            )),
525        }
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    macro_rules! test_match_single_group {
534        ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => {
535            #[test]
536            fn $test_name() {
537                let array: $arr_type = <$arr_type>::from($values);
538                let pattern: $arr_type = <$arr_type>::from($patterns);
539
540                let actual = regexp_match(&array, &pattern, None).unwrap();
541
542                let elem_builder: $builder_type = <$builder_type>::new();
543                let mut expected_builder = ListBuilder::new(elem_builder);
544
545                for val in $expected {
546                    match val {
547                        Some(v) => {
548                            expected_builder.values().append_value(v);
549                            expected_builder.append(true);
550                        }
551                        None => expected_builder.append(false),
552                    }
553                }
554
555                let expected = expected_builder.finish();
556                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
557                assert_eq!(&expected, result);
558            }
559        };
560    }
561
562    test_match_single_group!(
563        match_single_group_string,
564        vec![
565            Some("abc-005-def"),
566            Some("X-7-5"),
567            Some("X545"),
568            None,
569            Some("foobarbequebaz"),
570            Some("foobarbequebaz"),
571        ],
572        vec![
573            r".*-(\d*)-.*",
574            r".*-(\d*)-.*",
575            r".*-(\d*)-.*",
576            r".*-(\d*)-.*",
577            r"(bar)(bequ1e)",
578            ""
579        ],
580        StringArray,
581        GenericStringBuilder<i32>,
582        [Some("005"), Some("7"), None, None, None, Some("")]
583    );
584    test_match_single_group!(
585        match_single_group_string_view,
586        vec![
587            Some("abc-005-def"),
588            Some("X-7-5"),
589            Some("X545"),
590            None,
591            Some("foobarbequebaz"),
592            Some("foobarbequebaz"),
593        ],
594        vec![
595            r".*-(\d*)-.*",
596            r".*-(\d*)-.*",
597            r".*-(\d*)-.*",
598            r".*-(\d*)-.*",
599            r"(bar)(bequ1e)",
600            ""
601        ],
602        StringViewArray,
603        StringViewBuilder,
604        [Some("005"), Some("7"), None, None, None, Some("")]
605    );
606
607    macro_rules! test_match_single_group_with_flags {
608        ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
609            #[test]
610            fn $test_name() {
611                let array: $array_type = <$array_type>::from($values);
612                let pattern: $array_type = <$array_type>::from($patterns);
613                let flags: $array_type = <$array_type>::from($flags);
614
615                let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
616
617                let elem_builder: $builder_type = <$builder_type>::new();
618                let mut expected_builder = ListBuilder::new(elem_builder);
619
620                for val in $expected {
621                    match val {
622                        Some(v) => {
623                            expected_builder.values().append_value(v);
624                            expected_builder.append(true);
625                        }
626                        None => {
627                            expected_builder.append(false);
628                        }
629                    }
630                }
631
632                let expected = expected_builder.finish();
633                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
634                assert_eq!(&expected, result);
635            }
636        };
637    }
638
639    test_match_single_group_with_flags!(
640        match_single_group_with_flags_string,
641        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
642        vec![r"x.*-(\d*)-.*"; 4],
643        vec!["i"; 4],
644        StringArray,
645        GenericStringBuilder<i32>,
646        [None, Some("7"), None, None]
647    );
648    test_match_single_group_with_flags!(
649        match_single_group_with_flags_stringview,
650        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
651        vec![r"x.*-(\d*)-.*"; 4],
652        vec!["i"; 4],
653        StringViewArray,
654        StringViewBuilder,
655        [None, Some("7"), None, None]
656    );
657
658    macro_rules! test_match_scalar_pattern {
659        ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
660            #[test]
661            fn $test_name() {
662                let array: $array_type = <$array_type>::from($values);
663
664                let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1]));
665                let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1]));
666
667                let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap();
668
669                let elem_builder: $builder_type = <$builder_type>::new();
670                let mut expected_builder = ListBuilder::new(elem_builder);
671
672                for val in $expected {
673                    match val {
674                        Some(v) => {
675                            expected_builder.values().append_value(v);
676                            expected_builder.append(true);
677                        }
678                        None => expected_builder.append(false),
679                    }
680                }
681
682                let expected = expected_builder.finish();
683                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
684                assert_eq!(&expected, result);
685            }
686        };
687    }
688
689    test_match_scalar_pattern!(
690        match_scalar_pattern_string_with_flags,
691        vec![
692            Some("abc-005-def"),
693            Some("x-7-5"),
694            Some("X-0-Y"),
695            Some("X545"),
696            None
697        ],
698        r"x.*-(\d*)-.*",
699        Some("i"),
700        StringArray,
701        GenericStringBuilder<i32>,
702        [None, Some("7"), Some("0"), None, None]
703    );
704    test_match_scalar_pattern!(
705        match_scalar_pattern_stringview_with_flags,
706        vec![
707            Some("abc-005-def"),
708            Some("x-7-5"),
709            Some("X-0-Y"),
710            Some("X545"),
711            None
712        ],
713        r"x.*-(\d*)-.*",
714        Some("i"),
715        StringViewArray,
716        StringViewBuilder,
717        [None, Some("7"), Some("0"), None, None]
718    );
719
720    test_match_scalar_pattern!(
721        match_scalar_pattern_string_no_flags,
722        vec![
723            Some("abc-005-def"),
724            Some("x-7-5"),
725            Some("X-0-Y"),
726            Some("X545"),
727            None
728        ],
729        r"x.*-(\d*)-.*",
730        None::<&str>,
731        StringArray,
732        GenericStringBuilder<i32>,
733        [None, Some("7"), None, None, None]
734    );
735    test_match_scalar_pattern!(
736        match_scalar_pattern_stringview_no_flags,
737        vec![
738            Some("abc-005-def"),
739            Some("x-7-5"),
740            Some("X-0-Y"),
741            Some("X545"),
742            None
743        ],
744        r"x.*-(\d*)-.*",
745        None::<&str>,
746        StringViewArray,
747        StringViewBuilder,
748        [None, Some("7"), None, None, None]
749    );
750
751    macro_rules! test_match_scalar_no_pattern {
752        ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => {
753            #[test]
754            fn $test_name() {
755                let array: $array_type = <$array_type>::from($values);
756                let pattern = Scalar::new(new_null_array(&$pattern_type, 1));
757
758                let actual = regexp_match(&array, &pattern, None).unwrap();
759
760                let elem_builder: $builder_type = <$builder_type>::new();
761                let mut expected_builder = ListBuilder::new(elem_builder);
762
763                for val in $expected {
764                    match val {
765                        Some(v) => {
766                            expected_builder.values().append_value(v);
767                            expected_builder.append(true);
768                        }
769                        None => expected_builder.append(false),
770                    }
771                }
772
773                let expected = expected_builder.finish();
774                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
775                assert_eq!(&expected, result);
776            }
777        };
778    }
779
780    test_match_scalar_no_pattern!(
781        match_scalar_no_pattern_string,
782        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
783        StringArray,
784        DataType::Utf8,
785        GenericStringBuilder<i32>,
786        [None::<&str>, None, None, None]
787    );
788    test_match_scalar_no_pattern!(
789        match_scalar_no_pattern_stringview,
790        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
791        StringViewArray,
792        DataType::Utf8View,
793        StringViewBuilder,
794        [None::<&str>, None, None, None]
795    );
796
797    macro_rules! test_match_single_group_not_skip {
798        ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
799            #[test]
800            fn $test_name() {
801                let array: $array_type = <$array_type>::from($values);
802                let pattern: $array_type = <$array_type>::from(vec![$pattern]);
803
804                let actual = regexp_match(&array, &pattern, None).unwrap();
805
806                let elem_builder: $builder_type = <$builder_type>::new();
807                let mut expected_builder = ListBuilder::new(elem_builder);
808
809                for val in $expected {
810                    match val {
811                        Some(v) => {
812                            expected_builder.values().append_value(v);
813                            expected_builder.append(true);
814                        }
815                        None => expected_builder.append(false),
816                    }
817                }
818
819                let expected = expected_builder.finish();
820                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
821                assert_eq!(&expected, result);
822            }
823        };
824    }
825
826    test_match_single_group_not_skip!(
827        match_single_group_not_skip_string,
828        vec![Some("foo"), Some("bar")],
829        r"foo",
830        StringArray,
831        GenericStringBuilder<i32>,
832        [Some("foo")]
833    );
834    test_match_single_group_not_skip!(
835        match_single_group_not_skip_stringview,
836        vec![Some("foo"), Some("bar")],
837        r"foo",
838        StringViewArray,
839        StringViewBuilder,
840        [Some("foo")]
841    );
842
843    macro_rules! test_flag_utf8 {
844        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
845            #[test]
846            fn $test_name() {
847                let left = $left;
848                let right = $right;
849                let res = $op(&left, &right, None).unwrap();
850                let expected = $expected;
851                assert_eq!(expected.len(), res.len());
852                for i in 0..res.len() {
853                    let v = res.value(i);
854                    assert_eq!(v, expected[i]);
855                }
856            }
857        };
858        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
859            #[test]
860            fn $test_name() {
861                let left = $left;
862                let right = $right;
863                let flag = Some($flag);
864                let res = $op(&left, &right, flag.as_ref()).unwrap();
865                let expected = $expected;
866                assert_eq!(expected.len(), res.len());
867                for i in 0..res.len() {
868                    let v = res.value(i);
869                    assert_eq!(v, expected[i]);
870                }
871            }
872        };
873    }
874
875    macro_rules! test_flag_utf8_scalar {
876        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
877            #[test]
878            fn $test_name() {
879                let left = $left;
880                let res = $op(&left, $right, None).unwrap();
881                let expected = $expected;
882                assert_eq!(expected.len(), res.len());
883                for i in 0..res.len() {
884                    let v = res.value(i);
885                    assert_eq!(
886                        v,
887                        expected[i],
888                        "unexpected result when comparing {} at position {} to {} ",
889                        left.value(i),
890                        i,
891                        $right
892                    );
893                }
894            }
895        };
896        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
897            #[test]
898            fn $test_name() {
899                let left = $left;
900                let flag = Some($flag);
901                let res = $op(&left, $right, flag).unwrap();
902                let expected = $expected;
903                assert_eq!(expected.len(), res.len());
904                for i in 0..res.len() {
905                    let v = res.value(i);
906                    assert_eq!(
907                        v,
908                        expected[i],
909                        "unexpected result when comparing {} at position {} to {} ",
910                        left.value(i),
911                        i,
912                        $right
913                    );
914                }
915            }
916        };
917    }
918
919    test_flag_utf8!(
920        test_array_regexp_is_match_utf8,
921        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
922        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
923        regexp_is_match::<StringArray, StringArray, StringArray>,
924        [true, false, true, false, false, true]
925    );
926    test_flag_utf8!(
927        test_array_regexp_is_match_utf8_insensitive,
928        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
929        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
930        StringArray::from(vec!["i"; 6]),
931        regexp_is_match,
932        [true, true, true, true, false, true]
933    );
934
935    test_flag_utf8_scalar!(
936        test_array_regexp_is_match_utf8_scalar,
937        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
938        "^ar",
939        regexp_is_match_scalar,
940        [true, false, false, false]
941    );
942    test_flag_utf8_scalar!(
943        test_array_regexp_is_match_utf8_scalar_empty,
944        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
945        "",
946        regexp_is_match_scalar,
947        [true, true, true, true]
948    );
949    test_flag_utf8_scalar!(
950        test_array_regexp_is_match_utf8_scalar_insensitive,
951        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
952        "^ar",
953        "i",
954        regexp_is_match_scalar,
955        [true, true, false, false]
956    );
957
958    test_flag_utf8!(
959        tes_array_regexp_is_match,
960        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
961        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
962        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
963        [true, false, true, false, false, true]
964    );
965    test_flag_utf8!(
966        test_array_regexp_is_match_2,
967        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
968        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
969        regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
970        [true, false, true, false, false, true]
971    );
972    test_flag_utf8!(
973        test_array_regexp_is_match_insensitive,
974        StringViewArray::from(vec![
975            "Official Rust implementation of Apache Arrow",
976            "apache/arrow-rs",
977            "apache/arrow-rs",
978            "parquet",
979            "parquet",
980            "row",
981            "row",
982        ]),
983        StringViewArray::from(vec![
984            ".*rust implement.*",
985            "^ap",
986            "^AP",
987            "et$",
988            "ET$",
989            "foo",
990            ""
991        ]),
992        StringViewArray::from(vec!["i"; 7]),
993        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
994        [true, true, true, true, true, false, true]
995    );
996    test_flag_utf8!(
997        test_array_regexp_is_match_insensitive_2,
998        LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
999        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
1000        StringArray::from(vec!["i"; 6]),
1001        regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
1002        [true, true, true, true, false, true]
1003    );
1004
1005    test_flag_utf8_scalar!(
1006        test_array_regexp_is_match_scalar,
1007        StringViewArray::from(vec![
1008            "apache/arrow-rs",
1009            "APACHE/ARROW-RS",
1010            "parquet",
1011            "PARQUET",
1012        ]),
1013        "^ap",
1014        regexp_is_match_scalar::<StringViewArray>,
1015        [true, false, false, false]
1016    );
1017    test_flag_utf8_scalar!(
1018        test_array_regexp_is_match_scalar_empty,
1019        StringViewArray::from(vec![
1020            "apache/arrow-rs",
1021            "APACHE/ARROW-RS",
1022            "parquet",
1023            "PARQUET",
1024        ]),
1025        "",
1026        regexp_is_match_scalar::<StringViewArray>,
1027        [true, true, true, true]
1028    );
1029    test_flag_utf8_scalar!(
1030        test_array_regexp_is_match_scalar_insensitive,
1031        StringViewArray::from(vec![
1032            "apache/arrow-rs",
1033            "APACHE/ARROW-RS",
1034            "parquet",
1035            "PARQUET",
1036        ]),
1037        "^ap",
1038        "i",
1039        regexp_is_match_scalar::<StringViewArray>,
1040        [true, true, false, false]
1041    );
1042}