Skip to content

Commit

Permalink
Merge pull request #258 from huangshiyu13/main
Browse files Browse the repository at this point in the history
fix arena petting zoo import error
  • Loading branch information
huangshiyu13 authored Oct 23, 2023
2 parents 7e70569 + 0228e51 commit 962d45a
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 42 deletions.
4 changes: 3 additions & 1 deletion openrl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def prepare_loss(
)

q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)

Expand Down
4 changes: 3 additions & 1 deletion openrl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def prepare_loss(
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list
Expand Down
4 changes: 3 additions & 1 deletion openrl/arena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def make_arena(
**kwargs,
):
if custom_build_env is None:
from openrl.envs import PettingZoo

if (
env_id in pettingzoo_all_envs
or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
or env_id in PettingZoo.registration.pettingzoo_env_dict.keys()
):
from openrl.envs.PettingZoo import make_PettingZoo_env

Expand Down
10 changes: 4 additions & 6 deletions openrl/envs/mpe/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@
except ImportError:
print(
"Error occured while running `from pyglet.gl import *`",
(
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\''
),
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\'',
)

import math
Expand Down
4 changes: 3 additions & 1 deletion openrl/envs/snake/snake.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,9 @@ class Snake:
def __init__(self, player_id, board_width, board_height, init_len):
self.actions = [-2, 2, -1, 1]
self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
self.direction = random.choice(
self.actions
) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
self.board_width = board_width
self.board_height = board_height
x = random.randrange(0, board_height)
Expand Down
34 changes: 11 additions & 23 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,8 @@ def reset_send(

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete"
),
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete",
self._state.value,
)

Expand Down Expand Up @@ -329,10 +327,8 @@ def step_send(self, actions: np.ndarray):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete."
),
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete.",
self._state.value,
)

Expand All @@ -342,9 +338,7 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP

def step_fetch(
self, timeout: Optional[Union[int, float]] = None
) -> Union[
def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
Expand Down Expand Up @@ -576,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -636,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -717,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down
4 changes: 1 addition & 3 deletions openrl/utils/callbacks/checkpoint_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st
"""
return os.path.join(
self.save_path,
(
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}"
),
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}",
)

def _on_step(self) -> bool:
Expand Down
10 changes: 4 additions & 6 deletions openrl/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,10 @@ def evaluate_policy(

if not is_monitor_wrapped and warn:
warnings.warn(
(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper."
),
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
" may result in reporting modified episode lengths and rewards, if"
" other wrappers happen to modify these. Consider wrapping environment"
" first with ``Monitor`` wrapper.",
UserWarning,
)

Expand Down

0 comments on commit 962d45a

Please sign in to comment.