Skip to content

Commit

Permalink
Rename to NodeToColumnAliasMapping.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 9, 2024
1 parent e06ab95 commit dcd8fcf
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 127 deletions.
44 changes: 26 additions & 18 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import logging
from typing import FrozenSet, Mapping

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet
from metricflow.sql.optimizer.tag_required_column_aliases import SqlTagRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.optimizer.tag_required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
Expand All @@ -29,7 +29,7 @@ class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):

def __init__(
self,
required_alias_mapping: Mapping[SqlQueryPlanNode, FrozenSet[str]],
required_alias_mapping: NodeToColumnAliasMapping,
) -> None:
"""Constructor.
Expand All @@ -42,7 +42,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
required_column_aliases = self._required_alias_mapping.get(node)
required_column_aliases = self._required_alias_mapping.get_aliases(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
Expand Down Expand Up @@ -100,23 +100,31 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
"""Removes unnecessary columns in the SELECT clauses."""
"""Removes unnecessary columns in the SELECT statements."""

def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102
# Can't prune columns without knowing the structure of the query.
if not node.as_select_node:
# ALl columns in the nearest SELECT node need to be kept as otherwise, the meaning of the query changes.
required_select_columns = node.nearest_select_columns({})

# Can't prune without knowing the structure of the query.
if required_select_columns is None:
logger.debug(
LazyFormat(
"The columns required at this node can't be determined, so skipping column pruning",
node=node.structure_text(),
required_select_columns=required_select_columns,
)
)
return node

# Figure out which columns in which nodes are required.
tagged_column_alias_set = TaggedColumnAliasSet()
tagged_column_alias_set.tag_all_aliases_in_node(node.as_select_node)
tag_required_column_alias_visitor = SqlTagRequiredColumnAliasesVisitor(
tagged_column_alias_set=tagged_column_alias_set,
map_required_column_aliases_visitor = SqlMapRequiredColumnAliasesVisitor(
start_node=node,
required_column_aliases_in_start_node=frozenset(
[select_column.column_alias for select_column in required_select_columns]
),
)
node.accept(tag_required_column_alias_visitor)
node.accept(map_required_column_aliases_visitor)

# Re-write the query, pruning columns in the SELECT that are not needed.
pruning_visitor = SqlColumnPrunerVisitor(
required_alias_mapping=tagged_column_alias_set.get_mapping(),
)
# Re-write the query, removing unnecessary columns in the SELECT statements.
pruning_visitor = SqlColumnPrunerVisitor(map_required_column_aliases_visitor.required_column_alias_mapping)
return node.accept(pruning_visitor)
83 changes: 10 additions & 73 deletions metricflow/sql/optimizer/tag_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,32 @@

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Iterable, Mapping, Set

from typing_extensions import override
from typing import Dict, FrozenSet, Iterable, Set

from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)


class TaggedColumnAliasSet:
"""Keep track of column aliases in SELECT statements that have been tagged.
The main use case for this class is to keep track of column aliases / columns that are required so that unnecessary
columns can be pruned.
For example, in this query:
SELECT source_0.col_0 AS col_0
FROM (
SELECT
example_table.col_0
example_table.col_1
FROM example_table
) source_0
class NodeToColumnAliasMapping:
"""Mutable class for mapping a SQL node to an arbitrary set of column aliases for that node.
this class can be used to tag `example_table.col_0` but not tag `example_table.col_1` since it's not needed for the
query to run correctly.
* Alternatively, this can be described as mapping a location in the SQL query plan to a set of column aliases.
* See `SqlMapRequiredColumnAliasesVisitor` for the main use case for this class.
* This is a thin wrapper over a dict to aid readability.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set)

def get_tagged_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the given tagged column aliases associated with the given SQL node."""
def get_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the column aliases added for the given SQL node."""
return frozenset(self._node_to_tagged_aliases[node])

def tag_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
def add_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
return self._node_to_tagged_aliases[node].add(column_alias)

def tag_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
def add_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
self._node_to_tagged_aliases[node].update(column_aliases)

def tag_all_aliases_in_node(self, node: SqlQueryPlanNode) -> None:
"""Convenience method that tags all column aliases in the given SQL node, where appropriate."""
node.accept(_TagAllColumnAliasesInNodeVisitor(self))

def get_mapping(self) -> Mapping[SqlQueryPlanNode, FrozenSet[str]]:
"""Return mapping view that associates a given SQL node with the tagged column aliases in that node."""
return {node: frozenset(tagged_aliases) for node, tagged_aliases in self._node_to_tagged_aliases.items()}


class _TagAllColumnAliasesInNodeVisitor(SqlQueryPlanNodeVisitor[None]):
"""Visitor to help implement `TaggedColumnAliasSet.tag_all_aliases_in_node`."""

def __init__(self, required_column_alias_collector: TaggedColumnAliasSet) -> None:
self._required_column_alias_collector = required_column_alias_collector

@override
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
for select_column in node.select_columns:
self._required_column_alias_collector.tag_alias(
node=node,
column_alias=select_column.column_alias,
)

@override
def visit_table_node(self, node: SqlTableNode) -> None:
"""Columns in a SQL table are not represented."""
pass

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
"""Columns in an arbitrary SQL query are not represented."""
pass

@override
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
112 changes: 77 additions & 35 deletions metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import logging
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from typing import Dict, FrozenSet, List, Set, Tuple

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.tag_column_aliases import TaggedColumnAliasSet
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_exprs import SqlExpressionTreeLineage
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
Expand All @@ -23,21 +23,29 @@
logger = logging.getLogger(__name__)


class SqlTagRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]):
"""To aid column pruning, traverse the SQL-query representation DAG and tag all column aliases that are required.
class SqlMapRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]):
"""To aid column pruning, traverse the SQL-query representation DAG and map the SELECT columns needed at each node.
For example, for the query:
For example, the query:
-- SELECT node_id="select_0"
SELECT source_0.col_0 AS col_0_renamed
FROM (
-- SELECT node_id="select_1
SELECT
example_table.col_0
example_table.col_1
FROM example_table_0
) source_0
The top-level SQL node would have the column alias `col_0_renamed` tagged, and the SQL node associated with
`source_0` would have `col_0` tagged. Once tagged, the information can be used to prune the columns in the SELECT:
would generate the mapping:
{
"select_0": {"col_0"},
"select_1": {"col_0"),
}
The mapping can be later used to rewrite the query to:
SELECT source_0.col_0 AS col_0_renamed
FROM (
Expand All @@ -47,14 +55,26 @@ class SqlTagRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]):
) source_0
"""

