Skip to content

Commit

Permalink
Merge pull request #278 from kingjuno/Issue-#216
Browse files Browse the repository at this point in the history
Issue #216: Add envpool to openrl
  • Loading branch information
huangshiyu13 authored Dec 20, 2023
2 parents 8185373 + 693d2e1 commit e864a08
Show file tree
Hide file tree
Showing 6 changed files with 424 additions and 2 deletions.
20 changes: 20 additions & 0 deletions examples/envpool/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## Installation


Install envpool with:

``` shell
pip install envpool
```

Note 1: envpool only supports Linux operating system.

## Usage

You can use `OpenRL` to train Cartpole (envpool) via:

``` shell
PYTHON_PATH train_ppo.py
```

You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers.
182 changes: 182 additions & 0 deletions examples/envpool/envpool_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import time
import warnings
from typing import Optional

import gym
import gymnasium
import numpy as np
from envpool.python.protocol import EnvPool
from packaging import version
from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper
from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3.common.vec_env.base_vec_env import (VecEnvObs,
VecEnvStepReturn)

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class VecEnvWrapper(BaseWrapper):
@property
def agent_num(self):
if self.is_original_envpool_env():
return 1
else:
return self.env.agent_num

def is_original_envpool_env(self):
return not hasattr(self.venv, "agent_num`")


class VecAdapter(VecEnvWrapper):
"""
Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.
:param venv: The envpool object.
"""

def __init__(self, venv: EnvPool):
venv.num_envs = venv.spec.config.num_envs
observation_space = venv.observation_space
new_observation_space = gymnasium.spaces.Box(
low=observation_space.low,
high=observation_space.high,
dtype=observation_space.dtype,
)
action_space = venv.action_space
if isinstance(action_space, gym.spaces.Discrete):
new_action_space = gymnasium.spaces.Discrete(action_space.n)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec)
elif isinstance(action_space, gym.spaces.MultiBinary):
new_action_space = gymnasium.spaces.MultiBinary(action_space.n)
elif isinstance(action_space, gym.spaces.Box):
new_action_space = gymnasium.spaces.Box(
low=action_space.low,
high=action_space.high,
dtype=action_space.dtype,
)
else:
raise NotImplementedError(f"Action space {action_space} is not supported")
super().__init__(
venv=venv,
observation_space=new_observation_space,
action_space=new_action_space,
)

def step_async(self, actions: np.ndarray) -> None:
self.actions = actions

def reset(self) -> VecEnvObs:
if is_legacy_gym:
return self.venv.reset(), {}
else:
return self.venv.reset()

def step_wait(self) -> VecEnvStepReturn:
if is_legacy_gym:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
else:
obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions)
dones = terms + truncs
rewards = rewards
infos = []
for i in range(self.num_envs):
infos.append(
{
key: info_dict[key][i]
for key in info_dict.keys()
if isinstance(info_dict[key], np.ndarray)
}
)
if dones[i]:
infos[i]["terminal_observation"] = obs[i]
if is_legacy_gym:
obs[i] = self.venv.reset(np.array([i]))
else:
obs[i] = self.venv.reset(np.array([i]))[0]
return obs, rewards, dones, infos


class VecMonitor(VecEnvWrapper):
def __init__(
self,
venv,
filename: Optional[str] = None,
info_keywords=(),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter

try:
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
except AttributeError:
is_wrapped_with_monitor = False

if is_wrapped_with_monitor:
warnings.warn(
"The environment is already wrapped with a `Monitor` wrapper"
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
"overwritten by the `VecMonitor` ones.",
UserWarning,
)

VecEnvWrapper.__init__(self, venv)
self.episode_count = 0
self.t_start = time.time()

env_id = None
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id

self.results_writer: Optional[ResultsWriter] = None
if filename:
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": str(env_id)},
extra_keys=info_keywords,
)

self.info_keywords = info_keywords
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)

def reset(self, **kwargs) -> VecEnvObs:
obs, info = self.venv.reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info

def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.time() - self.t_start, 6),
}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
rewards = np.expand_dims(rewards, 1)
return obs, rewards, dones, new_infos

def close(self) -> None:
if self.results_writer:
self.results_writer.close()
return self.venv.close()


__all__ = ["VecAdapter", "VecMonitor"]
128 changes: 128 additions & 0 deletions examples/envpool/make_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import copy
import inspect
from typing import Callable, Iterable, List, Optional, Union

import envpool
from gymnasium import Env


from openrl.envs.vec_env import (AsyncVectorEnv, RewardWrapper,
SyncVectorEnv, VecMonitorWrapper)
from openrl.envs.vec_env.vec_info import VecInfoFactory
from openrl.envs.wrappers.base_wrapper import BaseWrapper
from openrl.rewards import RewardFactory


def build_envs(
make,
id: str,
env_num: int = 1,
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
need_env_id: bool = False,
**kwargs,
) -> List[Callable[[], Env]]:
cfg = kwargs.get("cfg", None)

def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]:
def _make_env() -> Env:
new_kwargs = copy.deepcopy(kwargs)
if need_env_id:
new_kwargs["env_id"] = env_id
new_kwargs["env_num"] = env_num
if "envpool" in new_kwargs:
# for now envpool doesnt support any render mode
# envpool also doesnt stores the id anywhere
new_kwargs.pop("envpool")
env = make(
id,
**new_kwargs,
)
env.unwrapped.spec.id = id

if wrappers is not None:
if callable(wrappers):
if issubclass(wrappers, BaseWrapper):
env = wrappers(env, cfg=cfg)
else:
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
if (
issubclass(wrapper, BaseWrapper)
and "cfg" in inspect.signature(wrapper.__init__).parameters
):
env = wrapper(env, cfg=cfg)
else:
env = wrapper(env)
else:
raise NotImplementedError

return env

return _make_env

env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)]
return env_fns


def make_envpool_envs(
id: str,
env_num: int = 1,
**kwargs,
):
assert "env_type" in kwargs
assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"]
kwargs["envpool"] = True

if 'env_wrappers' in kwargs:
env_wrappers = kwargs.pop("env_wrappers")
else:
env_wrappers = []
env_fns = build_envs(
make=envpool.make,
id=id,
env_num=env_num,
wrappers=env_wrappers,
**kwargs,
)
return env_fns


def make(
id: str,
env_num: int = 1,
asynchronous: bool = False,
add_monitor: bool = True,
render_mode: Optional[str] = None,
auto_reset: bool = True,
**kwargs,
):
cfg = kwargs.get("cfg", None)
if id in envpool.registration.list_all_envs():
env_fns = make_envpool_envs(
id=id.split(":")[-1],
env_num=env_num,
**kwargs,
)
if asynchronous:
env = AsyncVectorEnv(
env_fns, render_mode=render_mode, auto_reset=auto_reset
)
else:
env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset)

reward_class = cfg.reward_class if cfg else None
reward_class = RewardFactory.get_reward_class(reward_class, env)

env = RewardWrapper(env, reward_class)

if add_monitor:
vec_info_class = cfg.vec_info_class if cfg else None
vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env)
env = VecMonitorWrapper(vec_info_class, env)

return env
else:
raise NotImplementedError(f"env {id} is not supported")
Loading

0 comments on commit e864a08

Please sign in to comment.