diff --git a/examples/custom_env/pettingzoo_env.py b/examples/custom_env/pettingzoo_env.py index 8b173449..d5644b7b 100644 --- a/examples/custom_env/pettingzoo_env.py +++ b/examples/custom_env/pettingzoo_env.py @@ -25,6 +25,7 @@ from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper register("RockPaperScissors", RockPaperScissors) + env = make( "RockPaperScissors", env_num=10, diff --git a/examples/custom_env/rock_paper_scissors.py b/examples/custom_env/rock_paper_scissors.py index 7d5649d1..71d0fa74 100644 --- a/examples/custom_env/rock_paper_scissors.py +++ b/examples/custom_env/rock_paper_scissors.py @@ -54,7 +54,7 @@ class RockPaperScissors(AECEnv): metadata = {"render_modes": ["human"], "name": "rps_v2"} - def __init__(self, render_mode=None): + def __init__(self, id, render_mode=None): """ The init method takes in environment arguments and should define the following attributes: @@ -122,8 +122,8 @@ def observe(self, agent): """ # observation of one agent is the previous state of the other # return np.array(self.observations[agent]) - obs = np.zeros(4, dtype=np.int64) - obs[self.observations[agent]] = 1 + obs = np.zeros([1, 4], dtype=np.int64) + obs[0, self.observations[agent]] = 1 return obs def close(self): diff --git a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py index a3de3c0f..5cd6116c 100644 --- a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py +++ b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py @@ -147,10 +147,18 @@ def _step(self, action): if termination or truncation: return ( copy.copy(self.env.observe(self.self_player)), - self.env.rewards[self.self_player], + ( + self.env.rewards[self.self_player] + if self.self_player in self.env.rewards + else 0 + ), termination, truncation, - self.env.infos[self.self_player], + ( + self.env.infos[self.self_player] + if self.self_player in self.env.rewards + else {} + ), ) else: