Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Fix data type annotation in hook definition #289

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions contract-tests/hook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ldclient.hook import Hook, EvaluationSeriesContext
from ldclient.evaluation import EvaluationDetail

from typing import Any, Optional
from typing import Optional
import requests


Expand All @@ -12,13 +12,13 @@ def __init__(self, name: str, callback: str, data: dict, errors: dict):
self.__data = data
self.__errors = errors

def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> Any:
def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict:
return self.__post("beforeEvaluation", series_context, data, None)

def after_evaluation(self, series_context: EvaluationSeriesContext, data: Any, detail: EvaluationDetail) -> Any:
def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict:
return self.__post("afterEvaluation", series_context, data, detail)

def __post(self, stage: str, series_context: EvaluationSeriesContext, data: Any, detail: Optional[EvaluationDetail]) -> Any:
def __post(self, stage: str, series_context: EvaluationSeriesContext, data: dict, detail: Optional[EvaluationDetail]) -> dict:
if stage in self.__errors:
raise Exception(self.__errors[stage])

Expand All @@ -42,4 +42,4 @@ def __post(self, stage: str, series_context: EvaluationSeriesContext, data: Any,

requests.post(self.__callback, json=payload)

return {**(data or {}), **self.__data.get(stage, {})}
return {**data, **self.__data.get(stage, {})}
12 changes: 4 additions & 8 deletions ldclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@
from ldclient.migrations import Stage, OpTracker
from ldclient.impl.flag_tracker import FlagTrackerImpl

from threading import Lock




class _FeatureStoreClientWrapper(FeatureStore):
"""Provides additional behavior that the client requires before or after feature store operations.
Expand Down Expand Up @@ -639,24 +635,24 @@ def __evaluate_with_hooks(self, key: str, context: Context, default_value: Any,

return evaluation_result

def __execute_before_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext) -> List[Any]:
def __execute_before_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext) -> List[dict]:
return [
self.__try_execute_stage("beforeEvaluation", hook.metadata.name, lambda: hook.before_evaluation(series_context, {}))
for hook in hooks
]

def __execute_after_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext, hook_data: List[Any], evaluation_detail: EvaluationDetail) -> List[Any]:
def __execute_after_evaluation(self, hooks: List[Hook], series_context: EvaluationSeriesContext, hook_data: List[dict], evaluation_detail: EvaluationDetail) -> List[dict]:
return [
self.__try_execute_stage("afterEvaluation", hook.metadata.name, lambda: hook.after_evaluation(series_context, data, evaluation_detail))
for (hook, data) in reversed(list(zip(hooks, hook_data)))
]

def __try_execute_stage(self, method: str, hook_name: str, block: Callable[[], Any]) -> Any:
def __try_execute_stage(self, method: str, hook_name: str, block: Callable[[], dict]) -> dict:
try:
return block()
except BaseException as e:
log.error(f"An error occurred in {method} of the hook {hook_name}: #{e}")
return None
return {}

@property
def big_segment_store_status_provider(self) -> BigSegmentStoreStatusProvider:
Expand Down
4 changes: 2 additions & 2 deletions ldclient/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def metadata(self) -> Metadata:
return Metadata(name='UNDEFINED')

@abstractmethod
def before_evaluation(self, series_context: EvaluationSeriesContext, data: Any) -> Any:
def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict:
"""
The before method is called during the execution of a variation method
before the flag value has been determined. The method is executed
Expand All @@ -63,7 +63,7 @@ def before_evaluation(self, series_context: EvaluationSeriesContext, data: Any)
return data

@abstractmethod
def after_evaluation(self, series_context: EvaluationSeriesContext, data: Any, detail: EvaluationDetail) -> dict:
def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict:
"""
The after method is called during the execution of the variation method
after the flag value has been determined. The method is executed
Expand Down
15 changes: 7 additions & 8 deletions ldclient/testing/test_ldclient_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from ldclient import LDClient, Config, Context
from ldclient.hook import Hook, Metadata, EvaluationSeriesContext
from ldclient.migrations import Stage

from ldclient.integrations.test_data import TestData

from typing import Callable, Any
from typing import Callable


def record(label, log):
Expand All @@ -16,18 +15,18 @@ def inner(*args, **kwargs):


class MockHook(Hook):
def __init__(self, before_evaluation: Callable[[EvaluationSeriesContext, Any], dict], after_evaluation: Callable[[EvaluationSeriesContext, Any, EvaluationDetail], dict]):
def __init__(self, before_evaluation: Callable[[EvaluationSeriesContext, dict], dict], after_evaluation: Callable[[EvaluationSeriesContext, dict, EvaluationDetail], dict]):
self.__before_evaluation = before_evaluation
self.__after_evaluation = after_evaluation

@property
def metadata(self) -> Metadata:
return Metadata(name='test-hook')

def before_evaluation(self, series_context: EvaluationSeriesContext, data):
def before_evaluation(self, series_context: EvaluationSeriesContext, data: dict) -> dict:
return self.__before_evaluation(series_context, data)

def after_evaluation(self, series_context: EvaluationSeriesContext, data, detail: EvaluationDetail):
def after_evaluation(self, series_context: EvaluationSeriesContext, data: dict, detail: EvaluationDetail) -> dict:
return self.__after_evaluation(series_context, data, detail)


Expand Down Expand Up @@ -95,7 +94,7 @@ def test_passing_data_from_before_to_after():
assert calls[0] == "from before"


def test_exception_in_before_passes_none():
def test_exception_in_before_passes_empty_dict():
def raise_exception(series_context, data):
raise Exception("error")

Expand All @@ -107,7 +106,7 @@ def raise_exception(series_context, data):
client.variation('flag-key', user, False)

assert len(calls) == 1
assert calls[0] is None
assert calls[0] == {}


def test_exceptions_do_not_affect_data_passing_order():
Expand All @@ -127,7 +126,7 @@ def raise_exception(series_context, data):
# NOTE: These are reversed since the push happens in the after_evaluation
# (when hooks are reversed)
assert calls[0] == "third hook"
assert calls[1] is None
assert calls[1] == {}
assert calls[2] == "first hook"


Expand Down
Loading