From 8c196b39d1755dee7cca44a581d9d61c6fab8fdb Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Mon, 13 Nov 2023 15:38:17 +0800 Subject: [PATCH 1/2] update --- examples/custom_env/pettingzoo_env.py | 2 +- examples/custom_env/rock_paper_scissors.py | 1 + openrl/envs/vec_env/sync_venv.py | 2 + openrl/envs/wrappers/extra_wrappers.py | 79 +++++++++++++++++++ .../wrappers/base_multiplayer_wrapper.py | 1 + 5 files changed, 84 insertions(+), 1 deletion(-) diff --git a/examples/custom_env/pettingzoo_env.py b/examples/custom_env/pettingzoo_env.py index d5644b7b..64211512 100644 --- a/examples/custom_env/pettingzoo_env.py +++ b/examples/custom_env/pettingzoo_env.py @@ -28,7 +28,7 @@ env = make( "RockPaperScissors", - env_num=10, + env_num=1, opponent_wrappers=[RandomOpponentWrapper], ) diff --git a/examples/custom_env/rock_paper_scissors.py b/examples/custom_env/rock_paper_scissors.py index 71d0fa74..f18e1841 100644 --- a/examples/custom_env/rock_paper_scissors.py +++ b/examples/custom_env/rock_paper_scissors.py @@ -18,6 +18,7 @@ import functools +import time import gymnasium import numpy as np diff --git a/openrl/envs/vec_env/sync_venv.py b/openrl/envs/vec_env/sync_venv.py index 6a61d489..1e208e4c 100644 --- a/openrl/envs/vec_env/sync_venv.py +++ b/openrl/envs/vec_env/sync_venv.py @@ -15,6 +15,7 @@ # limitations under the License. """""" +import time from copy import deepcopy from typing import Any, Callable, Iterable, List, Optional, Sequence, Union @@ -202,6 +203,7 @@ def _step(self, actions: ActType): self._truncateds[i], info, ) = returns + need_reset = _need_reset and ( all(self._terminateds[i]) or all(self._truncateds[i]) ) diff --git a/openrl/envs/wrappers/extra_wrappers.py b/openrl/envs/wrappers/extra_wrappers.py index da819a87..27359d9e 100644 --- a/openrl/envs/wrappers/extra_wrappers.py +++ b/openrl/envs/wrappers/extra_wrappers.py @@ -21,6 +21,9 @@ import gymnasium as gym import numpy as np from gymnasium import spaces +from gymnasium.utils.step_api_compatibility import ( + convert_to_terminated_truncated_step_api, +) from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper @@ -46,6 +49,76 @@ def step(self, action): return obs, total_reward, term, trunc, info +def convert_to_done_step_api( + step_returns, + is_vector_env: bool = False, +): + if len(step_returns) == 4: + return step_returns + else: + assert len(step_returns) == 5 + observations, rewards, terminated, truncated, infos = step_returns + + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + # if truncated[0]: + # import pdb; + # pdb.set_trace() + + if is_vector_env is False: + if isinstance(terminated, list): + infos["TimeLimit.truncated"] = truncated[0] and not terminated[0] + done_return = np.logical_or(terminated, truncated) + else: + if truncated or terminated: + infos["TimeLimit.truncated"] = truncated and not terminated + done_return = terminated or truncated + return ( + observations, + rewards, + done_return, + infos, + ) + elif isinstance(infos, list): + for info, env_truncated, env_terminated in zip( + infos, truncated, terminated + ): + if env_truncated or env_terminated: + info["TimeLimit.truncated"] = env_truncated and not env_terminated + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + elif isinstance(infos, dict): + if np.logical_or(np.any(truncated), np.any(terminated)): + infos["TimeLimit.truncated"] = np.logical_and( + truncated, np.logical_not(terminated) + ) + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + else: + raise TypeError( + "Unexpected value of infos, as is_vector_envs=False, expects `info` to" + f" be a list or dict, actual type: {type(infos)}" + ) + + +def step_api_compatibility( + step_returns, + output_truncation_bool: bool = True, + is_vector_env: bool = False, +): + if output_truncation_bool: + return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) + else: + return convert_to_done_step_api(step_returns, is_vector_env) + + class RemoveTruncated(StepAPICompatibility, BaseWrapper): def __init__( self, @@ -54,6 +127,12 @@ def __init__( output_truncation_bool = False super().__init__(env, output_truncation_bool=output_truncation_bool) + def step(self, action): + step_returns = self.env.step(action) + return step_api_compatibility( + step_returns, self.output_truncation_bool, self.is_vector_env + ) + class FlattenObservation(BaseObservationWrapper): def __init__(self, env: gym.Env): diff --git a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py index 5cd6116c..ca8d1e95 100644 --- a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py +++ b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py @@ -104,6 +104,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs): action = self.get_opponent_action( player_name, observation, reward, termination, truncation, info ) + self.env.step(action) def on_episode_end( From 6e9ce0f6a81309acecaee6f09924d1b788d49219 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Mon, 13 Nov 2023 15:41:47 +0800 Subject: [PATCH 2/2] fix petting zoo --- examples/custom_env/pettingzoo_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/custom_env/pettingzoo_env.py b/examples/custom_env/pettingzoo_env.py index 64211512..d5644b7b 100644 --- a/examples/custom_env/pettingzoo_env.py +++ b/examples/custom_env/pettingzoo_env.py @@ -28,7 +28,7 @@ env = make( "RockPaperScissors", - env_num=1, + env_num=10, opponent_wrappers=[RandomOpponentWrapper], )