diff --git a/Makefile b/Makefile index d6a00deb..4427a1a3 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: # see https://www.flake8rules.com/ ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff check ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 109b2eb5..60e36123 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,32 @@ Changelog ========== + +Release 2.4.0a4 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 2.4.0 + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Updated PyTorch version on CI to 2.3.1 +- Remove unnecessary SDE noise resampling in PPO/TRPO update + +Documentation: +^^^^^^^^^^^^^^ + + Release 2.3.0 (2024-03-31) -------------------------- diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 05ffb010..7cd97cf3 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -342,10 +342,6 @@ def train(self) -> None: # Convert mask from float to bool mask = rollout_data.mask > 1e-8 - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 8165081b..a8e65567 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -261,11 +261,6 @@ def train(self) -> None: # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - # batch_size is only used for the value function - self.policy.reset_noise(actions.shape[0]) - with th.no_grad(): # Note: is copy enough, no need for deepcopy? # If using gSDE and deepcopy, we need to use `old_distribution.distribution` diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 276cbf9e..2d22b158 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.3.0 +2.4.0a4 diff --git a/setup.py b/setup.py index 49b24588..158fff98 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.3.0,<3.0", + "stable_baselines3>=2.4.0a4,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",