From e6395e21d923caf5b7cd643fc5d4418642d3bb3a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jan 2024 03:21:53 -0800 Subject: [PATCH] Make regexp_match take scalar pattern and flag (#5245) * Make regexp_match take Datum pattern input * Add more tests * More * Update benchmark * Fix clippy * For review * Fix clippy * Don't expose utility function --- arrow-string/src/regexp.rs | 242 +++++++++++++++++++++++++++++--- arrow/benches/regexp_kernels.rs | 9 +- 2 files changed, 227 insertions(+), 24 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 25c712d20f08..5e539b91b492 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -19,10 +19,11 @@ //! expression of a \[Large\]StringArray use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; +use arrow_array::cast::AsArray; use arrow_array::*; use arrow_buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError, DataType, Field}; use regex::Regex; use std::collections::HashMap; use std::sync::Arc; @@ -152,28 +153,7 @@ pub fn regexp_is_match_utf8_scalar( Ok(BooleanArray::from(data)) } -/// Extract all groups matched by a regular expression for a given String array. -/// -/// Modelled after the Postgres [regexp_match]. -/// -/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first -/// match of the corresponding index in `regex_array` to string in `array` -/// -/// If there is no match, the list element is NULL. -/// -/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, -/// then the list element is a single-element [`GenericStringArray`] containing the substring -/// matching the whole pattern. -/// -/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the -/// list element is a [`GenericStringArray`] whose n'th element is the substring matching -/// the n'th capturing parenthesized subexpression of the pattern. -/// -/// The flags parameter is an optional text string containing zero or more single-letter flags -/// that change the function's behavior. -/// -/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP -pub fn regexp_match( +fn regexp_array_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, @@ -248,6 +228,179 @@ pub fn regexp_match( Ok(Arc::new(list_builder.finish())) } +fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( + regex_array: &'a dyn Array, + flag_array: Option<&'a dyn Array>, +) -> (Option<&'a str>, Option<&'a str>) { + let regex = regex_array.as_string::(); + let regex = regex.is_valid(0).then(|| regex.value(0)); + + if let Some(flag_array) = flag_array { + let flag = flag_array.as_string::(); + (regex, flag.is_valid(0).then(|| flag.value(0))) + } else { + (regex, None) + } +} + +fn regexp_scalar_match( + array: &GenericStringArray, + regex: &Regex, +) -> Result { + let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + array + .iter() + .map(|value| { + match value { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + Some(_) if regex.as_str() == "" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + Some(value) => match regex.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + }, + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + Ok(Arc::new(list_builder.finish())) +} + +/// Extract all groups matched by a regular expression for a given String array. +/// +/// Modelled after the Postgres [regexp_match]. +/// +/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first +/// match of the corresponding index in `regex_array` to string in `array` +/// +/// If there is no match, the list element is NULL. +/// +/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, +/// then the list element is a single-element [`GenericStringArray`] containing the substring +/// matching the whole pattern. +/// +/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the +/// list element is a [`GenericStringArray`] whose n'th element is the substring matching +/// the n'th capturing parenthesized subexpression of the pattern. +/// +/// The flags parameter is an optional text string containing zero or more single-letter flags +/// that change the function's behavior. +/// +/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP +pub fn regexp_match( + array: &dyn Array, + regex_array: &dyn Datum, + flags_array: Option<&dyn Datum>, +) -> Result { + let (rhs, is_rhs_scalar) = regex_array.get(); + + if array.data_type() != rhs.data_type() { + return Err(ArrowError::ComputeError( + "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8" + .to_string(), + )); + } + + let (flags, is_flags_scalar) = match flags_array { + Some(flags) => { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), Some(is_flags_scalar)) + } + None => (None, None), + }; + + if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { + return Err(ArrowError::ComputeError( + "regexp_match() requires both pattern and flags to be either scalar or array" + .to_string(), + )); + } + + if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { + return Err(ArrowError::ComputeError( + "regexp_match() requires both pattern and flags to be either string or largestring" + .to_string(), + )); + } + + if is_rhs_scalar { + // Regex and flag is scalars + let (regex, flag) = match rhs.data_type() { + DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), + DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), + _ => { + return Err(ArrowError::ComputeError( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(), + )); + } + }; + + if regex.is_none() { + return Ok(new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + array.len(), + )); + } + + let regex = regex.unwrap(); + + let pattern = if let Some(flag) = flag { + format!("(?{flag}){regex}") + } else { + regex.to_string() + }; + + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) + })?; + + match array.data_type() { + DataType::Utf8 => regexp_scalar_match(array.as_string::(), &re), + DataType::LargeUtf8 => regexp_scalar_match(array.as_string::(), &re), + _ => Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )), + } + } else { + match array.data_type() { + DataType::Utf8 => { + let regex_array = rhs.as_string(); + let flags_array = flags.map(|flags| flags.as_string()); + regexp_array_match(array.as_string::(), regex_array, flags_array) + } + DataType::LargeUtf8 => { + let regex_array = rhs.as_string(); + let flags_array = flags.map(|flags| flags.as_string()); + regexp_array_match(array.as_string::(), regex_array, flags_array) + } + _ => Err(ArrowError::ComputeError( + "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + )), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -304,6 +457,49 @@ mod tests { assert_eq!(&expected, result); } + #[test] + fn match_scalar_pattern() { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1])); + let flags = Scalar::new(StringArray::from(vec!["i"; 1])); + let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); + let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.append(false); + expected_builder.values().append_value("7"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.append(false); + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + + // No flag + let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let actual = regexp_match(&array, &pattern, None).unwrap(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + + #[test] + fn match_scalar_no_pattern() { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1)); + let actual = regexp_match(&array, &pattern, None).unwrap(); + let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.append(false); + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + #[test] fn test_single_group_not_skip_match() { let array = StringArray::from(vec![Some("foo"), Some("bar")]); diff --git a/arrow/benches/regexp_kernels.rs b/arrow/benches/regexp_kernels.rs index eb38ba6783bc..d5ffbcb997ff 100644 --- a/arrow/benches/regexp_kernels.rs +++ b/arrow/benches/regexp_kernels.rs @@ -25,7 +25,7 @@ use arrow::array::*; use arrow::compute::kernels::regexp::*; use arrow::util::bench_util::*; -fn bench_regexp(arr: &GenericStringArray, regex_array: &GenericStringArray) { +fn bench_regexp(arr: &GenericStringArray, regex_array: &dyn Datum) { regexp_match(criterion::black_box(arr), regex_array, None).unwrap(); } @@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) { let pattern = GenericStringArray::::from(pattern_values); c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, &pattern))); + + let pattern_values = vec![r".*-(\d*)-.*"]; + let pattern = Scalar::new(GenericStringArray::::from(pattern_values)); + + c.bench_function("regexp scalar", |b| { + b.iter(|| bench_regexp(&arr_string, &pattern)) + }); } criterion_group!(benches, add_benchmark);