diff --git a/examples/envpool/README.md b/examples/envpool/README.md new file mode 100644 index 0000000..e9a1638 --- /dev/null +++ b/examples/envpool/README.md @@ -0,0 +1,20 @@ +## Installation + + +Install envpool with: + +``` shell +pip install envpool +``` + +Note 1: envpool only supports Linux operating system. + +## Usage + +You can use `OpenRL` to train Cartpole (envpool) via: + +``` shell +PYTHON_PATH train_ppo.py +``` + +You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers. \ No newline at end of file diff --git a/examples/envpool/envpool_wrappers.py b/examples/envpool/envpool_wrappers.py new file mode 100644 index 0000000..d0da090 --- /dev/null +++ b/examples/envpool/envpool_wrappers.py @@ -0,0 +1,182 @@ +import time +import warnings +from typing import Optional + +import gym +import gymnasium +import numpy as np +from envpool.python.protocol import EnvPool +from packaging import version +from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper +from stable_baselines3.common.vec_env import VecMonitor +from stable_baselines3.common.vec_env.base_vec_env import (VecEnvObs, + VecEnvStepReturn) + +is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0") + + +class VecEnvWrapper(BaseWrapper): + @property + def agent_num(self): + if self.is_original_envpool_env(): + return 1 + else: + return self.env.agent_num + + def is_original_envpool_env(self): + return not hasattr(self.venv, "agent_num`") + + +class VecAdapter(VecEnvWrapper): + """ + Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + + :param venv: The envpool object. + """ + + def __init__(self, venv: EnvPool): + venv.num_envs = venv.spec.config.num_envs + observation_space = venv.observation_space + new_observation_space = gymnasium.spaces.Box( + low=observation_space.low, + high=observation_space.high, + dtype=observation_space.dtype, + ) + action_space = venv.action_space + if isinstance(action_space, gym.spaces.Discrete): + new_action_space = gymnasium.spaces.Discrete(action_space.n) + elif isinstance(action_space, gym.spaces.MultiDiscrete): + new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec) + elif isinstance(action_space, gym.spaces.MultiBinary): + new_action_space = gymnasium.spaces.MultiBinary(action_space.n) + elif isinstance(action_space, gym.spaces.Box): + new_action_space = gymnasium.spaces.Box( + low=action_space.low, + high=action_space.high, + dtype=action_space.dtype, + ) + else: + raise NotImplementedError(f"Action space {action_space} is not supported") + super().__init__( + venv=venv, + observation_space=new_observation_space, + action_space=new_action_space, + ) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def reset(self) -> VecEnvObs: + if is_legacy_gym: + return self.venv.reset(), {} + else: + return self.venv.reset() + + def step_wait(self) -> VecEnvStepReturn: + if is_legacy_gym: + obs, rewards, dones, info_dict = self.venv.step(self.actions) + else: + obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions) + dones = terms + truncs + rewards = rewards + infos = [] + for i in range(self.num_envs): + infos.append( + { + key: info_dict[key][i] + for key in info_dict.keys() + if isinstance(info_dict[key], np.ndarray) + } + ) + if dones[i]: + infos[i]["terminal_observation"] = obs[i] + if is_legacy_gym: + obs[i] = self.venv.reset(np.array([i])) + else: + obs[i] = self.venv.reset(np.array([i]))[0] + return obs, rewards, dones, infos + + +class VecMonitor(VecEnvWrapper): + def __init__( + self, + venv, + filename: Optional[str] = None, + info_keywords=(), + ): + # Avoid circular import + from stable_baselines3.common.monitor import Monitor, ResultsWriter + + try: + is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] + except AttributeError: + is_wrapped_with_monitor = False + + if is_wrapped_with_monitor: + warnings.warn( + "The environment is already wrapped with a `Monitor` wrapper" + "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" + "overwritten by the `VecMonitor` ones.", + UserWarning, + ) + + VecEnvWrapper.__init__(self, venv) + self.episode_count = 0 + self.t_start = time.time() + + env_id = None + if hasattr(venv, "spec") and venv.spec is not None: + env_id = venv.spec.id + + self.results_writer: Optional[ResultsWriter] = None + if filename: + self.results_writer = ResultsWriter( + filename, + header={"t_start": self.t_start, "env_id": str(env_id)}, + extra_keys=info_keywords, + ) + + self.info_keywords = info_keywords + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + + def reset(self, **kwargs) -> VecEnvObs: + obs, info = self.venv.reset() + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return obs, info + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, infos = self.venv.step_wait() + self.episode_returns += rewards + self.episode_lengths += 1 + new_infos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = { + "r": episode_return, + "l": episode_length, + "t": round(time.time() - self.t_start, 6), + } + for key in self.info_keywords: + episode_info[key] = info[key] + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info + rewards = np.expand_dims(rewards, 1) + return obs, rewards, dones, new_infos + + def close(self) -> None: + if self.results_writer: + self.results_writer.close() + return self.venv.close() + + +__all__ = ["VecAdapter", "VecMonitor"] diff --git a/examples/envpool/make_env.py b/examples/envpool/make_env.py new file mode 100644 index 0000000..92c1b51 --- /dev/null +++ b/examples/envpool/make_env.py @@ -0,0 +1,128 @@ +import copy +import inspect +from typing import Callable, Iterable, List, Optional, Union + +import envpool +from gymnasium import Env + + +from openrl.envs.vec_env import (AsyncVectorEnv, RewardWrapper, + SyncVectorEnv, VecMonitorWrapper) +from openrl.envs.vec_env.vec_info import VecInfoFactory +from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.rewards import RewardFactory + + +def build_envs( + make, + id: str, + env_num: int = 1, + wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None, + need_env_id: bool = False, + **kwargs, +) -> List[Callable[[], Env]]: + cfg = kwargs.get("cfg", None) + + def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]: + def _make_env() -> Env: + new_kwargs = copy.deepcopy(kwargs) + if need_env_id: + new_kwargs["env_id"] = env_id + new_kwargs["env_num"] = env_num + if "envpool" in new_kwargs: + # for now envpool doesnt support any render mode + # envpool also doesnt stores the id anywhere + new_kwargs.pop("envpool") + env = make( + id, + **new_kwargs, + ) + env.unwrapped.spec.id = id + + if wrappers is not None: + if callable(wrappers): + if issubclass(wrappers, BaseWrapper): + env = wrappers(env, cfg=cfg) + else: + env = wrappers(env) + elif isinstance(wrappers, Iterable) and all( + [callable(w) for w in wrappers] + ): + for wrapper in wrappers: + if ( + issubclass(wrapper, BaseWrapper) + and "cfg" in inspect.signature(wrapper.__init__).parameters + ): + env = wrapper(env, cfg=cfg) + else: + env = wrapper(env) + else: + raise NotImplementedError + + return env + + return _make_env + + env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)] + return env_fns + + +def make_envpool_envs( + id: str, + env_num: int = 1, + **kwargs, +): + assert "env_type" in kwargs + assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] + kwargs["envpool"] = True + + if 'env_wrappers' in kwargs: + env_wrappers = kwargs.pop("env_wrappers") + else: + env_wrappers = [] + env_fns = build_envs( + make=envpool.make, + id=id, + env_num=env_num, + wrappers=env_wrappers, + **kwargs, + ) + return env_fns + + +def make( + id: str, + env_num: int = 1, + asynchronous: bool = False, + add_monitor: bool = True, + render_mode: Optional[str] = None, + auto_reset: bool = True, + **kwargs, +): + cfg = kwargs.get("cfg", None) + if id in envpool.registration.list_all_envs(): + env_fns = make_envpool_envs( + id=id.split(":")[-1], + env_num=env_num, + **kwargs, + ) + if asynchronous: + env = AsyncVectorEnv( + env_fns, render_mode=render_mode, auto_reset=auto_reset + ) + else: + env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset) + + reward_class = cfg.reward_class if cfg else None + reward_class = RewardFactory.get_reward_class(reward_class, env) + + env = RewardWrapper(env, reward_class) + + if add_monitor: + vec_info_class = cfg.vec_info_class if cfg else None + vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) + env = VecMonitorWrapper(vec_info_class, env) + + return env + else: + raise NotImplementedError(f"env {id} is not supported") diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py new file mode 100644 index 0000000..a02151f --- /dev/null +++ b/examples/envpool/train_ppo.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import numpy as np + +from openrl.configs.config import create_config_parser +from make_env import make +from examples.envpool.envpool_wrappers import VecAdapter, VecMonitor +from openrl.modules.common import PPONet as Net +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + + +def train(): + # create the neural network + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + # create environment, set environment parallelism to 9 + env = make( + "CartPole-v1", + render_mode=None, + env_num=9, + asynchronous=False, + env_wrappers=[VecAdapter, VecMonitor], + env_type="gym", + ) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net, use_wandb=False, project_name="CartPole-v1") + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=20000) + + env.close() + return agent + + +def evaluation(agent): + # begin to test + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make( + "CartPole-v1", + env_wrappers=[VecAdapter, VecMonitor], + render_mode=render_mode, + env_num=9, + asynchronous=True, + env_type="gym", + ) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + step = 0 + total_step, total_reward = 0, 0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + total_step += 1 + total_reward += np.mean(r) + if step % 50 == 0: + print(f"{step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + agent = train() + evaluation(agent) diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 76f4b35..a0c59c6 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -69,4 +69,4 @@ def _make_env() -> Env: return _make_env env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)] - return env_fns + return env_fns \ No newline at end of file diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 5d1ed64..1ee9b53 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -173,4 +173,4 @@ def make( vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) env = VecMonitorWrapper(vec_info_class, env) - return env + return env \ No newline at end of file