Skip to content

Commit

Permalink
Refactor passing around info about parsed expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 18, 2024
1 parent 1d1cbf8 commit a214297
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 41 deletions.
113 changes: 73 additions & 40 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@


def parse_sql_extended_expression(catalog, schema, sql):
"""Parse a SQL SELECT statement into an ExtendedExpression.
Only supports SELECT statements with projections and WHERE clauses.
"""
select = sqlglot.parse_one(sql)
if not isinstance(select, sqlglot.expressions.Select):
raise ValueError("a SELECT statement was expected")
Expand All @@ -41,12 +45,13 @@ def parse_sql_extended_expression(catalog, schema, sql):
project_expressions = []
projection_invoked_functions = set()
for sqlexpr in select.expressions:
invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(
sqlexpr
)
projection_invoked_functions.update(invoked_functions)
parsed_expr = sqlglot_parser.expression_from_sqlglot(sqlexpr)
projection_invoked_functions.update(parsed_expr.invoked_functions)
project_expressions.append(
proto.ExpressionReference(expression=expr, output_names=[output_name])
proto.ExpressionReference(
expression=parsed_expr.expression,
output_names=[parsed_expr.output_name],
)
)
extension_uris, extensions = catalog.extensions_for_functions(
projection_invoked_functions
Expand All @@ -59,17 +64,19 @@ def parse_sql_extended_expression(catalog, schema, sql):
)

# Handle WHERE clause in the SELECT statement.
invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot(
filter_parsed_expr = sqlglot_parser.expression_from_sqlglot(
select.find(sqlglot.expressions.Where).this
)
extension_uris, extensions = catalog.extensions_for_functions(
invoked_functions_filter
filter_parsed_expr.invoked_functions
)
filter_extended_expr = proto.ExtendedExpression(
extension_uris=extension_uris,
extensions=extensions,
base_schema=schema,
referred_expr=[proto.ExpressionReference(expression=filter_expr)],
referred_expr=[
proto.ExpressionReference(expression=filter_parsed_expr.expression)
],
)

return projection_extended_expr, filter_extended_expr
Expand All @@ -82,32 +89,29 @@ def __init__(self, functions_catalog, schema):
self._counter = itertools.count()

def expression_from_sqlglot(self, sqlglot_node):
invoked_functions = set()
output_name, _, substrait_expr = self._parse_expression(
sqlglot_node, invoked_functions
)
return invoked_functions, output_name, substrait_expr
"""Parse a SQLGlot expression into a Substrait Expression."""
return self._parse_expression(sqlglot_node)

def _parse_expression(self, expr, invoked_functions):
def _parse_expression(self, expr):
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return (
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
),
)
elif expr.is_int:
return (
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
)
elif sqlglot.helper.is_float(expr.name):
return (
return ParsedSubstraitExpression(
f"literal_{next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
Expand All @@ -120,7 +124,7 @@ def _parse_expression(self, expr, invoked_functions):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
return (
return ParsedSubstraitExpression(
column_name,
schema_type,
proto.Expression(
Expand All @@ -134,39 +138,47 @@ def _parse_expression(self, expr, invoked_functions):
),
)
elif isinstance(expr, sqlglot.expressions.Alias):
_, aliased_type, aliased_expr = self._parse_expression(
expr.this, invoked_functions
)
return expr.output_name, aliased_type, aliased_expr
parsed_expression = self._parse_expression(expr.this)
return parsed_expression.duplicate(output_name=expr.output_name)
elif expr.key in SQL_UNARY_FUNCTIONS:
argument_name, argument_type, argument = self._parse_expression(
expr.this, invoked_functions
)
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(function_name, argument_type, argument)
)
invoked_functions.add(signature)
result_name = f"{function_name}_{argument_name}_{next(self._counter)}"
return result_name, result_type, function_expression
elif expr.key in SQL_BINARY_FUNCTIONS:
left_name, left_type, left = self._parse_expression(
expr.left, invoked_functions
self._parse_function_invokation(
function_name,
argument_parsed_expr.type,
argument_parsed_expr.expression,
)
)
right_name, right_type, right = self._parse_expression(
expr.right, invoked_functions
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
argument_parsed_expr.invoked_functions | {signature},
)
elif expr.key in SQL_BINARY_FUNCTIONS:
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name, left_type, left, right_type, right
function_name,
left_parsed_expr.type,
left_parsed_expr.expression,
right_parsed_expr.type,
right_parsed_expr.expression,
)
)
invoked_functions.add(signature)
result_name = (
f"{left_name}_{function_name}_{right_name}_{next(self._counter)}"
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
left_parsed_expr.invoked_functions
| right_parsed_expr.invoked_functions
| {signature},
)
return result_name, result_type, function_expression
else:
raise ValueError(
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
Expand Down Expand Up @@ -215,3 +227,24 @@ def _parse_function_invokation(
)
),
)


class ParsedSubstraitExpression:
def __init__(self, output_name, type, expression, invoked_functions=None):
self.expression = expression
self.output_name = output_name
self.type = type

if invoked_functions is None:
invoked_functions = set()
self.invoked_functions = invoked_functions

def duplicate(
self, output_name=None, type=None, expression=None, invoked_functions=None
):
return ParsedSubstraitExpression(
output_name or self.output_name,
type or self.type,
expression or self.expression,
invoked_functions or self.invoked_functions,
)
5 changes: 4 additions & 1 deletion src/substrait/sql/functions_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def load(self, dirpath, filename):
for impl in function.get("impls", []):
# TODO: There seem to be some functions that have arguments without type. What to do?
# TODO: improve support complext type like LIST?<any>
argtypes = [t.get("value", "unknown").strip("?") for t in impl.get("args", [])]
argtypes = [
t.get("value", "unknown").strip("?")
for t in impl.get("args", [])
]
if not argtypes:
signature = function_name
else:
Expand Down

0 comments on commit a214297

Please sign in to comment.