Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
huangshiyu13 committed Mar 22, 2024
1 parent 0157948 commit a4a0929
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 26 deletions.
5 changes: 3 additions & 2 deletions examples/snake/submissions/rule_v1/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def step(self): # delay: prevent rear-end collision
and self.state + state == 0
): # third claim or more
print(
"snake {} meets third or more claim in grid ({}, {})"
.format(key, x_, y_)
"snake {} meets third or more claim in grid ({}, {})".format(
key, x_, y_
)
)
controversy = self.controversy[(x_, y_)]
pprint.pprint(controversy)
Expand Down
4 changes: 1 addition & 3 deletions openrl/buffers/offpolicy_replay_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ def feed_forward_generator(
batch_size = n_rollout_threads * (episode_length - 1) * num_agents

if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
assert batch_size >= num_mini_batch, (
"DQN requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of DQN mini batches ({})."
Expand Down
12 changes: 3 additions & 9 deletions openrl/buffers/replay_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,7 @@ def feed_forward_generator(
batch_size = n_rollout_threads * episode_length * num_agents

if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
Expand Down Expand Up @@ -658,9 +656,7 @@ def feed_forward_critic_obs_generator(
batch_size = n_rollout_threads * episode_length

if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
Expand Down Expand Up @@ -721,9 +717,7 @@ def feed_forward_generator_transformer(
batch_size = n_rollout_threads * episode_length

if mini_batch_size is None:
assert (
batch_size >= num_mini_batch
), (
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
Expand Down
5 changes: 3 additions & 2 deletions openrl/envs/mpe/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def get_display(spec):
return pyglet.canvas.Display(spec)
else:
raise error.Error(
"Invalid display specification: {}. (Must be a string like :0 or None.)"
.format(spec)
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
spec
)
)


Expand Down
2 changes: 1 addition & 1 deletion openrl/envs/nlp/utils/metrics/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" METEOR metric. """
"""METEOR metric."""

import datasets
import evaluate
Expand Down
4 changes: 1 addition & 3 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP

def step_fetch(
self, timeout: Optional[Union[int, float]] = None
) -> Union[
def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
Expand Down
5 changes: 3 additions & 2 deletions openrl/envs/vec_env/wrappers/base_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def step(self, actions, *args, **kwargs):
)
else:
raise ValueError(
"Invalid step return value, expected 4 or 5 values, got {} values"
.format(len(results))
"Invalid step return value, expected 4 or 5 values, got {} values".format(
len(results)
)
)

def observation(self, observation: ObsType) -> ObsType:
Expand Down
2 changes: 1 addition & 1 deletion openrl/selfplay/opponents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_opponent_template(opponent_template: Union[str, Path]):


def get_opponent_info(
info_path: Optional[Union[str, Path]]
info_path: Optional[Union[str, Path]],
) -> Optional[Dict[str, str]]:
if info_path is None:
return None
Expand Down
7 changes: 4 additions & 3 deletions openrl/selfplay/wrappers/opponent_pool_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def on_episode_end(
else:
loser_id = self.opponent.opponent_id
loser_ids.append(loser_id)
assert set(winner_ids).isdisjoint(set(loser_ids)), (
"winners and losers must be disjoint, but get winners: {}, losers: {}"
.format(winner_ids, loser_ids)
assert set(winner_ids).isdisjoint(
set(loser_ids)
), "winners and losers must be disjoint, but get winners: {}, losers: {}".format(
winner_ids, loser_ids
)
battle_info = {"winner_ids": winner_ids, "loser_ids": loser_ids}
self.api_client.add_battle_result(battle_info)

0 comments on commit a4a0929

Please sign in to comment.