Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] DeepSets #96

Merged
merged 16 commits into from
Jun 13, 2024
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ when requested, as critics. We provide a set of base models (layers) and a Seque
different layers. All the models can be used with or without parameter sharing within an
agent group. Here is a table of the models implemented in BenchMARL

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |

And the ones that are _work in progress_

Expand Down
9 changes: 9 additions & 0 deletions benchmarl/conf/model/layers/deepsets.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

name: deepsets

aggr: "sum"
local_nn_num_cells: [128, 128]
local_nn_activation_class: torch.nn.Tanh
out_features_local_nn: 256
global_nn_num_cells: [256, 256]
global_nn_activation_class: torch.nn.Tanh
19 changes: 17 additions & 2 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,24 @@

from .cnn import Cnn, CnnConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"]
classes = [
"Mlp",
"MlpConfig",
"Gnn",
"GnnConfig",
"Cnn",
"CnnConfig",
"Deepsets",
"DeepsetsConfig",
]

model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig}
model_config_registry = {
"mlp": MlpConfig,
"gnn": GnnConfig,
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
}
Loading
Loading