Skip to content

Commit

Permalink
[Model] LSTM (#120)
Browse files Browse the repository at this point in the history
* amend

* copyright

* test

* test

* test

* doc
  • Loading branch information
matteobettini authored Aug 2, 2024
1 parent 8f84b67 commit d260eea
Show file tree
Hide file tree
Showing 10 changed files with 669 additions and 30 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ agent group. Here is a table of the models implemented in BenchMARL
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GRU](benchmarl/models/gru.py) | Yes | Yes | Yes |
| [LSTM](benchmarl/models/lstm.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |
Expand Down
15 changes: 15 additions & 0 deletions benchmarl/conf/model/layers/lstm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

name: lstm

hidden_size: 128
n_layers: 1
bias: True
dropout: 0
compile: False

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
mlp_activation_class: torch.nn.Tanh
mlp_activation_kwargs: null
mlp_norm_class: null
mlp_norm_kwargs: null
4 changes: 4 additions & 0 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .gru import Gru, GruConfig
from .lstm import Lstm, LstmConfig
from .mlp import Mlp, MlpConfig

classes = [
Expand All @@ -22,6 +23,8 @@
"DeepsetsConfig",
"Gru",
"GruConfig",
"Lstm",
"LstmConfig",
]

model_config_registry = {
Expand All @@ -30,4 +33,5 @@
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
"gru": GruConfig,
"lstm": LstmConfig,
}
21 changes: 7 additions & 14 deletions benchmarl/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
# LICENSE file in the root directory of this source tree.
#

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import annotations

from dataclasses import dataclass, MISSING
Expand Down Expand Up @@ -167,7 +161,7 @@ def forward(
h_0=None,
):
# Input and output always have the multiagent dimension
# Hidden state only has it when not centralised
# Hidden states always have it apart from when it is centralized and share params
# is_init never has it

assert is_init is not None, "We need to pass is_init"
Expand Down Expand Up @@ -202,7 +196,7 @@ def forward(
is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1)

if h_0 is None:
if self.centralised:
if self.centralised and self.share_params:
shape = (
batch,
self.n_layers,
Expand Down Expand Up @@ -243,8 +237,8 @@ def run_net(self, input, is_init, h_0):
if self.centralised:
output, h_n = self.vmap_func_module(
self._empty_gru,
(0, None, None, None),
(-2, -2),
(0, None, None, -3),
(-2, -3),
)(self.params, input, is_init, h_0)
else:
output, h_n = self.vmap_func_module(
Expand Down Expand Up @@ -283,8 +277,8 @@ class Gru(Model):
The BenchMARL GRU accepts multiple inputs of type array: Tensors of shape ``(*batch,F)``
Where `F` is the number of features.
The features `F` will be processed to features of `hidden_size` by the GRU.
Where `F` is the number of features. These arrays will be concatenated along the F dimensions,
which will be processed to features of `hidden_size` by the GRU.
Args:
hidden_size (int): The number of features in the hidden state.
Expand Down Expand Up @@ -516,10 +510,9 @@ def is_rnn(self) -> bool:
return True

def get_model_state_spec(self, model_index: int = 0) -> CompositeSpec:
name = f"_hidden_gru_{model_index}"
spec = CompositeSpec(
{
name: UnboundedContinuousTensorSpec(
f"_hidden_gru_{model_index}": UnboundedContinuousTensorSpec(
shape=(self.n_layers, self.hidden_size)
)
}
Expand Down
Loading

0 comments on commit d260eea

Please sign in to comment.