Skip to content

Commit

Permalink
Command line tool to test queries
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 18, 2024
1 parent df44c9b commit 01d22b9
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 60 deletions.
98 changes: 60 additions & 38 deletions src/substrait/sql/__main__.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,67 @@
import pathlib
import argparse

from substrait import proto
from .functions_catalog import FunctionsCatalog
from .extended_expression import parse_sql_extended_expression

catalog = FunctionsCatalog()
catalog.load_standard_extensions(
pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions",
)

# TODO: Turn this into a command line tool to test more queries.
# We can probably have a quick way to declare schema using command line args.
# like first_name=String,surname=String,age=I32 etc...
schema = proto.NamedStruct(
names=["first_name", "surname", "age"],
struct=proto.Type.Struct(
types=[
proto.Type(
string=proto.Type.String(
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
)
),
proto.Type(
string=proto.Type.String(
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
)
),

def main():
"""Commandline tool to test the SQL to ExtendedExpression parser.
Run as python -m substrait.sql first_name=String,surname=String,age=I32 "SELECT surname, age + 1 as next_birthday, age + 2 WHERE age = 32"
"""
parser = argparse.ArgumentParser(
description="Convert a SQL SELECT statement to an ExtendedExpression"
)
parser.add_argument("schema", type=str, help="Schema of the input data")
parser.add_argument("sql", type=str, help="SQL SELECT statement")
args = parser.parse_args()

catalog = FunctionsCatalog()
catalog.load_standard_extensions(
pathlib.Path(__file__).parent.parent.parent.parent
/ "third_party"
/ "substrait"
/ "extensions",
)
schema = parse_schema(args.schema)
projection_expr, filter_expr = parse_sql_extended_expression(
catalog, schema, args.sql
)

print("---- SQL INPUT ----")
print(args.sql)
print("---- PROJECTION ----")
print(projection_expr)
print("---- FILTER ----")
print(filter_expr)


def parse_schema(schema_string):
"""Parse Schema from a comma separated string of fieldname=fieldtype pairs.
For example: "first_name=String,surname=String,age=I32"
"""
types = []
names = []

fields = schema_string.split(",")
for field in fields:
fieldname, fieldtype = field.split("=")
proto_type = getattr(proto.Type, fieldtype)
names.append(fieldname)
types.append(
proto.Type(
i32=proto.Type.I32(
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
)
),
]
),
)

sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32"
projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql)
print("---- SQL INPUT ----")
print(sql)
print("---- PROJECTION ----")
print(projection_expr)
print("---- FILTER ----")
print(filter_expr)
**{
fieldtype.lower(): proto_type(
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
)
}
)
)
return proto.NamedStruct(names=names, struct=proto.Type.Struct(types=types))


if __name__ == "__main__":
main()
67 changes: 47 additions & 20 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ def parse_sql_extended_expression(catalog, schema, sql):
project_expressions = []
projection_invoked_functions = set()
for sqlexpr in select.expressions:
invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr)
invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(
sqlexpr
)
projection_invoked_functions.update(invoked_functions)
project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name]))
project_expressions.append(
proto.ExpressionReference(expression=expr, output_names=[output_name])
)
extension_uris, extensions = catalog.extensions_for_functions(
projection_invoked_functions
)
Expand Down Expand Up @@ -73,34 +77,52 @@ def expression_from_sqlglot(self, sqlglot_node):
def _parse_expression(self, expr, invoked_functions):
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
return (
f"literal_{next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
),
)
elif expr.is_int:
return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
return (
f"literal_{next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
)
elif sqlglot.helper.is_float(expr.name):
return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
return (
f"literal_{next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
),
)
else:
raise ValueError(f"Unsupporter literal: {expr.text}")
elif isinstance(expr, sqlglot.expressions.Column):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
return column_name, schema_type, proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
return (
column_name,
schema_type,
proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
)
)
),
)
elif isinstance(expr, sqlglot.expressions.Alias):
_, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions)
_, aliased_type, aliased_expr = self._parse_expression(
expr.this, invoked_functions
)
return expr.output_name, aliased_type, aliased_expr
elif expr.key in SQL_BINARY_FUNCTIONS:
left_name, left_type, left = self._parse_expression(
Expand All @@ -110,18 +132,24 @@ def _parse_expression(self, expr, invoked_functions):
expr.right, invoked_functions
)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, left_type, left, right_type, right
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name, left_type, left, right_type, right
)
)
invoked_functions.add(signature)
result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}"
result_name = (
f"{left_name}_{function_name}_{right_name}_{next(self._counter)}"
)
return result_name, result_type, function_expression
else:
raise ValueError(
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
)

def _parse_function_invokation(self, function_name, left_type, left, right_type, right):
def _parse_function_invokation(
self, function_name, left_type, left, right_type, right
):
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
try:
function_anchor = self._functions_catalog.function_anchor(signature)
Expand All @@ -143,4 +171,3 @@ def _parse_function_invokation(self, function_name, left_type, left, right_type,
)
),
)

5 changes: 3 additions & 2 deletions src/substrait/sql/functions_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class FunctionsCatalog:
"""Catalog of Substrait functions and extensions.
Loads extensions from YAML files and records the declared functions.
Given a set of functions it can generate the necessary extension URIs
and extensions to be included in an ExtendedExpression or Plan.
Expand Down Expand Up @@ -49,6 +49,7 @@ def load(self, dirpath, filename):
for function in functions:
function_name = function["name"]
for impl in function.get("impls", []):
# TODO: There seem to be some functions that have arguments without type. What to do?
argtypes = [t.get("value", "unknown") for t in impl.get("args", [])]
if not argtypes:
signature = function_name
Expand Down Expand Up @@ -118,4 +119,4 @@ def extensions_for_functions(self, functions):

uris_by_anchor = self.extension_uris_by_anchor
extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors]
return extension_uris, extensions
return extension_uris, extensions

0 comments on commit 01d22b9

Please sign in to comment.