diff --git a/tests/test_translator.py b/tests/test_translator.py index 43c07b1..b1cd01b 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -1,10 +1,28 @@ import logging -from trame_server.core import State +from trame_server.core import State, Translator logger = logging.getLogger(__name__) +def test_translator(): + a_translator = Translator() + a_translator.add_translation("foo", "a_foo") + + assert a_translator.translate_key("foo") == "a_foo" + assert a_translator.translate_key("bar") == "bar" + assert a_translator.reverse_translate_key("a_foo") == "foo" + assert a_translator.reverse_translate_key("bar") == "bar" + + b_translator = Translator() + b_translator.set_prefix("b_") + + assert b_translator.translate_key("foo") == "b_foo" + assert b_translator.translate_key("bar") == "b_bar" + assert b_translator.reverse_translate_key("b_foo") == "foo" + assert b_translator.reverse_translate_key("b_bar") == "bar" + + def test_translation(): root_state = State() a_state = State(internal=root_state) @@ -131,3 +149,50 @@ def test_prefix_and_translation(): common_shared_value=789, ) assert expected_state == root_state.to_dict() + + +def test_change_callback(): + # Ensure change callbacks are passed translated kwargs when using translations + test_passed = False + + a_state = State() + a_state.translator.add_translation("foo", "a_foo") + + a_state.bar = 1 + + def on_a_foo_change(*args, **kwargs): + nonlocal test_passed + assert "foo" in kwargs + assert "a_foo" not in kwargs + assert kwargs["foo"] == 123 + test_passed = "foo" in kwargs and "a_foo" not in kwargs + + a_state.change("foo")(on_a_foo_change) + a_state.ready() + a_state.foo = 123 + a_state.flush() + + assert test_passed + + # Ensure change callbacks are passed translated kwargs when using prefix + test_passed = False + + b_state = State() + b_state.translator.set_prefix("b_") + + def on_b_foo_change(*args, **kwargs): + nonlocal test_passed + assert "foo" in kwargs + assert "b_foo" not in kwargs + assert kwargs["foo"] == 456 + test_passed = "foo" in kwargs and "b_foo" not in kwargs + + b_state.change("foo")(on_b_foo_change) + + b_state.ready() + + b_state.foo = 456 + + b_state.flush() + + assert test_passed diff --git a/trame_server/state.py b/trame_server/state.py index 22ce1c3..1dd7bf9 100644 --- a/trame_server/state.py +++ b/trame_server/state.py @@ -238,12 +238,17 @@ def flush(self): # Execute state listeners self._state_listeners.add_all(_keys) + + pushed_state = self.translator.reverse_translate_dict( + self._pushed_state + ) + for callback in self._state_listeners: if self._hot_reload: if not inspect.iscoroutinefunction(callback): callback = reload(callback) - coroutine = callback(**self._pushed_state) + coroutine = callback(**pushed_state) if inspect.isawaitable(coroutine): asynchronous.create_task(coroutine) diff --git a/trame_server/utils/namespace.py b/trame_server/utils/namespace.py index cd0c89a..ed94f14 100644 --- a/trame_server/utils/namespace.py +++ b/trame_server/utils/namespace.py @@ -62,12 +62,14 @@ def __init__(self, prefix=None): logger.info("Translator(prefix=%s)", prefix) self._prefix = prefix self._transl = {} + self._reverse_transl = {} def set_prefix(self, prefix): self._prefix = prefix def add_translation(self, key, translated_key): self._transl[key] = translated_key + self._reverse_transl[translated_key] = key def translate_key(self, key): # Reserved keys @@ -82,12 +84,31 @@ def translate_key(self, key): return key + def reverse_translate_key(self, translated_key): + # Reserved keys + if is_name_reserved(translated_key): + return translated_key + + if translated_key in self._reverse_transl: + return self._reverse_transl[translated_key] + + if self._prefix: + return translated_key.removeprefix(self._prefix) + + return translated_key + def translate_list(self, key_list): return [self.translate_key(v) for v in key_list] def translate_dict(self, key_dict): return {self.translate_key(k): v for k, v in key_dict.items()} + def reverse_translate_list(self, key_list): + return [self.reverse_translate_key(v) for v in key_list] + + def reverse_translate_dict(self, key_dict): + return {self.reverse_translate_key(k): v for k, v in key_dict.items()} + def translate_js_expression(self, state, expression): tokens = [] for token in split_when(expression, js_tokenizer):