From 6b73e59ebeeb60aaa34e40e8a8c5a8363416224e Mon Sep 17 00:00:00 2001 From: Ivan-267 <61947090+Ivan-267@users.noreply.github.com> Date: Fri, 4 Aug 2023 00:20:29 +0200 Subject: [PATCH 1/2] Adds n_parallel to cleanrl --- examples/clean_rl_example.py | 71 ++++++++++---------- godot_rl/wrappers/clean_rl_wrapper.py | 96 +++++++++++++++++++-------- 2 files changed, 106 insertions(+), 61 deletions(-) diff --git a/examples/clean_rl_example.py b/examples/clean_rl_example.py index 1f123b65..3e81e5c6 100644 --- a/examples/clean_rl_example.py +++ b/examples/clean_rl_example.py @@ -13,68 +13,70 @@ from torch.utils.tensorboard import SummaryWriter from godot_rl.wrappers.clean_rl_wrapper import CleanRLGodotEnv + def parse_args(): # fmt: off parser = argparse.ArgumentParser() - parser.add_argument("--viz", default=False, type=bool, - help="If set, the simulation will be displayed in a window during training. Otherwise " - "training will run without rendering the simualtion. This setting does not apply to in-editor training.") + parser.add_argument("--viz", action="store_true", default=False, + help="If set, the simulation will be displayed in a window during training. Otherwise " + "training will run without rendering the simulation. This setting does not apply to " + "in-editor training.") parser.add_argument("--experiment_dir", default="logs/cleanrl", type=str, - help="The name of the experiment directory, in which the tensorboard logs are getting stored") + help="The name of the experiment directory, in which the tensorboard logs are getting stored") parser.add_argument("--experiment_name", default=os.path.basename(__file__).rstrip(".py"), type=str, - help="The name of the experiment, which will be displayed in tensorboard") + help="The name of the experiment, which will be displayed in tensorboard") parser.add_argument("--seed", type=int, default=1, - help="seed of the experiment") + help="seed of the experiment") parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, `torch.backends.cudnn.deterministic=False`") + help="if toggled, `torch.backends.cudnn.deterministic=False`") parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, cuda will be enabled by default") + help="if toggled, cuda will be enabled by default") parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="if toggled, this experiment will be tracked with Weights and Biases") + help="if toggled, this experiment will be tracked with Weights and Biases") parser.add_argument("--wandb-project-name", type=str, default="cleanRL", - help="the wandb's project name") + help="the wandb's project name") parser.add_argument("--wandb-entity", type=str, default=None, - help="the entity (team) of wandb's project") + help="the entity (team) of wandb's project") parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="whether to capture videos of the agent performances (check out `videos` folder)") + help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments parser.add_argument("--env_path", type=str, default=None, - help="the path of the environment") + help="the path of the environment") parser.add_argument("--speedup", type=int, default=8, - help="the speedup of the godot environment") + help="the speedup of the godot environment") parser.add_argument("--total-timesteps", type=int, default=1000000, - help="total timesteps of the experiments") + help="total timesteps of the experiments") parser.add_argument("--learning-rate", type=float, default=3e-4, - help="the learning rate of the optimizer") - parser.add_argument("--num-envs", type=int, default=1, - help="the number of parallel game environments") + help="the learning rate of the optimizer") parser.add_argument("--num-steps", type=int, default=32, - help="the number of steps to run in each environment per policy rollout") + help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggle learning rate annealing for policy and value networks") + help="Toggle learning rate annealing for policy and value networks") parser.add_argument("--gamma", type=float, default=0.99, - help="the discount factor gamma") + help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, - help="the lambda for the general advantage estimation") + help="the lambda for the general advantage estimation") parser.add_argument("--num-minibatches", type=int, default=8, - help="the number of mini-batches") + help="the number of mini-batches") parser.add_argument("--update-epochs", type=int, default=10, - help="the K epochs to update the policy") + help="the K epochs to update the policy") parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggles advantages normalization") + help="Toggles advantages normalization") parser.add_argument("--clip-coef", type=float, default=0.2, - help="the surrogate clipping coefficient") + help="the surrogate clipping coefficient") parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") parser.add_argument("--ent-coef", type=float, default=0.0001, - help="coefficient of the entropy") + help="coefficient of the entropy") parser.add_argument("--vf-coef", type=float, default=0.5, - help="coefficient of the value function") + help="coefficient of the value function") parser.add_argument("--max-grad-norm", type=float, default=0.5, - help="the maximum norm for the gradient clipping") + help="the maximum norm for the gradient clipping") parser.add_argument("--target-kl", type=float, default=None, - help="the target KL divergence threshold") + help="the target KL divergence threshold") + parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to " + "launch - requires --env_path to be set if > 1.") args = parser.parse_args() # fmt: on @@ -85,6 +87,7 @@ def make_env(env_path, speedup): def thunk(): env = CleanRLGodotEnv(env_path=env_path, show_window=True, speedup=speedup) return env + return thunk @@ -156,8 +159,8 @@ def get_action_and_value(self, x, action=None): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - - envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, convert_action_space=True) # Godot envs are already vectorized + + envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel) args.num_envs = envs.num_envs args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) @@ -211,7 +214,7 @@ def get_action_and_value(self, x, action=None): next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) accum_rewards += np.array(reward) - + for i, d in enumerate(done): if d: episode_returns.append(accum_rewards[i]) diff --git a/godot_rl/wrappers/clean_rl_wrapper.py b/godot_rl/wrappers/clean_rl_wrapper.py index bd73d4e2..edc0497d 100644 --- a/godot_rl/wrappers/clean_rl_wrapper.py +++ b/godot_rl/wrappers/clean_rl_wrapper.py @@ -1,47 +1,89 @@ - import numpy as np import gymnasium as gym +from numpy import ndarray + from godot_rl.core.utils import lod_to_dol from godot_rl.core.godot_env import GodotEnv +from typing import Any, Dict, List, Optional, Tuple, Union class CleanRLGodotEnv: - def __init__(self, env_path=None, convert_action_space=False, **kwargs): - # convert_action_space: combine multiple continue action spaces into one larger space - self._env = GodotEnv(env_path=env_path,convert_action_space=convert_action_space, **kwargs) - + def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: int = 0, **kwargs: object) -> None: + + # If we are doing editor training, n_parallel must be 1 + if env_path is None and n_parallel > 1: + raise ValueError("You must provide the path to a exported game executable if n_parallel > 1") + + # Define the default port + port = kwargs.pop("port", GodotEnv.DEFAULT_PORT) - def _check_valid_action_space(self): - action_space = self._env.action_space + # Create a list of GodotEnv instances + self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs) for + p in range(n_parallel)] + + # Store the number of parallel environments + self.n_parallel = n_parallel + + def _check_valid_action_space(self) -> None: + # Check if the action space is a tuple space with multiple spaces + action_space = self.envs[0].action_space if isinstance(action_space, gym.spaces.Tuple): assert ( - len(action_space.spaces) == 1 - ), f"clearn rl supports a single action space, this env constains multiple spaces {action_space}" + len(action_space.spaces) == 1 + ), f"sb3 supports a single action space, this env contains multiple spaces {action_space}" + + def step(self, action: np.ndarray) -> tuple[ndarray, list[Any], list[Any], list[Any], list[Any]]: + # Initialize lists for collecting results + all_obs = [] + all_rewards = [] + all_term = [] + all_trunc = [] + all_info = [] - @staticmethod - def action_preprocessor(action): - return action + # Get the number of environments + num_envs = self.envs[0].num_envs - def step(self, action): - action = self.action_preprocessor(action) - obs, reward, term, trunc, info = self._env.step(action) - obs = lod_to_dol(obs) - return np.stack(obs["obs"]), reward, term, trunc, info + # Send actions to each environment + for i in range(self.n_parallel): + self.envs[i].step_send(action[i * num_envs:(i + 1) * num_envs]) - def reset(self, seed): - obs, info = self._env.reset(seed) - obs = lod_to_dol(obs) - return np.stack(obs["obs"]), info + # Receive results from each environment + for i in range(self.n_parallel): + obs, reward, term, trunc, info = self.envs[i].step_recv() + all_obs.extend(obs) + all_rewards.extend(reward) + all_term.extend(term) + all_trunc.extend(trunc) + all_info.extend(info) + + # Convert list of dictionaries to dictionary of lists + obs = lod_to_dol(all_obs) + + # Return results + return np.stack(obs["obs"]), all_rewards, all_term, all_trunc, all_info + + def reset(self, seed) -> tuple[ndarray, list[Any]]: + # Initialize lists for collecting results + all_obs = [] + all_info = [] + + # Reset each environment + for i in range(self.n_parallel): + obs, info = self.envs[i].reset() + all_obs.extend(obs) + all_info.extend(info) + + # Convert list of dictionaries to dictionary of lists + obs = lod_to_dol(all_obs) + return np.stack(obs["obs"]), all_info @property def single_observation_space(self): - return self._env.observation_space["obs"] + return self.envs[0].observation_space["obs"] @property def single_action_space(self): - return self._env.action_space - - + return self.envs[0].action_space @property - def num_envs(self): - return self._env.num_envs \ No newline at end of file + def num_envs(self) -> int: + return self.envs[0].num_envs * self.n_parallel From 680700db75a2db1e544f830831cc4aa738b1cda6 Mon Sep 17 00:00:00 2001 From: Ivan-267 <61947090+Ivan-267@users.noreply.github.com> Date: Fri, 4 Aug 2023 00:45:06 +0200 Subject: [PATCH 2/2] Adds close method --- godot_rl/wrappers/clean_rl_wrapper.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/godot_rl/wrappers/clean_rl_wrapper.py b/godot_rl/wrappers/clean_rl_wrapper.py index edc0497d..0c13fdfe 100644 --- a/godot_rl/wrappers/clean_rl_wrapper.py +++ b/godot_rl/wrappers/clean_rl_wrapper.py @@ -84,6 +84,12 @@ def single_observation_space(self): @property def single_action_space(self): return self.envs[0].action_space + @property def num_envs(self) -> int: return self.envs[0].num_envs * self.n_parallel + + def close(self) -> None: + # Close each environment + for env in self.envs: + env.close()