From 4f20c553116d36e20291c0a74e98f0518207d63c Mon Sep 17 00:00:00 2001 From: Enoch Tang Date: Tue, 30 Apr 2024 13:12:05 -0400 Subject: [PATCH] feat(dev-exp): Move entity validators from parser to entity processing stage (#5826) * move entity validators * update tests * fix test * remove duplicate assert --- snuba/datasets/plans/entity_validation.py | 88 +++++++++++++++++++ snuba/pipeline/stages/query_processing.py | 5 ++ snuba/query/parser/validation/__init__.py | 36 -------- snuba/query/parser/validation/functions.py | 17 +++- snuba/query/snql/parser.py | 3 - tests/clickhouse/query_dsl/test_project_id.py | 8 +- .../test_entity_processing_stage_composite.py | 4 +- tests/subscriptions/test_subscription.py | 5 +- tests/web/test_query_cache.py | 4 +- tests/web/test_transform_names.py | 2 + 10 files changed, 124 insertions(+), 48 deletions(-) create mode 100644 snuba/datasets/plans/entity_validation.py diff --git a/snuba/datasets/plans/entity_validation.py b/snuba/datasets/plans/entity_validation.py new file mode 100644 index 0000000000..18125c0566 --- /dev/null +++ b/snuba/datasets/plans/entity_validation.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Union + +import sentry_sdk + +from snuba.datasets.entities.factory import get_entity +from snuba.query import Query +from snuba.query.composite import CompositeQuery +from snuba.query.data_source.join import JoinClause +from snuba.query.data_source.simple import Entity as QueryEntity +from snuba.query.exceptions import InvalidQueryException, ValidationException +from snuba.query.logical import Query as LogicalQuery +from snuba.query.parser.validation.functions import FunctionCallsValidator +from snuba.query.query_settings import QuerySettings +from snuba.state import explain_meta + +EXPRESSION_VALIDATORS = [FunctionCallsValidator()] + + +def _validate_query(query: Query) -> None: + """ + Applies all the expression validators in one pass over the AST. + """ + + for exp in query.get_all_expressions(): + for v in EXPRESSION_VALIDATORS: + v.validate(exp, query.get_from_clause()) + + +def _validate_entities_with_query( + query: Union[CompositeQuery[QueryEntity], LogicalQuery] +) -> None: + """ + Applies all validator defined on the entities in the query + """ + if isinstance(query, LogicalQuery): + entity = get_entity(query.get_from_clause().key) + try: + for v in entity.get_validators(): + v.validate(query) + except InvalidQueryException as e: + raise ValidationException( + f"validation failed for entity {query.get_from_clause().key.value}: {e}", + should_report=e.should_report, + ) + else: + from_clause = query.get_from_clause() + if isinstance(from_clause, JoinClause): + alias_map = from_clause.get_alias_node_map() + for alias, node in alias_map.items(): + assert isinstance(node.data_source, QueryEntity) # mypy + entity = get_entity(node.data_source.key) + try: + for v in entity.get_validators(): + v.validate(query, alias) + except InvalidQueryException as e: + raise ValidationException( + f"validation failed for entity {node.data_source.key.value}: {e}", + should_report=e.should_report, + ) + + +VALIDATORS = [_validate_query, _validate_entities_with_query] + + +def run_entity_validators( + query: Union[CompositeQuery[QueryEntity], LogicalQuery], + settings: QuerySettings | None = None, +) -> None: + """ + Main function for applying all validators associated with an entity + """ + for validator_func in VALIDATORS: + description = getattr(validator_func, "__name__", "custom") + with sentry_sdk.start_span(op="validator", description=description): + if settings and settings.get_dry_run(): + with explain_meta.with_query_differ( + "entity_validator", description, query + ): + validator_func(query) + else: + validator_func(query) + + if isinstance(query, CompositeQuery): + from_clause = query.get_from_clause() + if isinstance(from_clause, (LogicalQuery, CompositeQuery)): + run_entity_validators(from_clause, settings) diff --git a/snuba/pipeline/stages/query_processing.py b/snuba/pipeline/stages/query_processing.py index 097719d880..34d245a2d0 100644 --- a/snuba/pipeline/stages/query_processing.py +++ b/snuba/pipeline/stages/query_processing.py @@ -1,5 +1,6 @@ from snuba.clickhouse.query import Query as ClickhouseQuery from snuba.datasets.plans.entity_processing import run_entity_processing_executor +from snuba.datasets.plans.entity_validation import run_entity_validators from snuba.datasets.plans.storage_processing import ( apply_storage_processors, build_best_plan, @@ -22,6 +23,10 @@ class EntityProcessingStage( def _process_data( self, pipe_input: QueryPipelineData[Request] ) -> ClickhouseQuery | CompositeQuery[Table]: + # Execute entity validators on logical query + run_entity_validators(pipe_input.data.query, pipe_input.query_settings) + + # Run entity processors on the query and transform logical query into a physical query if isinstance(pipe_input.data.query, LogicalQuery): return run_entity_processing_executor( pipe_input.data.query, pipe_input.query_settings diff --git a/snuba/query/parser/validation/__init__.py b/snuba/query/parser/validation/__init__.py index 7e43eca255..e69de29bb2 100644 --- a/snuba/query/parser/validation/__init__.py +++ b/snuba/query/parser/validation/__init__.py @@ -1,36 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Sequence - -from snuba.query import Query -from snuba.query.data_source import DataSource -from snuba.query.expressions import Expression - - -class ExpressionValidator(ABC): - """ - Validates an individual expression in a Snuba logical query. - """ - - @abstractmethod - def validate(self, exp: Expression, data_source: DataSource) -> None: - """ - If the expression is valid according to this validator it - returns, otherwise it raises a subclass of - snuba.query.parser.ValidationException - """ - raise NotImplementedError - - -from snuba.query.parser.validation.functions import FunctionCallsValidator - -validators: Sequence[ExpressionValidator] = [FunctionCallsValidator()] - - -def validate_query(query: Query) -> None: - """ - Applies all the expression validators in one pass over the AST. - """ - - for exp in query.get_all_expressions(): - for v in validators: - v.validate(exp, query.get_from_clause()) diff --git a/snuba/query/parser/validation/functions.py b/snuba/query/parser/validation/functions.py index 064ea2610f..d8bb992a69 100644 --- a/snuba/query/parser/validation/functions.py +++ b/snuba/query/parser/validation/functions.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import List, Mapping from snuba.clickhouse.columns import Array, DateTime, String @@ -10,7 +11,6 @@ from snuba.query.data_source.visitor import DataSourceVisitor from snuba.query.exceptions import InvalidExpressionException from snuba.query.expressions import Expression, FunctionCall -from snuba.query.parser.validation import ExpressionValidator from snuba.query.validation import FunctionCallValidator, InvalidFunctionCall from snuba.query.validation.functions import AllowedFunctionValidator from snuba.query.validation.signature import Any, Column, SignatureValidator @@ -90,6 +90,21 @@ def visit_join_clause(self, node: JoinClause[QueryEntity]) -> List[QueryEntity]: return node.right_node.accept(self) + node.left_node.accept(self) +class ExpressionValidator(ABC): + """ + Validates an individual expression in a Snuba logical query. + """ + + @abstractmethod + def validate(self, exp: Expression, data_source: DataSource) -> None: + """ + If the expression is valid according to this validator it + returns, otherwise it raises a subclass of + snuba.query.parser.ValidationException + """ + raise NotImplementedError + + class FunctionCallsValidator(ExpressionValidator): """ Applies all function validators on the provided expression. diff --git a/snuba/query/snql/parser.py b/snuba/query/snql/parser.py index 47a078bf0d..0ae88cf5a3 100644 --- a/snuba/query/snql/parser.py +++ b/snuba/query/snql/parser.py @@ -72,7 +72,6 @@ validate_aliases, ) from snuba.query.parser.exceptions import ParsingException, PostProcessingError -from snuba.query.parser.validation import validate_query from snuba.query.query_settings import QuerySettings from snuba.query.schema import POSITIVE_OPERATORS from snuba.query.snql.anonymize import format_snql_anonymized @@ -1491,8 +1490,6 @@ def _post_process( VALIDATORS = [ validate_identifiers_in_lambda, - validate_query, - validate_entities_with_query, ] diff --git a/tests/clickhouse/query_dsl/test_project_id.py b/tests/clickhouse/query_dsl/test_project_id.py index 964c0b43c8..cf285baf10 100644 --- a/tests/clickhouse/query_dsl/test_project_id.py +++ b/tests/clickhouse/query_dsl/test_project_id.py @@ -5,9 +5,11 @@ from snuba.clickhouse.query_dsl.accessors import get_object_ids_in_query_ast from snuba.datasets.factory import get_dataset +from snuba.datasets.plans.entity_validation import run_entity_validators from snuba.datasets.plans.translator.query import identity_translate +from snuba.query.exceptions import ValidationException from snuba.query.logical import Query -from snuba.query.parser.exceptions import ParsingException +from snuba.query.query_settings import HTTPQuerySettings from snuba.query.snql.parser import parse_snql_query test_cases: Sequence[Tuple[Mapping[str, Any], Optional[Set[int]]]] = [ @@ -176,17 +178,19 @@ def test_find_projects( ) -> None: events = get_dataset("events") if expected_projects is None: - with pytest.raises(ParsingException): + with pytest.raises(ValidationException): request = json_to_snql(query_body, "events") request.validate() query, _ = parse_snql_query(str(request.query), events) assert isinstance(query, Query) + run_entity_validators(query, HTTPQuerySettings()) identity_translate(query) else: request = json_to_snql(query_body, "events") request.validate() query, _ = parse_snql_query(str(request.query), events) assert isinstance(query, Query) + run_entity_validators(query, HTTPQuerySettings()) translated_query = identity_translate(query) project_ids_ast = get_object_ids_in_query_ast(translated_query, "project_id") assert project_ids_ast == expected_projects diff --git a/tests/pipeline/test_entity_processing_stage_composite.py b/tests/pipeline/test_entity_processing_stage_composite.py index 9b967fea93..5c5e367986 100644 --- a/tests/pipeline/test_entity_processing_stage_composite.py +++ b/tests/pipeline/test_entity_processing_stage_composite.py @@ -328,7 +328,7 @@ ), SelectedExpression( "_snuba_right", - Column("_snuba_right", "groups", "right_col"), + Column("_snuba_right", "groups", "status"), ), ], condition=binary_condition( @@ -398,7 +398,7 @@ ), SelectedExpression( "_snuba_right", - Column("_snuba_right", None, "right_col"), + Column("_snuba_right", None, "status"), ), ], ), diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index c5cfec6fd1..be9818aa8d 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -8,8 +8,7 @@ from snuba.datasets.entities.factory import get_entity from snuba.datasets.entity_subscriptions.validators import InvalidSubscriptionError from snuba.datasets.factory import get_dataset -from snuba.query.exceptions import InvalidQueryException -from snuba.query.parser.exceptions import ParsingException +from snuba.query.exceptions import InvalidQueryException, ValidationException from snuba.query.validation.validators import ColumnValidationMode from snuba.redis import RedisClientKey, get_redis_client from snuba.subscriptions.data import SubscriptionData @@ -86,7 +85,7 @@ def test(self, subscription: SubscriptionData) -> None: def test_invalid_condition_column(self, subscription: SubscriptionData) -> None: override_entity_column_validator(EntityKey.EVENTS, ColumnValidationMode.ERROR) creator = SubscriptionCreator(self.dataset, EntityKey.EVENTS) - with raises(ParsingException): + with raises(ValidationException): creator.create( subscription, self.timer, diff --git a/tests/web/test_query_cache.py b/tests/web/test_query_cache.py index 9800f9a761..bcff178af4 100644 --- a/tests/web/test_query_cache.py +++ b/tests/web/test_query_cache.py @@ -10,8 +10,9 @@ from snuba.datasets.entities.factory import get_entity from snuba.datasets.factory import get_dataset from snuba.query import SelectedExpression +from snuba.query.conditions import in_condition from snuba.query.data_source.simple import Entity -from snuba.query.expressions import Column +from snuba.query.expressions import Column, Literal from snuba.query.logical import Query from snuba.query.query_settings import HTTPQuerySettings from snuba.request import Request @@ -29,6 +30,7 @@ def run_query() -> None: selected_columns=[ SelectedExpression("event_id", Column("_snuba_event_id", None, "event_id")), ], + condition=in_condition(Column(None, None, "project_id"), [Literal(None, 123)]), ) query_settings = HTTPQuerySettings(referrer="asd") diff --git a/tests/web/test_transform_names.py b/tests/web/test_transform_names.py index 301c9d322c..a7688889a3 100644 --- a/tests/web/test_transform_names.py +++ b/tests/web/test_transform_names.py @@ -12,6 +12,7 @@ from snuba.datasets.factory import get_dataset from snuba.processor import InsertEvent from snuba.query import SelectedExpression +from snuba.query.conditions import in_condition from snuba.query.data_source.simple import Entity from snuba.query.expressions import Column, FunctionCall, Literal from snuba.query.logical import Query @@ -75,6 +76,7 @@ def test_transform_column_names() -> None: ), ), ], + condition=in_condition(Column(None, None, "project_id"), [Literal(None, 1)]), ) query_settings = HTTPQuerySettings(referrer="asd")