Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 4, 2024
1 parent 39df9cc commit ab30f13
Showing 1 changed file with 112 additions and 99 deletions.
211 changes: 112 additions & 99 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,21 @@ def log_collection(
step: int,
) -> float:
to_log = {}
json_metrics = {}
groups_episode_rewards = []
gobal_done = self._get_global_done(batch) # Does not have agent dim
any_episode_ended = gobal_done.nonzero().numel() > 0
if not any_episode_ended:
warnings.warn(
"No episode terminated this iteration and thus the episode rewards will be NaN, "
"this is normal if your horizon is longer then one iteration. Learning is proceeding fine."
"The episodes will probably terminate in a future iteration."
)
for group in self.group_map.keys():
done = self._get_global_done(group, batch) # Does not have agent dim
group_episode_rewards = self._log_individual_and_group_rewards(
group, batch, done, to_log, prefix="collection"
group, batch, gobal_done, any_episode_ended, to_log
)
json_metrics[group + "_return"] = group_episode_rewards
# group_episode_rewards has shape (n_episodes) as we took the mean over agents in the group
groups_episode_rewards.append(group_episode_rewards)

if "info" in batch.get(("next", group)).keys():
to_log.update(
Expand All @@ -118,8 +126,9 @@ def log_collection(
}
)
to_log.update(task.log_info(batch))
# global_episode_rewards has shape (n_episodes) as we took the mean over groups
global_episode_rewards = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="collection"
groups_episode_rewards, to_log, prefix="collection"
)

self.log(to_log, step=step)
Expand All @@ -145,57 +154,51 @@ def log_evaluation(
not len(self.loggers) and not self.experiment_config.create_json
) or not len(rollouts):
return

# Cut rollouts at first done
max_length_rollout_0 = 0
for i in range(len(rollouts)):
r = rollouts[i]
next_done = self._get_global_done(r)
# Reduce it to batch size
next_done = next_done.sum(
tuple(range(r.batch_dims, next_done.ndim)),
dtype=torch.bool,
)
# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts[i] = r

to_log = {}
json_metrics = {}
max_length_rollout_0 = 0
for group in self.group_map.keys():
# Cut the rollouts at the first done
rollouts_group = []
for i, r in enumerate(rollouts):
next_done = self._get_agents_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
tuple(range(r.batch_dims, next_done.ndim)),
dtype=torch.bool,
)
# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts_group.append(r)

