Skip to content

Commit

Permalink
Bug fix: treat metric_time and agg_time the same when both are incldu…
Browse files Browse the repository at this point in the history
…ed in JoinToTimeSpineNode

We had previously made a product decision about the behavior of this node that we later decided was not correct. If metric_time or the agg_time_dimension were requested on their own in the JoinToTimeSpineNode, they would each be treated the same. But if both metric_time and the agg_time_dimension were requested, we would select metric_time from time spine, then treat the agg_time_dimension like any other dimension and select it from the parent. We later decided this behavior was inconsistent. This fixes that, treating them the same and selecting both from the time spine.
  • Loading branch information
courtneyholcomb committed Nov 20, 2024
1 parent 7a1a3fe commit af2375c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 51 deletions.
6 changes: 6 additions & 0 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def with_entity_prefix(
spec=transformed_spec,
)

def with_new_defined_from(self, defined_from: Tuple[SemanticModelElementReference, ...]) -> TimeDimensionInstance:
"""Returns a new instance with the defined_from field replaced."""
return TimeDimensionInstance(
associated_columns=self.associated_columns, defined_from=defined_from, spec=self.spec
)


@dataclass(frozen=True)
class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
Expand Down
83 changes: 32 additions & 51 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
Expand Down Expand Up @@ -469,11 +468,10 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
parent_data_set = node.parent_node.accept(self)
parent_data_set_alias = self._next_unique_table_alias()

# For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan.
agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs})

# Assemble time_spine dataset with a column for each agg_time_dimension requested.
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs)
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(
node.queried_agg_time_dimension_specs
)
time_spine_data_set_alias = self._next_unique_table_alias()
time_spine_data_set = self._make_time_spine_data_set(
agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint
Expand All @@ -490,7 +488,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
# Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine.
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set
table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs))
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=node.queried_agg_time_dimension_specs))
)
select_columns = create_simple_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set
Expand Down Expand Up @@ -1380,33 +1378,30 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()

if node.use_custom_agg_time_dimension:
agg_time_dimension = node.requested_agg_time_dimension_specs[0]
agg_time_element_name = agg_time_dimension.element_name
agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links
else:
agg_time_element_name = METRIC_TIME_ELEMENT_NAME
agg_time_entity_links = ()
agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(
node.requested_agg_time_dimension_specs
)

# Find the time dimension instances in the parent data set that match the one we want to join with.
agg_time_dimension_instances: List[TimeDimensionInstance] = []
for instance in parent_data_set.instance_set.time_dimension_instances:
if (
instance.spec.date_part is None # Ensure we don't join using an instance with date part
and instance.spec.element_name == agg_time_element_name
and instance.spec.entity_links == agg_time_entity_links
):
agg_time_dimension_instances.append(instance)
# Select the dimension for the join from the parent node because it may not have been included in the request.
# Default to using metric_time for the join if it was requested, otherwise use the agg_time_dimension.
included_metric_time_instances = [
instance for instance in agg_time_dimension_instances if instance.spec.is_metric_time
]
if included_metric_time_instances:
join_on_time_dimension_sample = included_metric_time_instances[0].spec
else:
join_on_time_dimension_sample = agg_time_dimension_instances[0].spec

# Choose the instance with the smallest base granularity available.
agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int())
assert len(agg_time_dimension_instances) > 0, (
"Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been "
"configured incorrectly."
agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(
[
instance
for instance in parent_data_set.instance_set.time_dimension_instances
if instance.spec.element_name == join_on_time_dimension_sample.element_name
and instance.spec.entity_links == join_on_time_dimension_sample.entity_links
]
)
agg_time_dimension_instance_for_join = agg_time_dimension_instances[0]

# Build time spine data set using the requested agg_time_dimension name.
# Build time spine data set with just the agg_time_dimension instance needed for the join.
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
Expand All @@ -1430,24 +1425,14 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
):
if time_dimension_instance in agg_time_dimension_instances:
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=tuple(
time_dimension_instance
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances
if not (
time_dimension_instance.spec.element_name == agg_time_element_name
and time_dimension_instance.spec.entity_links == agg_time_entity_links
)
),
time_dimension_instances=time_dimensions_to_select_from_parent,
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
Expand Down Expand Up @@ -1479,8 +1464,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = time_dimension_instance.spec
for parent_time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = parent_time_dimension_instance.spec
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
Expand Down Expand Up @@ -1519,13 +1504,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
# Apply date_part to time spine column select expression.
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = original_time_spine_dim_instance.spec.with_grain_and_date_part(
time_granularity=time_dimension_spec.time_granularity, date_part=time_dimension_spec.date_part
)
time_spine_dim_instance = TimeDimensionInstance(
defined_from=original_time_spine_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,

time_spine_dim_instance = parent_time_dimension_instance.with_new_defined_from(
original_time_spine_dim_instance.defined_from
)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
Expand Down

0 comments on commit af2375c

Please sign in to comment.