Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 13, 2024
1 parent b0ae1c6 commit 43a99c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
15 changes: 11 additions & 4 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Gnn(Model):
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
exclude_pos_from_node_features (bool): If ``position_key`` is provided,
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_kwargs: Optional[dict],
position_key: Optional[str],
exclude_pos_from_node_features: bool,
exclude_pos_from_node_features: Optional[bool],
velocity_key: Optional[str],
edge_radius: Optional[float],
**kwargs,
Expand Down Expand Up @@ -201,6 +201,13 @@ def _perform_checks(self):
)
if self.topology == "from_pos" and self.position_key is None:
raise ValueError("If topology is from_pos, position_key must be provided")
if (
self.position_key is not None
and self.exclude_pos_from_node_features is None
):
raise ValueError(
"exclude_pos_from_node_features needs to be specified when position_key is provided"
)

if not self.input_has_agent_dim:
raise ValueError(
Expand Down Expand Up @@ -418,11 +425,11 @@ class GnnConfig(ModelConfig):
self_loops: bool = MISSING

gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING
exclude_pos_from_node_features: bool = MISSING

gnn_kwargs: Optional[dict] = None

position_key: Optional[str] = None
velocity_key: Optional[str] = None
exclude_pos_from_node_features: Optional[bool] = None
edge_radius: Optional[float] = None

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def mlp_gnn_sequence_config() -> ModelConfig:
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
gnn_kwargs={},
),
MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear),
],
Expand Down

0 comments on commit 43a99c6

Please sign in to comment.