returns = [
self._get_reward(group, td).sum(0).mean().item()
for td in rollouts_group
]
json_metrics[group + "_return"] = torch.tensor(
returns, device=rollouts_group[0].device
# returns has shape (n_episodes, n_agents_in_group)
returns = torch.stack(
[self._get_reward(group, td).sum(0) for td in rollouts],
dim=0,
)
to_log.update(
{
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts_group),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
self._log_min_mean_max(
to_log, f"eval/{group}/reward/episode_reward", returns
)
json_metrics[group + "_return"] = returns.mean(
dim=tuple(range(1, returns.ndim))
) # result has shape (n_episodes) as we take the mean over agents in the group

mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
to_log.update(
{
"eval/reward/episode_reward_min": mean_group_return.min().item(),
"eval/reward/episode_reward_mean": mean_group_return.mean().item(),
"eval/reward/episode_reward_max": mean_group_return.max().item(),
"eval/reward/episode_len_mean": sum(td.batch_size[0] for td in rollouts)
/ len(rollouts),
}
mean_group_return = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="eval"
)
# mean_group_return has shape (n_episodes) as we take the mean groups
json_metrics["return"] = mean_group_return

to_log["eval/reward/episode_len_mean"] = sum(
td.batch_size[0] for td in rollouts
) / len(rollouts)

if self.json_writer is not None:
self.json_writer.write(
metrics=json_metrics,
Expand Down Expand Up @@ -252,15 +255,15 @@ def _get_reward(
def _get_agents_done(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "done") not in td.keys(True, True):
group_td = td.get(("next", group))
if ("done") not in group_td.keys():
done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1)
else:
done = td.get(("next", group, "done"))
done = group_td.get("done")
return done.any(-2) if remove_agent_dim else done

def _get_global_done(
self,
group: str,
td: TensorDictBase,
):
done = td.get(("next", "done"))
Expand All @@ -273,67 +276,77 @@ def _get_episode_reward(
return episode_reward.mean(-2) if remove_agent_dim else episode_reward

def _log_individual_and_group_rewards(
self, group, batch, global_done, to_log, prefix: str
self,
group: str,
batch: TensorDictBase,
global_done: Tensor,
any_episode_ended: bool,
to_log: Dict[str, Tensor],
prefix: str = "collection",
log_individual_agents: bool = True,
):
reward = self._get_reward(group, batch) # Has agent dim
episode_reward = self._get_episode_reward(group, batch) # Has agent dim
n_agents_in_group = episode_reward.shape[-2]

# The trajectories are considered until the global done
episode_reward = episode_reward[
global_done.unsqueeze(-1).expand((*batch.get_item_shape(group), 1))
]
# Add multiagent dim
unsqueeze_global_done = global_done.unsqueeze(-1).expand(
(*batch.get_item_shape(group), 1)
)
#######
# All trajectories are considered done at the global done
#######

# 1. Here we log rewards from individual agent data
if log_individual_agents:
for i in range(n_agents_in_group):
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/reward",
reward[..., i, :],
)
if any_episode_ended:
agent_global_done = unsqueeze_global_done[..., i, :]
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/episode_reward",
episode_reward[..., i, :][agent_global_done],
)

for i in range(n_agents_in_group):
to_log.update(
{
f"{prefix}/{group}/reward/agent_{i}/reward_min": reward[..., i, :]
.min()
.item(),
f"{prefix}/{group}/reward/agent_{i}/reward_mean": reward[..., i, :]
.mean()
.item(),
f"{prefix}/{group}/reward/agent_{i}/reward_max": reward[..., i, :]
.max()
.item(),
}
# 2. Here we log rewards from group data
if any_episode_ended:
group_episode_reward = episode_reward[unsqueeze_global_done]
self._log_min_mean_max(
to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward
)
if episode_reward.numel() > 0:
self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward)

to_log.update(
{
f"{prefix}/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"{prefix}/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"{prefix}/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
# 3. We take the mean over agents in the group so that we will later log from a global perspecitve
return episode_reward.mean(-2)[global_done]

def _log_global_episode_reward(
self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str
):
# Each element in the list is the episode reward (with shape n_episodes) for the group at the global done,
# so they will have same shape as done is shared
episode_rewards = torch.stack(episode_rewards, dim=0).mean(
0
) # Mean over groups
if episode_rewards.numel() > 0:
self._log_min_mean_max(
to_log, f"{prefix}/reward/episode_reward", episode_rewards
)

return episode_rewards

def _log_min_mean_max(self, to_log: Dict[str, Tensor], key: str, value: Tensor):
to_log.update(
{
f"{prefix}/{group}/reward/reward_min": reward.min().item(),
f"{prefix}/{group}/reward/reward_mean": reward.mean().item(),
f"{prefix}/{group}/reward/reward_max": reward.max().item(),
key + "_min": value.min().item(),
key + "_mean": value.mean().item(),
key + "_max": value.max().item(),
}
)
return episode_reward

def _log_global_episode_reward(self, episode_rewards, to_log, prefix: str):
# Each element in the list is the episode reward for the group at the global done, so they will have same shape as done is shared
episode_rewards = torch.stack(episode_rewards, dim=0)
if episode_rewards.numel() > 0:
to_log.update(
{
f"{prefix}/reward/episode_reward_min": episode_rewards.min().item(),
f"{prefix}/reward/episode_reward_mean": episode_rewards.mean().item(),
f"{prefix}/reward/episode_reward_max": episode_rewards.max().item(),
}
)
else:
warnings.warn(
"No episode terminated this iteration and thus the overall episode reward in NaN, "
"this is normal if your horizon is longer then one iteration. Learning is proceeding fine."
"The episodes will probably terminate in a future iteration."
)
return episode_rewards


class JsonWriter:
Expand Down

0 comments on commit ab30f13

Please sign in to comment.