Skip to content

Commit

Permalink
json_find: make implementation much shorter (#18)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
davidhewitt and samuelcolvin authored Jun 25, 2024
1 parent b1e0922 commit d041a72
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 53 deletions.
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ log = "0.4"
datafusion-execution = "39"

[dev-dependencies]
codspeed-criterion-compat = "2.3"
criterion = "0.5.1"
datafusion = "39"
clap = "~4.4" # for testing on MSRV 1.73
tokio = { version = "1.37", features = ["full"] }

[lints.clippy]
Expand All @@ -33,3 +36,7 @@ print_stdout = "deny"
pedantic = { level = "deny", priority = -1 }
missing_errors_doc = "allow"
cast_possible_truncation = "allow"

[[bench]]
name = "main"
harness = false
38 changes: 38 additions & 0 deletions benches/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion};

use datafusion_common::ScalarValue;
use datafusion_expr::ColumnarValue;
use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf};

fn bench_json_contains(b: &mut Bencher) {
let json_contains = json_contains_udf();
let args = &[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
];

b.iter(|| json_contains.invoke(args).unwrap());
}

fn bench_json_get_str(b: &mut Bencher) {
let json_get_str = json_get_str_udf();
let args = &[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
];

b.iter(|| json_get_str.invoke(args).unwrap());
}
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("json_contains", bench_json_contains);
c.bench_function("json_get_str", bench_json_get_str);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
82 changes: 29 additions & 53 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,70 +158,46 @@ fn scalar_apply_iter<'j, C: FromIterator<Option<I>>, I>(
}

pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {
if let Some(json_str) = opt_json {
let mut jiter = Jiter::new(json_str.as_bytes());
if let Ok(peek) = jiter.peek() {
if let Ok(peek_found) = jiter_json_find_step(&mut jiter, peek, path) {
return Some((jiter, peek_found));
let json_str = opt_json?;
let mut jiter = Jiter::new(json_str.as_bytes());
let mut peek = jiter.peek().ok()?;
for element in path {
match element {
JsonPath::Key(key) if peek == Peek::Object => {
let mut next_key = jiter.known_object().ok()??;

while next_key != *key {
jiter.next_skip().ok()?;
next_key = jiter.next_key().ok()??;
}

peek = jiter.peek().ok()?;
}
JsonPath::Index(index) if peek == Peek::Array => {
let mut array_item = jiter.known_array().ok()??;

for _ in 0..*index {
jiter.known_skip(array_item).ok()?;
array_item = jiter.array_step().ok()??;
}

peek = array_item;
}
_ => {
return None;
}
}
}
None
Some((jiter, peek))
}

macro_rules! get_err {
() => {
Err(GetError)
};
}
pub(crate) use get_err;

fn jiter_json_find_step(jiter: &mut Jiter, peek: Peek, path: &[JsonPath]) -> Result<Peek, GetError> {
let Some((first, rest)) = path.split_first() else {
return Ok(peek);
};
let next_peek = match peek {
Peek::Array => match first {
JsonPath::Index(index) => jiter_array_get(jiter, *index),
_ => get_err!(),
},
Peek::Object => match first {
JsonPath::Key(key) => jiter_object_get(jiter, key),
_ => get_err!(),
},
_ => get_err!(),
}?;
jiter_json_find_step(jiter, next_peek, rest)
}

fn jiter_array_get(jiter: &mut Jiter, find_key: usize) -> Result<Peek, GetError> {
let mut peek_opt = jiter.known_array()?;

let mut index: usize = 0;
while let Some(peek) = peek_opt {
if index == find_key {
return Ok(peek);
}
jiter.known_skip(peek)?;
index += 1;
peek_opt = jiter.array_step()?;
}
get_err!()
}

fn jiter_object_get(jiter: &mut Jiter, find_key: &str) -> Result<Peek, GetError> {
let mut opt_key = jiter.known_object()?;

while let Some(key) = opt_key {
if key == find_key {
let value_peek = jiter.peek()?;
return Ok(value_peek);
}
jiter.next_skip()?;
opt_key = jiter.next_key()?;
}
get_err!()
}

pub struct GetError;

impl From<JiterError> for GetError {
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ pub mod functions {
pub use crate::json_length::json_length;
}

pub mod udfs {
pub use crate::json_contains::json_contains_udf;
pub use crate::json_get::json_get_udf;
pub use crate::json_get_bool::json_get_bool_udf;
pub use crate::json_get_float::json_get_float_udf;
pub use crate::json_get_int::json_get_int_udf;
pub use crate::json_get_json::json_get_json_udf;
pub use crate::json_get_str::json_get_str_udf;
pub use crate::json_length::json_length_udf;
}

/// Register all JSON UDFs
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<ScalarUDF>> = vec![
Expand Down

0 comments on commit d041a72

Please sign in to comment.