diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a28eced0ca..b34896a7be 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -367,7 +367,6 @@ jobs: tests/sentry/event_manager \ tests/sentry/api/endpoints/test_organization_profiling_functions.py \ tests/snuba/api/endpoints/test_organization_events_stats_mep.py \ - tests/sentry/sentry_metrics/querying \ tests/snuba/test_snql_snuba.py \ tests/snuba/test_metrics_layer.py \ -vv --cov . --cov-report="xml:.artifacts/snuba.coverage.xml" diff --git a/snuba/query/conditions.py b/snuba/query/conditions.py index 0de509bc3f..11a0e129d7 100644 --- a/snuba/query/conditions.py +++ b/snuba/query/conditions.py @@ -155,10 +155,7 @@ def is_not_in_condition_pattern(lhs: Pattern[Expression]) -> FunctionCallPattern def binary_condition( function_name: str, lhs: Expression, rhs: Expression ) -> FunctionCall: - """This function is deprecated please use snuba.query.dsl.binary_condition""" - from snuba.query.dsl import binary_condition as dsl_binary_condition - - return dsl_binary_condition(function_name, lhs, rhs) + return FunctionCall(None, function_name, (lhs, rhs)) binary_condition_patterns = { diff --git a/snuba/query/dsl.py b/snuba/query/dsl.py index cfc7b3bad4..e56a46d219 100644 --- a/snuba/query/dsl.py +++ b/snuba/query/dsl.py @@ -55,29 +55,6 @@ def divide( return FunctionCall(alias, "divide", (lhs, rhs)) -# boolean functions -def binary_condition( - function_name: str, lhs: Expression, rhs: Expression -) -> FunctionCall: - return FunctionCall(None, function_name, (lhs, rhs)) - - -def equals(lhs: Expression, rhs: Expression) -> FunctionCall: - return binary_condition("equals", lhs, rhs) - - -def and_cond(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall: - return binary_condition("and", lhs, rhs) - - -def or_cond(lhs: FunctionCall, rhs: FunctionCall) -> FunctionCall: - return binary_condition("or", lhs, rhs) - - -def in_cond(lhs: Expression, rhs: Expression) -> FunctionCall: - return binary_condition("in", lhs, rhs) - - # aggregate functions def count(column: Optional[Column] = None, alias: Optional[str] = None) -> FunctionCall: return FunctionCall(alias, "count", (column,) if column else ()) diff --git a/snuba/query/mql/parser.py b/snuba/query/mql/parser.py index 7dc4b70cfb..06495e1eed 100644 --- a/snuba/query/mql/parser.py +++ b/snuba/query/mql/parser.py @@ -39,7 +39,7 @@ from snuba.query.processors.logical.filter_in_select_optimizer import ( FilterInSelectOptimizer, ) -from snuba.query.query_settings import HTTPQuerySettings, QuerySettings +from snuba.query.query_settings import QuerySettings from snuba.query.snql.anonymize import format_snql_anonymized from snuba.query.snql.parser import ( MAX_LIMIT, @@ -49,7 +49,7 @@ _replace_time_condition, _treeify_or_and_conditions, ) -from snuba.state import explain_meta, get_int_config +from snuba.state import explain_meta from snuba.util import parse_datetime from snuba.utils.constants import GRANULARITIES_AVAILABLE @@ -1064,12 +1064,19 @@ def transform(exp: Expression) -> Expression: query.transform_expressions(transform) +def optimize_filter_in_select( + query: CompositeQuery[QueryEntity] | LogicalQuery, +) -> None: + FilterInSelectOptimizer().process_mql_query(query) + + CustomProcessors = Sequence[ Callable[[Union[CompositeQuery[QueryEntity], LogicalQuery]], None] ] MQL_POST_PROCESSORS: CustomProcessors = POST_PROCESSORS + [ quantiles_to_quantile, + optimize_filter_in_select, ] @@ -1109,17 +1116,6 @@ def parse_mql_query( settings, ) - # Filter in select optimizer - feat_flag = get_int_config("enable_filter_in_select_optimizer", default=1) - if feat_flag: - with sentry_sdk.start_span( - op="processor", description="filter_in_select_optimize" - ): - if settings is None: - FilterInSelectOptimizer().process_query(query, HTTPQuerySettings()) - else: - FilterInSelectOptimizer().process_query(query, settings) - # Custom processing to tweak the AST before validation with sentry_sdk.start_span(op="processor", description="custom_processing"): if custom_processing is not None: diff --git a/snuba/query/processors/logical/filter_in_select_optimizer.py b/snuba/query/processors/logical/filter_in_select_optimizer.py index e5766af27d..02794742eb 100644 --- a/snuba/query/processors/logical/filter_in_select_optimizer.py +++ b/snuba/query/processors/logical/filter_in_select_optimizer.py @@ -5,152 +5,230 @@ from snuba.query.conditions import binary_condition from snuba.query.data_source.simple import Entity as QueryEntity from snuba.query.expressions import ( - Argument, Column, CurriedFunctionCall, - ExpressionVisitor, + Expression, FunctionCall, - Lambda, Literal, SubscriptableReference, ) from snuba.query.logical import Query as LogicalQuery -from snuba.query.processors.logical import LogicalQueryProcessor -from snuba.query.query_settings import QuerySettings +from snuba.state import get_int_config from snuba.utils.metrics.wrapper import MetricsWrapper +""" +Domain maps from a property to the specific values that are being filtered for. Ex: + sumIf(value, metric_id in [1,2,3]) + Domain = {metric_id: {1,2,3}} + sumIf(value, metric_id=5 and status=200) / sumIf(value, metric_id=5 and status=400) + Domain = { + metric_id: {5}, + status: {200, 400} + } +""" +Domain = dict[Column | SubscriptableReference, set[Literal]] + metrics = MetricsWrapper(environment.metrics, "api") logger = logging.getLogger(__name__) -class FindConditionalAggregateFunctionsVisitor( - ExpressionVisitor[list[FunctionCall | CurriedFunctionCall]] -): - """ - Visitor that searches an expression for all conditional aggregate functions. - Results are returned via get_matches function. - - Example: - myexp = add(divide(sumIf(...), avgIf(...)), 100) - visitor = FindConditionalAggregateFunctionsVisitor() - myexp.accept(visitor) - res = visitor.get_matches() - - Visitor implementation to find all conditional aggregate functions in an expression. - Usage: - >>> exp: Expression = add(divide(sumIf(...), avgIf(...)), 100) - >>> visitor = FindConditionalAggregateFunctionsVisitor() - >>> found = exp.accept(visitor) # found = [sumIf(...), avgIf(...)] +class FilterInSelectOptimizer: """ + This optimizer takes queries that filter by metric_id in the select clause (via conditional aggregate functions), + and adds the equivalent metric_id filtering to the where clause. Example: - def __init__(self) -> None: - self._matches: list[FunctionCall | CurriedFunctionCall] = [] - - def visit_literal(self, exp: Literal) -> list[FunctionCall | CurriedFunctionCall]: - return self._matches - - def visit_column(self, exp: Column) -> list[FunctionCall | CurriedFunctionCall]: - return self._matches - - def visit_subscriptable_reference( - self, exp: SubscriptableReference - ) -> list[FunctionCall | CurriedFunctionCall]: - return self._matches - - def visit_function_call( - self, exp: FunctionCall - ) -> list[FunctionCall | CurriedFunctionCall]: - if exp.function_name[-2:] == "If": - self._matches.append(exp) - else: - for param in exp.parameters: - param.accept(self) - return self._matches - - def visit_curried_function_call( - self, exp: CurriedFunctionCall - ) -> list[FunctionCall | CurriedFunctionCall]: - if exp.internal_function.function_name[-2:] == "If": - self._matches.append(exp) - return self._matches - - def visit_argument(self, exp: Argument) -> list[FunctionCall | CurriedFunctionCall]: - return self._matches - - def visit_lambda(self, exp: Lambda) -> list[FunctionCall | CurriedFunctionCall]: - return self._matches - - -class FilterInSelectOptimizer(LogicalQueryProcessor): - """ - This optimizer grabs all conditions from conditional aggregate functions in the select clause - and adds them into the where clause. Example: - - SELECT sumIf(value, metric_id in (1,2,3,4) and status=200) / sumIf(value, metric_id in (1,2,3,4)), - avgIf(value, metric_id in (3,4,5)) + SELECT sumIf(value, metric_id in (1,2,3,4)) FROM table becomes - SELECT sumIf(value, metric_id in (1,2,3,4) and status=200) / sumIf(value, metric_id in (1,2,3,4)), - avgIf(value, metric_id in (3,4,5)) + SELECT sumIf(value, metric_id in (1,2,3,4)) FROM table - WHERE (metric_id in (1,2,3,4) and status=200) or metric_id in (1,2,3,4) or metric_id in (3,4,5) + WHERE metric_id in (1,2,3,4) """ - def process_query(self, query: LogicalQuery, query_settings: QuerySettings) -> None: - try: - new_condition = self.get_select_filter(query) - except Exception: - logger.warning( - "Failed during optimization", exc_info=True, extra={"query": query} - ) - return - if new_condition is not None: - query.add_condition_to_ast(new_condition) - metrics.increment("filter_in_select_optimizer_optimized") - - def get_select_filter( - self, - query: LogicalQuery | CompositeQuery[QueryEntity], - ) -> FunctionCall | None: + def process_mql_query( + self, query: LogicalQuery | CompositeQuery[QueryEntity] + ) -> None: + feat_flag = get_int_config("enable_filter_in_select_optimizer", default=0) + if feat_flag: + try: + domain = self.get_domain_of_mql_query(query) + except ValueError: + logger.warning( + "Failed getting domain", exc_info=True, extra={"query": query} + ) + domain = {} + + if domain: + # add domain to where clause + domain_filter = None + for key, value in domain.items(): + clause = binary_condition( + "in", + key, + FunctionCall( + alias=None, + function_name="array", + parameters=tuple(value), + ), + ) + if not domain_filter: + domain_filter = clause + else: + domain_filter = binary_condition( + "and", + domain_filter, + clause, + ) + assert domain_filter is not None + query.add_condition_to_ast(domain_filter) + metrics.increment("kyles_optimizer_optimized") + + def get_domain_of_mql_query( + self, query: LogicalQuery | CompositeQuery[QueryEntity] + ) -> Domain: """ - Given a query, grabs all the conditions from conditional aggregates and lifts into - one condition. - - ex: SELECT sumIf(value, metric_id in (1,2,3,4) and status=200), - avgIf(value, metric_id in (11,12) and status=400), - ... - - returns or((metric_id in (1,2,3,4) and status=200), (metric_id in (11,12) and status=400)) + This function returns the metric_id domain of the given query. + For a definition of metric_id domain, go to definition of the return type of this function ('Domain') """ - # find and grab all the conditional aggregate functions - cond_agg_functions: list[FunctionCall | CurriedFunctionCall] = [] - for selected_exp in query.get_selected_columns(): - found = selected_exp.expression.accept( - FindConditionalAggregateFunctionsVisitor() - ) - cond_agg_functions += found - if len(cond_agg_functions) == 0: - return None - - # validate the functions, and lift their conditions into new_condition, return it - new_condition = None - for func in cond_agg_functions: - if len(func.parameters) != 2: - raise ValueError( - f"expected conditional function to be of the form funcIf(val, condition), but was given one with {len(func.parameters)} parameters" - ) - if not isinstance(func.parameters[1], FunctionCall): - raise ValueError( - f"expected conditional function to be of the form funcIf(val, condition), but the condition is type {type(func.parameters[1])} rather than FunctionCall" - ) + expressions = map(lambda x: x.expression, query.get_selected_columns()) + target_exp = None + for exp in expressions: + if self._contains_conditional_aggregate(exp): + if target_exp is not None: + raise ValueError( + "Was expecting only 1 select expression to contain condition aggregate but found multiple" + ) + else: + target_exp = exp + + if target_exp is not None: + domains = self._get_conditional_domains(target_exp) + if len(domains) == 0: + raise ValueError("This shouldnt happen bc there is a target_exp") + + # find the intersect of keys, across the domains of all conditional aggregates + key_intersect = set(domains[0].keys()) + for i in range(1, len(domains)): + domain = domains[i] + key_intersect = key_intersect.intersection(set(domains[i].keys())) + + # union the domains + domain_union: Domain = {} + for key in key_intersect: + domain_union[key] = set() + for domain in domains: + domain_union[key] = domain_union[key].union(domain[key]) + + return domain_union + else: + return {} - if new_condition is None: - new_condition = func.parameters[1] + def _contains_conditional_aggregate(self, exp: Expression) -> bool: + if isinstance(exp, FunctionCall): + if exp.function_name[-2:] == "If": + return True + for param in exp.parameters: + if self._contains_conditional_aggregate(param): + return True + return False + elif isinstance(exp, CurriedFunctionCall): + if exp.internal_function.function_name[-2:] == "If": + return True + return False + else: + return False + + def _get_conditional_domains(self, exp: Expression) -> list[Domain]: + domains: list[Domain] = [] + self._get_conditional_domains_helper(exp, domains) + return domains + + def _get_conditional_domains_helper( + self, exp: Expression, domains: list[Domain] + ) -> None: + if isinstance(exp, FunctionCall): + # add domain of function call + if exp.function_name[-2:] == "If": + if len(exp.parameters) != 2 or not isinstance( + exp.parameters[1], FunctionCall + ): + raise ValueError("unexpected form of function aggregate") + domains.append(self._get_domain_of_predicate(exp.parameters[1])) else: - new_condition = binary_condition( - "or", new_condition, func.parameters[1] - ) - return new_condition + for param in exp.parameters: + self._get_conditional_domains_helper(param, domains) + elif isinstance(exp, CurriedFunctionCall): + # add domain of curried function + if exp.internal_function.function_name[-2:] == "If": + if len(exp.parameters) != 2 or not isinstance( + exp.parameters[1], FunctionCall + ): + raise ValueError("unexpected form of curried function aggregate") + + domains.append(self._get_domain_of_predicate(exp.parameters[1])) + + def _get_domain_of_predicate(self, p: FunctionCall) -> Domain: + domain: Domain = {} + self._get_domain_of_predicate_helper(p, domain) + return domain + + def _get_domain_of_predicate_helper( + self, + p: FunctionCall, + domain: Domain, + ) -> None: + if p.function_name == "equals": + # validate + if not len(p.parameters) == 2: + raise ValueError("unexpected form of 'equals' function in predicate") + lhs = p.parameters[0] + rhs = p.parameters[1] + if not isinstance(lhs, (Column, SubscriptableReference)) or not isinstance( + rhs, Literal + ): + raise ValueError("unexpected form of 'equals' function in predicate") + # if already there throw error, this was to protect against: and(field=1, field=2) + if lhs in domain: + raise ValueError("lhs of 'equals' was already seen (likely from and)") + + # add it to domain + domain[lhs] = {rhs} + elif p.function_name == "in": + # validate + if not len(p.parameters) == 2: + raise ValueError("unexpected form of 'in' function in predicate") + lhs = p.parameters[0] + rhs = p.parameters[1] + if not ( + isinstance(lhs, (Column, SubscriptableReference)) + and isinstance(rhs, FunctionCall) + and rhs.function_name in ("array", "tuple") + ): + raise ValueError("unexpected form of 'in' function in predicate") + # if already there throw error, this was to protect against: and(field=1, field=2) + if lhs in domain: + raise ValueError("lhs of 'in' was already seen (likely from and)") + + # add it to domain + values = set() + for e in rhs.parameters: + if not isinstance(e, Literal): + raise ValueError( + "expected rhs of 'in' to only contain Literal, but that was not the case" + ) + values.add(e) + domain[lhs] = values + elif p.function_name == "and": + if not ( + len(p.parameters) == 2 + and isinstance(p.parameters[0], FunctionCall) + and isinstance(p.parameters[1], FunctionCall) + ): + raise ValueError("unexpected form of 'and' function in predicate") + self._get_domain_of_predicate_helper(p.parameters[0], domain) + self._get_domain_of_predicate_helper(p.parameters[1], domain) + else: + raise ValueError("unexpected form of predicate") diff --git a/tests/query/parser/test_formula_mql_query.py b/tests/query/parser/test_formula_mql_query.py index d36be56007..9bb19b024e 100644 --- a/tests/query/parser/test_formula_mql_query.py +++ b/tests/query/parser/test_formula_mql_query.py @@ -11,15 +11,7 @@ from snuba.query import OrderBy, OrderByDirection, SelectedExpression from snuba.query.conditions import binary_condition from snuba.query.data_source.simple import Entity as QueryEntity -from snuba.query.dsl import ( - and_cond, - arrayElement, - divide, - equals, - multiply, - or_cond, - plus, -) +from snuba.query.dsl import arrayElement, divide, multiply, plus from snuba.query.expressions import ( Column, CurriedFunctionCall, @@ -257,22 +249,6 @@ def test_simple_formula() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -283,11 +259,7 @@ def test_simple_formula() -> None: ), ], groupby=[time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -306,7 +278,6 @@ def test_simple_formula() -> None: assert eq, reason -@pytest.mark.xfail() def test_simple_formula_with_leading_literals() -> None: query_body = "1 + sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@millisecond`)" expected_selected = SelectedExpression( @@ -323,22 +294,6 @@ def test_simple_formula_with_leading_literals() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -349,11 +304,7 @@ def test_simple_formula_with_leading_literals() -> None: ), ], groupby=[time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -369,7 +320,6 @@ def test_simple_formula_with_leading_literals() -> None: ) query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) eq, reason = query.equals(expected) - assert eq, reason def test_groupby() -> None: @@ -388,22 +338,6 @@ def test_groupby() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -415,11 +349,7 @@ def test_groupby() -> None: ), ], groupby=[tag_column("transaction"), time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -483,22 +413,6 @@ def test_curried_aggregate() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -510,11 +424,7 @@ def test_curried_aggregate() -> None: ), ], groupby=[tag_column("transaction"), time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -549,28 +459,6 @@ def test_bracketing_rules() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - or_cond( - or_cond( - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -581,11 +469,7 @@ def test_bracketing_rules() -> None: ), ], groupby=[time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -639,28 +523,6 @@ def test_formula_filters() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -671,11 +533,7 @@ def test_formula_filters() -> None: ), ], groupby=[time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -716,28 +574,6 @@ def test_formula_groupby() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - and_cond( - equals( - tag_column("status_code"), - Literal(None, "200"), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -752,11 +588,7 @@ def test_formula_groupby() -> None: ), ], groupby=[tag_column("transaction"), time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, @@ -788,16 +620,6 @@ def test_formula_scalar_value() -> None: "_snuba_aggregate_value", ), ) - filter_in_select_condition = or_cond( - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - equals( - Column("_snuba_metric_id", None, "metric_id"), - Literal(None, 123456), - ), - ) expected = Query( from_distributions, selected_columns=[ @@ -808,11 +630,7 @@ def test_formula_scalar_value() -> None: ), ], groupby=[time_expression], - condition=binary_condition( - "and", - filter_in_select_condition, - formula_condition, - ), + condition=formula_condition, order_by=[ OrderBy( direction=OrderByDirection.ASC, diff --git a/tests/query/processors/test_filter_in_select_optimizer.py b/tests/query/processors/test_filter_in_select_optimizer.py index 961eefffc4..56f282bb4c 100644 --- a/tests/query/processors/test_filter_in_select_optimizer.py +++ b/tests/query/processors/test_filter_in_select_optimizer.py @@ -1,168 +1,156 @@ -from copy import copy +import pytest -from snuba.datasets.entities.entity_key import EntityKey -from snuba.datasets.entities.factory import get_entity -from snuba.query import SelectedExpression -from snuba.query.data_source.simple import Entity -from snuba.query.dsl import and_cond, divide, equals, multiply, or_cond, plus -from snuba.query.expressions import Column, CurriedFunctionCall, FunctionCall, Literal +from snuba.datasets.factory import get_dataset +from snuba.query.expressions import Column, Literal, SubscriptableReference from snuba.query.logical import Query +from snuba.query.mql.parser import parse_mql_query from snuba.query.processors.logical.filter_in_select_optimizer import ( FilterInSelectOptimizer, ) -from snuba.query.query_settings import HTTPQuerySettings +""" CONFIG STUFF THAT DOESNT MATTER MUCH """ -def _equals(col_name: str, value: str | int) -> FunctionCall: - return equals(Column(None, None, col_name), Literal(None, value)) - - -def _cond_agg(function_name: str, condition: FunctionCall) -> FunctionCall: - return FunctionCall(None, function_name, (Column(None, None, "value"), condition)) - - -optimizer = FilterInSelectOptimizer() -from_entity = Entity( - EntityKey.GENERIC_METRICS_DISTRIBUTIONS, - get_entity(EntityKey.GENERIC_METRICS_DISTRIBUTIONS).get_data_model(), +generic_metrics = get_dataset( + "generic_metrics", ) -settings = HTTPQuerySettings() - - -def test_simple_query() -> None: - input_query = Query( - from_clause=from_entity, - selected_columns=[ - SelectedExpression( - None, - divide( - _cond_agg( - "sumIf", - and_cond(_equals("metric_id", 1), _equals("status_code", 200)), - ), - _cond_agg("sumIf", _equals("metric_id", 1)), - ), - ) - ], - ) - expected_optimized_query = copy(input_query) - expected_optimized_query.set_ast_condition( - or_cond( - and_cond(_equals("metric_id", 1), _equals("status_code", 200)), - _equals("metric_id", 1), - ) - ) - optimizer.process_query(input_query, settings) - assert input_query == expected_optimized_query - - -def test_query_with_curried_function() -> None: - input_query = Query( - from_clause=from_entity, - selected_columns=[ - SelectedExpression( - None, - divide( - CurriedFunctionCall( - alias=None, - internal_function=FunctionCall( - None, "quantilesIf", (Literal(None, 0.5),) - ), - parameters=( - Column(None, None, "value"), - and_cond( - _equals("metric_id", 1), _equals("status_code", 200) - ), - ), - ), - _cond_agg("sumIf", _equals("metric_id", 1)), - ), - ) - ], - ) - expected_optimized_query = copy(input_query) - expected_optimized_query.set_ast_condition( - or_cond( - and_cond(_equals("metric_id", 1), _equals("status_code", 200)), - _equals("metric_id", 1), - ) +mql_context = { + "entity": "generic_metrics_distributions", + "start": "2023-11-23T18:30:00", + "end": "2023-11-23T22:30:00", + "rollup": { + "granularity": 60, + "interval": 60, + "with_totals": "False", + "orderby": None, + }, + "scope": { + "org_ids": [1], + "project_ids": [11], + "use_case_id": "transactions", + }, + "indexer_mappings": { + "d:transactions/duration@millisecond": 123456, + "d:transactions/duration@second": 123457, + "status_code": 222222, + "transaction": 333333, + }, + "limit": None, + "offset": None, +} + +""" TEST CASES """ + + +def subscriptable_reference(name: str, key: str) -> SubscriptableReference: + """Helper function to build a SubscriptableReference""" + return SubscriptableReference( + f"_snuba_{name}[{key}]", + Column(f"_snuba_{name}", None, name), + Literal(None, key), ) - optimizer.process_query(input_query, settings) - assert input_query == expected_optimized_query - -def test_query_with_many_nested_functions() -> None: - input_query = Query( - from_clause=from_entity, - selected_columns=[ - SelectedExpression( - None, - divide( - _cond_agg("sumIf", _equals("metric_id", 1)), - multiply( - plus( - _cond_agg("maxIf", _equals("metric_id", 1)), - _cond_agg("avgIf", _equals("metric_id", 1)), - ), - _cond_agg("minIf", _equals("metric_id", 1)), - ), - ), - ) - ], - ) - expected_optimized_query = copy(input_query) - expected_optimized_query.set_ast_condition( - or_cond( - or_cond( - or_cond(_equals("metric_id", 1), _equals("metric_id", 1)), - _equals("metric_id", 1), - ), - _equals("metric_id", 1), - ) - ) - optimizer.process_query(input_query, settings) - assert input_query == expected_optimized_query +mql_test_cases: list[tuple[str, dict]] = [ + ( + "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@millisecond`)", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + } + }, + ), + ( + "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@second`)", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + Literal(None, 123457), + } + }, + ), + ( + "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + } + }, + ), + ( + "quantiles(0.5)(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + } + }, + ), + ( + "sum(`d:transactions/duration@millisecond`) / ((max(`d:transactions/duration@millisecond`) + avg(`d:transactions/duration@millisecond`)) * min(`d:transactions/duration@millisecond`))", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + } + }, + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200}", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + }, + subscriptable_reference("tags_raw", "222222"): { + Literal(None, "200"), + }, + }, + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:[400,404,500,501]}", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + }, + subscriptable_reference("tags_raw", "222222"): { + Literal(None, "400"), + Literal(None, "404"), + Literal(None, "500"), + Literal(None, "501"), + }, + }, + ), + ( + "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200} by transaction", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + }, + subscriptable_reference("tags_raw", "222222"): { + Literal(None, "200"), + }, + }, + ), + ( + "(sum(`d:transactions/duration@millisecond`) / sum(`d:transactions/duration@millisecond`)) + 100", + { + Column("_snuba_metric_id", None, "metric_id"): { + Literal(None, 123456), + }, + }, + ), +] + +""" TESTING """ -def test_query_with_literal_arithmetic_in_select() -> None: - input_query = Query( - from_clause=from_entity, - selected_columns=[ - SelectedExpression( - None, - plus(_cond_agg("sumIf", _equals("metric_id", 1)), Literal(None, 100.0)), - ) - ], - ) - expected_optimized_query = copy(input_query) - expected_optimized_query.set_ast_condition(_equals("metric_id", 1)) - optimizer.process_query(input_query, settings) - assert input_query == expected_optimized_query +optimizer = FilterInSelectOptimizer() -def test_query_with_multiple_aggregate_columns() -> None: - input_query = Query( - from_clause=from_entity, - selected_columns=[ - SelectedExpression( - None, - plus(_cond_agg("sumIf", _equals("metric_id", 1)), Literal(None, 100.0)), - ), - SelectedExpression( - None, - multiply( - _cond_agg("maxIf", _equals("metric_id", 2)), - _cond_agg("avgIf", _equals("metric_id", 2)), - ), - ), - ], - ) - expected_optimized_query = copy(input_query) - expected_optimized_query.set_ast_condition( - or_cond( - or_cond(_equals("metric_id", 1), _equals("metric_id", 2)), - _equals("metric_id", 2), - ) - ) - optimizer.process_query(input_query, settings) - assert input_query == expected_optimized_query +@pytest.mark.parametrize( + "mql_query, expected_domain", + mql_test_cases, +) +def test_get_domain_of_mql(mql_query: str, expected_domain: set[int]) -> None: + logical_query, _ = parse_mql_query(str(mql_query), mql_context, generic_metrics) + assert isinstance(logical_query, Query) + res = optimizer.get_domain_of_mql_query(logical_query) + if res != expected_domain: + raise + assert res == expected_domain diff --git a/tests/test_metrics_sdk_api.py b/tests/test_metrics_sdk_api.py index b41eda0ae0..584f99a790 100644 --- a/tests/test_metrics_sdk_api.py +++ b/tests/test_metrics_sdk_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from datetime import datetime, timedelta, timezone from typing import Any, Callable, Tuple, Union, cast @@ -358,7 +359,9 @@ def test_raw_mql_string(self, test_dataset: str, tag_column: str) -> None: assert response.status_code == 200, data rows = data["data"] - assert len(rows) == 0 + assert len(rows) >= 180, rows + + assert math.isnan(rows[0]["aggregate_value"]) # division by zero @pytest.mark.clickhouse_db