Skip to content

Commit

Permalink
Merge pull request #272 from huangshiyu13/main
Browse files Browse the repository at this point in the history
fix petting zoo bugs
  • Loading branch information
huangshiyu13 authored Nov 13, 2023
2 parents b71b07b + 6e9ce0f commit 6cd773a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/custom_env/rock_paper_scissors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import functools
import time

import gymnasium
import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions openrl/envs/vec_env/sync_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

""""""
import time
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union

Expand Down Expand Up @@ -202,6 +203,7 @@ def _step(self, actions: ActType):
self._truncateds[i],
info,
) = returns

need_reset = _need_reset and (
all(self._terminateds[i]) or all(self._truncateds[i])
)
Expand Down
79 changes: 79 additions & 0 deletions openrl/envs/wrappers/extra_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.utils.step_api_compatibility import (
convert_to_terminated_truncated_step_api,
)
from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility

from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper
Expand All @@ -46,6 +49,76 @@ def step(self, action):
return obs, total_reward, term, trunc, info


def convert_to_done_step_api(
step_returns,
is_vector_env: bool = False,
):
if len(step_returns) == 4:
return step_returns
else:
assert len(step_returns) == 5
observations, rewards, terminated, truncated, infos = step_returns

# Cases to handle - info single env / info vector env (list) / info vector env (dict)
# if truncated[0]:
# import pdb;
# pdb.set_trace()

if is_vector_env is False:
if isinstance(terminated, list):
infos["TimeLimit.truncated"] = truncated[0] and not terminated[0]
done_return = np.logical_or(terminated, truncated)
else:
if truncated or terminated:
infos["TimeLimit.truncated"] = truncated and not terminated
done_return = terminated or truncated
return (
observations,
rewards,
done_return,
infos,
)
elif isinstance(infos, list):
for info, env_truncated, env_terminated in zip(
infos, truncated, terminated
):
if env_truncated or env_terminated:
info["TimeLimit.truncated"] = env_truncated and not env_terminated
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
elif isinstance(infos, dict):
if np.logical_or(np.any(truncated), np.any(terminated)):
infos["TimeLimit.truncated"] = np.logical_and(
truncated, np.logical_not(terminated)
)
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
else:
raise TypeError(
"Unexpected value of infos, as is_vector_envs=False, expects `info` to"
f" be a list or dict, actual type: {type(infos)}"
)


def step_api_compatibility(
step_returns,
output_truncation_bool: bool = True,
is_vector_env: bool = False,
):
if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
else:
return convert_to_done_step_api(step_returns, is_vector_env)


class RemoveTruncated(StepAPICompatibility, BaseWrapper):
def __init__(
self,
Expand All @@ -54,6 +127,12 @@ def __init__(
output_truncation_bool = False
super().__init__(env, output_truncation_bool=output_truncation_bool)

def step(self, action):
step_returns = self.env.step(action)
return step_api_compatibility(
step_returns, self.output_truncation_bool, self.is_vector_env
)


class FlattenObservation(BaseObservationWrapper):
def __init__(self, env: gym.Env):
Expand Down
1 change: 1 addition & 0 deletions openrl/selfplay/wrappers/base_multiplayer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs):
action = self.get_opponent_action(
player_name, observation, reward, termination, truncation, info
)

self.env.step(action)

def on_episode_end(
Expand Down

0 comments on commit 6cd773a

Please sign in to comment.