Skip to content

Commit

Permalink
arena add test more envs
Browse files Browse the repository at this point in the history
arena add test more envs
  • Loading branch information
huangshiyu13 authored Oct 24, 2023
2 parents 962d45a + 537f822 commit 73c8108
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 11 deletions.
9 changes: 9 additions & 0 deletions examples/arena/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

```bash
pip install "openrl[selfplay]"
pip install "pettingzoo[mpe]","pettingzoo[butterfly]"
```

### Usage
Expand All @@ -15,3 +16,11 @@ python run_arena.py
### Evaluate Google Research Football submissions for JiDi locally

If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally).

### Evaluate more environments

We also provide a script to evaluate more environments, including MPE, Go, Texas Holdem, Butterfly. You can run the script as follows:

```shell
python evaluate_more_envs.py
```
122 changes: 122 additions & 0 deletions examples/arena/evaluate_more_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""

from pettingzoo.butterfly import cooperative_pong_v5
from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6
from pettingzoo.mpe import simple_push_v3

from examples.custom_env.rock_paper_scissors import RockPaperScissors
from openrl.arena import make_arena
from openrl.arena.agents.local_agent import LocalAgent
from openrl.envs.PettingZoo.registration import register
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner


def ConnectFourEnv(render_mode, **kwargs):
return connect_four_v3.env(render_mode)


def RockPaperScissorsEnv(render_mode, **kwargs):
return RockPaperScissors(render_mode)


def GoEnv(render_mode, **kwargs):
return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)


def TexasHoldemEnv(render_mode, **kwargs):
return texas_holdem_no_limit_v6.env(render_mode=render_mode)


# MPE
def SimplePushEnv(render_mode, **kwargs):
return simple_push_v3.env(render_mode=render_mode)


def CooperativePongEnv(render_mode, **kwargs):
return cooperative_pong_v5.env(render_mode=render_mode)


def register_new_envs():
new_env_dict = {
"connect_four_v3": ConnectFourEnv,
"RockPaperScissors": RockPaperScissorsEnv,
"go_v5": GoEnv,
"texas_holdem_no_limit_v6": TexasHoldemEnv,
"simple_push_v3": SimplePushEnv,
"cooperative_pong_v5": CooperativePongEnv,
}

for env_id, env in new_env_dict.items():
register(env_id, env)
return new_env_dict.keys()


def run_arena(
env_id: str,
parallel: bool = True,
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
):
env_wrappers = [RecordWinner]

arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)

agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent")

arena.reset(
agents={"agent1": agent1, "agent2": agent2},
total_games=total_games,
max_game_onetime=max_game_onetime,
seed=seed,
)
result = arena.run(parallel=parallel)
arena.close()
print(result)
return result


def test_new_envs():
env_ids = register_new_envs()
seed = 0
for env_id in env_ids:
run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)


if __name__ == "__main__":
test_new_envs()
1 change: 1 addition & 0 deletions examples/custom_env/rock_paper_scissors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def step(self, action):
# handles stepping an agent which is already dead
# accepts a None action for the one agent, and moves the agent_selection to
# the next dead agent, or if there are no more dead agents, to the next live agent
action = None
self._was_dead_step(action)
return

Expand Down
14 changes: 8 additions & 6 deletions openrl/arena/games/two_player_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def default_dispatch_func(
players: List[str],
agent_names: List[str],
) -> Dict[str, str]:
assert len(players) == len(
agent_names
), "The number of players must be equal to the number of agents."
assert len(players) == len(agent_names), (
f"The number of players {len(players)} must be equal to the number of"
f" agents: {len(agent_names)}."
)
assert len(players) == 2, "The number of players must be equal to 2."
np_random.shuffle(agent_names)
return dict(zip(players, agent_names))
Expand All @@ -49,20 +50,21 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]):
for player, agent in player2agent.items():
agent.reset(env, player)
result = {}
truncation_dict = {}
while True:
termination = False
info = {}
for player_name in env.agent_iter():
observation, reward, termination, truncation, info = env.last()

