From fe60a8ba6b4ebd9ac2ae0e58bfda11b4bae6539a Mon Sep 17 00:00:00 2001 From: Ivan-267 <61947090+Ivan-267@users.noreply.github.com> Date: Fri, 5 Apr 2024 22:10:48 +0200 Subject: [PATCH] Auto-set num_envs_per_worker in rllib_example.py --- examples/rllib_example.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/rllib_example.py b/examples/rllib_example.py index 8fcca985..0907422d 100644 --- a/examples/rllib_example.py +++ b/examples/rllib_example.py @@ -49,13 +49,21 @@ def env_creator(env_config): tune.register_env(env_name, env_creator) - # Make temp env to get info needed for multi-agent training config - if is_multiagent: + policy_names = None + num_envs = None + tmp_env = None + + if is_multiagent: # Make temp env to get info needed for multi-agent training config print("Starting a temporary multi-agent env to get the policy names") tmp_env = GDRLPettingZooEnv(config=exp["config"]["env_config"], show_window=False) policy_names = tmp_env.agent_policy_names print("Policy names for each Agent (AIController) set in the Godot Environment", policy_names) - tmp_env.close() + else: # Make temp env to get info needed for setting num_workers training config + print("Starting a temporary env to get the number of envs and auto-set the num_envs_per_worker config value") + tmp_env = GodotEnv(env_path=exp["config"]["env_config"]["env_path"], show_window=False) + num_envs = tmp_env.num_envs + + tmp_env.close() def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str: return policy_names[agent_id] @@ -67,6 +75,8 @@ def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str: "policies": {policy_name: PolicySpec() for policy_name in policy_names}, "policy_mapping_fn": policy_mapping_fn, } + else: + exp["config"]["num_envs_per_worker"] = num_envs tuner = None if not args.restore: @@ -89,6 +99,7 @@ def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str: # Onnx export after training if a checkpoint was saved checkpoint = result.get_best_result().checkpoint + if checkpoint: result_path = result.get_best_result().path ppo = Algorithm.from_checkpoint(checkpoint)