diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 73aad10a8e26..1211945a8b9d 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -214,9 +214,9 @@ impl ScalarUDFImpl for MapFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() % 2 != 0 { + if arg_types.len() != 2 { return exec_err!( - "map requires an even number of arguments, got {} instead", + "map requires exactly 2 arguments, got {} instead", arg_types.len() ); } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ae2607de00a2..6660c425b4ee 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -464,6 +464,7 @@ impl Unparser<'_> { "array_element" => self.array_element_to_sql(args), "named_struct" => self.named_struct_to_sql(args), "get_field" => self.get_field_to_sql(args), + "map" => self.map_to_sql(args), // TODO: support for the construct and access functions of the `map` type _ => self.scalar_function_to_sql_internal(func_name, args), } @@ -567,6 +568,39 @@ impl Unparser<'_> { Ok(ast::Expr::CompoundIdentifier(id)) } + fn map_to_sql(&self, args: &[Expr]) -> Result { + if args.len() != 2 { + return internal_err!("map must have exactly 2 arguments"); + } + + let ast::Expr::Array(Array { elem: keys, .. }) = self.expr_to_sql(&args[0])? + else { + return internal_err!( + "map expects first argument to be an array, but received: {:?}", + &args[0] + ); + }; + + let ast::Expr::Array(Array { elem: values, .. }) = self.expr_to_sql(&args[1])? + else { + return internal_err!( + "map expects second argument to be an array, but received: {:?}", + &args[1] + ); + }; + + let entries = keys + .into_iter() + .zip(values) + .map(|(key, value)| ast::MapEntry { + key: Box::new(key), + value: Box::new(value), + }) + .collect(); + + Ok(ast::Expr::Map(ast::Map { entries })) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -1581,6 +1615,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_functions_nested::expr_fn::{array_element, make_array}; + use datafusion_functions_nested::map::map; use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ @@ -1996,6 +2031,10 @@ mod tests { "{a: '1', b: 2}", ), (get_field(col("a.b"), "c"), "a.b.c"), + ( + map(vec![lit("a"), lit("b")], vec![lit(1), lit(2)]), + "MAP {'a': 1, 'b': 2}", + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 58d99549de31..eee79399701c 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -27,6 +27,7 @@ use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; +use datafusion_functions_nested::map::map_udf; use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -190,7 +191,8 @@ fn roundtrip_statement() -> Result<()> { "SELECT [1, 2, 3][1]", "SELECT left[1] FROM array", "SELECT {a:1, b:2}", - "SELECT s.a FROM (SELECT {a:1, b:2} AS s)" + "SELECT s.a FROM (SELECT {a:1, b:2} AS s)", + "SELECT MAP {'a': 1, 'b': 2}" ]; // For each test sql string, we transform as follows: @@ -206,6 +208,7 @@ fn roundtrip_statement() -> Result<()> { let state = MockSessionState::default() .with_scalar_function(make_array_udf()) .with_scalar_function(array_element_udf()) + .with_scalar_function(map_udf()) .with_aggregate_function(sum_udaf()) .with_aggregate_function(count_udaf()) .with_aggregate_function(max_udaf()) diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 10ca3ae881bf..28fc2f4b0b80 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -185,7 +185,7 @@ SELECT MAP([[1,2], [3,4]], ['a', 'b']); query error SELECT MAP() -query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead +query error DataFusion error: Execution error: map requires exactly 2 arguments, got 1 instead SELECT MAP(['POST', 'HEAD']) query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null