diff --git a/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py b/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py index 859ea803..cee0863e 100644 --- a/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py @@ -108,10 +108,12 @@ def _build_one_step_graph(self, models: Sequence[Model], training_variables: Tra lower_bounds = NF.minimum2(probability_ratio * advantage, clipped_ratio * advantage) clip_loss = NF.mean(lower_bounds) - entropy = distribution.entropy() - entropy_loss = NF.mean(entropy) - - self._pi_loss += 0.0 if ignore_loss else (-clip_loss - self._config.entropy_coefficient * entropy_loss) + if self._config.entropy_coefficient != 0.0: + entropy = distribution.entropy() + entropy_loss = NF.mean(entropy) + self._pi_loss += 0.0 if ignore_loss else (-clip_loss - self._config.entropy_coefficient * entropy_loss) + else: + self._pi_loss += 0.0 if ignore_loss else -clip_loss def _setup_training_variables(self, batch_size) -> TrainingVariables: # Training input variables diff --git a/tests/model_trainers/test_policy_trainers.py b/tests/model_trainers/test_policy_trainers.py index a05dc4e4..51654d0f 100644 --- a/tests/model_trainers/test_policy_trainers.py +++ b/tests/model_trainers/test_policy_trainers.py @@ -24,10 +24,12 @@ import nnabla.parametric_functions as NPF import nnabla_rl.model_trainers as MT from nnabla_rl.distributions.gaussian import Gaussian +from nnabla_rl.distributions.squashed_gaussian import SquashedGaussian from nnabla_rl.environments.dummy import DummyContinuous from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import LossIntegration from nnabla_rl.model_trainers.policy.dpg_policy_trainer import DPGPolicyTrainer +from nnabla_rl.model_trainers.policy.ppo_policy_trainer import PPOPolicyTrainer from nnabla_rl.model_trainers.policy.soft_policy_trainer import AdjustableTemperature, SoftPolicyTrainer from nnabla_rl.model_trainers.policy.trpo_policy_trainer import ( _concat_network_params_in_ndarray, @@ -65,11 +67,30 @@ def pi(self, s): return s +class StochasticNonRnnPolicy(StochasticPolicy): + def __init__(self, scope_name: str, squash: bool = False): + super().__init__(scope_name) + self._squash = squash + + def pi(self, s): + if self._squash: + return SquashedGaussian( + mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), + ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)), + ) + else: + return Gaussian( + mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), + ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)), + ) + + class StochasticRnnPolicy(StochasticPolicy): - def __init__(self, scope_name: str): + def __init__(self, scope_name: str, squash: bool = False): super().__init__(scope_name) self._internal_state_shape = (10,) self._fake_internal_state = None + self._squash = squash def is_recurrent(self) -> bool: return True @@ -85,9 +106,16 @@ def get_internal_states(self): def pi(self, s): self._fake_internal_state = self._fake_internal_state * 2 - return Gaussian( - mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)) - ) + if self._squash: + return SquashedGaussian( + mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), + ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)), + ) + else: + return Gaussian( + mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), + ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)), + ) class TrainerTest(metaclass=ABCMeta): @@ -349,6 +377,43 @@ def test_with_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): # pass: If no ecror occurs +class TestPPOPolicyTrainer(TrainerTest): + def setup_method(self, method): + nn.clear_parameters() + + @pytest.mark.parametrize("unroll_steps", [1, 2]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("entropy_coefficient", [0.0, 1.0]) + @pytest.mark.parametrize("squash", [True, False]) + def test_with_non_rnn_model(self, unroll_steps, burn_in_steps, loss_integration, entropy_coefficient, squash): + env_info = EnvironmentInfo.from_env(DummyContinuous()) + + policy = StochasticNonRnnPolicy("stub_pi", squash=squash) + config = MT.policy_trainers.PPOPolicyTrainerConfig( + unroll_steps=unroll_steps, + burn_in_steps=burn_in_steps, + loss_integration=loss_integration, + entropy_coefficient=entropy_coefficient, + ) + if squash and entropy_coefficient != 0.0: + with pytest.raises(NotImplementedError): + PPOPolicyTrainer( + policy, + solvers={}, + env_info=env_info, + config=config, + ) + else: + PPOPolicyTrainer( + policy, + solvers={}, + env_info=env_info, + config=config, + ) + # pass: If no ecror occurs + + class TestAdjustableTemperature(TrainerTest): def test_initial_temperature(self): initial_value = 5.0