Skip to content

Commit

Permalink
compsitespec -> composite
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 27, 2024
1 parent 64851a9 commit fbcff01
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions benchmarl/environments/magent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Callable, Dict, List, Optional

from torchrl.data import CompositeSpec
from torchrl.data import Composite
from torchrl.envs import EnvBase, PettingZooWrapper

from benchmarl.environments.common import Task
Expand Down Expand Up @@ -86,10 +86,10 @@ def max_steps(self, env: EnvBase) -> int:
def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return CompositeSpec({"state": env.observation_spec["state"].clone()})
def state_spec(self, env: EnvBase) -> Optional[Composite]:
return Composite({"state": env.observation_spec["state"].clone()})

def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
def action_mask_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
Expand All @@ -103,7 +103,7 @@ def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return None
return observation_spec

def observation_spec(self, env: EnvBase) -> CompositeSpec:
def observation_spec(self, env: EnvBase) -> Composite:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
Expand All @@ -113,7 +113,7 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec:
del observation_spec["state"]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
def info_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
Expand All @@ -123,7 +123,7 @@ def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
del observation_spec["state"]
return observation_spec

def action_spec(self, env: EnvBase) -> CompositeSpec:
def action_spec(self, env: EnvBase) -> Composite:
return env.full_action_spec

@staticmethod
Expand Down

0 comments on commit fbcff01

Please sign in to comment.