def __init__(self, tagged_column_alias_set: TaggedColumnAliasSet) -> None:
def __init__(self, start_node: SqlQueryPlanNode, required_column_aliases_in_start_node: FrozenSet[str]) -> None:
"""Initializer.
Args:
tagged_column_alias_set: Stores the set of columns that are tagged. This will be updated as the visitor
traverses the SQL-query representation DAG.
start_node: The node where the traversal by this visitor will start.
required_column_aliases_in_start_node: The column aliases at the `start_node` that are required.
"""
self._column_alias_tagger = tagged_column_alias_set
# Stores the mapping of the SQL node to the required column aliases. This will be updated as the visitor
# traverses the SQL-query representation DAG.
self._current_required_column_alias_mapping = NodeToColumnAliasMapping()
self._current_required_column_alias_mapping.add_aliases(start_node, required_column_aliases_in_start_node)

# Helps lookup the CTE node associated with a given CTE alias. A member variable is needed as any node in the
# SQL DAG can reference a CTE.
start_node_as_select_node = start_node.as_select_node
self._cte_alias_to_cte_node: Dict[str, SqlCteNode] = (
{cte_source.cte_alias: cte_source for cte_source in start_node_as_select_node.cte_sources}
if start_node_as_select_node is not None
else {}
)

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
Expand Down Expand Up @@ -94,11 +114,9 @@ def _visit_parents(self, node: SqlQueryPlanNode) -> None:
parent_node.accept(self)
return

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # noqa: D102
# Based on column aliases that are tagged in this SELECT statement, tag corresponding column aliases in
# parent nodes.

initial_required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node)
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
"""Based on required column aliases for this SELECT, figure out required column aliases in parents."""
initial_required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)

# If this SELECT statement uses DISTINCT, all columns are required as removing them would change the meaning of
# the query.
Expand All @@ -121,20 +139,17 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: #
)
)
# Since additional select columns could have been selected due to DISTINCT or GROUP BY, re-tag.
self._column_alias_tagger.tag_aliases(node, updated_required_column_aliases_in_this_node)
self._current_required_column_alias_mapping.add_aliases(node, updated_required_column_aliases_in_this_node)

