Skip to content

Commit

Permalink
Fix json_length with only one array argument (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored May 5, 2024
1 parent 0a4493c commit 62743e3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
13 changes: 9 additions & 4 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
to_scalar: impl Fn(Option<I>) -> ScalarValue,
) -> DataFusionResult<ColumnarValue> {
match &args[0] {
let Some(first_arg) = args.first() else {
// I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe
return exec_err!("expected at least one argument");
};
match first_arg {
ColumnarValue::Array(json_array) => {
let result_collect = match &args[1] {
ColumnarValue::Array(a) => {
let result_collect = match args.get(1) {
Some(ColumnarValue::Array(a)) => {
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find)
Expand All @@ -84,7 +88,8 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
return exec_err!("unexpected second argument type, expected string or int array");
}
}
ColumnarValue::Scalar(_) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),
Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),
None => scalar_apply(json_array, &[], jiter_find),
};
to_array(result_collect?).map(ColumnarValue::from)
}
Expand Down
4 changes: 2 additions & 2 deletions src/json_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ make_udf_function!(
#[derive(Debug)]
pub(super) struct JsonLength {
signature: Signature,
aliases: [String; 1],
aliases: [String; 2],
}

impl Default for JsonLength {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_length".to_string()],
aliases: ["json_length".to_string(), "json_len".to_string()],
}
}
}
Expand Down
32 changes: 32 additions & 0 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,35 @@ async fn test_json_contains_large_both_params() {
let batches = run_query_params(sql, true, params).await.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_length_vec() {
let sql = r#"select name, json_len(json_data) as len from test"#;
let batches = run_query(sql).await.unwrap();

let expected = [
"+------------------+-----+",
"| name | len |",
"+------------------+-----+",
"| object_foo | 1 |",
"| object_foo_array | 1 |",
"| object_foo_obj | 1 |",
"| object_foo_null | 1 |",
"| object_bar | 1 |",
"| list_foo | 1 |",
"| invalid_json | |",
"+------------------+-----+",
];
assert_batches_eq!(expected, &batches);

let batches = run_query_large(sql).await.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_no_args() {
let err = run_query(r#"select json_len()"#).await.unwrap_err();
assert!(err
.to_string()
.contains("No function matches the given name and argument types 'json_length()'."));
}

0 comments on commit 62743e3

Please sign in to comment.