From fb8d23735f29805c7f6640882ccfcacfe0ab3714 Mon Sep 17 00:00:00 2001 From: Enoch Tang Date: Mon, 22 Jan 2024 15:16:30 -0500 Subject: [PATCH] feat(on-call): Add illegal aggregate function in conditions entity validator (#5435) * add illegal aggregate entity validator * clean up * fix typing * add negative test case and move killswitch * fix typing --- snuba/datasets/pluggable_entity.py | 4 +- snuba/query/validation/validators.py | 87 +++++++++- ...illegal_aggregate_conditions_validation.py | 155 ++++++++++++++++++ 3 files changed, 244 insertions(+), 2 deletions(-) create mode 100644 tests/datasets/validation/test_illegal_aggregate_conditions_validation.py diff --git a/snuba/datasets/pluggable_entity.py b/snuba/datasets/pluggable_entity.py index 7776015886..485ca68e6a 100644 --- a/snuba/datasets/pluggable_entity.py +++ b/snuba/datasets/pluggable_entity.py @@ -27,6 +27,7 @@ from snuba.query.validation.validators import ( ColumnValidationMode, EntityContainsColumnsValidator, + IllegalAggregateInConditionValidator, QueryValidator, ) from snuba.utils.schemas import SchemaModifiers @@ -70,7 +71,8 @@ def _get_builtin_validators(self) -> Sequence[QueryValidator]: EntityColumnSet(self.columns), mappers, self.validate_data_model or ColumnValidationMode.ERROR, - ) + ), + IllegalAggregateInConditionValidator(), ] def get_query_processors(self) -> Sequence[LogicalQueryProcessor]: diff --git a/snuba/query/validation/validators.py b/snuba/query/validation/validators.py index bf4c207440..70218625e0 100644 --- a/snuba/query/validation/validators.py +++ b/snuba/query/validation/validators.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Optional, Sequence, Type, cast +from snuba import state from snuba.clickhouse.translators.snuba.allowed import DefaultNoneColumnMapper from snuba.clickhouse.translators.snuba.function_call_mappers import ( AggregateCurriedFunctionMapper, @@ -20,6 +21,8 @@ from snuba.environment import metrics as environment_metrics from snuba.query import Query from snuba.query.conditions import ( + OPERATOR_TO_FUNCTION, + BooleanFunctions, ConditionFunctions, build_match, get_first_level_and_conditions, @@ -29,7 +32,9 @@ from snuba.query.expressions import Column, Expression, FunctionCall, Literal from snuba.query.expressions import SubscriptableReference as SubscriptableReferenceExpr from snuba.query.logical import Query as LogicalQuery -from snuba.query.matchers import Or +from snuba.query.matchers import AnyExpression +from snuba.query.matchers import FunctionCall as FunctionCallMatcher +from snuba.query.matchers import MatchResult, Or, Param, String from snuba.utils.metrics.wrapper import MetricsWrapper from snuba.utils.registered_class import RegisteredClass from snuba.utils.schemas import ColumnSet, Date, DateTime @@ -434,3 +439,83 @@ def validate(self, query: Query, alias: Optional[str] = None) -> None: logger.warning( f"{lhs} requires datetime conditions: '{param.value}' is not a valid datetime" ) + + +class IllegalAggregateInConditionValidator(QueryValidator): + """ + Ensures that aggregate functions are not used in WHERE clause of query. + """ + + def __init__(self) -> None: + # This is not an exhaustive list of all aggregate functions, + # but should be sufficient for catching most invalid queries + common_aggregate_functions = [ + "count", + "min", + "max", + "sum", + "avg", + "last", + "uniq", + ] + self.aggregate_function_names = [ + String(func_name) for func_name in common_aggregate_functions + ] + self.aggregate_function_names.extend( + [String(f"{func_name}If") for func_name in common_aggregate_functions] + ) + + def validate(self, query: Query, alias: Optional[str] = None) -> None: + def find_illegal_aggregate_functions( + expression: Expression, + ) -> list[MatchResult]: + matches: list[MatchResult] = [] + match = FunctionCallMatcher( + function_name=Param( + "aggregate", + Or(self.aggregate_function_names), + ), + with_optionals=True, + ).match(expression) + if match is not None: + matches.append(match) + + match = FunctionCallMatcher( + Param( + "operator", + Or( + [ + String(BooleanFunctions.AND), + String(BooleanFunctions.OR), + String(OPERATOR_TO_FUNCTION["="]), + String(OPERATOR_TO_FUNCTION[">"]), + String(OPERATOR_TO_FUNCTION["<"]), + String(OPERATOR_TO_FUNCTION[">="]), + String(OPERATOR_TO_FUNCTION["<="]), + String(OPERATOR_TO_FUNCTION["!="]), + ] + ), + ), + (Param("lhs", AnyExpression()), Param("rhs", AnyExpression())), + ).match(expression) + if match is not None: + matches.extend( + find_illegal_aggregate_functions(match.expression("lhs")) + ) + matches.extend( + find_illegal_aggregate_functions(match.expression("rhs")) + ) + return matches + + enable_illegal_aggregate_validator = state.get_config( + "enable_illegal_aggregate_in_condition_validator", 0 + ) + assert enable_illegal_aggregate_validator is not None + if int(enable_illegal_aggregate_validator) == 1: + conditions = query.get_condition() + if conditions: + matches = find_illegal_aggregate_functions(conditions) + if len(matches) > 0: + raise InvalidQueryException( + "Aggregate function found in WHERE clause of query" + ) diff --git a/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py b/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py new file mode 100644 index 0000000000..bebd5fdbb9 --- /dev/null +++ b/tests/datasets/validation/test_illegal_aggregate_conditions_validation.py @@ -0,0 +1,155 @@ +import datetime +from typing import Optional + +import pytest + +from snuba import state +from snuba.datasets.entities.entity_key import EntityKey +from snuba.datasets.entities.factory import get_entity +from snuba.query import SelectedExpression +from snuba.query.conditions import binary_condition +from snuba.query.data_source.simple import Entity as QueryEntity +from snuba.query.exceptions import InvalidQueryException +from snuba.query.expressions import Column, Expression, FunctionCall, Literal +from snuba.query.logical import Query as LogicalQuery +from snuba.query.validation.validators import IllegalAggregateInConditionValidator + +tests = [ + pytest.param( + FunctionCall( + None, + "greater", + ( + FunctionCall( + None, + "sumIf", + ( + Column(None, None, "duration"), + FunctionCall( + None, + "equals", + (Column(None, None, "is_segment"), Literal(None, 1)), + ), + ), + ), + Literal(None, 42), + ), + ), + None, + InvalidQueryException, + id="aggregate function in top level", + ), + pytest.param( + binary_condition( + "and", + binary_condition( + "equals", + Column("_snuba_received", None, "received"), + Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)), + ), + FunctionCall( + None, + "less", + ( + FunctionCall( + None, + "countIf", + ( + Column(None, None, "duration"), + FunctionCall( + None, + "equals", + (Column(None, None, "is_segment"), Literal(None, 1)), + ), + ), + ), + Literal(None, 42), + ), + ), + ), + None, + InvalidQueryException, + id="aggregate function in nested condition 1", + ), + pytest.param( + binary_condition( + "and", + binary_condition( + "equals", + Column("_snuba_received", None, "received"), + Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)), + ), + FunctionCall( + None, + "less", + ( + FunctionCall( + None, + "max", + (Column(None, None, "duration"),), + ), + Literal(None, 42), + ), + ), + ), + None, + InvalidQueryException, + id="aggregate function in nested condition 2", + ), + pytest.param( + binary_condition( + "and", + binary_condition( + "equals", + Column("_snuba_received", None, "received"), + Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)), + ), + FunctionCall( + None, + "less", + ( + Column("_snuba_group_id", None, "group_id"), + Literal(None, 2), + ), + ), + ), + FunctionCall( + None, + "less", + ( + FunctionCall( + None, + "max", + (Column(None, None, "duration"),), + ), + Literal(None, 42), + ), + ), + None, + id="no aggregate function in where clause but in having clause", + ), +] + + +@pytest.mark.parametrize("condition, having, exception", tests) +@pytest.mark.redis_db +def test_illegal_aggregate_in_condition_validator( + condition: Optional[Expression], + having: Optional[Expression], + exception: Exception, +) -> None: + state.set_config("enable_illegal_aggregate_in_condition_validator", 1) + query = LogicalQuery( + QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), + selected_columns=[ + SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")), + ], + condition=condition, + having=having, + ) + validator = IllegalAggregateInConditionValidator() + if exception: + with pytest.raises(exception): + validator.validate(query) + else: + validator.validate(query)