Skip to content

Commit

Permalink
Auto-set num_envs_per_worker in rllib_example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-267 authored Apr 5, 2024
1 parent beb1203 commit fe60a8b
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions examples/rllib_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,21 @@ def env_creator(env_config):

tune.register_env(env_name, env_creator)

# Make temp env to get info needed for multi-agent training config
if is_multiagent:
policy_names = None
num_envs = None
tmp_env = None

if is_multiagent: # Make temp env to get info needed for multi-agent training config
print("Starting a temporary multi-agent env to get the policy names")
tmp_env = GDRLPettingZooEnv(config=exp["config"]["env_config"], show_window=False)
policy_names = tmp_env.agent_policy_names
print("Policy names for each Agent (AIController) set in the Godot Environment", policy_names)
tmp_env.close()
else: # Make temp env to get info needed for setting num_workers training config
print("Starting a temporary env to get the number of envs and auto-set the num_envs_per_worker config value")
tmp_env = GodotEnv(env_path=exp["config"]["env_config"]["env_path"], show_window=False)
num_envs = tmp_env.num_envs

tmp_env.close()

def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str:
return policy_names[agent_id]
Expand All @@ -67,6 +75,8 @@ def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str:
"policies": {policy_name: PolicySpec() for policy_name in policy_names},
"policy_mapping_fn": policy_mapping_fn,
}
else:
exp["config"]["num_envs_per_worker"] = num_envs

tuner = None
if not args.restore:
Expand All @@ -89,6 +99,7 @@ def policy_mapping_fn(agent_id: int, episode, worker, **kwargs) -> str:

# Onnx export after training if a checkpoint was saved
checkpoint = result.get_best_result().checkpoint

if checkpoint:
result_path = result.get_best_result().path
ppo = Algorithm.from_checkpoint(checkpoint)
Expand Down

0 comments on commit fe60a8b

Please sign in to comment.