Skip to content

Commit

Permalink
Simplify dataflow to SQL logic for JoinOverTimeRangeNode
Browse files Browse the repository at this point in the history
There should be no functional changes in this commit, only cleanup and readability improvements. Mostly involves moving complex logic to helper functions.
  • Loading branch information
courtneyholcomb committed Nov 20, 2024
1 parent 570e7f0 commit 8dae910
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 70 deletions.
6 changes: 3 additions & 3 deletions metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.protocols import MetricTimeWindow
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand All @@ -26,7 +26,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode):
time_range_constraint: Time range to aggregate over.
"""

queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...]
window: Optional[MetricTimeWindow]
grain_to_date: Optional[TimeGranularity]
time_range_constraint: Optional[TimeRangeConstraint]
Expand All @@ -38,7 +38,7 @@ def __post_init__(self) -> None: # noqa: D105
@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...],
window: Optional[MetricTimeWindow] = None,
grain_to_date: Optional[TimeGranularity] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
Expand Down
23 changes: 23 additions & 0 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
Expand Down Expand Up @@ -154,3 +155,25 @@ def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensi
@override
def semantic_model_reference(self) -> Optional[SemanticModelReference]:
return None

def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet:
"""Convert to an AnnotatedSqlDataSet with specified metadata."""
metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name
return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name)


@dataclass(frozen=True)
class AnnotatedSqlDataSet:
"""Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan."""

data_set: SqlDataSet
alias: str
_metric_time_column_name: Optional[str] = None

@property
def metric_time_column_name(self) -> str:
"""Direct accessor for the optional metric time name, only safe to call when we know that value is set."""
assert (
self._metric_time_column_name
), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!"
return self._metric_time_column_name
84 changes: 35 additions & 49 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,70 +466,41 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet:
"""Generate time range join SQL."""
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()
input_data_set = node.parent_node.accept(self)
input_data_set_alias = self._next_unique_table_alias()
parent_data_set = node.parent_node.accept(self)
parent_data_set_alias = self._next_unique_table_alias()

# Find requested agg_time_dimensions in parent instance set.
# Will use instance with the smallest base granularity in time spine join.
agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None
requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
for instance in input_data_set.instance_set.time_dimension_instances:
if instance.spec in node.queried_agg_time_dimension_specs:
requested_agg_time_dimension_instances += (instance,)
if not agg_time_dimension_instance_for_join or (
instance.spec.time_granularity.base_granularity.to_int()
< agg_time_dimension_instance_for_join.spec.time_granularity.base_granularity.to_int()
):
agg_time_dimension_instance_for_join = instance
assert (
agg_time_dimension_instance_for_join
), "Specified metric time spec not found in parent data set. This should have been caught by validations."
# For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan.
agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs})

# Assemble time_spine dataset with a column for each agg_time_dimension requested.
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs)
time_spine_data_set_alias = self._next_unique_table_alias()

# Assemble time_spine dataset with requested agg time dimension instances selected.
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instances=requested_agg_time_dimension_instances,
time_range_constraint=node.time_range_constraint,
agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint
)
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set

# Build the join description.
join_spec = self._choose_instance_for_time_spine_join(agg_time_dimension_instances).spec
annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec)
annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec)
join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description(
node=node,
metric_data_set=AnnotatedSqlDataSet(
data_set=input_data_set,
alias=input_data_set_alias,
_metric_time_column_name=input_data_set.column_association_for_time_dimension(
agg_time_dimension_instance_for_join.spec
).column_name,
),
time_spine_data_set=AnnotatedSqlDataSet(
data_set=time_spine_data_set,
alias=time_spine_data_set_alias,
_metric_time_column_name=time_spine_data_set.column_association_for_time_dimension(
agg_time_dimension_instance_for_join.spec
).column_name,
),
node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine
)

# Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances.
agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances)
modified_input_instance_set = input_data_set.instance_set.transform(
# Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine.
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set
table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs))
)
table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set

# The output instances are the same as the input instances.
output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform(
input_data_set.instance_set
select_columns = create_simple_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set
)

return SqlDataSet(
instance_set=output_instance_set,
instance_set=parent_data_set.instance_set, # The output instances are the same as the input instances.
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=create_simple_select_columns_for_instance_sets(
self._column_association_resolver, table_alias_to_instance_set
),
select_columns=select_columns,
from_source=time_spine_data_set.checked_sql_select_node,
from_source_alias=time_spine_data_set_alias,
join_descs=(join_desc,),
Expand Down Expand Up @@ -1390,6 +1361,21 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
),
)

def _choose_instance_for_time_spine_join(
self, agg_time_dimension_instances: Sequence[TimeDimensionInstance]
) -> TimeDimensionInstance:
"""Find the agg_time_dimension instance with the smallest grain to use for the time spine join."""
# We can't use a date part spec to join to the time spine, so filter those out.
agg_time_dimension_instances = [
instance for instance in agg_time_dimension_instances if not instance.spec.date_part
]
assert len(agg_time_dimension_instances) > 0, (
"No appropriate agg_time_dimension was found to join to the time spine. "
"This indicates that the dataflow plan was configured incorrectly."
)
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
return agg_time_dimension_instances[0]

def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()
Expand Down
1 change: 1 addition & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
)


# TODO: delete this class & all uses. It doesn't do anything.
class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]):
"""Change the columns associated with instances to the one specified by the resolver.
Expand Down
19 changes: 1 addition & 18 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet
from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr
from metricflow.sql.sql_exprs import (
SqlColumnReference,
Expand Down Expand Up @@ -45,23 +45,6 @@ class ColumnEqualityDescription:
treat_nulls_as_equal: bool = False


@dataclass(frozen=True)
class AnnotatedSqlDataSet:
"""Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan."""

data_set: SqlDataSet
alias: str
_metric_time_column_name: Optional[str] = None

@property
def metric_time_column_name(self) -> str:
"""Direct accessor for the optional metric time name, only safe to call when we know that value is set."""
assert (
self._metric_time_column_name
), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!"
return self._metric_time_column_name


class SqlQueryPlanJoinBuilder:
"""Helper class for constructing various join components in a SqlQueryPlan."""

Expand Down

0 comments on commit 8dae910

Please sign in to comment.