diff --git a/pyproject.toml b/pyproject.toml index 3a851d7..85ecc6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0"] +dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0", "PyYAML"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py new file mode 100644 index 0000000..119343b --- /dev/null +++ b/src/substrait/sql/__main__.py @@ -0,0 +1,42 @@ +import pathlib + +from substrait import proto +from .functions_catalog import FunctionsCatalog +from .extended_expression import parse_sql_extended_expression + +catalog = FunctionsCatalog() +catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions", +) +schema = proto.NamedStruct( + names=["first_name", "surname", "age"], + struct=proto.Type.Struct( + types=[ + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + i32=proto.Type.I32( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + ] + ), +) + +sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" +projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) +print("---- SQL INPUT ----") +print(sql) +print("---- PROJECTION ----") +print(projection_expr) +print("---- FILTER ----") +print(filter_expr) +# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") \ No newline at end of file diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index b6af469..d684507 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -1,7 +1,4 @@ -import pathlib - import sqlglot -import yaml from substrait import proto @@ -138,7 +135,7 @@ def _parse_expression(catalog, schema, expr, invoked_functions): ) function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = _parse_function_invokation( - function_name, left_type, left, right_type, right + catalog, function_name, left_type, left, right_type, right ) invoked_functions.add(signature) return result_type, function_expression @@ -148,7 +145,7 @@ def _parse_expression(catalog, schema, expr, invoked_functions): ) -def _parse_function_invokation(function_name, left_type, left, right_type, right): +def _parse_function_invokation(catalog, function_name, left_type, left, right_type, right): signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" try: function_anchor = catalog.function_anchor(signature) @@ -170,146 +167,3 @@ def _parse_function_invokation(function_name, left_type, left, right_type, right ), ) - -class FunctionsCatalog: - STANDARD_EXTENSIONS = ( - "/functions_aggregate_approx.yaml", - "/functions_aggregate_generic.yaml", - "/functions_arithmetic.yaml", - "/functions_arithmetic_decimal.yaml", - "/functions_boolean.yaml", - "/functions_comparison.yaml", - "/functions_datetime.yaml", - "/functions_geometry.yaml", - "/functions_logarithmic.yaml", - "/functions_rounding.yaml", - "/functions_set.yaml", - "/functions_string.yaml", - ) - - def __init__(self): - self._declarations = {} - self._registered_extensions = {} - self._functions = {} - - def load_standard_extensions(self, dirpath): - for ext in self.STANDARD_EXTENSIONS: - self.load(dirpath, ext) - - def load(self, dirpath, filename): - with open(pathlib.Path(dirpath) / filename.strip("/")) as f: - sections = yaml.safe_load(f) - - loaded_functions = set() - for functions in sections.values(): - for function in functions: - function_name = function["name"] - for impl in function.get("impls", []): - argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] - if not argtypes: - signature = function_name - else: - signature = f"{function_name}:{'_'.join(argtypes)}" - self._declarations[signature] = filename - loaded_functions.add(signature) - - self._register_extensions(filename, loaded_functions) - - def _register_extensions(self, extension_uri, loaded_functions): - if extension_uri not in self._registered_extensions: - ext_anchor_id = len(self._registered_extensions) + 1 - self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( - extension_uri_anchor=ext_anchor_id, uri=extension_uri - ) - - for function in loaded_functions: - if function in self._functions: - extensions_by_anchor = self.extension_uris_by_anchor - function = self._functions[function] - function_extension = extensions_by_anchor[ - function.extension_uri_reference - ].uri - continue - raise ValueError( - f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" - ) - extension_anchor = self._registered_extensions[ - extension_uri - ].extension_uri_anchor - function_anchor = len(self._functions) + 1 - self._functions[function] = ( - proto.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=extension_anchor, - name=function, - function_anchor=function_anchor, - ) - ) - - @property - def extension_uris_by_anchor(self): - return { - ext.extension_uri_anchor: ext - for ext in self._registered_extensions.values() - } - - @property - def extension_uris(self): - return list(self._registered_extensions.values()) - - @property - def extensions(self): - return list(self._functions.values()) - - def function_anchor(self, function): - return self._functions[function].function_anchor - - def extensions_for_functions(self, functions): - uris_anchors = set() - extensions = [] - for f in functions: - ext = self._functions[f] - uris_anchors.add(ext.extension_uri_reference) - extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) - - uris_by_anchor = self.extension_uris_by_anchor - extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] - return extension_uris, extensions - - -catalog = FunctionsCatalog() -catalog.load_standard_extensions( - pathlib.Path(__file__).parent.parent / "third_party" / "substrait" / "extensions", -) -schema = proto.NamedStruct( - names=["first_name", "surname", "age"], - struct=proto.Type.Struct( - types=[ - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - i32=proto.Type.I32( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - ] - ), -) - -if __name__ == '__main__': - sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" - projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) - print("---- SQL INPUT ----") - print(sql) - print("---- PROJECTION ----") - print(projection_expr) - print("---- FILTER ----") - print(filter_expr) - # parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py new file mode 100644 index 0000000..54a089a --- /dev/null +++ b/src/substrait/sql/functions_catalog.py @@ -0,0 +1,120 @@ +import pathlib + +import yaml + +from substrait import proto + + +class FunctionsCatalog: + """Catalog of Substrait functions and extensions. + + Loads extensions from YAML files and records the declared functions. + Given a set of functions it can generate the necessary extension URIs + and extensions to be included in an ExtendedExpression or Plan. + """ + + # TODO: Find a way to support standard extensions in released distribution. + # IE: Include the standard extension yaml files in the package data and + # update them when gen_proto is used.. + STANDARD_EXTENSIONS = ( + "/functions_aggregate_approx.yaml", + "/functions_aggregate_generic.yaml", + "/functions_arithmetic.yaml", + "/functions_arithmetic_decimal.yaml", + "/functions_boolean.yaml", + "/functions_comparison.yaml", + "/functions_datetime.yaml", + "/functions_geometry.yaml", + "/functions_logarithmic.yaml", + "/functions_rounding.yaml", + "/functions_set.yaml", + "/functions_string.yaml", + ) + + def __init__(self): + self._declarations = {} + self._registered_extensions = {} + self._functions = {} + + def load_standard_extensions(self, dirpath): + for ext in self.STANDARD_EXTENSIONS: + self.load(dirpath, ext) + + def load(self, dirpath, filename): + with open(pathlib.Path(dirpath) / filename.strip("/")) as f: + sections = yaml.safe_load(f) + + loaded_functions = set() + for functions in sections.values(): + for function in functions: + function_name = function["name"] + for impl in function.get("impls", []): + argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] + if not argtypes: + signature = function_name + else: + signature = f"{function_name}:{'_'.join(argtypes)}" + self._declarations[signature] = filename + loaded_functions.add(signature) + + self._register_extensions(filename, loaded_functions) + + def _register_extensions(self, extension_uri, loaded_functions): + if extension_uri not in self._registered_extensions: + ext_anchor_id = len(self._registered_extensions) + 1 + self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( + extension_uri_anchor=ext_anchor_id, uri=extension_uri + ) + + for function in loaded_functions: + if function in self._functions: + extensions_by_anchor = self.extension_uris_by_anchor + function = self._functions[function] + function_extension = extensions_by_anchor[ + function.extension_uri_reference + ].uri + continue + raise ValueError( + f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" + ) + extension_anchor = self._registered_extensions[ + extension_uri + ].extension_uri_anchor + function_anchor = len(self._functions) + 1 + self._functions[function] = ( + proto.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=extension_anchor, + name=function, + function_anchor=function_anchor, + ) + ) + + @property + def extension_uris_by_anchor(self): + return { + ext.extension_uri_anchor: ext + for ext in self._registered_extensions.values() + } + + @property + def extension_uris(self): + return list(self._registered_extensions.values()) + + @property + def extensions(self): + return list(self._functions.values()) + + def function_anchor(self, function): + return self._functions[function].function_anchor + + def extensions_for_functions(self, functions): + uris_anchors = set() + extensions = [] + for f in functions: + ext = self._functions[f] + uris_anchors.add(ext.extension_uri_reference) + extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) + + uris_by_anchor = self.extension_uris_by_anchor + extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] + return extension_uris, extensions \ No newline at end of file