diff --git a/test/test_magent.py b/test/test_magent.py index 5a246840..881a4ce0 100644 --- a/test/test_magent.py +++ b/test/test_magent.py @@ -19,6 +19,7 @@ from benchmarl.experiment import Experiment from utils import _has_magent2 +from utils_experiment import ExperimentUtils @pytest.mark.skipif(not _has_magent2, reason="magent2 not found") @@ -91,26 +92,26 @@ def test_lstm( ) experiment.run() - # @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) - # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) - # def test_reloading_trainer( - # self, - # algo_config: AlgorithmConfig, - # task: Task, - # experiment_config, - # cnn_sequence_config, - # ): - # # To not run unsupported algo-task pairs - # if not algo_config.supports_discrete_actions(): - # pytest.skip() - # algo_config = algo_config.get_from_yaml() - # - # ExperimentUtils.check_experiment_loading( - # algo_config=algo_config, - # model_config=cnn_sequence_config, - # experiment_config=experiment_config, - # task=task.get_from_yaml(), - # ) + @pytest.mark.parametrize("algo_config", IppoConfig) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_reloading_trainer( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + # To not run unsupported algo-task pairs + if not algo_config.supports_discrete_actions(): + pytest.skip() + algo_config = algo_config.get_from_yaml() + + ExperimentUtils.check_experiment_loading( + algo_config=algo_config, + model_config=cnn_sequence_config, + experiment_config=experiment_config, + task=task.get_from_yaml(), + ) @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])