diff --git a/envpool/__init__.py b/envpool/__init__.py index 270f2570..23626636 100644 --- a/envpool/__init__.py +++ b/envpool/__init__.py @@ -23,7 +23,7 @@ register, ) -__version__ = "0.6.2.post2" +__version__ = "0.6.3" __all__ = [ "register", "make", diff --git a/envpool/atari/api_test.py b/envpool/atari/api_test.py index aac4a6b0..7295be74 100644 --- a/envpool/atari/api_test.py +++ b/envpool/atari/api_test.py @@ -259,7 +259,7 @@ def test_lowlevel_step(self) -> None: np.testing.assert_allclose(done.shape, (num_envs,)) self.assertEqual(done.dtype, np.bool_) self.assertIsInstance(info, dict) - self.assertEqual(len(info), 5) + self.assertEqual(len(info), 6) self.assertEqual(info["env_id"].dtype, np.int32) self.assertEqual(info["lives"].dtype, np.int32) self.assertEqual(info["players"]["env_id"].dtype, np.int32) @@ -295,7 +295,7 @@ def test_highlevel_step(self) -> None: np.testing.assert_allclose(done.shape, (num_envs,)) self.assertEqual(done.dtype, np.bool_) self.assertIsInstance(info, dict) - self.assertEqual(len(info), 5) + self.assertEqual(len(info), 6) self.assertEqual(info["env_id"].dtype, np.int32) self.assertEqual(info["lives"].dtype, np.int32) self.assertEqual(info["players"]["env_id"].dtype, np.int32) diff --git a/envpool/atari/atari_env.h b/envpool/atari/atari_env.h index d417ea2a..f20ded12 100644 --- a/envpool/atari/atari_env.h +++ b/envpool/atari/atari_env.h @@ -65,8 +65,9 @@ class AtariEnvFns { conf["img_height"_], conf["img_width"_]}, {0, 255})), "discount"_.Bind(Spec({-1}, {0.0, 1.0})), - "info:lives"_.Bind(Spec({-1}, {0, 5})), - "info:reward"_.Bind(Spec({-1}))); + "info:lives"_.Bind(Spec({-1})), + "info:reward"_.Bind(Spec({-1})), + "info:terminated"_.Bind(Spec({-1}, {0, 1}))); } template static decltype(auto) ActionSpec(const Config& conf) { @@ -199,7 +200,7 @@ class AtariEnv : public Env { PushStack(false, skip_id == 0); ++elapsed_step_; done_ |= (elapsed_step_ >= max_episode_steps_); - if (episodic_life_ && env_->lives() < lives_) { + if (episodic_life_ && 0 < env_->lives() && env_->lives() < lives_) { done_ = true; } float discount; @@ -229,6 +230,7 @@ class AtariEnv : public Env { state["reward"_] = reward; state["info:lives"_] = lives_; state["info:reward"_] = info_reward; + state["info:terminated"_] = env_->game_over(); for (int i = 0; i < stack_num_; ++i) { state["obs"_] .Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3) diff --git a/envpool/atari/atari_envpool_test.py b/envpool/atari/atari_envpool_test.py index 2786c6f6..826ce0b1 100644 --- a/envpool/atari/atari_envpool_test.py +++ b/envpool/atari/atari_envpool_test.py @@ -78,6 +78,51 @@ def test_align(self) -> None: np.testing.assert_allclose(obs0, obs1) # cv2.imwrite(f"/tmp/log/align{i}.png", obs0[0, 1:].transpose(1, 2, 0)) + def test_reset_life(self) -> None: + """Issue 171.""" + for env_id in [ + "atlantis", "backgammon", "breakout", "pong", "wizard_of_wor" + ]: + np.random.seed(0) + env = AtariGymEnvPool( + AtariEnvSpec( + AtariEnvSpec.gen_config(task=env_id, num_envs=1, episodic_life=True) + ) + ) + action_num = env.action_space.n # type: ignore + env.reset() + info = env.step(np.array([0]))[-1] + if info["lives"].sum() == 0: + # no life in this game + continue + for _ in range(10000): + _, _, done, info = env.step(np.random.randint(0, action_num, 1)) + if info["lives"][0] == 0: + break + else: + self.assertFalse(info["terminated"][0]) + if info["lives"][0] > 0: + # step too long + continue + # for normal atari (e.g., breakout) + # take an additional step after all lives are exhausted + _, _, next_done, next_info = env.step( + np.random.randint(0, action_num, 1) + ) + if done[0] and next_info["lives"][0] > 0: + self.assertTrue(info["terminated"][0]) + continue + self.assertFalse(done[0]) + self.assertFalse(info["terminated"][0]) + while not done[0]: + self.assertFalse(info["terminated"][0]) + _, _, done, info = env.step(np.random.randint(0, action_num, 1)) + _, _, next_done, next_info = env.step( + np.random.randint(0, action_num, 1) + ) + self.assertTrue(next_info["lives"][0] > 0) + self.assertTrue(info["terminated"][0]) + def test_partial_step(self) -> None: num_envs = 5 max_episode_steps = 10 diff --git a/setup.cfg b/setup.cfg index 77e17157..4d435609 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = envpool -version = 0.6.2.post2 +version = 0.6.3 author = "EnvPool Contributors" author_email = "sail@sea.com" description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments."