Skip to content

Commit

Permalink
Add set_parameter method to Updatable
Browse files Browse the repository at this point in the history
  • Loading branch information
tilmantroester committed Oct 20, 2023
1 parent 08d8ffd commit 387f10c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
13 changes: 11 additions & 2 deletions firecrown/updatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,21 @@ def __setattr__(self, key: str, value: Any) -> None:
"""
if isinstance(value, (Updatable, UpdatableCollection)):
self._updatables.append(value)
if isinstance(value, (InternalParameter, SamplerParameter)):
self.set_parameter(key, value)
else:
super().__setattr__(key, value)

def set_parameter(
self, key: str, value: Union[InternalParameter, SamplerParameter]
) -> None:
"""Assure this InternalParameter or SamplerParameter has not already
been set, and then set it."""

if isinstance(value, SamplerParameter):
self.set_sampler_parameter(key, value)
elif isinstance(value, InternalParameter):
self.set_internal_parameter(key, value)
else:
super().__setattr__(key, value)

def set_internal_parameter(self, key: str, value: InternalParameter) -> None:
"""Assure this InternalParameter has not already been set, and then set it."""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_updatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ def test_set_internal_parameter_rejects_duplicates():
)


def test_set_parameter():
my_updatable = MinimalUpdatable()
my_updatable.set_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_parameter("no_meaning_of_life", parameters.create())

assert hasattr(my_updatable, "the_meaning_of_life")
assert my_updatable.the_meaning_of_life == 42.0

assert hasattr(my_updatable, "no_meaning_of_life")
assert my_updatable.no_meaning_of_life is None


def test_update_rejects_internal_parameters():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
Expand Down

0 comments on commit 387f10c

Please sign in to comment.