Skip to content

Commit 9d287bd

Browse files
alambOmega359
andauthored
Update REGEXP_MATCH scalar function to support Utf8View (#14449) (#14457)
* Update REGEXP_MATCH scalar function to support Utf8View * Cargo fmt fix. Co-authored-by: Bruce Ritchie <[email protected]>
1 parent 755b26a commit 9d287bd

File tree

4 files changed

+144
-43
lines changed

4 files changed

+144
-43
lines changed

datafusion/functions/benches/regx.rs

+57-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
extern crate criterion;
1919

2020
use arrow::array::builder::StringBuilder;
21-
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray};
21+
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray};
2222
use arrow::compute::cast;
2323
use arrow::datatypes::DataType;
2424
use criterion::{black_box, criterion_group, criterion_main, Criterion};
@@ -141,6 +141,20 @@ fn criterion_benchmark(c: &mut Criterion) {
141141
})
142142
});
143143

144+
c.bench_function("regexp_like_1000 utf8view", |b| {
145+
let mut rng = rand::thread_rng();
146+
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
147+
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
148+
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();
149+
150+
b.iter(|| {
151+
black_box(
152+
regexp_like(&[Arc::clone(&data), Arc::clone(&regex), Arc::clone(&flags)])
153+
.expect("regexp_like should work on valid values"),
154+
)
155+
})
156+
});
157+
144158
c.bench_function("regexp_match_1000", |b| {
145159
let mut rng = rand::thread_rng();
146160
let data = Arc::new(data(&mut rng)) as ArrayRef;
@@ -149,7 +163,25 @@ fn criterion_benchmark(c: &mut Criterion) {
149163

150164
b.iter(|| {
151165
black_box(
152-
regexp_match::<i32>(&[
166+
regexp_match(&[
167+
Arc::clone(&data),
168+
Arc::clone(&regex),
169+
Arc::clone(&flags),
170+
])
171+
.expect("regexp_match should work on valid values"),
172+
)
173+
})
174+
});
175+
176+
c.bench_function("regexp_match_1000 utf8view", |b| {
177+
let mut rng = rand::thread_rng();
178+
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
179+
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
180+
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();
181+
182+
b.iter(|| {
183+
black_box(
184+
regexp_match(&[
153185
Arc::clone(&data),
154186
Arc::clone(&regex),
155187
Arc::clone(&flags),
@@ -180,6 +212,29 @@ fn criterion_benchmark(c: &mut Criterion) {
180212
)
181213
})
182214
});
215+
216+
c.bench_function("regexp_replace_1000 utf8view", |b| {
217+
let mut rng = rand::thread_rng();
218+
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
219+
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
220+
// flags are not allowed to be utf8view according to the function
221+
let flags = Arc::new(flags(&mut rng)) as ArrayRef;
222+
let replacement = Arc::new(StringViewArray::from_iter_values(
223+
iter::repeat("XX").take(1000),
224+
));
225+
226+
b.iter(|| {
227+
black_box(
228+
regexp_replace::<i32, _, _>(
229+
data.as_string_view(),
230+
regex.as_string_view(),
231+
&replacement,
232+
Some(&flags),
233+
)
234+
.expect("regexp_replace should work on valid values"),
235+
)
236+
})
237+
});
183238
}
184239

185240
criterion_group!(benches, criterion_benchmark);

datafusion/functions/src/regex/regexpmatch.rs

+33-33
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
// under the License.
1717

1818
//! Regex expressions
19-
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
19+
use arrow::array::{Array, ArrayRef, AsArray};
2020
use arrow::compute::kernels::regexp;
2121
use arrow::datatypes::DataType;
2222
use arrow::datatypes::Field;
2323
use datafusion_common::exec_err;
2424
use datafusion_common::ScalarValue;
2525
use datafusion_common::{arrow_datafusion_err, plan_err};
26-
use datafusion_common::{
27-
cast::as_generic_string_array, internal_err, DataFusionError, Result,
28-
};
26+
use datafusion_common::{DataFusionError, Result};
2927
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
3028
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3129
use datafusion_macros::user_doc;
@@ -86,11 +84,12 @@ impl RegexpMatchFunc {
8684
signature: Signature::one_of(
8785
vec![
8886
// Planner attempts coercion to the target type starting with the most preferred candidate.
89-
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`.
90-
// If that fails, it proceeds to `(LargeUtf8, Utf8)`.
91-
// TODO: Native support Utf8View for regexp_match.
87+
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
88+
// If that fails, it proceeds to `(Utf8, Utf8)`.
89+
TypeSignature::Exact(vec![Utf8View, Utf8View]),
9290
TypeSignature::Exact(vec![Utf8, Utf8]),
9391
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
92+
TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
9493
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
9594
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
9695
],
@@ -138,7 +137,7 @@ impl ScalarUDFImpl for RegexpMatchFunc {
138137
.map(|arg| arg.to_array(inferred_length))
139138
.collect::<Result<Vec<_>>>()?;
140139

141-
let result = regexp_match_func(&args);
140+
let result = regexp_match(&args);
142141
if is_scalar {
143142
// If all inputs are scalar, keeps output as scalar
144143
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
@@ -153,33 +152,35 @@ impl ScalarUDFImpl for RegexpMatchFunc {
153152
}
154153
}
155154

156-
fn regexp_match_func(args: &[ArrayRef]) -> Result<ArrayRef> {
157-
match args[0].data_type() {
158-
DataType::Utf8 => regexp_match::<i32>(args),
159-
DataType::LargeUtf8 => regexp_match::<i64>(args),
160-
other => {
161-
internal_err!("Unsupported data type {other:?} for function regexp_match")
162-
}
163-
}
164-
}
165-
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
155+
pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
166156
match args.len() {
167157
2 => {
168-
let values = as_generic_string_array::<T>(&args[0])?;
169-
let regex = as_generic_string_array::<T>(&args[1])?;
170-
regexp::regexp_match(values, regex, None)
158+
regexp::regexp_match(&args[0], &args[1], None)
171159
.map_err(|e| arrow_datafusion_err!(e))
172160
}
173161
3 => {
174-
let values = as_generic_string_array::<T>(&args[0])?;
175-
let regex = as_generic_string_array::<T>(&args[1])?;
176-
let flags = as_generic_string_array::<T>(&args[2])?;
177-
178-
if flags.iter().any(|s| s == Some("g")) {
179-
return plan_err!("regexp_match() does not support the \"global\" option");
162+
match args[2].data_type() {
163+
DataType::Utf8View => {
164+
if args[2].as_string_view().iter().any(|s| s == Some("g")) {
165+
return plan_err!("regexp_match() does not support the \"global\" option");
166+
}
167+
}
168+
DataType::Utf8 => {
169+
if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
170+
return plan_err!("regexp_match() does not support the \"global\" option");
171+
}
172+
}
173+
DataType::LargeUtf8 => {
174+
if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
175+
return plan_err!("regexp_match() does not support the \"global\" option");
176+
}
177+
}
178+
e => {
179+
return plan_err!("regexp_match was called with unexpected data type {e:?}");
180+
}
180181
}
181182

182-
regexp::regexp_match(values, regex, Some(flags))
183+
regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
183184
.map_err(|e| arrow_datafusion_err!(e))
184185
}
185186
other => exec_err!(
@@ -211,7 +212,7 @@ mod tests {
211212
expected_builder.append(false);
212213
let expected = expected_builder.finish();
213214

214-
let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
215+
let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();
215216

216217
assert_eq!(re.as_ref(), &expected);
217218
}
@@ -236,9 +237,8 @@ mod tests {
236237
expected_builder.append(false);
237238
let expected = expected_builder.finish();
238239

239-
let re =
240-
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
241-
.unwrap();
240+
let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
241+
.unwrap();
242242

243243
assert_eq!(re.as_ref(), &expected);
244244
}
@@ -250,7 +250,7 @@ mod tests {
250250
let flags = StringArray::from(vec!["g"]);
251251

252252
let re_err =
253-
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
253+
regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
254254
.expect_err("unsupported flag should have failed");
255255

256256
assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option");

datafusion/sqllogictest/test_files/regexp.slt

+53-7
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,29 @@ NULL
193193
[Köln]
194194
[إسرائيل]
195195

196+
# test string view
197+
statement ok
198+
CREATE TABLE t_stringview AS
199+
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;
200+
201+
query ?
202+
SELECT regexp_match(str, pattern, flags) FROM t_stringview;
203+
----
204+
[a]
205+
[A]
206+
[B]
207+
NULL
208+
NULL
209+
NULL
210+
[010]
211+
[Düsseldorf]
212+
[Москва]
213+
[Köln]
214+
[إسرائيل]
215+
216+
statement ok
217+
DROP TABLE t_stringview;
218+
196219
query ?
197220
SELECT regexp_match('foobarbequebaz', '');
198221
----
@@ -354,6 +377,29 @@ X
354377
X
355378
X
356379

380+
# test string view
381+
statement ok
382+
CREATE TABLE t_stringview AS
383+
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;
384+
385+
query T
386+
SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview;
387+
----
388+
Xbc
389+
X
390+
aXc
391+
AbC
392+
aBC
393+
4000
394+
X
395+
X
396+
X
397+
X
398+
X
399+
400+
statement ok
401+
DROP TABLE t_stringview;
402+
357403
query T
358404
SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi');
359405
----
@@ -621,7 +667,7 @@ CREATE TABLE t_stringview AS
621667
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t;
622668

623669
query I
624-
SELECT regexp_count(str, '\w') from t;
670+
SELECT regexp_count(str, '\w') from t_stringview;
625671
----
626672
3
627673
3
@@ -636,7 +682,7 @@ SELECT regexp_count(str, '\w') from t;
636682
7
637683

638684
query I
639-
SELECT regexp_count(str, '\w{2}', start) from t;
685+
SELECT regexp_count(str, '\w{2}', start) from t_stringview;
640686
----
641687
1
642688
1
@@ -651,7 +697,7 @@ SELECT regexp_count(str, '\w{2}', start) from t;
651697
3
652698

653699
query I
654-
SELECT regexp_count(str, 'ab', 1, 'i') from t;
700+
SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview;
655701
----
656702
1
657703
1
@@ -667,7 +713,7 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t;
667713

668714

669715
query I
670-
SELECT regexp_count(str, pattern) from t;
716+
SELECT regexp_count(str, pattern) from t_stringview;
671717
----
672718
1
673719
1
@@ -682,7 +728,7 @@ SELECT regexp_count(str, pattern) from t;
682728
1
683729

684730
query I
685-
SELECT regexp_count(str, pattern, start) from t;
731+
SELECT regexp_count(str, pattern, start) from t_stringview;
686732
----
687733
1
688734
1
@@ -697,7 +743,7 @@ SELECT regexp_count(str, pattern, start) from t;
697743
1
698744

699745
query I
700-
SELECT regexp_count(str, pattern, start, flags) from t;
746+
SELECT regexp_count(str, pattern, start, flags) from t_stringview;
701747
----
702748
1
703749
1
@@ -713,7 +759,7 @@ SELECT regexp_count(str, pattern, start, flags) from t;
713759

714760
# test type coercion
715761
query I
716-
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t;
762+
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview;
717763
----
718764
1
719765
1

datafusion/sqllogictest/test_files/string/string_view.slt

+1-1
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ EXPLAIN SELECT
794794
FROM test;
795795
----
796796
logical_plan
797-
01)Projection: regexp_match(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k
797+
01)Projection: regexp_match(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k
798798
02)--TableScan: test projection=[column1_utf8view]
799799

800800
## Ensure no casts for REGEXP_REPLACE

0 commit comments

Comments
 (0)