diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 01b0778c0..5abffad8f 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -385,7 +385,9 @@ def _make_time_spine_data_set( def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: """Generate the SQL to read from the source.""" return SqlDataSet( - sql_select_node=node.data_set.checked_sql_select_node, + # This visitor is assumed to create a unique SELECT node for each dataflow node, so create a copy. + # The column pruner relies on this assumption to keep track of what columns are required at each node. + sql_select_node=node.data_set.checked_sql_select_node.create_copy(), instance_set=node.data_set.instance_set, ) diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 249ca5c73..18f2e2839 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -227,6 +227,21 @@ def nearest_select_columns( ) -> Optional[Sequence[SqlSelectColumn]]: return self.select_columns + def create_copy(self) -> SqlSelectStatementNode: # noqa: D102 + return SqlSelectStatementNode.create( + description=self.description, + select_columns=self.select_columns, + from_source=self.from_source, + from_source_alias=self.from_source_alias, + cte_sources=self.cte_sources, + join_descs=self.join_descs, + group_bys=self.group_bys, + order_bys=self.order_bys, + where=self.where, + limit=self.limit, + distinct=self.distinct, + ) + @dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode):