From d041a7295c79d97c72ffacf63536a5dd695539d1 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 25 Jun 2024 11:16:27 +0100 Subject: [PATCH] json_find: make implementation much shorter (#18) Co-authored-by: Samuel Colvin --- Cargo.toml | 7 +++++ benches/main.rs | 38 +++++++++++++++++++++++ src/common.rs | 82 +++++++++++++++++-------------------------------- src/lib.rs | 11 +++++++ 4 files changed, 85 insertions(+), 53 deletions(-) create mode 100644 benches/main.rs diff --git a/Cargo.toml b/Cargo.toml index 7b4d201..aa4521a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,10 @@ log = "0.4" datafusion-execution = "39" [dev-dependencies] +codspeed-criterion-compat = "2.3" +criterion = "0.5.1" datafusion = "39" +clap = "~4.4" # for testing on MSRV 1.73 tokio = { version = "1.37", features = ["full"] } [lints.clippy] @@ -33,3 +36,7 @@ print_stdout = "deny" pedantic = { level = "deny", priority = -1 } missing_errors_doc = "allow" cast_possible_truncation = "allow" + +[[bench]] +name = "main" +harness = false diff --git a/benches/main.rs b/benches/main.rs new file mode 100644 index 0000000..92ec395 --- /dev/null +++ b/benches/main.rs @@ -0,0 +1,38 @@ +use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; + +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; + +fn bench_json_contains(b: &mut Bencher) { + let json_contains = json_contains_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + b.iter(|| json_contains.invoke(args).unwrap()); +} + +fn bench_json_get_str(b: &mut Bencher) { + let json_get_str = json_get_str_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + b.iter(|| json_get_str.invoke(args).unwrap()); +} +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("json_contains", bench_json_contains); + c.bench_function("json_get_str", bench_json_get_str); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/src/common.rs b/src/common.rs index 3bf1944..a630915 100644 --- a/src/common.rs +++ b/src/common.rs @@ -158,16 +158,39 @@ fn scalar_apply_iter<'j, C: FromIterator>, I>( } pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { - if let Some(json_str) = opt_json { - let mut jiter = Jiter::new(json_str.as_bytes()); - if let Ok(peek) = jiter.peek() { - if let Ok(peek_found) = jiter_json_find_step(&mut jiter, peek, path) { - return Some((jiter, peek_found)); + let json_str = opt_json?; + let mut jiter = Jiter::new(json_str.as_bytes()); + let mut peek = jiter.peek().ok()?; + for element in path { + match element { + JsonPath::Key(key) if peek == Peek::Object => { + let mut next_key = jiter.known_object().ok()??; + + while next_key != *key { + jiter.next_skip().ok()?; + next_key = jiter.next_key().ok()??; + } + + peek = jiter.peek().ok()?; + } + JsonPath::Index(index) if peek == Peek::Array => { + let mut array_item = jiter.known_array().ok()??; + + for _ in 0..*index { + jiter.known_skip(array_item).ok()?; + array_item = jiter.array_step().ok()??; + } + + peek = array_item; + } + _ => { + return None; } } } - None + Some((jiter, peek)) } + macro_rules! get_err { () => { Err(GetError) @@ -175,53 +198,6 @@ macro_rules! get_err { } pub(crate) use get_err; -fn jiter_json_find_step(jiter: &mut Jiter, peek: Peek, path: &[JsonPath]) -> Result { - let Some((first, rest)) = path.split_first() else { - return Ok(peek); - }; - let next_peek = match peek { - Peek::Array => match first { - JsonPath::Index(index) => jiter_array_get(jiter, *index), - _ => get_err!(), - }, - Peek::Object => match first { - JsonPath::Key(key) => jiter_object_get(jiter, key), - _ => get_err!(), - }, - _ => get_err!(), - }?; - jiter_json_find_step(jiter, next_peek, rest) -} - -fn jiter_array_get(jiter: &mut Jiter, find_key: usize) -> Result { - let mut peek_opt = jiter.known_array()?; - - let mut index: usize = 0; - while let Some(peek) = peek_opt { - if index == find_key { - return Ok(peek); - } - jiter.known_skip(peek)?; - index += 1; - peek_opt = jiter.array_step()?; - } - get_err!() -} - -fn jiter_object_get(jiter: &mut Jiter, find_key: &str) -> Result { - let mut opt_key = jiter.known_object()?; - - while let Some(key) = opt_key { - if key == find_key { - let value_peek = jiter.peek()?; - return Ok(value_peek); - } - jiter.next_skip()?; - opt_key = jiter.next_key()?; - } - get_err!() -} - pub struct GetError; impl From for GetError { diff --git a/src/lib.rs b/src/lib.rs index e5dd590..ebd9c8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,17 @@ pub mod functions { pub use crate::json_length::json_length; } +pub mod udfs { + pub use crate::json_contains::json_contains_udf; + pub use crate::json_get::json_get_udf; + pub use crate::json_get_bool::json_get_bool_udf; + pub use crate::json_get_float::json_get_float_udf; + pub use crate::json_get_int::json_get_int_udf; + pub use crate::json_get_json::json_get_json_udf; + pub use crate::json_get_str::json_get_str_udf; + pub use crate::json_length::json_length_udf; +} + /// Register all JSON UDFs pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![