From b2143754ca0f6a6dc09fec42fb630ab6f77d144c Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:38:21 +0200 Subject: [PATCH] [BugFix] More flexible episode_reward computation in logger (#136) * amend * amend * disable single agent logging * amend * amend * amend --- benchmarl/experiment/logger.py | 232 +++++++++++++++++++++------------ 1 file changed, 147 insertions(+), 85 deletions(-) diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 6ed16ae4..20e516aa 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -6,7 +6,9 @@ import json import os +import warnings from pathlib import Path + from typing import Dict, List, Optional import numpy as np @@ -91,28 +93,27 @@ def log_collection( step: int, ) -> float: to_log = {} - json_metrics = {} + groups_episode_rewards = [] + gobal_done = self._get_global_done(batch) # Does not have agent dim + any_episode_ended = gobal_done.nonzero().numel() > 0 + if not any_episode_ended: + warnings.warn( + "No episode terminated this iteration and thus the episode rewards will be NaN, " + "this is normal if your horizon is longer then one iteration. Learning is proceeding fine." + "The episodes will probably terminate in a future iteration." + ) for group in self.group_map.keys(): - episode_reward = self._get_episode_reward(group, batch) - done = self._get_done(group, batch) - reward = self._get_reward(group, batch) - to_log.update( - { - f"collection/{group}/reward/reward_min": reward.min().item(), - f"collection/{group}/reward/reward_mean": reward.mean().item(), - f"collection/{group}/reward/reward_max": reward.max().item(), - } + group_episode_rewards = self._log_individual_and_group_rewards( + group, + batch, + gobal_done, + any_episode_ended, + to_log, + log_individual_agents=False, # Turn on if you want single agent granularity ) - json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)] - episode_reward = episode_reward[done] - if episode_reward.numel() > 0: - to_log.update( - { - f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(), - f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(), - f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(), - } - ) + # group_episode_rewards has shape (n_episodes) as we took the mean over agents in the group + groups_episode_rewards.append(group_episode_rewards) + if "info" in batch.get(("next", group)).keys(): to_log.update( { @@ -130,19 +131,13 @@ def log_collection( } ) to_log.update(task.log_info(batch)) - mean_group_return = torch.stack( - [value for key, value in json_metrics.items()], dim=0 - ).mean(0) - if mean_group_return.numel() > 0: - to_log.update( - { - "collection/reward/episode_reward_min": mean_group_return.min().item(), - "collection/reward/episode_reward_mean": mean_group_return.mean().item(), - "collection/reward/episode_reward_max": mean_group_return.max().item(), - } - ) + # global_episode_rewards has shape (n_episodes) as we took the mean over groups + global_episode_rewards = self._log_global_episode_reward( + groups_episode_rewards, to_log, prefix="collection" + ) + self.log(to_log, step=step) - return mean_group_return.mean().item() + return global_episode_rewards.mean().item() def log_training(self, group: str, training_td: TensorDictBase, step: int): if not len(self.loggers): @@ -164,57 +159,45 @@ def log_evaluation( not len(self.loggers) and not self.experiment_config.create_json ) or not len(rollouts): return + + # Cut rollouts at first done + max_length_rollout_0 = 0 + for i in range(len(rollouts)): + r = rollouts[i] + next_done = self._get_global_done(r).squeeze(-1) + + # First done index for this traj + done_index = next_done.nonzero(as_tuple=True)[0] + if done_index.numel() > 0: + done_index = done_index[0] + r = r[: done_index + 1] + if i == 0: + max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0) + rollouts[i] = r + to_log = {} json_metrics = {} - max_length_rollout_0 = 0 for group in self.group_map.keys(): - # Cut the rollouts at the first done - rollouts_group = [] - for i, r in enumerate(rollouts): - next_done = self._get_done(group, r) - # Reduce it to batch size - next_done = next_done.sum( - tuple(range(r.batch_dims, next_done.ndim)), - dtype=torch.bool, - ) - # First done index for this traj - done_index = next_done.nonzero(as_tuple=True)[0] - if done_index.numel() > 0: - done_index = done_index[0] - r = r[: done_index + 1] - if i == 0: - max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0) - rollouts_group.append(r) - - returns = [ - self._get_reward(group, td).sum(0).mean().item() - for td in rollouts_group - ] - json_metrics[group + "_return"] = torch.tensor( - returns, device=rollouts_group[0].device + # returns has shape (n_episodes) + returns = torch.stack( + [self._get_reward(group, td).sum(0).mean() for td in rollouts], + dim=0, ) - to_log.update( - { - f"eval/{group}/reward/episode_reward_min": min(returns), - f"eval/{group}/reward/episode_reward_mean": sum(returns) - / len(rollouts_group), - f"eval/{group}/reward/episode_reward_max": max(returns), - } + self._log_min_mean_max( + to_log, f"eval/{group}/reward/episode_reward", returns ) + json_metrics[group + "_return"] = returns - mean_group_return = torch.stack( - [value for key, value in json_metrics.items()], dim=0 - ).mean(0) - to_log.update( - { - "eval/reward/episode_reward_min": mean_group_return.min().item(), - "eval/reward/episode_reward_mean": mean_group_return.mean().item(), - "eval/reward/episode_reward_max": mean_group_return.max().item(), - "eval/reward/episode_len_mean": sum(td.batch_size[0] for td in rollouts) - / len(rollouts), - } + mean_group_return = self._log_global_episode_reward( + list(json_metrics.values()), to_log, prefix="eval" ) + # mean_group_return has shape (n_episodes) as we take the mean groups json_metrics["return"] = mean_group_return + + to_log["eval/reward/episode_len_mean"] = sum( + td.batch_size[0] for td in rollouts + ) / len(rollouts) + if self.json_writer is not None: self.json_writer.write( metrics=json_metrics, @@ -265,34 +248,113 @@ def finish(self): def _get_reward( self, group: str, td: TensorDictBase, remove_agent_dim: bool = False ): - if ("next", group, "reward") not in td.keys(True, True): + reward = td.get(("next", group, "reward"), None) + if reward is None: reward = ( td.get(("next", "reward")).expand(td.get(group).shape).unsqueeze(-1) ) - else: - reward = td.get(("next", group, "reward")) return reward.mean(-2) if remove_agent_dim else reward - def _get_done(self, group: str, td: TensorDictBase, remove_agent_dim: bool = False): - if ("next", group, "done") not in td.keys(True, True): + def _get_agents_done( + self, group: str, td: TensorDictBase, remove_agent_dim: bool = False + ): + done = td.get(("next", group, "done"), None) + if done is None: done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1) - else: - done = td.get(("next", group, "done")) + return done.any(-2) if remove_agent_dim else done + def _get_global_done( + self, + td: TensorDictBase, + ): + done = td.get(("next", "done")) + return done + def _get_episode_reward( self, group: str, td: TensorDictBase, remove_agent_dim: bool = False ): - if ("next", group, "episode_reward") not in td.keys(True, True): + episode_reward = td.get(("next", group, "episode_reward"), None) + if episode_reward is None: episode_reward = ( td.get(("next", "episode_reward")) .expand(td.get(group).shape) .unsqueeze(-1) ) - else: - episode_reward = td.get(("next", group, "episode_reward")) return episode_reward.mean(-2) if remove_agent_dim else episode_reward + def _log_individual_and_group_rewards( + self, + group: str, + batch: TensorDictBase, + global_done: Tensor, + any_episode_ended: bool, + to_log: Dict[str, Tensor], + prefix: str = "collection", + log_individual_agents: bool = True, + ): + reward = self._get_reward(group, batch) # Has agent dim + episode_reward = self._get_episode_reward(group, batch) # Has agent dim + n_agents_in_group = episode_reward.shape[-2] + + # Add multiagent dim + unsqueeze_global_done = global_done.unsqueeze(-1).expand( + (*batch.get_item_shape(group), 1) + ) + ####### + # All trajectories are considered done at the global done + ####### + + # 1. Here we log rewards from individual agent data + if log_individual_agents: + for i in range(n_agents_in_group): + self._log_min_mean_max( + to_log, + f"{prefix}/{group}/reward/agent_{i}/reward", + reward[..., i, :], + ) + if any_episode_ended: + agent_global_done = unsqueeze_global_done[..., i, :] + self._log_min_mean_max( + to_log, + f"{prefix}/{group}/reward/agent_{i}/episode_reward", + episode_reward[..., i, :][agent_global_done], + ) + + # 2. Here we log rewards from group data taking the mean over agents + group_episode_reward = episode_reward.mean(-2)[global_done] + if any_episode_ended: + self._log_min_mean_max( + to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward + ) + self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward) + + return group_episode_reward + + def _log_global_episode_reward( + self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str + ): + # Each element in the list is the episode reward (with shape n_episodes) for the group at the global done, + # so they will have same shape as done is shared + episode_rewards = torch.stack(episode_rewards, dim=0).mean( + 0 + ) # Mean over groups + if episode_rewards.numel() > 0: + self._log_min_mean_max( + to_log, f"{prefix}/reward/episode_reward", episode_rewards + ) + + return episode_rewards + + def _log_min_mean_max(self, to_log: Dict[str, Tensor], key: str, value: Tensor): + to_log.update( + { + key + "_min": value.min().item(), + key + "_mean": value.mean().item(), + key + "_max": value.max().item(), + } + ) + class JsonWriter: """