if termination:
truncation_dict[player_name] = truncation
if termination or all(truncation_dict.values()):
break
action = player2agent[player_name].act(
player_name, observation, reward, termination, truncation, info
)
env.step(action)

if termination:
if termination or all(truncation_dict.values()):
assert "winners" in info, "The game is terminated but no winners."
assert "losers" in info, "The game is terminated but no losers."

Expand Down
3 changes: 2 additions & 1 deletion openrl/envs/wrappers/pettingzoo_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def last(self, observe: bool = True):

winners = None
losers = None

for agent in self.terminations:
if self.terminations[agent]:
if self.terminations[agent] or all(self.truncations):
if winners is None:
winners = self.get_winners()
losers = [player for player in self.agents if player not in winners]
Expand Down
17 changes: 13 additions & 4 deletions openrl/selfplay/opponents/random_opponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ def _sample_random_action(
action = []

for obs, space in zip(observation, action_space):
mask = obs.get("action_mask", None)
action.append(space.sample(mask))
if termination or truncation:
action.append(None)
else:
if isinstance(obs, dict):
mask = obs.get("action_mask", None)
else:
mask = None
action.append(space.sample(mask))
else:
mask = observation.get("action_mask", None)
action = action_space.sample(mask)
if termination or truncation:
action = None
else:
mask = observation.get("action_mask", None)
action = action_space.sample(mask)
return action

def _load(self, opponent_path: Union[str, Path]):
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def get_extra_requires() -> dict:
"evaluate",
],
"selfplay": ["ray[default]", "ray[serve]", "pettingzoo[classic]", "trueskill"],
"selfplay_test": ["pettingzoo[mpe]", "pettingzoo[butterfly]"],
"retro": ["gym-retro"],
"super_mario": ["gym-super-mario-bros"],
"atari": ["gymnasium[atari]", "gymnasium[accept-rom-license]"],
}
req["test"].extend(req["selfplay"])
req["test"].extend(req["selfplay_test"])
req["test"].extend(req["atari"])
req["test"].extend(req["nlp_test"])
return req
Expand Down
107 changes: 107 additions & 0 deletions tests/test_arena/test_new_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import os
import sys

import pytest
from pettingzoo.butterfly import cooperative_pong_v5
from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6
from pettingzoo.mpe import simple_push_v3

from examples.custom_env.rock_paper_scissors import RockPaperScissors
from openrl.arena import make_arena
from openrl.arena.agents.local_agent import LocalAgent
from openrl.envs.PettingZoo.registration import register
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner


def ConnectFourEnv(render_mode, **kwargs):
return connect_four_v3.env(render_mode)


def RockPaperScissorsEnv(render_mode, **kwargs):
return RockPaperScissors(render_mode)


def GoEnv(render_mode, **kwargs):
return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)


def TexasHoldemEnv(render_mode, **kwargs):
return texas_holdem_no_limit_v6.env(render_mode=render_mode)


# MPE
def SimplePushEnv(render_mode, **kwargs):
return simple_push_v3.env(render_mode=render_mode)


def CooperativePongEnv(render_mode, **kwargs):
return cooperative_pong_v5.env(render_mode=render_mode)


def register_new_envs():
new_env_dict = {
"connect_four_v3": ConnectFourEnv,
"RockPaperScissors": RockPaperScissorsEnv,
"go_v5": GoEnv,
"texas_holdem_no_limit_v6": TexasHoldemEnv,
"simple_push_v3": SimplePushEnv,
"cooperative_pong_v5": CooperativePongEnv,
}

for env_id, env in new_env_dict.items():
register(env_id, env)
return new_env_dict.keys()


def run_arena(
env_id: str,
parallel: bool = True,
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
):
env_wrappers = [RecordWinner]

arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)

agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")

arena.reset(
agents={"agent1": agent1, "agent2": agent2},
total_games=total_games,
max_game_onetime=max_game_onetime,
seed=seed,
)
result = arena.run(parallel=parallel)
arena.close()
return result


@pytest.mark.unittest
def test_new_envs():
env_ids = register_new_envs()
seed = 0
for env_id in env_ids:
run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 comments on commit 73c8108

Please sign in to comment.