-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue #216: Add envpool to openrl #278
Changes from 4 commits
dac2804
3c31b5a
990f8c5
448cc6f
753fba7
693d2e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" | ||
import numpy as np | ||
|
||
from openrl.configs.config import create_config_parser | ||
from openrl.envs.common import make | ||
from openrl.envs.wrappers.envpool_wrappers import VecAdapter, VecMonitor | ||
from openrl.modules.common import PPONet as Net | ||
from openrl.modules.common.ppo_net import PPONet as Net | ||
from openrl.runners.common import PPOAgent as Agent | ||
|
||
|
||
def train(): | ||
# create the neural network | ||
cfg_parser = create_config_parser() | ||
cfg = cfg_parser.parse_args() | ||
|
||
# create environment, set environment parallelism to 9 | ||
env = make( | ||
"envpool:CartPole-v1", | ||
render_mode=None, | ||
env_num=9, | ||
asynchronous=False, | ||
env_wrappers=[VecAdapter, VecMonitor], | ||
env_type="gym", | ||
) | ||
|
||
net = Net( | ||
env, | ||
cfg=cfg, | ||
) | ||
# initialize the trainer | ||
agent = Agent(net, use_wandb=False, project_name="envpool:CartPole-v1") | ||
# start training, set total number of training steps to 20000 | ||
agent.train(total_time_steps=20000) | ||
|
||
env.close() | ||
return agent | ||
|
||
|
||
def evaluation(agent): | ||
# begin to test | ||
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. | ||
render_mode = "group_human" | ||
render_mode = None | ||
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True) | ||
# The trained agent sets up the interactive environment it needs. | ||
agent.set_env(env) | ||
# Initialize the environment and get initial observations and environmental information. | ||
obs, info = env.reset() | ||
done = False | ||
step = 0 | ||
total_step, total_reward = 0, 0 | ||
while not np.any(done): | ||
# Based on environmental observation input, predict next action. | ||
action, _ = agent.act(obs, deterministic=True) | ||
obs, r, done, info = env.step(action) | ||
step += 1 | ||
total_step += 1 | ||
total_reward += np.mean(r) | ||
if step % 50 == 0: | ||
print(f"{step}: reward:{np.mean(r)}") | ||
env.close() | ||
print("total step:", total_step) | ||
print("total reward:", total_reward) | ||
|
||
|
||
if __name__ == "__main__": | ||
agent = train() | ||
evaluation(agent) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,13 +36,22 @@ | |
new_kwargs["env_num"] = env_num | ||
if id.startswith("ALE/") or id in gym.envs.registry.keys(): | ||
new_kwargs.pop("cfg", None) | ||
|
||
env = make( | ||
id, | ||
render_mode=env_render_mode, | ||
disable_env_checker=_disable_env_checker, | ||
**new_kwargs, | ||
) | ||
if "envpool" in new_kwargs: | ||
# for now envpool doesnt support any render mode | ||
# envpool also doesnt stores the id anywhere | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Envpool can not be installed on many platforms (such as Windows), so you should use the envpool as an external env which is out of OpenRL. Moreover, with the envpool as an external env, you don't need to write test code for codecov, because we only track the code in the openrl folder. Examples: |
||
new_kwargs.pop("envpool") | ||
env = make( | ||
id, | ||
**new_kwargs, | ||
) | ||
env.unwrapped.spec.id = id | ||
else: | ||
env = make( | ||
id, | ||
render_mode=env_render_mode, | ||
disable_env_checker=_disable_env_checker, | ||
**new_kwargs, | ||
) | ||
|
||
if wrappers is not None: | ||
if callable(wrappers): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
"""""" | ||
from typing import Callable, Optional | ||
|
||
import envpool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Envpool can not be installed on many platforms (such as Windows), so you should use the envpool as an external env which is out of OpenRL. Moreover, with the envpool as an external env, you don't need to write test code for codecov, because we only track the code in the openrl folder. Examples: |
||
import gymnasium as gym | ||
|
||
import openrl | ||
|
@@ -72,7 +73,6 @@ | |
env_fns = make_single_agent_drone_envs( | ||
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs | ||
) | ||
|
||
elif id.startswith("snakes_"): | ||
from openrl.envs.snake import make_snake_envs | ||
|
||
|
@@ -155,6 +155,18 @@ | |
env_fns = make_PettingZoo_envs( | ||
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs | ||
) | ||
elif ( | ||
"envpool:" in id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Envpool can not be installed on many platforms (such as Windows), so you should use the envpool as an external env which is out of OpenRL. Moreover, with the envpool as an external env, you don't need to write test code for codecov, because we only track the code in the openrl folder. Examples: |
||
and id.split(":")[-1] in envpool.registration.list_all_envs() | ||
): | ||
from openrl.envs.envpool import make_envpool_envs | ||
|
||
env_fns = make_envpool_envs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Envpool can not be installed on many platforms (such as Windows), so you should use the envpool as an external env which is out of OpenRL. Moreover, with the envpool as an external env, you don't need to write test code for codecov, because we only track the code in the openrl folder. Examples: |
||
id=id.split(":")[-1], | ||
env_num=env_num, | ||
render_mode=convert_render_mode, | ||
**kwargs, | ||
) | ||
else: | ||
raise NotImplementedError(f"env {id} is not supported.") | ||
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should move to examples/envpool |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 The OpenRL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""""" | ||
from typing import List, Optional, Union | ||
|
||
import envpool | ||
|
||
from openrl.envs.common import build_envs | ||
|
||
|
||
def make_envpool_envs( | ||
id: str, | ||
env_num: int = 1, | ||
render_mode: Optional[Union[str, List[str]]] = None, | ||
**kwargs, | ||
): | ||
assert "env_type" in kwargs | ||
assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] | ||
# Since render_mode is not supported, we set envpool to True | ||
# so that we can remove render_mode keyword argument from build_envs | ||
assert render_mode is None, "envpool does not support render_mode yet" | ||
kwargs["envpool"] = True | ||
|
||
env_wrappers = kwargs.pop("env_wrappers") | ||
env_fns = build_envs( | ||
make=envpool.make, | ||
id=id, | ||
env_num=env_num, | ||
render_mode=render_mode, | ||
wrappers=env_wrappers, | ||
**kwargs, | ||
) | ||
return env_fns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should move to examples/envpool |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import time | ||
import warnings | ||
from typing import Optional | ||
|
||
import gym | ||
import gymnasium | ||
import numpy as np | ||
from envpool.python.protocol import EnvPool | ||
from packaging import version | ||
from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper | ||
from stable_baselines3.common.vec_env import VecMonitor | ||
from stable_baselines3.common.vec_env.base_vec_env import (VecEnvObs, | ||
VecEnvStepReturn) | ||
|
||
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0") | ||
|
||
|
||
class VecEnvWrapper(BaseWrapper): | ||
@property | ||
def agent_num(self): | ||
if self.is_original_envpool_env(): | ||
return 1 | ||
else: | ||
return self.env.agent_num | ||
|
||
def is_original_envpool_env(self): | ||
return not hasattr(self.venv, "agent_num`") | ||
|
||
|
||
class VecAdapter(VecEnvWrapper): | ||
""" | ||
Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. | ||
|
||
:param venv: The envpool object. | ||
""" | ||
|
||
def __init__(self, venv: EnvPool): | ||
venv.num_envs = venv.spec.config.num_envs | ||
observation_space = venv.observation_space | ||
new_observation_space = gymnasium.spaces.Box( | ||
low=observation_space.low, | ||
high=observation_space.high, | ||
dtype=observation_space.dtype, | ||
) | ||
action_space = venv.action_space | ||
if isinstance(action_space, gym.spaces.Discrete): | ||
new_action_space = gymnasium.spaces.Discrete(action_space.n) | ||
elif isinstance(action_space, gym.spaces.MultiDiscrete): | ||
new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec) | ||
elif isinstance(action_space, gym.spaces.MultiBinary): | ||
new_action_space = gymnasium.spaces.MultiBinary(action_space.n) | ||
elif isinstance(action_space, gym.spaces.Box): | ||
new_action_space = gymnasium.spaces.Box( | ||
low=action_space.low, | ||
high=action_space.high, | ||
dtype=action_space.dtype, | ||
) | ||
else: | ||
raise NotImplementedError(f"Action space {action_space} is not supported") | ||
super().__init__( | ||
venv=venv, | ||
observation_space=new_observation_space, | ||
action_space=new_action_space, | ||
) | ||
|
||
def step_async(self, actions: np.ndarray) -> None: | ||
self.actions = actions | ||
|
||
def reset(self) -> VecEnvObs: | ||
if is_legacy_gym: | ||
return self.venv.reset(), {} | ||
else: | ||
return self.venv.reset() | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
if is_legacy_gym: | ||
obs, rewards, dones, info_dict = self.venv.step(self.actions) | ||
else: | ||
obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions) | ||
dones = terms + truncs | ||
rewards = rewards | ||
infos = [] | ||
for i in range(self.num_envs): | ||
infos.append( | ||
{ | ||
key: info_dict[key][i] | ||
for key in info_dict.keys() | ||
if isinstance(info_dict[key], np.ndarray) | ||
} | ||
) | ||
if dones[i]: | ||
infos[i]["terminal_observation"] = obs[i] | ||
if is_legacy_gym: | ||
obs[i] = self.venv.reset(np.array([i])) | ||
else: | ||
obs[i] = self.venv.reset(np.array([i]))[0] | ||
return obs, rewards, dones, infos | ||
|
||
|
||
class VecMonitor(VecEnvWrapper): | ||
def __init__( | ||
self, | ||
venv, | ||
filename: Optional[str] = None, | ||
info_keywords=(), | ||
): | ||
# Avoid circular import | ||
from stable_baselines3.common.monitor import Monitor, ResultsWriter | ||
|
||
try: | ||
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] | ||
except AttributeError: | ||
is_wrapped_with_monitor = False | ||
|
||
if is_wrapped_with_monitor: | ||
warnings.warn( | ||
"The environment is already wrapped with a `Monitor` wrapper" | ||
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" | ||
"overwritten by the `VecMonitor` ones.", | ||
UserWarning, | ||
) | ||
|
||
VecEnvWrapper.__init__(self, venv) | ||
self.episode_count = 0 | ||
self.t_start = time.time() | ||
|
||
env_id = None | ||
if hasattr(venv, "spec") and venv.spec is not None: | ||
env_id = venv.spec.id | ||
|
||
self.results_writer: Optional[ResultsWriter] = None | ||
if filename: | ||
self.results_writer = ResultsWriter( | ||
filename, | ||
header={"t_start": self.t_start, "env_id": str(env_id)}, | ||
extra_keys=info_keywords, | ||
) | ||
|
||
self.info_keywords = info_keywords | ||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) | ||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) | ||
|
||
def reset(self, **kwargs) -> VecEnvObs: | ||
obs, info = self.venv.reset() | ||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) | ||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) | ||
return obs, info | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, rewards, dones, infos = self.venv.step_wait() | ||
self.episode_returns += rewards | ||
self.episode_lengths += 1 | ||
new_infos = list(infos[:]) | ||
for i in range(len(dones)): | ||
if dones[i]: | ||
info = infos[i].copy() | ||
episode_return = self.episode_returns[i] | ||
episode_length = self.episode_lengths[i] | ||
episode_info = { | ||
"r": episode_return, | ||
"l": episode_length, | ||
"t": round(time.time() - self.t_start, 6), | ||
} | ||
for key in self.info_keywords: | ||
episode_info[key] = info[key] | ||
info["episode"] = episode_info | ||
self.episode_count += 1 | ||
self.episode_returns[i] = 0 | ||
self.episode_lengths[i] = 0 | ||
if self.results_writer: | ||
self.results_writer.write_row(episode_info) | ||
new_infos[i] = info | ||
rewards = np.expand_dims(rewards, 1) | ||
return obs, rewards, dones, new_infos | ||
|
||
def close(self) -> None: | ||
if self.results_writer: | ||
self.results_writer.close() | ||
return self.venv.close() | ||
|
||
|
||
__all__ = ["VecAdapter", "VecMonitor"] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,7 @@ def get_extra_requires() -> dict: | |
"async_timeout", | ||
"pettingzoo[classic]", | ||
"trueskill", | ||
"envpool", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this. Add a README.md in examples/envpool. And show how to install dependencies in the markdown. |
||
], | ||
"selfplay_test": [ | ||
"ray[default]>=2.7", | ||
|
@@ -84,6 +85,7 @@ def get_extra_requires() -> dict: | |
"fastapi", | ||
"pettingzoo[mpe]", | ||
"pettingzoo[butterfly]", | ||
"envpool", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete this. Add a README.md in examples/envpool. And show how to install dependencies in the markdown. |
||
], | ||
"retro": ["gym-retro"], | ||
"super_mario": ["gym-super-mario-bros"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Envpool can not be installed on many platforms (such as Windows), so you should use the envpool as an external env which is out of OpenRL. Moreover, with the envpool as an external env, you don't need to write test code for codecov, because we only track the code in the openrl folder.
Examples: