diff --git a/snuba/query/mql/parser.py b/snuba/query/mql/parser.py index 06495e1eed2..2642b69cc5f 100644 --- a/snuba/query/mql/parser.py +++ b/snuba/query/mql/parser.py @@ -1067,7 +1067,7 @@ def transform(exp: Expression) -> Expression: def optimize_filter_in_select( query: CompositeQuery[QueryEntity] | LogicalQuery, ) -> None: - FilterInSelectOptimizer().process_mql_query(query) + FilterInSelectOptimizer().process_mql_query2(query) CustomProcessors = Sequence[ diff --git a/snuba/query/processors/logical/filter_in_select_optimizer.py b/snuba/query/processors/logical/filter_in_select_optimizer.py index 3118a27fb3f..a88967b3a4d 100644 --- a/snuba/query/processors/logical/filter_in_select_optimizer.py +++ b/snuba/query/processors/logical/filter_in_select_optimizer.py @@ -92,6 +92,64 @@ class FilterInSelectOptimizer: WHERE metric_id in (1,2,3,4) and status=200 """ + def process_mql_query2( + self, query: LogicalQuery | CompositeQuery[QueryEntity] + ) -> None: + feat_flag = get_int_config("enable_filter_in_select_optimizer", default=0) + if feat_flag: + try: + new_condition = self.get_select_filter(query) + if new_condition is not None: + query.add_condition_to_ast(new_condition) + except ValueError: + raise + + def get_select_filter( + self, + query: LogicalQuery | CompositeQuery[QueryEntity], + ) -> FunctionCall | None: + """ + Given a query, grabs all the conditions from conditional aggregates and lifts into + one condition. + + ex: SELECT sumIf(value, metric_id in (1,2,3,4) and status=200), + avgIf(value, metric_id in (11,12) and status=400), + returns or((metric_id in (1,2,3,4) and status=200), (metric_id in (11,12) and status=400)) + """ + # find and grab all the conditional aggregate functions + cond_agg_functions: list[FunctionCall | CurriedFunctionCall] = [] + for selected_exp in query.get_selected_columns(): + exp = selected_exp.expression + finder = FindConditionalAggregateFunctionsVisitor() + exp.accept(finder) + found = finder.get_matches() + if len(found) > 0: + if len(cond_agg_functions) > 0: + raise ValueError( + "expected only one selected column to contain conditional aggregate functions but found multiple" + ) + else: + cond_agg_functions = found + + if len(cond_agg_functions) == 0: + return None + + # validate the functions, and lift their conditions into new_condition, return it + new_condition = None + for func in cond_agg_functions: + if len(func.parameters) != 2 or not isinstance( + func.parameters[1], FunctionCall + ): + raise ValueError("unexpected form of function") + + if new_condition is None: + new_condition = deepcopy(func.parameters[1]) + else: + new_condition = binary_condition( + "or", deepcopy(func.parameters[1]), new_condition + ) + return new_condition + def process_mql_query( self, query: LogicalQuery | CompositeQuery[QueryEntity] ) -> None: @@ -130,6 +188,37 @@ def process_mql_query( query.add_condition_to_ast(domain_filter) metrics.increment("kyles_optimizer_optimized") + def get_select_filter_old( + self, + query: LogicalQuery | CompositeQuery[QueryEntity], + ) -> FunctionCall | None: + domain = self.get_domain_of_mql_query(query) + + if not domain: + return None + + # make the condition + domain_filter = None + for key, value in domain.items(): + clause = binary_condition( + "in", + key, + FunctionCall( + alias=None, + function_name="array", + parameters=tuple(value), + ), + ) + if not domain_filter: + domain_filter = clause + else: + domain_filter = binary_condition( + "and", + domain_filter, + clause, + ) + return domain_filter + def get_domain_of_mql_query( self, query: LogicalQuery | CompositeQuery[QueryEntity] ) -> Domain: diff --git a/tests/query/processors/test_filter_in_select_optimizer.py b/tests/query/processors/test_filter_in_select_optimizer.py index 57879f352dc..6f028a9ca3c 100644 --- a/tests/query/processors/test_filter_in_select_optimizer.py +++ b/tests/query/processors/test_filter_in_select_optimizer.py @@ -1,7 +1,14 @@ import pytest from snuba.datasets.factory import get_dataset -from snuba.query.expressions import Column, Literal, SubscriptableReference +from snuba.query.conditions import binary_condition +from snuba.query.expressions import ( + Column, + Expression, + FunctionCall, + Literal, + SubscriptableReference, +) from snuba.query.logical import Query from snuba.query.mql.parser import parse_mql_query from snuba.query.processors.logical.filter_in_select_optimizer import ( @@ -38,6 +45,9 @@ "limit": None, "offset": None, } +assert isinstance( + mql_context["indexer_mappings"], dict +) # oh mypy, my oh mypy, how you check types """ TEST CASES """ @@ -147,6 +157,239 @@ def subscriptable_reference(name: str, key: str) -> SubscriptableReference: ), ] + +def equals(lhs: Expression, rhs: Expression) -> FunctionCall: + return binary_condition("equals", lhs, rhs) + + +def _and(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall: + return binary_condition("and", lhs, rhs) + + +def _or(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall: + return binary_condition("or", lhs, rhs) + + +def _in(lhs: Expression, rhs: Expression) -> FunctionCall: + return binary_condition("in", lhs, rhs) + + +new_mql_test_cases: list[tuple[str, FunctionCall]] = [ + ( + "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@second`)", + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "quantiles(0.5)(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "sum(`d:transactions/duration@millisecond`) / ((max(`d:transactions/duration@millisecond`) + avg(`d:transactions/duration@millisecond`)) * min(`d:transactions/duration@millisecond`))", + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200}", + _or( + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:[400,404,500,501]}", + _or( + _and( + _in( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + FunctionCall( + None, + "array", + ( + Literal(None, "400"), + Literal(None, "404"), + Literal(None, "500"), + Literal(None, "501"), + ), + ), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + _and( + _in( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + FunctionCall( + None, + "array", + ( + Literal(None, "400"), + Literal(None, "404"), + Literal(None, "500"), + Literal(None, "501"), + ), + ), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200} by transaction", + _or( + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + _and( + equals( + subscriptable_reference( + "tags_raw", str(mql_context["indexer_mappings"]["status_code"]) + ), + Literal(None, "200"), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ), + ( + "(sum(`d:transactions/duration@millisecond`) / sum(`d:transactions/duration@millisecond`)) + 100", + _or( + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), + ), + ( + "sum(`d:transactions/duration@millisecond`) * 1000", + equals( + Column("_snuba_metric_id", None, "metric_id"), + Literal(None, 123456), + ), + ), +] + """ TESTING """ optimizer = FilterInSelectOptimizer() @@ -183,3 +426,17 @@ def test_searcher(mql_query: str, expected_domain: set[int]) -> None: selected_expression.expression.accept(v) newres = len(v.get_matches()) > 0 assert newres == oldres + + +@pytest.mark.parametrize( + "mql_query, expected_condition", + new_mql_test_cases, +) +def test_new_pipeline(mql_query: str, expected_condition: FunctionCall) -> None: + logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics) + assert isinstance(logical_query, Query) + + opt = FilterInSelectOptimizer() + actual = opt.get_select_filter(logical_query) + if actual != expected_condition: + assert actual != expected_condition