Skip to content

Commit

Permalink
bug(mql): Fix MQL formula queries with totals and no groupby (#6136)
Browse files Browse the repository at this point in the history
* fix totals queries

* clean up

* fix test

* typing

* add test

* fix

* add test

* fix test

* fix parser tests
  • Loading branch information
enochtangg authored Jul 24, 2024
1 parent ac18814 commit 57f2059
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 180 deletions.
15 changes: 14 additions & 1 deletion snuba/clickhouse/formatter/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from snuba.query import ProcessableQuery
from snuba.query import Query as AbstractQuery
from snuba.query.composite import CompositeQuery
from snuba.query.data_source.join import IndividualNode, JoinClause, JoinVisitor
from snuba.query.data_source.join import (
IndividualNode,
JoinClause,
JoinType,
JoinVisitor,
)
from snuba.query.data_source.simple import Table
from snuba.query.data_source.visitor import DataSourceVisitor
from snuba.query.expressions import Expression, ExpressionVisitor
Expand Down Expand Up @@ -271,6 +276,14 @@ def visit_individual_node(self, node: IndividualNode[Table]) -> FormattedNode:
def visit_join_clause(self, node: JoinClause[Table]) -> FormattedNode:
join_type = f"{node.join_type.value} " if node.join_type else ""
modifier = f"{node.join_modifier.value} " if node.join_modifier else ""
if node.join_type == JoinType.CROSS:
return SequenceNode(
[
node.left_node.accept(self),
StringNode(f"{modifier}{join_type}JOIN"),
node.right_node.accept(self),
]
)
return SequenceNode(
[
node.left_node.accept(self),
Expand Down
1 change: 1 addition & 0 deletions snuba/query/data_source/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class JoinType(Enum):
INNER = "INNER"
LEFT = "LEFT"
CROSS = "CROSS"


class JoinModifier(Enum):
Expand Down
48 changes: 29 additions & 19 deletions snuba/query/formatters/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from snuba.query import ProcessableQuery
from snuba.query.composite import CompositeQuery
from snuba.query.data_source.join import IndividualNode, JoinClause, JoinVisitor
from snuba.query.data_source.join import (
IndividualNode,
JoinClause,
JoinType,
JoinVisitor,
)
from snuba.query.data_source.simple import SimpleDataSource
from snuba.query.data_source.visitor import DataSourceVisitor
from snuba.query.expressions import StringifyVisitor
Expand Down Expand Up @@ -135,22 +140,27 @@ def visit_individual_node(
return [f"{self.visit(node.data_source)} AS `{node.alias}`"]

def visit_join_clause(self, node: JoinClause[SimpleDataSource]) -> List[str]:
# There is only one of these in the on clause (I think)
on_list = [
[
f"{c.left.table_alias}.{c.left.column}",
f"{c.right.table_alias}.{c.right.column}",
if node.join_type == JoinType.CROSS:
return [
*_indent_str_list(node.left_node.accept(self), 1),
f"{node.join_type.name.upper()} JOIN",
*_indent_str_list(node.right_node.accept(self), 1),
]
else:
on_list = [
[
f"{c.left.table_alias}.{c.left.column}",
f"{c.right.table_alias}.{c.right.column}",
]
for c in node.keys
][0]
return [
*_indent_str_list(node.left_node.accept(self), 1),
f"{node.join_type.name.upper()} JOIN",
*_indent_str_list(node.right_node.accept(self), 1),
"ON",
*_indent_str_list(
on_list,
1,
),
]
for c in node.keys
][0]

return [
*_indent_str_list(node.left_node.accept(self), 1),
f"{node.join_type.name.upper()} JOIN",
*_indent_str_list(node.right_node.accept(self), 1),
"ON",
*_indent_str_list(
on_list,
1,
),
]
14 changes: 2 additions & 12 deletions snuba/query/joins/metrics_subquery_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from dataclasses import replace
from typing import Generator, Mapping

from snuba.query import ProcessableQuery, SelectedExpression
Expand Down Expand Up @@ -257,17 +256,8 @@ def generate_metrics_subqueries(query: CompositeQuery[Entity]) -> None:

for e in query.get_groupby():
_process_root_groupby(query, e, subqueries, alias_generator)

# Since groupbys are pushed down, we don't need them in the outer query.
query.set_ast_groupby([])

query.set_ast_orderby(
[
replace(
orderby,
expression=_process_root(
orderby.expression, subqueries, alias_generator
),
)
for orderby in query.get_orderby()
]
)
query.set_from_clause(SubqueriesReplacer(subqueries).visit_join_clause(from_clause))
40 changes: 23 additions & 17 deletions snuba/query/mql/context_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def scope_conditions(

def rollup_expressions(
mql_context: MQLContext, table_name: str | None = None
) -> tuple[Expression, bool, OrderBy | None, SelectedExpression]:
) -> tuple[Expression, bool, OrderBy | None, SelectedExpression | None]:
"""
This function returns four values based on the rollup field in the MQL context:
- granularity_condition: an expression that filters the granularity column based on the granularity in the MQL context
Expand All @@ -114,6 +114,12 @@ def rollup_expressions(
Literal(None, rollup.granularity),
)

# Validate totals/interval
if rollup.interval is None and rollup.with_totals in (None, "False"):
raise ParsingException(
"either interval or with_totals must be specified in rollup"
)

# Validate totals/orderby
if rollup.with_totals is not None and rollup.with_totals not in ("True", "False"):
raise ParsingException("with_totals must be a string, either 'True' or 'False'")
Expand All @@ -130,27 +136,27 @@ def rollup_expressions(
)

with_totals = rollup.with_totals == "True"
selected_time = None
orderby = None

prefix = "" if not table_name else f"{table_name}."
time_expression = FunctionCall(
f"{prefix}time",
"toStartOfInterval",
parameters=(
Column(None, table_name, "timestamp"),
FunctionCall(
None,
"toIntervalSecond",
(Literal(None, rollup.interval),),
),
Literal(None, "Universal"),
),
)
selected_time = SelectedExpression("time", time_expression)

if rollup.interval:
# If an interval is specified, then we need to group the time by that interval,
# return the time in the select, and order the results by that time.
prefix = "" if not table_name else f"{table_name}."
time_expression = FunctionCall(
f"{prefix}time",
"toStartOfInterval",
parameters=(
Column(None, table_name, "timestamp"),
FunctionCall(
None,
"toIntervalSecond",
(Literal(None, rollup.interval),),
),
Literal(None, "Universal"),
),
)
selected_time = SelectedExpression("time", time_expression)
orderby = OrderBy(OrderByDirection.ASC, time_expression)
elif rollup.orderby is not None:
direction = (
Expand Down
63 changes: 49 additions & 14 deletions snuba/query/mql/parser_supported_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,22 +1304,20 @@ def populate_query_from_mql_context(
query.set_totals(with_totals)
if orderby:
query.set_ast_orderby([orderby])
query.set_ast_selected_columns(
list(query.get_selected_columns()) + [selected_time]
)

groupby = query.get_groupby()
if groupby:
query.set_ast_groupby(list(groupby) + [selected_time.expression])
else:
query.set_ast_groupby([selected_time.expression])
if selected_time:
query.set_ast_selected_columns(
list(query.get_selected_columns()) + [selected_time]
)
groupby = query.get_groupby()
if groupby:
query.set_ast_groupby(list(groupby) + [selected_time.expression])
else:
query.set_ast_groupby([selected_time.expression])

if isinstance(query, CompositeQuery):
# If the query is grouping by time, that needs to be added to the JoinClause keys to
# ensure we correctly join the subqueries. The column names will be the same for all the
# subqueries, so we just need to map all the table aliases.

def add_join_keys(join_clause: JoinClause[Any]) -> str:
def add_time_join_keys(join_clause: JoinClause[Any]) -> str:
match (join_clause.left_node, join_clause.right_node):
case (
IndividualNode(alias=left),
Expand All @@ -1340,7 +1338,7 @@ def add_join_keys(join_clause: JoinClause[Any]) -> str:
JoinClause() as inner_join_clause,
IndividualNode(alias=right),
):
left_alias = add_join_keys(inner_join_clause)
left_alias = add_time_join_keys(inner_join_clause)
join_clause.keys.append(
JoinCondition(
left=JoinConditionExpression(
Expand All @@ -1354,7 +1352,44 @@ def add_join_keys(join_clause: JoinClause[Any]) -> str:
)
return right

add_join_keys(join_clause)
def convert_to_cross_join(join_clause: JoinClause[Any]) -> JoinClause[Any]:
match (join_clause.left_node, join_clause.right_node):
case (
IndividualNode(),
IndividualNode(),
):
join_clause = replace(join_clause, join_type=JoinType.CROSS)
case (
JoinClause() as inner_join_clause,
IndividualNode(),
):
new_inner_join_clause = add_time_join_keys(inner_join_clause)
join_clause = replace(join_clause, left_node=new_inner_join_clause)
return join_clause

# Check if groupby is empty or has a one-sided groupby on the formula
number_of_joins = len(alias_node_map.keys())
number_of_groupbys = len(query.get_groupby())

no_groupby_or_one_sided_groupby = False
if number_of_groupbys == 0:
no_groupby_or_one_sided_groupby = True
elif number_of_groupbys % number_of_joins != 0:
no_groupby_or_one_sided_groupby = True

if selected_time:
# If the query is grouping by time, that needs to be added to the JoinClause keys to
# ensure we correctly join the subqueries. The column names will be the same for all the
# subqueries, so we just need to map all the table aliases.
add_time_join_keys(join_clause)
elif query.has_totals() and no_groupby_or_one_sided_groupby:
# If formula query has no interval and no group by or a onesided groupby, but has totals, we need to convert
# join type to a CROSS join. This is because without a group by, each sub-query will return
# a single row with single value column. In order to combine the results in the outer query,
# we need to perform a cross join on each of these single values since there are no conditions
# to join by.
join_clause = convert_to_cross_join(join_clause)
query.set_from_clause(join_clause)

limit = limit_value(mql_context)
offset = offset_value(mql_context)
Expand Down
7 changes: 6 additions & 1 deletion tests/query/joins/test_metrics_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ def test_subquery_generator_metrics() -> None:
expected_outer_query_orderby = [
OrderBy(
direction=OrderByDirection.ASC,
expression=column("_snuba_d0.time", "d0", "_snuba_d0.time"),
expression=f.toStartOfInterval(
column("timestamp", "d0", "_snuba_timestamp"),
f.toIntervalSecond(literal(60)),
literal("Universal"),
alias="_snuba_d0.time",
),
)
]
assert original_query.get_orderby() == expected_outer_query_orderby
Expand Down
58 changes: 16 additions & 42 deletions tests/query/parser/test_formula_mql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def test_curried_aggregate_formula() -> None:
assert eq, reason


def test_formula_with_totals() -> None:
def test_formula_no_groupby_no_interval_with_totals() -> None:
mql_context_new = deepcopy(mql_context)
mql_context_new["rollup"]["with_totals"] = "True"
mql_context_new["rollup"]["interval"] = None
Expand Down Expand Up @@ -1286,13 +1286,8 @@ def test_formula_with_totals() -> None:
alias="d0",
data_source=from_distributions,
),
keys=[
JoinCondition(
left=JoinConditionExpression(table_alias="d1", column="d1.time"),
right=JoinConditionExpression(table_alias="d0", column="d0.time"),
)
],
join_type=JoinType.INNER,
keys=[],
join_type=JoinType.CROSS,
join_modifier=None,
)

Expand All @@ -1309,18 +1304,8 @@ def test_formula_with_totals() -> None:

expected = CompositeQuery(
from_clause=join_clause,
selected_columns=[
expected_selected,
SelectedExpression(
"time",
time_expression("d1", None),
),
SelectedExpression(
"time",
time_expression("d0", None),
),
],
groupby=[time_expression("d1", None), time_expression("d0", None)],
selected_columns=[expected_selected],
groupby=[],
condition=formula_condition,
order_by=[],
limit=1000,
Expand All @@ -1336,10 +1321,11 @@ def test_formula_with_totals() -> None:
assert eq, reason


def test_formula_with_totals_and_interval() -> None:
def test_formula_onesided_groupby_no_interval_with_totals() -> None:
mql_context_new = deepcopy(mql_context)
mql_context_new["rollup"]["with_totals"] = "True"
query_body = "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@millisecond`)"
mql_context_new["rollup"]["interval"] = None
query_body = "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`)"

expected_selected = SelectedExpression(
"aggregate_value",
Expand Down Expand Up @@ -1367,13 +1353,8 @@ def test_formula_with_totals_and_interval() -> None:
alias="d0",
data_source=from_distributions,
),
keys=[
JoinCondition(
left=JoinConditionExpression(table_alias="d1", column="d1.time"),
right=JoinConditionExpression(table_alias="d0", column="d0.time"),
)
],
join_type=JoinType.INNER,
keys=[],
join_type=JoinType.CROSS,
join_modifier=None,
)

Expand All @@ -1393,22 +1374,15 @@ def test_formula_with_totals_and_interval() -> None:
selected_columns=[
expected_selected,
SelectedExpression(
"time",
time_expression("d1"),
),
SelectedExpression(
"time",
time_expression("d0"),
"transaction",
subscriptable_expression("333333", "d0"),
),
],
groupby=[time_expression("d1"), time_expression("d0")],
condition=formula_condition,
order_by=[
OrderBy(
direction=OrderByDirection.ASC,
expression=time_expression("d0"),
),
groupby=[
subscriptable_expression("333333", "d0"),
],
condition=formula_condition,
order_by=[],
limit=1000,
offset=0,
totals=True,
Expand Down
Loading

0 comments on commit 57f2059

Please sign in to comment.