From aa14b01fdde6999dd10a6a01a067b3441fdcb7a1 Mon Sep 17 00:00:00 2001 From: Kyle Mumma Date: Thu, 21 Mar 2024 10:33:23 -0700 Subject: [PATCH] removed old optimizer, on by default --- snuba/query/mql/parser.py | 2 +- .../logical/filter_in_select_optimizer.py | 240 +----------------- .../test_filter_in_select_optimizer.py | 135 +--------- 3 files changed, 9 insertions(+), 368 deletions(-) diff --git a/snuba/query/mql/parser.py b/snuba/query/mql/parser.py index 2642b69cc5f..06495e1eed2 100644 --- a/snuba/query/mql/parser.py +++ b/snuba/query/mql/parser.py @@ -1067,7 +1067,7 @@ def transform(exp: Expression) -> Expression: def optimize_filter_in_select( query: CompositeQuery[QueryEntity] | LogicalQuery, ) -> None: - FilterInSelectOptimizer().process_mql_query2(query) + FilterInSelectOptimizer().process_mql_query(query) CustomProcessors = Sequence[ diff --git a/snuba/query/processors/logical/filter_in_select_optimizer.py b/snuba/query/processors/logical/filter_in_select_optimizer.py index a88967b3a4d..1d31d8fa8d5 100644 --- a/snuba/query/processors/logical/filter_in_select_optimizer.py +++ b/snuba/query/processors/logical/filter_in_select_optimizer.py @@ -9,7 +9,6 @@ Argument, Column, CurriedFunctionCall, - Expression, ExpressionVisitor, FunctionCall, Lambda, @@ -20,18 +19,6 @@ from snuba.state import get_int_config from snuba.utils.metrics.wrapper import MetricsWrapper -""" -Domain maps from a property to the specific values that are being filtered for. Ex: - sumIf(value, metric_id in [1,2,3]) - Domain = {metric_id: {1,2,3}} - sumIf(value, metric_id=5 and status=200) / sumIf(value, metric_id=5 and status=400) - Domain = { - metric_id: {5}, - status: {200, 400} - } -""" -Domain = dict[Column | SubscriptableReference, set[Literal]] - metrics = MetricsWrapper(environment.metrics, "api") logger = logging.getLogger(__name__) @@ -79,7 +66,8 @@ def visit_lambda(self, exp: Lambda) -> None: class FilterInSelectOptimizer: """ - This optimizer takes queries that filter in the select clause (via conditional aggregate functions), + This optimizer lifts conditions in the select clause into the where clause, + this is and adds the equivalent conditions to the where clause. Example: SELECT sumIf(value, metric_id in (1,2,3,4) and status=200) @@ -92,10 +80,10 @@ class FilterInSelectOptimizer: WHERE metric_id in (1,2,3,4) and status=200 """ - def process_mql_query2( + def process_mql_query( self, query: LogicalQuery | CompositeQuery[QueryEntity] ) -> None: - feat_flag = get_int_config("enable_filter_in_select_optimizer", default=0) + feat_flag = get_int_config("enable_filter_in_select_optimizer", default=1) if feat_flag: try: new_condition = self.get_select_filter(query) @@ -113,7 +101,9 @@ def get_select_filter( one condition. ex: SELECT sumIf(value, metric_id in (1,2,3,4) and status=200), - avgIf(value, metric_id in (11,12) and status=400), + avgIf(value, metric_id in (11,12) and status=400), + ... + returns or((metric_id in (1,2,3,4) and status=200), (metric_id in (11,12) and status=400)) """ # find and grab all the conditional aggregate functions @@ -149,219 +139,3 @@ def get_select_filter( "or", deepcopy(func.parameters[1]), new_condition ) return new_condition - - def process_mql_query( - self, query: LogicalQuery | CompositeQuery[QueryEntity] - ) -> None: - feat_flag = get_int_config("enable_filter_in_select_optimizer", default=0) - if feat_flag: - try: - domain = self.get_domain_of_mql_query(query) - except ValueError: - logger.warning( - "Failed getting domain", exc_info=True, extra={"query": query} - ) - domain = {} - - if domain: - # add domain to where clause - domain_filter = None - for key, value in domain.items(): - clause = binary_condition( - "in", - key, - FunctionCall( - alias=None, - function_name="array", - parameters=tuple(value), - ), - ) - if not domain_filter: - domain_filter = clause - else: - domain_filter = binary_condition( - "and", - domain_filter, - clause, - ) - assert domain_filter is not None - query.add_condition_to_ast(domain_filter) - metrics.increment("kyles_optimizer_optimized") - - def get_select_filter_old( - self, - query: LogicalQuery | CompositeQuery[QueryEntity], - ) -> FunctionCall | None: - domain = self.get_domain_of_mql_query(query) - - if not domain: - return None - - # make the condition - domain_filter = None - for key, value in domain.items(): - clause = binary_condition( - "in", - key, - FunctionCall( - alias=None, - function_name="array", - parameters=tuple(value), - ), - ) - if not domain_filter: - domain_filter = clause - else: - domain_filter = binary_condition( - "and", - domain_filter, - clause, - ) - return domain_filter - - def get_domain_of_mql_query( - self, query: LogicalQuery | CompositeQuery[QueryEntity] - ) -> Domain: - """ - This function returns the metric_id domain of the given query. - For a definition of metric_id domain, go to definition of the return type of this function ('Domain') - """ - expressions = map(lambda x: x.expression, query.get_selected_columns()) - target_exp = None - for exp in expressions: - if self._contains_conditional_aggregate(exp): - if target_exp is not None: - raise ValueError( - "Was expecting only 1 select expression to contain condition aggregate but found multiple" - ) - else: - target_exp = exp - - if target_exp is not None: - domains = self._get_conditional_domains(target_exp) - if len(domains) == 0: - raise ValueError("This shouldnt happen bc there is a target_exp") - - # find the intersect of keys, across the domains of all conditional aggregates - key_intersect = set(domains[0].keys()) - for i in range(1, len(domains)): - domain = domains[i] - key_intersect = key_intersect.intersection(set(domains[i].keys())) - - # union the domains - domain_union: Domain = {} - for key in key_intersect: - domain_union[key] = set() - for domain in domains: - domain_union[key] = domain_union[key].union(domain[key]) - - return domain_union - else: - return {} - - def _contains_conditional_aggregate(self, exp: Expression) -> bool: - if isinstance(exp, FunctionCall): - if exp.function_name[-2:] == "If": - return True - for param in exp.parameters: - if self._contains_conditional_aggregate(param): - return True - return False - elif isinstance(exp, CurriedFunctionCall): - if exp.internal_function.function_name[-2:] == "If": - return True - return False - else: - return False - - def _get_conditional_domains(self, exp: Expression) -> list[Domain]: - domains: list[Domain] = [] - self._get_conditional_domains_helper(exp, domains) - return domains - - def _get_conditional_domains_helper( - self, exp: Expression, domains: list[Domain] - ) -> None: - if isinstance(exp, FunctionCall): - # add domain of function call - if exp.function_name[-2:] == "If": - if len(exp.parameters) != 2 or not isinstance( - exp.parameters[1], FunctionCall - ): - raise ValueError("unexpected form of function aggregate") - domains.append(self._get_domain_of_predicate(exp.parameters[1])) - else: - for param in exp.parameters: - self._get_conditional_domains_helper(param, domains) - elif isinstance(exp, CurriedFunctionCall): - # add domain of curried function - if exp.internal_function.function_name[-2:] == "If": - if len(exp.parameters) != 2 or not isinstance( - exp.parameters[1], FunctionCall - ): - raise ValueError("unexpected form of curried function aggregate") - - domains.append(self._get_domain_of_predicate(exp.parameters[1])) - - def _get_domain_of_predicate(self, p: FunctionCall) -> Domain: - domain: Domain = {} - self._get_domain_of_predicate_helper(p, domain) - return domain - - def _get_domain_of_predicate_helper( - self, - p: FunctionCall, - domain: Domain, - ) -> None: - if p.function_name == "equals": - # validate - if not len(p.parameters) == 2: - raise ValueError("unexpected form of 'equals' function in predicate") - lhs = p.parameters[0] - rhs = p.parameters[1] - if not isinstance(lhs, (Column, SubscriptableReference)) or not isinstance( - rhs, Literal - ): - raise ValueError("unexpected form of 'equals' function in predicate") - # if already there throw error, this was to protect against: and(field=1, field=2) - if lhs in domain: - raise ValueError("lhs of 'equals' was already seen (likely from and)") - - # add it to domain - domain[lhs] = {rhs} - elif p.function_name == "in": - # validate - if not len(p.parameters) == 2: - raise ValueError("unexpected form of 'in' function in predicate") - lhs = p.parameters[0] - rhs = p.parameters[1] - if not ( - isinstance(lhs, (Column, SubscriptableReference)) - and isinstance(rhs, FunctionCall) - and rhs.function_name in ("array", "tuple") - ): - raise ValueError("unexpected form of 'in' function in predicate") - # if already there throw error, this was to protect against: and(field=1, field=2) - if lhs in domain: - raise ValueError("lhs of 'in' was already seen (likely from and)") - - # add it to domain - values = set() - for e in rhs.parameters: - if not isinstance(e, Literal): - raise ValueError( - "expected rhs of 'in' to only contain Literal, but that was not the case" - ) - values.add(e) - domain[lhs] = values - elif p.function_name == "and": - if not ( - len(p.parameters) == 2 - and isinstance(p.parameters[0], FunctionCall) - and isinstance(p.parameters[1], FunctionCall) - ): - raise ValueError("unexpected form of 'and' function in predicate") - self._get_domain_of_predicate_helper(p.parameters[0], domain) - self._get_domain_of_predicate_helper(p.parameters[1], domain) - else: - raise ValueError("unexpected form of predicate") diff --git a/tests/query/processors/test_filter_in_select_optimizer.py b/tests/query/processors/test_filter_in_select_optimizer.py index 6f028a9ca3c..9ebc293e6b5 100644 --- a/tests/query/processors/test_filter_in_select_optimizer.py +++ b/tests/query/processors/test_filter_in_select_optimizer.py @@ -13,7 +13,6 @@ from snuba.query.mql.parser import parse_mql_query from snuba.query.processors.logical.filter_in_select_optimizer import ( FilterInSelectOptimizer, - FindConditionalAggregateFunctionsVisitor, ) """ CONFIG STUFF THAT DOESNT MATTER MUCH """ @@ -61,103 +60,6 @@ def subscriptable_reference(name: str, key: str) -> SubscriptableReference: ) -mql_test_cases: list[tuple[str, dict]] = [ - ( - "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@millisecond`)", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - } - }, - ), - ( - "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@second`)", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - Literal(None, 123457), - } - }, - ), - ( - "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - } - }, - ), - ( - "quantiles(0.5)(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - } - }, - ), - ( - "sum(`d:transactions/duration@millisecond`) / ((max(`d:transactions/duration@millisecond`) + avg(`d:transactions/duration@millisecond`)) * min(`d:transactions/duration@millisecond`))", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - } - }, - ), - ( - "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200}", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - }, - subscriptable_reference("tags_raw", "222222"): { - Literal(None, "200"), - }, - }, - ), - ( - "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:[400,404,500,501]}", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - }, - subscriptable_reference("tags_raw", "222222"): { - Literal(None, "400"), - Literal(None, "404"), - Literal(None, "500"), - Literal(None, "501"), - }, - }, - ), - ( - "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200} by transaction", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - }, - subscriptable_reference("tags_raw", "222222"): { - Literal(None, "200"), - }, - }, - ), - ( - "(sum(`d:transactions/duration@millisecond`) / sum(`d:transactions/duration@millisecond`)) + 100", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - }, - }, - ), - ( - "sum(`d:transactions/duration@millisecond`) * 1000", - { - Column("_snuba_metric_id", None, "metric_id"): { - Literal(None, 123456), - } - }, - ), -] - - def equals(lhs: Expression, rhs: Expression) -> FunctionCall: return binary_condition("equals", lhs, rhs) @@ -392,47 +294,12 @@ def _in(lhs: Expression, rhs: Expression) -> FunctionCall: """ TESTING """ -optimizer = FilterInSelectOptimizer() - - -@pytest.mark.parametrize( - "mql_query, expected_domain", - mql_test_cases, -) -def test_get_domain_of_mql(mql_query: str, expected_domain: set[int]) -> None: - logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics) - assert isinstance(logical_query, Query) - res = optimizer.get_domain_of_mql_query(logical_query) - if res != expected_domain: - raise - assert res == expected_domain - - -@pytest.mark.parametrize( - "mql_query, expected_domain", - mql_test_cases, -) -def test_searcher(mql_query: str, expected_domain: set[int]) -> None: - logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics) - assert isinstance(logical_query, Query) - - opt = FilterInSelectOptimizer() - for selected_expression in logical_query.get_selected_columns(): - exp = selected_expression.expression - - oldres = opt._contains_conditional_aggregate(exp) - - v = FindConditionalAggregateFunctionsVisitor() - selected_expression.expression.accept(v) - newres = len(v.get_matches()) > 0 - assert newres == oldres - @pytest.mark.parametrize( "mql_query, expected_condition", new_mql_test_cases, ) -def test_new_pipeline(mql_query: str, expected_condition: FunctionCall) -> None: +def test_condition_generation(mql_query: str, expected_condition: FunctionCall) -> None: logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics) assert isinstance(logical_query, Query)