diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index ed5efadc..219cc046 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -7,11 +7,12 @@ import pathlib from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import ( + CompositeSpec, DiscreteTensorSpec, LazyTensorStorage, OneHotDiscreteTensorSpec, @@ -19,7 +20,14 @@ TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement -from torchrl.envs import Compose, Transform +from torchrl.envs import ( + Compose, + EnvBase, + InitTracker, + TensorDictPrimer, + Transform, + TransformedEnv, +) from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -51,6 +59,16 @@ def __init__(self, experiment): self.action_spec = experiment.action_spec self.state_spec = experiment.state_spec self.action_mask_spec = experiment.action_mask_spec + self.has_independent_critic = ( + experiment.algorithm_config.has_independent_critic() + ) + self.has_centralized_critic = ( + experiment.algorithm_config.has_centralized_critic() + ) + self.has_critic = experiment.algorithm_config.has_critic + self.has_rnn = self.model_config.is_rnn or ( + self.critic_model_config.is_rnn and self.has_critic + ) # Cached values that will be instantiated only once and then remain fixed self._losses_and_updaters = {} @@ -142,10 +160,7 @@ def get_replay_buffer( """ memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy) sampling_size = self.experiment_config.train_minibatch_size(self.on_policy) - if ( - self.experiment.model_config.is_rnn - or self.experiment.critic_model_config.is_rnn - ): + if self.has_rnn: sequence_length = -( -self.experiment_config.collected_frames_per_batch(self.on_policy) // self.experiment_config.n_envs_per_worker(self.on_policy) @@ -229,6 +244,69 @@ def get_parameters(self, group: str) -> Dict[str, Iterable]: loss=self.get_loss_and_updater(group)[0], ) + def process_env_fun( + self, + env_fun: Callable[[], EnvBase], + ) -> Callable[[], EnvBase]: + """ + This function can be used to wrap env_fun + Args: + env_fun (callable): a function that takes no args and creates an enviornment + + Returns: a function that takes no args and creates an enviornment + + """ + if self.has_rnn: + + def model_fun(): + env = env_fun() + + spec_actor = self.model_config.get_model_state_spec() + spec_actor = CompositeSpec( + { + group: CompositeSpec( + spec_actor.expand(len(agents), *spec_actor.shape), + shape=(len(agents),), + ) + for group, agents in self.group_map.items() + } + ) + # if self.has_critic and self.critic_model_config.is_rnn and False: + # spec_critic = self.critic_model_config.get_model_state_spec() + # if ( + # self.has_independent_critic + # or not self.critic_model_config.share_param_critic + # ): + # spec_critic = CompositeSpec( + # { + # group: CompositeSpec( + # spec_critic.expand(len(agents), *spec_critic.shape), + # shape=(len(agents),), + # ) + # for group, agents in self.group_map.items() + # } + # ) + # spec_actor.update(spec_critic) + + env = TransformedEnv( + env, + Compose( + *( + [InitTracker(init_key="is_init")] + + ( + [TensorDictPrimer(spec_actor)] + if len(spec_actor.keys(True, True)) > 0 + else [] + ) + ) + ), + ) + return env + + return model_fun + + return env_fun + ############################### # Abstract methods to implement ############################### @@ -410,3 +488,27 @@ def supports_discrete_actions() -> bool: If the algorithm supports discrete actions """ raise NotImplementedError + + @staticmethod + def has_independent_critic() -> bool: + """ + If the algorithm uses an independent critic + """ + return False + + @staticmethod + def has_centralized_critic() -> bool: + """ + If the algorithm uses a centralized critic + """ + return False + + def has_critic(self) -> bool: + """ + If the algorithm uses a critic + """ + if self.has_centralized_critic() and self.has_independent_critic(): + raise ValueError( + "Algorithm can either have a centralized critic or an indpendent one" + ) + return self.has_centralized_critic() or self.has_independent_critic() diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index a114123b..ced5d137 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -123,12 +123,14 @@ def _get_policy_for_loss( in_keys=[(group, "param")], out_keys=[(group, "action")], distribution_class=TanhDelta if self.use_tanh_mapping else Delta, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_mapping - else {}, + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_mapping + else {} + ), return_log_prob=False, safe=not self.use_tanh_mapping, ) @@ -249,3 +251,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index aac2cd88..f66d7993 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -325,3 +325,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return True + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index 74c29ea7..d96431d7 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -199,15 +199,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -387,3 +389,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index 1590f81f..8967573b 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -123,12 +123,14 @@ def _get_policy_for_loss( in_keys=[(group, "param")], out_keys=[(group, "action")], distribution_class=TanhDelta if self.use_tanh_mapping else Delta, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_mapping - else {}, + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_mapping + else {} + ), return_log_prob=False, safe=not self.use_tanh_mapping, ) @@ -299,3 +301,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 891200ef..f8efedcd 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -361,3 +361,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return True + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index 358010ef..4f204158 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -199,15 +199,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -461,3 +463,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/conf/model/layers/gru.yaml b/benchmarl/conf/model/layers/gru.yaml index 69b0a2f5..edd9aafe 100644 --- a/benchmarl/conf/model/layers/gru.yaml +++ b/benchmarl/conf/model/layers/gru.yaml @@ -2,7 +2,7 @@ name: gru hidden_size: 128 -compile: True +compile: False mlp_num_cells: [256, 256] mlp_layer_class: torch.nn.Linear diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 05abaf48..b26d3fa9 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -322,8 +322,13 @@ def __init__( self.task = task self.model_config = model_config self.critic_model_config = ( - critic_model_config if critic_model_config is not None else model_config + critic_model_config + if critic_model_config is not None + else copy.deepcopy(model_config) ) + self.model_config.is_critic = False + self.critic_model_config.is_critic = True + self.algorithm_config = algorithm_config self.seed = seed @@ -377,23 +382,17 @@ def _set_action_type(self): ) def _setup_task(self): - test_env = self.model_config.process_env_fun( - self.task.get_env_fun( - num_envs=self.config.evaluation_episodes, - continuous_actions=self.continuous_actions, - seed=self.seed, - device=self.config.sampling_device, - ), - self.task, + test_env = self.task.get_env_fun( + num_envs=self.config.evaluation_episodes, + continuous_actions=self.continuous_actions, + seed=self.seed, + device=self.config.sampling_device, )() - env_func = self.model_config.process_env_fun( - self.task.get_env_fun( - num_envs=self.config.n_envs_per_worker(self.on_policy), - continuous_actions=self.continuous_actions, - seed=self.seed, - device=self.config.sampling_device, - ), - self.task, + env_func = self.task.get_env_fun( + num_envs=self.config.n_envs_per_worker(self.on_policy), + continuous_actions=self.continuous_actions, + seed=self.seed, + device=self.config.sampling_device, ) transforms_env = self.task.get_env_transforms(test_env) @@ -429,6 +428,10 @@ def _setup_task(self): def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) + + self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)() + self.env_func = self.algorithm.process_env_fun(self.env_func) + self.replay_buffers = { group: self.algorithm.get_replay_buffer( group=group, @@ -612,7 +615,7 @@ def _collection_loop(self): for group in self.train_group_map.keys(): group_batch = batch.exclude(*self._get_excluded_keys(group)) group_batch = self.algorithm.process_batch(group, group_batch) - if not (self.model_config.is_rnn or self.critic_model_config.is_rnn): + if not self.algorithm.has_rnn: group_batch = group_batch.reshape(-1) self.replay_buffers[group].extend(group_batch) diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index c3e7611d..95a3c426 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -114,6 +114,8 @@ def __init__( share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), ) self.x = self.input_spec[self.image_in_keys[0]].shape[-3] diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 3870990e..373a2846 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -8,13 +8,12 @@ import warnings from abc import ABC, abstractmethod from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase, TensorDictSequential from tensordict.utils import NestedKey from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec -from torchrl.envs import EnvBase from benchmarl.utils import _class_from_name, _read_yaml_config, DEVICE_TYPING @@ -88,6 +87,8 @@ def __init__( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int, + is_critic: bool, ): TensorDictModuleBase.__init__(self) @@ -100,6 +101,8 @@ def __init__( self.device = device self.n_agents = n_agents self.action_spec = action_spec + self.model_index = model_index + self.is_critic = is_critic self.in_keys = list(self.input_spec.keys(True, True)) self.out_keys = list(self.output_spec.keys(True, True)) @@ -220,6 +223,8 @@ def __init__( agent_group=models[0].agent_group, input_has_agent_dim=models[0].input_has_agent_dim, action_spec=models[0].action_spec, + model_index=models[0].model_index, + is_critic=models[0].is_critic, ) self.models = TensorDictSequential(*models) @@ -250,6 +255,7 @@ def get_model( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int = 0, ) -> Model: """ Creates the model from the config. @@ -288,6 +294,8 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=model_index, + is_critic=self.is_critic, ) @staticmethod @@ -298,27 +306,23 @@ def associated_class(): """ raise NotImplementedError - def process_env_fun( - self, - env_fun: Callable[[], EnvBase], - task, - model_index: int = 0, - ) -> Callable[[], EnvBase]: - """ - This function can be used to wrap env_fun - Args: - env_fun (callable): a function that takes no args and creates an enviornment - task (Task): the task - - Returns: a function that takes no args and creates an enviornment - - """ - return env_fun - @property def is_rnn(self) -> bool: return False + @property + def is_critic(self): + if not hasattr(self, "_is_critic"): + raise AttributeError() + return self._is_critic + + @is_critic.setter + def is_critic(self, value): + self._is_critic = value + + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + return CompositeSpec() + @staticmethod def _load_from_yaml(name: str) -> Dict[str, Any]: yaml_path = ( @@ -402,6 +406,7 @@ def get_model( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int = 0, ) -> Model: n_models = len(self.model_configs) if not n_models > 0: @@ -418,7 +423,7 @@ def get_model( intermediate_specs = [ CompositeSpec( { - f"_{agent_group}_intermediate_{i}": UnboundedContinuousTensorSpec( + f"_{agent_group}{'_critic' if self.is_critic else ''}_intermediate_{i}": UnboundedContinuousTensorSpec( shape=(n_agents, size) if out_has_agent_dim else (size,) ) } @@ -437,6 +442,8 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=0, + is_critic=self.is_critic, ) ] @@ -451,6 +458,8 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=i, + is_critic=self.is_critic, ) for i in range(1, n_models) ] @@ -461,15 +470,23 @@ def get_model( def associated_class(): return SequenceModel - def process_env_fun( - self, - env_fun: Callable[[], EnvBase], - task, - model_index: int = 0, - ) -> Callable[[], EnvBase]: + @property + def is_critic(self): + if not hasattr(self, "_is_critic"): + raise AttributeError() + return self._is_critic + + @is_critic.setter + def is_critic(self, value): + self._is_critic = value + for model_config in self.model_configs: + model_config.is_critic = value + + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + spec = CompositeSpec() for i, model_config in enumerate(self.model_configs): - env_fun = model_config.process_env_fun(env_fun, task, i) - return env_fun + spec.update(model_config.get_model_state_spec(model_index=i)) + return spec @property def is_rnn(self) -> bool: diff --git a/benchmarl/models/gru.py b/benchmarl/models/gru.py index 4051577e..dfe88262 100644 --- a/benchmarl/models/gru.py +++ b/benchmarl/models/gru.py @@ -13,14 +13,14 @@ from __future__ import annotations from dataclasses import dataclass, MISSING -from typing import Callable, Optional, Sequence, Type +from typing import Optional, Sequence, Type import torch from tensordict import TensorDictBase from tensordict.utils import expand_as_right from torch import nn from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec -from torchrl.envs import Compose, EnvBase, InitTracker, TensorDictPrimer, TransformedEnv + from torchrl.modules import GRUCell, MLP, MultiAgentMLP from benchmarl.models.common import Model, ModelConfig @@ -83,15 +83,20 @@ def __init__( if self.centralised: self.input_size = self.input_size * self.n_agents - self.gru = GRU( + self.base_gru = GRU( input_size, hidden_size, device=self.device, ) + if self.compile: - self.gru = torch.compile(self.gru, mode="reduce-overhead", fullgraph=True) + self.base_gru = torch.compile( + self.base_gru, mode="reduce-overhead", fullgraph=True + ) if not self.centralised: - self.gru = torch.vmap(self.gru, in_dims=-2, out_dims=-2) + self.gru = torch.vmap(self.base_gru, in_dims=-2, out_dims=-2) + else: + self.gru = self.base_gru def forward( self, @@ -99,6 +104,10 @@ def forward( is_init, h_0=None, ): + # Input and output always have the multiagent dimension + # Hidden state only has it when not centralised + # is_init never has it + assert is_init is not None, "We need to pass is_init" training = h_0 is None if ( @@ -112,13 +121,8 @@ def forward( assert input.shape == (batch, seq, self.n_agents, self.input_size) if h_0 is not None: # Collection - assert h_0.shape == ( - batch, - self.n_agents, - self.hidden_size, - ) - if is_init is not None: # Set hidden to 0 when is_init - h_0 = torch.where(expand_as_right(is_init, h_0), 0, h_0) + # Set hidden to 0 when is_init + h_0 = torch.where(expand_as_right(is_init, h_0), 0, h_0) if not training: # If in collection emulate the sequence dimension is_init = is_init.unsqueeze(1) @@ -126,25 +130,32 @@ def forward( is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1) if h_0 is None: + if self.centralised: + shape = ( + batch, + self.hidden_size, + ) + else: + shape = ( + batch, + self.n_agents, + self.hidden_size, + ) h_0 = torch.zeros( - batch, - self.n_agents, - self.hidden_size, + shape, device=self.device, dtype=torch.float, ) if self.centralised: input = input.view(batch, seq, self.n_agents * self.input_size) - h_0 = h_0[..., 0, :] is_init = is_init.view(batch, seq, self.n_agents) - output, h_n = self.vmap_gru(input, is_init, h_0) + output, h_n = self.gru(input, is_init, h_0) if self.centralised: output = output.unsqueeze(-2).expand( batch, seq, self.n_agents, self.hidden_size ) - h_n = h_n.unsqueeze(-2).expand(batch, self.n_agents, self.hidden_size) if not training: output = output.squeeze(1) @@ -169,7 +180,14 @@ def __init__( share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), ) + self.hidden_state_name = (f"_hidden_gru_{self.model_index}",) + if not self.centralised: + self.hidden_state_name = (self.agent_group, *self.hidden_state_name) + self.rnn_keys = ["is_init", self.hidden_state_name] + self.in_keys += self.rnn_keys self.hidden_size = hidden_size self.compile = compile @@ -188,6 +206,17 @@ def __init__( centralised=self.centralised, compile=self.compile, ) + else: + self.gru = nn.ModuleList( + [ + GRU( + self.input_features, + self.hidden_size, + device=self.device, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) mlp_net_kwargs = { "_".join(k.split("_")[1:]): v @@ -253,17 +282,67 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) - h_0 = tensordict.get((self.agent_group, "_hidden_gru"), None) + input = torch.cat( + [ + tensordict.get(in_key) + for in_key in self.in_keys + if in_key not in self.rnn_keys + ], + dim=-1, + ) + h_0 = tensordict.get(self.hidden_state_name, None) is_init = tensordict.get("is_init") + training = h_0 is None # Has multi-agent input dimension if self.input_has_agent_dim: output, h_n = self.gru(input, is_init, h_0) if not self.output_has_agent_dim: output = output[..., 0, :] - else: - pass + else: # Is a global input + if ( + not training + ): # In collection we emulate the sequence dimension and we have the hidden state + input = input.unsqueeze(1) + # Check input + batch = input.shape[0] + seq = input.shape[1] + assert input.shape == (batch, seq, self.input_size) + + if h_0 is not None: # Collection + # Set hidden to 0 when is_init + h_0 = torch.where(expand_as_right(is_init, h_0), 0, h_0) + + if not training: # If in collection emulate the sequence dimension + is_init = is_init.unsqueeze(1) + assert is_init.shape == (batch, seq, 1) + + if h_0 is None: + h_0 = torch.zeros( + ( + batch, + self.hidden_size, + ), + device=self.device, + dtype=torch.float, + ) + if self.share_params: + output, h_n = self.gru[0](input, is_init, h_0) + else: + outputs = [] + h_ns = [] + for i, net in enumerate(self.mlp): + if h_0 is not None: + h_i = h_0[..., i, :] + else: + h_i = None + output, h_n = net(input, is_init, h_i) + outputs.append(output) + h_ns.append(h_n) + output = torch.stack(outputs, dim=-2) + h_n = torch.stack(h_ns, dim=-2) + if not training: + output = output.squeeze(1) # Mlp if self.output_has_agent_dim: @@ -278,8 +357,8 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: output = self.mlp[0](output) tensordict.set(self.out_key, output) - if h_0 is not None: - tensordict.set(("next", self.agent_group, "_hidden_gru"), h_n) + if not training: + tensordict.set(("next", *self.hidden_state_name), h_n) return tensordict @@ -306,42 +385,9 @@ def associated_class(): def is_rnn(self) -> bool: return True - def process_env_fun( - self, - env_fun: Callable[[], EnvBase], - task, - model_index: int = 0, - ) -> Callable[[], EnvBase]: - """ - This function can be used to wrap env_fun - Args: - env_fun (callable): a function that takes no args and creates an enviornment - - Returns: a function that takes no args and creates an enviornment - - """ - - def model_fun(): - env = env_fun() - env = TransformedEnv( - env, - Compose( - InitTracker(init_key="is_init"), - TensorDictPrimer( - { - group: CompositeSpec( - { - "_hidden_gru": UnboundedContinuousTensorSpec( - shape=(len(agents), self.hidden_size) - ) - }, - shape=(len(agents),), - ) - for group, agents in task.group_map(env).items() - } - ), - ), - ) - return env - - return model_fun + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + name = f"_hidden_gru_{model_index}" + spec = CompositeSpec( + {name: UnboundedContinuousTensorSpec(shape=(self.hidden_size,))} + ) + return spec diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index 66e2c3c4..c353be11 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -46,6 +46,8 @@ def __init__( share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), ) self.input_features = sum(