From 9665e09476072c44129a959d01895ffa67eea0bd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 26 Dec 2024 20:33:34 -0500 Subject: [PATCH] Correct return type for initcap scalar function with utf8view (#13909) * Set utf8view as return type when input type is the same * Verify that the returned type from call to scalar function matches the return type specified in the return_type function * Match return type to utf8view --- datafusion/functions/src/unicode/initcap.rs | 18 +++++++++++------- datafusion/functions/src/utils.rs | 1 + 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index e9f966b95868..c21fb77c9eca 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -63,7 +63,11 @@ impl ScalarUDFImpl for InitcapFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "initcap") + if let DataType::Utf8View = arg_types[0] { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "initcap") + } } fn invoke_batch( @@ -188,7 +192,7 @@ mod tests { use crate::unicode::initcap::InitcapFunc; use crate::utils::test::test_function; use arrow::array::{Array, StringArray, StringViewArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -247,7 +251,7 @@ mod tests { )))], Ok(Some("Hi Thomas")), &str, - Utf8, + Utf8View, StringViewArray ); test_function!( @@ -257,7 +261,7 @@ mod tests { )))], Ok(Some("Hi Thomas With M0re Than 12 Chars")), &str, - Utf8, + Utf8View, StringViewArray ); test_function!( @@ -270,7 +274,7 @@ mod tests { "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική" )), &str, - Utf8, + Utf8View, StringViewArray ); test_function!( @@ -280,7 +284,7 @@ mod tests { )))], Ok(Some("")), &str, - Utf8, + Utf8View, StringViewArray ); test_function!( @@ -288,7 +292,7 @@ mod tests { vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))], Ok(None), &str, - Utf8, + Utf8View, StringViewArray ); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 53f607492266..39d8aeeda460 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -154,6 +154,7 @@ pub mod test { let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); // value is correct match expected {