Skip to content

Commit

Permalink
feat(dev-exp): Move entity validators from parser to entity processin…
Browse files Browse the repository at this point in the history
…g stage (#5826)

* move entity validators

* update tests

* fix test

* remove duplicate assert
  • Loading branch information
enochtangg authored Apr 30, 2024
1 parent aff21c6 commit 4f20c55
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 48 deletions.
88 changes: 88 additions & 0 deletions snuba/datasets/plans/entity_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import Union

import sentry_sdk

from snuba.datasets.entities.factory import get_entity
from snuba.query import Query
from snuba.query.composite import CompositeQuery
from snuba.query.data_source.join import JoinClause
from snuba.query.data_source.simple import Entity as QueryEntity
from snuba.query.exceptions import InvalidQueryException, ValidationException
from snuba.query.logical import Query as LogicalQuery
from snuba.query.parser.validation.functions import FunctionCallsValidator
from snuba.query.query_settings import QuerySettings
from snuba.state import explain_meta

EXPRESSION_VALIDATORS = [FunctionCallsValidator()]


def _validate_query(query: Query) -> None:
"""
Applies all the expression validators in one pass over the AST.
"""

for exp in query.get_all_expressions():
for v in EXPRESSION_VALIDATORS:
v.validate(exp, query.get_from_clause())


def _validate_entities_with_query(
query: Union[CompositeQuery[QueryEntity], LogicalQuery]
) -> None:
"""
Applies all validator defined on the entities in the query
"""
if isinstance(query, LogicalQuery):
entity = get_entity(query.get_from_clause().key)
try:
for v in entity.get_validators():
v.validate(query)
except InvalidQueryException as e:
raise ValidationException(
f"validation failed for entity {query.get_from_clause().key.value}: {e}",
should_report=e.should_report,
)
else:
from_clause = query.get_from_clause()
if isinstance(from_clause, JoinClause):
alias_map = from_clause.get_alias_node_map()
for alias, node in alias_map.items():
assert isinstance(node.data_source, QueryEntity) # mypy
entity = get_entity(node.data_source.key)
try:
for v in entity.get_validators():
v.validate(query, alias)
except InvalidQueryException as e:
raise ValidationException(
f"validation failed for entity {node.data_source.key.value}: {e}",
should_report=e.should_report,
)


VALIDATORS = [_validate_query, _validate_entities_with_query]


def run_entity_validators(
query: Union[CompositeQuery[QueryEntity], LogicalQuery],
settings: QuerySettings | None = None,
) -> None:
"""
Main function for applying all validators associated with an entity
"""
for validator_func in VALIDATORS:
description = getattr(validator_func, "__name__", "custom")
with sentry_sdk.start_span(op="validator", description=description):
if settings and settings.get_dry_run():
with explain_meta.with_query_differ(
"entity_validator", description, query
):
validator_func(query)
else:
validator_func(query)

if isinstance(query, CompositeQuery):
from_clause = query.get_from_clause()
if isinstance(from_clause, (LogicalQuery, CompositeQuery)):
run_entity_validators(from_clause, settings)
5 changes: 5 additions & 0 deletions snuba/pipeline/stages/query_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from snuba.clickhouse.query import Query as ClickhouseQuery
from snuba.datasets.plans.entity_processing import run_entity_processing_executor
from snuba.datasets.plans.entity_validation import run_entity_validators
from snuba.datasets.plans.storage_processing import (
apply_storage_processors,
build_best_plan,
Expand All @@ -22,6 +23,10 @@ class EntityProcessingStage(
def _process_data(
self, pipe_input: QueryPipelineData[Request]
) -> ClickhouseQuery | CompositeQuery[Table]:
# Execute entity validators on logical query
run_entity_validators(pipe_input.data.query, pipe_input.query_settings)

# Run entity processors on the query and transform logical query into a physical query
if isinstance(pipe_input.data.query, LogicalQuery):
return run_entity_processing_executor(
pipe_input.data.query, pipe_input.query_settings
Expand Down
36 changes: 0 additions & 36 deletions snuba/query/parser/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@
from abc import ABC, abstractmethod
from typing import Sequence

from snuba.query import Query
from snuba.query.data_source import DataSource
from snuba.query.expressions import Expression


class ExpressionValidator(ABC):
"""
Validates an individual expression in a Snuba logical query.
"""

@abstractmethod
def validate(self, exp: Expression, data_source: DataSource) -> None:
"""
If the expression is valid according to this validator it
returns, otherwise it raises a subclass of
snuba.query.parser.ValidationException
"""
raise NotImplementedError


from snuba.query.parser.validation.functions import FunctionCallsValidator

validators: Sequence[ExpressionValidator] = [FunctionCallsValidator()]


def validate_query(query: Query) -> None:
"""
Applies all the expression validators in one pass over the AST.
"""

for exp in query.get_all_expressions():
for v in validators:
v.validate(exp, query.get_from_clause())
17 changes: 16 additions & 1 deletion snuba/query/parser/validation/functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import List, Mapping

from snuba.clickhouse.columns import Array, DateTime, String
Expand All @@ -10,7 +11,6 @@
from snuba.query.data_source.visitor import DataSourceVisitor
from snuba.query.exceptions import InvalidExpressionException
from snuba.query.expressions import Expression, FunctionCall
from snuba.query.parser.validation import ExpressionValidator
from snuba.query.validation import FunctionCallValidator, InvalidFunctionCall
from snuba.query.validation.functions import AllowedFunctionValidator
from snuba.query.validation.signature import Any, Column, SignatureValidator
Expand Down Expand Up @@ -90,6 +90,21 @@ def visit_join_clause(self, node: JoinClause[QueryEntity]) -> List[QueryEntity]:
return node.right_node.accept(self) + node.left_node.accept(self)


class ExpressionValidator(ABC):
"""
Validates an individual expression in a Snuba logical query.
"""

@abstractmethod
def validate(self, exp: Expression, data_source: DataSource) -> None:
"""
If the expression is valid according to this validator it
returns, otherwise it raises a subclass of
snuba.query.parser.ValidationException
"""
raise NotImplementedError


class FunctionCallsValidator(ExpressionValidator):
"""
Applies all function validators on the provided expression.
Expand Down
3 changes: 0 additions & 3 deletions snuba/query/snql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
validate_aliases,
)
from snuba.query.parser.exceptions import ParsingException, PostProcessingError
from snuba.query.parser.validation import validate_query
from snuba.query.query_settings import QuerySettings
from snuba.query.schema import POSITIVE_OPERATORS
from snuba.query.snql.anonymize import format_snql_anonymized
Expand Down Expand Up @@ -1491,8 +1490,6 @@ def _post_process(

VALIDATORS = [
validate_identifiers_in_lambda,
validate_query,
validate_entities_with_query,
]


Expand Down
8 changes: 6 additions & 2 deletions tests/clickhouse/query_dsl/test_project_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from snuba.clickhouse.query_dsl.accessors import get_object_ids_in_query_ast
from snuba.datasets.factory import get_dataset
from snuba.datasets.plans.entity_validation import run_entity_validators
from snuba.datasets.plans.translator.query import identity_translate
from snuba.query.exceptions import ValidationException
from snuba.query.logical import Query
from snuba.query.parser.exceptions import ParsingException
from snuba.query.query_settings import HTTPQuerySettings
from snuba.query.snql.parser import parse_snql_query

test_cases: Sequence[Tuple[Mapping[str, Any], Optional[Set[int]]]] = [
Expand Down Expand Up @@ -176,17 +178,19 @@ def test_find_projects(
) -> None:
events = get_dataset("events")
if expected_projects is None:
with pytest.raises(ParsingException):
with pytest.raises(ValidationException):
request = json_to_snql(query_body, "events")
request.validate()
query, _ = parse_snql_query(str(request.query), events)
assert isinstance(query, Query)
run_entity_validators(query, HTTPQuerySettings())
identity_translate(query)
else:
request = json_to_snql(query_body, "events")
request.validate()
query, _ = parse_snql_query(str(request.query), events)
assert isinstance(query, Query)
run_entity_validators(query, HTTPQuerySettings())
translated_query = identity_translate(query)
project_ids_ast = get_object_ids_in_query_ast(translated_query, "project_id")
assert project_ids_ast == expected_projects
4 changes: 2 additions & 2 deletions tests/pipeline/test_entity_processing_stage_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@
),
SelectedExpression(
"_snuba_right",
Column("_snuba_right", "groups", "right_col"),
Column("_snuba_right", "groups", "status"),
),
],
condition=binary_condition(
Expand Down Expand Up @@ -398,7 +398,7 @@
),
SelectedExpression(
"_snuba_right",
Column("_snuba_right", None, "right_col"),
Column("_snuba_right", None, "status"),
),
],
),
Expand Down
5 changes: 2 additions & 3 deletions tests/subscriptions/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from snuba.datasets.entities.factory import get_entity
from snuba.datasets.entity_subscriptions.validators import InvalidSubscriptionError
from snuba.datasets.factory import get_dataset
from snuba.query.exceptions import InvalidQueryException
from snuba.query.parser.exceptions import ParsingException
from snuba.query.exceptions import InvalidQueryException, ValidationException
from snuba.query.validation.validators import ColumnValidationMode
from snuba.redis import RedisClientKey, get_redis_client
from snuba.subscriptions.data import SubscriptionData
Expand Down Expand Up @@ -86,7 +85,7 @@ def test(self, subscription: SubscriptionData) -> None:
def test_invalid_condition_column(self, subscription: SubscriptionData) -> None:
override_entity_column_validator(EntityKey.EVENTS, ColumnValidationMode.ERROR)
creator = SubscriptionCreator(self.dataset, EntityKey.EVENTS)
with raises(ParsingException):
with raises(ValidationException):
creator.create(
subscription,
self.timer,
Expand Down
4 changes: 3 additions & 1 deletion tests/web/test_query_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from snuba.datasets.entities.factory import get_entity
from snuba.datasets.factory import get_dataset
from snuba.query import SelectedExpression
from snuba.query.conditions import in_condition
from snuba.query.data_source.simple import Entity
from snuba.query.expressions import Column
from snuba.query.expressions import Column, Literal
from snuba.query.logical import Query
from snuba.query.query_settings import HTTPQuerySettings
from snuba.request import Request
Expand All @@ -29,6 +30,7 @@ def run_query() -> None:
selected_columns=[
SelectedExpression("event_id", Column("_snuba_event_id", None, "event_id")),
],
condition=in_condition(Column(None, None, "project_id"), [Literal(None, 123)]),
)

query_settings = HTTPQuerySettings(referrer="asd")
Expand Down
2 changes: 2 additions & 0 deletions tests/web/test_transform_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from snuba.datasets.factory import get_dataset
from snuba.processor import InsertEvent
from snuba.query import SelectedExpression
from snuba.query.conditions import in_condition
from snuba.query.data_source.simple import Entity
from snuba.query.expressions import Column, FunctionCall, Literal
from snuba.query.logical import Query
Expand Down Expand Up @@ -75,6 +76,7 @@ def test_transform_column_names() -> None:
),
),
],
condition=in_condition(Column(None, None, "project_id"), [Literal(None, 1)]),
)
query_settings = HTTPQuerySettings(referrer="asd")

Expand Down

0 comments on commit 4f20c55

Please sign in to comment.