Skip to content

Commit

Permalink
Add method to figure out the common branches in a dataflow plan.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent ec7c29b commit 93704a3
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 1 deletion.
2 changes: 1 addition & 1 deletion metricflow-semantics/metricflow_semantics/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import TypeVar

VisitorOutputT = TypeVar("VisitorOutputT")
VisitorOutputT = TypeVar("VisitorOutputT", covariant=True)


class Visitable(ABC):
Expand Down
79 changes: 79 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from collections import defaultdict
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler


class DataflowPlanAnalyzer:
"""Class to determine more complex properties of the dataflow plan.
These could also be made as member methods of the dataflow plan, but this requires resolving some circular
dependency issues to break out the functionality into separate files.
"""

@staticmethod
def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]:
"""Starting from the sink node, find the common branches that exist in the associated DAG.
Returns a sorted sequence for reproducibility.
"""
counting_visitor = _CountDataflowNodeVisitor()
dataflow_plan.sink_node.accept(counting_visitor)

node_to_common_count = counting_visitor.get_node_counts()

common_nodes = []
for node, count in node_to_common_count.items():
if count > 1:
common_nodes.append(node)

common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes))

return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))


class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""

def __init__(self) -> None:
self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int)

def get_node_counts(self) -> Mapping[DataflowPlanNode, int]:
return self._node_to_count

@override
def _default_handler(self, node: DataflowPlanNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
self._node_to_count[node] += 1


class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]):
"""Given the nodes that are known to appear in the DAG multiple times, find the common branches.
To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D`
and `C -> D` can be considered common branches, and we want the largest one), this uses preorder traversal and
returns the first common node that is seen.
"""

def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None:
self._common_nodes = common_nodes

@override
def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]:
# Traversal starts from the leaf node and then goes to the parent branches. By doing this check first, we don't
# return smaller common branches that are a part of a larger common branch.
if node in self._common_nodes:
return frozenset({node})

common_branch_leaf_nodes: Set[DataflowPlanNode] = set()

for parent_node in node.parent_nodes:
common_branch_leaf_nodes.update(parent_node.accept(self))

return frozenset(common_branch_leaf_nodes)
54 changes: 54 additions & 0 deletions tests_metricflow/sql/test_common_dataflow_branches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import logging

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import (
assert_str_snapshot_equal,
)

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer

logger = logging.getLogger(__name__)


def test_shared_metric_query(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_association_resolver: ColumnAssociationResolver,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
) -> None:
"""For a known case, test that a metric computation node is identified as a common branch.
A query for `bookings` and `bookings_per_booker` should have the computation for `bookings` as a common branch in
the dataflow plan.
"""
parse_result = query_parser.parse_and_validate_query(
metric_names=("bookings", "bookings_per_booker"),
group_by_names=("metric_time",),
)
dataflow_plan = dataflow_plan_builder.build_plan(parse_result.query_spec)

obj_dict = {
"dataflow_plan": dataflow_plan.structure_text(),
}

common_branch_leaf_nodes = DataflowPlanAnalyzer.find_common_branches(dataflow_plan)
for i, common_branch_leaf_node in enumerate(sorted(common_branch_leaf_nodes)):
obj_dict[f"common_branch_{i}"] = common_branch_leaf_node.structure_text()

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=mf_pformat_dict(
obj_dict=obj_dict,
preserve_raw_strings=True,
),
)

0 comments on commit 93704a3

Please sign in to comment.