diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index fe78e1a..06d621a 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -194,30 +194,20 @@ def _parse_function_invokation( invokation expression itself. """ arguments = [argument_parsed_expr] + list(additional_arguments) - signature = self._functions_catalog.signature( + signature = self._functions_catalog.make_signature( function_name, proto_argtypes=[arg.type for arg in arguments] ) - try: - function_anchor = self._functions_catalog.function_anchor(signature) - except KeyError: - # No function found with the exact types, try any1_any1 version - # TODO: What about cases like i32_any1? What about any instead of any1? - # TODO: What about optional arguments? IE: "i32_i32?" - signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}" - function_anchor = self._functions_catalog.function_anchor(signature) - - function_return_type = self._functions_catalog.function_return_type(signature) - if function_return_type is None: - print("No return type for", signature) - # TODO: Is this the right way to handle this? - function_return_type = left_type + registered_function = self._functions_catalog.lookup_function(signature) + if registered_function is None: + raise KeyError(f"Function not found: {signature}") + return ( - signature, - function_return_type, + registered_function.signature, + registered_function.return_type, proto.Expression( scalar_function=proto.Expression.ScalarFunction( - function_reference=function_anchor, + function_reference=registered_function.function_anchor, arguments=[ proto.FunctionArgument(value=arg.expression) for arg in arguments @@ -255,3 +245,6 @@ def duplicate( expression or self.expression, invoked_functions or self.invoked_functions, ) + + def __repr__(self): + return f"" diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 0430e8e..d6b175d 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -1,8 +1,80 @@ +import os import pathlib +from collections.abc import Iterable import yaml -from substrait import proto +from substrait.gen.proto.type_pb2 import Type as SubstraitType +from substrait.gen.proto.extensions.extensions_pb2 import ( + SimpleExtensionURI, + SimpleExtensionDeclaration, +) + + +class RegisteredSubstraitFunction: + """A Substrait function loaded from an extension file. + + The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction + and will use them to generate the necessary extension URIs and extensions. + """ + + def __init__(self, signature: str, function_anchor: int | None, impl: dict): + self.signature = signature + self.function_anchor = function_anchor + self.variadic = impl.get("variadic", False) + + if "return" in impl: + self.return_type = self._type_from_name(impl["return"]) + else: + # We do always need a return type + # to know which type to propagate up to the invoker + _, argtypes = FunctionsCatalog.parse_signature(signature) + # TODO: Is this the right way to handle this? + self.return_type = self._type_from_name(argtypes[0]) + + @property + def name(self) -> str: + name, _ = FunctionsCatalog.parse_signature(self.signature) + return name + + @property + def arguments(self) -> list[str]: + _, argtypes = FunctionsCatalog.parse_signature(self.signature) + return argtypes + + @property + def arguments_type(self) -> list[SubstraitType | None]: + return [self._type_from_name(arg) for arg in self.arguments] + + def _type_from_name(self, typename: str) -> SubstraitType | None: + nullable = False + if typename.endswith("?"): + nullable = True + + typename = typename.strip("?") + if typename in ("any", "any1"): + return None + + if typename == "boolean": + # For some reason boolean is an exception to the naming convention + typename = "bool" + + try: + type_descriptor = SubstraitType.DESCRIPTOR.fields_by_name[ + typename + ].message_type + except KeyError: + # TODO: improve resolution of complext type like LIST? + print("Unsupported type", typename) + return None + + type_class = getattr(SubstraitType, type_descriptor.name) + nullability = ( + SubstraitType.Nullability.NULLABILITY_REQUIRED + if not nullable + else SubstraitType.Nullability.NULLABILITY_NULLABLE + ) + return SubstraitType(**{typename: type_class(nullability=nullability)}) class FunctionsCatalog: @@ -32,20 +104,21 @@ class FunctionsCatalog: ) def __init__(self): - self._registered_extensions = {} + self._substrait_extension_uris = {} + self._substrait_extension_functions = {} self._functions = {} - self._functions_return_type = {} - def load_standard_extensions(self, dirpath): + def load_standard_extensions(self, dirpath: str | os.PathLike): + """Load all standard substrait extensions from the target directory.""" for ext in self.STANDARD_EXTENSIONS: self.load(dirpath, ext) - def load(self, dirpath, filename): + def load(self, dirpath: str | os.PathLike, filename: str): + """Load an extension from a YAML file in a target directory.""" with open(pathlib.Path(dirpath) / filename.strip("/")) as f: sections = yaml.safe_load(f) - loaded_functions = set() - functions_return_type = {} + loaded_functions = {} for functions in sections.values(): for function in functions: function_name = function["name"] @@ -56,100 +129,80 @@ def load(self, dirpath, filename): t.get("value", "unknown").strip("?") for t in impl.get("args", []) ] - if impl.get("variadic", False): - # TODO: Variadic functions. - argtypes *= 2 - if not argtypes: signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" - loaded_functions.add(signature) - print("Loaded function", signature) - functions_return_type[signature] = self._type_from_name( - impl["return"] + loaded_functions[signature] = RegisteredSubstraitFunction( + signature, None, impl ) - self._register_extensions(filename, loaded_functions, functions_return_type) + self._register_extensions(filename, loaded_functions) def _register_extensions( - self, extension_uri, loaded_functions, functions_return_type + self, + extension_uri: str, + loaded_functions: dict[str, RegisteredSubstraitFunction], ): - if extension_uri not in self._registered_extensions: - ext_anchor_id = len(self._registered_extensions) + 1 - self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( + if extension_uri not in self._substrait_extension_uris: + ext_anchor_id = len(self._substrait_extension_uris) + 1 + self._substrait_extension_uris[extension_uri] = SimpleExtensionURI( extension_uri_anchor=ext_anchor_id, uri=extension_uri ) - for function in loaded_functions: - if function in self._functions: + for signature, registered_function in loaded_functions.items(): + if signature in self._substrait_extension_functions: extensions_by_anchor = self.extension_uris_by_anchor - existing_function = self._functions[function] + existing_function = self._substrait_extension_functions[signature] function_extension = extensions_by_anchor[ existing_function.extension_uri_reference ].uri raise ValueError( f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}" ) - extension_anchor = self._registered_extensions[ + extension_anchor = self._substrait_extension_uris[ extension_uri ].extension_uri_anchor - function_anchor = len(self._functions) + 1 - self._functions[function] = ( - proto.SimpleExtensionDeclaration.ExtensionFunction( + function_anchor = len(self._substrait_extension_functions) + 1 + self._substrait_extension_functions[signature] = ( + SimpleExtensionDeclaration.ExtensionFunction( extension_uri_reference=extension_anchor, - name=function, + name=signature, function_anchor=function_anchor, ) ) - self._functions_return_type[function] = functions_return_type[function] - - def _type_from_name(self, typename): - nullable = False - if typename.endswith("?"): - nullable = True - - typename = typename.strip("?") - if typename in ("any", "any1"): - return None - - if typename == "boolean": - # For some reason boolean is an exception to the naming convention - typename = "bool" - - try: - type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[ - typename - ].message_type - except KeyError: - # TODO: improve resolution of complext type like LIST? - print("Unsupported type", typename) - return None - - type_class = getattr(proto.Type, type_descriptor.name) - nullability = ( - proto.Type.Nullability.NULLABILITY_REQUIRED - if not nullable - else proto.Type.Nullability.NULLABILITY_NULLABLE - ) - return proto.Type(**{typename: type_class(nullability=nullability)}) + registered_function.function_anchor = function_anchor + self._functions.setdefault(registered_function.name, []).append( + registered_function + ) @property - def extension_uris_by_anchor(self): + def extension_uris_by_anchor(self) -> dict[int, SimpleExtensionURI]: return { ext.extension_uri_anchor: ext - for ext in self._registered_extensions.values() + for ext in self._substrait_extension_uris.values() } @property - def extension_uris(self): - return list(self._registered_extensions.values()) + def extension_uris(self) -> list[SimpleExtensionURI]: + return list(self._substrait_extension_uris.values()) @property - def extensions(self): - return list(self._functions.values()) + def extensions_functions( + self, + ) -> list[SimpleExtensionDeclaration.ExtensionFunction]: + return list(self._substrait_extension_functions.values()) + + @classmethod + def make_signature( + cls, function_name: str, proto_argtypes: Iterable[SubstraitType] + ): + """Create a function signature from a function name and substrait types. + + The signature is generated according to Function Signature Compound Names + as described in the Substrait documentation. + """ - def signature(self, function_name, proto_argtypes): def _normalize_arg_types(argtypes): for argtype in argtypes: kind = argtype.WhichOneof("kind") @@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes): return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}" - def function_anchor(self, function): - return self._functions[function].function_anchor + @classmethod + def parse_signature(cls, signature: str) -> tuple[str, list[str]]: + """Parse a function signature and returns name and type names""" + try: + function_name, signature_args = signature.split(":") + except ValueError: + function_name = signature + argtypes = [] + else: + argtypes = signature_args.split("_") + return function_name, argtypes - def function_return_type(self, function): - return self._functions_return_type[function] + def extensions_for_functions( + self, function_signatures: Iterable[str] + ) -> tuple[list[SimpleExtensionURI], list[SimpleExtensionDeclaration]]: + """Given a set of function signatures, return the necessary extensions. - def extensions_for_functions(self, functions): + The function will return the URIs of the extensions and the extension + that have to be declared in the plan to use the functions. + """ uris_anchors = set() extensions = [] - for f in functions: - ext = self._functions[f] - if not ext.extension_uri_reference: - # Built-in function - continue + for f in function_signatures: + ext = self._substrait_extension_functions[f] uris_anchors.add(ext.extension_uri_reference) - extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) + extensions.append(SimpleExtensionDeclaration(extension_function=ext)) uris_by_anchor = self.extension_uris_by_anchor extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] return extension_uris, extensions + + def lookup_function(self, signature: str) -> RegisteredSubstraitFunction | None: + """Given the signature of a function invocation, return the matching function.""" + function_name, invocation_argtypes = self.parse_signature(signature) + + functions = self._functions.get(function_name) + if not functions: + # No function with such a name at all. + return None + + is_variadic = functions[0].variadic + if is_variadic: + # If it's variadic we care about only the first parameter. + invocation_argtypes = invocation_argtypes[:1] + + found_function = None + for function in functions: + accepted_function_arguments = function.arguments + for argidx, argtype in enumerate(invocation_argtypes): + try: + accepted_argument = accepted_function_arguments[argidx] + except IndexError: + # More arguments than available were provided + break + if accepted_argument != argtype and accepted_argument not in ( + "any", + "any1", + ): + break + else: + if argidx < len(accepted_function_arguments) - 1: + # Not enough arguments were provided + remainder = accepted_function_arguments[argidx + 1 :] + if all(arg.endswith("?") for arg in remainder): + # All remaining arguments are optional + found_function = function + else: + found_function = function + + return found_function diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py index c73a531..9ffad36 100644 --- a/src/substrait/sql/utils.py +++ b/src/substrait/sql/utils.py @@ -32,7 +32,9 @@ def __getitem__(self, argument): if isinstance(argument, dispatch_cls): return func else: - raise ValueError(f"Unsupported SQL Node type: {cls}") + raise ValueError( + f"Unsupported SQL Node type: {argument.__class__.__name__} -> {argument}" + ) def __call__(self, obj, dispatch_argument, *args, **kwargs): return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)