From ce8470829bc683345695bcd3ef26492757727c14 Mon Sep 17 00:00:00 2001 From: Lyn Nagara Date: Fri, 23 Jun 2023 14:08:35 -0700 Subject: [PATCH] feat: A small step towards schema enforcement (#4405) Devserver now enforces the schema for the events topic --- snuba/cli/consumer.py | 9 +++++++++ snuba/cli/devserver.py | 1 + snuba/cli/dlq_consumer.py | 1 + snuba/consumers/consumer.py | 5 +++-- snuba/consumers/consumer_builder.py | 3 +++ snuba/web/views.py | 1 + tests/consumers/test_consumer_builder.py | 7 ++++--- tests/test_consumer.py | 7 +++++-- 8 files changed, 27 insertions(+), 7 deletions(-) diff --git a/snuba/cli/consumer.py b/snuba/cli/consumer.py index d77bba7da8..d6e242a268 100644 --- a/snuba/cli/consumer.py +++ b/snuba/cli/consumer.py @@ -116,6 +116,13 @@ type=int, ) @click.option("--join-timeout", type=int, help="Join timeout in seconds.", default=5) +@click.option( + "--enforce-schema", + type=bool, + is_flag=True, + default=False, + help="Enforce schema on the raw events topic.", +) @click.option( "--profile-path", type=click.Path(dir_okay=True, file_okay=False, exists=True) ) @@ -144,6 +151,7 @@ def consumer( input_block_size: Optional[int], output_block_size: Optional[int], join_timeout: int = 5, + enforce_schema: bool = False, log_level: Optional[str] = None, profile_path: Optional[str] = None, max_poll_interval_ms: Optional[int] = None, @@ -201,6 +209,7 @@ def consumer( slice_id=slice_id, join_timeout=join_timeout, max_poll_interval_ms=max_poll_interval_ms, + enforce_schema=enforce_schema, ) consumer = consumer_builder.build_base_consumer() diff --git a/snuba/cli/devserver.py b/snuba/cli/devserver.py index 15c7280600..2e7e59514d 100644 --- a/snuba/cli/devserver.py +++ b/snuba/cli/devserver.py @@ -81,6 +81,7 @@ def devserver(*, bootstrap: bool, workers: bool) -> None: "--no-strict-offset-reset", "--log-level=debug", "--storage=errors", + "--enforce-schema", ], ), ( diff --git a/snuba/cli/dlq_consumer.py b/snuba/cli/dlq_consumer.py index 10afcd2786..c6c12c065c 100644 --- a/snuba/cli/dlq_consumer.py +++ b/snuba/cli/dlq_consumer.py @@ -169,6 +169,7 @@ def handler(signum: int, frame: Any) -> None: metrics=metrics, slice_id=instruction.slice_id, join_timeout=None, + enforce_schema=False, ) consumer = consumer_builder.build_dlq_consumer(instruction) diff --git a/snuba/consumers/consumer.py b/snuba/consumers/consumer.py index e8137bfb7b..f8215721e2 100644 --- a/snuba/consumers/consumer.py +++ b/snuba/consumers/consumer.py @@ -523,6 +523,7 @@ def process_message( processor: MessageProcessor, consumer_group: str, snuba_logical_topic: SnubaTopic, + enforce_schema: bool, message: Message[KafkaPayload], ) -> Union[None, BytesInsertBatch, ReplacementBatch]: local_metrics = MetricsWrapper( @@ -573,6 +574,8 @@ def process_message( _LAST_INVALID_MESSAGE[snuba_logical_topic.name] = start sentry_sdk.set_tag("invalid_message_schema", "true") logger.warning(err, exc_info=True) + if enforce_schema: + raise # TODO: this is not the most efficient place to emit a metric, but # as long as should_validate is behind a sample rate it should be @@ -604,8 +607,6 @@ def process_message( value = message.value raise InvalidMessage(value.partition, value.offset) from err - return None - if isinstance(result, InsertBatch): return BytesInsertBatch( [json_row_encoder.encode(row) for row in result.rows], diff --git a/snuba/consumers/consumer_builder.py b/snuba/consumers/consumer_builder.py index 0f95a0c309..77f98fb6b6 100644 --- a/snuba/consumers/consumer_builder.py +++ b/snuba/consumers/consumer_builder.py @@ -71,6 +71,7 @@ def __init__( metrics: MetricsBackend, slice_id: Optional[int], join_timeout: Optional[float], + enforce_schema: bool, profile_path: Optional[str] = None, max_poll_interval_ms: Optional[int] = None, ) -> None: @@ -83,6 +84,7 @@ def __init__( self.__consumer_config = consumer_config self.__kafka_params = kafka_params self.consumer_group = kafka_params.group_id + self.__enforce_schema = enforce_schema broker_config = build_kafka_consumer_configuration( self.__consumer_config.raw_topic.broker_config, @@ -213,6 +215,7 @@ def build_streaming_strategy_factory( processor, self.consumer_group, logical_topic, + self.__enforce_schema, ), collector=build_batch_writer( table_writer, diff --git a/snuba/web/views.py b/snuba/web/views.py index aac3a1d23a..2b30836bf2 100644 --- a/snuba/web/views.py +++ b/snuba/web/views.py @@ -644,6 +644,7 @@ def commit( stream_loader.get_processor(), "consumer_grouup", stream_loader.get_default_topic_spec().topic, + False, ), build_batch_writer(table_writer, metrics=metrics), max_batch_size=1, diff --git a/tests/consumers/test_consumer_builder.py b/tests/consumers/test_consumer_builder.py index c76084c006..c8d9c1d0cb 100644 --- a/tests/consumers/test_consumer_builder.py +++ b/tests/consumers/test_consumer_builder.py @@ -18,7 +18,7 @@ from snuba.datasets.storages.storage_key import StorageKey from snuba.utils.metrics.backends.abstract import MetricsBackend from snuba.utils.metrics.wrapper import MetricsWrapper -from tests.fixtures import get_raw_event +from tests.fixtures import get_raw_error_message from tests.test_consumer import get_row_count test_storage_key = StorageKey("errors") @@ -61,6 +61,7 @@ ), slice_id=None, join_timeout=5, + enforce_schema=True, ) optional_consumer_config = resolve_consumer_config( @@ -104,6 +105,7 @@ ), slice_id=None, join_timeout=5, + enforce_schema=True, ) @@ -160,8 +162,7 @@ def test_run_processing_strategy() -> None: strategy_factory = consumer_builder.build_streaming_strategy_factory() strategy = strategy_factory.create_with_partitions(commit, partitions) - raw_message = get_raw_event() - json_string = json.dumps([2, "insert", raw_message, []]) + json_string = json.dumps(get_raw_error_message()) message = Message( BrokerValue( diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 582c8eb5fd..dd7fa9e99f 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -31,13 +31,16 @@ from snuba.utils.streams.topics import Topic as SnubaTopic from tests.assertions import assert_changes from tests.backends.metrics import TestingMetricsBackend, Timing +from tests.fixtures import get_raw_error_message def test_streaming_consumer_strategy() -> None: messages = ( Message( BrokerValue( - KafkaPayload(None, b"{}", []), + KafkaPayload( + None, json.dumps(get_raw_error_message()).encode("utf-8"), [] + ), Partition(Topic("events"), 0), i, datetime.now(), @@ -72,7 +75,7 @@ def write_step() -> ProcessedMessageBatchWriter: factory = KafkaConsumerStrategyFactory( None, functools.partial( - process_message, processor, "consumer_group", SnubaTopic.EVENTS + process_message, processor, "consumer_group", SnubaTopic.EVENTS, True ), write_step, max_batch_size=10,