From 10e9288fed15f298209b81982e9b53d894c2952a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 29 Apr 2024 11:22:23 +0100 Subject: [PATCH 1/3] amend --- benchmarl/algorithms/common.py | 6 ------ benchmarl/environments/meltingpot/common.py | 6 +----- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index e9b885b1..07cd9840 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -67,12 +67,6 @@ def _check_specs(self): "you can apply a transform to your environment to satisfy this criteria." ) for group in self.group_map.keys(): - if len(self.observation_spec[group].keys(True, True)) != 1: - raise ValueError( - "Observation spec must contain one entry per group" - " to follow the library conventions, " - "you can apply a transform to your environment to satisfy this criteria." - ) if ( len(self.action_spec[group].keys(True, True)) != 1 or list(self.action_spec[group].keys())[0] != "action" diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index f209d8b1..a3111f9c 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -81,6 +81,7 @@ def get_env_fun( return lambda: MeltingpotEnv( substrate=self.name.lower(), categorical_actions=True, + device=device, **self.config, ) @@ -141,11 +142,6 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec: for group_key in list(observation_spec.keys()): if group_key not in self.group_map(env).keys(): del observation_spec[group_key] - else: - group_obs_spec = observation_spec[group_key]["observation"] - for key in list(group_obs_spec.keys()): - if key != "RGB": - del group_obs_spec[key] return observation_spec def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: From d96a40f5ba1d04e103b644788286a1ecee8a56fb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 29 Apr 2024 11:53:03 +0100 Subject: [PATCH 2/3] amend --- benchmarl/environments/meltingpot/common.py | 20 ++++++++++++++++++-- benchmarl/experiment/experiment.py | 17 +++++++++-------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index a3111f9c..2e16311b 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -10,7 +10,13 @@ from tensordict import TensorDictBase from torchrl.data import CompositeSpec -from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform +from torchrl.envs import ( + DoubleToFloat, + DTypeCastTransform, + EnvBase, + FlattenObservation, + Transform, +) from benchmarl.environments.common import Task from benchmarl.utils import DEVICE_TYPING @@ -101,7 +107,17 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]: return env.group_map def get_env_transforms(self, env: EnvBase) -> List[Transform]: - return [DoubleToFloat()] + return [ + DoubleToFloat(), + FlattenObservation( + in_keys=[ + (group, "observation", "INTERACTION_INVENTORIES") + for group in self.group_map(env).keys() + ], + first_dim=-2, + last_dim=-1, + ), + ] def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: return [ diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index f7a75a99..e74a97a0 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -387,14 +387,6 @@ def _setup_task(self): device=self.config.sampling_device, ) ) - self.observation_spec = self.task.observation_spec(test_env) - self.info_spec = self.task.info_spec(test_env) - self.state_spec = self.task.state_spec(test_env) - self.action_mask_spec = self.task.action_mask_spec(test_env) - self.action_spec = self.task.action_spec(test_env) - self.group_map = self.task.group_map(test_env) - self.train_group_map = copy.deepcopy(self.group_map) - self.max_steps = self.task.max_steps(test_env) transforms_env = self.task.get_env_transforms(test_env) transforms_training = transforms_env + [ @@ -418,6 +410,15 @@ def _setup_task(self): self.config.sampling_device ) + self.observation_spec = self.task.observation_spec(self.test_env) + self.info_spec = self.task.info_spec(self.test_env) + self.state_spec = self.task.state_spec(self.test_env) + self.action_mask_spec = self.task.action_mask_spec(self.test_env) + self.action_spec = self.task.action_spec(self.test_env) + self.group_map = self.task.group_map(self.test_env) + self.train_group_map = copy.deepcopy(self.group_map) + self.max_steps = self.task.max_steps(self.test_env) + def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) self.replay_buffers = { From affe867515b5929762cc3034c08e748cf6123d34 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 29 Apr 2024 14:06:12 +0100 Subject: [PATCH 3/3] amend --- benchmarl/environments/meltingpot/common.py | 26 +++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index 2e16311b..6dfb39fa 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -107,17 +107,23 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]: return env.group_map def get_env_transforms(self, env: EnvBase) -> List[Transform]: - return [ - DoubleToFloat(), - FlattenObservation( - in_keys=[ - (group, "observation", "INTERACTION_INVENTORIES") - for group in self.group_map(env).keys() - ], - first_dim=-2, - last_dim=-1, - ), + interaction_inventories_keys = [ + (group, "observation", "INTERACTION_INVENTORIES") + for group in self.group_map(env).keys() + if (group, "observation", "INTERACTION_INVENTORIES") + in env.observation_spec.keys(True, True) ] + return [DoubleToFloat()] + ( + [ + FlattenObservation( + in_keys=interaction_inventories_keys, + first_dim=-2, + last_dim=-1, + ) + ] + if len(interaction_inventories_keys) + else [] + ) def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: return [