From 5d34fda57504c7ad7253d5612e2f2c9088446449 Mon Sep 17 00:00:00 2001 From: Brian Simpson Date: Mon, 22 Feb 2021 15:36:36 -0800 Subject: [PATCH 1/2] (squash) Copy code from graphql no changes have been made yet! --- baseplate/lib/circuit_breaker/__init__.py | 0 baseplate/lib/circuit_breaker/breaker.py | 103 ++++++++++++++++++ .../lib/circuit_breaker/breaker_client.py | 58 ++++++++++ .../lib/circuit_breaker/cassandra_context.py | 34 ++++++ baseplate/lib/circuit_breaker/errors.py | 4 + baseplate/lib/circuit_breaker/http_context.py | 30 +++++ baseplate/lib/circuit_breaker/observer.py | 30 +++++ .../lib/circuit_breaker/redis_context.py | 21 ++++ .../lib/circuit_breaker/thrift_context.py | 35 ++++++ tests/unit/lib/circuit_breaker_test.py | 102 +++++++++++++++++ 10 files changed, 417 insertions(+) create mode 100644 baseplate/lib/circuit_breaker/__init__.py create mode 100644 baseplate/lib/circuit_breaker/breaker.py create mode 100644 baseplate/lib/circuit_breaker/breaker_client.py create mode 100644 baseplate/lib/circuit_breaker/cassandra_context.py create mode 100644 baseplate/lib/circuit_breaker/errors.py create mode 100644 baseplate/lib/circuit_breaker/http_context.py create mode 100644 baseplate/lib/circuit_breaker/observer.py create mode 100644 baseplate/lib/circuit_breaker/redis_context.py create mode 100644 baseplate/lib/circuit_breaker/thrift_context.py create mode 100644 tests/unit/lib/circuit_breaker_test.py diff --git a/baseplate/lib/circuit_breaker/__init__.py b/baseplate/lib/circuit_breaker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/baseplate/lib/circuit_breaker/breaker.py b/baseplate/lib/circuit_breaker/breaker.py new file mode 100644 index 000000000..1ce0ef10a --- /dev/null +++ b/baseplate/lib/circuit_breaker/breaker.py @@ -0,0 +1,103 @@ +from collections import deque +from datetime import datetime +from datetime import timedelta +from enum import Enum +from math import ceil +from random import random +from typing import Deque + + +class BreakerState(Enum): + WORKING = "working" + TRIPPED = "tripped" + # trip immediately after failure + TESTING = "testing" + + +class Breaker: + _state: BreakerState = BreakerState.WORKING + _is_bucket_full: bool = False + + def __init__( + self, + name: str, + samples: int = 20, + trip_failure_ratio: float = 0.5, + trip_for: timedelta = timedelta(minutes=1), + fuzz_ratio: float = 0.1, + ): + """ + * name: str - full name/path of the circuit breaker + * samples: int - number of previous results used to calculate the trip failure ratio + * trip_failure_percent: float - the minimum ratio of sampled failed results to trip the breaker + * trip_for: timedelta - how long to remain tripped before resetting the breaker + * fuzz_ratio: float - how much to randomly add/subtract to the trip_for time + """ + self.name = name + self.samples = samples + self.results_bucket: Deque = deque([], self.samples) + self.tripped_until: datetime = datetime.utcnow() + self.trip_threshold = ceil(trip_failure_ratio * samples) + self.trip_for = trip_for + self.fuzz_ratio = fuzz_ratio + self.reset() + + @property + def state(self) -> BreakerState: + if self._state == BreakerState.TRIPPED and (datetime.utcnow() >= self.tripped_until): + self.set_state(BreakerState.TESTING) + + return self._state + + def register_attempt(self, success: bool): + # This breaker has already tripped, so ignore the "late" registrations + if self.state == BreakerState.TRIPPED: + return + + if not success: + self.failures += 1 + + if self._is_bucket_full and not self.results_bucket[0]: + self.failures -= 1 + + self.results_bucket.append(success) + + if not self._is_bucket_full and (len(self.results_bucket) == self.samples): + self._is_bucket_full = True + + if success and (self.state == BreakerState.TESTING): + self.reset() + return + + if self.state == BreakerState.TESTING: + # failure in the TESTING state trips the breaker immediately + self.trip() + return + + if not self._is_bucket_full: + # no need to check anything if we haven't recorded enough samples + return + + # check for trip condition + if self.failures >= self.trip_threshold: + self.trip() + + def set_state(self, state: BreakerState): + self._state = state + + def trip(self): + if self.fuzz_ratio > 0.0: + fuzz_ratio = ((2 * random()) - 1.0) * self.fuzz_ratio + fuzz_ratio = 1 + fuzz_ratio + else: + fuzz_ratio = 1.0 + + self.tripped_until = datetime.utcnow() + (self.trip_for * fuzz_ratio) + self.set_state(BreakerState.TRIPPED) + + def reset(self): + self.results_bucket.clear() + self.failures = 0 + self._is_bucket_full = False + self.tripped_until = None + self.set_state(BreakerState.WORKING) diff --git a/baseplate/lib/circuit_breaker/breaker_client.py b/baseplate/lib/circuit_breaker/breaker_client.py new file mode 100644 index 000000000..4908c319d --- /dev/null +++ b/baseplate/lib/circuit_breaker/breaker_client.py @@ -0,0 +1,58 @@ +from typing import Any + +import baseplate + +from baseplate.clients import ContextFactory +from baseplate.lib import config + +from graphql_api.lib.circuit_breaker.breaker import Breaker + + +class CircuitBreakerFactory(ContextFactory): + def __init__(self, name, cfg): + self.breaker_box = CircuitBreakerBox(name.replace("_breaker", ""), cfg) + + def make_object_for_context(self, name: str, span: "baseplate.Span") -> Any: + return self.breaker_box + + @staticmethod + def get_breaker_cfg(app_config, default_prefix, cfg_prefix, cfg_spec): + cfg = config.parse_config(app_config, {cfg_prefix: cfg_spec}) + breaker_cfg = getattr(cfg, cfg_prefix) + default_cfg = config.parse_config(app_config, {default_prefix: cfg_spec}) + default_breaker_cfg = getattr(default_cfg, default_prefix) + + for k in cfg_spec: + if getattr(breaker_cfg, k) is None: + setattr(breaker_cfg, k, getattr(default_breaker_cfg, k)) + return breaker_cfg + + @classmethod + def from_config(cls, name, app_config, default_prefix, cfg_prefix, cfg_spec): + breaker_cfg = cls.get_breaker_cfg(app_config, default_prefix, cfg_prefix, cfg_spec) + return cls(name, breaker_cfg) + + +class CircuitBreakerBox: + def __init__(self, name, cfg): + self.name = name + self.cfg = cfg + self.breaker_box = {} + + def get_endpoint_breaker(self, endpoint=None): + if not endpoint: + # service breaker + endpoint = "service" + + # lazy add breaker into breaker box + if endpoint not in self.breaker_box: + breaker = Breaker( + name=f"{self.name}.{endpoint}", + samples=self.cfg.samples, + trip_failure_ratio=self.cfg.trip_failure_ratio, + trip_for=self.cfg.trip_for, + fuzz_ratio=self.cfg.fuzz_ratio, + ) + + self.breaker_box[endpoint] = breaker + return self.breaker_box[endpoint] diff --git a/baseplate/lib/circuit_breaker/cassandra_context.py b/baseplate/lib/circuit_breaker/cassandra_context.py new file mode 100644 index 000000000..d5f29929c --- /dev/null +++ b/baseplate/lib/circuit_breaker/cassandra_context.py @@ -0,0 +1,34 @@ +import sys + +from contextlib import contextmanager + +from cassandra.cluster import DriverException +from cassandra.cluster import NoHostAvailable + +from graphql_api.errors import raise_graphql_server_error +from graphql_api.lib.circuit_breaker.errors import BreakerTrippedError +from graphql_api.lib.circuit_breaker.observer import BreakerObserver +from graphql_api.lib.delegations import PLATFORM_SLACK + + +@contextmanager +def cassandra_circuit_breaker(context): + breaker = context.cassandra_breaker.get_endpoint_breaker() + breaker_observer = BreakerObserver(context, breaker) + + try: + breaker_observer.check_state() + except BreakerTrippedError: + raise_graphql_server_error( + context, "Cassandra connection failure", upstream_exc_info=sys.exc_info(), owner=PLATFORM_SLACK + ) + + success: bool = True + try: + yield + except (NoHostAvailable, DriverException): + # Errors of connection, timeout, etc. + success = False + raise + finally: + breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/errors.py b/baseplate/lib/circuit_breaker/errors.py new file mode 100644 index 000000000..d2625496f --- /dev/null +++ b/baseplate/lib/circuit_breaker/errors.py @@ -0,0 +1,4 @@ +class BreakerTrippedError(Exception): + def __init__(self): + default_message = "Breaker tripped!" + super(BreakerTrippedError, self).__init__(default_message) diff --git a/baseplate/lib/circuit_breaker/http_context.py b/baseplate/lib/circuit_breaker/http_context.py new file mode 100644 index 000000000..9b653dfd3 --- /dev/null +++ b/baseplate/lib/circuit_breaker/http_context.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +from requests.exceptions import ConnectionError + +from graphql_api.errors import GraphQLUpstreamHTTPRequestError +from graphql_api.http_adapter import HTTPRequestTimeout +from graphql_api.lib.circuit_breaker.observer import BreakerObserver + + +@contextmanager +def http_circuit_breaker(context, breaker): + breaker_observer = BreakerObserver(context, breaker) + breaker_observer.check_state() + + success: bool = True + + try: + yield + except (ConnectionError, HTTPRequestTimeout): + # ConnectionError can be caused by DNS issues + success = False + raise + except GraphQLUpstreamHTTPRequestError as e: + if e.code >= 500: + success = False + raise + except Exception: + raise + finally: + breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/observer.py b/baseplate/lib/circuit_breaker/observer.py new file mode 100644 index 000000000..c92eae7d4 --- /dev/null +++ b/baseplate/lib/circuit_breaker/observer.py @@ -0,0 +1,30 @@ +from graphql_api.lib.circuit_breaker.breaker import BreakerState +from graphql_api.lib.circuit_breaker.errors import BreakerTrippedError + +METRICS_PREFIX = "breakers" + + +class BreakerObserver: + def __init__(self, context, breaker): + self.context = context + self.breaker = breaker + self.name = breaker.name + + def on_fast_failed_request(self): + self.context.logger.debug(f"Circuit breaker '{self.name}' tripped; request failed fast") + self.context.trace.incr_tag(f"{METRICS_PREFIX}.{self.name}.request.fail_fast") + + def on_state_change(self, prev, curr): + self.context.trace.incr_tag(f"{METRICS_PREFIX}.{self.name}.state_change.{prev.value}.{curr.value}") + + def register_attempt(self, success): + prev_state = self.breaker.state + self.breaker.register_attempt(success) + curr_state = self.breaker.state + if prev_state != curr_state: + self.on_state_change(prev_state, curr_state) + + def check_state(self): + if self.breaker.state == BreakerState.TRIPPED: + self.on_fast_failed_request() + raise BreakerTrippedError() diff --git a/baseplate/lib/circuit_breaker/redis_context.py b/baseplate/lib/circuit_breaker/redis_context.py new file mode 100644 index 000000000..ccaff1510 --- /dev/null +++ b/baseplate/lib/circuit_breaker/redis_context.py @@ -0,0 +1,21 @@ +from contextlib import contextmanager + +from redis.exceptions import ConnectionError +from redis.exceptions import TimeoutError + +from graphql_api.lib.circuit_breaker.observer import BreakerObserver + + +@contextmanager +def redis_circuit_breaker(context, breaker): + breaker_observer = BreakerObserver(context, breaker) + breaker_observer.check_state() + + success: bool = True + try: + yield + except (ConnectionError, TimeoutError): + success = False + raise + finally: + breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/thrift_context.py b/baseplate/lib/circuit_breaker/thrift_context.py new file mode 100644 index 000000000..ef8b05b7c --- /dev/null +++ b/baseplate/lib/circuit_breaker/thrift_context.py @@ -0,0 +1,35 @@ +from contextlib import contextmanager + +from thrift.protocol.TProtocol import TProtocolException +from thrift.Thrift import TApplicationException +from thrift.Thrift import TException +from thrift.transport.TTransport import TTransportException + +from graphql_api.lib.circuit_breaker.observer import BreakerObserver + + +@contextmanager +def thrift_circuit_breaker(context, breaker): + breaker_observer = BreakerObserver(context, breaker) + breaker_observer.check_state() + + success: bool = True + try: + yield + + except (TApplicationException, TTransportException, TProtocolException): + # Unhealthy errors: + # * Unknown thrift failure + # * DNS, socket, connection error + # * serialization error + success = False + raise + except TException: + # Healthy errors: known thrift exception, defined in the IDL + raise + except Exception: + # Any other + success = False + raise + finally: + breaker_observer.register_attempt(success) diff --git a/tests/unit/lib/circuit_breaker_test.py b/tests/unit/lib/circuit_breaker_test.py new file mode 100644 index 000000000..e3a13d099 --- /dev/null +++ b/tests/unit/lib/circuit_breaker_test.py @@ -0,0 +1,102 @@ +from datetime import datetime +from datetime import timedelta + +import pytest + +from graphql_api.lib.circuit_breaker.breaker import Breaker +from graphql_api.lib.circuit_breaker.breaker import BreakerState + + +@pytest.fixture +def breaker(): + return Breaker( + name="test", samples=4, trip_failure_ratio=0.5, trip_for=timedelta(seconds=60), fuzz_ratio=0.1 + ) + + +@pytest.fixture +def tripped_breaker(breaker): + for attempt in [True, True, False, False]: + breaker.register_attempt(attempt) + return breaker + + +@pytest.fixture +def tripped_exact_breaker(breaker): + breaker.fuzz_ratio = 0.0 + for attempt in [True, True, False, False]: + breaker.register_attempt(attempt) + return breaker + + +@pytest.fixture +def testing_breaker(tripped_breaker): + tripped_breaker.tripped_until = datetime.utcnow() + return tripped_breaker + + +@pytest.mark.parametrize( + "attempts,expected_state", + [ + ([], BreakerState.WORKING), + ([True], BreakerState.WORKING), + ([False], BreakerState.WORKING), + ([True, True], BreakerState.WORKING), + ([False, False], BreakerState.WORKING), + ([True, True, True], BreakerState.WORKING), + ([False, False, False], BreakerState.WORKING), + ([True, True, True, True], BreakerState.WORKING), + ([True, True, True, False], BreakerState.WORKING), + ([True, True, False, True], BreakerState.WORKING), + ([True, False, True, True], BreakerState.WORKING), + ([False, True, True, True], BreakerState.WORKING), + ([False, False, False, False], BreakerState.TRIPPED), + ([True, True, False, False], BreakerState.TRIPPED), + ([False, False, True, True], BreakerState.TRIPPED), + ([True, False, True, False], BreakerState.TRIPPED), + ([False, True, False, True], BreakerState.TRIPPED), + ], +) +def test_breaker_state(breaker, attempts, expected_state): + for attempt in attempts: + breaker.register_attempt(attempt) + assert breaker.state == expected_state + + +def test_testing_state(testing_breaker): + assert testing_breaker.state == BreakerState.TESTING + + +def test_trip_after_successful_test(testing_breaker): + testing_breaker.register_attempt(True) + assert testing_breaker.state == BreakerState.WORKING + + +def test_trip_after_failed_test(testing_breaker): + testing_breaker.register_attempt(False) + assert testing_breaker.state == BreakerState.TRIPPED + + +def test_late_register_success(tripped_breaker): + tripped_breaker.register_attempt(True) + assert tripped_breaker.state == BreakerState.TRIPPED + assert tripped_breaker.failures == 2 + + +def test_late_register_failure(tripped_breaker): + tripped_breaker.register_attempt(False) + assert tripped_breaker.state == BreakerState.TRIPPED + assert tripped_breaker.failures == 2 + + +def test_trip_for_exact(tripped_exact_breaker): + assert tripped_exact_breaker.fuzz_ratio == 0.0 + expected_tripped_until = datetime.utcnow() + timedelta(seconds=60) + assert tripped_exact_breaker.tripped_until <= expected_tripped_until + + +def test_trip_for_fuzzing(tripped_breaker): + assert tripped_breaker.fuzz_ratio == 0.1 + expected_tripped_until = datetime.utcnow() + timedelta(seconds=60) + delta = abs(tripped_breaker.tripped_until - expected_tripped_until) + assert delta <= timedelta(seconds=6, milliseconds=1) From 53f7105e44004e9c548e57546e81aa6662d76f0f Mon Sep 17 00:00:00 2001 From: Brian Simpson Date: Mon, 22 Feb 2021 16:02:22 -0800 Subject: [PATCH 2/2] Add circuit breakers --- baseplate/lib/circuit_breaker/__init__.py | 10 + baseplate/lib/circuit_breaker/breaker.py | 48 +++-- .../lib/circuit_breaker/breaker_client.py | 58 ------ .../lib/circuit_breaker/cassandra_context.py | 34 ---- baseplate/lib/circuit_breaker/errors.py | 4 +- baseplate/lib/circuit_breaker/factory.py | 164 +++++++++++++++ baseplate/lib/circuit_breaker/http_context.py | 30 --- baseplate/lib/circuit_breaker/observer.py | 30 --- .../lib/circuit_breaker/redis_context.py | 21 -- .../lib/circuit_breaker/thrift_context.py | 35 ---- tests/integration/circuit_breaker_tests.py | 192 ++++++++++++++++++ tests/unit/lib/circuit_breaker_test.py | 10 +- 12 files changed, 411 insertions(+), 225 deletions(-) delete mode 100644 baseplate/lib/circuit_breaker/breaker_client.py delete mode 100644 baseplate/lib/circuit_breaker/cassandra_context.py create mode 100644 baseplate/lib/circuit_breaker/factory.py delete mode 100644 baseplate/lib/circuit_breaker/http_context.py delete mode 100644 baseplate/lib/circuit_breaker/observer.py delete mode 100644 baseplate/lib/circuit_breaker/redis_context.py delete mode 100644 baseplate/lib/circuit_breaker/thrift_context.py create mode 100644 tests/integration/circuit_breaker_tests.py diff --git a/baseplate/lib/circuit_breaker/__init__.py b/baseplate/lib/circuit_breaker/__init__.py index e69de29bb..9c5bf37b0 100644 --- a/baseplate/lib/circuit_breaker/__init__.py +++ b/baseplate/lib/circuit_breaker/__init__.py @@ -0,0 +1,10 @@ +from baseplate.lib.circuit_breaker.errors import BreakerTrippedError +from baseplate.lib.circuit_breaker.factory import breaker_box_from_config +from baseplate.lib.circuit_breaker.factory import CircuitBreakerClientWrapperFactory + + +__all__ = [ + "breaker_box_from_config", + "BreakerTrippedError", + "CircuitBreakerClientWrapperFactory", +] diff --git a/baseplate/lib/circuit_breaker/breaker.py b/baseplate/lib/circuit_breaker/breaker.py index 1ce0ef10a..78ebec9be 100644 --- a/baseplate/lib/circuit_breaker/breaker.py +++ b/baseplate/lib/circuit_breaker/breaker.py @@ -15,6 +15,29 @@ class BreakerState(Enum): class Breaker: + """Circuit breaker. + + The circuit breaker has 3 states: + * WORKING (closed) + * TRIPPED (open) + * TESTING (half open) + + During normal operation the circuit breaker is in the WORKING state. + + When the number of failures exceeds the threshold the breaker moves to the TRIPPED state. It + stays in this state for the timeout period. + + After the timeout period passes the breaker moves to the TESTING state. If the next attempt + is successful the breaker moves to the WORKING state. If the next attempt is a failure the + breaker moves back to the TRIPPED state. + + :param name: full name/path of the circuit breaker + :param samples: number of previous results used to calculate the trip failure ratio + :param trip_failure_percent: the minimum ratio of sampled failed results to trip the breaker + :param trip_for: how long to remain tripped before resetting the breaker + :param fuzz_ratio: how much to randomly add/subtract to the trip_for time + """ + _state: BreakerState = BreakerState.WORKING _is_bucket_full: bool = False @@ -26,13 +49,6 @@ def __init__( trip_for: timedelta = timedelta(minutes=1), fuzz_ratio: float = 0.1, ): - """ - * name: str - full name/path of the circuit breaker - * samples: int - number of previous results used to calculate the trip failure ratio - * trip_failure_percent: float - the minimum ratio of sampled failed results to trip the breaker - * trip_for: timedelta - how long to remain tripped before resetting the breaker - * fuzz_ratio: float - how much to randomly add/subtract to the trip_for time - """ self.name = name self.samples = samples self.results_bucket: Deque = deque([], self.samples) @@ -49,7 +65,13 @@ def state(self) -> BreakerState: return self._state - def register_attempt(self, success: bool): + def register_attempt(self, success: bool) -> None: + """Register a success or failure. + + This may cause the state to change. + + :param success: Whether the attempt was a success (not a failure). + """ # This breaker has already tripped, so ignore the "late" registrations if self.state == BreakerState.TRIPPED: return @@ -82,10 +104,11 @@ def register_attempt(self, success: bool): if self.failures >= self.trip_threshold: self.trip() - def set_state(self, state: BreakerState): + def set_state(self, state: BreakerState) -> None: self._state = state - def trip(self): + def trip(self) -> None: + """Change state to TRIPPED and set the timeout after which state will change to TESTING.""" if self.fuzz_ratio > 0.0: fuzz_ratio = ((2 * random()) - 1.0) * self.fuzz_ratio fuzz_ratio = 1 + fuzz_ratio @@ -95,9 +118,10 @@ def trip(self): self.tripped_until = datetime.utcnow() + (self.trip_for * fuzz_ratio) self.set_state(BreakerState.TRIPPED) - def reset(self): + def reset(self) -> None: + """Reset to freshly initialized WORKING state.""" self.results_bucket.clear() self.failures = 0 self._is_bucket_full = False - self.tripped_until = None + self.tripped_until = datetime.utcnow() self.set_state(BreakerState.WORKING) diff --git a/baseplate/lib/circuit_breaker/breaker_client.py b/baseplate/lib/circuit_breaker/breaker_client.py deleted file mode 100644 index 4908c319d..000000000 --- a/baseplate/lib/circuit_breaker/breaker_client.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any - -import baseplate - -from baseplate.clients import ContextFactory -from baseplate.lib import config - -from graphql_api.lib.circuit_breaker.breaker import Breaker - - -class CircuitBreakerFactory(ContextFactory): - def __init__(self, name, cfg): - self.breaker_box = CircuitBreakerBox(name.replace("_breaker", ""), cfg) - - def make_object_for_context(self, name: str, span: "baseplate.Span") -> Any: - return self.breaker_box - - @staticmethod - def get_breaker_cfg(app_config, default_prefix, cfg_prefix, cfg_spec): - cfg = config.parse_config(app_config, {cfg_prefix: cfg_spec}) - breaker_cfg = getattr(cfg, cfg_prefix) - default_cfg = config.parse_config(app_config, {default_prefix: cfg_spec}) - default_breaker_cfg = getattr(default_cfg, default_prefix) - - for k in cfg_spec: - if getattr(breaker_cfg, k) is None: - setattr(breaker_cfg, k, getattr(default_breaker_cfg, k)) - return breaker_cfg - - @classmethod - def from_config(cls, name, app_config, default_prefix, cfg_prefix, cfg_spec): - breaker_cfg = cls.get_breaker_cfg(app_config, default_prefix, cfg_prefix, cfg_spec) - return cls(name, breaker_cfg) - - -class CircuitBreakerBox: - def __init__(self, name, cfg): - self.name = name - self.cfg = cfg - self.breaker_box = {} - - def get_endpoint_breaker(self, endpoint=None): - if not endpoint: - # service breaker - endpoint = "service" - - # lazy add breaker into breaker box - if endpoint not in self.breaker_box: - breaker = Breaker( - name=f"{self.name}.{endpoint}", - samples=self.cfg.samples, - trip_failure_ratio=self.cfg.trip_failure_ratio, - trip_for=self.cfg.trip_for, - fuzz_ratio=self.cfg.fuzz_ratio, - ) - - self.breaker_box[endpoint] = breaker - return self.breaker_box[endpoint] diff --git a/baseplate/lib/circuit_breaker/cassandra_context.py b/baseplate/lib/circuit_breaker/cassandra_context.py deleted file mode 100644 index d5f29929c..000000000 --- a/baseplate/lib/circuit_breaker/cassandra_context.py +++ /dev/null @@ -1,34 +0,0 @@ -import sys - -from contextlib import contextmanager - -from cassandra.cluster import DriverException -from cassandra.cluster import NoHostAvailable - -from graphql_api.errors import raise_graphql_server_error -from graphql_api.lib.circuit_breaker.errors import BreakerTrippedError -from graphql_api.lib.circuit_breaker.observer import BreakerObserver -from graphql_api.lib.delegations import PLATFORM_SLACK - - -@contextmanager -def cassandra_circuit_breaker(context): - breaker = context.cassandra_breaker.get_endpoint_breaker() - breaker_observer = BreakerObserver(context, breaker) - - try: - breaker_observer.check_state() - except BreakerTrippedError: - raise_graphql_server_error( - context, "Cassandra connection failure", upstream_exc_info=sys.exc_info(), owner=PLATFORM_SLACK - ) - - success: bool = True - try: - yield - except (NoHostAvailable, DriverException): - # Errors of connection, timeout, etc. - success = False - raise - finally: - breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/errors.py b/baseplate/lib/circuit_breaker/errors.py index d2625496f..47ab658d3 100644 --- a/baseplate/lib/circuit_breaker/errors.py +++ b/baseplate/lib/circuit_breaker/errors.py @@ -1,4 +1,4 @@ class BreakerTrippedError(Exception): - def __init__(self): + def __init__(self) -> None: default_message = "Breaker tripped!" - super(BreakerTrippedError, self).__init__(default_message) + super().__init__(default_message) diff --git a/baseplate/lib/circuit_breaker/factory.py b/baseplate/lib/circuit_breaker/factory.py new file mode 100644 index 000000000..9082d0704 --- /dev/null +++ b/baseplate/lib/circuit_breaker/factory.py @@ -0,0 +1,164 @@ +import logging + +from contextlib import contextmanager +from datetime import timedelta +from typing import Any +from typing import Dict +from typing import Iterator +from typing import Tuple +from typing import Type + +from baseplate import Span +from baseplate.clients import ContextFactory +from baseplate.lib import config +from baseplate.lib.circuit_breaker.breaker import Breaker +from baseplate.lib.circuit_breaker.breaker import BreakerState +from baseplate.lib.circuit_breaker.errors import BreakerTrippedError + + +logger = logging.getLogger(__name__) + + +class CircuitBreakerClientWrapperFactory(ContextFactory): + """Provide an object combining a client and circuit breaker for use with the client. + + When attached to the baseplate `RequestContext` can be used like: + + ``` + breakable_exceptions = (...) # exceptions indicating the service is unhealthy + with context.breaker_wrapped_client.breaker_context("identifier", breakable_exceptions) as svc: + svc.get_something() + ``` + """ + + def __init__(self, client_factory: ContextFactory, breaker_box: "CircuitBreakerBox"): + self.client_factory = client_factory + self.breaker_box = breaker_box + + def make_object_for_context(self, name: str, span: Span) -> Any: + client = self.client_factory.make_object_for_context(name, span) + + return CircuitBreakerWrappedClient(span, self.breaker_box, client) + + +class CircuitBreakerWrappedClient: + def __init__(self, span: Span, breaker_box: "CircuitBreakerBox", client: Any): + self.span = span + self.breaker_box = breaker_box + self._client = client + + @property + def client(self) -> Any: + """Return the raw, undecorated client""" + return self._client + + @contextmanager + def breaker_context( + self, operation: str, breakable_exceptions: Tuple[Type[Exception]] + ) -> Iterator[Any]: + """Get a context manager to perform client operations within. + + Yields the client to use within the breaker context. + + The context manager manages the Breaker's state and registers + successes and failures. + + When the `Breaker` is in TRIPPED state all calls to this context + manager will raise a `BreakerTrippedError` exception. + + :param operation: The operation name, used to get a specific `Breaker`. + :param breakable_exceptions: Tuple of exceptions that count as failures + """ + breaker = self.breaker_box.get(operation) + + if breaker.state == BreakerState.TRIPPED: + logger.debug("Circuit breaker '%s' tripped; request failed fast", breaker.name) + self.span.incr_tag(f"breakers.{breaker.name}.request.fail_fast") + raise BreakerTrippedError() + + success: bool = True + try: + # yield to the application code that will use + # the client covered by this breaker. if this + # raises an exception we will catch it here. + yield self._client + except breakable_exceptions: + # only known exceptions in `breakable_exceptions` should trigger + # opening the circuit. the client call may raise exceptions that + # are a meaningful response, like defined thrift IDL exceptions. + success = False + raise + finally: + prev = breaker.state + breaker.register_attempt(success) + final = breaker.state + if prev != final: + self.span.incr_tag( + f"breakers.{breaker.name}.state_change.{prev.value}.{final.value}" + ) + + +class CircuitBreakerBox: + """Container for a client's `Breaker`s. + + Will lazily create `Breaker`s for each operation as needed. There + is no global coordination across operations--each `Breaker` is + isolated and does not consider the state or failure rates in other + `Breaker`s. + + :param name: The base `Breaker` name. The full name is like "name.operation". + :param samples: See `Breaker` + :param trip_failure_ratio: See `Breaker` + :param trip_for: See `Breaker` + :param fuzz_ratio: See `Breaker` + """ + + def __init__( + self, + name: str, + samples: int, + trip_failure_ratio: float, + trip_for: timedelta, + fuzz_ratio: float, + ): + self.name = name + self.samples = samples + self.trip_failure_ratio = trip_failure_ratio + self.trip_for = trip_for + self.fuzz_ratio = fuzz_ratio + self.breaker_box: Dict[str, Breaker] = {} + + def get(self, operation: str) -> Breaker: + # lazy add breaker into breaker box + if operation not in self.breaker_box: + breaker = Breaker( + name=f"{self.name}.{operation}", + samples=self.samples, + trip_failure_ratio=self.trip_failure_ratio, + trip_for=self.trip_for, + fuzz_ratio=self.fuzz_ratio, + ) + self.breaker_box[operation] = breaker + return self.breaker_box[operation] + + +def breaker_box_from_config( + app_config: config.RawConfig, name: str, prefix: str = "breaker.", +) -> CircuitBreakerBox: + """Make a CircuitBreakerBox from a configuration dictionary.""" + # TODO: fix default handling here. if these are not set + # they will be None and passed through to the Breaker() constructor + # which will override the defaults set in Breaker() + assert prefix.endswith(".") + parser = config.SpecParser( + { + "samples": config.Optional(config.Integer), + "trip_failure_ratio": config.Optional(config.Float), + "trip_for": config.Optional(config.Timespan), + "fuzz_ratio": config.Optional(config.Float), + } + ) + options = parser.parse(prefix[:-1], app_config) + return CircuitBreakerBox( + name, options.samples, options.trip_failure_ratio, options.trip_for, options.fuzz_ratio + ) diff --git a/baseplate/lib/circuit_breaker/http_context.py b/baseplate/lib/circuit_breaker/http_context.py deleted file mode 100644 index 9b653dfd3..000000000 --- a/baseplate/lib/circuit_breaker/http_context.py +++ /dev/null @@ -1,30 +0,0 @@ -from contextlib import contextmanager - -from requests.exceptions import ConnectionError - -from graphql_api.errors import GraphQLUpstreamHTTPRequestError -from graphql_api.http_adapter import HTTPRequestTimeout -from graphql_api.lib.circuit_breaker.observer import BreakerObserver - - -@contextmanager -def http_circuit_breaker(context, breaker): - breaker_observer = BreakerObserver(context, breaker) - breaker_observer.check_state() - - success: bool = True - - try: - yield - except (ConnectionError, HTTPRequestTimeout): - # ConnectionError can be caused by DNS issues - success = False - raise - except GraphQLUpstreamHTTPRequestError as e: - if e.code >= 500: - success = False - raise - except Exception: - raise - finally: - breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/observer.py b/baseplate/lib/circuit_breaker/observer.py deleted file mode 100644 index c92eae7d4..000000000 --- a/baseplate/lib/circuit_breaker/observer.py +++ /dev/null @@ -1,30 +0,0 @@ -from graphql_api.lib.circuit_breaker.breaker import BreakerState -from graphql_api.lib.circuit_breaker.errors import BreakerTrippedError - -METRICS_PREFIX = "breakers" - - -class BreakerObserver: - def __init__(self, context, breaker): - self.context = context - self.breaker = breaker - self.name = breaker.name - - def on_fast_failed_request(self): - self.context.logger.debug(f"Circuit breaker '{self.name}' tripped; request failed fast") - self.context.trace.incr_tag(f"{METRICS_PREFIX}.{self.name}.request.fail_fast") - - def on_state_change(self, prev, curr): - self.context.trace.incr_tag(f"{METRICS_PREFIX}.{self.name}.state_change.{prev.value}.{curr.value}") - - def register_attempt(self, success): - prev_state = self.breaker.state - self.breaker.register_attempt(success) - curr_state = self.breaker.state - if prev_state != curr_state: - self.on_state_change(prev_state, curr_state) - - def check_state(self): - if self.breaker.state == BreakerState.TRIPPED: - self.on_fast_failed_request() - raise BreakerTrippedError() diff --git a/baseplate/lib/circuit_breaker/redis_context.py b/baseplate/lib/circuit_breaker/redis_context.py deleted file mode 100644 index ccaff1510..000000000 --- a/baseplate/lib/circuit_breaker/redis_context.py +++ /dev/null @@ -1,21 +0,0 @@ -from contextlib import contextmanager - -from redis.exceptions import ConnectionError -from redis.exceptions import TimeoutError - -from graphql_api.lib.circuit_breaker.observer import BreakerObserver - - -@contextmanager -def redis_circuit_breaker(context, breaker): - breaker_observer = BreakerObserver(context, breaker) - breaker_observer.check_state() - - success: bool = True - try: - yield - except (ConnectionError, TimeoutError): - success = False - raise - finally: - breaker_observer.register_attempt(success) diff --git a/baseplate/lib/circuit_breaker/thrift_context.py b/baseplate/lib/circuit_breaker/thrift_context.py deleted file mode 100644 index ef8b05b7c..000000000 --- a/baseplate/lib/circuit_breaker/thrift_context.py +++ /dev/null @@ -1,35 +0,0 @@ -from contextlib import contextmanager - -from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException -from thrift.transport.TTransport import TTransportException - -from graphql_api.lib.circuit_breaker.observer import BreakerObserver - - -@contextmanager -def thrift_circuit_breaker(context, breaker): - breaker_observer = BreakerObserver(context, breaker) - breaker_observer.check_state() - - success: bool = True - try: - yield - - except (TApplicationException, TTransportException, TProtocolException): - # Unhealthy errors: - # * Unknown thrift failure - # * DNS, socket, connection error - # * serialization error - success = False - raise - except TException: - # Healthy errors: known thrift exception, defined in the IDL - raise - except Exception: - # Any other - success = False - raise - finally: - breaker_observer.register_attempt(success) diff --git a/tests/integration/circuit_breaker_tests.py b/tests/integration/circuit_breaker_tests.py new file mode 100644 index 000000000..1df34b51f --- /dev/null +++ b/tests/integration/circuit_breaker_tests.py @@ -0,0 +1,192 @@ +import unittest + +from collections import deque +from datetime import datetime +from datetime import timedelta +from unittest import mock + +import pytest + +from pytz import UTC + +from baseplate import Baseplate +from baseplate.clients import ContextFactory +from baseplate.lib.circuit_breaker import breaker_box_from_config +from baseplate.lib.circuit_breaker import CircuitBreakerClientWrapperFactory +from baseplate.lib.circuit_breaker.breaker import BreakerState +from baseplate.lib.circuit_breaker.errors import BreakerTrippedError + +from . import TestBaseplateObserver + + +class TestClientFactory(ContextFactory): + def __init__(self, client): + self.client = client + + def make_object_for_context(self, name, span): + return self.client + + +class CircuitBreakerTests(unittest.TestCase): + def setUp(self): + self.breaker_box = breaker_box_from_config( + app_config={ + "brkr.samples": "4", + "brkr.trip_failure_ratio": "0.75", + "brkr.trip_for": "1 minute", + "brkr.fuzz_ratio": "0.1", + }, + name="test_breaker", + prefix="brkr.", + ) + + self.client = mock.Mock() + client_factory = TestClientFactory(self.client) + + wrapped_client_factory = CircuitBreakerClientWrapperFactory( + client_factory, self.breaker_box + ) + + self.baseplate_observer = TestBaseplateObserver() + + baseplate = Baseplate() + baseplate.register(self.baseplate_observer) + baseplate.add_to_context("wrapped_client", wrapped_client_factory) + + self.context = baseplate.make_context_object() + self.server_span = baseplate.make_server_span(self.context, "test") + + def test_breaker_box(self): + breaker_box = self.context.wrapped_client.breaker_box + assert breaker_box.name == "test_breaker" + assert breaker_box.samples == 4 + assert breaker_box.trip_failure_ratio == 0.75 + assert breaker_box.trip_for == timedelta(seconds=60) + assert breaker_box.fuzz_ratio == 0.1 + + @mock.patch("baseplate.lib.circuit_breaker.breaker.random") + @mock.patch("baseplate.lib.circuit_breaker.breaker.datetime") + def test_breaker_context(self, datetime_mock, random_mock): + self.client.get_something.side_effect = [None, AttributeError, KeyError, ValueError] + + datetime_mock.utcnow.side_effect = [ + datetime(2021, 2, 25, 0, 0, 1, tzinfo=UTC), + datetime(2021, 2, 25, 0, 0, 2, tzinfo=UTC), + ] + + random_mock.return_value = 0.1 + + def make_call(_server_span, _context, *args): + with _server_span: + breaker_context = _context.wrapped_client.breaker_context( + operation="get_something", breakable_exceptions=(KeyError, ValueError), + ) + + with breaker_context as client: + client.get_something(*args) + + # test a few calls + make_call(self.server_span, self.context, "a") + + with pytest.raises(AttributeError): + # note that this is not in `breakable_exceptions` so + # counts as a success + make_call(self.server_span, self.context, "b") + + with pytest.raises(KeyError): + make_call(self.server_span, self.context, "c") + + with pytest.raises(ValueError): + make_call(self.server_span, self.context, "d") + + self.client.get_something.assert_has_calls( + [mock.call("a"), mock.call("b"), mock.call("c"), mock.call("d")] + ) + + breaker = self.breaker_box.breaker_box["get_something"] + assert breaker.name == "test_breaker.get_something" + assert breaker.results_bucket == deque([True, True, False, False]) + assert breaker.state == BreakerState.WORKING + assert breaker.tripped_until == datetime(2021, 2, 25, 0, 0, 2, tzinfo=UTC) + + # push into failed state + datetime_mock.utcnow.side_effect = [ + datetime(2021, 2, 25, 0, 0, 3, tzinfo=UTC), + datetime(2021, 2, 25, 0, 0, 4, tzinfo=UTC), + datetime(2021, 2, 25, 0, 0, 5, tzinfo=UTC), + ] + + self.client.get_something.reset_mock() + self.client.get_something.side_effect = [ValueError] + + with pytest.raises(ValueError): + make_call(self.server_span, self.context, "e") + + self.client.get_something.assert_called_once_with("e") + + breaker = self.breaker_box.breaker_box["get_something"] + assert breaker.name == "test_breaker.get_something" + assert breaker.results_bucket == deque([True, False, False, False]) + assert breaker.state == BreakerState.TRIPPED + assert breaker.tripped_until == datetime(2021, 2, 25, 0, 0, 58, 200000, tzinfo=UTC) + + # call while in failed state + datetime_mock.utcnow.side_effect = [ + datetime(2021, 2, 25, 0, 0, 6, tzinfo=UTC), + datetime(2021, 2, 25, 0, 0, 7, tzinfo=UTC), + ] + + self.client.get_something.reset_mock() + self.client.get_something.return_value = None + + with pytest.raises(BreakerTrippedError): + make_call(self.server_span, self.context, "f") + + self.client.get_something.assert_not_called() + + breaker = self.breaker_box.breaker_box["get_something"] + assert breaker.name == "test_breaker.get_something" + assert breaker.results_bucket == deque([True, False, False, False]) + assert breaker.state == BreakerState.TRIPPED + assert breaker.tripped_until == datetime(2021, 2, 25, 0, 0, 58, 200000, tzinfo=UTC) + + # call (and fail) while in testing state + datetime_mock.utcnow.side_effect = [ + datetime(2021, 2, 25, 0, 1, 0, tzinfo=UTC), + datetime(2021, 2, 25, 0, 1, 1, tzinfo=UTC), + datetime(2021, 2, 25, 0, 1, 2, tzinfo=UTC), + datetime(2021, 2, 25, 0, 1, 3, tzinfo=UTC), + ] + + self.client.get_something.reset_mock() + self.client.get_something.side_effect = [ValueError] + + assert breaker.state == BreakerState.TESTING + + with pytest.raises(ValueError): + make_call(self.server_span, self.context, "g") + + self.client.get_something.assert_called_once_with("g") + + assert breaker.results_bucket == deque([False, False, False, False]) + assert breaker.state == BreakerState.TRIPPED + assert breaker.tripped_until == datetime(2021, 2, 25, 0, 1, 56, 200000, tzinfo=UTC) + + # call (and succeed) while in testing state + datetime_mock.utcnow.side_effect = [ + datetime(2021, 2, 25, 0, 2, 0, tzinfo=UTC), + datetime(2021, 2, 25, 0, 2, 1, tzinfo=UTC), + ] + + self.client.get_something.reset_mock() + self.client.get_something.side_effect = [None] + + assert breaker.state == BreakerState.TESTING + + make_call(self.server_span, self.context, "h") + + self.client.get_something.assert_called_once_with("h") + + assert breaker.results_bucket == deque([]) + assert breaker.state == BreakerState.WORKING + assert breaker.tripped_until == datetime(2021, 2, 25, 0, 2, 1, tzinfo=UTC) diff --git a/tests/unit/lib/circuit_breaker_test.py b/tests/unit/lib/circuit_breaker_test.py index e3a13d099..7b2453240 100644 --- a/tests/unit/lib/circuit_breaker_test.py +++ b/tests/unit/lib/circuit_breaker_test.py @@ -3,14 +3,18 @@ import pytest -from graphql_api.lib.circuit_breaker.breaker import Breaker -from graphql_api.lib.circuit_breaker.breaker import BreakerState +from baseplate.lib.circuit_breaker.breaker import Breaker +from baseplate.lib.circuit_breaker.breaker import BreakerState @pytest.fixture def breaker(): return Breaker( - name="test", samples=4, trip_failure_ratio=0.5, trip_for=timedelta(seconds=60), fuzz_ratio=0.1 + name="test", + samples=4, + trip_failure_ratio=0.5, + trip_for=timedelta(seconds=60), + fuzz_ratio=0.1, )