From d05cf6d5e74e79ddcacaa4a68bddaba230b0f163 Mon Sep 17 00:00:00 2001 From: Tai Le Manh <49281946+tlm365@users.noreply.github.com> Date: Sat, 21 Sep 2024 13:44:45 -0700 Subject: [PATCH] Implement native support StringViewArray for `regexp_is_match` and `regexp_is_match_scalar` function, deprecate `regexp_is_match_utf8` and `regexp_is_match_utf8_scalar` (#6376) * Implement native support StringViewArray for regex_is_match function * Update test cases cover StringViewArray length more then 12 bytes * Add StringView benchmark for regexp_is_match Signed-off-by: Tai Le Manh * Implement native support StringViewArray for regex_is_match function Signed-off-by: Tai Le Manh * Remove duplicate implementation, fix clippy, add docs more --------- Signed-off-by: Tai Le Manh Co-authored-by: Andrew Lamb --- arrow-string/src/like.rs | 2 +- arrow-string/src/regexp.rs | 228 +++++++++++++++++++++++++--- arrow/benches/comparison_kernels.rs | 67 ++++++-- arrow/src/compute/kernels.rs | 3 + 4 files changed, 261 insertions(+), 39 deletions(-) diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 4626be1362e9..4a6c5bab90e6 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -155,7 +155,7 @@ fn like_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result: ArrayAccessor + Sized { +pub trait StringArrayType<'a>: ArrayAccessor + Sized { fn is_ascii(&self) -> bool; fn iter(&self) -> ArrayIter; } diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index f79eff4b6ea8..5ad452a17b12 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -18,6 +18,8 @@ //! Defines kernel to extract substrings based on a regular //! expression of a \[Large\]StringArray +use crate::like::StringArrayType; + use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; use arrow_array::cast::AsArray; use arrow_array::*; @@ -25,6 +27,7 @@ use arrow_buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; use regex::Regex; + use std::collections::HashMap; use std::sync::Arc; @@ -35,16 +38,64 @@ use std::sync::Arc; /// special search modes, such as case insensitive and multi-line mode. /// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) /// for more information. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")] pub fn regexp_is_match_utf8( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, ) -> Result { + regexp_is_match(array, regex_array, flags_array) +} + +/// Return BooleanArray indicating which strings in an array match an array of +/// regular expressions. +/// +/// This is equivalent to the SQL `array ~ regex_array`, supporting +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// # See Also +/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings +/// * [`regexp_match`] for extracting groups from a string array based on a regular expression +/// +/// # Example +/// ``` +/// # use arrow_array::{StringArray, BooleanArray}; +/// # use arrow_string::regexp::regexp_is_match; +/// // First array is the array of strings to match +/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]); +/// // Second array is the array of regular expressions to match against +/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]); +/// // Third array is the array of flags to use for each regular expression, if desired +/// // (the type must be provided to satisfy type inference for the third parameter) +/// let flags_array: Option<&StringArray> = None; +/// // The result is a BooleanArray indicating when each string in `array` +/// // matches the corresponding regular expression in `regex_array` +/// let result = regexp_is_match(&array, ®ex_array, flags_array).unwrap(); +/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true])); +/// ``` +pub fn regexp_is_match<'a, S1, S2, S3>( + array: &'a S1, + regex_array: &'a S2, + flags_array: Option<&'a S3>, +) -> Result +where + &'a S1: StringArrayType<'a>, + &'a S2: StringArrayType<'a>, + &'a S3: StringArrayType<'a>, +{ if array.len() != regex_array.len() { return Err(ArrowError::ComputeError( "Cannot perform comparison operation on arrays of different length".to_string(), )); } + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); let mut patterns: HashMap = HashMap::new(); @@ -107,6 +158,7 @@ pub fn regexp_is_match_utf8( .nulls(nulls) .build_unchecked() }; + Ok(BooleanArray::from(data)) } @@ -114,11 +166,47 @@ pub fn regexp_is_match_utf8( /// [`LargeStringArray`] and a scalar. /// /// See the documentation on [`regexp_is_match_utf8`] for more details. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")] pub fn regexp_is_match_utf8_scalar( array: &GenericStringArray, regex: &str, flag: Option<&str>, ) -> Result { + regexp_is_match_scalar(array, regex, flag) +} + +/// Return BooleanArray indicating which strings in an array match a single regular expression. +/// +/// This is equivalent to the SQL `array ~ regex_array`, supporting +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar. +/// +/// See the documentation on [`regexp_is_match`] for more details on arguments +/// +/// # See Also +/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings +/// * [`regexp_match`] for extracting groups from a string array based on a regular expression +/// +/// # Example +/// ``` +/// # use arrow_array::{StringArray, BooleanArray}; +/// # use arrow_string::regexp::regexp_is_match_scalar; +/// // array of strings to match +/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]); +/// let regexp = "^Foo"; // regular expression to match against +/// let flags: Option<&str> = None; // flags can control the matching behavior +/// // The result is a BooleanArray indicating when each string in `array` +/// // matches the regular expression `regexp` +/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap(); +/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false])); +/// ``` +pub fn regexp_is_match_scalar<'a, S>( + array: &'a S, + regex: &str, + flag: Option<&str>, +) -> Result +where + &'a S: StringArrayType<'a>, +{ let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); let mut result = BooleanBufferBuilder::new(array.len()); @@ -126,6 +214,7 @@ pub fn regexp_is_match_utf8_scalar( Some(flag) => format!("(?{flag}){regex}"), None => regex.to_string(), }; + if pattern.is_empty() { result.append_n(array.len(), true); } else { @@ -150,6 +239,7 @@ pub fn regexp_is_match_utf8_scalar( vec![], ) }; + Ok(BooleanArray::from(data)) } @@ -303,6 +393,9 @@ fn regexp_scalar_match( /// The flags parameter is an optional text string containing zero or more single-letter flags /// that change the function's behavior. /// +/// # See Also +/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings +/// /// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP pub fn regexp_match( array: &dyn Array, @@ -517,8 +610,8 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { - let left = StringArray::from($left); - let right = StringArray::from($right); + let left = $left; + let right = $right; let res = $op(&left, &right, None).unwrap(); let expected = $expected; assert_eq!(expected.len(), res.len()); @@ -531,9 +624,9 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { - let left = StringArray::from($left); - let right = StringArray::from($right); - let flag = Some(StringArray::from($flag)); + let left = $left; + let right = $right; + let flag = Some($flag); let res = $op(&left, &right, flag.as_ref()).unwrap(); let expected = $expected; assert_eq!(expected.len(), res.len()); @@ -549,7 +642,7 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { - let left = StringArray::from($left); + let left = $left; let res = $op(&left, $right, None).unwrap(); let expected = $expected; assert_eq!(expected.len(), res.len()); @@ -569,7 +662,7 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { - let left = StringArray::from($left); + let left = $left; let flag = Some($flag); let res = $op(&left, $right, flag).unwrap(); let expected = $expected; @@ -590,41 +683,126 @@ mod tests { } test_flag_utf8!( - test_utf8_array_regexp_is_match, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], - vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], - regexp_is_match_utf8, + test_array_regexp_is_match_utf8, + StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + regexp_is_match::, [true, false, true, false, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], - vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], - vec!["i"; 6], - regexp_is_match_utf8, + test_array_regexp_is_match_utf8_insensitive, + StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + StringArray::from(vec!["i"; 6]), + regexp_is_match, [true, true, true, true, false, true] ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], + test_array_regexp_is_match_utf8_scalar, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), "^ar", - regexp_is_match_utf8_scalar, + regexp_is_match_scalar, [true, false, false, false] ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], + test_array_regexp_is_match_utf8_scalar_empty, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), "", - regexp_is_match_utf8_scalar, + regexp_is_match_scalar, [true, true, true, true] ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar, - vec!["arrow", "ARROW", "parquet", "PARQUET"], + test_array_regexp_is_match_utf8_scalar_insensitive, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), "^ar", "i", - regexp_is_match_utf8_scalar, + regexp_is_match_scalar, + [true, true, false, false] + ); + + test_flag_utf8!( + tes_array_regexp_is_match, + StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + regexp_is_match::, + [true, false, true, false, false, true] + ); + test_flag_utf8!( + test_array_regexp_is_match_2, + StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + regexp_is_match::, GenericStringArray>, + [true, false, true, false, false, true] + ); + test_flag_utf8!( + test_array_regexp_is_match_insensitive, + StringViewArray::from(vec![ + "Official Rust implementation of Apache Arrow", + "apache/arrow-rs", + "apache/arrow-rs", + "parquet", + "parquet", + "row", + "row", + ]), + StringViewArray::from(vec![ + ".*rust implement.*", + "^ap", + "^AP", + "et$", + "ET$", + "foo", + "" + ]), + StringViewArray::from(vec!["i"; 7]), + regexp_is_match::, + [true, true, true, true, true, false, true] + ); + test_flag_utf8!( + test_array_regexp_is_match_insensitive_2, + LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + StringArray::from(vec!["i"; 6]), + regexp_is_match::, StringViewArray, GenericStringArray>, + [true, true, true, true, false, true] + ); + + test_flag_utf8_scalar!( + test_array_regexp_is_match_scalar, + StringViewArray::from(vec![ + "apache/arrow-rs", + "APACHE/ARROW-RS", + "parquet", + "PARQUET", + ]), + "^ap", + regexp_is_match_scalar::, + [true, false, false, false] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_scalar_empty, + StringViewArray::from(vec![ + "apache/arrow-rs", + "APACHE/ARROW-RS", + "parquet", + "PARQUET", + ]), + "", + regexp_is_match_scalar::, + [true, true, true, true] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_scalar_insensitive, + StringViewArray::from(vec![ + "apache/arrow-rs", + "APACHE/ARROW-RS", + "parquet", + "PARQUET", + ]), + "^ap", + "i", + regexp_is_match_scalar::, [true, true, false, false] ); } diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index c8aa7dfcf530..4c4a63a775a7 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -15,19 +15,18 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -use arrow::util::test_util::seedable_rng; -use criterion::Criterion; - -extern crate arrow; use arrow::compute::kernels::cmp::*; use arrow::util::bench_util::*; +use arrow::util::test_util::seedable_rng; use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; use arrow_buffer::IntervalMonthDayNano; use arrow_string::like::*; -use arrow_string::regexp::regexp_is_match_utf8_scalar; +use arrow_string::regexp::regexp_is_match_scalar; +use criterion::Criterion; use rand::rngs::StdRng; use rand::Rng; @@ -53,8 +52,17 @@ fn bench_nilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { nilike(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } -fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { - regexp_is_match_utf8_scalar( +fn bench_stringview_regexp_is_match_scalar(arr_a: &StringViewArray, value_b: &str) { + regexp_is_match_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + None, + ) + .unwrap(); +} + +fn bench_string_regexp_is_match_scalar(arr_a: &StringArray, value_b: &str) { + regexp_is_match_scalar( criterion::black_box(arr_a), criterion::black_box(value_b), None, @@ -78,6 +86,7 @@ fn add_benchmark(c: &mut Criterion) { let arr_month_day_nano_b = create_month_day_nano_array_with_seed(SIZE, 0.0, 43); let arr_string = create_string_array::(SIZE, 0.0); + let arr_string_view = create_string_view_array(SIZE, 0.0); let scalar = Float32Array::from(vec![1.0]); @@ -343,13 +352,45 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_nilike_utf8_scalar(&arr_string, "%xx_xX%xXX")) }); - c.bench_function("regexp_matches_utf8 scalar starts with", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "^xx")) - }); + // StringArray: regexp_matches_utf8 scalar benchmarks + let mut group = + c.benchmark_group("StringArray: regexp_matches_utf8 scalar benchmarks".to_string()); - c.bench_function("regexp_matches_utf8 scalar ends with", |b| { - b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$")) - }); + group + .bench_function("regexp_matches_utf8 scalar starts with", |b| { + b.iter(|| bench_string_regexp_is_match_scalar(&arr_string, "^xx")) + }) + .bench_function("regexp_matches_utf8 scalar contains", |b| { + b.iter(|| bench_string_regexp_is_match_scalar(&arr_string, ".*xxXX.*")) + }) + .bench_function("regexp_matches_utf8 scalar ends with", |b| { + b.iter(|| bench_string_regexp_is_match_scalar(&arr_string, "xx$")) + }) + .bench_function("regexp_matches_utf8 scalar complex", |b| { + b.iter(|| bench_string_regexp_is_match_scalar(&arr_string, ".*x{2}.xX.*xXX")) + }); + + group.finish(); + + // StringViewArray: regexp_matches_utf8view scalar benchmarks + group = + c.benchmark_group("StringViewArray: regexp_matches_utf8view scalar benchmarks".to_string()); + + group + .bench_function("regexp_matches_utf8view scalar starts with", |b| { + b.iter(|| bench_stringview_regexp_is_match_scalar(&arr_string_view, "^xx")) + }) + .bench_function("regexp_matches_utf8view scalar contains", |b| { + b.iter(|| bench_stringview_regexp_is_match_scalar(&arr_string_view, ".*xxXX.*")) + }) + .bench_function("regexp_matches_utf8view scalar ends with", |b| { + b.iter(|| bench_stringview_regexp_is_match_scalar(&arr_string_view, "xx$")) + }) + .bench_function("regexp_matches_utf8view scalar complex", |b| { + b.iter(|| bench_stringview_regexp_is_match_scalar(&arr_string_view, ".*x{2}.xX.*xXX")) + }); + + group.finish(); // DictionaryArray benchmarks diff --git a/arrow/src/compute/kernels.rs b/arrow/src/compute/kernels.rs index 4eeb5892c97c..426952ebb5d4 100644 --- a/arrow/src/compute/kernels.rs +++ b/arrow/src/compute/kernels.rs @@ -28,5 +28,8 @@ pub use arrow_string::{concat_elements, length, regexp, substring}; pub mod comparison { pub use arrow_ord::comparison::*; pub use arrow_string::like::*; + // continue to export deprecated methods until they are removed + pub use arrow_string::regexp::{regexp_is_match, regexp_is_match_scalar}; + #[allow(deprecated)] pub use arrow_string::regexp::{regexp_is_match_utf8, regexp_is_match_utf8_scalar}; }