Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 17, 2023
1 parent bfce5e9 commit 99c3f87
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 35 deletions.
2 changes: 1 addition & 1 deletion reinforced_lib/exts/base_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def observation_space(self) -> gym.spaces.Space:
def get_agent_params(
self,
agent_type: type = None,
agent_parameter_space: gym.spaces.dict = None,
agent_parameter_space: gym.spaces.Dict = None,
user_parameters: dict[str, any] = None
) -> dict[str, any]:
"""
Expand Down
14 changes: 8 additions & 6 deletions reinforced_lib/rlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def observation_space(self) -> gym.spaces.Space:
if not self._agent:
raise NoAgentError()
else:
return gym.spaces.dict({
return gym.spaces.Dict({
'update_observation_space': self._agent.update_observation_space,
'sample_observation_space': self._agent.sample_observation_space
})
Expand Down Expand Up @@ -522,15 +522,17 @@ def load(

rlib = RLib(
save_directory=experiment_state["save_directory"],
auto_checkpoint=experiment_state["auto_checkpoint"]
auto_checkpoint=experiment_state["auto_checkpoint"],
no_ext_mode=experiment_state["ext_type"] is None
)

rlib._agent_containers = []

if ext_params:
rlib.set_ext(experiment_state["ext_type"], ext_params)
else:
rlib.set_ext(experiment_state["ext_type"], experiment_state["ext_params"])
if experiment_state["ext_type"] is not None:
if ext_params:
rlib.set_ext(experiment_state["ext_type"], ext_params)
else:
rlib.set_ext(experiment_state["ext_type"], experiment_state["ext_params"])

if agent_params:
rlib.set_agent(experiment_state["agent_type"], agent_params)
Expand Down
Empty file removed test/agents/wifi/__init__.py
Empty file.
27 changes: 11 additions & 16 deletions test/test_rlib.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,28 @@
import reinforced_lib as rfl
from reinforced_lib.agents.wifi import ParticleFilter
from reinforced_lib.exts.wifi import IEEE_802_11_ax_RA
from reinforced_lib.agents.mab import EGreedy


if __name__ == '__main__':
rl = rfl.RLib(
agent_type=ParticleFilter,
ext_type=IEEE_802_11_ax_RA
agent_type=EGreedy,
agent_params={'n_arms': 4, 'e': 0.1},
no_ext_mode=True
)

print(rl.observation_space)

observations = {
'time': 0.0,
'n_successful': 0,
'n_failed': 0,
'power': 16.0,
'cw': 15
'action': 3,
'reward': 1.0
}

action = rl.sample(**observations)
action = rl.sample(update_observations=observations)
print(action)

observations = {
'time': 0.0,
'n_successful': 10,
'n_failed': 0,
'power': 16.0,
'cw': 15
'action': 3,
'reward': 1.0
}

action = rl.sample(**observations)
action = rl.sample(update_observations=observations)
print(action)
21 changes: 9 additions & 12 deletions test/test_rlib_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import jax.numpy as jnp
import reinforced_lib as rfl

from reinforced_lib.agents.mab import ThompsonSampling
from reinforced_lib.exts.wifi import IEEE_802_11_ax_RA
from reinforced_lib.agents.mab import EGreedy
from reinforced_lib.rlib import RLib
from reinforced_lib.logs import *

Expand All @@ -21,9 +20,9 @@ class TestRLibSerialization(unittest.TestCase):

def run_experiment(self, reload: bool, new_decay: float = None) -> list[int]:
rl = rfl.RLib(
agent_type=ThompsonSampling,
agent_params={"decay": 0.0},
ext_type=IEEE_802_11_ax_RA,
agent_type=EGreedy,
agent_params={'n_arms': len(self.arms_probs), 'e': 0.1},
no_ext_mode=True,
logger_types=CsvLogger,
logger_sources=['n_failed', 'n_successful', ('action', SourceType.METRIC)],
logger_params={'csv_path': f'output_reload={reload}_new-decay={new_decay}.csv'}
Expand All @@ -36,20 +35,18 @@ def run_experiment(self, reload: bool, new_decay: float = None) -> list[int]:
for t in self.time:
r = int(jax.random.uniform(self.key) < self.arms_probs[a])
observations = {
'time': t,
'mcs': a,
'n_successful': r,
'n_failed': 1 - r,
'action': a,
'reward': r
}

a = rl.sample(**observations)
actions.append(int(a))
a = rl.sample(update_observations=observations)
actions.append(a)

if t > self.t_change and not reloaded:
save_path = rl.save()

if new_decay:
rl = RLib.load(save_path, agent_params={"decay": new_decay}, restore_loggers=False)
rl = RLib.load(save_path, agent_params={'n_arms': len(self.arms_probs), 'e': 0.5}, restore_loggers=False)
else:
rl = RLib.load(save_path, restore_loggers=False)
reloaded = True
Expand Down

0 comments on commit 99c3f87

Please sign in to comment.