From dac28047bb94d344e54632649d7ab11d2d296e29 Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Thu, 7 Dec 2023 14:25:09 +0530 Subject: [PATCH 1/6] Add envpool to openrl --- examples/envpool/test_model.py | 78 ++++++++++ examples/envpool/train_ppo.py | 86 +++++++++++ openrl/envs/common/build_envs.py | 24 ++- openrl/envs/common/registration.py | 14 +- openrl/envs/envpool/__init__.py | 47 ++++++ openrl/envs/wrappers/envpool_wrappers.py | 182 +++++++++++++++++++++++ setup.py | 2 + 7 files changed, 425 insertions(+), 8 deletions(-) create mode 100644 examples/envpool/test_model.py create mode 100644 examples/envpool/train_ppo.py create mode 100644 openrl/envs/envpool/__init__.py create mode 100644 openrl/envs/wrappers/envpool_wrappers.py diff --git a/examples/envpool/test_model.py b/examples/envpool/test_model.py new file mode 100644 index 00000000..c0b4ddfb --- /dev/null +++ b/examples/envpool/test_model.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +# Use OpenRL to load stable-baselines's model for testing + +import numpy as np +import torch + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.modules.networks.policy_value_network_sb3 import ( + PolicyValueNetworkSB3 as PolicyValueNetwork, +) +from openrl.runners.common import PPOAgent as Agent + + +def evaluation(local_trained_file_path=None): + # begin to test + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "ppo.yaml"]) + + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) + model_dict = {"model": PolicyValueNetwork} + net = Net( + env, + cfg=cfg, + model_dict=model_dict, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + # initialize the trainer + agent = Agent( + net, + ) + if local_trained_file_path is not None: + agent.load(local_trained_file_path) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + + total_step = 0 + total_reward = 0.0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + total_step += 1 + total_reward += np.mean(r) + if total_step % 50 == 0: + print(f"{total_step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + evaluation() diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py new file mode 100644 index 00000000..4120ee4a --- /dev/null +++ b/examples/envpool/train_ppo.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import numpy as np +from test_model import evaluation + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.envpool_wrappers import VecAdapter, VecMonitor +from openrl.modules.common import PPONet as Net +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + + +def train(): + # create the neural network + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + # create environment, set environment parallelism to 9 + env = make( + "envpool:Adventure-v5", + render_mode=None, + env_num=9, + asynchronous=False, + env_wrappers=[VecAdapter, VecMonitor], + env_type="gym", + ) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net, use_wandb=False, project_name="envpool:Adventure-v5") + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=20000) + + env.close() + return agent + + +def evaluation(agent): + # begin to test + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + step = 0 + total_step, total_reward = 0, 0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + total_step += 1 + total_reward += np.mean(r) + if step % 50 == 0: + print(f"{step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + agent = train() + evaluation(agent) diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 76f4b35b..386c4adc 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -6,6 +6,7 @@ from gymnasium import Env from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.envs.wrappers.envpool_wrappers import VecEnvWrapper, VecMonitor def build_envs( @@ -36,13 +37,22 @@ def _make_env() -> Env: new_kwargs["env_num"] = env_num if id.startswith("ALE/") or id in gym.envs.registry.keys(): new_kwargs.pop("cfg", None) - - env = make( - id, - render_mode=env_render_mode, - disable_env_checker=_disable_env_checker, - **new_kwargs, - ) + if "envpool" in new_kwargs: + # for now envpool doesnt support any render mode + # envpool also doesnt stores the id anywhere + new_kwargs.pop("envpool") + env = make( + id, + **new_kwargs, + ) + env.unwrapped.spec.id = id + else: + env = make( + id, + render_mode=env_render_mode, + disable_env_checker=_disable_env_checker, + **new_kwargs, + ) if wrappers is not None: if callable(wrappers): diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 5d1ed645..053dd104 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -17,6 +17,7 @@ """""" from typing import Callable, Optional +import envpool import gymnasium as gym import openrl @@ -72,7 +73,6 @@ def make( env_fns = make_single_agent_drone_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) - elif id.startswith("snakes_"): from openrl.envs.snake import make_snake_envs @@ -155,6 +155,18 @@ def make( env_fns = make_PettingZoo_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) + elif ( + "envpool:" in id + and id.split(":")[-1] in envpool.registration.list_all_envs() + ): + from openrl.envs.envpool import make_envpool_envs + + env_fns = make_envpool_envs( + id=id.split(":")[-1], + env_num=env_num, + render_mode=convert_render_mode, + **kwargs, + ) else: raise NotImplementedError(f"env {id} is not supported.") diff --git a/openrl/envs/envpool/__init__.py b/openrl/envs/envpool/__init__.py new file mode 100644 index 00000000..48fbd1f5 --- /dev/null +++ b/openrl/envs/envpool/__init__.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import List, Optional, Union + +import envpool + +from openrl.envs.common import build_envs + + +def make_envpool_envs( + id: str, + env_num: int = 1, + render_mode: Optional[Union[str, List[str]]] = None, + **kwargs, +): + assert "env_type" in kwargs + assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] + # Since render_mode is not supported, we set envpool to True + # so that we can remove render_mode keyword argument from build_envs + assert render_mode is None, "envpool does not support render_mode yet" + kwargs["envpool"] = True + + env_wrappers = kwargs.pop("env_wrappers") + env_fns = build_envs( + make=envpool.make, + id=id, + env_num=env_num, + render_mode=render_mode, + wrappers=env_wrappers, + **kwargs, + ) + return env_fns diff --git a/openrl/envs/wrappers/envpool_wrappers.py b/openrl/envs/wrappers/envpool_wrappers.py new file mode 100644 index 00000000..d0da090a --- /dev/null +++ b/openrl/envs/wrappers/envpool_wrappers.py @@ -0,0 +1,182 @@ +import time +import warnings +from typing import Optional + +import gym +import gymnasium +import numpy as np +from envpool.python.protocol import EnvPool +from packaging import version +from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper +from stable_baselines3.common.vec_env import VecMonitor +from stable_baselines3.common.vec_env.base_vec_env import (VecEnvObs, + VecEnvStepReturn) + +is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0") + + +class VecEnvWrapper(BaseWrapper): + @property + def agent_num(self): + if self.is_original_envpool_env(): + return 1 + else: + return self.env.agent_num + + def is_original_envpool_env(self): + return not hasattr(self.venv, "agent_num`") + + +class VecAdapter(VecEnvWrapper): + """ + Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + + :param venv: The envpool object. + """ + + def __init__(self, venv: EnvPool): + venv.num_envs = venv.spec.config.num_envs + observation_space = venv.observation_space + new_observation_space = gymnasium.spaces.Box( + low=observation_space.low, + high=observation_space.high, + dtype=observation_space.dtype, + ) + action_space = venv.action_space + if isinstance(action_space, gym.spaces.Discrete): + new_action_space = gymnasium.spaces.Discrete(action_space.n) + elif isinstance(action_space, gym.spaces.MultiDiscrete): + new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec) + elif isinstance(action_space, gym.spaces.MultiBinary): + new_action_space = gymnasium.spaces.MultiBinary(action_space.n) + elif isinstance(action_space, gym.spaces.Box): + new_action_space = gymnasium.spaces.Box( + low=action_space.low, + high=action_space.high, + dtype=action_space.dtype, + ) + else: + raise NotImplementedError(f"Action space {action_space} is not supported") + super().__init__( + venv=venv, + observation_space=new_observation_space, + action_space=new_action_space, + ) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def reset(self) -> VecEnvObs: + if is_legacy_gym: + return self.venv.reset(), {} + else: + return self.venv.reset() + + def step_wait(self) -> VecEnvStepReturn: + if is_legacy_gym: + obs, rewards, dones, info_dict = self.venv.step(self.actions) + else: + obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions) + dones = terms + truncs + rewards = rewards + infos = [] + for i in range(self.num_envs): + infos.append( + { + key: info_dict[key][i] + for key in info_dict.keys() + if isinstance(info_dict[key], np.ndarray) + } + ) + if dones[i]: + infos[i]["terminal_observation"] = obs[i] + if is_legacy_gym: + obs[i] = self.venv.reset(np.array([i])) + else: + obs[i] = self.venv.reset(np.array([i]))[0] + return obs, rewards, dones, infos + + +class VecMonitor(VecEnvWrapper): + def __init__( + self, + venv, + filename: Optional[str] = None, + info_keywords=(), + ): + # Avoid circular import + from stable_baselines3.common.monitor import Monitor, ResultsWriter + + try: + is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] + except AttributeError: + is_wrapped_with_monitor = False + + if is_wrapped_with_monitor: + warnings.warn( + "The environment is already wrapped with a `Monitor` wrapper" + "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" + "overwritten by the `VecMonitor` ones.", + UserWarning, + ) + + VecEnvWrapper.__init__(self, venv) + self.episode_count = 0 + self.t_start = time.time() + + env_id = None + if hasattr(venv, "spec") and venv.spec is not None: + env_id = venv.spec.id + + self.results_writer: Optional[ResultsWriter] = None + if filename: + self.results_writer = ResultsWriter( + filename, + header={"t_start": self.t_start, "env_id": str(env_id)}, + extra_keys=info_keywords, + ) + + self.info_keywords = info_keywords + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + + def reset(self, **kwargs) -> VecEnvObs: + obs, info = self.venv.reset() + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return obs, info + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, infos = self.venv.step_wait() + self.episode_returns += rewards + self.episode_lengths += 1 + new_infos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = { + "r": episode_return, + "l": episode_length, + "t": round(time.time() - self.t_start, 6), + } + for key in self.info_keywords: + episode_info[key] = info[key] + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info + rewards = np.expand_dims(rewards, 1) + return obs, rewards, dones, new_infos + + def close(self) -> None: + if self.results_writer: + self.results_writer.close() + return self.venv.close() + + +__all__ = ["VecAdapter", "VecMonitor"] diff --git a/setup.py b/setup.py index 28cffd3c..faffbe84 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ def get_extra_requires() -> dict: "async_timeout", "pettingzoo[classic]", "trueskill", + "envpool", ], "selfplay_test": [ "ray[default]>=2.7", @@ -84,6 +85,7 @@ def get_extra_requires() -> dict: "fastapi", "pettingzoo[mpe]", "pettingzoo[butterfly]", + "envpool", ], "retro": ["gym-retro"], "super_mario": ["gym-super-mario-bros"], From 3c31b5a48a35830af2d1cfd633d04bdf05fae06a Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Thu, 7 Dec 2023 15:49:19 +0530 Subject: [PATCH 2/6] Remove unwanted test for envpool --- examples/envpool/test_model.py | 78 ---------------------------------- examples/envpool/train_ppo.py | 1 - 2 files changed, 79 deletions(-) delete mode 100644 examples/envpool/test_model.py diff --git a/examples/envpool/test_model.py b/examples/envpool/test_model.py deleted file mode 100644 index c0b4ddfb..00000000 --- a/examples/envpool/test_model.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2023 The OpenRL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""""" - -# Use OpenRL to load stable-baselines's model for testing - -import numpy as np -import torch - -from openrl.configs.config import create_config_parser -from openrl.envs.common import make -from openrl.modules.common.ppo_net import PPONet as Net -from openrl.modules.networks.policy_value_network_sb3 import ( - PolicyValueNetworkSB3 as PolicyValueNetwork, -) -from openrl.runners.common import PPOAgent as Agent - - -def evaluation(local_trained_file_path=None): - # begin to test - - cfg_parser = create_config_parser() - cfg = cfg_parser.parse_args(["--config", "ppo.yaml"]) - - # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. - render_mode = "group_human" - render_mode = None - env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) - model_dict = {"model": PolicyValueNetwork} - net = Net( - env, - cfg=cfg, - model_dict=model_dict, - device="cuda" if torch.cuda.is_available() else "cpu", - ) - # initialize the trainer - agent = Agent( - net, - ) - if local_trained_file_path is not None: - agent.load(local_trained_file_path) - # The trained agent sets up the interactive environment it needs. - agent.set_env(env) - # Initialize the environment and get initial observations and environmental information. - obs, info = env.reset() - done = False - - total_step = 0 - total_reward = 0.0 - while not np.any(done): - # Based on environmental observation input, predict next action. - action, _ = agent.act(obs, deterministic=True) - obs, r, done, info = env.step(action) - total_step += 1 - total_reward += np.mean(r) - if total_step % 50 == 0: - print(f"{total_step}: reward:{np.mean(r)}") - env.close() - print("total step:", total_step) - print("total reward:", total_reward) - - -if __name__ == "__main__": - evaluation() diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py index 4120ee4a..49de50f4 100644 --- a/examples/envpool/train_ppo.py +++ b/examples/envpool/train_ppo.py @@ -16,7 +16,6 @@ """""" import numpy as np -from test_model import evaluation from openrl.configs.config import create_config_parser from openrl.envs.common import make From 990f8c5fd3e448f5594a20f05a48a3d8ed65696c Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Thu, 7 Dec 2023 15:52:42 +0530 Subject: [PATCH 3/6] Fix a typo: envpool/train-ppo --- examples/envpool/train_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py index 49de50f4..49eb4456 100644 --- a/examples/envpool/train_ppo.py +++ b/examples/envpool/train_ppo.py @@ -32,7 +32,7 @@ def train(): # create environment, set environment parallelism to 9 env = make( - "envpool:Adventure-v5", + "envpool:CartPole-v1", render_mode=None, env_num=9, asynchronous=False, @@ -45,7 +45,7 @@ def train(): cfg=cfg, ) # initialize the trainer - agent = Agent(net, use_wandb=False, project_name="envpool:Adventure-v5") + agent = Agent(net, use_wandb=False, project_name="envpool:CartPole-v1") # start training, set total number of training steps to 20000 agent.train(total_time_steps=20000) From 448cc6f1189ca0ed46aa57b23c87b96e23f99e21 Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Thu, 7 Dec 2023 16:25:03 +0530 Subject: [PATCH 4/6] Fix dependency error: stablebaseline3 --- openrl/envs/common/build_envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 386c4adc..37c17c01 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -6,7 +6,6 @@ from gymnasium import Env from openrl.envs.wrappers.base_wrapper import BaseWrapper -from openrl.envs.wrappers.envpool_wrappers import VecEnvWrapper, VecMonitor def build_envs( From 753fba760929fb5502eef105c59d19113e88d8b1 Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Wed, 20 Dec 2023 11:47:42 +0530 Subject: [PATCH 5/6] Move envpool to examples --- examples/envpool/README.md | 20 +++ .../envpool}/envpool_wrappers.py | 0 examples/envpool/make_env.py | 128 ++++++++++++++++++ examples/envpool/train_ppo.py | 17 ++- openrl/envs/common/registration.py | 13 -- openrl/envs/envpool/__init__.py | 47 ------- setup.py | 2 - 7 files changed, 160 insertions(+), 67 deletions(-) create mode 100644 examples/envpool/README.md rename {openrl/envs/wrappers => examples/envpool}/envpool_wrappers.py (100%) create mode 100644 examples/envpool/make_env.py delete mode 100644 openrl/envs/envpool/__init__.py diff --git a/examples/envpool/README.md b/examples/envpool/README.md new file mode 100644 index 00000000..e9a16389 --- /dev/null +++ b/examples/envpool/README.md @@ -0,0 +1,20 @@ +## Installation + + +Install envpool with: + +``` shell +pip install envpool +``` + +Note 1: envpool only supports Linux operating system. + +## Usage + +You can use `OpenRL` to train Cartpole (envpool) via: + +``` shell +PYTHON_PATH train_ppo.py +``` + +You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers. \ No newline at end of file diff --git a/openrl/envs/wrappers/envpool_wrappers.py b/examples/envpool/envpool_wrappers.py similarity index 100% rename from openrl/envs/wrappers/envpool_wrappers.py rename to examples/envpool/envpool_wrappers.py diff --git a/examples/envpool/make_env.py b/examples/envpool/make_env.py new file mode 100644 index 00000000..92c1b51a --- /dev/null +++ b/examples/envpool/make_env.py @@ -0,0 +1,128 @@ +import copy +import inspect +from typing import Callable, Iterable, List, Optional, Union + +import envpool +from gymnasium import Env + + +from openrl.envs.vec_env import (AsyncVectorEnv, RewardWrapper, + SyncVectorEnv, VecMonitorWrapper) +from openrl.envs.vec_env.vec_info import VecInfoFactory +from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.rewards import RewardFactory + + +def build_envs( + make, + id: str, + env_num: int = 1, + wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None, + need_env_id: bool = False, + **kwargs, +) -> List[Callable[[], Env]]: + cfg = kwargs.get("cfg", None) + + def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]: + def _make_env() -> Env: + new_kwargs = copy.deepcopy(kwargs) + if need_env_id: + new_kwargs["env_id"] = env_id + new_kwargs["env_num"] = env_num + if "envpool" in new_kwargs: + # for now envpool doesnt support any render mode + # envpool also doesnt stores the id anywhere + new_kwargs.pop("envpool") + env = make( + id, + **new_kwargs, + ) + env.unwrapped.spec.id = id + + if wrappers is not None: + if callable(wrappers): + if issubclass(wrappers, BaseWrapper): + env = wrappers(env, cfg=cfg) + else: + env = wrappers(env) + elif isinstance(wrappers, Iterable) and all( + [callable(w) for w in wrappers] + ): + for wrapper in wrappers: + if ( + issubclass(wrapper, BaseWrapper) + and "cfg" in inspect.signature(wrapper.__init__).parameters + ): + env = wrapper(env, cfg=cfg) + else: + env = wrapper(env) + else: + raise NotImplementedError + + return env + + return _make_env + + env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)] + return env_fns + + +def make_envpool_envs( + id: str, + env_num: int = 1, + **kwargs, +): + assert "env_type" in kwargs + assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] + kwargs["envpool"] = True + + if 'env_wrappers' in kwargs: + env_wrappers = kwargs.pop("env_wrappers") + else: + env_wrappers = [] + env_fns = build_envs( + make=envpool.make, + id=id, + env_num=env_num, + wrappers=env_wrappers, + **kwargs, + ) + return env_fns + + +def make( + id: str, + env_num: int = 1, + asynchronous: bool = False, + add_monitor: bool = True, + render_mode: Optional[str] = None, + auto_reset: bool = True, + **kwargs, +): + cfg = kwargs.get("cfg", None) + if id in envpool.registration.list_all_envs(): + env_fns = make_envpool_envs( + id=id.split(":")[-1], + env_num=env_num, + **kwargs, + ) + if asynchronous: + env = AsyncVectorEnv( + env_fns, render_mode=render_mode, auto_reset=auto_reset + ) + else: + env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset) + + reward_class = cfg.reward_class if cfg else None + reward_class = RewardFactory.get_reward_class(reward_class, env) + + env = RewardWrapper(env, reward_class) + + if add_monitor: + vec_info_class = cfg.vec_info_class if cfg else None + vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) + env = VecMonitorWrapper(vec_info_class, env) + + return env + else: + raise NotImplementedError(f"env {id} is not supported") diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py index 49eb4456..a02151f7 100644 --- a/examples/envpool/train_ppo.py +++ b/examples/envpool/train_ppo.py @@ -18,8 +18,8 @@ import numpy as np from openrl.configs.config import create_config_parser -from openrl.envs.common import make -from openrl.envs.wrappers.envpool_wrappers import VecAdapter, VecMonitor +from make_env import make +from examples.envpool.envpool_wrappers import VecAdapter, VecMonitor from openrl.modules.common import PPONet as Net from openrl.modules.common.ppo_net import PPONet as Net from openrl.runners.common import PPOAgent as Agent @@ -32,7 +32,7 @@ def train(): # create environment, set environment parallelism to 9 env = make( - "envpool:CartPole-v1", + "CartPole-v1", render_mode=None, env_num=9, asynchronous=False, @@ -45,7 +45,7 @@ def train(): cfg=cfg, ) # initialize the trainer - agent = Agent(net, use_wandb=False, project_name="envpool:CartPole-v1") + agent = Agent(net, use_wandb=False, project_name="CartPole-v1") # start training, set total number of training steps to 20000 agent.train(total_time_steps=20000) @@ -58,7 +58,14 @@ def evaluation(agent): # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. render_mode = "group_human" render_mode = None - env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) + env = make( + "CartPole-v1", + env_wrappers=[VecAdapter, VecMonitor], + render_mode=render_mode, + env_num=9, + asynchronous=True, + env_type="gym", + ) # The trained agent sets up the interactive environment it needs. agent.set_env(env) # Initialize the environment and get initial observations and environmental information. diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 053dd104..bb6c4462 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -17,7 +17,6 @@ """""" from typing import Callable, Optional -import envpool import gymnasium as gym import openrl @@ -155,18 +154,6 @@ def make( env_fns = make_PettingZoo_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) - elif ( - "envpool:" in id - and id.split(":")[-1] in envpool.registration.list_all_envs() - ): - from openrl.envs.envpool import make_envpool_envs - - env_fns = make_envpool_envs( - id=id.split(":")[-1], - env_num=env_num, - render_mode=convert_render_mode, - **kwargs, - ) else: raise NotImplementedError(f"env {id} is not supported.") diff --git a/openrl/envs/envpool/__init__.py b/openrl/envs/envpool/__init__.py deleted file mode 100644 index 48fbd1f5..00000000 --- a/openrl/envs/envpool/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2023 The OpenRL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""""" -from typing import List, Optional, Union - -import envpool - -from openrl.envs.common import build_envs - - -def make_envpool_envs( - id: str, - env_num: int = 1, - render_mode: Optional[Union[str, List[str]]] = None, - **kwargs, -): - assert "env_type" in kwargs - assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] - # Since render_mode is not supported, we set envpool to True - # so that we can remove render_mode keyword argument from build_envs - assert render_mode is None, "envpool does not support render_mode yet" - kwargs["envpool"] = True - - env_wrappers = kwargs.pop("env_wrappers") - env_fns = build_envs( - make=envpool.make, - id=id, - env_num=env_num, - render_mode=render_mode, - wrappers=env_wrappers, - **kwargs, - ) - return env_fns diff --git a/setup.py b/setup.py index faffbe84..28cffd3c 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,6 @@ def get_extra_requires() -> dict: "async_timeout", "pettingzoo[classic]", "trueskill", - "envpool", ], "selfplay_test": [ "ray[default]>=2.7", @@ -85,7 +84,6 @@ def get_extra_requires() -> dict: "fastapi", "pettingzoo[mpe]", "pettingzoo[butterfly]", - "envpool", ], "retro": ["gym-retro"], "super_mario": ["gym-super-mario-bros"], From 693d2e1c4d46ef636d0a8d3f4962378e9fa95da0 Mon Sep 17 00:00:00 2001 From: Geo Jolly Date: Wed, 20 Dec 2023 11:51:07 +0530 Subject: [PATCH 6/6] Revert files in openrl folder --- openrl/envs/common/build_envs.py | 25 ++++++++----------------- openrl/envs/common/registration.py | 3 ++- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 37c17c01..a0c59c6f 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -36,22 +36,13 @@ def _make_env() -> Env: new_kwargs["env_num"] = env_num if id.startswith("ALE/") or id in gym.envs.registry.keys(): new_kwargs.pop("cfg", None) - if "envpool" in new_kwargs: - # for now envpool doesnt support any render mode - # envpool also doesnt stores the id anywhere - new_kwargs.pop("envpool") - env = make( - id, - **new_kwargs, - ) - env.unwrapped.spec.id = id - else: - env = make( - id, - render_mode=env_render_mode, - disable_env_checker=_disable_env_checker, - **new_kwargs, - ) + + env = make( + id, + render_mode=env_render_mode, + disable_env_checker=_disable_env_checker, + **new_kwargs, + ) if wrappers is not None: if callable(wrappers): @@ -78,4 +69,4 @@ def _make_env() -> Env: return _make_env env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)] - return env_fns + return env_fns \ No newline at end of file diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index bb6c4462..1ee9b532 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -72,6 +72,7 @@ def make( env_fns = make_single_agent_drone_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) + elif id.startswith("snakes_"): from openrl.envs.snake import make_snake_envs @@ -172,4 +173,4 @@ def make( vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) env = VecMonitorWrapper(vec_info_class, env) - return env + return env \ No newline at end of file