diff --git a/README.MD b/README.MD index 214dfd5..4f7454d 100644 --- a/README.MD +++ b/README.MD @@ -96,7 +96,7 @@ python3 -m social-distancing-sim.scripts.run_single_population ````python import social_distancing_sim.environment as env -# The graph is the "true" population model, containing all the nodes and their data +# The graph is the "true" environment model, containing all the nodes and their data graph = env.Graph(community_n=50, community_size_mean=15, community_p_in=0.06, # The likelihood of intra-community connections @@ -159,7 +159,7 @@ def run_and_replay(pop, *args, **kwargs): save = True -# Create a population with high inter and intra connectivity +# Create a environment with high inter and intra connectivity pop = env.Environment(name='A herd of cats', disease=env.Disease(name='COVID-19'), observation_space=env.ObservationSpace(graph=env.Graph(community_n=40, @@ -168,7 +168,7 @@ pop = env.Environment(name='A herd of cats', test_rate=1), environment_plotting=env.EnvironmentPlotting(ts_fields_g2=["Turn score"])) -# Create a population with reduced inter and intra connectivity +# Create a environment with reduced inter and intra connectivity pop_distanced = env.Environment(name='A socially responsible environment', disease=env.Disease(name='COVID-19'), observation_space=env.ObservationSpace(graph=env.Graph(community_n=40, @@ -312,7 +312,6 @@ Parallel(n_jobs=2, # Basic agents and strategy comparison -![Test basic agents](https://github.com/garethjns/social-distancing-sim/blob/master/images/basic_agents_example.gif) ````bash python3 -m social-distancing-sim.scripts.visual_compare_basic_agents @@ -379,6 +378,7 @@ Parallel(n_jobs=4, ``` # MultiSims: Statistical comparisons - basic agents and strategy comparison +![Test basic agents](https://github.com/garethjns/social-distancing-sim/blob/master/images/agent_comparison_score_example.png) ````bash python3 -m social-distancing-sim.scripts.stats_compare_basic_agents diff --git a/images/multiagents.gif b/images/multiagents.gif new file mode 100644 index 0000000..afa3bbf Binary files /dev/null and b/images/multiagents.gif differ diff --git a/scripts/stats_compare_basic_agents.py b/scripts/stats_compare_basic_agents.py index 2733ed6..387b93e 100644 --- a/scripts/stats_compare_basic_agents.py +++ b/scripts/stats_compare_basic_agents.py @@ -2,8 +2,9 @@ Run all the basic agents with a number of actions per turn using a MultiSim (n reps = 100). Doesn't save .gifs of each rep, rather plots distributions of final scores. -Parameters here match the visual version run in scripts/stats_compare_basic_agents.py. +Parameters here are similar to the visual version run in scripts/visual_compare_basic_agents.py. """ +from typing import List import matplotlib.pyplot as plt import numpy as np @@ -15,7 +16,8 @@ import social_distancing_sim.sim as sim -def plot_dists(result: str = "Overall score") -> plt.Figure: +def plot_dists(multi_sims: List[sim.MultiSim], + result: str = "Overall score") -> plt.Figure: """Plot final score distributions across repetitions, for all agents.""" fig, axs = plt.subplots(nrows=4, ncols=1, @@ -84,17 +86,17 @@ def plot_dists(result: str = "Overall score") -> plt.Figure: n_steps=125) multi_sims.append(sim.MultiSim(sim_, - name='basic agent comparison 2', + name='basic agent comparison', n_reps=100)) # Run all the sims. No need to parallelize here as it's done across n reps in MultiSim.run() for ms in tqdm(multi_sims): ms.run() - fig = plot_dists("Overall score") + fig = plot_dists(multi_sims, "Overall score") plt.show() - fig.savefig('agent_comparison_score.png') + fig.savefig('basic_agent_comparison_score.png') - fig = plot_dists("Total deaths") + fig = plot_dists(multi_sims, "Total deaths") plt.show() - fig.savefig('agent_comparison_deaths.png') + fig.savefig('basic_agent_comparison_deaths.png') diff --git a/scripts/stats_compare_multi_agents.py b/scripts/stats_compare_multi_agents.py new file mode 100644 index 0000000..ef8bbc1 --- /dev/null +++ b/scripts/stats_compare_multi_agents.py @@ -0,0 +1,131 @@ +""" +Run all the basic agents with a number of actions per turn using a MultiSim (n reps = 100). Doesn't save .gifs of +each rep, rather plots distributions of final scores. + +Parameters here are similar to the visual version run in scripts/visual_compare_multi_agents.py. +""" +from typing import List + +import matplotlib.pyplot as plt +import seaborn as sns +from tqdm import tqdm + +import social_distancing_sim.agent as agent +import social_distancing_sim.environment as env +import social_distancing_sim.sim as sim + + +def plot_dists(multi_sims: List[sim.MultiSim], + result: str = "Overall score") -> plt.Figure: + """Plot final score distributions across repetitions, for all agents.""" + fig, ax = plt.subplots(nrows=1, + ncols=1, + figsize=(8, 8)) + + min_score = 0 + max_score = 0 + for run in multi_sims: + min_score = min(min_score, run.results[result].min()) + max_score = max(max_score, run.results[result].max()) + + sns.distplot(run.results[result], + hist=False, + label=run.sim.agent.name) + + ax.set_title("Policy comparison", + fontweight='bold') + ax.set_xlim([min_score - abs(min_score * 0.2), max_score + abs(max_score * 0.2)]) + ax.set_xlabel(ax.get_xlabel(), + fontweight='bold') + ax.set_ylabel('Prop.', + fontweight='bold') + ax.legend(title='Agent') + + return fig + + +if __name__ == "__main__": + seed = 123 + steps = 250 + distancing_params = {"actions_per_turn": 15, + "start_step": {'isolate': 15, 'reconnect': 60}, + "end_step": {'isolate': 55, 'reconnect': steps}} + vaccination_params = {"actions_per_turn": 5, + "start_step": {'vaccinate': 60}, + "end_step": {'vaccinate': steps}} + treatment_params = {"actions_per_turn": 5, + "start_step": {'treat': 50}, + "end_step": {'treat': steps}} + + # Create a parameter set containing all combinations of the 3 policy agents, and a small set of n_actions + agents = [agent.MultiAgent(name="Distancing", + agents=[agent.DistancingPolicyAgent(**distancing_params)]), + agent.MultiAgent(name="Vaccination", + agents=[agent.VaccinationPolicyAgent(**vaccination_params)]), + agent.MultiAgent(name="Treatment", + agents=[agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Distancing, vaccination", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.VaccinationPolicyAgent(**vaccination_params)]), + agent.MultiAgent(name="Distancing, treatment", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Vaccination, treatment", + agents=[agent.VaccinationPolicyAgent(**vaccination_params), + agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Distancing, vaccination, treatment", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.VaccinationPolicyAgent(**vaccination_params), + agent.TreatmentPolicyAgent(**treatment_params)])] + + # Loop over the parameter set and create the Agents, Environments, and the Sim handler + multi_sims = [] + for agt in agents: + # Name the environment according to the agent used + env_ = env.Environment(name=f"{type(agt).__name__} - {agt.name}", + action_space=env.ActionSpace(vaccinate_cost=0, + treat_cost=0, + isolate_cost=0, + isolate_efficiency=0.70, + reconnect_efficiency=0.2, + treatment_conclusion_chance=0.5, + treatment_recovery_rate_modifier=1.8, + vaccinate_efficiency=1.25), + disease=env.Disease(name='COVID-19', + virulence=0.0055, + seed=None, + immunity_mean=0.7, + recovery_rate=0.9, + immunity_decay_mean=0.01), + healthcare=env.Healthcare(capacity=200), + observation_space=env.ObservationSpace( + graph=env.Graph(community_n=30, + community_size_mean=20, + community_p_out=0.08, + community_p_in=0.16, + seed=None), + test_rate=1, + seed=None), + initial_infections=5, + random_infection_chance=1, + seed=None) + + sim_ = sim.Sim(env=env_, + agent=agt, + n_steps=150) + + multi_sims.append(sim.MultiSim(sim_, + name='policy agent comparison', + n_reps=100)) + + # Run all the sims. No need to parallelize here as it's done across n reps in MultiSim.run() + for ms in tqdm(multi_sims): + ms.run() + + fig = plot_dists(multi_sims, "Overall score") + plt.show() + fig.savefig('multi_agent_comparison_score.png') + + fig = plot_dists(multi_sims, "Total deaths") + plt.show() + fig.savefig('multi_agent_comparison_deaths.png') diff --git a/scripts/visual_compare_basic_agents_small.py b/scripts/visual_compare_basic_agents_small.py index 5d3db78..1912f5b 100644 --- a/scripts/visual_compare_basic_agents_small.py +++ b/scripts/visual_compare_basic_agents_small.py @@ -1,9 +1,4 @@ - -""" -Run all the basic agents with a number of actions per turn (n reps = 1). Generate and save .gif. - -Parameters here match the stats version run in scripts/stats_compare_basic_agents.py. -""" +"""Run all the basic agents with a number of actions per turn (n reps = 1). Generate and save .gif.""" import numpy as np from joblib import Parallel, delayed @@ -62,4 +57,3 @@ def run_and_replay(sim): # Run all the prepared Sims Parallel(n_jobs=4, backend='loky')(delayed(run_and_replay)(sim) for sim in sims) - diff --git a/scripts/visual_compare_multi_agents.py b/scripts/visual_compare_multi_agents.py new file mode 100644 index 0000000..a6d8999 --- /dev/null +++ b/scripts/visual_compare_multi_agents.py @@ -0,0 +1,92 @@ +"""A number of different MultiAgent setups (n reps = 1). Generate and save .gif.""" + +from joblib import Parallel, delayed + +import social_distancing_sim.agent as agent +import social_distancing_sim.environment as env +import social_distancing_sim.sim as sim + + +def run_and_replay(sim): + sim.run() + if sim.save: + sim.env.replay() + + +if __name__ == "__main__": + seed = 123 + steps = 250 + distancing_params = {"actions_per_turn": 15, + "start_step": {'isolate': 15, 'reconnect': 60}, + "end_step": {'isolate': 55, 'reconnect': steps}} + vaccination_params = {"actions_per_turn": 5, + "start_step": {'vaccinate': 60}, + "end_step": {'vaccinate': steps}} + treatment_params = {"actions_per_turn": 5, + "start_step": {'treat': 50}, + "end_step": {'treat': steps}} + + # Create a parameter set containing all combinations of the 3 policy agents, and a small set of n_actions + agents = [agent.MultiAgent(name="Distancing", + agents=[agent.DistancingPolicyAgent(**distancing_params)]), + agent.MultiAgent(name="Vaccination", + agents=[agent.VaccinationPolicyAgent(**vaccination_params)]), + agent.MultiAgent(name="Treatment", + agents=[agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Distancing, vaccination", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.VaccinationPolicyAgent(**vaccination_params)]), + agent.MultiAgent(name="Distancing, treatment", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Vaccination, treatment", + agents=[agent.VaccinationPolicyAgent(**vaccination_params), + agent.TreatmentPolicyAgent(**treatment_params)]), + agent.MultiAgent(name="Distancing, vaccination, treatment", + agents=[agent.DistancingPolicyAgent(**distancing_params), + agent.VaccinationPolicyAgent(**vaccination_params), + agent.TreatmentPolicyAgent(**treatment_params)])] + + # Loop over the parameter set and create the Agents, Environments, and the Sim handler + sims = [] + for agt in agents: + # Name the environment according to the agent used + env_ = env.Environment(name=f"{type(agt).__name__} - {agt.name}", + action_space=env.ActionSpace(isolate_efficiency=0.75, + reconnect_efficiency=0.2, + treatment_conclusion_chance=0.2, + treatment_recovery_rate_modifier=1.8, + vaccinate_efficiency=0.95), + disease=env.Disease(name='COVID-19', + virulence=0.0055, + seed=seed, + immunity_mean=0.7, + recovery_rate=0.95, + immunity_decay_mean=0.012), + healthcare=env.Healthcare(capacity=200), + environment_plotting=env.EnvironmentPlotting( + auto_lim_x=False, + ts_fields_g2=["Actions taken", "Vaccinate actions", "Isolate actions", + "Reconnect actions", "Treat actions"]), + observation_space=env.ObservationSpace( + graph=env.Graph(community_n=30, + community_size_mean=20, + community_p_out=0.08, + community_p_in=0.16, + seed=seed + 1), + test_rate=1, + seed=seed + 2), + initial_infections=5, + random_infection_chance=1, + seed=seed + 3) + + sims.append(sim.Sim(env=env_, + agent=agt, + n_steps=steps, + plot=False, + save=True, + tqdm_on=True)) # Show progress bars for running sims + + # Run all the prepared Sims + Parallel(n_jobs=9, + backend='loky')(delayed(run_and_replay)(sim) for sim in sims) diff --git a/scripts/visual_compare_policy_agents.py b/scripts/visual_compare_policy_agents.py new file mode 100644 index 0000000..5c09506 --- /dev/null +++ b/scripts/visual_compare_policy_agents.py @@ -0,0 +1,76 @@ +"""Run all the basic policy agents with a number of actions per turn (n reps = 1). Generate and save .gif.""" + +from functools import partial + +import numpy as np +from joblib import Parallel, delayed + +import social_distancing_sim.agent as agent +import social_distancing_sim.environment as env +import social_distancing_sim.sim as sim + + +def run_and_replay(sim): + sim.run() + if sim.save: + sim.env.replay() + + +if __name__ == "__main__": + seed = 123 + + # Create a parameter set containing all combinations of the 3 policy agents, and a small set of n_actions + agents = [agent.DummyAgent, + partial(agent.DistancingPolicyAgent, + start_step={'isolate': 10, + 'reconnect': 50}, + end_step={'isolate': 40, + 'reconnect': 60}), + partial(agent.VaccinationPolicyAgent, + start_step={'vaccinate': 20}, + end_step={'vaccinate': 40}), + partial(agent.TreatmentPolicyAgent, + start_step={'treat': 20}, + end_step={'treat': 40}), + ] + n_actions = [10, 20] + sims = [] + + # Loop over the parameter set and create the Agents, Environments, and the Sim handler + for n_act, agt in np.array(np.meshgrid(n_actions, + agents)).T.reshape(-1, 2): + agt_ = agt(actions_per_turn=n_act) + + # Name the environment according to the agent used + env_ = env.Environment(name=f"{type(agt_).__name__} - {n_act} actions", + action_space=env.ActionSpace(isolate_efficiency=0.5, + vaccinate_efficiency=0.95), + disease=env.Disease(name='COVID-19', + virulence=0.008, + seed=seed, + immunity_mean=0.7, + recovery_rate=0.9, + immunity_decay_mean=0.01), + healthcare=env.Healthcare(capacity=75), + environment_plotting=env.EnvironmentPlotting(ts_fields_g2=["Actions taken", + "Overall score"]), + observation_space=env.ObservationSpace( + graph=env.Graph(community_n=20, + community_size_mean=15, + considered_immune_threshold=0.7, + seed=seed + 1), + test_rate=1, + seed=seed + 2), + initial_infections=15, + seed=seed + 3) + + sims.append(sim.Sim(env=env_, + agent=agt_, + n_steps=300, + plot=False, + save=True, + tqdm_on=True)) # Show progress bars for running sims + + # Run all the prepared Sims + Parallel(n_jobs=4, + backend='loky')(delayed(run_and_replay)(sim) for sim in sims) diff --git a/scripts/visual_compare_testing_rate.py b/scripts/visual_compare_testing_rate.py index f03de2f..858ee56 100644 --- a/scripts/visual_compare_testing_rate.py +++ b/scripts/visual_compare_testing_rate.py @@ -5,7 +5,7 @@ pop = env.Environment(name='A herd of cats, observed', disease=env.Disease(name='COVID-19'), - healthcare=env.healthcare(), + healthcare=env.Healthcare(), observation_space=env.ObservationSpace(graph=env.Graph(community_n=40, community_size_mean=16, seed=123), diff --git a/scripts/visual_compare_two_populations.py b/scripts/visual_compare_two_populations.py index 2113939..9165779 100644 --- a/scripts/visual_compare_two_populations.py +++ b/scripts/visual_compare_two_populations.py @@ -12,7 +12,7 @@ def run_and_replay(pop, *args, **kwargs): if __name__ == "__main__": save = True - # Create a population with high inter and intra connectivity + # Create a environment with high inter and intra connectivity pop = env.Environment(name='A herd of cats', disease=env.Disease(name='COVID-19'), observation_space=env.ObservationSpace(graph=env.Graph(community_n=40, @@ -21,7 +21,7 @@ def run_and_replay(pop, *args, **kwargs): test_rate=1), environment_plotting=env.EnvironmentPlotting(ts_fields_g2=["Turn score"])) - # Create a population with reduced inter and intra connectivity + # Create a environment with reduced inter and intra connectivity pop_distanced = env.Environment(name='A socially responsible environment', disease=env.Disease(name='COVID-19'), observation_space=env.ObservationSpace(graph=env.Graph(community_n=40, diff --git a/scripts/visual_run_simulation_with_agent.py b/scripts/visual_run_simulation_with_agent.py index e01ae8c..2fe5a9d 100644 --- a/scripts/visual_run_simulation_with_agent.py +++ b/scripts/visual_run_simulation_with_agent.py @@ -24,8 +24,8 @@ "Observed overall score"])) sim = sim.Sim(env=pop, - agent=env.accinationAgent(actions_per_turn=25, - seed=seed), + agent=env.VaccinationAgent(actions_per_turn=25, + seed=seed), plot=True, save=True) diff --git a/scripts/visual_run_single_population.py b/scripts/visual_run_single_population.py index fbad323..258b25b 100644 --- a/scripts/visual_run_single_population.py +++ b/scripts/visual_run_single_population.py @@ -3,7 +3,7 @@ import social_distancing_sim.environment as env if __name__ == "__main__": - # The graph is the "true" population model, containing all the nodes and their data + # The graph is the "true" environment model, containing all the nodes and their data graph = env.Graph(community_n=50, community_size_mean=15, community_p_in=0.06, # The likelihood of intra-community connections diff --git a/social_distancing_sim/__init__.py b/social_distancing_sim/__init__.py index d94c6fe..b1d0757 100644 --- a/social_distancing_sim/__init__.py +++ b/social_distancing_sim/__init__.py @@ -1,5 +1,5 @@ MAJOR = 0 -MINOR = 4 +MINOR = 5 PATCH = 0 __version__ = ".".join(str(v) for v in [MAJOR, MINOR, PATCH]) diff --git a/social_distancing_sim/agent/__init__.py b/social_distancing_sim/agent/__init__.py index 93e2832..17c6624 100644 --- a/social_distancing_sim/agent/__init__.py +++ b/social_distancing_sim/agent/__init__.py @@ -1,4 +1,8 @@ -from social_distancing_sim.agent.random_agent import RandomAgent -from social_distancing_sim.agent.isolation_agent import IsolationAgent -from social_distancing_sim.agent.vaccination_agent import VaccinationAgent -from social_distancing_sim.agent.dummy_agent import DummyAgent +from social_distancing_sim.agent.basic_agents.random_agent import RandomAgent +from social_distancing_sim.agent.basic_agents.isolation_agent import IsolationAgent +from social_distancing_sim.agent.basic_agents.vaccination_agent import VaccinationAgent +from social_distancing_sim.agent.basic_agents.dummy_agent import DummyAgent +from social_distancing_sim.agent.policy_agents.distancing_policy_agent import DistancingPolicyAgent +from social_distancing_sim.agent.policy_agents.vaccination_policy_agent import VaccinationPolicyAgent +from social_distancing_sim.agent.policy_agents.treatment_policy_agent import TreatmentPolicyAgent +from social_distancing_sim.agent.multi_agent.multi_agent import MultiAgent diff --git a/social_distancing_sim/agent/agent_base.py b/social_distancing_sim/agent/agent_base.py index 08cbca5..f746c8d 100644 --- a/social_distancing_sim/agent/agent_base.py +++ b/social_distancing_sim/agent/agent_base.py @@ -1,6 +1,6 @@ import abc import copy -from typing import List, Union, Dict, Tuple +from typing import List, Union, Dict import numpy as np @@ -14,10 +14,29 @@ class AgentBase(metaclass=abc.ABCMeta): def __init__(self, name: str = 'unnamed_agent', seed: Union[None, int] = None, - actions_per_turn: int = 5) -> None: + actions_per_turn: int = 5, + start_step: Union[Dict[str, int], None] = None, + end_step: Union[Dict[str, int], None] = None) -> None: + """ + + :param name: Agent name. + :param seed: Seed. + :param actions_per_turn: Number of actions to return each turn. Automatically limited by available targets each + turn, if necessary. + :param start_step: Dict keyed by action names, with ints indicating step to start performing actions. + :param end_step: Dict keyed by action names, with ints indicating step to stop performing actions. + """ self.seed = seed self.name = name + + if start_step is None: + start_step = {} + self.start_step = start_step + if end_step is None: + end_step = {} + self.end_step = end_step self.actions_per_turn = actions_per_turn + self._step = 0 # Track steps as number of .sample calls to agent self._prepare_random_state() def _prepare_random_state(self) -> None: @@ -32,6 +51,15 @@ def available_actions(self) -> List[str]: """ return self.action_space.available_actions + @property + def currently_active_actions(self) -> List[str]: + """Return the current active actions based on defined time periods.""" + + active_actions = [a for a in self.available_actions + if (self._step >= self.start_step.get(a, 0)) and (self._step <= self.end_step.get(a, np.inf))] + + return active_actions + @staticmethod def available_targets(obs: ObservationSpace) -> List[int]: """ @@ -48,6 +76,7 @@ def _check_available_targets(self, obs: ObservationSpace) -> int: :param obs: Observation space to get targets from. :return: Number of possible actions given available targets. """ + n_available_targets = len(self.available_targets(obs)) return min(self.actions_per_turn, n_available_targets) @@ -60,7 +89,21 @@ def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: """ pass - def sample(self, obs: ObservationSpace) -> Dict[int, str]: + def get_actions(self, obs: ObservationSpace) -> Dict[int, str]: + """Get next set of actions and targets and track.""" + + actions = self.select_actions(obs) + self._step += 1 + return actions + + def sample(self, obs: ObservationSpace, + track: bool = True) -> Dict[int, str]: + """ + Randomly return self.actions_per_turn actions and targets and optionally track. + + :param obs: ObservationSpace. + :param track: If True, track in self._steps. Can be set to False if desired. + """ n = self._check_available_targets(obs) # Randomly pick n actions and targets @@ -70,6 +113,9 @@ def sample(self, obs: ObservationSpace) -> Dict[int, str]: replace=False, size=n) + if track: + self._step += 1 + return {t: a for t, a in zip(targets, actions)} def clone(self) -> "AgentBase": @@ -77,3 +123,7 @@ def clone(self) -> "AgentBase": clone = copy.deepcopy(self) clone._prepare_random_state() return clone + + def reset(self): + self._step = 0 + self._prepare_random_state() diff --git a/tests/integration/population/__init__.py b/social_distancing_sim/agent/basic_agents/__init__.py similarity index 100% rename from tests/integration/population/__init__.py rename to social_distancing_sim/agent/basic_agents/__init__.py diff --git a/social_distancing_sim/agent/dummy_agent.py b/social_distancing_sim/agent/basic_agents/dummy_agent.py similarity index 99% rename from social_distancing_sim/agent/dummy_agent.py rename to social_distancing_sim/agent/basic_agents/dummy_agent.py index f41d686..19a1360 100644 --- a/social_distancing_sim/agent/dummy_agent.py +++ b/social_distancing_sim/agent/basic_agents/dummy_agent.py @@ -6,7 +6,6 @@ class DummyAgent(AgentBase): """Doesn't do anything.""" - @property def available_actions(self) -> list: return [] diff --git a/social_distancing_sim/agent/isolation_agent.py b/social_distancing_sim/agent/basic_agents/isolation_agent.py similarity index 89% rename from social_distancing_sim/agent/isolation_agent.py rename to social_distancing_sim/agent/basic_agents/isolation_agent.py index 6747d9d..d4f2f9e 100644 --- a/social_distancing_sim/agent/isolation_agent.py +++ b/social_distancing_sim/agent/basic_agents/isolation_agent.py @@ -16,7 +16,9 @@ def available_targets(obs: ObservationSpace) -> Dict[str, List[int]]: 'reconnect': list(set(obs.current_clear_nodes).intersection(obs.isolated_nodes))} def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + """Selects randomly between both actions, any time frames are totally ignored.""" actions = self._random_state.choice(self.available_actions, + replace=True, size=self.actions_per_turn) available_targets = self.available_targets(obs) diff --git a/social_distancing_sim/agent/random_agent.py b/social_distancing_sim/agent/basic_agents/random_agent.py similarity index 65% rename from social_distancing_sim/agent/random_agent.py rename to social_distancing_sim/agent/basic_agents/random_agent.py index 5fb8bf9..f3daeb8 100644 --- a/social_distancing_sim/agent/random_agent.py +++ b/social_distancing_sim/agent/basic_agents/random_agent.py @@ -5,6 +5,12 @@ class RandomAgent(AgentBase): + """ + RandomAgent randomly selects an action and target. + + It doesn't support reconnection action, as this breaks on connected nodes - this is a good error check so will stay + as it is for now. + """ @property def available_actions(self) -> List[str]: return ['vaccinate', 'isolate'] diff --git a/social_distancing_sim/agent/basic_agents/treatment_agent.py b/social_distancing_sim/agent/basic_agents/treatment_agent.py new file mode 100644 index 0000000..f89c1fd --- /dev/null +++ b/social_distancing_sim/agent/basic_agents/treatment_agent.py @@ -0,0 +1,23 @@ +from social_distancing_sim.agent.agent_base import AgentBase +from typing import List, Dict + +from social_distancing_sim.agent.agent_base import AgentBase +from social_distancing_sim.environment.observation_space import ObservationSpace + + +class TreatmentAgent(AgentBase): + """TreatmentAgent randomly vaccinates clear nodes.""" + + @property + def available_actions(self) -> List[str]: + """Isolation agent can only isolate. It can't even un-isolate (yet?)""" + return ['treat'] + + @staticmethod + def available_targets(obs: ObservationSpace) -> List[int]: + return obs.current_infected_nodes + + def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + # Don't track sample call here as self.get_actions() will handle that. + return self.sample(obs, + track=False) \ No newline at end of file diff --git a/social_distancing_sim/agent/vaccination_agent.py b/social_distancing_sim/agent/basic_agents/vaccination_agent.py similarity index 74% rename from social_distancing_sim/agent/vaccination_agent.py rename to social_distancing_sim/agent/basic_agents/vaccination_agent.py index 0796ef7..9ba1b44 100644 --- a/social_distancing_sim/agent/vaccination_agent.py +++ b/social_distancing_sim/agent/basic_agents/vaccination_agent.py @@ -5,6 +5,7 @@ class VaccinationAgent(AgentBase): + """VaccinationAgent randomly vaccinates clear nodes.""" @property def available_actions(self) -> List[str]: """Isolation agent can only isolate. It can't even un-isolate (yet?)""" @@ -15,4 +16,6 @@ def available_targets(obs: ObservationSpace) -> List[int]: return list(set(obs.current_clear_nodes).difference(obs.current_immune_nodes)) def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: - return self.sample(obs) + # Don't track sample call here as self.get_actions() will handle that. + return self.sample(obs, + track=False) diff --git a/social_distancing_sim/agent/multi_agent/__init__.py b/social_distancing_sim/agent/multi_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/social_distancing_sim/agent/multi_agent/multi_agent.py b/social_distancing_sim/agent/multi_agent/multi_agent.py new file mode 100644 index 0000000..8118342 --- /dev/null +++ b/social_distancing_sim/agent/multi_agent/multi_agent.py @@ -0,0 +1,28 @@ +from typing import List, Dict + +from social_distancing_sim.agent.agent_base import AgentBase +from social_distancing_sim.environment.observation_space import ObservationSpace + + +class MultiAgent(AgentBase): + """ + Combine other agents/policy agents + + Actions per turn is dynamic and determined by individual agents + """ + + def __init__(self, agents: List[AgentBase], *args, **kwargs): + super().__init__(*args, **kwargs) + + self.agents = agents + self.actions_per_turn = sum([a.actions_per_turn for a in agents]) + + def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + """Ask each agent for their actions. They handle n and availability""" + + actions = {} + for agent in self.agents: + actions.update(agent.get_actions(obs)) + + return actions + \ No newline at end of file diff --git a/social_distancing_sim/agent/policy_agents/__init__.py b/social_distancing_sim/agent/policy_agents/__init__.py new file mode 100644 index 0000000..7279a5e --- /dev/null +++ b/social_distancing_sim/agent/policy_agents/__init__.py @@ -0,0 +1 @@ +"""Policy agents are Agents that use the start and end options to define an active period.""" \ No newline at end of file diff --git a/social_distancing_sim/agent/policy_agents/distancing_policy_agent.py b/social_distancing_sim/agent/policy_agents/distancing_policy_agent.py new file mode 100644 index 0000000..fb11e66 --- /dev/null +++ b/social_distancing_sim/agent/policy_agents/distancing_policy_agent.py @@ -0,0 +1,47 @@ +from typing import List, Dict + +from social_distancing_sim.agent.agent_base import AgentBase +from social_distancing_sim.environment.observation_space import ObservationSpace + + +class DistancingPolicyAgent(AgentBase): + """ + DistancingPolicyAgent applies isolation on a set turn and reconnection on a later turn. + + It can be used to model start and end of social distancing or quarantine periods. Note that this agent is similar to + the isolation agent, but will isolate any node not just infected ones. + + 0 start['isolate'] end['isolate'] start['reconnect'] end['reconnect'] + | Does nothing | Isolates ANY node | Does nothing | Reconnects Nodes | Does nothing ... + + """ + + @property + def available_actions(self) -> List[str]: + return ['isolate', 'reconnect'] + + @staticmethod + def available_targets(obs: ObservationSpace) -> Dict[str, List[int]]: + """Slightly different IsolationAgent - also isolates clear nodes and reconnects any isolated node.""" + return {'isolate': list(set(obs.current_clear_nodes).difference(obs.isolated_nodes)), + 'reconnect': obs.isolated_nodes} + + def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + """Selects from actions that are currently available. If both are active, selects randomly between them.""" + + available_actions = {} + if len(self.currently_active_actions) > 0: + actions = self._random_state.choice(self.currently_active_actions, + replace=True, + size=self.actions_per_turn) + available_targets = self.available_targets(obs) + + # This effectively discards duplicate actions/target if same target is randomly selected twice + # Not sure if this is a good approach. When pool of targets is small, more likely to not take all available + # actions. + for ac in actions: + available_targets_for_this_action = available_targets[ac] + if len(available_targets_for_this_action) > 0: + available_actions.update({self._random_state.choice(available_targets_for_this_action): ac}) + + return available_actions diff --git a/social_distancing_sim/agent/policy_agents/treatment_policy_agent.py b/social_distancing_sim/agent/policy_agents/treatment_policy_agent.py new file mode 100644 index 0000000..e99dd2f --- /dev/null +++ b/social_distancing_sim/agent/policy_agents/treatment_policy_agent.py @@ -0,0 +1,34 @@ +from social_distancing_sim.agent.agent_base import AgentBase +from typing import List, Dict + +from social_distancing_sim.agent.agent_base import AgentBase +from social_distancing_sim.environment.observation_space import ObservationSpace + + +class TreatmentPolicyAgent(AgentBase): + """ + TreatmentPolicyAgent applies treatment to random infected nodes in active time frame. + + It can be used to model start and end of social distancing or quarantine periods. Note that this agent is similar to + the isolation agent, but will isolate any node not just infected ones. + + 0 start['treat'] end['treat'] + | Does nothing | Isolates ANY node | Does nothing ... + + """ + @property + def available_actions(self) -> List[str]: + return ['treat'] + + @staticmethod + def available_targets(obs: ObservationSpace) -> List[int]: + """Slightly different IsolationAgent - also isolates clear nodes and reconnects any isolated node.""" + return obs.current_infected_nodes + + def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + if len(self.currently_active_actions) > 0: + # Don't track sample call here as self.get_actions() will handle that. + return self.sample(obs, + track=False) + else: + return {} diff --git a/social_distancing_sim/agent/policy_agents/vaccination_policy_agent.py b/social_distancing_sim/agent/policy_agents/vaccination_policy_agent.py new file mode 100644 index 0000000..41c80ac --- /dev/null +++ b/social_distancing_sim/agent/policy_agents/vaccination_policy_agent.py @@ -0,0 +1,35 @@ +from typing import List, Dict + +from social_distancing_sim.agent.agent_base import AgentBase +from social_distancing_sim.environment.observation_space import ObservationSpace + + +class VaccinationPolicyAgent(AgentBase): + """ + Vaccination applies vaccination during a set time period. + + Unlike VaccinationAgent, vaccinates any clear node even if they have some immunity. + + It can be used to model availability of a vaccine, for max or staggered use. + + 0 start['vaccinate'] end['vaccinate'] + | Does nothing | Isolates ANY node | Does nothing ... + """ + + @property + def available_actions(self) -> List[str]: + """Isolation agent can only isolate. It can't even un-isolate (yet?)""" + return ['vaccinate'] + + @staticmethod + def available_targets(obs: ObservationSpace) -> List[int]: + """Same as VaccinationAgent.""" + return list(set(obs.current_clear_nodes)) + + def select_actions(self, obs: ObservationSpace) -> Dict[int, str]: + if len(self.currently_active_actions) > 0: + # Don't track sample call here as self.get_actions() will handle that. + return self.sample(obs, + track=False) + else: + return {} diff --git a/social_distancing_sim/environment/action_space.py b/social_distancing_sim/environment/action_space.py index ce901a6..cb6af10 100644 --- a/social_distancing_sim/environment/action_space.py +++ b/social_distancing_sim/environment/action_space.py @@ -11,8 +11,14 @@ class ActionSpace: TODO: Standardise api and remove **kwargs """ - vaccinate_cost: int = -2 - isolate_cost: int = 0 + vaccinate_cost: float = -2 + isolate_cost: float = 0 + treat_cost: float = -3 + vaccinate_efficiency: float = 0.95 + isolate_efficiency: float = 0.95 + reconnect_efficiency: float = 0.95 + treatment_conclusion_chance: float = 0.9 + treatment_recovery_rate_modifier: float = 1.5 seed: Union[int, None] = None def __post_init__(self) -> None: @@ -21,25 +27,41 @@ def __post_init__(self) -> None: def _prepare_random_state(self) -> None: self.state = np.random.RandomState(seed=self.seed) + @property + def n(self): + return len(self.available_actions) + @property def available_actions(self) -> List[str]: - return ['vaccinate', 'isolate', 'reconnect'] + return ['vaccinate', 'isolate', 'reconnect', 'treat'] + + def treat(self, **kwargs) -> float: + kwargs["env"].disease.conclude(kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]], + chance_to_force=self.treatment_conclusion_chance, + recovery_rate_modifier=self.treatment_recovery_rate_modifier) + kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["status"].immune = True + kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["last_tested"] = kwargs["step"] + + return self.treat_cost - def vaccinate(self, **kwargs) -> int: - kwargs["env"].disease.give_immunity(kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]) + def vaccinate(self, **kwargs) -> float: + kwargs["env"].disease.give_immunity(kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]], + immunity=self.vaccinate_efficiency) kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["status"].immune = True kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["last_tested"] = kwargs["step"] return self.vaccinate_cost - def isolate(self, **kwargs) -> int: - kwargs["env"].observation_space.graph.isolate_node(kwargs["target_node_id"]) + def isolate(self, **kwargs) -> float: + kwargs["env"].observation_space.graph.isolate_node(kwargs["target_node_id"], + effectiveness=self.isolate_efficiency) kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["status"].isolated = True return self.isolate_cost - def reconnect(self, **kwargs) -> int: - kwargs["env"].observation_space.graph.reconnect_node(kwargs["target_node_id"]) + def reconnect(self, **kwargs) -> float: + kwargs["env"].observation_space.graph.reconnect_node(kwargs["target_node_id"], + effectiveness=self.reconnect_efficiency) kwargs["env"].observation_space.graph.g_.nodes[kwargs["target_node_id"]]["status"].isolated = False return self.isolate_cost diff --git a/social_distancing_sim/environment/disease.py b/social_distancing_sim/environment/disease.py index 2498f7e..084e76b 100644 --- a/social_distancing_sim/environment/disease.py +++ b/social_distancing_sim/environment/disease.py @@ -31,11 +31,31 @@ def __post_init__(self) -> None: def _prepare_random_state(self) -> None: self.state = np.random.RandomState(seed=self.seed) - def give_immunity(self, node: Dict[Hashable, Any]) -> Dict[Hashable, Any]: - node["immune"] = min(self.immunity_mean + self.state.normal(scale=self.immunity_std), 1.0) + def give_immunity(self, node: Dict[Hashable, Any], + immunity: float = None) -> Dict[Hashable, Any]: + """ + Give immunity to a node. + + :param node: Graph node data. + :param immunity: Amount of immunity to give to node, optional. Default None, which gives self.immunity_ (the + values applied on recovery). Allows for bonus immunity from actions, like vaccination. + :return: Modified node data. + """ + if immunity is None: + immunity = min(self.immunity_mean + self.state.normal(scale=self.immunity_std), 1.0) + node["immune"] = immunity + return node def decay_immunity(self, node: Dict[Hashable, Any]) -> Dict[Hashable, Any]: + """ + Decay immunity on a node. + + Uses self.immunity_decay_ values. Not option to specify decay amount for now, because YAGNI. + + :param node: Graph node data. + :return: Modified node data. + """ decay = self.immunity_decay_mean + self.state.normal(scale=self.immunity_decay_std) new_immunity = max(0.0, node["immune"] - node["immune"] * decay) node["immune"] = new_immunity @@ -46,25 +66,38 @@ def modified_virulence(self, immunity: float) -> float: return min(max(1e-7, self.virulence * (1 - immunity)), 0.999) def conclude(self, node: Dict[Hashable, Any], + chance_to_force: float = 0.0, recovery_rate_modifier: float = 1) -> Dict[Hashable, Any]: """ :param node: Graph node to update. + + :param chance_to_force: Chance to force conclusion. Must be between 0 -> 1. Optional, default 0. + Default uses self.duration_* params instead. Allows for specifying a treatment efficacy + along with recovery rate modified. Note that if conclusion is forced, node can still + die. So treatment can be worse than cure if high chance to force conclusion and + recovery rate penalty. :param recovery_rate_modifier: Modify recovery rate depending on external factors such as healthcare burden. - Default 1.0 (no modification). + Default 1.0 (no modification). Only relevant if disease ends. + :return: Updated graph node. """ - modified_recovery_rate = self.recovery_rate * recovery_rate_modifier + # Decide if forcing conclusion or not + force = False + if chance_to_force > 0: + force = self.state.binomial(1, chance_to_force) # Decide end of disease if node["infected"] > self.state.normal(self.duration_mean, self.duration_std, - size=1): + size=1) or force: + # Concluding, decide fate node["infected"] = 0 - node = self.give_immunity(node) + modified_recovery_rate = max(0.0, min(1.0, self.recovery_rate * recovery_rate_modifier)) if self.state.binomial(1, modified_recovery_rate): + node = self.give_immunity(node) node["alive"] = True else: node["alive"] = False diff --git a/social_distancing_sim/environment/environment.py b/social_distancing_sim/environment/environment.py index 44c8b9b..7b14ce3 100644 --- a/social_distancing_sim/environment/environment.py +++ b/social_distancing_sim/environment/environment.py @@ -6,8 +6,8 @@ import numpy as np from tqdm import tqdm -from social_distancing_sim.environment.disease import Disease from social_distancing_sim.environment.action_space import ActionSpace +from social_distancing_sim.environment.disease import Disease from social_distancing_sim.environment.environment_plotting import EnvironmentPlotting from social_distancing_sim.environment.healthcare import Healthcare from social_distancing_sim.environment.history import History @@ -95,7 +95,8 @@ def _conclude_all(self) -> Tuple[int, int]: def _log(self, new_infections: int, known_new_infections: int, deaths: int, recoveries: int, turn_score: float = 0.0, - obs_turn_score: float = 0.0) -> None: + obs_turn_score: float = 0.0, + actions_taken: Dict[int, str] = 0) -> None: # Log counts/score for this turn self.history.log({"Turn score": turn_score, @@ -104,6 +105,13 @@ def _log(self, new_infections: int, known_new_infections: int, deaths: int, reco "Known new infections": known_new_infections, "New deaths": deaths, "Current recovered": recoveries}) + # Log actions + # TODO: Might be worth making this less manual and/or handling it somewhere else? + self.history.log({"Actions taken": len(actions_taken.values()), + "Vaccinate actions": len([a for a in actions_taken.values() if a == 'vaccinate']), + "Isolate actions": len([a for a in actions_taken.values() if a == 'isolate']), + "Reconnect actions": len([a for a in actions_taken.values() if a == 'reconnect']), + "Treat actions": len([a for a in actions_taken.values() if a == 'treat'])}) # Log full space and observed space self.history.log({"Current infections": self.observation_space.graph.n_current_infected, @@ -210,7 +218,8 @@ def step(self, actions: Dict[int, str]) -> Tuple[Dict[str, Any], float, bool, Di deaths=deaths, recoveries=recoveries, turn_score=turn_score, - obs_turn_score=obs_turn_score) + obs_turn_score=obs_turn_score, + actions_taken=actions) self._step += 1 diff --git a/social_distancing_sim/environment/environment_plotting.py b/social_distancing_sim/environment/environment_plotting.py index 9b69c2b..2001843 100644 --- a/social_distancing_sim/environment/environment_plotting.py +++ b/social_distancing_sim/environment/environment_plotting.py @@ -18,6 +18,8 @@ @dataclass class EnvironmentPlotting: both: bool = True + auto_lim_x: bool = True + auto_lim_y: bool = True ts_fields_g1: List[str] = None ts_fields_g2: List[str] = None ts_obs_fields_g1: List[str] = None @@ -117,8 +119,8 @@ def plot_ts(self, history: History, healthcare: Healthcare, total_steps: int, to step: int) -> None: for ax, fields in zip(self._ts_ax_g1, [self.ts_fields_g1, self.ts_obs_fields_g1]): history.plot(ks=fields, - x_lim=(-1, total_steps), - y_lim=(-10, int(total_population + total_population * 0.05)), + x_lim=(-1, total_steps) if self.auto_lim_x else None, + y_lim=(-10, int(total_population + total_population * 0.05)) if self.auto_lim_y else None, x_label='Day' if not self._g2_on else None, remove_x_tick_labels=self._g2_on, ax=ax, @@ -130,15 +132,15 @@ def plot_ts(self, history: History, healthcare: Healthcare, total_steps: int, to if self._g2_on: for ax, fields in zip(self._ts_ax_g2, [self.ts_fields_g2, self.ts_obs_fields_g2]): history.plot(ks=fields, - y_label='Score', - x_lim=(-1, total_steps), + y_label='', + x_lim=(-1, total_steps) if self.auto_lim_x else None, # y_lim=(-0.1, 1.1), ax=ax, show=False) def plot_graphs(self, obs: ObservationSpace, title: str): obs.graph.plot(ax=self._graph_ax[0]) - self._graph_ax[0].set_title(f"Full sim: {title}") + self._graph_ax[0].set_title(f"Full sim: {title}", fontsize=14) if (obs.test_rate < 1) & self.both: obs.plot(ax=self._graph_ax[1]) diff --git a/social_distancing_sim/environment/graph.py b/social_distancing_sim/environment/graph.py index db70b1e..7e6a57c 100644 --- a/social_distancing_sim/environment/graph.py +++ b/social_distancing_sim/environment/graph.py @@ -62,6 +62,8 @@ def _generate_graph(self) -> None: nv["infected"] = 0 nv["immune"] = False nv["alive"] = True + nv["_edges"] = [] + nv["isolated"] = False def _prepare_random_state(self) -> None: self._random_state = np.random.RandomState(seed=self.seed) @@ -117,16 +119,16 @@ def overall_death_rate(self) -> float: return death_rate def isolate_node(self, node_id: int, - effectiveness: float = 0.95): + effectiveness: float = 0.95) -> None: """ Remove some or all edges from a node, and store on node. + Flag node as isolated if any edges removed and stored in _edges. + :param node_id: Node index. :param effectiveness: Proportion of edges to remove """ node = self.g_.nodes[node_id] - # Do NOT deepcopy EdgeView!!! Copy won't work either. - node["_edges"] = copy.deepcopy(list(self.g_.edges(node_id))) node["isolated"] = True # Select edges to remove @@ -135,12 +137,34 @@ def isolate_node(self, node_id: int, if self._random_state.binomial(1, effectiveness): to_remove.append(uv) + # Do NOT deepcopy EdgeView!!! Copy won't work either. + node["_edges"] += to_remove + self.g_.remove_edges_from(to_remove) - def reconnect_node(self, node_id: int): + def reconnect_node(self, node_id: int, + effectiveness: float = 0.95) -> None: + """ + Restore edges with probability defined in effectiveness. + + Flag node as not isolated when all edges have been restored. + + :param node_id: Node index. + :param effectiveness: Proportion of edges to re-add. + """ node = self.g_.nodes[node_id] - self.g_.add_edges_from(node["_edges"]) - node["isolated"] = False + to_add = [] + leave = [] + for uv in node["_edges"]: + if self._random_state.binomial(1, effectiveness): + to_add.append(uv) + else: + leave.append(uv) + + self.g_.add_edges_from(to_add) + node["_edges"] = leave + if len(node["_edges"]) == 0: + node["isolated"] = False def plot(self, ax: Union[None, plt.Axes] = None, diff --git a/social_distancing_sim/sim/multi_sim.py b/social_distancing_sim/sim/multi_sim.py index 5337355..f138916 100644 --- a/social_distancing_sim/sim/multi_sim.py +++ b/social_distancing_sim/sim/multi_sim.py @@ -92,40 +92,3 @@ def log(self): mlflow.log_metrics(metrics_to_log) mlflow.end_run() - - -if __name__ == "__main__": - from social_distancing_sim.agent import RandomAgent - from social_distancing_sim.environment.graph import Graph - from social_distancing_sim.environment.healthcare import Healthcare - from social_distancing_sim.environment.observation_space import ObservationSpace - from social_distancing_sim.environment.environment import Environment - from social_distancing_sim.disease.disease import Disease - - seed = None - - pop = Environment(name="multi sim environment", - disease=Disease(name='COVID-19', - virulence=0.01, - seed=seed, - immunity_mean=0.95, - immunity_decay_mean=0.05), - healthcare=Healthcare(capacity=5), - observation_space=ObservationSpace(graph=Graph(community_n=15, - community_size_mean=10, - seed=seed), - test_rate=1, - seed=seed), - seed=seed, - environment_plotting=EnvironmentPlotting(ts_fields_g2=["Turn score", "Action cost", - "Overall score"], - ts_obs_fields_g2=["Observed turn score", "Action cost", - "Observed overall score"])) - - sim = Sim(env=pop, - n_steps=150, - agent=RandomAgent(actions_per_turn=10, - seed=seed),) - - multi_sim = MultiSim(sim) - multi_sim.run() diff --git a/social_distancing_sim/sim/sim.py b/social_distancing_sim/sim/sim.py index 9e2b971..ee4007c 100644 --- a/social_distancing_sim/sim/sim.py +++ b/social_distancing_sim/sim/sim.py @@ -5,7 +5,7 @@ from tqdm import tqdm from social_distancing_sim.agent.agent_base import AgentBase -from social_distancing_sim.agent.dummy_agent import DummyAgent +from social_distancing_sim.agent.basic_agents.dummy_agent import DummyAgent from social_distancing_sim.environment.environment import Environment from social_distancing_sim.environment.history import History @@ -41,7 +41,7 @@ def run(self) -> History: desc=self.env.name): # Pick action - actions = self.agent.select_actions(obs=self.env.observation_space) + actions = self.agent.get_actions(obs=self.env.observation_space) # Step the simulation observation, reward, done, info = self.env.step(actions) @@ -72,46 +72,3 @@ def clone(self) -> "Sim": agent=self.agent.clone() if self.agent is not None else None, n_steps=self.n_steps, plot=self.plot, save=self.save, tqdm_on=self.tqdm_on) - - -if __name__ == "__main__": - from social_distancing_sim.environment.graph import Graph - from social_distancing_sim.environment.healthcare import Healthcare - from social_distancing_sim.environment.observation_space import ObservationSpace - from social_distancing_sim.environment.environment import Environment - from social_distancing_sim.disease.disease import Disease - from social_distancing_sim.environment.environment_plotting import EnvironmentPlotting - import social_distancing_sim.agent as agents - - seed = 100 - - env = Environment(name="example environment", - disease=Disease(name='COVID-19', - virulence=0.1, - duration_mean=5, - seed=seed, - immunity_mean=0.95, - immunity_decay_mean=0.05), - healthcare=Healthcare(capacity=5), - observation_space=ObservationSpace(graph=Graph(community_n=15, - community_size_mean=10, - community_p_in=1, - community_p_out=0.3, - seed=seed + 1), - test_rate=1, - seed=seed + 2), - environment_plotting=EnvironmentPlotting(ts_fields_g2=["Turn score", "Action cost", - "Overall score"], - ts_obs_fields_g2=["Observed turn score", "Action cost", - "Observed overall score"]), - seed=seed + 3) - - sim = Sim(env=env, - n_steps=150, - agent=agents.VaccinationAgent(seed=8, - actions_per_turn=10), - tqdm_on=True, - plot=True, - save=False) - - sim.run() diff --git a/tests/integration/environment/__init__.py b/tests/integration/environment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/environment/test_action_space.py b/tests/integration/environment/test_action_space.py new file mode 100644 index 0000000..33bf561 --- /dev/null +++ b/tests/integration/environment/test_action_space.py @@ -0,0 +1,157 @@ +import copy +import unittest + +import social_distancing_sim.environment as env +from social_distancing_sim.environment.action_space import ActionSpace + + +class TestActionSpace(unittest.TestCase): + _sut = ActionSpace + + def setUp(self) -> None: + self.env = env.Environment(name='Test env', + disease=env.Disease(name='test disease'), + observation_space=env.ObservationSpace(graph=env.Graph(community_n=30, + community_p_in=1, + community_p_out=0.9, + community_size_mean=5, + seed=123))) + + def test_expected_default_actions_are_available(self): + # Arrange + action_space = self._sut() + + # Act + available_actions = action_space.available_actions + + # Assert + self.assertListEqual(['vaccinate', 'isolate', 'reconnect', 'treat'], available_actions) + + def test_vaccinate_action_adds_default_immunity(self): + # Arrange + action_space = self._sut() + + # Act + cost = action_space.vaccinate(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["immune"] > 0) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["immune"] > 0) + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["status"].immune) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["status"].immune) + + def test_vaccinate_action_adds_extra_immunity_if_specified(self): + # Arrange + action_space = self._sut(vaccinate_efficiency=1) + + # Act + cost = action_space.vaccinate(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["immune"] > 0) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["immune"] == 1) + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["status"].immune) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["status"].immune) + + def test_isolate_action_isolates_all_with_full_effectiveness(self): + # Arrange + action_space = self._sut(isolate_efficiency=1) + + # Act + cost = action_space.isolate(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["isolated"]) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["isolated"]) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(0)) > 0) + self.assertEqual(0, len(self.env.observation_space.graph.g_.edges(1))) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[0]["_edges"])) + self.assertTrue(len(self.env.observation_space.graph.g_.nodes[1]["_edges"]) > 0) + + def test_isolate_action_isolates_does_not_isolate_all_with_limited_effectiveness(self): + # Arrange + action_space = self._sut(isolate_efficiency=0.1) + + # Act + cost = action_space.isolate(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["isolated"]) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["isolated"]) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(0)) > 0) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(1)) > 0) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[0]["_edges"])) + self.assertTrue(len(self.env.observation_space.graph.g_.nodes[1]["_edges"]) > 0) + + def test_reconnect_action_reconnects_all_with_full_effectiveness(self): + # Arrange + action_space = self._sut(isolate_efficiency=1, + reconnect_efficiency=1) + cost = action_space.isolate(env=self.env, target_node_id=1, step=1) + + # Act + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["isolated"]) + self.assertFalse(self.env.observation_space.graph.g_.nodes[1]["isolated"]) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(0)) > 0) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(1)) > 0) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[0]["_edges"])) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[1]["_edges"])) + + def test_reconnect_action_does_not_reconnect_all_with_limited_effectiveness(self): + # Arrange + action_space = self._sut(isolate_efficiency=1, + reconnect_efficiency=0.2) + cost = action_space.isolate(env=self.env, target_node_id=1, step=1) + + # Act + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["isolated"]) + self.assertTrue(self.env.observation_space.graph.g_.nodes[1]["isolated"]) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(0)) > 0) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(1)) > 0) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[0]["_edges"])) + self.assertTrue(len(self.env.observation_space.graph.g_.nodes[1]["_edges"]) > 0) + + def test_reconnect_action_eventually_reconnects_all_with_limited_effectiveness_and_repeated_calls(self): + # Arrange + action_space = self._sut(isolate_efficiency=1, + reconnect_efficiency=0.7) + original_edges = copy.deepcopy(list(self.env.observation_space.graph.g_.edges(1))) + cost = action_space.isolate(env=self.env, target_node_id=1, step=1) + + # Act + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + cost = action_space.reconnect(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["isolated"]) + self.assertFalse(self.env.observation_space.graph.g_.nodes[1]["isolated"]) + self.assertTrue(len(self.env.observation_space.graph.g_.edges(0)) > 0) + self.assertEqual(len(original_edges), len(self.env.observation_space.graph.g_.edges(1))) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[0]["_edges"])) + self.assertEqual(0, len(self.env.observation_space.graph.g_.nodes[1]["_edges"])) + + def test_treatment_removes_infection_when_forced(self): + # Arrange + action_space = self._sut(treatment_conclusion_chance= 1, + treatment_recovery_rate_modifier=10) + self.env.observation_space.graph.g_.nodes[1]["infected"] = 3 + self.env.observation_space.graph.g_.nodes[1]["status"].infected = True + + # Act + cost = action_space.treat(env=self.env, target_node_id=1, step=1) + + # Assert + self.assertFalse(self.env.observation_space.graph.g_.nodes[0]["infected"] > 0) + self.assertFalse(self.env.observation_space.graph.g_.nodes[1]["infected"] > 0) + self.assertIsNone(self.env.observation_space.graph.g_.nodes[0]["status"].infected) + self.assertFalse(self.env.observation_space.graph.g_.nodes[1]["status"].infected) diff --git a/tests/integration/population/test_population.py b/tests/integration/environment/test_population.py similarity index 100% rename from tests/integration/population/test_population.py rename to tests/integration/environment/test_population.py diff --git a/tests/integration/sim/test_multi_sim.py b/tests/integration/sim/test_multi_sim.py index fbfe229..675c944 100644 --- a/tests/integration/sim/test_multi_sim.py +++ b/tests/integration/sim/test_multi_sim.py @@ -3,8 +3,8 @@ import numpy as np from tqdm import tqdm -from social_distancing_sim.agent.isolation_agent import IsolationAgent -from social_distancing_sim.agent.vaccination_agent import VaccinationAgent +from social_distancing_sim.agent.basic_agents.isolation_agent import IsolationAgent +from social_distancing_sim.agent.basic_agents.vaccination_agent import VaccinationAgent from social_distancing_sim.environment.disease import Disease from social_distancing_sim.environment.environment import Environment from social_distancing_sim.environment.graph import Graph diff --git a/tests/integration/sim/test_sim.py b/tests/integration/sim/test_sim.py index 78fbbc9..37b7289 100644 --- a/tests/integration/sim/test_sim.py +++ b/tests/integration/sim/test_sim.py @@ -1,6 +1,6 @@ import unittest -from social_distancing_sim.agent.vaccination_agent import VaccinationAgent +from social_distancing_sim.agent.basic_agents.vaccination_agent import VaccinationAgent from social_distancing_sim.environment.disease import Disease from social_distancing_sim.environment.environment import Environment from social_distancing_sim.environment.environment_plotting import EnvironmentPlotting @@ -9,8 +9,19 @@ from social_distancing_sim.environment.observation_space import ObservationSpace from social_distancing_sim.sim.sim import Sim +import os +import shutil + class TestSim(unittest.TestCase): + + def setUp(self): + self._to_delete = None + + def tearDown(self): + if self._to_delete is not None: + shutil.rmtree(self._to_delete, ignore_errors=True) + def test_default_sim_run(self): pop = Environment(disease=Disease(), healthcare=Healthcare(), @@ -26,7 +37,7 @@ def test_default_sim_run(self): def test_example_sim_run(self): seed = 123 - pop = Environment(name="agent example environment", + pop = Environment(name="agent example environment 1", disease=Disease(name='COVID-19', virulence=0.01, seed=seed, @@ -51,3 +62,106 @@ def test_example_sim_run(self): save=False) sim.run() + + self._to_delete = pop.name + + def test_example_sim_run_with_plotting(self): + + seed = 123 + + pop = Environment(name="agent example environment 2", + disease=Disease(name='COVID-19', + virulence=0.01, + seed=seed, + immunity_mean=0.95, + immunity_decay_mean=0.05), + healthcare=Healthcare(capacity=5), + observation_space=ObservationSpace(graph=Graph(community_n=15, + community_size_mean=10, + seed=seed + 1), + test_rate=1, + seed=seed + 2), + seed=seed + 3, + environment_plotting=EnvironmentPlotting(ts_fields_g2=["Score", "Action cost", + "Overall score"], + ts_obs_fields_g2=["Observed Score", + "Action cost", + "Observed overall score"])) + + sim = Sim(env=pop, + n_steps=3, + agent=VaccinationAgent(actions_per_turn=25, + seed=seed), + plot=False, + save=True) + + sim.run() + sim.env.replay() + self._to_delete = pop.name + + def test_example_sim_run_with_extra_plotting(self): + seed = 123 + + pop = Environment(name="agent example environment 3", + disease=Disease(name='COVID-19', + virulence=0.01, + seed=seed, + immunity_mean=0.95, + immunity_decay_mean=0.05), + healthcare=Healthcare(capacity=5), + observation_space=ObservationSpace(graph=Graph(community_n=15, + community_size_mean=10, + seed=seed + 1), + test_rate=1, + seed=seed + 2), + seed=seed + 3, + environment_plotting=EnvironmentPlotting(ts_fields_g2=["Score", "Action cost", + "Overall score"], + ts_obs_fields_g2=["Observed Score", + "Action cost", + "Observed overall score"])) + + sim = Sim(env=pop, + n_steps=3, + agent=VaccinationAgent(actions_per_turn=25, + seed=seed), + plot=False, + save=True) + + sim.run() + sim.env.replay() + self._to_delete = pop.name + + def test_example_sim_run_with_all_plotting(self): + seed = 123 + + pop = Environment(name="agent example environment 4", + disease=Disease(name='COVID-19', + virulence=0.01, + seed=seed, + immunity_mean=0.95, + immunity_decay_mean=0.05), + healthcare=Healthcare(capacity=5), + observation_space=ObservationSpace(graph=Graph(community_n=15, + community_size_mean=10, + seed=seed + 1), + test_rate=0.1, + seed=seed + 2), + seed=seed + 3, + environment_plotting=EnvironmentPlotting(ts_fields_g2=["Score", "Action cost", + "Overall score"], + ts_obs_fields_g2=["Observed Score", + "Action cost", + "Observed overall score"])) + + sim = Sim(env=pop, + n_steps=3, + agent=VaccinationAgent(actions_per_turn=25, + seed=seed), + plot=False, + save=True) + + sim.run() + sim.env.replay() + self._to_delete = pop.name + diff --git a/tests/unit/agent/basic_agents/__init__.py b/tests/unit/agent/basic_agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/agent/test_isolation_agent.py b/tests/unit/agent/basic_agents/test_isolation_agent.py similarity index 83% rename from tests/unit/agent/test_isolation_agent.py rename to tests/unit/agent/basic_agents/test_isolation_agent.py index e68e367..c029cce 100644 --- a/tests/unit/agent/test_isolation_agent.py +++ b/tests/unit/agent/basic_agents/test_isolation_agent.py @@ -1,6 +1,6 @@ import unittest -from social_distancing_sim.agent.isolation_agent import IsolationAgent +from social_distancing_sim.agent.basic_agents.isolation_agent import IsolationAgent class TestVaccinationAgent(unittest.TestCase): diff --git a/tests/unit/agent/test_random_agent.py b/tests/unit/agent/basic_agents/test_random_agent.py similarity index 84% rename from tests/unit/agent/test_random_agent.py rename to tests/unit/agent/basic_agents/test_random_agent.py index 096a8aa..e0aae0c 100644 --- a/tests/unit/agent/test_random_agent.py +++ b/tests/unit/agent/basic_agents/test_random_agent.py @@ -1,6 +1,6 @@ import unittest -from social_distancing_sim.agent.random_agent import RandomAgent +from social_distancing_sim.agent.basic_agents.random_agent import RandomAgent class TestRandomAgent(unittest.TestCase): diff --git a/tests/unit/agent/basic_agents/test_treatment_agent.py b/tests/unit/agent/basic_agents/test_treatment_agent.py new file mode 100644 index 0000000..3b29dc1 --- /dev/null +++ b/tests/unit/agent/basic_agents/test_treatment_agent.py @@ -0,0 +1,21 @@ +import unittest + +from social_distancing_sim.agent.basic_agents.treatment_agent import TreatmentAgent + + +class TestTreatmentAgent(unittest.TestCase): + _sut = TreatmentAgent + + def test_init_with_defaults(self): + # Act + agent = self._sut() + + # Assert + self.assertIsInstance(agent, TreatmentAgent) + + def test_available_actions(self): + # Arrange + agent = self._sut() + + # Assert + self.assertListEqual(['treat'], agent.available_actions) diff --git a/tests/unit/agent/test_vaccination_agent.py b/tests/unit/agent/basic_agents/test_vaccination_agent.py similarity index 82% rename from tests/unit/agent/test_vaccination_agent.py rename to tests/unit/agent/basic_agents/test_vaccination_agent.py index bd4ca31..26bf5c8 100644 --- a/tests/unit/agent/test_vaccination_agent.py +++ b/tests/unit/agent/basic_agents/test_vaccination_agent.py @@ -1,6 +1,6 @@ import unittest -from social_distancing_sim.agent.vaccination_agent import VaccinationAgent +from social_distancing_sim.agent.basic_agents.vaccination_agent import VaccinationAgent class TestVaccinationAgent(unittest.TestCase): diff --git a/tests/unit/agent/policy_agents/__init__.py b/tests/unit/agent/policy_agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/agent/policy_agents/test_distancing_policy_agent.py b/tests/unit/agent/policy_agents/test_distancing_policy_agent.py new file mode 100644 index 0000000..e077083 --- /dev/null +++ b/tests/unit/agent/policy_agents/test_distancing_policy_agent.py @@ -0,0 +1,154 @@ +import unittest +from unittest.mock import MagicMock + +from social_distancing_sim.agent.policy_agents.distancing_policy_agent import DistancingPolicyAgent + + +class TestDistancingPolicyAgent(unittest.TestCase): + _sut = DistancingPolicyAgent + + def setUp(self) -> None: + mock_observation_space = MagicMock() + mock_observation_space.current_clear_nodes = [9, 10, 11, 12] + mock_observation_space.isolated_nodes = [12, 13, 14] + self.mock_observation_space = mock_observation_space + + def test_init_with_defaults(self): + # Act + agent = self._sut() + + # Assert + self.assertIsInstance(agent, DistancingPolicyAgent) + + def test_available_actions(self): + # Arrange + agent = self._sut() + + # Assert + self.assertListEqual(['isolate', 'reconnect'], agent.available_actions) + + def test_no_actions_outside_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 25, + 'reconnect': 35}, + end_step={'isolate': 30, + 'reconnect': 40}) + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual([], list(action.keys())) + + def test_n_actions_inside_first_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 25, + 'reconnect': 35}, + end_step={'isolate': 30, + 'reconnect': 40}) + agent._step = 26 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual(['isolate'], list(action.values())) + + def test_no_actions_between_active_periods(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 25, + 'reconnect': 35}, + end_step={'isolate': 30, + 'reconnect': 40}) + agent._step = 32 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual([], list(action.values())) + + def test_n_actions_inside_second_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 25, + 'reconnect': 35}, + end_step={'isolate': 30, + 'reconnect': 40}) + agent._step = 36 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual(['reconnect'], list(action.values())) + + def test_no_actions_after_active_periods(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 25, + 'reconnect': 35}, + end_step={'isolate': 30, + 'reconnect': 40}) + agent._step = 45 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual([], list(action.values())) + + def test_whole_active_period_returns_actions_with_single_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'isolate': 5, + 'reconnect': 12}, + end_step={'isolate': 10, + 'reconnect': 16}) + + # Act + actions = [] + for s in range(20): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(20, len(actions)) + self.assertEqual(20, agent._step) + self.assertListEqual([0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, + 0, + 1, 1, 1, 1, 1, + 0, 0, 0], [len(d.keys()) for d in actions]) + + def test_whole_active_period_returns_actions_with_multiple_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=3, + start_step={'isolate': 5, + 'reconnect': 12}, + end_step={'isolate': 10, + 'reconnect': 16}) + + # Act + actions = [] + for s in range(20): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(20, len(actions)) + self.assertEqual(20, agent._step) + # Likely to be fewer than 3 actions if duplicate target selected from small pool. + self.assertListEqual([False, False, False, False, False, + True, True, True, True, True, True, + False, + True, True, True, True, True, + False, False, False], [len(d.keys()) > 0 for d in actions]) diff --git a/tests/unit/agent/policy_agents/test_treatment_policy_agent.py b/tests/unit/agent/policy_agents/test_treatment_policy_agent.py new file mode 100644 index 0000000..a81a899 --- /dev/null +++ b/tests/unit/agent/policy_agents/test_treatment_policy_agent.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import MagicMock + +from social_distancing_sim.agent.policy_agents.treatment_policy_agent import TreatmentPolicyAgent + + +class TestTreatmentPolicyAgent(unittest.TestCase): + _sut = TreatmentPolicyAgent + + def setUp(self) -> None: + mock_observation_space = MagicMock() + mock_observation_space.current_infected_nodes = [10, 11, 12] + self.mock_observation_space = mock_observation_space + + def test_init_with_defaults(self): + # Act + agent = self._sut() + + # Assert + self.assertIsInstance(agent, TreatmentPolicyAgent) + + def test_available_actions(self): + # Arrange + agent = self._sut() + + # Assert + self.assertListEqual(['treat'], agent.available_actions) + + def test_no_actions_outside_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'treat': 25}, + end_step={'treat': 30}) + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual([], list(action.keys())) + + def test_n_actions_inside_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'treat': 25}, + end_step={'treat': 30}) + agent._step = 26 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual(['treat'], list(action.values())) + + def test_whole_active_period_returns_actions_with_single_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'treat': 5}, + end_step={'treat': 10}) + + # Act + actions = [] + for s in range(15): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(15, len(actions)) + self.assertEqual(15, agent._step) + self.assertListEqual([0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0], [len(d.keys()) for d in actions]) + + def test_whole_active_period_returns_actions_with_multiple_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=3, + start_step={'treat': 5}, + end_step={'treat': 10}) + + # Act + actions = [] + for s in range(15): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(15, len(actions)) + self.assertEqual(15, agent._step) + self.assertListEqual([0, 0, 0, 0, 0, + 3, 3, 3, 3, 3, 3, + 0, 0, 0, 0], [len(d.keys()) for d in actions]) diff --git a/tests/unit/agent/policy_agents/test_vaccination_policy_agent.py b/tests/unit/agent/policy_agents/test_vaccination_policy_agent.py new file mode 100644 index 0000000..ad4c203 --- /dev/null +++ b/tests/unit/agent/policy_agents/test_vaccination_policy_agent.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import MagicMock + +from social_distancing_sim.agent.policy_agents.vaccination_policy_agent import VaccinationPolicyAgent + + +class TestVaccinationPolicyAgent(unittest.TestCase): + _sut = VaccinationPolicyAgent + + def setUp(self) -> None: + mock_observation_space = MagicMock() + mock_observation_space.current_clear_nodes = [10, 11, 12] + self.mock_observation_space = mock_observation_space + + def test_init_with_defaults(self): + # Act + agent = self._sut() + + # Assert + self.assertIsInstance(agent, VaccinationPolicyAgent) + + def test_available_actions(self): + # Arrange + agent = self._sut() + + # Assert + self.assertListEqual(['vaccinate'], agent.available_actions) + + def test_no_actions_outside_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'vaccinate': 25}, + end_step={'vaccinate': 30}) + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual([], list(action.keys())) + + def test_n_actions_inside_active_period(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'vaccinate': 25}, + end_step={'vaccinate': 30}) + agent._step = 26 + + # Act + action = agent.get_actions(self.mock_observation_space) + + # Assert + self.assertListEqual(['vaccinate'], list(action.values())) + + def test_whole_active_period_returns_actions_with_single_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=1, + start_step={'vaccinate': 5}, + end_step={'vaccinate': 10}) + + # Act + actions = [] + for s in range(15): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(15, len(actions)) + self.assertEqual(15, agent._step) + self.assertListEqual([0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0], [len(d.keys()) for d in actions]) + + def test_whole_active_period_returns_actions_with_multiple_actions(self): + # Arrange + agent = self._sut(name='test_agent', + actions_per_turn=3, + start_step={'vaccinate': 5}, + end_step={'vaccinate': 10}) + + # Act + actions = [] + for s in range(15): + actions.append(agent.get_actions(self.mock_observation_space)) + + # Assert + self.assertEqual(15, len(actions)) + self.assertEqual(15, agent._step) + self.assertListEqual([0, 0, 0, 0, 0, + 3, 3, 3, 3, 3, 3, + 0, 0, 0, 0], [len(d.keys()) for d in actions]) diff --git a/tests/unit/environment/test_action_space.py b/tests/unit/environment/test_action_space.py index b4f1dee..979004f 100644 --- a/tests/unit/environment/test_action_space.py +++ b/tests/unit/environment/test_action_space.py @@ -1,25 +1,29 @@ import unittest + from social_distancing_sim.environment.action_space import ActionSpace class TestActionSpace(unittest.TestCase): _sut = ActionSpace() + _implemented_actions = ['vaccinate', 'isolate', 'reconnect', 'treat'] def test_expected_default_actions_are_available(self): + # Act + available_actions = self._sut.available_actions + + # Assert + self.assertListEqual(self._implemented_actions, available_actions) + + def test_n_returns_expected_n_actions(self): + # Act + n_available_actions = self._sut.n + + # Assert + self.assertEqual(len(self._implemented_actions), n_available_actions) + + def test_sampled_returns_valid_action(self): + # Act + action = self._sut.sample() + # Assert - self.assertListEqual(['vaccinate', 'isolate', 'reconnect'], self._sut.available_actions) - - @unittest.skip(reason='TODO') - def test_vaccinate_action(self): - # TODO - pass - - @unittest.skip(reason='TODO') - def test_isolate_action(self): - # TODO - pass - - @unittest.skip(reason='TODO') - def test_reconnect_action(self): - # TODO - pass + self.assertIn(action, self._implemented_actions)