diff --git a/README.md b/README.md index c5ee12ce..a62baf55 100644 --- a/README.md +++ b/README.md @@ -263,16 +263,11 @@ agent group. Here is a table of the models implemented in BenchMARL | Name | Decentralized | Centralized with local inputs | Centralized with global input | |------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:| | [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes | +| [GRU](benchmarl/models/gru.py) | Yes | Yes | Yes | | [GNN](benchmarl/models/gnn.py) | Yes | Yes | No | | [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes | | [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes | -And the ones that are _work in progress_ - -| Name | Decentralized | Centralized with local inputs | Centralized with global input | -|--------------------|:-------------:|:-----------------------------:|:-----------------------------:| -| RNN (GRU and LSTM) | Yes | Yes | Yes | - ## Fine-tuned public benchmarks > [!WARNING] diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index e742d808..f76e57ec 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -7,11 +7,12 @@ import pathlib from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from tensordict import TensorDictBase from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import ( + CompositeSpec, DiscreteTensorSpec, LazyTensorStorage, OneHotDiscreteTensorSpec, @@ -19,7 +20,14 @@ TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement -from torchrl.envs import Compose, Transform +from torchrl.envs import ( + Compose, + EnvBase, + InitTracker, + TensorDictPrimer, + Transform, + TransformedEnv, +) from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -51,6 +59,16 @@ def __init__(self, experiment): self.action_spec = experiment.action_spec self.state_spec = experiment.state_spec self.action_mask_spec = experiment.action_mask_spec + self.has_independent_critic = ( + experiment.algorithm_config.has_independent_critic() + ) + self.has_centralized_critic = ( + experiment.algorithm_config.has_centralized_critic() + ) + self.has_critic = experiment.algorithm_config.has_critic + self.has_rnn = self.model_config.is_rnn or ( + self.critic_model_config.is_rnn and self.has_critic + ) # Cached values that will be instantiated only once and then remain fixed self._losses_and_updaters = {} @@ -142,6 +160,14 @@ def get_replay_buffer( """ memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy) sampling_size = self.experiment_config.train_minibatch_size(self.on_policy) + if self.has_rnn: + sequence_length = -( + -self.experiment_config.collected_frames_per_batch(self.on_policy) + // self.experiment_config.n_envs_per_worker(self.on_policy) + ) + memory_size = -(-memory_size // sequence_length) + sampling_size = -(-sampling_size // sequence_length) + sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler() return TensorDictReplayBuffer( storage=LazyTensorStorage( @@ -218,6 +244,54 @@ def get_parameters(self, group: str) -> Dict[str, Iterable]: loss=self.get_loss_and_updater(group)[0], ) + def process_env_fun( + self, + env_fun: Callable[[], EnvBase], + ) -> Callable[[], EnvBase]: + """ + This function can be used to wrap env_fun + + Args: + env_fun (callable): a function that takes no args and creates an enviornment + + Returns: a function that takes no args and creates an enviornment + + """ + if self.has_rnn: + + def model_fun(): + env = env_fun() + + spec_actor = self.model_config.get_model_state_spec() + spec_actor = CompositeSpec( + { + group: CompositeSpec( + spec_actor.expand(len(agents), *spec_actor.shape), + shape=(len(agents),), + ) + for group, agents in self.group_map.items() + } + ) + + env = TransformedEnv( + env, + Compose( + *( + [InitTracker(init_key="is_init")] + + ( + [TensorDictPrimer(spec_actor, reset_key="_reset")] + if len(spec_actor.keys(True, True)) > 0 + else [] + ) + ) + ), + ) + return env + + return model_fun + + return env_fun + ############################### # Abstract methods to implement ############################### @@ -399,3 +473,27 @@ def supports_discrete_actions() -> bool: If the algorithm supports discrete actions """ raise NotImplementedError + + @staticmethod + def has_independent_critic() -> bool: + """ + If the algorithm uses an independent critic + """ + return False + + @staticmethod + def has_centralized_critic() -> bool: + """ + If the algorithm uses a centralized critic + """ + return False + + def has_critic(self) -> bool: + """ + If the algorithm uses a critic + """ + if self.has_centralized_critic() and self.has_independent_critic(): + raise ValueError( + "Algorithm can either have a centralized critic or an indpendent one" + ) + return self.has_centralized_critic() or self.has_independent_critic() diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index 334aa526..a3b5c161 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -251,3 +251,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index 012c6880..8d7ef64a 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -325,3 +325,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return True + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index 03024c68..5407cc88 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -389,3 +389,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_independent_critic() -> bool: + return True diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index c3ad1069..b963bfb7 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -301,3 +301,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 3ddd8d53..9267fea4 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -361,3 +361,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return True + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index 1991403e..29080f2d 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -463,3 +463,7 @@ def supports_discrete_actions() -> bool: @staticmethod def on_policy() -> bool: return False + + @staticmethod + def has_centralized_critic() -> bool: + return True diff --git a/benchmarl/conf/model/layers/gru.yaml b/benchmarl/conf/model/layers/gru.yaml new file mode 100644 index 00000000..b882b159 --- /dev/null +++ b/benchmarl/conf/model/layers/gru.yaml @@ -0,0 +1,15 @@ + +name: gru + +hidden_size: 128 +n_layers: 1 +bias: True +dropout: 0 +compile: False + +mlp_num_cells: [256, 256] +mlp_layer_class: torch.nn.Linear +mlp_activation_class: torch.nn.Tanh +mlp_activation_kwargs: null +mlp_norm_class: null +mlp_norm_kwargs: null diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index b8a294c8..12fa8859 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -26,6 +26,8 @@ from torchrl.record.loggers import generate_exp_name from tqdm import tqdm +from benchmarl.algorithms import IsacConfig, MasacConfig + from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task from benchmarl.experiment.callback import Callback, CallbackNotifier @@ -322,8 +324,12 @@ def __init__( self.task = task self.model_config = model_config self.critic_model_config = ( - critic_model_config if critic_model_config is not None else model_config + critic_model_config + if critic_model_config is not None + else copy.deepcopy(model_config) ) + self.critic_model_config.is_critic = True + self.algorithm_config = algorithm_config self.seed = seed @@ -345,6 +351,7 @@ def on_policy(self) -> bool: def _setup(self): self.config.validate(self.on_policy) seed_everything(self.seed) + self._perfrom_checks() self._set_action_type() self._setup_task() self._setup_algorithm() @@ -353,6 +360,15 @@ def _setup(self): self._setup_logger() self._on_setup() + def _perfrom_checks(self): + if isinstance(self.algorithm_config, (MasacConfig, IsacConfig)) and ( + self.model_config.is_rnn or self.critic_model_config.is_rnn + ): + raise ValueError( + "SAC based losses not compatible with RNNs due to https://github.com/pytorch/rl/issues/2338." + " Please leave a comment on the issue if you would like this feature." + ) + def _set_action_type(self): if ( self.task.supports_continuous_actions() @@ -377,21 +393,17 @@ def _set_action_type(self): ) def _setup_task(self): - test_env = self.model_config.process_env_fun( - self.task.get_env_fun( - num_envs=self.config.evaluation_episodes, - continuous_actions=self.continuous_actions, - seed=self.seed, - device=self.config.sampling_device, - ) + test_env = self.task.get_env_fun( + num_envs=self.config.evaluation_episodes, + continuous_actions=self.continuous_actions, + seed=self.seed, + device=self.config.sampling_device, )() - env_func = self.model_config.process_env_fun( - self.task.get_env_fun( - num_envs=self.config.n_envs_per_worker(self.on_policy), - continuous_actions=self.continuous_actions, - seed=self.seed, - device=self.config.sampling_device, - ) + env_func = self.task.get_env_fun( + num_envs=self.config.n_envs_per_worker(self.on_policy), + continuous_actions=self.continuous_actions, + seed=self.seed, + device=self.config.sampling_device, ) transforms_env = self.task.get_env_transforms(test_env) @@ -427,6 +439,10 @@ def _setup_task(self): def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) + + self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)() + self.env_func = self.algorithm.process_env_fun(self.env_func) + self.replay_buffers = { group: self.algorithm.get_replay_buffer( group=group, @@ -610,7 +626,8 @@ def _collection_loop(self): for group in self.train_group_map.keys(): group_batch = batch.exclude(*self._get_excluded_keys(group)) group_batch = self.algorithm.process_batch(group, group_batch) - group_batch = group_batch.reshape(-1) + if not self.algorithm.has_rnn: + group_batch = group_batch.reshape(-1) self.replay_buffers[group].extend(group_batch) training_tds = [] diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index 8bd743be..554ee985 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -8,6 +8,7 @@ from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig from .deepsets import Deepsets, DeepsetsConfig from .gnn import Gnn, GnnConfig +from .gru import Gru, GruConfig from .mlp import Mlp, MlpConfig classes = [ @@ -19,6 +20,8 @@ "CnnConfig", "Deepsets", "DeepsetsConfig", + "Gru", + "GruConfig", ] model_config_registry = { @@ -26,4 +29,5 @@ "gnn": GnnConfig, "cnn": CnnConfig, "deepsets": DeepsetsConfig, + "gru": GruConfig, } diff --git a/benchmarl/models/cnn.py b/benchmarl/models/cnn.py index c3e7611d..95a3c426 100644 --- a/benchmarl/models/cnn.py +++ b/benchmarl/models/cnn.py @@ -114,6 +114,8 @@ def __init__( share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), ) self.x = self.input_spec[self.image_in_keys[0]].shape[-3] diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 85e07ee6..f788a52a 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -8,13 +8,12 @@ import warnings from abc import ABC, abstractmethod from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase, TensorDictSequential from tensordict.utils import NestedKey from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec -from torchrl.envs import EnvBase from benchmarl.utils import _class_from_name, _read_yaml_config, DEVICE_TYPING @@ -75,6 +74,8 @@ class Model(TensorDictModuleBase, ABC): This is independent of the other options as it is possible to have different parameters for centralized critics with global input. action_spec (CompositeSpec): The action spec of the environment + model_index (int): the index of the model in a sequence + is_critic (bool): Whether the model is a critic """ def __init__( @@ -88,6 +89,8 @@ def __init__( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int, + is_critic: bool, ): TensorDictModuleBase.__init__(self) @@ -100,6 +103,8 @@ def __init__( self.device = device self.n_agents = n_agents self.action_spec = action_spec + self.model_index = model_index + self.is_critic = is_critic self.in_keys = list(self.input_spec.keys(True, True)) self.out_keys = list(self.output_spec.keys(True, True)) @@ -220,8 +225,12 @@ def __init__( agent_group=models[0].agent_group, input_has_agent_dim=models[0].input_has_agent_dim, action_spec=models[0].action_spec, + model_index=models[0].model_index, + is_critic=models[0].is_critic, ) self.models = TensorDictSequential(*models) + self.in_keys = self.models.in_keys + self.out_keys = self.models.out_keys def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: return self.models(tensordict) @@ -250,6 +259,7 @@ def get_model( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int = 0, ) -> Model: """ Creates the model from the config. @@ -273,6 +283,7 @@ def get_model( This is independent of the other options as it is possible to have different parameters for centralized critics with global input. action_spec (CompositeSpec): The action spec of the environment + model_index (int): the index of the model in a sequence. Defaults to 0. Returns: the Model @@ -288,6 +299,8 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=model_index, + is_critic=self.is_critic, ) @staticmethod @@ -298,16 +311,43 @@ def associated_class(): """ raise NotImplementedError - def process_env_fun(self, env_fun: Callable[[], EnvBase]) -> Callable[[], EnvBase]: + @property + def is_rnn(self) -> bool: """ - This function can be used to wrap env_fun - Args: - env_fun (callable): a function that takes no args and creates an enviornment + Whether the model is an RNN + """ + return False - Returns: a function that takes no args and creates an enviornment + @property + def is_critic(self): + """ + Whether the model is a critic + """ + if not hasattr(self, "_is_critic"): + self._is_critic = False + return self._is_critic + @is_critic.setter + def is_critic(self, value): + """ + Set whether the model is a critic """ - return env_fun + self._is_critic = value + + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + """Get additional specs needed by the model as input. + + This method is useful for adding recurrent states. + + The returned value should be key: spec with the desired ending shape. + + The batch and agent dimensions will automatically be added to the spec. + + Args: + model_index (int, optional): the index of the model. Defaults to 0. + + """ + return CompositeSpec() @staticmethod def _load_from_yaml(name: str) -> Dict[str, Any]: @@ -392,6 +432,7 @@ def get_model( share_params: bool, device: DEVICE_TYPING, action_spec: CompositeSpec, + model_index: int = 0, ) -> Model: n_models = len(self.model_configs) if not n_models > 0: @@ -408,7 +449,7 @@ def get_model( intermediate_specs = [ CompositeSpec( { - f"_{agent_group}_intermediate_{i}": UnboundedContinuousTensorSpec( + f"_{agent_group}{'_critic' if self.is_critic else ''}_intermediate_{i}": UnboundedContinuousTensorSpec( shape=(n_agents, size) if out_has_agent_dim else (size,) ) } @@ -427,6 +468,7 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=0, ) ] @@ -441,6 +483,7 @@ def get_model( share_params=share_params, device=device, action_spec=action_spec, + model_index=i, ) for i in range(1, n_models) ] @@ -451,10 +494,30 @@ def get_model( def associated_class(): return SequenceModel - def process_env_fun(self, env_fun: Callable[[], EnvBase]) -> Callable[[], EnvBase]: + @property + def is_critic(self): + if not hasattr(self, "_is_critic"): + self._is_critic = False + return self._is_critic + + @is_critic.setter + def is_critic(self, value): + self._is_critic = value + for model_config in self.model_configs: + model_config.is_critic = value + + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + spec = CompositeSpec() + for i, model_config in enumerate(self.model_configs): + spec.update(model_config.get_model_state_spec(model_index=i)) + return spec + + @property + def is_rnn(self) -> bool: + is_rnn = False for model_config in self.model_configs: - env_fun = model_config.process_env_fun(env_fun) - return env_fun + is_rnn += model_config.is_rnn + return is_rnn @classmethod def get_from_yaml(cls, path: Optional[str] = None): diff --git a/benchmarl/models/gru.py b/benchmarl/models/gru.py new file mode 100644 index 00000000..d4ebe062 --- /dev/null +++ b/benchmarl/models/gru.py @@ -0,0 +1,527 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from dataclasses import dataclass, MISSING +from typing import Optional, Sequence, Type + +import torch +import torch.nn.functional as F +from tensordict import TensorDict, TensorDictBase +from tensordict.utils import expand_as_right, unravel_key_list +from torch import nn +from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec + +from torchrl.modules import GRUCell, MLP, MultiAgentMLP + +from benchmarl.models.common import Model, ModelConfig +from benchmarl.utils import DEVICE_TYPING + + +class GRU(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + device: DEVICE_TYPING, + n_layers: int, + dropout: float, + bias: bool, + time_dim: int = -2, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.device = device + self.time_dim = time_dim + self.n_layers = n_layers + self.dropout = dropout + self.bias = bias + + self.grus = torch.nn.ModuleList( + [ + GRUCell( + input_size if i == 0 else hidden_size, + hidden_size, + device=self.device, + bias=self.bias, + ) + for i in range(self.n_layers) + ] + ) + + def forward( + self, + input, + is_init, + h, + ): + hs = [] + h = list(h.unbind(dim=-2)) + for in_t, init_t in zip( + input.unbind(self.time_dim), is_init.unbind(self.time_dim) + ): + for layer in range(self.n_layers): + h[layer] = torch.where(init_t, 0, h[layer]) + + h[layer] = self.grus[layer](in_t, h[layer]) + + if layer < self.n_layers - 1 and self.dropout: + in_t = F.dropout(h[layer], p=self.dropout, training=self.training) + else: + in_t = h[layer] + + hs.append(in_t) + h_n = torch.stack(h, dim=-2) + output = torch.stack(hs, self.time_dim) + + return output, h_n + + +def get_net(input_size, hidden_size, n_layers, bias, device, dropout, compile): + gru = GRU( + input_size, + hidden_size, + n_layers=n_layers, + bias=bias, + device=device, + dropout=dropout, + ) + if compile: + gru = torch.compile(gru, mode="reduce-overhead") + return gru + + +class MultiAgentGRU(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + n_agents: int, + device: DEVICE_TYPING, + centralised: bool, + share_params: bool, + n_layers: int, + dropout: float, + bias: bool, + compile: bool, + ): + super().__init__() + self.input_size = input_size + self.n_agents = n_agents + self.hidden_size = hidden_size + self.device = device + self.centralised = centralised + self.share_params = share_params + self.n_layers = n_layers + self.bias = bias + self.dropout = dropout + self.compile = compile + + if self.centralised: + input_size = input_size * self.n_agents + + agent_networks = [ + get_net( + input_size=input_size, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, + device=self.device, + dropout=self.dropout, + compile=self.compile, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + self._make_params(agent_networks) + + with torch.device("meta"): + self._empty_gru = get_net( + input_size=input_size, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, + device="meta", + dropout=self.dropout, + compile=self.compile, + ) + # Remove all parameters + TensorDict.from_module(self._empty_gru).data.to("meta").to_module( + self._empty_gru + ) + + def forward( + self, + input, + is_init, + h_0=None, + ): + # Input and output always have the multiagent dimension + # Hidden state only has it when not centralised + # is_init never has it + + assert is_init is not None, "We need to pass is_init" + training = h_0 is None + + missing_batch = False + if ( + not training and len(input.shape) < 3 + ): # In evaluation the batch might be missing + missing_batch = True + input = input.unsqueeze(0) + h_0 = h_0.unsqueeze(0) + is_init = is_init.unsqueeze(0) + + if ( + not training + ): # In collection we emulate the sequence dimension and we have the hidden state + input = input.unsqueeze(1) + + # Check input + batch = input.shape[0] + seq = input.shape[1] + assert input.shape == (batch, seq, self.n_agents, self.input_size) + + if h_0 is not None: # Collection + # Set hidden to 0 when is_init + h_0 = torch.where(expand_as_right(is_init, h_0), 0, h_0) + + if not training: # If in collection emulate the sequence dimension + is_init = is_init.unsqueeze(1) + assert is_init.shape == (batch, seq, 1) + is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1) + + if h_0 is None: + if self.centralised: + shape = ( + batch, + self.n_layers, + self.hidden_size, + ) + else: + shape = ( + batch, + self.n_agents, + self.n_layers, + self.hidden_size, + ) + h_0 = torch.zeros( + shape, + device=self.device, + dtype=torch.float, + ) + if self.centralised: + input = input.view(batch, seq, self.n_agents * self.input_size) + is_init = is_init[..., 0, :] + + output, h_n = self.run_net(input, is_init, h_0) + + if self.centralised and self.share_params: + output = output.unsqueeze(-2).expand( + batch, seq, self.n_agents, self.hidden_size + ) + + if not training: + output = output.squeeze(1) + if missing_batch: + output = output.squeeze(0) + h_n = h_n.squeeze(0) + return output, h_n + + def run_net(self, input, is_init, h_0): + if not self.share_params: + if self.centralised: + output, h_n = self.vmap_func_module( + self._empty_gru, + (0, None, None, None), + (-2, -2), + )(self.params, input, is_init, h_0) + else: + output, h_n = self.vmap_func_module( + self._empty_gru, + (0, -2, -2, -3), + (-2, -3), + )(self.params, input, is_init, h_0) + else: + with self.params.to_module(self._empty_gru): + if self.centralised: + output, h_n = self._empty_gru(input, is_init, h_0) + else: + output, h_n = torch.vmap( + self._empty_gru, in_dims=(-2, -2, -3), out_dims=(-2, -3) + )(input, is_init, h_0) + + return output, h_n + + def vmap_func_module(self, module, *args, **kwargs): + def exec_module(params, *input): + with params.to_module(module): + return module(*input) + + return torch.vmap(exec_module, *args, **kwargs) + + def _make_params(self, agent_networks): + if self.share_params: + self.params = TensorDict.from_module(agent_networks[0], as_module=True) + else: + self.params = TensorDict.from_modules(*agent_networks, as_module=True) + + +class Gru(Model): + r"""A multi-layer Gated Recurrent Unit (GRU) RNN like the one from + `torch `__ . + + The BenchMARL GRU accepts multiple inputs of type array: Tensors of shape ``(*batch,F)`` + + Where `F` is the number of features. + The features `F` will be processed to features of `hidden_size` by the GRU. + + Args: + hidden_size (int): The number of features in the hidden state. + num_layers (int): Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias (bool): If ``False``, then the GRU layers do not use bias. + Default: ``True`` + dropout (float): If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + compile (bool): If ``True``, compiles underlying gru model. Default: ``False`` + + """ + + def __init__( + self, + hidden_size: int, + n_layers: int, + bias: bool, + dropout: float, + compile: bool, + **kwargs, + ): + + super().__init__( + input_spec=kwargs.pop("input_spec"), + output_spec=kwargs.pop("output_spec"), + agent_group=kwargs.pop("agent_group"), + input_has_agent_dim=kwargs.pop("input_has_agent_dim"), + n_agents=kwargs.pop("n_agents"), + centralised=kwargs.pop("centralised"), + share_params=kwargs.pop("share_params"), + device=kwargs.pop("device"), + action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), + ) + + self.hidden_state_name = (self.agent_group, f"_hidden_gru_{self.model_index}") + self.rnn_keys = unravel_key_list(["is_init", self.hidden_state_name]) + self.in_keys += self.rnn_keys + + self.hidden_size = hidden_size + self.n_layers = n_layers + self.bias = bias + self.dropout = dropout + self.compile = compile + + self.input_features = sum( + [spec.shape[-1] for spec in self.input_spec.values(True, True)] + ) + self.output_features = self.output_leaf_spec.shape[-1] + + if self.input_has_agent_dim: + self.gru = MultiAgentGRU( + self.input_features, + self.hidden_size, + self.n_agents, + self.device, + bias=self.bias, + n_layers=self.n_layers, + centralised=self.centralised, + share_params=self.share_params, + dropout=self.dropout, + compile=self.compile, + ) + else: + self.gru = nn.ModuleList( + [ + get_net( + input_size=self.input_features, + hidden_size=self.hidden_size, + n_layers=self.n_layers, + bias=self.bias, + device=self.device, + dropout=self.dropout, + compile=self.compile, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + mlp_net_kwargs = { + "_".join(k.split("_")[1:]): v + for k, v in kwargs.items() + if k.startswith("mlp_") + } + if self.output_has_agent_dim: + self.mlp = MultiAgentMLP( + n_agent_inputs=self.hidden_size, + n_agent_outputs=self.output_features, + n_agents=self.n_agents, + centralised=self.centralised, + share_params=self.share_params, + device=self.device, + **mlp_net_kwargs, + ) + else: + self.mlp = nn.ModuleList( + [ + MLP( + in_features=self.hidden_size, + out_features=self.output_features, + device=self.device, + **mlp_net_kwargs, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def _perform_checks(self): + super()._perform_checks() + + input_shape = None + for input_key, input_spec in self.input_spec.items(True, True): + if (self.input_has_agent_dim and len(input_spec.shape) == 2) or ( + not self.input_has_agent_dim and len(input_spec.shape) == 1 + ): + if input_shape is None: + input_shape = input_spec.shape[:-1] + else: + if input_spec.shape[:-1] != input_shape: + raise ValueError( + f"GRU inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + else: + raise ValueError( + f"GRU input value {input_key} from {self.input_spec} has an invalid shape, maybe you need a CNN?" + ) + if self.input_has_agent_dim: + if input_shape[-1] != self.n_agents: + raise ValueError( + "If the GRU input has the agent dimension," + f" the second to last spec dimension should be the number of agents, got {self.input_spec}" + ) + if ( + self.output_has_agent_dim + and self.output_leaf_spec.shape[-2] != self.n_agents + ): + raise ValueError( + "If the GRU output has the agent dimension," + " the second to last spec dimension should be the number of agents" + ) + + def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # Gather in_key + input = torch.cat( + [ + tensordict.get(in_key) + for in_key in self.in_keys + if in_key not in self.rnn_keys + ], + dim=-1, + ) + h_0 = tensordict.get(self.hidden_state_name, None) + is_init = tensordict.get("is_init") + training = h_0 is None + + # Has multi-agent input dimension + if self.input_has_agent_dim: + output, h_n = self.gru(input, is_init, h_0) + if not self.output_has_agent_dim: + output = output[..., 0, :] + else: # Is a global input, this is a critic + # Check input + batch = input.shape[0] + seq = input.shape[1] + assert input.shape == (batch, seq, self.input_features) + assert is_init.shape == (batch, seq, 1) + + h_0 = torch.zeros( + (batch, self.n_layers, self.hidden_size), + device=self.device, + dtype=torch.float, + ) + if self.share_params: + output, _ = self.gru[0](input, is_init, h_0) + else: + outputs = [] + for net in self.gru: + output, _ = net(input, is_init, h_0) + outputs.append(output) + output = torch.stack(outputs, dim=-2) + + # Mlp + if self.output_has_agent_dim: + output = self.mlp.forward(output) + else: + if not self.share_params: + output = torch.stack( + [net(output) for net in self.mlp], + dim=-2, + ) + else: + output = self.mlp[0](output) + + tensordict.set(self.out_key, output) + if not training: + tensordict.set(("next", *self.hidden_state_name), h_n) + return tensordict + + +@dataclass +class GruConfig(ModelConfig): + """Dataclass config for a :class:`~benchmarl.models.Gru`.""" + + hidden_size: int = MISSING + n_layers: int = MISSING + bias: bool = MISSING + dropout: float = MISSING + compile: bool = MISSING + + mlp_num_cells: Sequence[int] = MISSING + mlp_layer_class: Type[nn.Module] = MISSING + mlp_activation_class: Type[nn.Module] = MISSING + + mlp_activation_kwargs: Optional[dict] = None + mlp_norm_class: Type[nn.Module] = None + mlp_norm_kwargs: Optional[dict] = None + + @staticmethod + def associated_class(): + return Gru + + @property + def is_rnn(self) -> bool: + return True + + def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec: + name = f"_hidden_gru_{model_index}" + spec = CompositeSpec( + { + name: UnboundedContinuousTensorSpec( + shape=(self.n_layers, self.hidden_size) + ) + } + ) + return spec diff --git a/benchmarl/models/mlp.py b/benchmarl/models/mlp.py index ea810b00..c353be11 100644 --- a/benchmarl/models/mlp.py +++ b/benchmarl/models/mlp.py @@ -46,6 +46,8 @@ def __init__( share_params=kwargs.pop("share_params"), device=kwargs.pop("device"), action_spec=kwargs.pop("action_spec"), + model_index=kwargs.pop("model_index"), + is_critic=kwargs.pop("is_critic"), ) self.input_features = sum( @@ -99,7 +101,7 @@ def _perform_checks(self): if input_shape[-1] != self.n_agents: raise ValueError( "If the MLP input has the agent dimension," - " the second to last spec dimension should be the number of agents, got {self.input_spec}" + f" the second to last spec dimension should be the number of agents, got {self.input_spec}" ) if ( self.output_has_agent_dim diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index f8e22871..4bed457a 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -112,6 +112,8 @@ agent group. Here is a table of the models implemented in BenchMARL +=====================================+===============+===============================+===============================+ | :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes | +-------------------------------------+---------------+-------------------------------+-------------------------------+ + | :class:`~benchmarl.models.Gru` | Yes | Yes | Yes | + +-------------------------------------+---------------+-------------------------------+-------------------------------+ | :class:`~benchmarl.models.Gnn` | Yes | Yes | No | +-------------------------------------+---------------+-------------------------------+-------------------------------+ | :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes | diff --git a/examples/extending/model/models/custommodel.py b/examples/extending/model/models/custommodel.py index 7dea32cb..72d54ac6 100644 --- a/examples/extending/model/models/custommodel.py +++ b/examples/extending/model/models/custommodel.py @@ -176,6 +176,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: @dataclass class CustomModelConfig(ModelConfig): + # The config parameters for this class, these will be loaded from yaml custom_param: int = MISSING activation_class: Type[nn.Module] = MISSING @@ -184,3 +185,10 @@ class CustomModelConfig(ModelConfig): def associated_class(): # The associated algorithm class return CustomModel + + @property + def is_rnn(self) -> bool: + """ + Whether the model is an RNN + """ + return False diff --git a/test/conftest.py b/test/conftest.py index 4bdd195f..2df7e679 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,7 @@ import torch_geometric.nn.conv from benchmarl.experiment import ExperimentConfig -from benchmarl.models import CnnConfig, GnnConfig, MlpConfig +from benchmarl.models import CnnConfig, GnnConfig, GruConfig, MlpConfig from benchmarl.models.common import ModelConfig, SequenceModelConfig from torch import nn @@ -88,3 +88,23 @@ def mlp_gnn_sequence_config() -> ModelConfig: ], intermediate_sizes=[5, 3], ) + + +@pytest.fixture +def gru_mlp_sequence_config() -> ModelConfig: + return SequenceModelConfig( + model_configs=[ + GruConfig( + hidden_size=13, + mlp_num_cells=[], + mlp_activation_class=nn.Tanh, + mlp_layer_class=nn.Linear, + n_layers=1, + bias=True, + dropout=0, + compile=False, + ), + MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), + ], + intermediate_sizes=[5], + ) diff --git a/test/test_models.py b/test/test_models.py index 4af11458..c27495ec 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -6,9 +6,11 @@ import contextlib from typing import List +import packaging import pytest import torch import torch_geometric.nn +import torchrl from benchmarl.hydra_config import load_model_config_from_hydra from benchmarl.models import GnnConfig, model_config_registry @@ -143,6 +145,9 @@ def test_loading_sequence_models(model_name, intermediate_size=10): ["cnn", "gnn", "mlp"], ["cnn", "mlp", "gnn"], ["cnn", "mlp"], + ["cnn", "gru", "gnn", "mlp"], + ["cnn", "gru", "mlp"], + ["gru", "mlp"], ], ) def test_models_forward_shape( @@ -155,6 +160,11 @@ def test_models_forward_shape( or (isinstance(model_name, list) and model_name[0] != "gnn") ): pytest.skip("gnn model needs agent dim as input") + if ( + packaging.version.parse(torchrl.__version__).local is None + and "gru" in model_name + ): + pytest.skip("gru model needs torchrl from github") torch.manual_seed(0) @@ -176,6 +186,8 @@ def test_models_forward_shape( n_agents=n_agents, ) + if centralised: + config.is_critic = True model = config.get_model( input_spec=input_spec, output_spec=output_spec, @@ -187,8 +199,23 @@ def test_models_forward_shape( agent_group="agents", action_spec=None, ) - input_td = input_spec.expand(batch_size).rand() - out_td = model(input_td) + input_td = input_spec.rand() + if "gru" in model_name: + if len(batch_size) < 2: + if centralised: + pytest.skip("gru model with this batch sizes is a policy") + hidden_spec = config.get_model_state_spec() + hidden_spec = CompositeSpec( + { + "agents": CompositeSpec( + hidden_spec.expand(n_agents, *hidden_spec.shape), + shape=(n_agents,), + ) + } + ) + input_td.update(hidden_spec.rand()) + input_td["is_init"] = torch.randint(0, 2, (1,), dtype=torch.bool) + out_td = model(input_td.expand(batch_size)) assert output_spec.expand(batch_size).is_in(out_td) @@ -202,6 +229,9 @@ def test_models_forward_shape( ["cnn", "gnn", "mlp"], ["cnn", "mlp", "gnn"], ["cnn", "mlp"], + ["cnn", "gru", "gnn", "mlp"], + ["cnn", "gru", "mlp"], + ["gru", "mlp"], ], ) @pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) @@ -220,6 +250,11 @@ def test_share_params_between_models( or (isinstance(model_name, list) and model_name[0] != "gnn") ): pytest.skip("gnn model needs agent dim as input") + if ( + packaging.version.parse(torchrl.__version__).local is None + and "gru" in model_name + ): + pytest.skip("gru model needs torchrl from github") torch.manual_seed(1) input_spec, output_spec = _get_input_and_output_specs( @@ -239,6 +274,8 @@ def test_share_params_between_models( ) else: config = model_config_registry[model_name].get_from_yaml() + if centralised: + config.is_critic = True model = config.get_model( input_spec=input_spec, output_spec=output_spec, diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index d8799bb5..c34e886b 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -5,13 +5,16 @@ # +import packaging import pytest - +import torchrl from benchmarl.algorithms import ( algorithm_config_registry, + IddpgConfig, IppoConfig, IsacConfig, MaddpgConfig, + MappoConfig, MasacConfig, QmixConfig, ) @@ -104,6 +107,36 @@ def test_gnn( ) experiment.run() + @pytest.mark.parametrize( + "algo_config", [IddpgConfig, MaddpgConfig, IppoConfig, MappoConfig, QmixConfig] + ) + @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) + @pytest.mark.skipif( + packaging.version.parse(torchrl.__version__).local is None, + reason="gru model needs torchrl from github", + ) + def test_gru( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + gru_mlp_sequence_config, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = False + experiment_config.share_policy_params = False + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=gru_mlp_sequence_config, + critic_model_config=gru_mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @pytest.mark.parametrize("prefer_continuous", [True, False]) @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) diff --git a/test/test_smacv2.py b/test/test_smacv2.py index 671c84b3..aa518b3d 100644 --- a/test/test_smacv2.py +++ b/test/test_smacv2.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # - +import packaging import pytest +import torchrl from benchmarl.algorithms import algorithm_config_registry, MappoConfig, QmixConfig from benchmarl.algorithms.common import AlgorithmConfig @@ -77,3 +78,27 @@ def test_gnn( task=task, ) experiment.run() + + @pytest.mark.parametrize("algo_config", [QmixConfig]) + @pytest.mark.parametrize("task", [Smacv2Task.PROTOSS_5_VS_5]) + @pytest.mark.skipif( + packaging.version.parse(torchrl.__version__).local is None, + reason="gru model needs torchrl from github", + ) + def test_gru( + self, + algo_config, + task, + experiment_config, + gru_mlp_sequence_config, + ): + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=gru_mlp_sequence_config, + critic_model_config=gru_mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() diff --git a/test/test_vmas.py b/test/test_vmas.py index 9a6eb6cc..a7cd3aa0 100644 --- a/test/test_vmas.py +++ b/test/test_vmas.py @@ -4,15 +4,16 @@ # LICENSE file in the root directory of this source tree. # - +import packaging import pytest - +import torchrl from benchmarl.algorithms import ( algorithm_config_registry, IddpgConfig, IppoConfig, IsacConfig, MaddpgConfig, + MappoConfig, MasacConfig, QmixConfig, ) @@ -114,6 +115,37 @@ def test_gnn( ) experiment.run() + @pytest.mark.parametrize( + "algo_config", [IddpgConfig, MaddpgConfig, IppoConfig, MappoConfig, QmixConfig] + ) + @pytest.mark.parametrize("task", [VmasTask.NAVIGATION]) + @pytest.mark.skipif( + packaging.version.parse(torchrl.__version__).local is None, + reason="gru model needs torchrl from github", + ) + def test_gru( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + gru_mlp_sequence_config, + share_params: bool = False, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = share_params + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=gru_mlp_sequence_config, + critic_model_config=gru_mlp_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @pytest.mark.parametrize("task", [VmasTask.BALANCE]) def test_reloading_trainer(