From 6bf3c57f0098dacc5a7e5a6aecf395a3043aa213 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Thu, 25 Apr 2024 13:44:05 -0400 Subject: [PATCH] chore: Fix data type annotation in hook definition --- contract-tests/hook.py | 10 +++++----- ldclient/client.py | 12 ++++-------- ldclient/hook.py | 4 ++-- ldclient/testing/test_ldclient_hooks.py | 15 +++++++-------- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/contract-tests/hook.py b/contract-tests/hook.py index ec2708c..866ae41 100644 --- a/contract-tests/hook.py +++ b/contract-tests/hook.py @@ -1,7 +1,7 @@ from ldclient.hook import Hook, EvaluationSeriesContext from ldclient.evaluation import EvaluationDetail -from typing import Any, Optional +from typing import Optional import requests @@ -12,13 +12,13 @@ def __init__(self, name: str, callback: str, data: dict, errors: dict): self.__data = data self.__errors = errors - def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> Any: + def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict: return self.__post("beforeEvaluation", series_context, data, None) - def after_evaluation(self, series_context: EvaluationSeriesContext, data: Any, detail: EvaluationDetail) -> Any: + def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict: return self.__post("afterEvaluation", series_context, data, detail) - def __post(self, stage: str, series_context: EvaluationSeriesContext, data: Any, detail: Optional[EvaluationDetail]) -> Any: + def __post(self, stage: str, series_context: EvaluationSeriesContext, data: dict, detail: Optional[EvaluationDetail]) -> dict: if stage in self.__errors: raise Exception(self.__errors[stage]) @@ -42,4 +42,4 @@ def __post(self, stage: str, series_context: EvaluationSeriesContext, data: Any, requests.post(self.__callback, json=payload) - return {**(data or {}), **self.__data.get(stage, {})} + return {**data, **self.__data.get(stage, {})} diff --git a/ldclient/client.py b/ldclient/client.py index d7bf1e0..db7c966 100644 --- a/ldclient/client.py +++ b/ldclient/client.py @@ -37,10 +37,6 @@ from ldclient.migrations import Stage, OpTracker from ldclient.impl.flag_tracker import FlagTrackerImpl -from threading import Lock - - - class _FeatureStoreClientWrapper(FeatureStore): """Provides additional behavior that the client requires before or after feature store operations. @@ -639,24 +635,24 @@ def __evaluate_with_hooks(self, key: str, context: Context, default_value: Any, return evaluation_result - def __execute_before_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext) -> List[Any]: + def __execute_before_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext) -> List[dict]: return [ self.__try_execute_stage("beforeEvaluation", hook.metadata.name, lambda: hook.before_evaluation(series_context, {})) for hook in hooks ] - def __execute_after_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext, hook_data: List[Any], evaluation_detail: EvaluationDetail) -> List[Any]: + def __execute_after_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext, hook_data: List[dict], evaluation_detail: EvaluationDetail) -> List[dict]: return [ self.__try_execute_stage("afterEvaluation", hook.metadata.name, lambda: hook.after_evaluation(series_context, data, evaluation_detail)) for (hook, data) in reversed(list(zip(hooks, hook_data))) ] - def __try_execute_stage(self, method: str, hook_name: str, block: Callable[[], Any]) -> Any: + def __try_execute_stage(self, method: str, hook_name: str, block: Callable[[], dict]) -> dict: try: return block() except BaseException as e: log.error(f"An error occurred in {method} of the hook {hook_name}: #{e}") - return None + return {} @property def big_segment_store_status_provider(self) -> BigSegmentStoreStatusProvider: diff --git a/ldclient/hook.py b/ldclient/hook.py index 349d5d1..3f594fc 100644 --- a/ldclient/hook.py +++ b/ldclient/hook.py @@ -48,7 +48,7 @@ def metadata(self) -> Metadata: return Metadata(name='UNDEFINED') @abstractmethod - def before_evaluation(self, series_context: EvaluationSeriesContext, data: Any) -> Any: + def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict: """ The before method is called during the execution of a variation method before the flag value has been determined. The method is executed @@ -63,7 +63,7 @@ def before_evaluation(self, series_context: EvaluationSeriesContext, data: Any) return data @abstractmethod - def after_evaluation(self, series_context: EvaluationSeriesContext, data: Any, detail: EvaluationDetail) -> dict: + def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict: """ The after method is called during the execution of the variation method after the flag value has been determined. The method is executed diff --git a/ldclient/testing/test_ldclient_hooks.py b/ldclient/testing/test_ldclient_hooks.py index 41c095c..90dd2e6 100644 --- a/ldclient/testing/test_ldclient_hooks.py +++ b/ldclient/testing/test_ldclient_hooks.py @@ -2,10 +2,9 @@ from ldclient import LDClient, Config, Context from ldclient.hook import Hook, Metadata, EvaluationSeriesContext from ldclient.migrations import Stage - from ldclient.integrations.test_data import TestData -from typing import Callable, Any +from typing import Callable def record(label, log): @@ -16,7 +15,7 @@ def inner(*args, **kwargs): class MockHook(Hook): - def __init__(self, before_evaluation: Callable[[EvaluationSeriesContext, Any], dict], after_evaluation: Callable[[EvaluationSeriesContext, Any, EvaluationDetail], dict]): + def __init__(self, before_evaluation: Callable[[EvaluationSeriesContext, dict], dict], after_evaluation: Callable[[EvaluationSeriesContext, dict, EvaluationDetail], dict]): self.__before_evaluation = before_evaluation self.__after_evaluation = after_evaluation @@ -24,10 +23,10 @@ def __init__(self, before_evaluation: Callable[[EvaluationSeriesContext, Any], d def metadata(self) -> Metadata: return Metadata(name='test-hook') - def before_evaluation(self, series_context: EvaluationSeriesContext, data): + def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict: return self.__before_evaluation(series_context, data) - def after_evaluation(self, series_context: EvaluationSeriesContext, data, detail: EvaluationDetail): + def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict: return self.__after_evaluation(series_context, data, detail) @@ -95,7 +94,7 @@ def test_passing_data_from_before_to_after(): assert calls[0] == "from before" -def test_exception_in_before_passes_none(): +def test_exception_in_before_passes_empty_dict(): def raise_exception(series_context, data): raise Exception("error") @@ -107,7 +106,7 @@ def raise_exception(series_context, data): client.variation('flag-key', user, False) assert len(calls) == 1 - assert calls[0] is None + assert calls[0] == {} def test_exceptions_do_not_affect_data_passing_order(): @@ -127,7 +126,7 @@ def raise_exception(series_context, data): # NOTE: These are reversed since the push happens in the after_evaluation # (when hooks are reversed) assert calls[0] == "third hook" - assert calls[1] is None + assert calls[1] == {} assert calls[2] == "first hook"