Skip to content

Commit

Permalink
feat(on-call): Add illegal aggregate function in conditions entity va…
Browse files Browse the repository at this point in the history
…lidator (#5435)

* add illegal aggregate entity validator

* clean up

* fix typing

* add negative test case and move killswitch

* fix typing
  • Loading branch information
enochtangg committed Jan 22, 2024
1 parent c7e88cd commit fb8d237
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 2 deletions.
4 changes: 3 additions & 1 deletion snuba/datasets/pluggable_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from snuba.query.validation.validators import (
ColumnValidationMode,
EntityContainsColumnsValidator,
IllegalAggregateInConditionValidator,
QueryValidator,
)
from snuba.utils.schemas import SchemaModifiers
Expand Down Expand Up @@ -70,7 +71,8 @@ def _get_builtin_validators(self) -> Sequence[QueryValidator]:
EntityColumnSet(self.columns),
mappers,
self.validate_data_model or ColumnValidationMode.ERROR,
)
),
IllegalAggregateInConditionValidator(),
]

def get_query_processors(self) -> Sequence[LogicalQueryProcessor]:
Expand Down
87 changes: 86 additions & 1 deletion snuba/query/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from typing import Optional, Sequence, Type, cast

from snuba import state
from snuba.clickhouse.translators.snuba.allowed import DefaultNoneColumnMapper
from snuba.clickhouse.translators.snuba.function_call_mappers import (
AggregateCurriedFunctionMapper,
Expand All @@ -20,6 +21,8 @@
from snuba.environment import metrics as environment_metrics
from snuba.query import Query
from snuba.query.conditions import (
OPERATOR_TO_FUNCTION,
BooleanFunctions,
ConditionFunctions,
build_match,
get_first_level_and_conditions,
Expand All @@ -29,7 +32,9 @@
from snuba.query.expressions import Column, Expression, FunctionCall, Literal
from snuba.query.expressions import SubscriptableReference as SubscriptableReferenceExpr
from snuba.query.logical import Query as LogicalQuery
from snuba.query.matchers import Or
from snuba.query.matchers import AnyExpression
from snuba.query.matchers import FunctionCall as FunctionCallMatcher
from snuba.query.matchers import MatchResult, Or, Param, String
from snuba.utils.metrics.wrapper import MetricsWrapper
from snuba.utils.registered_class import RegisteredClass
from snuba.utils.schemas import ColumnSet, Date, DateTime
Expand Down Expand Up @@ -434,3 +439,83 @@ def validate(self, query: Query, alias: Optional[str] = None) -> None:
logger.warning(
f"{lhs} requires datetime conditions: '{param.value}' is not a valid datetime"
)


class IllegalAggregateInConditionValidator(QueryValidator):
"""
Ensures that aggregate functions are not used in WHERE clause of query.
"""

def __init__(self) -> None:
# This is not an exhaustive list of all aggregate functions,
# but should be sufficient for catching most invalid queries
common_aggregate_functions = [
"count",
"min",
"max",
"sum",
"avg",
"last",
"uniq",
]
self.aggregate_function_names = [
String(func_name) for func_name in common_aggregate_functions
]
self.aggregate_function_names.extend(
[String(f"{func_name}If") for func_name in common_aggregate_functions]
)

def validate(self, query: Query, alias: Optional[str] = None) -> None:
def find_illegal_aggregate_functions(
expression: Expression,
) -> list[MatchResult]:
matches: list[MatchResult] = []
match = FunctionCallMatcher(
function_name=Param(
"aggregate",
Or(self.aggregate_function_names),
),
with_optionals=True,
).match(expression)
if match is not None:
matches.append(match)

match = FunctionCallMatcher(
Param(
"operator",
Or(
[
String(BooleanFunctions.AND),
String(BooleanFunctions.OR),
String(OPERATOR_TO_FUNCTION["="]),
String(OPERATOR_TO_FUNCTION[">"]),
String(OPERATOR_TO_FUNCTION["<"]),
String(OPERATOR_TO_FUNCTION[">="]),
String(OPERATOR_TO_FUNCTION["<="]),
String(OPERATOR_TO_FUNCTION["!="]),
]
),
),
(Param("lhs", AnyExpression()), Param("rhs", AnyExpression())),
).match(expression)
if match is not None:
matches.extend(
find_illegal_aggregate_functions(match.expression("lhs"))
)
matches.extend(
find_illegal_aggregate_functions(match.expression("rhs"))
)
return matches

enable_illegal_aggregate_validator = state.get_config(
"enable_illegal_aggregate_in_condition_validator", 0
)
assert enable_illegal_aggregate_validator is not None
if int(enable_illegal_aggregate_validator) == 1:
conditions = query.get_condition()
if conditions:
matches = find_illegal_aggregate_functions(conditions)
if len(matches) > 0:
raise InvalidQueryException(
"Aggregate function found in WHERE clause of query"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import datetime
from typing import Optional

import pytest

from snuba import state
from snuba.datasets.entities.entity_key import EntityKey
from snuba.datasets.entities.factory import get_entity
from snuba.query import SelectedExpression
from snuba.query.conditions import binary_condition
from snuba.query.data_source.simple import Entity as QueryEntity
from snuba.query.exceptions import InvalidQueryException
from snuba.query.expressions import Column, Expression, FunctionCall, Literal
from snuba.query.logical import Query as LogicalQuery
from snuba.query.validation.validators import IllegalAggregateInConditionValidator

tests = [
pytest.param(
FunctionCall(
None,
"greater",
(
FunctionCall(
None,
"sumIf",
(
Column(None, None, "duration"),
FunctionCall(
None,
"equals",
(Column(None, None, "is_segment"), Literal(None, 1)),
),
),
),
Literal(None, 42),
),
),
None,
InvalidQueryException,
id="aggregate function in top level",
),
pytest.param(
binary_condition(
"and",
binary_condition(
"equals",
Column("_snuba_received", None, "received"),
Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)),
),
FunctionCall(
None,
"less",
(
FunctionCall(
None,
"countIf",
(
Column(None, None, "duration"),
FunctionCall(
None,
"equals",
(Column(None, None, "is_segment"), Literal(None, 1)),
),
),
),
Literal(None, 42),
),
),
),
None,
InvalidQueryException,
id="aggregate function in nested condition 1",
),
pytest.param(
binary_condition(
"and",
binary_condition(
"equals",
Column("_snuba_received", None, "received"),
Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)),
),
FunctionCall(
None,
"less",
(
FunctionCall(
None,
"max",
(Column(None, None, "duration"),),
),
Literal(None, 42),
),
),
),
None,
InvalidQueryException,
id="aggregate function in nested condition 2",
),
pytest.param(
binary_condition(
"and",
binary_condition(
"equals",
Column("_snuba_received", None, "received"),
Literal(None, datetime.datetime(2023, 1, 25, 20, 3, 13)),
),
FunctionCall(
None,
"less",
(
Column("_snuba_group_id", None, "group_id"),
Literal(None, 2),
),
),
),
FunctionCall(
None,
"less",
(
FunctionCall(
None,
"max",
(Column(None, None, "duration"),),
),
Literal(None, 42),
),
),
None,
id="no aggregate function in where clause but in having clause",
),
]


@pytest.mark.parametrize("condition, having, exception", tests)
@pytest.mark.redis_db
def test_illegal_aggregate_in_condition_validator(
condition: Optional[Expression],
having: Optional[Expression],
exception: Exception,
) -> None:
state.set_config("enable_illegal_aggregate_in_condition_validator", 1)
query = LogicalQuery(
QueryEntity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()),
selected_columns=[
SelectedExpression("time", Column("_snuba_timestamp", None, "timestamp")),
],
condition=condition,
having=having,
)
validator = IllegalAggregateInConditionValidator()
if exception:
with pytest.raises(exception):
validator.validate(query)
else:
validator.validate(query)

0 comments on commit fb8d237

Please sign in to comment.