From 62743e3c928674dea6399b6d82612d84ea4f2b37 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 5 May 2024 10:04:10 +0100 Subject: [PATCH] Fix `json_length` with only one array argument (#8) --- src/common.rs | 13 +++++++++---- src/json_length.rs | 4 ++-- tests/main.rs | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/common.rs b/src/common.rs index 27d41e0..15fb3f1 100644 --- a/src/common.rs +++ b/src/common.rs @@ -64,10 +64,14 @@ pub fn invoke> + 'static, I>( to_array: impl Fn(C) -> DataFusionResult, to_scalar: impl Fn(Option) -> ScalarValue, ) -> DataFusionResult { - match &args[0] { + let Some(first_arg) = args.first() else { + // I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe + return exec_err!("expected at least one argument"); + }; + match first_arg { ColumnarValue::Array(json_array) => { - let result_collect = match &args[1] { - ColumnarValue::Array(a) => { + let result_collect = match args.get(1) { + Some(ColumnarValue::Array(a)) => { if let Some(str_path_array) = a.as_any().downcast_ref::() { let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); zip_apply(json_array, paths, jiter_find) @@ -84,7 +88,8 @@ pub fn invoke> + 'static, I>( return exec_err!("unexpected second argument type, expected string or int array"); } } - ColumnarValue::Scalar(_) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), + Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), + None => scalar_apply(json_array, &[], jiter_find), }; to_array(result_collect?).map(ColumnarValue::from) } diff --git a/src/json_length.rs b/src/json_length.rs index 2107979..b1bb900 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -20,14 +20,14 @@ make_udf_function!( #[derive(Debug)] pub(super) struct JsonLength { signature: Signature, - aliases: [String; 1], + aliases: [String; 2], } impl Default for JsonLength { fn default() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_length".to_string()], + aliases: ["json_length".to_string(), "json_len".to_string()], } } } diff --git a/tests/main.rs b/tests/main.rs index 263c2ef..7974ec9 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -429,3 +429,35 @@ async fn test_json_contains_large_both_params() { let batches = run_query_params(sql, true, params).await.unwrap(); assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_json_length_vec() { + let sql = r#"select name, json_len(json_data) as len from test"#; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+------------------+-----+", + "| name | len |", + "+------------------+-----+", + "| object_foo | 1 |", + "| object_foo_array | 1 |", + "| object_foo_obj | 1 |", + "| object_foo_null | 1 |", + "| object_bar | 1 |", + "| list_foo | 1 |", + "| invalid_json | |", + "+------------------+-----+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query_large(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_no_args() { + let err = run_query(r#"select json_len()"#).await.unwrap_err(); + assert!(err + .to_string() + .contains("No function matches the given name and argument types 'json_length()'.")); +}