-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from garethjns/bug_fixes
Fix bug with environment's automatic target selection for vaccination…
- Loading branch information
Showing
13 changed files
with
235 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import gym | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import seaborn as sns | ||
|
||
import social_distancing_sim.environment as env | ||
from social_distancing_sim.environment.gym.gym_env import GymEnv | ||
from social_distancing_sim.sim import MultiSim, Sim | ||
from social_distancing_sim.templates.template_base import TemplateBase | ||
|
||
|
||
class EnvTemplate(TemplateBase): | ||
def build(self): | ||
return env.Environment(name="visual_run_simulation_with_agent_custom_env", | ||
action_space=env.ActionSpace(isolate_efficiency=0.5, | ||
vaccinate_efficiency=0.95), | ||
disease=env.Disease(name='COVID-19', | ||
virulence=0.008, | ||
immunity_mean=0.7, | ||
recovery_rate=0.9, | ||
immunity_decay_mean=0.01), | ||
healthcare=env.Healthcare(capacity=75), | ||
environment_plotting=env.EnvironmentPlotting(ts_fields_g2=["Actions taken", | ||
"Overall score"]), | ||
observation_space=env.ObservationSpace( | ||
graph=env.Graph(community_n=20, | ||
community_size_mean=15, | ||
considered_immune_threshold=0.7), | ||
test_rate=1), | ||
initial_infections=15) | ||
|
||
|
||
class CustomEnv(GymEnv): | ||
template = EnvTemplate() | ||
|
||
|
||
if __name__ == "__main__": | ||
# Prepare a custom environment | ||
env_name = f"SDSTests-CustomEnv{np.random.randint(2e6)}-v0" | ||
gym.envs.register(id=env_name, | ||
entry_point='scripts.stats_run_single_population:CustomEnv', | ||
max_episode_steps=1000) | ||
|
||
sim = Sim(env_spec=gym.make(env_name).spec) | ||
|
||
ms = MultiSim(sim, n_reps=10, n_jobs=50) | ||
|
||
ms.run() | ||
|
||
sns.distplot(ms.results['Overall score']) | ||
sns.distplot(ms.results['Total deaths']) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,69 @@ | ||
from functools import partial | ||
|
||
import gym | ||
import numpy as np | ||
from reinforcement_learning_keras.agents.components.helpers.virtual_gpu import VirtualGPU | ||
|
||
import social_distancing_sim.environment as env | ||
from social_distancing_sim.agent.rl_agents.q_learning.dqn_untargeted import DQNUntargeted | ||
from social_distancing_sim.agent.rl_agents.rlk_agent_configs import RLKAgentConfigs | ||
from social_distancing_sim.environment import ActionSpace, EnvironmentPlotting | ||
from social_distancing_sim.environment.gym.gym_env import GymEnv | ||
from social_distancing_sim.environment.gym.wrappers.flatten_obs_wrapper import FlattenObsWrapper | ||
from social_distancing_sim.environment.gym.wrappers.limit_obs_wrapper import LimitObsWrapper | ||
from social_distancing_sim.sim import Sim | ||
from social_distancing_sim.templates.template_base import TemplateBase | ||
|
||
|
||
class EnvTemplate(TemplateBase): | ||
|
||
@classmethod | ||
def build(cls) -> env.Environment: | ||
return env.Environment(name="agent training example", | ||
action_space=ActionSpace(), | ||
environment_plotting=EnvironmentPlotting( | ||
ts_fields_g2=['Vaccinate actions completed', 'Isolate actions completed', | ||
'Reconnect actions completed', 'Treat actions completed', | ||
'Mask actions completed']), | ||
disease=env.Disease(name='COVID-19', | ||
virulence=0.006, | ||
immunity_mean=0.6, | ||
immunity_decay_mean=0.15), | ||
healthcare=env.Healthcare(), | ||
observation_space=env.ObservationSpace(graph=env.Graph(community_n=50, | ||
community_size_mean=15, | ||
community_p_in=0.1, | ||
community_p_out=0.05, | ||
seed=20200423), | ||
test_rate=1)) | ||
|
||
|
||
class CustomEnv(GymEnv): | ||
template = EnvTemplate() | ||
|
||
|
||
if __name__ == "__main__": | ||
gpu = VirtualGPU(gpu_memory_limit=2048, | ||
gpu_device_id=0) | ||
|
||
gym.envs.register(id='SDS-746-v0', | ||
entry_point='social_distancing_sim.environment.gym.environments.sds_746:SDS746', | ||
env_name = f"SDS-CustomEnv{np.random.randint(2e6)}-v0" | ||
gym.envs.register(id=env_name, | ||
entry_point='scripts.train_and_evaluate_untargeted_dqn:CustomEnv', | ||
max_episode_steps=1000) | ||
|
||
config_dict = RLKAgentConfigs(agent_name='flat_obs_dqn', env_spec='SDS-746-v0', expected_obs_shape=(746 * 6,), | ||
config_dict = RLKAgentConfigs(agent_name='flat_obs_dqn', env_spec=env_name, expected_obs_shape=(746 * 6,), | ||
env_wrappers=(partial(LimitObsWrapper, output=2), | ||
FlattenObsWrapper), | ||
n_actions=5).build_for_dqn_untargeted() | ||
|
||
# Train agent using rlk agents built in train function. Note that the agent only takes a single action per turn | ||
# unless the multiple actions wrapper is added. TODO: Add this wrapper for training but remove for future use. | ||
agent = DQNUntargeted(**config_dict) | ||
agent.train(render=False, n_episodes=16) | ||
agent.train(render=False, n_episodes=25) | ||
agent.save() | ||
|
||
# Eval | ||
env_spec = gym.make('SDS-746-v0').spec | ||
sim = Sim(env_spec=env_spec, agent=agent, n_steps=200, plot=True, save=True, tqdm_on=True, | ||
env_spec = gym.make(env_name).spec | ||
sim = Sim(env_spec=env_spec, agent=agent, n_steps=200, plot=False, save=True, tqdm_on=True, logging=True, | ||
save_dir='exps/untargeted_dqn') | ||
sim.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
MAJOR = 0 | ||
MINOR = 10 | ||
PATCH = 1 | ||
PATCH = 2 | ||
|
||
__version__ = ".".join(str(v) for v in [MAJOR, MINOR, PATCH]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.