From 29e2a10841a6c2a2820366ab0f3fdc356919a605 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Thu, 17 Oct 2024 09:05:16 -0700 Subject: [PATCH] Implement `Mergeable` for `SqlBindParameterSet` (#1460) `SqlBindParameterSet` implements `combine()`, which is similar to the interface for `Mergeable` objects. This PR updates `SqlBindParameterSet` to be a subclass of `Mergeable` for consistency and the ability to use convenience methods for `Mergeable` objects in later PRs. --- .../specs/where_filter/where_filter_spec.py | 2 +- .../sql/sql_bind_parameters.py | 22 +++++++++++++------ metricflow/sql/render/expr_renderer.py | 20 ++++++++--------- metricflow/sql/render/sql_plan_renderer.py | 20 ++++++++--------- metricflow/sql/render/trino.py | 6 ++--- .../sql_clients/test_sql_client.py | 4 ++-- 6 files changed, 41 insertions(+), 33 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py index 7b36234c07..694c15ec8a 100644 --- a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py @@ -69,7 +69,7 @@ def merge(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D102 return WhereFilterSpec( where_sql=f"({self.where_sql}) AND ({other.where_sql})", - bind_parameters=self.bind_parameters.combine(other.bind_parameters), + bind_parameters=self.bind_parameters.merge(other.bind_parameters), linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set).dedupe(), linkable_element_unions=ordered_dedupe(self.linkable_element_unions, other.linkable_element_unions), ) diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py b/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py index 346a8ca826..6b3d932dd7 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py @@ -6,8 +6,10 @@ from typing import Any, Mapping, Optional, Tuple from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass +from typing_extensions import override from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set +from metricflow_semantics.collection_helpers.merger import Mergeable from metricflow_semantics.sql.sql_column_type import SqlColumnType @@ -74,7 +76,7 @@ class SqlBindParameter(SerializableDataclass): # noqa: D101 @dataclass(frozen=True) -class SqlBindParameterSet(SerializableDataclass): +class SqlBindParameterSet(SerializableDataclass, Mergeable): """Helps to build execution parameters during SQL query rendering. These can be used as per https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql @@ -83,16 +85,22 @@ class SqlBindParameterSet(SerializableDataclass): # Using tuples for immutability as dicts are not. param_items: Tuple[SqlBindParameter, ...] = () - def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet: - """Create a new set of bind parameters that includes parameters from this and additional_params.""" + @classmethod + @override + def empty_instance(cls) -> SqlBindParameterSet: + return SqlBindParameterSet() + + @override + def merge(self, other: SqlBindParameterSet) -> SqlBindParameterSet: + """Create a new set of bind parameters that includes parameters from this and other.""" if len(self.param_items) == 0: - return additional_params + return other - if len(additional_params.param_items) == 0: + if len(other.param_items) == 0: return self self_dict = {item.key: item.value for item in self.param_items} - other_dict = {item.key: item.value for item in additional_params.param_items} + other_dict = {item.key: item.value for item in other.param_items} for key, value in other_dict.items(): if key in self_dict and self_dict[key] != value: @@ -102,7 +110,7 @@ def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet ) new_items = list(self.param_items) included_keys = set(item.key for item in new_items) - for item in additional_params.param_items: + for item in other.param_items: if item.key in included_keys: continue new_items.append(item) diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index c6cd709360..0fd9e81e86 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -143,10 +143,10 @@ def visit_comparison_expr(self, node: SqlComparisonExpression) -> SqlExpressionR combined_params = SqlBindParameterSet() left_expr_rendered = self.render_sql_expr(node.left_expr) - combined_params = combined_params.combine(left_expr_rendered.bind_parameter_set) + combined_params = combined_params.merge(left_expr_rendered.bind_parameter_set) right_expr_rendered = self.render_sql_expr(node.right_expr) - combined_params = combined_params.combine(right_expr_rendered.bind_parameter_set) + combined_params = combined_params.merge(right_expr_rendered.bind_parameter_set) # To avoid issues with operator precedence, use parenthesis to group the left / right expressions if they # contain operators. @@ -165,7 +165,7 @@ def visit_function_expr(self, node: SqlAggregateFunctionExpression) -> SqlExpres args_rendered = [self.render_sql_expr(x) for x in node.sql_function_args] combined_params = SqlBindParameterSet() for arg_rendered in args_rendered: - combined_params = combined_params.combine(arg_rendered.bind_parameter_set) + combined_params = combined_params.merge(arg_rendered.bind_parameter_set) distinct_prefix = "DISTINCT " if SqlFunction.is_distinct_aggregation(node.sql_function) else "" args_string = ", ".join([x.sql for x in args_rendered]) @@ -202,7 +202,7 @@ def visit_logical_expr(self, node: SqlLogicalExpression) -> SqlExpressionRenderR can_be_rendered_in_one_line = sum(len(x.expr.sql) for x in args_rendered) < 60 for arg_rendered in args_rendered: - combined_parameters.combine(arg_rendered.expr.bind_parameter_set) + combined_parameters.merge(arg_rendered.expr.bind_parameter_set) arg_sql = self._render_logical_arg( arg_rendered.expr, arg_rendered.requires_parenthesis, render_in_one_line=can_be_rendered_in_one_line ) @@ -329,8 +329,8 @@ def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> S denominator_sql = f"CAST(NULLIF({rendered_denominator.sql}, 0) AS {self.double_data_type})" bind_parameter_set = SqlBindParameterSet() - bind_parameter_set = bind_parameter_set.combine(rendered_numerator.bind_parameter_set) - bind_parameter_set = bind_parameter_set.combine(rendered_denominator.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_numerator.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_denominator.bind_parameter_set) return SqlExpressionRenderResult( sql=f"{numerator_sql} / {denominator_sql}", @@ -343,9 +343,9 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR rendered_end_expr = self.render_sql_expr(node.end_expr) bind_parameter_set = SqlBindParameterSet() - bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set) - bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set) - bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_column_arg.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_start_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_end_expr.bind_parameter_set) return SqlExpressionRenderResult( sql=f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}", @@ -366,7 +366,7 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx if order_by_args_rendered: args_rendered.extend(list(order_by_args_rendered.keys())) for arg_rendered in args_rendered: - combined_params = combined_params.combine(arg_rendered.bind_parameter_set) + combined_params = combined_params.merge(arg_rendered.bind_parameter_set) sql_function_args_string = ", ".join([x.sql for x in sql_function_args_rendered]) window_string_lines: List[str] = [] diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 39523a7f8a..b493df346a 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -93,7 +93,7 @@ def _render_select_columns_section( for select_column in select_columns: expr_rendered = self.EXPR_RENDERER.render_sql_expr(select_column.expr) # Merge all execution parameters together. Similar pattern follows below. - params = params.combine(expr_rendered.bind_parameter_set) + params = params.merge(expr_rendered.bind_parameter_set) column_select_str = f"{expr_rendered.sql} AS {select_column.column_alias}" @@ -165,13 +165,13 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) for join_description in join_descriptions: # Render the source for the join right_source_rendered = self._render_node(join_description.right_source) - params = params.combine(right_source_rendered.bind_parameter_set) + params = params.merge(right_source_rendered.bind_parameter_set) # Render the on condition for the join on_condition_rendered: Optional[SqlExpressionRenderResult] = None if join_description.on_condition: on_condition_rendered = self.EXPR_RENDERER.render_sql_expr(join_description.on_condition) - params = params.combine(on_condition_rendered.bind_parameter_set) + params = params.merge(on_condition_rendered.bind_parameter_set) if join_description.right_source.is_table: join_section_lines.append(join_description.join_type.value) @@ -210,7 +210,7 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) first = True for group_by_column in group_by_columns: group_by_expr_rendered = self.EXPR_RENDERER.render_group_by_expr(group_by_column) - params = params.combine(group_by_expr_rendered.bind_parameter_set) + params = params.merge(group_by_expr_rendered.bind_parameter_set) if first: first = False group_by_section_lines.append("GROUP BY") @@ -235,25 +235,25 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe select_section, select_params = self._render_select_columns_section( node.select_columns, len(node.parent_nodes), node.distinct ) - combined_params = combined_params.combine(select_params) + combined_params = combined_params.merge(select_params) # Render "FROM" section from_section, from_params = self._render_from_section(node.from_source, node.from_source_alias) - combined_params = combined_params.combine(from_params) + combined_params = combined_params.merge(from_params) # Render "JOIN" section join_section, join_params = self._render_joins_section(node.join_descs) - combined_params = combined_params.combine(join_params) + combined_params = combined_params.merge(join_params) # Render "GROUP BY" section group_by_section, group_by_params = self._render_group_by_section(node.group_bys) - combined_params = combined_params.combine(group_by_params) + combined_params = combined_params.merge(group_by_params) # Render "WHERE" section where_section = None if node.where: where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where) - combined_params = combined_params.combine(where_render_result.bind_parameter_set) + combined_params = combined_params.merge(where_render_result.bind_parameter_set) where_section = f"WHERE {where_render_result.sql}" # Render "ORDER BY" section @@ -263,7 +263,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe for order_by in node.order_bys: order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr) order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else "")) - combined_params = combined_params.combine(order_by_render_result.bind_parameter_set) + combined_params = combined_params.merge(order_by_render_result.bind_parameter_set) order_by_section = "ORDER BY " + ", ".join(order_by_items) diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index 4ecc72282b..23bd65ab66 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -91,9 +91,9 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR rendered_end_expr = self.render_sql_expr(node.end_expr) bind_parameter_set = SqlBindParameterSet() - bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set) - bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set) - bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_column_arg.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_start_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.merge(rendered_end_expr.bind_parameter_set) # Handle timestamp literals differently. if parse(rendered_start_expr.sql): diff --git a/tests_metricflow/sql_clients/test_sql_client.py b/tests_metricflow/sql_clients/test_sql_client.py index 068e2a111c..dccc9ea4d3 100644 --- a/tests_metricflow/sql_clients/test_sql_client.py +++ b/tests_metricflow/sql_clients/test_sql_client.py @@ -125,7 +125,7 @@ def test_update_params_with_same_item() -> None: # noqa: D103 bind_params0 = SqlBindParameterSet.create_from_dict({"key": "value"}) bind_params1 = SqlBindParameterSet.create_from_dict({"key": "value"}) - bind_params0.combine(bind_params1) + bind_params0.merge(bind_params1) def test_update_params_with_same_key_different_values() -> None: # noqa: D103 @@ -133,4 +133,4 @@ def test_update_params_with_same_key_different_values() -> None: # noqa: D103 bind_params1 = SqlBindParameterSet.create_from_dict(({"key": "value1"})) with pytest.raises(RuntimeError): - bind_params0.combine(bind_params1) + bind_params0.merge(bind_params1)