Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 17, 2024
1 parent a9db46b commit d2b242f
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,24 +442,30 @@ def _setup_task(self):
transforms_training = transforms_env + [
self.task.get_reward_sum_transform(test_env)
]

transforms_env = Compose(*transforms_env)
transforms_training = Compose(*transforms_training)

self.observation_spec = self.task.observation_spec(test_env)
self.info_spec = self.task.info_spec(test_env)
self.state_spec = self.task.state_spec(test_env)
self.action_mask_spec = self.task.action_mask_spec(test_env)
self.action_spec = self.task.action_spec(test_env)
self.group_map = self.task.group_map(test_env)
# Initialize test env
self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
self.config.sampling_device
)

self.observation_spec = self.task.observation_spec(self.test_env)
self.info_spec = self.task.info_spec(self.test_env)
self.state_spec = self.task.state_spec(self.test_env)
self.action_mask_spec = self.task.action_mask_spec(self.test_env)
self.action_spec = self.task.action_spec(self.test_env)
self.group_map = self.task.group_map(self.test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(test_env)
self.max_steps = self.task.max_steps(self.test_env)

# Add rnn transforms here so they do not show in the benchmarl specs
if self.model_config.is_rnn:
test_env = self._add_rnn_transforms(lambda: test_env)()
self.test_env = self._add_rnn_transforms(lambda: self.test_env)()
env_func = self._add_rnn_transforms(env_func)

if test_env.batch_size == ():
# Initialize train env
if self.test_env.batch_size == ():
# If the environment is not vectorized, we simulate vectorization using parallel or serial environments
env_class = (
SerialEnv if not self.config.parallel_collection else ParallelEnv
Expand All @@ -469,14 +475,11 @@ def _setup_task(self):
transforms_training.clone(),
)
else:
# Otherwise it is already vectorized
self.env_func = lambda: TransformedEnv(
env_func(), transforms_training.clone()
)

self.test_env = TransformedEnv(test_env, transforms_env.clone()).to(
self.config.sampling_device
)

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)

Expand Down

0 comments on commit d2b242f

Please sign in to comment.