Skip to content

Commit

Permalink
[BugFix] More flexible episode_reward computation in logger (#136)
Browse files Browse the repository at this point in the history
* amend

* amend

* disable single agent logging

* amend

* amend

* amend
  • Loading branch information
matteobettini authored Oct 4, 2024
1 parent 58ff47b commit b214375
Showing 1 changed file with 147 additions and 85 deletions.
232 changes: 147 additions & 85 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import json
import os
import warnings
from pathlib import Path

from typing import Dict, List, Optional

import numpy as np
Expand Down Expand Up @@ -91,28 +93,27 @@ def log_collection(
step: int,
) -> float:
to_log = {}
json_metrics = {}
groups_episode_rewards = []
gobal_done = self._get_global_done(batch) # Does not have agent dim
any_episode_ended = gobal_done.nonzero().numel() > 0
if not any_episode_ended:
warnings.warn(
"No episode terminated this iteration and thus the episode rewards will be NaN, "
"this is normal if your horizon is longer then one iteration. Learning is proceeding fine."
"The episodes will probably terminate in a future iteration."
)
for group in self.group_map.keys():
episode_reward = self._get_episode_reward(group, batch)
done = self._get_done(group, batch)
reward = self._get_reward(group, batch)
to_log.update(
{
f"collection/{group}/reward/reward_min": reward.min().item(),
f"collection/{group}/reward/reward_mean": reward.mean().item(),
f"collection/{group}/reward/reward_max": reward.max().item(),
}
group_episode_rewards = self._log_individual_and_group_rewards(
group,
batch,
gobal_done,
any_episode_ended,
to_log,
log_individual_agents=False, # Turn on if you want single agent granularity
)
json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)]
episode_reward = episode_reward[done]
if episode_reward.numel() > 0:
to_log.update(
{
f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
)
# group_episode_rewards has shape (n_episodes) as we took the mean over agents in the group
groups_episode_rewards.append(group_episode_rewards)

if "info" in batch.get(("next", group)).keys():
to_log.update(
{
Expand All @@ -130,19 +131,13 @@ def log_collection(
}
)
to_log.update(task.log_info(batch))
mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
if mean_group_return.numel() > 0:
to_log.update(
{
"collection/reward/episode_reward_min": mean_group_return.min().item(),
"collection/reward/episode_reward_mean": mean_group_return.mean().item(),
"collection/reward/episode_reward_max": mean_group_return.max().item(),
}
)
# global_episode_rewards has shape (n_episodes) as we took the mean over groups
global_episode_rewards = self._log_global_episode_reward(
groups_episode_rewards, to_log, prefix="collection"
)

self.log(to_log, step=step)
return mean_group_return.mean().item()
return global_episode_rewards.mean().item()

def log_training(self, group: str, training_td: TensorDictBase, step: int):
if not len(self.loggers):
Expand All @@ -164,57 +159,45 @@ def log_evaluation(
not len(self.loggers) and not self.experiment_config.create_json
) or not len(rollouts):
return

# Cut rollouts at first done
max_length_rollout_0 = 0
for i in range(len(rollouts)):
r = rollouts[i]
next_done = self._get_global_done(r).squeeze(-1)

# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts[i] = r

to_log = {}
json_metrics = {}
max_length_rollout_0 = 0
for group in self.group_map.keys():
# Cut the rollouts at the first done
rollouts_group = []
for i, r in enumerate(rollouts):
next_done = self._get_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
tuple(range(r.batch_dims, next_done.ndim)),
dtype=torch.bool,
)
# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts_group.append(r)

returns = [
self._get_reward(group, td).sum(0).mean().item()
for td in rollouts_group
]
json_metrics[group + "_return"] = torch.tensor(
returns, device=rollouts_group[0].device
# returns has shape (n_episodes)
returns = torch.stack(
[self._get_reward(group, td).sum(0).mean() for td in rollouts],
dim=0,
)
to_log.update(
{
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts_group),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
self._log_min_mean_max(
to_log, f"eval/{group}/reward/episode_reward", returns
)
json_metrics[group + "_return"] = returns

mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
to_log.update(
{
"eval/reward/episode_reward_min": mean_group_return.min().item(),
"eval/reward/episode_reward_mean": mean_group_return.mean().item(),
"eval/reward/episode_reward_max": mean_group_return.max().item(),
"eval/reward/episode_len_mean": sum(td.batch_size[0] for td in rollouts)
/ len(rollouts),
}
mean_group_return = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="eval"
)
# mean_group_return has shape (n_episodes) as we take the mean groups
json_metrics["return"] = mean_group_return

to_log["eval/reward/episode_len_mean"] = sum(
td.batch_size[0] for td in rollouts
) / len(rollouts)

if self.json_writer is not None:
self.json_writer.write(
metrics=json_metrics,
Expand Down Expand Up @@ -265,34 +248,113 @@ def finish(self):
def _get_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "reward") not in td.keys(True, True):
reward = td.get(("next", group, "reward"), None)
if reward is None:
reward = (
td.get(("next", "reward")).expand(td.get(group).shape).unsqueeze(-1)
)
else:
reward = td.get(("next", group, "reward"))
return reward.mean(-2) if remove_agent_dim else reward

def _get_done(self, group: str, td: TensorDictBase, remove_agent_dim: bool = False):
if ("next", group, "done") not in td.keys(True, True):
def _get_agents_done(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
done = td.get(("next", group, "done"), None)
if done is None:
done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1)
else:
done = td.get(("next", group, "done"))

return done.any(-2) if remove_agent_dim else done

def _get_global_done(
self,
td: TensorDictBase,
):
done = td.get(("next", "done"))
return done

def _get_episode_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "episode_reward") not in td.keys(True, True):
episode_reward = td.get(("next", group, "episode_reward"), None)
if episode_reward is None:
episode_reward = (
td.get(("next", "episode_reward"))
.expand(td.get(group).shape)
.unsqueeze(-1)
)
else:
episode_reward = td.get(("next", group, "episode_reward"))
return episode_reward.mean(-2) if remove_agent_dim else episode_reward

def _log_individual_and_group_rewards(
self,
group: str,
batch: TensorDictBase,
global_done: Tensor,
any_episode_ended: bool,
to_log: Dict[str, Tensor],
prefix: str = "collection",
log_individual_agents: bool = True,
):
reward = self._get_reward(group, batch) # Has agent dim
episode_reward = self._get_episode_reward(group, batch) # Has agent dim
n_agents_in_group = episode_reward.shape[-2]

# Add multiagent dim
unsqueeze_global_done = global_done.unsqueeze(-1).expand(
(*batch.get_item_shape(group), 1)
)
#######
# All trajectories are considered done at the global done
#######

# 1. Here we log rewards from individual agent data
if log_individual_agents:
for i in range(n_agents_in_group):
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/reward",
reward[..., i, :],
)
if any_episode_ended:
agent_global_done = unsqueeze_global_done[..., i, :]
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/episode_reward",
episode_reward[..., i, :][agent_global_done],
)

# 2. Here we log rewards from group data taking the mean over agents
group_episode_reward = episode_reward.mean(-2)[global_done]
if any_episode_ended:
self._log_min_mean_max(
to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward
)
self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward)

return group_episode_reward

def _log_global_episode_reward(
self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str
):
# Each element in the list is the episode reward (with shape n_episodes) for the group at the global done,
# so they will have same shape as done is shared
episode_rewards = torch.stack(episode_rewards, dim=0).mean(
0
) # Mean over groups
if episode_rewards.numel() > 0:
self._log_min_mean_max(
to_log, f"{prefix}/reward/episode_reward", episode_rewards
)

return episode_rewards

def _log_min_mean_max(self, to_log: Dict[str, Tensor], key: str, value: Tensor):
to_log.update(
{
key + "_min": value.min().item(),
key + "_mean": value.mean().item(),
key + "_max": value.max().item(),
}
)


class JsonWriter:
"""
Expand Down

0 comments on commit b214375

Please sign in to comment.