Skip to content

Commit

Permalink
Clean up FilterLinkableInstancesWithLeadingLink logic
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 2, 2024
1 parent ab38128 commit acb0295
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
8 changes: 3 additions & 5 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,17 +498,15 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
if join_on_entity:
# Remove any instances that already have the join_on_entity as the leading link. This will prevent a duplicate
# entity link when we add it in the next step.
right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(join_on_entity).transform(
right_data_set.instance_set
)
right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
join_on_entity.reference
).transform(right_data_set.instance_set)

# After the right data set is joined, update the entity links to indicate that joining on the entity was
# required to reach the spec. If the "country" dimension was joined and "user_id" is the join_on_entity,
# then the joined data set should have the "user__country" dimension.
new_instances: Tuple[MdoInstance, ...] = ()
for original_instance in right_instance_set_filtered.linkable_instances:
if original_instance.spec == join_on_entity:
continue
new_instance = original_instance.with_entity_prefix(
join_on_entity.reference, column_association_resolver=self._column_association_resolver
)
Expand Down
12 changes: 5 additions & 7 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import MetricReference, SemanticModelReference
from dbt_semantic_interfaces.references import EntityReference, MetricReference, SemanticModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
Expand All @@ -29,7 +29,6 @@
from metricflow_semantics.model.semantics.metric_lookup import MetricLookup
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
Expand Down Expand Up @@ -390,15 +389,14 @@ class FilterLinkableInstancesWithLeadingLink(InstanceSetTransform[InstanceSet]):
e.g. Remove "listing__country" if the specified link is "listing".
"""

def __init__(self, entity_link: LinklessEntitySpec) -> None:
def __init__(self, entity_link: EntityReference) -> None:
"""Remove elements with this link as the first element in "entity_links"."""
self._entity_link = entity_link

def _should_pass(self, linkable_spec: LinkableInstanceSpec) -> bool:
return (
len(linkable_spec.entity_links) == 0
or LinklessEntitySpec.from_reference(linkable_spec.entity_links[0]) != self._entity_link
)
if len(linkable_spec.entity_links) == 0:
return not linkable_spec.reference == self._entity_link
return linkable_spec.entity_links[0] != self._entity_link

def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
# Normal to not filter anything if the instance set has no instances with links.
Expand Down

0 comments on commit acb0295

Please sign in to comment.