Skip to content

Commit

Permalink
Dynamic dispatch of parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed Apr 18, 2024
1 parent ffbaf59 commit de51c99
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 72 deletions.
156 changes: 84 additions & 72 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sqlglot

from substrait import proto
from .utils import DispatchRegistry


SQL_UNARY_FUNCTIONS = {"not": "not"}
SQL_BINARY_FUNCTIONS = {
Expand Down Expand Up @@ -83,6 +85,8 @@ def parse_sql_extended_expression(catalog, schema, sql):


class SQLGlotParser:
DISPATCH_REGISTRY = DispatchRegistry()

def __init__(self, functions_catalog, schema):
self._functions_catalog = functions_catalog
self._schema = schema
Expand All @@ -99,88 +103,96 @@ def _parse_expression(self, expr):
invoked in a recursive manner to parse the whole
expression tree.
"""
if isinstance(expr, sqlglot.expressions.Literal):
if expr.is_string:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
literal=proto.Expression.Literal(string=expr.text)
),
)
elif expr.is_int:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
)
elif sqlglot.helper.is_float(expr.name):
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
),
)
else:
raise ValueError(f"Unsupporter literal: {expr.text}")
elif isinstance(expr, sqlglot.expressions.Column):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
expr_class = expr.__class__
return self.DISPATCH_REGISTRY[expr_class](self, expr)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Literal)
def _parse_Literal(self, expr):
if expr.is_string:
return ParsedSubstraitExpression(
column_name,
schema_type,
f"literal${next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
)
literal=proto.Expression.Literal(string=expr.text)
),
)
elif isinstance(expr, sqlglot.expressions.Alias):
parsed_expression = self._parse_expression(expr.this)
return parsed_expression.duplicate(output_name=expr.output_name)
elif expr.key in SQL_UNARY_FUNCTIONS:
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(function_name, argument_parsed_expr)
)
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
elif expr.is_int:
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
argument_parsed_expr.invoked_functions | {signature},
)
elif expr.key in SQL_BINARY_FUNCTIONS:
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
f"literal${next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(
literal=proto.Expression.Literal(i32=int(expr.name))
),
)
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
elif sqlglot.helper.is_float(expr.name):
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
left_parsed_expr.invoked_functions
| right_parsed_expr.invoked_functions
| {signature},
f"literal${next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
),
)
else:
raise ValueError(
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
raise ValueError(f"Unsupporter literal: {expr.text}")

@DISPATCH_REGISTRY.register(sqlglot.expressions.Column)
def _parse_Column(self, expr):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
return ParsedSubstraitExpression(
column_name,
schema_type,
proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
)
),
)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Alias)
def _parse_Alias(self, expr):
parsed_expression = self._parse_expression(expr.this)
return parsed_expression.duplicate(output_name=expr.output_name)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Binary)
def _parser_Binary(self, expr):
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_BINARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
)
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
left_parsed_expr.invoked_functions
| right_parsed_expr.invoked_functions
| {signature},
)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Unary)
def _parse_Unary(self, expr):
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_UNARY_FUNCTIONS[expr.key]
signature, result_type, function_expression = (
self._parse_function_invokation(function_name, argument_parsed_expr)
)
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
argument_parsed_expr.invoked_functions | {signature},
)

def _parse_function_invokation(
self, function_name, argument_parsed_expr, *additional_arguments
Expand Down
16 changes: 16 additions & 0 deletions src/substrait/sql/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class DispatchRegistry:
def __init__(self):
self._registry = {}

def register(self, cls):
def decorator(func):
self._registry[cls] = func
return func
return decorator

def __getitem__(self, cls):
for dispatch_cls, func in self._registry.items():
if issubclass(cls, dispatch_cls):
return func
else:
raise ValueError(f"Unsupported SQL Node type: {cls}")

0 comments on commit de51c99

Please sign in to comment.