Skip to content

Commit

Permalink
Tweak dynamic dispatch and handle variadic and, or etc...
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed May 8, 2024
1 parent 7688fc5 commit 1a5fcc7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 28 deletions.
36 changes: 12 additions & 24 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,25 @@ def __init__(self, functions_catalog, schema):
self._schema = schema
self._counter = itertools.count()

self._parse_expression = self.DISPATCH_REGISTRY.bind(self)

def expression_from_sqlglot(self, sqlglot_node):
"""Parse a SQLGlot expression into a Substrait Expression."""
return self._parse_expression(sqlglot_node)

def _parse_expression(self, expr):
"""Parse a SQLGlot node and return a Substrait expression.
This is the internal implementation, expected to be
invoked in a recursive manner to parse the whole
expression tree.
"""
expr_class = expr.__class__
return self.DISPATCH_REGISTRY[expr_class](self, expr)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Literal)
def _parse_Literal(self, expr):
if expr.is_string:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
),
proto.Expression(literal=proto.Expression.Literal(string=expr.text)),
)
elif expr.is_int:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))),
)
elif sqlglot.helper.is_float(expr.name):
return ParsedSubstraitExpression(
Expand All @@ -134,7 +122,7 @@ def _parse_Literal(self, expr):
)
else:
raise ValueError(f"Unsupporter literal: {expr.text}")

@DISPATCH_REGISTRY.register(sqlglot.expressions.Column)
def _parse_Column(self, expr):
column_name = expr.output_name
Expand Down Expand Up @@ -164,10 +152,8 @@ def _parser_Binary(self, expr):
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
signature, result_type, function_expression = self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
Expand All @@ -183,10 +169,12 @@ def _parser_Binary(self, expr):
def _parse_Unary(self, expr):
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(function_name, argument_parsed_expr)
signature, result_type, function_expression = self._parse_function_invokation(
function_name, argument_parsed_expr
)
result_name = (
f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
)
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
Expand Down
5 changes: 5 additions & 0 deletions src/substrait/sql/functions_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ def load(self, dirpath, filename):
t.get("value", "unknown").strip("?")
for t in impl.get("args", [])
]
if impl.get("variadic", False):
# TODO: Variadic functions.
argtypes *= 2

if not argtypes:
signature = function_name
else:
signature = f"{function_name}:{'_'.join(argtypes)}"
loaded_functions.add(signature)
print("Loaded function", signature)
functions_return_type[signature] = self._type_from_name(
impl["return"]
)
Expand Down
30 changes: 26 additions & 4 deletions src/substrait/sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import types


class DispatchRegistry:
"""Dispatch a function based on the class of the argument.
This class allows to register a function to execute for a specific class
and expose this as a method of an object which will be dispatched
based on the argument.
It is similar to functools.singledispatch but it allows more
customization in case the dispatch rules grow in complexity
and works for class methods as well
(singledispatch supports methods only in more recent versions)
"""

def __init__(self):
self._registry = {}

def register(self, cls):
def decorator(func):
self._registry[cls] = func
return func

return decorator

def __getitem__(self, cls):

def bind(self, obj):
return types.MethodType(self, obj)

def __getitem__(self, argument):
for dispatch_cls, func in self._registry.items():
if issubclass(cls, dispatch_cls):
if isinstance(argument, dispatch_cls):
return func
else:
raise ValueError(f"Unsupported SQL Node type: {cls}")
raise ValueError(f"Unsupported SQL Node type: {cls}")

def __call__(self, obj, dispatch_argument, *args, **kwargs):
return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)

0 comments on commit 1a5fcc7

Please sign in to comment.