diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 7d04eeb..89965f0 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -31,6 +31,10 @@ def parse_sql_extended_expression(catalog, schema, sql): + """Parse a SQL SELECT statement into an ExtendedExpression. + + Only supports SELECT statements with projections and WHERE clauses. + """ select = sqlglot.parse_one(sql) if not isinstance(select, sqlglot.expressions.Select): raise ValueError("a SELECT statement was expected") @@ -41,12 +45,13 @@ def parse_sql_extended_expression(catalog, schema, sql): project_expressions = [] projection_invoked_functions = set() for sqlexpr in select.expressions: - invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot( - sqlexpr - ) - projection_invoked_functions.update(invoked_functions) + parsed_expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + projection_invoked_functions.update(parsed_expr.invoked_functions) project_expressions.append( - proto.ExpressionReference(expression=expr, output_names=[output_name]) + proto.ExpressionReference( + expression=parsed_expr.expression, + output_names=[parsed_expr.output_name], + ) ) extension_uris, extensions = catalog.extensions_for_functions( projection_invoked_functions @@ -59,17 +64,19 @@ def parse_sql_extended_expression(catalog, schema, sql): ) # Handle WHERE clause in the SELECT statement. - invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot( + filter_parsed_expr = sqlglot_parser.expression_from_sqlglot( select.find(sqlglot.expressions.Where).this ) extension_uris, extensions = catalog.extensions_for_functions( - invoked_functions_filter + filter_parsed_expr.invoked_functions ) filter_extended_expr = proto.ExtendedExpression( extension_uris=extension_uris, extensions=extensions, base_schema=schema, - referred_expr=[proto.ExpressionReference(expression=filter_expr)], + referred_expr=[ + proto.ExpressionReference(expression=filter_parsed_expr.expression) + ], ) return projection_extended_expr, filter_extended_expr @@ -82,16 +89,13 @@ def __init__(self, functions_catalog, schema): self._counter = itertools.count() def expression_from_sqlglot(self, sqlglot_node): - invoked_functions = set() - output_name, _, substrait_expr = self._parse_expression( - sqlglot_node, invoked_functions - ) - return invoked_functions, output_name, substrait_expr + """Parse a SQLGlot expression into a Substrait Expression.""" + return self._parse_expression(sqlglot_node) - def _parse_expression(self, expr, invoked_functions): + def _parse_expression(self, expr): if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( @@ -99,7 +103,7 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif expr.is_int: - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( @@ -107,7 +111,7 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif sqlglot.helper.is_float(expr.name): - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( @@ -120,7 +124,7 @@ def _parse_expression(self, expr, invoked_functions): column_name = expr.output_name schema_field = list(self._schema.names).index(column_name) schema_type = self._schema.struct.types[schema_field] - return ( + return ParsedSubstraitExpression( column_name, schema_type, proto.Expression( @@ -134,39 +138,47 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif isinstance(expr, sqlglot.expressions.Alias): - _, aliased_type, aliased_expr = self._parse_expression( - expr.this, invoked_functions - ) - return expr.output_name, aliased_type, aliased_expr + parsed_expression = self._parse_expression(expr.this) + return parsed_expression.duplicate(output_name=expr.output_name) elif expr.key in SQL_UNARY_FUNCTIONS: - argument_name, argument_type, argument = self._parse_expression( - expr.this, invoked_functions - ) + argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( - self._parse_function_invokation(function_name, argument_type, argument) - ) - invoked_functions.add(signature) - result_name = f"{function_name}_{argument_name}_{next(self._counter)}" - return result_name, result_type, function_expression - elif expr.key in SQL_BINARY_FUNCTIONS: - left_name, left_type, left = self._parse_expression( - expr.left, invoked_functions + self._parse_function_invokation( + function_name, + argument_parsed_expr.type, + argument_parsed_expr.expression, + ) ) - right_name, right_type, right = self._parse_expression( - expr.right, invoked_functions + result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, ) + elif expr.key in SQL_BINARY_FUNCTIONS: + left_parsed_expr = self._parse_expression(expr.left) + right_parsed_expr = self._parse_expression(expr.right) function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( self._parse_function_invokation( - function_name, left_type, left, right_type, right + function_name, + left_parsed_expr.type, + left_parsed_expr.expression, + right_parsed_expr.type, + right_parsed_expr.expression, ) ) - invoked_functions.add(signature) - result_name = ( - f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + left_parsed_expr.invoked_functions + | right_parsed_expr.invoked_functions + | {signature}, ) - return result_name, result_type, function_expression else: raise ValueError( f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" @@ -215,3 +227,24 @@ def _parse_function_invokation( ) ), ) + + +class ParsedSubstraitExpression: + def __init__(self, output_name, type, expression, invoked_functions=None): + self.expression = expression + self.output_name = output_name + self.type = type + + if invoked_functions is None: + invoked_functions = set() + self.invoked_functions = invoked_functions + + def duplicate( + self, output_name=None, type=None, expression=None, invoked_functions=None + ): + return ParsedSubstraitExpression( + output_name or self.output_name, + type or self.type, + expression or self.expression, + invoked_functions or self.invoked_functions, + ) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 5f6b0bb..8d72871 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -52,7 +52,10 @@ def load(self, dirpath, filename): for impl in function.get("impls", []): # TODO: There seem to be some functions that have arguments without type. What to do? # TODO: improve support complext type like LIST? - argtypes = [t.get("value", "unknown").strip("?") for t in impl.get("args", [])] + argtypes = [ + t.get("value", "unknown").strip("?") + for t in impl.get("args", []) + ] if not argtypes: signature = function_name else: