Skip to content

Commit

Permalink
ref(hc): Improve test transaction utilities. (#52401)
Browse files Browse the repository at this point in the history
  • Loading branch information
corps authored Jul 9, 2023
1 parent 4dd5e15 commit 82e0ba1
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 68 deletions.
4 changes: 4 additions & 0 deletions src/sentry/db/postgres/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from django.db.transaction import get_connection

from sentry.silo.patches.silo_aware_transaction_patch import determine_using_by_silo_mode


@contextlib.contextmanager
def in_test_psql_role_override(role_name: str, using: str | None = None):
Expand All @@ -18,6 +20,8 @@ def in_test_psql_role_override(role_name: str, using: str | None = None):
yield
return

using = determine_using_by_silo_mode(using)

with get_connection(using).cursor() as conn:
conn.execute("SELECT user")
(cur,) = conn.fetchone()
Expand Down
32 changes: 20 additions & 12 deletions src/sentry/db/postgres/transactions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import contextlib
import sys
import threading

from django.conf import settings
from django.db import transaction


@contextlib.contextmanager
def django_test_transaction_water_mark(using: str = "default"):
def django_test_transaction_water_mark(using: str | None = None):
"""
Hybrid cloud outbox flushing depends heavily on transaction.on_commit logic, but our tests do not follow
production in terms of isolation (TestCase users two outer transactions, and stubbed RPCs cannot simulate
Expand All @@ -20,23 +23,27 @@ def django_test_transaction_water_mark(using: str = "default"):
yield
return

from sentry.testutils import hybrid_cloud

# No need to manage the watermark unless conftest has configured a watermark
if using not in hybrid_cloud.simulated_transaction_watermarks.state:
yield
if using is None:
with contextlib.ExitStack() as stack:
for db_name in settings.DATABASES: # type: ignore
stack.enter_context(django_test_transaction_water_mark(db_name))
yield
return

from sentry.testutils import hybrid_cloud

connection = transaction.get_connection(using)

prev = hybrid_cloud.simulated_transaction_watermarks.state[using]
hybrid_cloud.simulated_transaction_watermarks.state[using] = len(connection.savepoint_ids)
prev = hybrid_cloud.simulated_transaction_watermarks.state.get(using, 0)
hybrid_cloud.simulated_transaction_watermarks.state[
using
] = hybrid_cloud.simulated_transaction_watermarks.get_transaction_depth(connection)
try:
connection.maybe_flush_commit_hooks()
yield
finally:
hybrid_cloud.simulated_transaction_watermarks.state[using] = min(
len(connection.savepoint_ids), prev
hybrid_cloud.simulated_transaction_watermarks.get_transaction_depth(connection), prev
)


Expand Down Expand Up @@ -80,6 +87,7 @@ def in_test_assert_no_transaction(msg: str):

from sentry.testutils import hybrid_cloud

for using, watermark in hybrid_cloud.simulated_transaction_watermarks.state.items():
conn = transaction.get_connection(using)
assert len(conn.savepoint_ids) <= watermark, msg
for using in settings.DATABASES: # type: ignore
assert not hybrid_cloud.simulated_transaction_watermarks.connection_above_watermark(
using
), msg
16 changes: 11 additions & 5 deletions src/sentry/models/outbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
sane_repr,
)
from sentry.db.postgres.roles import in_test_psql_role_override
from sentry.db.postgres.transactions import django_test_transaction_water_mark
from sentry.db.postgres.transactions import (
django_test_transaction_water_mark,
in_test_assert_no_transaction,
)
from sentry.services.hybrid_cloud import REGION_NAME_LENGTH
from sentry.silo import SiloMode
from sentry.utils import metrics
Expand Down Expand Up @@ -199,7 +202,7 @@ def next_schedule(self, now: datetime.datetime) -> datetime.datetime:

def save(self, **kwds: Any):
if _outbox_context.flushing_enabled:
transaction.on_commit(lambda: self.drain_shard())
transaction.on_commit(lambda: self.drain_shard(), using=router.db_for_write(type(self)))

tags = {"category": OutboxCategory(self.category).name}
metrics.incr("outbox.saved", 1, tags=tags)
Expand Down Expand Up @@ -271,6 +274,9 @@ def send_signal(self) -> None:
def drain_shard(
self, flush_all: bool = False, _test_processing_barrier: threading.Barrier | None = None
) -> None:
in_test_assert_no_transaction(
"drain_shard should only be called outside of any active transaction!"
)
# When we are flushing in a local context, we don't care about outboxes created concurrently --
# at best our logic depends on previously created outboxes.
latest_shard_row: OutboxBase | None = None
Expand All @@ -292,8 +298,8 @@ def drain_shard(

shard_row.process()

if _test_processing_barrier:
_test_processing_barrier.wait()
if _test_processing_barrier:
_test_processing_barrier.wait()


# Outboxes bound from region silo -> control silo
Expand Down Expand Up @@ -469,7 +475,7 @@ def outbox_context(inner: Atomic | None = None, flush: bool | None = None) -> Co
original = _outbox_context.flushing_enabled

if inner:
with in_test_psql_role_override("postgres"), inner:
with in_test_psql_role_override("postgres", using=inner.using), inner:
_outbox_context.flushing_enabled = flush
try:
yield
Expand Down
6 changes: 6 additions & 0 deletions src/sentry/silo/patches/silo_aware_transaction_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,18 @@ def determine_using_by_silo_mode(using):


def patch_silo_aware_atomic():
global _default_on_commit, _default_get_connection, _default_atomic_impl

current_django_version = get_version()
assert current_django_version.startswith("2.2."), (
"Newer versions of Django have an additional 'durable' parameter in atomic,"
+ " verify the signature before updating the version check."
)

_default_atomic_impl = transaction.atomic
_default_on_commit = transaction.on_commit
_default_get_connection = transaction.get_connection

transaction.atomic = siloed_atomic # type:ignore
transaction.on_commit = siloed_on_commit
transaction.get_connection = siloed_get_connection
88 changes: 49 additions & 39 deletions src/sentry/testutils/hybrid_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from types import TracebackType
from typing import Any, Callable, Generator, List, Mapping, Optional, Sequence, Tuple, Type

from django.db import connections, transaction
from django.db.backends.base.base import BaseDatabaseWrapper

from sentry.models.organizationmember import OrganizationMember
from sentry.models.organizationmembermapping import OrganizationMemberMapping
from sentry.services.hybrid_cloud import DelegatedBySiloMode, hc_test_stub
from sentry.silo import SiloMode
from sentry.testutils.silo import exempt_from_silo_limits
from sentry.testutils.silo import assume_test_silo_mode


class use_real_service:
Expand Down Expand Up @@ -92,7 +95,7 @@ def cb(service: Any, method_name: str, *args: Sequence[Any], **kwds: Mapping[str


class HybridCloudTestMixin:
@exempt_from_silo_limits()
@assume_test_silo_mode(SiloMode.CONTROL)
def assert_org_member_mapping(self, org_member: OrganizationMember, expected=None):
org_member.refresh_from_db()
org_member_mapping_query = OrganizationMemberMapping.objects.filter(
Expand All @@ -108,15 +111,6 @@ def assert_org_member_mapping(self, org_member: OrganizationMember, expected=Non
# only either user_id or email should have a value, but not both.
assert (email is None and user_id) or (email and user_id is None)

assert (
OrganizationMember.objects.filter(
organization_id=org_member.organization_id,
user_id=user_id,
email=email,
).count()
== 1
)

assert org_member_mapping.role == org_member.role
if org_member.inviter_id:
assert org_member_mapping.inviter_id == org_member.inviter_id
Expand All @@ -127,7 +121,7 @@ def assert_org_member_mapping(self, org_member: OrganizationMember, expected=Non
for key, expected_value in expected.items():
assert getattr(org_member_mapping, key) == expected_value

@exempt_from_silo_limits()
@assume_test_silo_mode(SiloMode.CONTROL)
def assert_org_member_mapping_not_exists(self, org_member: OrganizationMember):
email = org_member.email
user_id = org_member.user_id
Expand All @@ -143,6 +137,20 @@ def assert_org_member_mapping_not_exists(self, org_member: OrganizationMember):
class SimulatedTransactionWatermarks(threading.local):
state: dict[str, int] = {}

@staticmethod
def get_transaction_depth(connection: BaseDatabaseWrapper) -> int:
total = len(connection.savepoint_ids)
if connection.in_atomic_block:
total += 1
return total

def connection_above_watermark(
self, using: str | None = None, connection: BaseDatabaseWrapper | None = None
) -> bool:
if connection is None:
connection = transaction.get_connection(using)
return self.get_transaction_depth(connection) > self.state.get(connection.alias, 0)


simulated_transaction_watermarks = SimulatedTransactionWatermarks()

Expand All @@ -156,37 +164,31 @@ def simulate_on_commit(request: Any):
outbox processing) to correctly detect which savepoint should call the `on_commit` hook.
"""

from django.conf import settings
from django.db import transaction
from django.db.backends.base.base import BaseDatabaseWrapper
from django.test import TestCase as DjangoTestCase

request_node_cls = request.node.cls
is_django_test_case = request_node_cls is not None and issubclass(
request_node_cls, DjangoTestCase
)
simulated_transaction_watermarks.state = {}

if request_node_cls is None or not issubclass(request_node_cls, DjangoTestCase):
yield
return

_old_atomic_exit = transaction.Atomic.__exit__
_old_transaction_on_commit = transaction.on_commit

def maybe_flush_commit_hooks(connection):
if (
connection.in_atomic_block
and len(connection.savepoint_ids)
<= simulated_transaction_watermarks.state[connection.alias or "default"]
and not connection.closed_in_transaction
and not connection.needs_rollback
):
old_validate = connection.validate_no_atomic_block
connection.validate_no_atomic_block = lambda: None
try:
connection.run_and_clear_commit_hooks()
finally:
connection.validate_no_atomic_block = old_validate
elif not connection.in_atomic_block or not connection.savepoint_ids:
assert not connection.run_on_commit, "Incidental run_on_commits detected!"
def maybe_flush_commit_hooks(connection: BaseDatabaseWrapper):
if connection.closed_in_transaction or connection.needs_rollback:
return

if simulated_transaction_watermarks.connection_above_watermark(connection=connection):
return

old_validate = connection.validate_no_atomic_block
connection.validate_no_atomic_block = lambda: None # type: ignore
try:
connection.run_and_clear_commit_hooks()
finally:
connection.validate_no_atomic_block = old_validate # type: ignore

def new_atomic_exit(self, exc_type, *args, **kwds):
_old_atomic_exit(self, exc_type, *args, **kwds)
Expand All @@ -199,19 +201,27 @@ def new_atomic_on_commit(func, using=None):
_old_transaction_on_commit(func, using)
maybe_flush_commit_hooks(transaction.get_connection(using))

for conn in connections.all():
# This value happens to match the number of outer transactions in
# a django test case. Unfortunately, the timing of when setup is called
# vs when that final outer transaction is added makes it impossible to
# sample the value directly -- we just have to specify it here.
# That said, there are tests that would fail if this number were wrong.
if is_django_test_case:
simulated_transaction_watermarks.state[conn.alias] = 2
else:
simulated_transaction_watermarks.state[
conn.alias
] = simulated_transaction_watermarks.get_transaction_depth(conn)

functools.update_wrapper(new_atomic_exit, _old_atomic_exit)
functools.update_wrapper(new_atomic_on_commit, _old_transaction_on_commit)
transaction.Atomic.__exit__ = new_atomic_exit # type: ignore
transaction.on_commit = new_atomic_on_commit
setattr(BaseDatabaseWrapper, "maybe_flush_commit_hooks", maybe_flush_commit_hooks)

# django tests start inside two transactions
for db_name in settings.DATABASES:
simulated_transaction_watermarks.state[db_name] = 1
try:
yield
finally:
transaction.Atomic.__exit__ = _old_atomic_exit # type: ignore
transaction.on_commit = _old_transaction_on_commit
simulated_transaction_watermarks.state.clear()
delattr(BaseDatabaseWrapper, "maybe_flush_commit_hooks")
9 changes: 9 additions & 0 deletions tests/sentry/api/endpoints/test_relay_projectconfigs_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
from django.urls import reverse
from sentry_relay.auth import generate_key_pair

from sentry.db.postgres.transactions import in_test_hide_transaction_boundary
from sentry.models.relay import Relay
from sentry.relay.config import ProjectConfig
from sentry.tasks.relay import build_project_config
from sentry.testutils.hybrid_cloud import simulated_transaction_watermarks
from sentry.utils import json
from sentry.utils.pytest.fixtures import django_db_all


@pytest.fixture(autouse=True)
def disable_auto_on_commit():
simulated_transaction_watermarks.state["default"] = -1
with in_test_hide_transaction_boundary():
yield


@pytest.fixture
def key_pair():
return generate_key_pair()
Expand Down
Loading

0 comments on commit 82e0ba1

Please sign in to comment.