Skip to content

Commit

Permalink
Bug fix: Generate new select columns for instances with new entity li…
Browse files Browse the repository at this point in the history
…nk - commit needs cleanup
  • Loading branch information
courtneyholcomb committed Nov 2, 2024
1 parent 6514e81 commit 1dafb8f
Show file tree
Hide file tree
Showing 62 changed files with 1,012 additions and 1,056 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class MdoInstance(ABC, Generic[SpecT]):
"""

# The columns associated with this instance.
# TODO: if poss, remove this and instead add a method that resolves this from the spec + column association resolver
# (ensure we're using consistent logic everywhere so this bug doesn't happen again)
associated_columns: Tuple[ColumnAssociation, ...]
# The spec that describes this instance.
spec: SpecT
Expand Down
192 changes: 133 additions & 59 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from metricflow_semantics.dag.sequential_id import SequentialIdGenerator
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.instances import (
DimensionInstance,
EntityInstance,
GroupByMetricInstance,
InstanceSet,
MdoInstance,
Expand All @@ -29,11 +31,9 @@
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.specs.column_assoc import (
ColumnAssociationResolver,
)
from metricflow_semantics.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.metadata_spec import MetadataSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
Expand Down Expand Up @@ -72,7 +72,6 @@
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.instance_converters import (
AddGroupByMetric,
AddLinkToLinkableElements,
AddMetadata,
AddMetrics,
AliasAggregatedMeasures,
Expand Down Expand Up @@ -229,7 +228,7 @@ def convert_to_sql_query_plan(
sql_node = optimizer.optimize(sql_node)
logger.debug(
LazyFormat(
lambda: f"After applying {optimizer.__class__.__name__}, the SQL query plan is:\n"
lambda: f"After applying optimizer {optimizer.__class__.__name__}, the SQL query plan is:\n"
f"{indent(sql_node.structure_text())}"
)
)
Expand Down Expand Up @@ -456,31 +455,58 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
)

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
"""Generates the query that realizes the behavior of the JoinToStandardOutputNode."""
# Keep a mapping between the table aliases that would be used in the query and the MDO instances in that source.
# e.g. when building "FROM from_table a JOIN right_table b", the value for key "a" would be the instances in
# "from_table"
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()

# Convert the dataflow from the left node to a DataSet and add context for it to table_alias_to_instance_set
# A DataSet is a bundle of the SQL query (in object form) and the MDO instances that the SQL query contains.
"""Generates the query that realizes the behavior of the JoinOnEntitiesNode."""
from_data_set = node.left_node.accept(self)
from_data_set_alias = self._next_unique_table_alias()
table_alias_to_instance_set[from_data_set_alias] = from_data_set.instance_set

# Build the join descriptions for the SqlQueryPlan - different from node.join_descriptions which are the join
# descriptions from the dataflow plan.
sql_join_descs: List[SqlJoinDescription] = []
# TODO: make prettier
def build_columns(spec: LinkableInstanceSpec) -> Tuple[ColumnAssociation]:
return (self._column_association_resolver.resolve_spec(spec),)

def build_select_column(
table_alias: str, original_instance: MdoInstance, new_instance: MdoInstance
) -> SqlSelectColumn:
"""Build new select column using the old column name as the expr and the new column name as the alias.
Example: "country AS user_id__country"
"""
return SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=table_alias, column_name=original_instance.associated_column.column_name
),
column_alias=new_instance.associated_column.column_name,
)

# Change the aggregation state for the measures to be partially aggregated if it was previously aggregated
# since we removed the entities and added the dimensions. The dimensions could have the same value for
# multiple rows, so we'll need to re-aggregate.
from_data_set_output_instance_set = from_data_set.instance_set.transform(
# TODO: is this filter doing anything? seems like no?
FilterElements(include_specs=from_data_set.instance_set.spec_set)
).transform(
ChangeMeasureAggregationState(
{
AggregationState.NON_AGGREGATED: AggregationState.NON_AGGREGATED,
AggregationState.COMPLETE: AggregationState.PARTIAL,
AggregationState.PARTIAL: AggregationState.PARTIAL,
}
)
)
instances_to_build_simple_select_columns_for = OrderedDict(
{from_data_set_alias: from_data_set_output_instance_set}
)

# The dataflow plan describes how the data sets coming from the parent nodes should be joined together. Use
# those descriptions to convert them to join descriptions for the SQL query plan.
output_instance_set = from_data_set_output_instance_set
select_columns: Tuple[SqlSelectColumn, ...] = ()
# Build join description, instance set, and select columns for each join target.
sql_join_descs: List[SqlJoinDescription] = []
for join_description in node.join_targets:
join_on_entity = join_description.join_on_entity

right_node_to_join: DataflowPlanNode = join_description.join_node
right_node_to_join = join_description.join_node
right_data_set: SqlDataSet = right_node_to_join.accept(self)
right_data_set_alias = self._next_unique_table_alias()

# Build join description.
sql_join_desc = SqlQueryPlanJoinBuilder.make_base_output_join_description(
left_data_set=AnnotatedSqlDataSet(data_set=from_data_set, alias=from_data_set_alias),
right_data_set=AnnotatedSqlDataSet(data_set=right_data_set, alias=right_data_set_alias),
Expand All @@ -489,58 +515,106 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
sql_join_descs.append(sql_join_desc)

if join_on_entity:
# Build instance set that will be available after join.
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity,
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(
join_on_entity=join_on_entity, column_association_resolver=self._column_association_resolver
# TODO: test if this transformation is necessary and remove it if not. This adds a lot of clutter to the function.
right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(join_on_entity).transform(
right_data_set.instance_set
)

# After the right data set is joined, we need to change the links to indicate that they a join was used to
# satisfy them. For example, if the right dataset contains the "country" dimension, and "user_id" is the
# join_on_entity, then the joined data set should have the "user__country" dimension.
transformed_spec: LinkableInstanceSpec
original_instance: MdoInstance
new_instance: MdoInstance
# Soooo much boilerplate. Figure out how to dedupe.
entity_instances: Tuple[EntityInstance, ...] = ()
for original_instance in right_instance_set_filtered.entity_instances:
# Is this necessary? Does it even work? i.e. diff types here
if original_instance.spec == join_on_entity:
continue
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = EntityInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
entity_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

dimension_instances: Tuple[DimensionInstance, ...] = ()
for original_instance in right_instance_set_filtered.dimension_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = DimensionInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
dimension_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
for original_instance in right_instance_set_filtered.time_dimension_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = TimeDimensionInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
time_dimension_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

group_by_metric_instances: Tuple[GroupByMetricInstance, ...] = ()
for original_instance in right_instance_set_filtered.group_by_metric_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = GroupByMetricInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
group_by_metric_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

right_instance_set_after_join = InstanceSet(
dimension_instances=dimension_instances,
entity_instances=entity_instances,
time_dimension_instances=time_dimension_instances,
group_by_metric_instances=group_by_metric_instances,
)
else:
right_data_set_instance_set_after_join = right_data_set.instance_set
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join
right_instance_set_after_join = right_data_set.instance_set
instances_to_build_simple_select_columns_for[right_data_set_alias] = right_instance_set_after_join

from_data_set_output_instance_set = from_data_set.instance_set.transform(
FilterElements(include_specs=from_data_set.instance_set.spec_set)
)
output_instance_set = InstanceSet.merge([output_instance_set, right_instance_set_after_join])

# Change the aggregation state for the measures to be partially aggregated if it was previously aggregated
# since we removed the entities and added the dimensions. The dimensions could have the same value for
# multiple rows, so we'll need to re-aggregate.
from_data_set_output_instance_set = from_data_set_output_instance_set.transform(
ChangeMeasureAggregationState(
{
AggregationState.NON_AGGREGATED: AggregationState.NON_AGGREGATED,
AggregationState.COMPLETE: AggregationState.PARTIAL,
AggregationState.PARTIAL: AggregationState.PARTIAL,
}
)
select_columns += create_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver,
table_alias_to_instance_set=instances_to_build_simple_select_columns_for,
)

table_alias_to_instance_set[from_data_set_alias] = from_data_set_output_instance_set

# Construct the data set that contains the updated instances and the SQL nodes that should go in the various
# clauses.
return SqlDataSet(
instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())),
# TODO: Should SqlDataSet have a map like {instance: column}? Trying to match them is a pain in the butt.
instance_set=output_instance_set,
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=create_select_columns_for_instance_sets(
self._column_association_resolver, table_alias_to_instance_set
),
select_columns=select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=from_data_set_alias,
join_descs=tuple(sql_join_descs),
Expand Down
Loading

0 comments on commit 1dafb8f

Please sign in to comment.