Skip to content

Commit

Permalink
Implement Mergeable for SqlBindParameterSet (#1460)
Browse files Browse the repository at this point in the history
`SqlBindParameterSet` implements `combine()`, which is similar to the
interface for `Mergeable` objects. This PR updates `SqlBindParameterSet`
to be a subclass of `Mergeable` for consistency and the ability to use
convenience methods for `Mergeable` objects in later PRs.
  • Loading branch information
plypaul authored Oct 17, 2024
1 parent d56c603 commit 29e2a10
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102

return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
bind_parameters=self.bind_parameters.merge(other.bind_parameters),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(),
linkable_element_unions=ordered_dedupe(self.linkable_element_unions, other.linkable_element_unions),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import Any, Mapping, Optional, Tuple

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from typing_extensions import override

from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.sql.sql_column_type import SqlColumnType


Expand Down Expand Up @@ -74,7 +76,7 @@ class SqlBindParameter(SerializableDataclass): # noqa: D101


@dataclass(frozen=True)
class SqlBindParameterSet(SerializableDataclass):
class SqlBindParameterSet(SerializableDataclass, Mergeable):
"""Helps to build execution parameters during SQL query rendering.
These can be used as per https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql
Expand All @@ -83,16 +85,22 @@ class SqlBindParameterSet(SerializableDataclass):
# Using tuples for immutability as dicts are not.
param_items: Tuple[SqlBindParameter, ...] = ()

def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet:
"""Create a new set of bind parameters that includes parameters from this and additional_params."""
@classmethod
@override
def empty_instance(cls) -> SqlBindParameterSet:
return SqlBindParameterSet()

@override
def merge(self, other: SqlBindParameterSet) -> SqlBindParameterSet:
"""Create a new set of bind parameters that includes parameters from this and other."""
if len(self.param_items) == 0:
return additional_params
return other

if len(additional_params.param_items) == 0:
if len(other.param_items) == 0:
return self

self_dict = {item.key: item.value for item in self.param_items}
other_dict = {item.key: item.value for item in additional_params.param_items}
other_dict = {item.key: item.value for item in other.param_items}

for key, value in other_dict.items():
if key in self_dict and self_dict[key] != value:
Expand All @@ -102,7 +110,7 @@ def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet
)
new_items = list(self.param_items)
included_keys = set(item.key for item in new_items)
for item in additional_params.param_items:
for item in other.param_items:
if item.key in included_keys:
continue
new_items.append(item)
Expand Down
20 changes: 10 additions & 10 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def visit_comparison_expr(self, node: SqlComparisonExpression) -> SqlExpressionR
combined_params = SqlBindParameterSet()

left_expr_rendered = self.render_sql_expr(node.left_expr)
combined_params = combined_params.combine(left_expr_rendered.bind_parameter_set)
combined_params = combined_params.merge(left_expr_rendered.bind_parameter_set)

right_expr_rendered = self.render_sql_expr(node.right_expr)
combined_params = combined_params.combine(right_expr_rendered.bind_parameter_set)
combined_params = combined_params.merge(right_expr_rendered.bind_parameter_set)

# To avoid issues with operator precedence, use parenthesis to group the left / right expressions if they
# contain operators.
Expand All @@ -165,7 +165,7 @@ def visit_function_expr(self, node: SqlAggregateFunctionExpression) -> SqlExpres
args_rendered = [self.render_sql_expr(x) for x in node.sql_function_args]
combined_params = SqlBindParameterSet()
for arg_rendered in args_rendered:
combined_params = combined_params.combine(arg_rendered.bind_parameter_set)
combined_params = combined_params.merge(arg_rendered.bind_parameter_set)

distinct_prefix = "DISTINCT " if SqlFunction.is_distinct_aggregation(node.sql_function) else ""
args_string = ", ".join([x.sql for x in args_rendered])
Expand Down Expand Up @@ -202,7 +202,7 @@ def visit_logical_expr(self, node: SqlLogicalExpression) -> SqlExpressionRenderR
can_be_rendered_in_one_line = sum(len(x.expr.sql) for x in args_rendered) < 60

for arg_rendered in args_rendered:
combined_parameters.combine(arg_rendered.expr.bind_parameter_set)
combined_parameters.merge(arg_rendered.expr.bind_parameter_set)
arg_sql = self._render_logical_arg(
arg_rendered.expr, arg_rendered.requires_parenthesis, render_in_one_line=can_be_rendered_in_one_line
)
Expand Down Expand Up @@ -329,8 +329,8 @@ def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> S
denominator_sql = f"CAST(NULLIF({rendered_denominator.sql}, 0) AS {self.double_data_type})"

bind_parameter_set = SqlBindParameterSet()
bind_parameter_set = bind_parameter_set.combine(rendered_numerator.bind_parameter_set)
bind_parameter_set = bind_parameter_set.combine(rendered_denominator.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_numerator.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_denominator.bind_parameter_set)

return SqlExpressionRenderResult(
sql=f"{numerator_sql} / {denominator_sql}",
Expand All @@ -343,9 +343,9 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR
rendered_end_expr = self.render_sql_expr(node.end_expr)

bind_parameter_set = SqlBindParameterSet()
bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set)
bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_column_arg.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_start_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_end_expr.bind_parameter_set)

return SqlExpressionRenderResult(
sql=f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}",
Expand All @@ -366,7 +366,7 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx
if order_by_args_rendered:
args_rendered.extend(list(order_by_args_rendered.keys()))
for arg_rendered in args_rendered:
combined_params = combined_params.combine(arg_rendered.bind_parameter_set)
combined_params = combined_params.merge(arg_rendered.bind_parameter_set)

sql_function_args_string = ", ".join([x.sql for x in sql_function_args_rendered])
window_string_lines: List[str] = []
Expand Down
20 changes: 10 additions & 10 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _render_select_columns_section(
for select_column in select_columns:
expr_rendered = self.EXPR_RENDERER.render_sql_expr(select_column.expr)
# Merge all execution parameters together. Similar pattern follows below.
params = params.combine(expr_rendered.bind_parameter_set)
params = params.merge(expr_rendered.bind_parameter_set)

column_select_str = f"{expr_rendered.sql} AS {select_column.column_alias}"

Expand Down Expand Up @@ -165,13 +165,13 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])
for join_description in join_descriptions:
# Render the source for the join
right_source_rendered = self._render_node(join_description.right_source)
params = params.combine(right_source_rendered.bind_parameter_set)
params = params.merge(right_source_rendered.bind_parameter_set)

# Render the on condition for the join
on_condition_rendered: Optional[SqlExpressionRenderResult] = None
if join_description.on_condition:
on_condition_rendered = self.EXPR_RENDERER.render_sql_expr(join_description.on_condition)
params = params.combine(on_condition_rendered.bind_parameter_set)
params = params.merge(on_condition_rendered.bind_parameter_set)

if join_description.right_source.is_table:
join_section_lines.append(join_description.join_type.value)
Expand Down Expand Up @@ -210,7 +210,7 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn])
first = True
for group_by_column in group_by_columns:
group_by_expr_rendered = self.EXPR_RENDERER.render_group_by_expr(group_by_column)
params = params.combine(group_by_expr_rendered.bind_parameter_set)
params = params.merge(group_by_expr_rendered.bind_parameter_set)
if first:
first = False
group_by_section_lines.append("GROUP BY")
Expand All @@ -235,25 +235,25 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe
select_section, select_params = self._render_select_columns_section(
node.select_columns, len(node.parent_nodes), node.distinct
)
combined_params = combined_params.combine(select_params)
combined_params = combined_params.merge(select_params)

# Render "FROM" section
from_section, from_params = self._render_from_section(node.from_source, node.from_source_alias)
combined_params = combined_params.combine(from_params)
combined_params = combined_params.merge(from_params)

# Render "JOIN" section
join_section, join_params = self._render_joins_section(node.join_descs)
combined_params = combined_params.combine(join_params)
combined_params = combined_params.merge(join_params)

# Render "GROUP BY" section
group_by_section, group_by_params = self._render_group_by_section(node.group_bys)
combined_params = combined_params.combine(group_by_params)
combined_params = combined_params.merge(group_by_params)

# Render "WHERE" section
where_section = None
if node.where:
where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where)
combined_params = combined_params.combine(where_render_result.bind_parameter_set)
combined_params = combined_params.merge(where_render_result.bind_parameter_set)
where_section = f"WHERE {where_render_result.sql}"

# Render "ORDER BY" section
Expand All @@ -263,7 +263,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe
for order_by in node.order_bys:
order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr)
order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else ""))
combined_params = combined_params.combine(order_by_render_result.bind_parameter_set)
combined_params = combined_params.merge(order_by_render_result.bind_parameter_set)

order_by_section = "ORDER BY " + ", ".join(order_by_items)

Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR
rendered_end_expr = self.render_sql_expr(node.end_expr)

bind_parameter_set = SqlBindParameterSet()
bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set)
bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_column_arg.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_start_expr.bind_parameter_set)
bind_parameter_set = bind_parameter_set.merge(rendered_end_expr.bind_parameter_set)

# Handle timestamp literals differently.
if parse(rendered_start_expr.sql):
Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/sql_clients/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def test_update_params_with_same_item() -> None: # noqa: D103
bind_params0 = SqlBindParameterSet.create_from_dict({"key": "value"})
bind_params1 = SqlBindParameterSet.create_from_dict({"key": "value"})

bind_params0.combine(bind_params1)
bind_params0.merge(bind_params1)


def test_update_params_with_same_key_different_values() -> None: # noqa: D103
bind_params0 = SqlBindParameterSet.create_from_dict(({"key": "value0"}))
bind_params1 = SqlBindParameterSet.create_from_dict(({"key": "value1"}))

with pytest.raises(RuntimeError):
bind_params0.combine(bind_params1)
bind_params0.merge(bind_params1)

0 comments on commit 29e2a10

Please sign in to comment.