diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py index e72e01bb..80e9f23f 100644 --- a/openrl/algorithms/ppo.py +++ b/openrl/algorithms/ppo.py @@ -179,7 +179,7 @@ def cal_value_loss( -self.clip_param, self.clip_param ) - if self._use_popart or self._use_valuenorm: + if (self._use_popart or self._use_valuenorm) and value_normalizer is not None: value_normalizer.update(return_batch) error_clipped = ( value_normalizer.normalize(return_batch) - value_pred_clipped @@ -382,9 +382,12 @@ def train_ppo(self, buffer, turn_on): ].module.value_normalizer else: value_normalizer = self.algo_module.get_critic_value_normalizer() - advantages = buffer.returns[:-1] - value_normalizer.denormalize( - buffer.value_preds[:-1] - ) + if value_normalizer is not None: + advantages = buffer.returns[:-1] - value_normalizer.denormalize( + buffer.value_preds[:-1] + ) + else: + advantages = buffer.returns[:-1] - buffer.value_preds[:-1] else: advantages = buffer.returns[:-1] - buffer.value_preds[:-1] diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py index 8d092d7d..a8f4c1b7 100644 --- a/openrl/buffers/replay_data.py +++ b/openrl/buffers/replay_data.py @@ -323,7 +323,9 @@ def compute_returns(self, next_value, value_normalizer=None): self.value_preds[-1] = next_value gae = 0 for step in reversed(range(self.rewards.shape[0])): - if self._use_popart or self._use_valuenorm: + if ( + self._use_popart or self._use_valuenorm + ) and value_normalizer is not None: # step + 1 delta = ( self.rewards[step] @@ -357,7 +359,9 @@ def compute_returns(self, next_value, value_normalizer=None): else: self.returns[-1] = next_value for step in reversed(range(self.rewards.shape[0])): - if self._use_popart or self._use_valuenorm: + if ( + self._use_popart or self._use_valuenorm + ) and value_normalizer is not None: self.returns[step] = ( self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step] diff --git a/tests/test_buffer/test_generator.py b/tests/test_buffer/test_generator.py index 4de33c02..27763635 100644 --- a/tests/test_buffer/test_generator.py +++ b/tests/test_buffer/test_generator.py @@ -25,6 +25,11 @@ from openrl.runners.common import PPOAgent as Agent +@pytest.fixture(scope="module", params=["--episode_length 10"]) +def episode_length(request): + return request.param + + @pytest.fixture( scope="module", params=[ @@ -64,9 +69,17 @@ def use_popart(request): @pytest.fixture(scope="module") -def config(use_proper_time_limits, use_popart, use_gae, generator_type): +def config(use_proper_time_limits, use_popart, use_gae, generator_type, episode_length): config_str = ( - use_proper_time_limits + " " + use_popart + " " + use_gae + " " + generator_type + use_proper_time_limits + + " " + + use_popart + + " " + + use_gae + + " " + + generator_type + + " " + + episode_length ) from openrl.configs.config import create_config_parser @@ -80,7 +93,7 @@ def config(use_proper_time_limits, use_popart, use_gae, generator_type): def test_buffer_generator(config): env = make("CartPole-v1", env_num=2) agent = Agent(Net(env, cfg=config)) - agent.train(total_time_steps=200) + agent.train(total_time_steps=50) env.close() diff --git a/tests/test_buffer/test_offpolicy_generator.py b/tests/test_buffer/test_offpolicy_generator.py index 5e5da276..ec960973 100644 --- a/tests/test_buffer/test_offpolicy_generator.py +++ b/tests/test_buffer/test_offpolicy_generator.py @@ -25,6 +25,11 @@ from openrl.runners.common import DQNAgent as Agent +@pytest.fixture(scope="module", params=["--episode_length 10"]) +def episode_length(request): + return request.param + + @pytest.fixture( scope="module", params=[ @@ -46,8 +51,16 @@ def use_popart(request): @pytest.fixture(scope="module") -def config(use_proper_time_limits, use_popart, generator_type): - config_str = use_proper_time_limits + " " + use_popart + " " + generator_type +def config(use_proper_time_limits, use_popart, generator_type, episode_length): + config_str = ( + use_proper_time_limits + + " " + + use_popart + + " " + + generator_type + + " " + + episode_length + ) from openrl.configs.config import create_config_parser @@ -60,7 +73,7 @@ def config(use_proper_time_limits, use_popart, generator_type): def test_buffer_generator(config): env = make("CartPole-v1", env_num=2) agent = Agent(Net(env, cfg=config)) - agent.train(total_time_steps=200) + agent.train(total_time_steps=50) env.close()