diff --git a/src/sentry/runner/commands/backup.py b/src/sentry/runner/commands/backup.py index 44a097515c625..764b3b666c055 100644 --- a/src/sentry/runner/commands/backup.py +++ b/src/sentry/runner/commands/backup.py @@ -1,11 +1,13 @@ from __future__ import annotations +from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone from difflib import unified_diff from io import StringIO -from typing import NamedTuple, NewType +from typing import Dict, List, NamedTuple, NewType import click +from dateutil import parser from django.apps import apps from django.core import management, serializers from django.core.serializers import serialize @@ -20,15 +22,15 @@ default=better_default_encoder, indent=2, ignore_nan=True, sort_keys=True ) -ComparatorName = NewType("ComparatorName", str) -ModelName = NewType("ModelName", str) +ComparatorKind = NewType("ComparatorKind", str) # TODO(team-ospo/#155): Figure out if we are going to use `pk` as part of the identifier, or some other kind of sequence number internal to the JSON export instead. class InstanceID(NamedTuple): - """Every entry in the generated backup JSON file should have a unique model+pk combination, which serves as its identifier.""" + """Every entry in the generated backup JSON file should have a unique model+pk combination, + which serves as its identifier.""" - model: ModelName + model: str pk: int def pretty(self) -> str: @@ -38,12 +40,12 @@ def pretty(self) -> str: class ComparatorFinding(NamedTuple): """Store all information about a single failed matching between expected and actual output.""" - name: ComparatorName + kind: ComparatorKind on: InstanceID reason: str = "" def pretty(self) -> str: - return f"Finding(\n\tname: {self.name!r},\n\ton: {self.on.pretty()},\n\treason: {self.reason}\n)" + return f"Finding(\n\tkind: {self.kind!r},\n\ton: {self.on.pretty()},\n\treason: {self.reason}\n)" class ComparatorFindings: @@ -59,8 +61,105 @@ def pretty(self) -> str: return "\n".join(f.pretty() for f in self.findings) -def validate(expect: JSONData, actual: JSONData) -> ComparatorFindings: - """Ensures that originally imported data correctly matches actual outputted data, and produces a list of reasons why not when it doesn't""" +class JSONScrubbingComparator(ABC): + """An abstract class that compares and then scrubs some set of fields that, by a more nuanced + definition than mere strict byte-for-byte equality, are expected to maintain some relation on + otherwise equivalent JSON instances of the same model. + + Each class inheriting from `JSONScrubbingComparator` should override the abstract `compare` + method with its own comparison logic. The `scrub` method is universal (it merely moves the + compared fields from the `fields` dictionary to the non-diffed `scrubbed` dictionary). + + If multiple comparators are used sequentially on a single model (see the `SCRUBBING_COMPARATORS` + dict below for specific mappings), all of the `compare(...)` methods are called before any of + the `scrub(...)` methods are. This ensures that comparators that touch the same fields do not + have their inputs mangled by one another.""" + + def __init__(self, fields: list[str]): + self.fields = fields + + def check(self, side: str, data: JSONData) -> None: + """Ensure that we have received valid JSON data at runtime.""" + + if "model" not in data or not isinstance(data["model"], str): + raise RuntimeError(f"The {side} input must have a `model` string assigned to it.") + if "pk" not in data or not isinstance(data["pk"], int): + raise RuntimeError(f"The {side} input must have a numerical `pk` entry.") + if "fields" not in data or not isinstance(data["fields"], dict): + raise RuntimeError(f"The {side} input must have a `fields` dictionary.") + + @abstractmethod + def compare(self, on: InstanceID, left: JSONData, right: JSONData) -> ComparatorFinding | None: + """An abstract method signature, to be implemented by inheriting classes with their own + comparison logic. Implementations of this method MUST take care not to mutate the method's + inputs!""" + + pass + + def scrub(self, on: InstanceID, left: JSONData, right: JSONData) -> None: + """Removes all of the fields compared by this comparator from the `fields` dict, so that the + remaining fields may be compared for equality. + + Parameters: + - on: An `InstanceID` that must be shared by both versions of the JSON model being + compared. + - left: One of the models being compared (usually the "before") version. + - right: The other model it is being compared against (usually the "after" or + post-processed version). + """ + + self.check("left", left) + self.check("right", right) + if "scrubbed" not in left: + left["scrubbed"] = {} + if "scrubbed" not in right: + right["scrubbed"] = {} + for field in self.fields: + del left["fields"][field] + left["scrubbed"][f"{self.get_kind()}::{field}"] = True + del right["fields"][field] + right["scrubbed"][f"{self.get_kind()}::{field}"] = True + + def get_kind(self) -> ComparatorKind: + """A unique identifier for this particular derivation of JSONScrubbingComparator, which will + be bubbled up in ComparatorFindings when they are generated.""" + + return self.__class__.__name__ + + +class DateUpdatedComparator(JSONScrubbingComparator): + """Comparator that ensures that the specified field's value on the right input is an ISO-8601 + date that is greater than (ie, occurs after) the specified field's left input.""" + + def __init__(self, field: str): + super().__init__([field]) + self.field = field + + def compare(self, on: InstanceID, left: JSONData, right: JSONData) -> ComparatorFinding | None: + left_date_updated = left["fields"][self.field] + right_date_updated = right["fields"][self.field] + if parser.parse(left_date_updated) > parser.parse(right_date_updated): + return ComparatorFinding( + kind=self.get_kind(), + on=on, + reason=f"""the left date_updated value on `{on}` ({left_date_updated}) was not less + than or equal to the right ({right_date_updated})""", + ) + + +ComparatorList = List[JSONScrubbingComparator] +ComparatorMap = Dict[str, ComparatorList] +DEFAULT_COMPARATORS: ComparatorMap = { + "sentry.userrole": [DateUpdatedComparator("date_updated")], + "sentry.userroleuser": [DateUpdatedComparator("date_updated")], +} + + +def validate( + expect: JSONData, actual: JSONData, comparators: ComparatorMap = DEFAULT_COMPARATORS +) -> ComparatorFindings: + """Ensures that originally imported data correctly matches actual outputted data, and produces a + list of reasons why not when it doesn't""" def json_lines(obj: JSONData) -> list[str]: """Take a JSONData object and pretty-print it as JSON.""" @@ -78,7 +177,7 @@ def json_lines(obj: JSONData) -> list[str]: for model in actual: id = InstanceID(model["model"], model["pk"]) if id in act_models: - findings.append(ComparatorFinding("duplicate_entry", id)) + findings.append(ComparatorFinding("DuplicateEntry", id)) else: act_models[id] = model @@ -87,20 +186,35 @@ def json_lines(obj: JSONData) -> list[str]: missing = sorted(exp_models.keys() - act_models.keys()) for id in extra: del act_models[id] - findings.append(ComparatorFinding("unexpected_entry", id)) + findings.append(ComparatorFinding("UnexpectedEntry", id)) for id in missing: del exp_models[id] - findings.append(ComparatorFinding("missing_entry", id)) + findings.append(ComparatorFinding("MissingEntry", id)) # We only perform custom comparisons and JSON diffs on non-duplicate entries that exist in both # outputs. for id, act in act_models.items(): exp = exp_models[id] + # Try comparators applicable for this specific model. + if id.model in comparators: + # We take care to run ALL of the `compare()` methods on each comparator before calling + # any `scrub()` methods. This ensures tha, in cases where a single model uses multiple + # comparators that touch the same fields, one comparator does not accidentally scrub the + # inputs for its follower. If `compare()` functions are well-behaved (that is, they + # don't mutate their inputs), this should be sufficient to ensure that the order in + # which comparators are applied does not change the final output. + for cmp in comparators[id.model]: + res = cmp.compare(id, exp, act) + if res: + findings.append(ComparatorFinding(cmp.get_kind(), id, res)) + for cmp in comparators[id.model]: + cmp.scrub(id, exp, act) + # Finally, perform a diff on the remaining JSON. diff = list(unified_diff(json_lines(exp["fields"]), json_lines(act["fields"]), n=3)) if diff: - findings.append(ComparatorFinding("json_diff", id, "\n " + "\n ".join(diff))) + findings.append(ComparatorFinding("UnequalJSON", id, "\n " + "\n ".join(diff))) return findings diff --git a/tests/sentry/backup/test_comparators.py b/tests/sentry/backup/test_comparators.py new file mode 100644 index 0000000000000..49e227c5477b0 --- /dev/null +++ b/tests/sentry/backup/test_comparators.py @@ -0,0 +1,44 @@ +from sentry.runner.commands.backup import DateUpdatedComparator, InstanceID + + +def test_good_date_updated_comparator(): + cmp = DateUpdatedComparator("my_date_field") + id = InstanceID("test", 1) + left = { + "model": "test", + "pk": 1, + "fields": { + "my_date_field": "2023-06-22T23:00:00.123Z", + }, + } + right = { + "model": "test", + "pk": 1, + "fields": { + "my_date_field": "2023-06-22T23:00:00.123Z", + }, + } + assert cmp.compare(id, left, right) is None + + +def test_bad_date_updated_comparator(): + cmp = DateUpdatedComparator("my_date_field") + id = InstanceID("test", 1) + left = { + "model": "test", + "pk": 1, + "fields": { + "my_date_field": "2023-06-22T23:12:34.567Z", + }, + } + right = { + "model": "test", + "pk": 1, + "fields": { + "my_date_field": "2023-06-22T23:00:00.001Z", + }, + } + res = cmp.compare(id, left, right) + assert res is not None + assert res.on == id + assert res.kind == "DateUpdatedComparator" diff --git a/tests/sentry/backup/test_correctness.py b/tests/sentry/backup/test_correctness.py index 45821d1892e1e..a326e5b7fa75c 100644 --- a/tests/sentry/backup/test_correctness.py +++ b/tests/sentry/backup/test_correctness.py @@ -2,17 +2,26 @@ import pytest from click.testing import CliRunner -from freezegun import freeze_time -from sentry.runner.commands.backup import import_, validate +from sentry.runner.commands.backup import ( + DEFAULT_COMPARATORS, + ComparatorMap, + InstanceID, + import_, + validate, +) from sentry.silo import unguarded_write from sentry.testutils.factories import get_fixture_path from sentry.utils import json from sentry.utils.pytest.fixtures import django_db_all from tests.sentry.backup import ValidationError, tmp_export_to_file +EMPTY_COMPARATORS_FOR_TESTING: ComparatorMap = {} -def import_export_then_validate(tmp_path: Path, fixture_file_name: str) -> None: + +def import_export_then_validate( + tmp_path: Path, fixture_file_name: str, map: ComparatorMap = EMPTY_COMPARATORS_FOR_TESTING +) -> None: """Test helper that validates that data imported from a fixture `.json` file correctly matches the actual outputted export data.""" @@ -24,22 +33,31 @@ def import_export_then_validate(tmp_path: Path, fixture_file_name: str) -> None: rv = CliRunner().invoke(import_, [str(fixture_file_path)]) assert rv.exit_code == 0, rv.output - res = validate(expect, tmp_export_to_file(tmp_path.joinpath("tmp_test_file.json"))) + res = validate( + expect, + tmp_export_to_file(tmp_path.joinpath("tmp_test_file.json")), + map, + ) if res.findings: raise ValidationError(res) @django_db_all(transaction=True, reset_sequences=True) -@freeze_time("2023-06-22T23:00:00.123Z") def test_good_fresh_install_validation(tmp_path): - import_export_then_validate(tmp_path, "fresh-install.json") + import_export_then_validate(tmp_path, "fresh-install.json", DEFAULT_COMPARATORS) @django_db_all(transaction=True, reset_sequences=True) def test_bad_fresh_install_validation(tmp_path): + with pytest.raises(ValidationError) as excinfo: import_export_then_validate(tmp_path, "fresh-install.json") - assert len(excinfo.value.info.findings) == 2 + info = excinfo.value.info + assert len(info.findings) == 2 + assert info.findings[0].kind == "UnequalJSON" + assert info.findings[0].on == InstanceID("sentry.userrole", 1) + assert info.findings[1].kind == "UnequalJSON" + assert info.findings[1].on == InstanceID("sentry.userroleuser", 1) @django_db_all(transaction=True, reset_sequences=True)