Skip to content

Commit

Permalink
improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
huangshiyu13 committed Dec 12, 2023
1 parent e29fd90 commit 40d8304
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
11 changes: 7 additions & 4 deletions openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def cal_value_loss(
-self.clip_param, self.clip_param
)

if self._use_popart or self._use_valuenorm:
if (self._use_popart or self._use_valuenorm) and value_normalizer is not None:
value_normalizer.update(return_batch)
error_clipped = (
value_normalizer.normalize(return_batch) - value_pred_clipped
Expand Down Expand Up @@ -382,9 +382,12 @@ def train_ppo(self, buffer, turn_on):
].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
advantages = buffer.returns[:-1] - value_normalizer.denormalize(
buffer.value_preds[:-1]
)
if value_normalizer is not None:
advantages = buffer.returns[:-1] - value_normalizer.denormalize(
buffer.value_preds[:-1]
)
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]

Expand Down
8 changes: 6 additions & 2 deletions openrl/buffers/replay_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ def compute_returns(self, next_value, value_normalizer=None):
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.shape[0])):
if self._use_popart or self._use_valuenorm:
if (
self._use_popart or self._use_valuenorm
) and value_normalizer is not None:
# step + 1
delta = (
self.rewards[step]
Expand Down Expand Up @@ -357,7 +359,9 @@ def compute_returns(self, next_value, value_normalizer=None):
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.shape[0])):
if self._use_popart or self._use_valuenorm:
if (
self._use_popart or self._use_valuenorm
) and value_normalizer is not None:
self.returns[step] = (
self.returns[step + 1] * self.gamma * self.masks[step + 1]
+ self.rewards[step]
Expand Down
19 changes: 16 additions & 3 deletions tests/test_buffer/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from openrl.runners.common import PPOAgent as Agent


@pytest.fixture(scope="module", params=["--episode_length 10"])
def episode_length(request):
return request.param


@pytest.fixture(
scope="module",
params=[
Expand Down Expand Up @@ -64,9 +69,17 @@ def use_popart(request):


@pytest.fixture(scope="module")
def config(use_proper_time_limits, use_popart, use_gae, generator_type):
def config(use_proper_time_limits, use_popart, use_gae, generator_type, episode_length):
config_str = (
use_proper_time_limits + " " + use_popart + " " + use_gae + " " + generator_type
use_proper_time_limits
+ " "
+ use_popart
+ " "
+ use_gae
+ " "
+ generator_type
+ " "
+ episode_length
)

from openrl.configs.config import create_config_parser
Expand All @@ -80,7 +93,7 @@ def config(use_proper_time_limits, use_popart, use_gae, generator_type):
def test_buffer_generator(config):
env = make("CartPole-v1", env_num=2)
agent = Agent(Net(env, cfg=config))
agent.train(total_time_steps=200)
agent.train(total_time_steps=50)
env.close()


Expand Down
19 changes: 16 additions & 3 deletions tests/test_buffer/test_offpolicy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from openrl.runners.common import DQNAgent as Agent


@pytest.fixture(scope="module", params=["--episode_length 10"])
def episode_length(request):
return request.param


@pytest.fixture(
scope="module",
params=[
Expand All @@ -46,8 +51,16 @@ def use_popart(request):


@pytest.fixture(scope="module")
def config(use_proper_time_limits, use_popart, generator_type):
config_str = use_proper_time_limits + " " + use_popart + " " + generator_type
def config(use_proper_time_limits, use_popart, generator_type, episode_length):
config_str = (
use_proper_time_limits
+ " "
+ use_popart
+ " "
+ generator_type
+ " "
+ episode_length
)

from openrl.configs.config import create_config_parser

Expand All @@ -60,7 +73,7 @@ def config(use_proper_time_limits, use_popart, generator_type):
def test_buffer_generator(config):
env = make("CartPole-v1", env_num=2)
agent = Agent(Net(env, cfg=config))
agent.train(total_time_steps=200)
agent.train(total_time_steps=50)
env.close()


Expand Down

0 comments on commit 40d8304

Please sign in to comment.