Skip to content

Commit

Permalink
add selfplay test
Browse files Browse the repository at this point in the history
  • Loading branch information
huangshiyu13 committed Nov 23, 2023
1 parent 7373b04 commit 7ded5d5
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/selfplay/selfplay.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
globals:
selfplay_api_host: 127.0.0.1
selfplay_api_port: 10086
selfplay_api_port: 13486

seed: 0
selfplay_api:
Expand Down
6 changes: 5 additions & 1 deletion openrl/selfplay/callbacks/selfplay_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def _init_callback(self) -> None:
success = self.api_client.set_sample_strategy(self.sample_strategy)
try_time -= 1
if try_time <= 0:
raise RuntimeError("Failed to set sample strategy.")
raise RuntimeError(
f"Failed to set sample strategy: {self.sample_strategy}. host:"
f" {self.host}, port: {self.port}"
)

def _on_step(self) -> bool:
# print("To send request to API server.")
Expand All @@ -72,5 +75,6 @@ def _on_training_end(self) -> None:
print(f"deleting {application_name}")
serve.delete(application_name)
del self.bind
serve.shutdown()
if self.verbose >= 2:
print(f"delete {application_name} done!")
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def get_install_requires() -> list:
return [
"setuptools>=67.0",
"gymnasium",
"gymnasium>=0.29",
"click",
"termcolor",
"gym",
Expand Down Expand Up @@ -71,7 +71,13 @@ def get_extra_requires() -> dict:
"evaluate",
],
"selfplay": ["ray[default]", "ray[serve]", "pettingzoo[classic]", "trueskill"],
"selfplay_test": ["pettingzoo[mpe]", "pettingzoo[butterfly]"],
"selfplay_test": [
"ray[default]",
"ray[serve]",
"fastapi",
"pettingzoo[mpe]",
"pettingzoo[butterfly]",
],
"retro": ["gym-retro"],
"super_mario": ["gym-super-mario-bros"],
"atari": ["gymnasium[atari]", "gymnasium[accept-rom-license]"],
Expand Down
14 changes: 11 additions & 3 deletions tests/test_selfplay/test_train_selfplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
import ray
import torch

from openrl.configs.config import create_config_parser
Expand All @@ -18,22 +19,29 @@
@pytest.fixture(
scope="module",
params=[
"RandomOpponent",
"LastOpponent",
{"port": 13486, "strategy": "RandomOpponent"},
{"port": 13487, "strategy": "LastOpponent"},
],
)
def config(request):
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "./examples/selfplay/selfplay.yaml"])
cfg.selfplay_api.port = request.param["port"]
for i, c in enumerate(cfg.callbacks):
if c["id"] == "SelfplayCallback":
c["args"][
"opponent_template"
] = "./examples/selfplay/opponent_templates/tictactoe_opponent"
port = c["args"]["api_address"].split(":")[-1].split("/")[0]
c["args"]["api_address"] = c["args"]["api_address"].replace(
port, str(request.param["port"])
)
cfg.callbacks[i] = c
elif c["id"] == "SelfplayAPI":
c["args"]["sample_strategy"] = request.param
c["args"]["sample_strategy"] = request.param["strategy"]
c["args"]["port"] = request.param["port"]
cfg.callbacks[i] = c

else:
pass

Expand Down

0 comments on commit 7ded5d5

Please sign in to comment.