Skip to content

Commit

Permalink
Fix zero-life reset error in Atari env (#175)
Browse files Browse the repository at this point in the history
add `info["terminated"]` as an indicator of `env.game_over()`
  • Loading branch information
Trinkle23897 authored Jul 24, 2022
1 parent ea86c2b commit 1eedd34
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 7 deletions.
2 changes: 1 addition & 1 deletion envpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
register,
)

__version__ = "0.6.2.post2"
__version__ = "0.6.3"
__all__ = [
"register",
"make",
Expand Down
4 changes: 2 additions & 2 deletions envpool/atari/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_lowlevel_step(self) -> None:
np.testing.assert_allclose(done.shape, (num_envs,))
self.assertEqual(done.dtype, np.bool_)
self.assertIsInstance(info, dict)
self.assertEqual(len(info), 5)
self.assertEqual(len(info), 6)
self.assertEqual(info["env_id"].dtype, np.int32)
self.assertEqual(info["lives"].dtype, np.int32)
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_highlevel_step(self) -> None:
np.testing.assert_allclose(done.shape, (num_envs,))
self.assertEqual(done.dtype, np.bool_)
self.assertIsInstance(info, dict)
self.assertEqual(len(info), 5)
self.assertEqual(len(info), 6)
self.assertEqual(info["env_id"].dtype, np.int32)
self.assertEqual(info["lives"].dtype, np.int32)
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
Expand Down
8 changes: 5 additions & 3 deletions envpool/atari/atari_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class AtariEnvFns {
conf["img_height"_], conf["img_width"_]},
{0, 255})),
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})),
"info:lives"_.Bind(Spec<int>({-1}, {0, 5})),
"info:reward"_.Bind(Spec<float>({-1})));
"info:lives"_.Bind(Spec<int>({-1})),
"info:reward"_.Bind(Spec<float>({-1})),
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down Expand Up @@ -199,7 +200,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
PushStack(false, skip_id == 0);
++elapsed_step_;
done_ |= (elapsed_step_ >= max_episode_steps_);
if (episodic_life_ && env_->lives() < lives_) {
if (episodic_life_ && 0 < env_->lives() && env_->lives() < lives_) {
done_ = true;
}
float discount;
Expand Down Expand Up @@ -229,6 +230,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
state["reward"_] = reward;
state["info:lives"_] = lives_;
state["info:reward"_] = info_reward;
state["info:terminated"_] = env_->game_over();
for (int i = 0; i < stack_num_; ++i) {
state["obs"_]
.Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3)
Expand Down
45 changes: 45 additions & 0 deletions envpool/atari/atari_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,51 @@ def test_align(self) -> None:
np.testing.assert_allclose(obs0, obs1)
# cv2.imwrite(f"/tmp/log/align{i}.png", obs0[0, 1:].transpose(1, 2, 0))

def test_reset_life(self) -> None:
"""Issue 171."""
for env_id in [
"atlantis", "backgammon", "breakout", "pong", "wizard_of_wor"
]:
np.random.seed(0)
env = AtariGymEnvPool(
AtariEnvSpec(
AtariEnvSpec.gen_config(task=env_id, num_envs=1, episodic_life=True)
)
)
action_num = env.action_space.n # type: ignore
env.reset()
info = env.step(np.array([0]))[-1]
if info["lives"].sum() == 0:
# no life in this game
continue
for _ in range(10000):
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
if info["lives"][0] == 0:
break
else:
self.assertFalse(info["terminated"][0])
if info["lives"][0] > 0:
# step too long
continue
# for normal atari (e.g., breakout)
# take an additional step after all lives are exhausted
_, _, next_done, next_info = env.step(
np.random.randint(0, action_num, 1)
)
if done[0] and next_info["lives"][0] > 0:
self.assertTrue(info["terminated"][0])
continue
self.assertFalse(done[0])
self.assertFalse(info["terminated"][0])
while not done[0]:
self.assertFalse(info["terminated"][0])
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
_, _, next_done, next_info = env.step(
np.random.randint(0, action_num, 1)
)
self.assertTrue(next_info["lives"][0] > 0)
self.assertTrue(info["terminated"][0])

def test_partial_step(self) -> None:
num_envs = 5
max_episode_steps = 10
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = envpool
version = 0.6.2.post2
version = 0.6.3
author = "EnvPool Contributors"
author_email = "sail@sea.com"
description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments."
Expand Down

0 comments on commit 1eedd34

Please sign in to comment.