Skip to content

Commit

Permalink
Merge pull request #88 from AutoResearch/add-alias-for-model-in-stand…
Browse files Browse the repository at this point in the history
…ardstate

Add alias for model in standardstate
  • Loading branch information
younesStrittmatter authored Aug 10, 2024
2 parents f565ae3 + 7555362 commit 8572b4a
Showing 1 changed file with 156 additions and 3 deletions.
159 changes: 156 additions & 3 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,12 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> from dataclasses import field, dataclass, fields
>>> @dataclass
... class Example:
... a: int = field()
... a: int = field() # base case
... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias
... c: List[int] = field(metadata={"aliases": {
... "ca": lambda x: x, # pass the value unchanged
... "cb": lambda x: [x] # wrap the value in a list
... }}) # Multiple alias
For a field with no aliases, we retrieve values with the base name:
>>> f_a = fields(Example)[0]
Expand All @@ -342,15 +347,104 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> _get_value(f_a, Delta(b=2, a=1))
(1, 'a')
For fields with an alias, we retrieve values with the base name:
>>> f_b = fields(Example)[1]
>>> _get_value(f_b, Delta(b=[2]))
([2], 'b')
... or for the alias name, transformed by the alias lambda function:
>>> _get_value(f_b, Delta(ba=21))
([21], 'ba')
We preferentially get the base name, and then any aliases:
>>> _get_value(f_b, Delta(b=2, ba=21))
(2, 'b')
... , regardless of their order in the `Delta` object:
>>> _get_value(f_b, Delta(ba=21, b=2))
(2, 'b')
Other names are ignored:
>>> _get_value(f_b, Delta(a=1))
(None, None)
and the order of other names is unimportant:
>>> _get_value(f_b, Delta(a=1, b=2))
(2, 'b')
For fields with multiple aliases, we retrieve values with the base name:
>>> f_c = fields(Example)[2]
>>> _get_value(f_c, Delta(c=[3]))
([3], 'c')
... for any alias:
>>> _get_value(f_c, Delta(ca=31))
(31, 'ca')
... transformed by the alias lambda function :
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')
... and ignoring any other names:
>>> print(_get_value(f_c, Delta(a=1)))
(None, None)
... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias:
>>> _get_value(f_c, Delta(c=3, ca=31, cb=32))
(3, 'c')
>>> _get_value(f_c, Delta(ca=31, cb=32))
(31, 'ca')
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')
>>> print(_get_value(f_c, Delta()))
(None, None)
This works with dict objects:
>>> _get_value(f_a, dict(a=13))
(13, 'a')
... with multiple keys:
>>> _get_value(f_b, dict(a=13, b=24, c=35))
(24, 'b')
... and with aliases:
>>> _get_value(f_b, dict(ba=222))
([222], 'ba')
This works with UserDicts:
>>> class MyDelta(UserDict):
... pass
>>> _get_value(f_a, MyDelta(a=14))
(14, 'a')
... with multiple keys:
>>> _get_value(f_b, MyDelta(a=1, b=4, c=9))
(4, 'b')
... and with aliases:
>>> _get_value(f_b, MyDelta(ba=234))
([234], 'ba')
"""

key = f.name
aliases = f.metadata.get("aliases", {})

value, used_key = None, None

if key in other.keys():
value = other[key]
used_key = key
elif aliases: # ... is not an empty dict
for alias_key, wrapping_function in aliases.items():
if alias_key in other:
value = wrapping_function(other[alias_key])
used_key = alias_key
break # we only evaluate the first match

return value, used_key

Expand Down Expand Up @@ -405,8 +499,23 @@ def _get_field_names_and_aliases(s: State):
>>> _get_field_names_and_aliases(SomeState())
['l', 'm']
>>> @dataclass(frozen=True)
... class SomeStateWithAliases(State):
... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}})
... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}})
>>> _get_field_names_and_aliases(SomeStateWithAliases())
['l', 'l1', 'l2', 'm', 'm1']
"""
result = [f.name for f in fields(s)]
result = []

for f in fields(s):
name = f.name
result.append(name)

aliases = f.metadata.get("aliases", {})
result.extend(aliases)

return result


Expand Down Expand Up @@ -1252,6 +1361,20 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]
The last model is available under the `model` property:
>>> (s + dm1 + dm2).model
DummyClassifier(constant=3)
If there is no model, `None` is returned:
>>> print(s.model)
None
`models` can also be updated using a Delta with a single `model`:
>>> dm3 = Delta(model=DummyClassifier(constant=4))
>>> (s + dm1 + dm3).model
DummyClassifier(constant=4)
We can use properties X, y, iv_names and dv_names as 'getters' ...
>>> x_v = Variable('x')
>>> y_v = Variable('y')
Expand Down Expand Up @@ -1280,6 +1403,24 @@ class StandardState(State):
1 2
2 3
However, if the property has a deticated setter, we can still use them as getter:
>>> s.model is None
True
>>> from sklearn.linear_model import LinearRegression
>>> @on_state
... def add_model(_model):
... return Delta(model=_model)
>>> s = add_model(s, _model=LinearRegression())
>>> s.models
[LinearRegression()]
>>> s.model
LinearRegression()
"""
Expand All @@ -1295,7 +1436,7 @@ class StandardState(State):
)
models: List[BaseEstimator] = field(
default_factory=list,
metadata={"delta": "extend"},
metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}},
)

@property
Expand Down Expand Up @@ -1392,6 +1533,18 @@ def y(self) -> pd.DataFrame:
return pd.DataFrame()
return self.experiment_data[self.dv_names]

@property
def model(self):
if len(self.models) == 0:
return None
# The property to access the backing field
return self.models[-1]

@model.setter
def model(self, value):
# Control the setting behavior
self.models.append(value)


X = TypeVar("X")
Y = TypeVar("Y")
Expand Down

0 comments on commit 8572b4a

Please sign in to comment.