From 387f10c908207742598095913d07e2b68e62d473 Mon Sep 17 00:00:00 2001 From: Tilman Troester Date: Fri, 20 Oct 2023 16:17:56 +0200 Subject: [PATCH] Add set_parameter method to Updatable --- firecrown/updatable.py | 13 +++++++++++-- tests/test_updatable.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 42e46a138..b37667e77 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -86,12 +86,21 @@ def __setattr__(self, key: str, value: Any) -> None: """ if isinstance(value, (Updatable, UpdatableCollection)): self._updatables.append(value) + if isinstance(value, (InternalParameter, SamplerParameter)): + self.set_parameter(key, value) + else: + super().__setattr__(key, value) + + def set_parameter( + self, key: str, value: Union[InternalParameter, SamplerParameter] + ) -> None: + """Assure this InternalParameter or SamplerParameter has not already + been set, and then set it.""" + if isinstance(value, SamplerParameter): self.set_sampler_parameter(key, value) elif isinstance(value, InternalParameter): self.set_internal_parameter(key, value) - else: - super().__setattr__(key, value) def set_internal_parameter(self, key: str, value: InternalParameter) -> None: """Assure this InternalParameter has not already been set, and then set it.""" diff --git a/tests/test_updatable.py b/tests/test_updatable.py index 3e93e2dc6..3d8ba02fc 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -184,6 +184,18 @@ def test_set_internal_parameter_rejects_duplicates(): ) +def test_set_parameter(): + my_updatable = MinimalUpdatable() + my_updatable.set_parameter("the_meaning_of_life", parameters.create(42.0)) + my_updatable.set_parameter("no_meaning_of_life", parameters.create()) + + assert hasattr(my_updatable, "the_meaning_of_life") + assert my_updatable.the_meaning_of_life == 42.0 + + assert hasattr(my_updatable, "no_meaning_of_life") + assert my_updatable.no_meaning_of_life is None + + def test_update_rejects_internal_parameters(): my_updatable = MinimalUpdatable() my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))