From 309796299ba29b6418b5db875affd85c7a0645db Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 15:15:53 +0200 Subject: [PATCH] Register builtin functions and handle return types --- src/substrait/sql/extended_expression.py | 62 +++++++++++++++--- src/substrait/sql/functions_catalog.py | 83 +++++++++++++++++++++--- 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index d53f00f..0e82abd 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -4,15 +4,29 @@ from substrait import proto - +SQL_UNARY_FUNCTIONS = {"not": "not"} SQL_BINARY_FUNCTIONS = { # Arithmetic "add": "add", "div": "div", "mul": "mul", "sub": "sub", + "mod": "modulus", + "bitwiseand": "bitwise_and", + "bitwiseor": "bitwise_or", + "bitwisexor": "bitwise_xor", + "bitwiseor": "bitwise_or", # Comparisons "eq": "equal", + "nullsafeeq": "is_not_distinct_from", + "new": "not_equal", + "gt": "gt", + "gte": "gte", + "lt": "lt", + "lte": "lte", + # logical + "and": "and", + "or": "or", } @@ -124,6 +138,17 @@ def _parse_expression(self, expr, invoked_functions): expr.this, invoked_functions ) return expr.output_name, aliased_type, aliased_expr + elif expr.key in SQL_UNARY_FUNCTIONS: + argument_name, argument_type, argument = self._parse_expression( + expr.this, invoked_functions + ) + function_name = SQL_UNARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = ( + self._parse_function_invokation(function_name, argument_type, argument) + ) + invoked_functions.add(signature) + result_name = f"{function_name}_{argument_name}_{next(self._counter)}" + return result_name, result_type, function_expression elif expr.key in SQL_BINARY_FUNCTIONS: left_name, left_type, left = self._parse_expression( expr.left, invoked_functions @@ -148,26 +173,45 @@ def _parse_expression(self, expr, invoked_functions): ) def _parse_function_invokation( - self, function_name, left_type, left, right_type, right + self, function_name, left_type, left, right_type=None, right=None ): - signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" + binary = False + argtypes = [left_type] + if right_type or right: + binary = True + argtypes.append(right_type) + signature = self._functions_catalog.signature(function_name, argtypes) + try: function_anchor = self._functions_catalog.function_anchor(signature) except KeyError: # No function found with the exact types, try any1_any1 version # TODO: What about cases like i32_any1? What about any instead of any1? - signature = f"{function_name}:any1_any1" + if binary: + signature = f"{function_name}:any1_any1" + else: + signature = f"{function_name}:any1" function_anchor = self._functions_catalog.function_anchor(signature) + + function_return_type = self._functions_catalog.function_return_type(signature) + if function_return_type is None: + print("No return type for", signature) + # TODO: Is this the right way to handle this? + function_return_type = left_type return ( signature, - left_type, # TODO: Get the actually returned type from the functions catalog. + function_return_type, proto.Expression( scalar_function=proto.Expression.ScalarFunction( function_reference=function_anchor, - arguments=[ - proto.FunctionArgument(value=left), - proto.FunctionArgument(value=right), - ], + arguments=( + [ + proto.FunctionArgument(value=left), + proto.FunctionArgument(value=right), + ] + if binary + else [proto.FunctionArgument(value=left)] + ), ) ), ) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index e591523..4bd214d 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -23,7 +23,7 @@ class FunctionsCatalog: "/functions_arithmetic_decimal.yaml", "/functions_boolean.yaml", "/functions_comparison.yaml", - "/functions_datetime.yaml", + # "/functions_datetime.yaml", for now skip, it has duplicated functions "/functions_geometry.yaml", "/functions_logarithmic.yaml", "/functions_rounding.yaml", @@ -32,9 +32,10 @@ class FunctionsCatalog: ) def __init__(self): - self._declarations = {} self._registered_extensions = {} self._functions = {} + self._functions_return_type = {} + self._register_builtins() def load_standard_extensions(self, dirpath): for ext in self.STANDARD_EXTENSIONS: @@ -45,6 +46,7 @@ def load(self, dirpath, filename): sections = yaml.safe_load(f) loaded_functions = set() + functions_return_type = {} for functions in sections.values(): for function in functions: function_name = function["name"] @@ -55,12 +57,16 @@ def load(self, dirpath, filename): signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" - self._declarations[signature] = filename loaded_functions.add(signature) + functions_return_type[signature] = self._type_from_name( + impl["return"] + ) - self._register_extensions(filename, loaded_functions) + self._register_extensions(filename, loaded_functions, functions_return_type) - def _register_extensions(self, extension_uri, loaded_functions): + def _register_extensions( + self, extension_uri, loaded_functions, functions_return_type + ): if extension_uri not in self._registered_extensions: ext_anchor_id = len(self._registered_extensions) + 1 self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( @@ -70,14 +76,12 @@ def _register_extensions(self, extension_uri, loaded_functions): for function in loaded_functions: if function in self._functions: extensions_by_anchor = self.extension_uris_by_anchor - function = self._functions[function] + existing_function = self._functions[function] function_extension = extensions_by_anchor[ - function.extension_uri_reference + existing_function.extension_uri_reference ].uri - # TODO: Support overloading of functions from different extensionUris. - continue raise ValueError( - f"Duplicate function definition: {function.name} from {extension_uri}, already loaded from {function_extension}" + f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}" ) extension_anchor = self._registered_extensions[ extension_uri @@ -90,6 +94,48 @@ def _register_extensions(self, extension_uri, loaded_functions): function_anchor=function_anchor, ) ) + self._functions_return_type[function] = functions_return_type[function] + + def _register_builtins(self): + self._functions["not:boolean"] = ( + proto.SimpleExtensionDeclaration.ExtensionFunction( + name="not", + function_anchor=len(self._functions) + 1, + ) + ) + self._functions_return_type["not:boolean"] = proto.Type( + bool=proto.Type.Boolean() + ) + + def _type_from_name(self, typename): + nullable = False + if typename.endswith("?"): + nullable = True + + typename = typename.strip("?") + if typename in ("any", "any1"): + return None + + if typename == "boolean": + # For some reason boolean is an exception to the naming convention + typename = "bool" + + try: + type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[ + typename + ].message_type + except KeyError: + # TODO: improve resolution of complext type like LIST? + print("Unsupported type", typename) + return None + + type_class = getattr(proto.Type, type_descriptor.name) + nullability = ( + proto.Type.Nullability.NULLABILITY_REQUIRED + if not nullable + else proto.Type.Nullability.NULLABILITY_NULLABLE + ) + return proto.Type(**{typename: type_class(nullability=nullability)}) @property def extension_uris_by_anchor(self): @@ -106,14 +152,31 @@ def extension_uris(self): def extensions(self): return list(self._functions.values()) + def signature(self, function_name, proto_argtypes): + def _normalize_arg_types(argtypes): + for argtype in argtypes: + kind = argtype.WhichOneof("kind") + if kind == "bool": + yield "boolean" + else: + yield kind + + return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}" + def function_anchor(self, function): return self._functions[function].function_anchor + def function_return_type(self, function): + return self._functions_return_type[function] + def extensions_for_functions(self, functions): uris_anchors = set() extensions = [] for f in functions: ext = self._functions[f] + if not ext.extension_uri_reference: + # Built-in function + continue uris_anchors.add(ext.extension_uri_reference) extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))