Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 25, 2024
1 parent 9813807 commit db6c2a5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
7 changes: 4 additions & 3 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
"""
return []

def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]:
"""
Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`.
Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`
of the specified group.
Args:
env (EnvBase): An environment created via self.get_env_fun
group (str): The agent group using the replay buffer
"""
return []
Expand Down
12 changes: 3 additions & 9 deletions benchmarl/environments/meltingpot/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,16 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]:
else []
)

def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]:
return [
DTypeCastTransform(
dtype_in=torch.uint8,
dtype_out=torch.float,
in_keys=[
"RGB",
*[
(group, "observation", "RGB")
for group in self.group_map(env).keys()
],
(group, "observation", "RGB"),
("next", "RGB"),
*[
("next", group, "observation", "RGB")
for group in self.group_map(env).keys()
],
("next", group, "observation", "RGB"),
],
in_keys_inv=[],
)
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _setup_algorithm(self):
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env),
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
Expand Down

0 comments on commit db6c2a5

Please sign in to comment.