diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index b33a8243..f1b15717 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -93,13 +93,21 @@ def log_collection( step: int, ) -> float: to_log = {} - json_metrics = {} + groups_episode_rewards = [] + gobal_done = self._get_global_done(batch) # Does not have agent dim + any_episode_ended = gobal_done.nonzero().numel() > 0 + if not any_episode_ended: + warnings.warn( + "No episode terminated this iteration and thus the episode rewards will be NaN, " + "this is normal if your horizon is longer then one iteration. Learning is proceeding fine." + "The episodes will probably terminate in a future iteration." + ) for group in self.group_map.keys(): - done = self._get_global_done(group, batch) # Does not have agent dim group_episode_rewards = self._log_individual_and_group_rewards( - group, batch, done, to_log, prefix="collection" + group, batch, gobal_done, any_episode_ended, to_log ) - json_metrics[group + "_return"] = group_episode_rewards + # group_episode_rewards has shape (n_episodes) as we took the mean over agents in the group + groups_episode_rewards.append(group_episode_rewards) if "info" in batch.get(("next", group)).keys(): to_log.update( @@ -118,8 +126,9 @@ def log_collection( } ) to_log.update(task.log_info(batch)) + # global_episode_rewards has shape (n_episodes) as we took the mean over groups global_episode_rewards = self._log_global_episode_reward( - list(json_metrics.values()), to_log, prefix="collection" + groups_episode_rewards, to_log, prefix="collection" ) self.log(to_log, step=step) @@ -145,57 +154,51 @@ def log_evaluation( not len(self.loggers) and not self.experiment_config.create_json ) or not len(rollouts): return + + # Cut rollouts at first done + max_length_rollout_0 = 0 + for i in range(len(rollouts)): + r = rollouts[i] + next_done = self._get_global_done(r) + # Reduce it to batch size + next_done = next_done.sum( + tuple(range(r.batch_dims, next_done.ndim)), + dtype=torch.bool, + ) + # First done index for this traj + done_index = next_done.nonzero(as_tuple=True)[0] + if done_index.numel() > 0: + done_index = done_index[0] + r = r[: done_index + 1] + if i == 0: + max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0) + rollouts[i] = r + to_log = {} json_metrics = {} - max_length_rollout_0 = 0 for group in self.group_map.keys(): - # Cut the rollouts at the first done - rollouts_group = [] - for i, r in enumerate(rollouts): - next_done = self._get_agents_done(group, r) - # Reduce it to batch size - next_done = next_done.sum( - tuple(range(r.batch_dims, next_done.ndim)), - dtype=torch.bool, - ) - # First done index for this traj - done_index = next_done.nonzero(as_tuple=True)[0] - if done_index.numel() > 0: - done_index = done_index[0] - r = r[: done_index + 1] - if i == 0: - max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0) - rollouts_group.append(r) - - returns = [ - self._get_reward(group, td).sum(0).mean().item() - for td in rollouts_group - ] - json_metrics[group + "_return"] = torch.tensor( - returns, device=rollouts_group[0].device + # returns has shape (n_episodes, n_agents_in_group) + returns = torch.stack( + [self._get_reward(group, td).sum(0) for td in rollouts], + dim=0, ) - to_log.update( - { - f"eval/{group}/reward/episode_reward_min": min(returns), - f"eval/{group}/reward/episode_reward_mean": sum(returns) - / len(rollouts_group), - f"eval/{group}/reward/episode_reward_max": max(returns), - } + self._log_min_mean_max( + to_log, f"eval/{group}/reward/episode_reward", returns ) + json_metrics[group + "_return"] = returns.mean( + dim=tuple(range(1, returns.ndim)) + ) # result has shape (n_episodes) as we take the mean over agents in the group - mean_group_return = torch.stack( - [value for key, value in json_metrics.items()], dim=0 - ).mean(0) - to_log.update( - { - "eval/reward/episode_reward_min": mean_group_return.min().item(), - "eval/reward/episode_reward_mean": mean_group_return.mean().item(), - "eval/reward/episode_reward_max": mean_group_return.max().item(), - "eval/reward/episode_len_mean": sum(td.batch_size[0] for td in rollouts) - / len(rollouts), - } + mean_group_return = self._log_global_episode_reward( + list(json_metrics.values()), to_log, prefix="eval" ) + # mean_group_return has shape (n_episodes) as we take the mean groups json_metrics["return"] = mean_group_return + + to_log["eval/reward/episode_len_mean"] = sum( + td.batch_size[0] for td in rollouts + ) / len(rollouts) + if self.json_writer is not None: self.json_writer.write( metrics=json_metrics, @@ -252,15 +255,15 @@ def _get_reward( def _get_agents_done( self, group: str, td: TensorDictBase, remove_agent_dim: bool = False ): - if ("next", group, "done") not in td.keys(True, True): + group_td = td.get(("next", group)) + if ("done") not in group_td.keys(): done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1) else: - done = td.get(("next", group, "done")) + done = group_td.get("done") return done.any(-2) if remove_agent_dim else done def _get_global_done( self, - group: str, td: TensorDictBase, ): done = td.get(("next", "done")) @@ -273,67 +276,77 @@ def _get_episode_reward( return episode_reward.mean(-2) if remove_agent_dim else episode_reward def _log_individual_and_group_rewards( - self, group, batch, global_done, to_log, prefix: str + self, + group: str, + batch: TensorDictBase, + global_done: Tensor, + any_episode_ended: bool, + to_log: Dict[str, Tensor], + prefix: str = "collection", + log_individual_agents: bool = True, ): reward = self._get_reward(group, batch) # Has agent dim episode_reward = self._get_episode_reward(group, batch) # Has agent dim n_agents_in_group = episode_reward.shape[-2] - # The trajectories are considered until the global done - episode_reward = episode_reward[ - global_done.unsqueeze(-1).expand((*batch.get_item_shape(group), 1)) - ] + # Add multiagent dim + unsqueeze_global_done = global_done.unsqueeze(-1).expand( + (*batch.get_item_shape(group), 1) + ) + ####### + # All trajectories are considered done at the global done + ####### + + # 1. Here we log rewards from individual agent data + if log_individual_agents: + for i in range(n_agents_in_group): + self._log_min_mean_max( + to_log, + f"{prefix}/{group}/reward/agent_{i}/reward", + reward[..., i, :], + ) + if any_episode_ended: + agent_global_done = unsqueeze_global_done[..., i, :] + self._log_min_mean_max( + to_log, + f"{prefix}/{group}/reward/agent_{i}/episode_reward", + episode_reward[..., i, :][agent_global_done], + ) - for i in range(n_agents_in_group): - to_log.update( - { - f"{prefix}/{group}/reward/agent_{i}/reward_min": reward[..., i, :] - .min() - .item(), - f"{prefix}/{group}/reward/agent_{i}/reward_mean": reward[..., i, :] - .mean() - .item(), - f"{prefix}/{group}/reward/agent_{i}/reward_max": reward[..., i, :] - .max() - .item(), - } + # 2. Here we log rewards from group data + if any_episode_ended: + group_episode_reward = episode_reward[unsqueeze_global_done] + self._log_min_mean_max( + to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward ) - if episode_reward.numel() > 0: + self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward) - to_log.update( - { - f"{prefix}/{group}/reward/episode_reward_min": episode_reward.min().item(), - f"{prefix}/{group}/reward/episode_reward_mean": episode_reward.mean().item(), - f"{prefix}/{group}/reward/episode_reward_max": episode_reward.max().item(), - } + # 3. We take the mean over agents in the group so that we will later log from a global perspecitve + return episode_reward.mean(-2)[global_done] + + def _log_global_episode_reward( + self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str + ): + # Each element in the list is the episode reward (with shape n_episodes) for the group at the global done, + # so they will have same shape as done is shared + episode_rewards = torch.stack(episode_rewards, dim=0).mean( + 0 + ) # Mean over groups + if episode_rewards.numel() > 0: + self._log_min_mean_max( + to_log, f"{prefix}/reward/episode_reward", episode_rewards ) + + return episode_rewards + + def _log_min_mean_max(self, to_log: Dict[str, Tensor], key: str, value: Tensor): to_log.update( { - f"{prefix}/{group}/reward/reward_min": reward.min().item(), - f"{prefix}/{group}/reward/reward_mean": reward.mean().item(), - f"{prefix}/{group}/reward/reward_max": reward.max().item(), + key + "_min": value.min().item(), + key + "_mean": value.mean().item(), + key + "_max": value.max().item(), } ) - return episode_reward - - def _log_global_episode_reward(self, episode_rewards, to_log, prefix: str): - # Each element in the list is the episode reward for the group at the global done, so they will have same shape as done is shared - episode_rewards = torch.stack(episode_rewards, dim=0) - if episode_rewards.numel() > 0: - to_log.update( - { - f"{prefix}/reward/episode_reward_min": episode_rewards.min().item(), - f"{prefix}/reward/episode_reward_mean": episode_rewards.mean().item(), - f"{prefix}/reward/episode_reward_max": episode_rewards.max().item(), - } - ) - else: - warnings.warn( - "No episode terminated this iteration and thus the overall episode reward in NaN, " - "this is normal if your horizon is longer then one iteration. Learning is proceeding fine." - "The episodes will probably terminate in a future iteration." - ) - return episode_rewards class JsonWriter: