diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py index 76556a2..f135e4a 100644 --- a/src/substrait/sql/__main__.py +++ b/src/substrait/sql/__main__.py @@ -1,45 +1,67 @@ import pathlib +import argparse from substrait import proto from .functions_catalog import FunctionsCatalog from .extended_expression import parse_sql_extended_expression -catalog = FunctionsCatalog() -catalog.load_standard_extensions( - pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions", -) - -# TODO: Turn this into a command line tool to test more queries. -# We can probably have a quick way to declare schema using command line args. -# like first_name=String,surname=String,age=I32 etc... -schema = proto.NamedStruct( - names=["first_name", "surname", "age"], - struct=proto.Type.Struct( - types=[ - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), + +def main(): + """Commandline tool to test the SQL to ExtendedExpression parser. + + Run as python -m substrait.sql first_name=String,surname=String,age=I32 "SELECT surname, age + 1 as next_birthday, age + 2 WHERE age = 32" + """ + parser = argparse.ArgumentParser( + description="Convert a SQL SELECT statement to an ExtendedExpression" + ) + parser.add_argument("schema", type=str, help="Schema of the input data") + parser.add_argument("sql", type=str, help="SQL SELECT statement") + args = parser.parse_args() + + catalog = FunctionsCatalog() + catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent.parent.parent + / "third_party" + / "substrait" + / "extensions", + ) + schema = parse_schema(args.schema) + projection_expr, filter_expr = parse_sql_extended_expression( + catalog, schema, args.sql + ) + + print("---- SQL INPUT ----") + print(args.sql) + print("---- PROJECTION ----") + print(projection_expr) + print("---- FILTER ----") + print(filter_expr) + + +def parse_schema(schema_string): + """Parse Schema from a comma separated string of fieldname=fieldtype pairs. + + For example: "first_name=String,surname=String,age=I32" + """ + types = [] + names = [] + + fields = schema_string.split(",") + for field in fields: + fieldname, fieldtype = field.split("=") + proto_type = getattr(proto.Type, fieldtype) + names.append(fieldname) + types.append( proto.Type( - i32=proto.Type.I32( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - ] - ), -) - -sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" -projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) -print("---- SQL INPUT ----") -print(sql) -print("---- PROJECTION ----") -print(projection_expr) -print("---- FILTER ----") -print(filter_expr) \ No newline at end of file + **{ + fieldtype.lower(): proto_type( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + } + ) + ) + return proto.NamedStruct(names=names, struct=proto.Type.Struct(types=types)) + + +if __name__ == "__main__": + main() diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 1e43f55..d53f00f 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -27,9 +27,13 @@ def parse_sql_extended_expression(catalog, schema, sql): project_expressions = [] projection_invoked_functions = set() for sqlexpr in select.expressions: - invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot( + sqlexpr + ) projection_invoked_functions.update(invoked_functions) - project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name])) + project_expressions.append( + proto.ExpressionReference(expression=expr, output_names=[output_name]) + ) extension_uris, extensions = catalog.extensions_for_functions( projection_invoked_functions ) @@ -73,16 +77,28 @@ def expression_from_sqlglot(self, sqlglot_node): def _parse_expression(self, expr, invoked_functions): if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: - return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( - literal=proto.Expression.Literal(string=expr.text) + return ( + f"literal_{next(self._counter)}", + proto.Type(string=proto.Type.String()), + proto.Expression( + literal=proto.Expression.Literal(string=expr.text) + ), ) elif expr.is_int: - return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) + return ( + f"literal_{next(self._counter)}", + proto.Type(i32=proto.Type.I32()), + proto.Expression( + literal=proto.Expression.Literal(i32=int(expr.name)) + ), ) elif sqlglot.helper.is_float(expr.name): - return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( - literal=proto.Expression.Literal(float=float(expr.name)) + return ( + f"literal_{next(self._counter)}", + proto.Type(fp32=proto.Type.FP32()), + proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ), ) else: raise ValueError(f"Unsupporter literal: {expr.text}") @@ -90,17 +106,23 @@ def _parse_expression(self, expr, invoked_functions): column_name = expr.output_name schema_field = list(self._schema.names).index(column_name) schema_type = self._schema.struct.types[schema_field] - return column_name, schema_type, proto.Expression( - selection=proto.Expression.FieldReference( - direct_reference=proto.Expression.ReferenceSegment( - struct_field=proto.Expression.ReferenceSegment.StructField( - field=schema_field + return ( + column_name, + schema_type, + proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) ) ) - ) + ), ) elif isinstance(expr, sqlglot.expressions.Alias): - _, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions) + _, aliased_type, aliased_expr = self._parse_expression( + expr.this, invoked_functions + ) return expr.output_name, aliased_type, aliased_expr elif expr.key in SQL_BINARY_FUNCTIONS: left_name, left_type, left = self._parse_expression( @@ -110,18 +132,24 @@ def _parse_expression(self, expr, invoked_functions): expr.right, invoked_functions ) function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = self._parse_function_invokation( - function_name, left_type, left, right_type, right + signature, result_type, function_expression = ( + self._parse_function_invokation( + function_name, left_type, left, right_type, right + ) ) invoked_functions.add(signature) - result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + result_name = ( + f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + ) return result_name, result_type, function_expression else: raise ValueError( f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" ) - def _parse_function_invokation(self, function_name, left_type, left, right_type, right): + def _parse_function_invokation( + self, function_name, left_type, left, right_type, right + ): signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" try: function_anchor = self._functions_catalog.function_anchor(signature) @@ -143,4 +171,3 @@ def _parse_function_invokation(self, function_name, left_type, left, right_type, ) ), ) - diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 8f7286d..e591523 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -7,7 +7,7 @@ class FunctionsCatalog: """Catalog of Substrait functions and extensions. - + Loads extensions from YAML files and records the declared functions. Given a set of functions it can generate the necessary extension URIs and extensions to be included in an ExtendedExpression or Plan. @@ -49,6 +49,7 @@ def load(self, dirpath, filename): for function in functions: function_name = function["name"] for impl in function.get("impls", []): + # TODO: There seem to be some functions that have arguments without type. What to do? argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] if not argtypes: signature = function_name @@ -118,4 +119,4 @@ def extensions_for_functions(self, functions): uris_by_anchor = self.extension_uris_by_anchor extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] - return extension_uris, extensions \ No newline at end of file + return extension_uris, extensions