From a4a0929d673a92067752ae1eb0467bd4c3a488f3 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Fri, 22 Mar 2024 10:54:57 +0800 Subject: [PATCH] update readme --- examples/snake/submissions/rule_v1/submission.py | 5 +++-- openrl/buffers/offpolicy_replay_data.py | 4 +--- openrl/buffers/replay_data.py | 12 +++--------- openrl/envs/mpe/rendering.py | 5 +++-- openrl/envs/nlp/utils/metrics/meteor.py | 2 +- openrl/envs/vec_env/async_venv.py | 4 +--- openrl/envs/vec_env/wrappers/base_wrapper.py | 5 +++-- openrl/selfplay/opponents/utils.py | 2 +- openrl/selfplay/wrappers/opponent_pool_wrapper.py | 7 ++++--- 9 files changed, 20 insertions(+), 26 deletions(-) diff --git a/examples/snake/submissions/rule_v1/submission.py b/examples/snake/submissions/rule_v1/submission.py index db9a81e9..14a4b414 100644 --- a/examples/snake/submissions/rule_v1/submission.py +++ b/examples/snake/submissions/rule_v1/submission.py @@ -243,8 +243,9 @@ def step(self): # delay: prevent rear-end collision and self.state + state == 0 ): # third claim or more print( - "snake {} meets third or more claim in grid ({}, {})" - .format(key, x_, y_) + "snake {} meets third or more claim in grid ({}, {})".format( + key, x_, y_ + ) ) controversy = self.controversy[(x_, y_)] pprint.pprint(controversy) diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py index 31e52e85..7b67fcdd 100644 --- a/openrl/buffers/offpolicy_replay_data.py +++ b/openrl/buffers/offpolicy_replay_data.py @@ -251,9 +251,7 @@ def feed_forward_generator( batch_size = n_rollout_threads * (episode_length - 1) * num_agents if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( + assert batch_size >= num_mini_batch, ( "DQN requires the number of processes ({}) " "* number of steps ({}) * number of agents ({}) = {} " "to be greater than or equal to the number of DQN mini batches ({})." diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py index a8f4c1b7..b81b493f 100644 --- a/openrl/buffers/replay_data.py +++ b/openrl/buffers/replay_data.py @@ -561,9 +561,7 @@ def feed_forward_generator( batch_size = n_rollout_threads * episode_length * num_agents if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( + assert batch_size >= num_mini_batch, ( "PPO requires the number of processes ({}) " "* number of steps ({}) * number of agents ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." @@ -658,9 +656,7 @@ def feed_forward_critic_obs_generator( batch_size = n_rollout_threads * episode_length if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( + assert batch_size >= num_mini_batch, ( "PPO requires the number of processes ({}) " "* number of steps ({}) * number of agents ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." @@ -721,9 +717,7 @@ def feed_forward_generator_transformer( batch_size = n_rollout_threads * episode_length if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( + assert batch_size >= num_mini_batch, ( "PPO requires the number of processes ({}) " "* number of steps ({}) = {} " "to be greater than or equal to the number of PPO mini batches ({})." diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py index a7197dca..9e458999 100644 --- a/openrl/envs/mpe/rendering.py +++ b/openrl/envs/mpe/rendering.py @@ -58,8 +58,9 @@ def get_display(spec): return pyglet.canvas.Display(spec) else: raise error.Error( - "Invalid display specification: {}. (Must be a string like :0 or None.)" - .format(spec) + "Invalid display specification: {}. (Must be a string like :0 or None.)".format( + spec + ) ) diff --git a/openrl/envs/nlp/utils/metrics/meteor.py b/openrl/envs/nlp/utils/metrics/meteor.py index ab15e66d..e9930265 100644 --- a/openrl/envs/nlp/utils/metrics/meteor.py +++ b/openrl/envs/nlp/utils/metrics/meteor.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" METEOR metric. """ +"""METEOR metric.""" import datasets import evaluate diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index e4f10d2b..220f999c 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -342,9 +342,7 @@ def step_send(self, actions: np.ndarray): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP - def step_fetch( - self, timeout: Optional[Union[int, float]] = None - ) -> Union[ + def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[ Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]], Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]], ]: diff --git a/openrl/envs/vec_env/wrappers/base_wrapper.py b/openrl/envs/vec_env/wrappers/base_wrapper.py index 85ca5082..87e5f8a9 100644 --- a/openrl/envs/vec_env/wrappers/base_wrapper.py +++ b/openrl/envs/vec_env/wrappers/base_wrapper.py @@ -230,8 +230,9 @@ def step(self, actions, *args, **kwargs): ) else: raise ValueError( - "Invalid step return value, expected 4 or 5 values, got {} values" - .format(len(results)) + "Invalid step return value, expected 4 or 5 values, got {} values".format( + len(results) + ) ) def observation(self, observation: ObsType) -> ObsType: diff --git a/openrl/selfplay/opponents/utils.py b/openrl/selfplay/opponents/utils.py index 42ddbb2b..73abc041 100644 --- a/openrl/selfplay/opponents/utils.py +++ b/openrl/selfplay/opponents/utils.py @@ -47,7 +47,7 @@ def check_opponent_template(opponent_template: Union[str, Path]): def get_opponent_info( - info_path: Optional[Union[str, Path]] + info_path: Optional[Union[str, Path]], ) -> Optional[Dict[str, str]]: if info_path is None: return None diff --git a/openrl/selfplay/wrappers/opponent_pool_wrapper.py b/openrl/selfplay/wrappers/opponent_pool_wrapper.py index d42c17d1..a24ae10c 100644 --- a/openrl/selfplay/wrappers/opponent_pool_wrapper.py +++ b/openrl/selfplay/wrappers/opponent_pool_wrapper.py @@ -111,9 +111,10 @@ def on_episode_end( else: loser_id = self.opponent.opponent_id loser_ids.append(loser_id) - assert set(winner_ids).isdisjoint(set(loser_ids)), ( - "winners and losers must be disjoint, but get winners: {}, losers: {}" - .format(winner_ids, loser_ids) + assert set(winner_ids).isdisjoint( + set(loser_ids) + ), "winners and losers must be disjoint, but get winners: {}, losers: {}".format( + winner_ids, loser_ids ) battle_info = {"winner_ids": winner_ids, "loser_ids": loser_ids} self.api_client.add_battle_result(battle_info)