Skip to content

Commit

Permalink
optimizer rewrite + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kylemumma committed Mar 21, 2024
1 parent a811de4 commit 5cd79aa
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 2 deletions.
2 changes: 1 addition & 1 deletion snuba/query/mql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def transform(exp: Expression) -> Expression:
def optimize_filter_in_select(
query: CompositeQuery[QueryEntity] | LogicalQuery,
) -> None:
FilterInSelectOptimizer().process_mql_query(query)
FilterInSelectOptimizer().process_mql_query2(query)


CustomProcessors = Sequence[
Expand Down
89 changes: 89 additions & 0 deletions snuba/query/processors/logical/filter_in_select_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,64 @@ class FilterInSelectOptimizer:
WHERE metric_id in (1,2,3,4) and status=200
"""

def process_mql_query2(
self, query: LogicalQuery | CompositeQuery[QueryEntity]
) -> None:
feat_flag = get_int_config("enable_filter_in_select_optimizer", default=0)
if feat_flag:
try:
new_condition = self.get_select_filter(query)
if new_condition is not None:
query.add_condition_to_ast(new_condition)
except ValueError:
raise

def get_select_filter(
self,
query: LogicalQuery | CompositeQuery[QueryEntity],
) -> FunctionCall | None:
"""
Given a query, grabs all the conditions from conditional aggregates and lifts into
one condition.
ex: SELECT sumIf(value, metric_id in (1,2,3,4) and status=200),
avgIf(value, metric_id in (11,12) and status=400),
returns or((metric_id in (1,2,3,4) and status=200), (metric_id in (11,12) and status=400))
"""
# find and grab all the conditional aggregate functions
cond_agg_functions: list[FunctionCall | CurriedFunctionCall] = []
for selected_exp in query.get_selected_columns():
exp = selected_exp.expression
finder = FindConditionalAggregateFunctionsVisitor()
exp.accept(finder)
found = finder.get_matches()
if len(found) > 0:
if len(cond_agg_functions) > 0:
raise ValueError(
"expected only one selected column to contain conditional aggregate functions but found multiple"
)
else:
cond_agg_functions = found

if len(cond_agg_functions) == 0:
return None

# validate the functions, and lift their conditions into new_condition, return it
new_condition = None
for func in cond_agg_functions:
if len(func.parameters) != 2 or not isinstance(
func.parameters[1], FunctionCall
):
raise ValueError("unexpected form of function")

if new_condition is None:
new_condition = deepcopy(func.parameters[1])
else:
new_condition = binary_condition(
"or", deepcopy(func.parameters[1]), new_condition
)
return new_condition

def process_mql_query(
self, query: LogicalQuery | CompositeQuery[QueryEntity]
) -> None:
Expand Down Expand Up @@ -130,6 +188,37 @@ def process_mql_query(
query.add_condition_to_ast(domain_filter)
metrics.increment("kyles_optimizer_optimized")

def get_select_filter_old(
self,
query: LogicalQuery | CompositeQuery[QueryEntity],
) -> FunctionCall | None:
domain = self.get_domain_of_mql_query(query)

if not domain:
return None

# make the condition
domain_filter = None
for key, value in domain.items():
clause = binary_condition(
"in",
key,
FunctionCall(
alias=None,
function_name="array",
parameters=tuple(value),
),
)
if not domain_filter:
domain_filter = clause
else:
domain_filter = binary_condition(
"and",
domain_filter,
clause,
)
return domain_filter

def get_domain_of_mql_query(
self, query: LogicalQuery | CompositeQuery[QueryEntity]
) -> Domain:
Expand Down
259 changes: 258 additions & 1 deletion tests/query/processors/test_filter_in_select_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest

from snuba.datasets.factory import get_dataset
from snuba.query.expressions import Column, Literal, SubscriptableReference
from snuba.query.conditions import binary_condition
from snuba.query.expressions import (
Column,
Expression,
FunctionCall,
Literal,
SubscriptableReference,
)
from snuba.query.logical import Query
from snuba.query.mql.parser import parse_mql_query
from snuba.query.processors.logical.filter_in_select_optimizer import (
Expand Down Expand Up @@ -38,6 +45,9 @@
"limit": None,
"offset": None,
}
assert isinstance(
mql_context["indexer_mappings"], dict
) # oh mypy, my oh mypy, how you check types

""" TEST CASES """

Expand Down Expand Up @@ -147,6 +157,239 @@ def subscriptable_reference(name: str, key: str) -> SubscriptableReference:
),
]


def equals(lhs: Expression, rhs: Expression) -> FunctionCall:
return binary_condition("equals", lhs, rhs)


def _and(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall:
return binary_condition("and", lhs, rhs)


def _or(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall:
return binary_condition("or", lhs, rhs)


def _in(lhs: Expression, rhs: Expression) -> FunctionCall:
return binary_condition("in", lhs, rhs)


new_mql_test_cases: list[tuple[str, FunctionCall]] = [
(
"sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@second`)",
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction",
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"quantiles(0.5)(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction",
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"sum(`d:transactions/duration@millisecond`) / ((max(`d:transactions/duration@millisecond`) + avg(`d:transactions/duration@millisecond`)) * min(`d:transactions/duration@millisecond`))",
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
),
(
"(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200}",
_or(
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:[400,404,500,501]}",
_or(
_and(
_in(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
FunctionCall(
None,
"array",
(
Literal(None, "400"),
Literal(None, "404"),
Literal(None, "500"),
Literal(None, "501"),
),
),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
_and(
_in(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
FunctionCall(
None,
"array",
(
Literal(None, "400"),
Literal(None, "404"),
Literal(None, "500"),
Literal(None, "501"),
),
),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200} by transaction",
_or(
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
_and(
equals(
subscriptable_reference(
"tags_raw", str(mql_context["indexer_mappings"]["status_code"])
),
Literal(None, "200"),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
),
(
"(sum(`d:transactions/duration@millisecond`) / sum(`d:transactions/duration@millisecond`)) + 100",
_or(
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
),
(
"sum(`d:transactions/duration@millisecond`) * 1000",
equals(
Column("_snuba_metric_id", None, "metric_id"),
Literal(None, 123456),
),
),
]

""" TESTING """

optimizer = FilterInSelectOptimizer()
Expand Down Expand Up @@ -183,3 +426,17 @@ def test_searcher(mql_query: str, expected_domain: set[int]) -> None:
selected_expression.expression.accept(v)
newres = len(v.get_matches()) > 0
assert newres == oldres


@pytest.mark.parametrize(
"mql_query, expected_condition",
new_mql_test_cases,
)
def test_new_pipeline(mql_query: str, expected_condition: FunctionCall) -> None:
logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics)
assert isinstance(logical_query, Query)

opt = FilterInSelectOptimizer()
actual = opt.get_select_filter(logical_query)
if actual != expected_condition:
assert actual != expected_condition

0 comments on commit 5cd79aa

Please sign in to comment.