Skip to content

Commit

Permalink
Unparse map to sql (#13532)
Browse files Browse the repository at this point in the history
* map to sql

* fix sqllogictest for map arg len error

* match array expr concisely
  • Loading branch information
delamarch3 authored Nov 23, 2024
1 parent c0ca4b4 commit eaf51ba
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
4 changes: 2 additions & 2 deletions datafusion/functions-nested/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ impl ScalarUDFImpl for MapFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() % 2 != 0 {
if arg_types.len() != 2 {
return exec_err!(
"map requires an even number of arguments, got {} instead",
"map requires exactly 2 arguments, got {} instead",
arg_types.len()
);
}
Expand Down
39 changes: 39 additions & 0 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ impl Unparser<'_> {
"array_element" => self.array_element_to_sql(args),
"named_struct" => self.named_struct_to_sql(args),
"get_field" => self.get_field_to_sql(args),
"map" => self.map_to_sql(args),
// TODO: support for the construct and access functions of the `map` type
_ => self.scalar_function_to_sql_internal(func_name, args),
}
Expand Down Expand Up @@ -567,6 +568,39 @@ impl Unparser<'_> {
Ok(ast::Expr::CompoundIdentifier(id))
}

fn map_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
if args.len() != 2 {
return internal_err!("map must have exactly 2 arguments");
}

let ast::Expr::Array(Array { elem: keys, .. }) = self.expr_to_sql(&args[0])?
else {
return internal_err!(
"map expects first argument to be an array, but received: {:?}",
&args[0]
);
};

let ast::Expr::Array(Array { elem: values, .. }) = self.expr_to_sql(&args[1])?
else {
return internal_err!(
"map expects second argument to be an array, but received: {:?}",
&args[1]
);
};

let entries = keys
.into_iter()
.zip(values)
.map(|(key, value)| ast::MapEntry {
key: Box::new(key),
value: Box::new(value),
})
.collect();

Ok(ast::Expr::Map(ast::Map { entries }))
}

pub fn sort_to_sql(&self, sort: &Sort) -> Result<ast::OrderByExpr> {
let Sort {
expr,
Expand Down Expand Up @@ -1581,6 +1615,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_nested::expr_fn::{array_element, make_array};
use datafusion_functions_nested::map::map;
use datafusion_functions_window::row_number::row_number_udwf;

use crate::unparser::dialect::{
Expand Down Expand Up @@ -1996,6 +2031,10 @@ mod tests {
"{a: '1', b: 2}",
),
(get_field(col("a.b"), "c"), "a.b.c"),
(
map(vec![lit("a"), lit("b")], vec![lit(1), lit(2)]),
"MAP {'a': 1, 'b': 2}",
),
];

for (expr, expected) in tests {
Expand Down
5 changes: 4 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
use datafusion_functions_nested::make_array::make_array_udf;
use datafusion_functions_nested::map::map_udf;
use datafusion_functions_window::rank::rank_udwf;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
Expand Down Expand Up @@ -190,7 +191,8 @@ fn roundtrip_statement() -> Result<()> {
"SELECT [1, 2, 3][1]",
"SELECT left[1] FROM array",
"SELECT {a:1, b:2}",
"SELECT s.a FROM (SELECT {a:1, b:2} AS s)"
"SELECT s.a FROM (SELECT {a:1, b:2} AS s)",
"SELECT MAP {'a': 1, 'b': 2}"
];

// For each test sql string, we transform as follows:
Expand All @@ -206,6 +208,7 @@ fn roundtrip_statement() -> Result<()> {
let state = MockSessionState::default()
.with_scalar_function(make_array_udf())
.with_scalar_function(array_element_udf())
.with_scalar_function(map_udf())
.with_aggregate_function(sum_udaf())
.with_aggregate_function(count_udaf())
.with_aggregate_function(max_udaf())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ SELECT MAP([[1,2], [3,4]], ['a', 'b']);
query error
SELECT MAP()

query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead
query error DataFusion error: Execution error: map requires exactly 2 arguments, got 1 instead
SELECT MAP(['POST', 'HEAD'])

query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null
Expand Down

0 comments on commit eaf51ba

Please sign in to comment.