diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index e9b885b1..07cd9840 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -67,12 +67,6 @@ def _check_specs(self): "you can apply a transform to your environment to satisfy this criteria." ) for group in self.group_map.keys(): - if len(self.observation_spec[group].keys(True, True)) != 1: - raise ValueError( - "Observation spec must contain one entry per group" - " to follow the library conventions, " - "you can apply a transform to your environment to satisfy this criteria." - ) if ( len(self.action_spec[group].keys(True, True)) != 1 or list(self.action_spec[group].keys())[0] != "action" diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index f209d8b1..6dfb39fa 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -10,7 +10,13 @@ from tensordict import TensorDictBase from torchrl.data import CompositeSpec -from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform +from torchrl.envs import ( + DoubleToFloat, + DTypeCastTransform, + EnvBase, + FlattenObservation, + Transform, +) from benchmarl.environments.common import Task from benchmarl.utils import DEVICE_TYPING @@ -81,6 +87,7 @@ def get_env_fun( return lambda: MeltingpotEnv( substrate=self.name.lower(), categorical_actions=True, + device=device, **self.config, ) @@ -100,7 +107,23 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]: return env.group_map def get_env_transforms(self, env: EnvBase) -> List[Transform]: - return [DoubleToFloat()] + interaction_inventories_keys = [ + (group, "observation", "INTERACTION_INVENTORIES") + for group in self.group_map(env).keys() + if (group, "observation", "INTERACTION_INVENTORIES") + in env.observation_spec.keys(True, True) + ] + return [DoubleToFloat()] + ( + [ + FlattenObservation( + in_keys=interaction_inventories_keys, + first_dim=-2, + last_dim=-1, + ) + ] + if len(interaction_inventories_keys) + else [] + ) def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: return [ @@ -141,11 +164,6 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec: for group_key in list(observation_spec.keys()): if group_key not in self.group_map(env).keys(): del observation_spec[group_key] - else: - group_obs_spec = observation_spec[group_key]["observation"] - for key in list(group_obs_spec.keys()): - if key != "RGB": - del group_obs_spec[key] return observation_spec def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]: diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index f7a75a99..e74a97a0 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -387,14 +387,6 @@ def _setup_task(self): device=self.config.sampling_device, ) ) - self.observation_spec = self.task.observation_spec(test_env) - self.info_spec = self.task.info_spec(test_env) - self.state_spec = self.task.state_spec(test_env) - self.action_mask_spec = self.task.action_mask_spec(test_env) - self.action_spec = self.task.action_spec(test_env) - self.group_map = self.task.group_map(test_env) - self.train_group_map = copy.deepcopy(self.group_map) - self.max_steps = self.task.max_steps(test_env) transforms_env = self.task.get_env_transforms(test_env) transforms_training = transforms_env + [ @@ -418,6 +410,15 @@ def _setup_task(self): self.config.sampling_device ) + self.observation_spec = self.task.observation_spec(self.test_env) + self.info_spec = self.task.info_spec(self.test_env) + self.state_spec = self.task.state_spec(self.test_env) + self.action_mask_spec = self.task.action_mask_spec(self.test_env) + self.action_spec = self.task.action_spec(self.test_env) + self.group_map = self.task.group_map(self.test_env) + self.train_group_map = copy.deepcopy(self.group_map) + self.max_steps = self.task.max_steps(self.test_env) + def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) self.replay_buffers = {