required_select_columns_in_this_node = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in updated_required_column_aliases_in_this_node
)

# TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any
# SqlExpressionNode.

if len(required_select_columns_in_this_node) == 0:
raise RuntimeError(
"No columns are required in this node - this indicates a bug in this collector or in the inputs."
"No columns are required in this node - this indicates a bug in this visitor or in the inputs."
)

# Based on the expressions in this select statement, figure out what column aliases are needed in the sources of
Expand All @@ -144,25 +159,43 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: #
# If any of the string expressions don't have context on what columns are used in the expression, then it's
# impossible to know what columns can be pruned from the parent sources. Tag all columns in parents as required.
if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]):
for parent_node in node.parent_nodes:
self._column_alias_tagger.tag_all_aliases_in_node(parent_node)
nodes_to_retain_all_columns = [node.from_source]
for join_desc in node.join_descs:
nodes_to_retain_all_columns.append(join_desc.right_source)

for node_to_retain_all_columns in nodes_to_retain_all_columns:
nearest_select_columns = node_to_retain_all_columns.nearest_select_columns({})
for select_column in nearest_select_columns or ():
self._current_required_column_alias_mapping.add_alias(
node=node_to_retain_all_columns, column_alias=select_column.column_alias
)

self._visit_parents(node)
return

# Create a mapping from the source alias to the column aliases needed from the corresponding source.
source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set)
source_alias_to_required_column_aliases: Dict[str, Set[str]] = defaultdict(set)
for column_reference_expr in exprs_used_in_this_node.column_reference_exprs:
column_reference = column_reference_expr.col_ref
source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name)
source_alias_to_required_column_aliases[column_reference.table_alias].add(column_reference.column_name)

logger.debug(
LazyFormat(
"Collected required column names from sources",
source_alias_to_required_column_aliases=source_alias_to_required_column_aliases,
)
)
# Appropriately tag the columns required in the parent nodes.
if node.from_source_alias in source_alias_to_required_column_alias:
aliases_required_in_parent = source_alias_to_required_column_alias[node.from_source_alias]
self._column_alias_tagger.tag_aliases(node=node.from_source, column_aliases=aliases_required_in_parent)
if node.from_source_alias in source_alias_to_required_column_aliases:
aliases_required_in_parent = source_alias_to_required_column_aliases[node.from_source_alias]
self._current_required_column_alias_mapping.add_aliases(
node=node.from_source, column_aliases=aliases_required_in_parent
)

for join_desc in node.join_descs:
if join_desc.right_source_alias in source_alias_to_required_column_alias:
aliases_required_in_parent = source_alias_to_required_column_alias[join_desc.right_source_alias]
self._column_alias_tagger.tag_aliases(
if join_desc.right_source_alias in source_alias_to_required_column_aliases:
aliases_required_in_parent = source_alias_to_required_column_aliases[join_desc.right_source_alias]
self._current_required_column_alias_mapping.add_aliases(
node=join_desc.right_source, column_aliases=aliases_required_in_parent
)
# TODO: Handle CTEs parent nodes.
Expand All @@ -172,17 +205,21 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: #
for string_expr in exprs_used_in_this_node.string_exprs:
if string_expr.used_columns:
for column_alias in string_expr.used_columns:
for parent_node in node.parent_nodes:
self._column_alias_tagger.tag_alias(parent_node, column_alias)
for node_to_retain_all_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)

# Same with unqualified column references - it's hard to tell which source it came from, so it's safest to say
# it's required from all parents.
# An unqualified column reference expression is like `SELECT col_0` whereas a qualified column reference
# expression is like `SELECT table_0.col_0`.
for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs:
column_alias = unqualified_column_reference_expr.column_alias
for parent_node in node.parent_nodes:
self._column_alias_tagger.tag_alias(parent_node, column_alias)
for node_to_retain_all_columns in (node.from_source,) + tuple(
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_alias(node_to_retain_all_columns, column_alias)

# Visit recursively.
self._visit_parents(node)
Expand All @@ -198,3 +235,8 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> No

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> None: # noqa: D102
return self._visit_parents(node)

@property
def required_column_alias_mapping(self) -> NodeToColumnAliasMapping:
"""Return the column aliases required at each node as determined after traversal."""
return self._current_required_column_alias_mapping
Loading

0 comments on commit dcd8fcf

Please sign in to comment.