diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 06d621a..44d62ae 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -6,29 +6,29 @@ from .utils import DispatchRegistry -SQL_UNARY_FUNCTIONS = {"not": "not"} -SQL_BINARY_FUNCTIONS = { +SQL_FUNCTIONS = { # Arithmetic - "add": "add", - "div": "div", - "mul": "mul", - "sub": "sub", - "mod": "modulus", - "bitwiseand": "bitwise_and", - "bitwiseor": "bitwise_or", - "bitwisexor": "bitwise_xor", - "bitwiseor": "bitwise_or", + sqlglot.expressions.Add: "add", + sqlglot.expressions.Div: "div", + sqlglot.expressions.Mul: "mul", + sqlglot.expressions.Sub: "sub", + sqlglot.expressions.Mod: "modulus", + sqlglot.expressions.BitwiseAnd: "bitwise_and", + sqlglot.expressions.BitwiseOr: "bitwise_or", + sqlglot.expressions.BitwiseXor: "bitwise_xor", + sqlglot.expressions.BitwiseNot: "bitwise_not", # Comparisons - "eq": "equal", - "nullsafeeq": "is_not_distinct_from", - "neq": "not_equal", - "gt": "gt", - "gte": "gte", - "lt": "lt", - "lte": "lte", + sqlglot.expressions.EQ: "equal", + sqlglot.expressions.NullSafeEQ: "is_not_distinct_from", + sqlglot.expressions.NEQ: "not_equal", + sqlglot.expressions.GT: "gt", + sqlglot.expressions.GTE: "gte", + sqlglot.expressions.LT: "lt", + sqlglot.expressions.LTE: "lte", # logical - "and": "and", - "or": "or", + sqlglot.expressions.And: "and", + sqlglot.expressions.Or: "or", + sqlglot.expressions.Not: "not", } @@ -151,7 +151,7 @@ def _parse_Alias(self, expr): def _parser_Binary(self, expr): left_parsed_expr = self._parse_expression(expr.left) right_parsed_expr = self._parse_expression(expr.right) - function_name = SQL_BINARY_FUNCTIONS[expr.key] + function_name = SQL_FUNCTIONS[type(expr)] signature, result_type, function_expression = self._parse_function_invokation( function_name, left_parsed_expr, right_parsed_expr ) @@ -168,7 +168,7 @@ def _parser_Binary(self, expr): @DISPATCH_REGISTRY.register(sqlglot.expressions.Unary) def _parse_Unary(self, expr): argument_parsed_expr = self._parse_expression(expr.this) - function_name = SQL_UNARY_FUNCTIONS[expr.key] + function_name = SQL_FUNCTIONS[type(expr)] signature, result_type, function_expression = self._parse_function_invokation( function_name, argument_parsed_expr )