Skip to content

Commit

Permalink
revert buffer update
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 17, 2024
1 parent a6952e0 commit a1c10ff
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def _setup(self):
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_buffers()
self._setup_name()
self._setup_logger()
self._on_setup()
Expand Down Expand Up @@ -480,6 +479,13 @@ def _setup_algorithm(self):
self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)

self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
self.losses = {
group: self.algorithm.get_loss_and_updater(group)[0]
for group in self.group_map.keys()
Expand Down Expand Up @@ -528,15 +534,6 @@ def _setup_collector(self):
)
self.rollout_env = self.env_func().to(self.config.sampling_device)

def _setup_buffers(self):
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}

def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
self.model_name = self.model_config.associated_class().__name__.lower()
Expand Down

0 comments on commit a1c10ff

Please sign in to comment.