From 1a5fcc761db60e165bf4091663a63f4735417052 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Fri, 19 Apr 2024 16:26:43 +0200 Subject: [PATCH] Tweak dynamic dispatch and handle variadic and, or etc... --- src/substrait/sql/extended_expression.py | 36 ++++++++---------------- src/substrait/sql/functions_catalog.py | 5 ++++ src/substrait/sql/utils.py | 30 +++++++++++++++++--- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 74ec3c4..fe78e1a 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -92,37 +92,25 @@ def __init__(self, functions_catalog, schema): self._schema = schema self._counter = itertools.count() + self._parse_expression = self.DISPATCH_REGISTRY.bind(self) + def expression_from_sqlglot(self, sqlglot_node): """Parse a SQLGlot expression into a Substrait Expression.""" return self._parse_expression(sqlglot_node) - def _parse_expression(self, expr): - """Parse a SQLGlot node and return a Substrait expression. - - This is the internal implementation, expected to be - invoked in a recursive manner to parse the whole - expression tree. - """ - expr_class = expr.__class__ - return self.DISPATCH_REGISTRY[expr_class](self, expr) - @DISPATCH_REGISTRY.register(sqlglot.expressions.Literal) def _parse_Literal(self, expr): if expr.is_string: return ParsedSubstraitExpression( f"literal${next(self._counter)}", proto.Type(string=proto.Type.String()), - proto.Expression( - literal=proto.Expression.Literal(string=expr.text) - ), + proto.Expression(literal=proto.Expression.Literal(string=expr.text)), ) elif expr.is_int: return ParsedSubstraitExpression( f"literal${next(self._counter)}", proto.Type(i32=proto.Type.I32()), - proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) - ), + proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))), ) elif sqlglot.helper.is_float(expr.name): return ParsedSubstraitExpression( @@ -134,7 +122,7 @@ def _parse_Literal(self, expr): ) else: raise ValueError(f"Unsupporter literal: {expr.text}") - + @DISPATCH_REGISTRY.register(sqlglot.expressions.Column) def _parse_Column(self, expr): column_name = expr.output_name @@ -164,10 +152,8 @@ def _parser_Binary(self, expr): left_parsed_expr = self._parse_expression(expr.left) right_parsed_expr = self._parse_expression(expr.right) function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation( - function_name, left_parsed_expr, right_parsed_expr - ) + signature, result_type, function_expression = self._parse_function_invokation( + function_name, left_parsed_expr, right_parsed_expr ) result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( @@ -183,10 +169,12 @@ def _parser_Binary(self, expr): def _parse_Unary(self, expr): argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation(function_name, argument_parsed_expr) + signature, result_type, function_expression = self._parse_function_invokation( + function_name, argument_parsed_expr + ) + result_name = ( + f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" ) - result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( result_name, result_type, diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 8d72871..0430e8e 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -56,11 +56,16 @@ def load(self, dirpath, filename): t.get("value", "unknown").strip("?") for t in impl.get("args", []) ] + if impl.get("variadic", False): + # TODO: Variadic functions. + argtypes *= 2 + if not argtypes: signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" loaded_functions.add(signature) + print("Loaded function", signature) functions_return_type[signature] = self._type_from_name( impl["return"] ) diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py index 13eeeeb..c73a531 100644 --- a/src/substrait/sql/utils.py +++ b/src/substrait/sql/utils.py @@ -1,4 +1,19 @@ +import types + + class DispatchRegistry: + """Dispatch a function based on the class of the argument. + + This class allows to register a function to execute for a specific class + and expose this as a method of an object which will be dispatched + based on the argument. + + It is similar to functools.singledispatch but it allows more + customization in case the dispatch rules grow in complexity + and works for class methods as well + (singledispatch supports methods only in more recent versions) + """ + def __init__(self): self._registry = {} @@ -6,11 +21,18 @@ def register(self, cls): def decorator(func): self._registry[cls] = func return func + return decorator - - def __getitem__(self, cls): + + def bind(self, obj): + return types.MethodType(self, obj) + + def __getitem__(self, argument): for dispatch_cls, func in self._registry.items(): - if issubclass(cls, dispatch_cls): + if isinstance(argument, dispatch_cls): return func else: - raise ValueError(f"Unsupported SQL Node type: {cls}") \ No newline at end of file + raise ValueError(f"Unsupported SQL Node type: {cls}") + + def __call__(self, obj, dispatch_argument, *args, **kwargs): + return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)