diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 212e3c99..223fc9a4 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -442,24 +442,30 @@ def _setup_task(self): transforms_training = transforms_env + [ self.task.get_reward_sum_transform(test_env) ] - transforms_env = Compose(*transforms_env) transforms_training = Compose(*transforms_training) - self.observation_spec = self.task.observation_spec(test_env) - self.info_spec = self.task.info_spec(test_env) - self.state_spec = self.task.state_spec(test_env) - self.action_mask_spec = self.task.action_mask_spec(test_env) - self.action_spec = self.task.action_spec(test_env) - self.group_map = self.task.group_map(test_env) + # Initialize test env + self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( + self.config.sampling_device + ) + + self.observation_spec = self.task.observation_spec(self.test_env) + self.info_spec = self.task.info_spec(self.test_env) + self.state_spec = self.task.state_spec(self.test_env) + self.action_mask_spec = self.task.action_mask_spec(self.test_env) + self.action_spec = self.task.action_spec(self.test_env) + self.group_map = self.task.group_map(self.test_env) self.train_group_map = copy.deepcopy(self.group_map) - self.max_steps = self.task.max_steps(test_env) + self.max_steps = self.task.max_steps(self.test_env) + # Add rnn transforms here so they do not show in the benchmarl specs if self.model_config.is_rnn: - test_env = self._add_rnn_transforms(lambda: test_env)() + self.test_env = self._add_rnn_transforms(lambda: self.test_env)() env_func = self._add_rnn_transforms(env_func) - if test_env.batch_size == (): + # Initialize train env + if self.test_env.batch_size == (): # If the environment is not vectorized, we simulate vectorization using parallel or serial environments env_class = ( SerialEnv if not self.config.parallel_collection else ParallelEnv @@ -469,14 +475,11 @@ def _setup_task(self): transforms_training.clone(), ) else: + # Otherwise it is already vectorized self.env_func = lambda: TransformedEnv( env_func(), transforms_training.clone() ) - self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( - self.config.sampling_device - ) - def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self)