Skip to content

Commit

Permalink
Add tests for CTE generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent e81879e commit a7d043d
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions tests_metricflow/plan_conversion/dataflow_to_sql/test_cte_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

import logging
from typing import FrozenSet, Mapping

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import (
assert_str_snapshot_equal,
make_schema_replacement_function,
)

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
)
from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.sql.optimizer.optimization_levels import SqlGenerationOptionSet, SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup

logger = logging.getLogger(__name__)


def convert_and_check(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
node: DataflowPlanNode,
nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode],
) -> None:
"""Convert the dataflow plan to SQL and compare with snapshots."""
# Generate without CTEs
optimizers = SqlGenerationOptionSet.options_for_level(
SqlQueryOptimizationLevel.O5, use_column_alias_in_group_by=False
).optimizers
conversion_result = dataflow_to_sql_converter.convert_using_specifics(
dataflow_plan_node=node,
sql_query_plan_id=None,
optimizers=optimizers,
nodes_to_convert_to_cte=frozenset(),
)
sql_plan_without_cte = conversion_result.sql_plan

# Generate with CTEs
conversion_result = dataflow_to_sql_converter.convert_using_specifics(
dataflow_plan_node=node,
sql_query_plan_id=None,
optimizers=optimizers,
nodes_to_convert_to_cte=nodes_to_convert_to_cte,
)
sql_plan_with_cte = conversion_result.sql_plan
renderer = DefaultSqlQueryPlanRenderer()

lines = [
"sql_without_cte:",
indent(renderer.render_sql_query_plan(sql_plan_without_cte).sql),
"\n",
"sql_with_cte:",
indent(renderer.render_sql_query_plan(sql_plan_with_cte).sql),
]

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str="\n".join(lines),
incomparable_strings_replacement_function=make_schema_replacement_function(
system_schema=mf_test_configuration.mf_system_schema, source_schema=mf_test_configuration.mf_source_schema
),
)


def test_cte_for_simple_dataflow_plan(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture],
) -> None:
"""Test a simple case for generating a CTE for a specific dataflow plan node."""
measure_spec = MeasureSpec(
element_name="bookings",
)
source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[
"bookings_source"
]
filter_node = FilterElementsNode.create(
parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,))
)

convert_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
node=filter_node,
nodes_to_convert_to_cte=frozenset(
[
source_node,
]
),
)


def test_cte_for_shared_metrics(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_association_resolver: ColumnAssociationResolver,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
) -> None:
"""Check common branches in a query that uses derived metrics defined from metrics that are also in the query."""
parse_result = query_parser.parse_and_validate_query(
metric_names=("bookings", "views", "bookings_per_view"),
group_by_names=("metric_time", "listing__country_latest"),
)
dataflow_plan = dataflow_plan_builder.build_plan(parse_result.query_spec)
common_nodes = DataflowPlanAnalyzer.find_common_branches(dataflow_plan)
# metric_nodes = find_metric_nodes(dataflow_plan.sink_node)

convert_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
node=dataflow_plan.sink_node,
nodes_to_convert_to_cte=frozenset(common_nodes),
)

0 comments on commit a7d043d

Please sign in to comment.