Skip to content

Commit

Permalink
Migrate function resolution from keys to classes
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed May 8, 2024
1 parent 3c9f6eb commit e9508a4
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@
from .utils import DispatchRegistry


SQL_UNARY_FUNCTIONS = {"not": "not"}
SQL_BINARY_FUNCTIONS = {
SQL_FUNCTIONS = {
# Arithmetic
"add": "add",
"div": "div",
"mul": "mul",
"sub": "sub",
"mod": "modulus",
"bitwiseand": "bitwise_and",
"bitwiseor": "bitwise_or",
"bitwisexor": "bitwise_xor",
"bitwiseor": "bitwise_or",
sqlglot.expressions.Add: "add",
sqlglot.expressions.Div: "div",
sqlglot.expressions.Mul: "mul",
sqlglot.expressions.Sub: "sub",
sqlglot.expressions.Mod: "modulus",
sqlglot.expressions.BitwiseAnd: "bitwise_and",
sqlglot.expressions.BitwiseOr: "bitwise_or",
sqlglot.expressions.BitwiseXor: "bitwise_xor",
sqlglot.expressions.BitwiseNot: "bitwise_not",
# Comparisons
"eq": "equal",
"nullsafeeq": "is_not_distinct_from",
"neq": "not_equal",
"gt": "gt",
"gte": "gte",
"lt": "lt",
"lte": "lte",
sqlglot.expressions.EQ: "equal",
sqlglot.expressions.NullSafeEQ: "is_not_distinct_from",
sqlglot.expressions.NEQ: "not_equal",
sqlglot.expressions.GT: "gt",
sqlglot.expressions.GTE: "gte",
sqlglot.expressions.LT: "lt",
sqlglot.expressions.LTE: "lte",
# logical
"and": "and",
"or": "or",
sqlglot.expressions.And: "and",
sqlglot.expressions.Or: "or",
sqlglot.expressions.Not: "not",
}


Expand Down Expand Up @@ -151,7 +151,7 @@ def _parse_Alias(self, expr):
def _parser_Binary(self, expr):
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
function_name = SQL_FUNCTIONS[type(expr)]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
Expand All @@ -168,7 +168,7 @@ def _parser_Binary(self, expr):
@DISPATCH_REGISTRY.register(sqlglot.expressions.Unary)
def _parse_Unary(self, expr):
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
function_name = SQL_FUNCTIONS[type(expr)]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, argument_parsed_expr
)
Expand Down

0 comments on commit e9508a4

Please sign in to comment.