diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 89965f0..57b8eb3 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -93,10 +93,16 @@ def expression_from_sqlglot(self, sqlglot_node): return self._parse_expression(sqlglot_node) def _parse_expression(self, expr): + """Parse a SQLGlot node and return a Substrait expression. + + This is the internal implementation, expected to be + invoked in a recursive manner to parse the whole + expression tree. + """ if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( literal=proto.Expression.Literal(string=expr.text) @@ -104,7 +110,7 @@ def _parse_expression(self, expr): ) elif expr.is_int: return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( literal=proto.Expression.Literal(i32=int(expr.name)) @@ -112,7 +118,7 @@ def _parse_expression(self, expr): ) elif sqlglot.helper.is_float(expr.name): return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( literal=proto.Expression.Literal(float=float(expr.name)) @@ -144,11 +150,7 @@ def _parse_expression(self, expr): argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( - self._parse_function_invokation( - function_name, - argument_parsed_expr.type, - argument_parsed_expr.expression, - ) + self._parse_function_invokation(function_name, argument_parsed_expr) ) result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( @@ -163,11 +165,7 @@ def _parse_expression(self, expr): function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( self._parse_function_invokation( - function_name, - left_parsed_expr.type, - left_parsed_expr.expression, - right_parsed_expr.type, - right_parsed_expr.expression, + function_name, left_parsed_expr, right_parsed_expr ) ) result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" @@ -185,24 +183,27 @@ def _parse_expression(self, expr): ) def _parse_function_invokation( - self, function_name, left_type, left, right_type=None, right=None + self, function_name, argument_parsed_expr, *additional_arguments ): - binary = False - argtypes = [left_type] - if right_type or right: - binary = True - argtypes.append(right_type) - signature = self._functions_catalog.signature(function_name, argtypes) + """Generates a Substrait function invokation expression. + + The function invocation will be generated from the function name + and the arguments as ParsedSubstraitExpression. + + Returns the function signature, the return type and the + invokation expression itself. + """ + arguments = [argument_parsed_expr] + list(additional_arguments) + signature = self._functions_catalog.signature( + function_name, proto_argtypes=[arg.type for arg in arguments] + ) try: function_anchor = self._functions_catalog.function_anchor(signature) except KeyError: # No function found with the exact types, try any1_any1 version # TODO: What about cases like i32_any1? What about any instead of any1? - if binary: - signature = f"{function_name}:any1_any1" - else: - signature = f"{function_name}:any1" + signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}" function_anchor = self._functions_catalog.function_anchor(signature) function_return_type = self._functions_catalog.function_return_type(signature) @@ -216,20 +217,25 @@ def _parse_function_invokation( proto.Expression( scalar_function=proto.Expression.ScalarFunction( function_reference=function_anchor, - arguments=( - [ - proto.FunctionArgument(value=left), - proto.FunctionArgument(value=right), - ] - if binary - else [proto.FunctionArgument(value=left)] - ), + arguments=[ + proto.FunctionArgument(value=arg.expression) + for arg in arguments + ], ) ), ) class ParsedSubstraitExpression: + """A Substrait expression that was parsed from a SQLGlot node. + + This stores the expression itself, with an associated output name + in case it is required to emit projections. + + It also stores the type of the expression (i64, string, boolean, etc...) + and the functions that the expression in going to invoke. + """ + def __init__(self, output_name, type, expression, invoked_functions=None): self.expression = expression self.output_name = output_name