diff --git a/test/test_magent.py b/test/test_magent.py index cc9aecdf..26c9aab1 100644 --- a/test/test_magent.py +++ b/test/test_magent.py @@ -7,6 +7,7 @@ import pytest +from algorithms import IppoConfig, IsacConfig, MasacConfig, QmixConfig from benchmarl.algorithms import algorithm_config_registry from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import MAgentTask, Task @@ -41,50 +42,50 @@ def test_all_algos( ) experiment.run() - # @pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig, IsacConfig]) - # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) - # def test_gnn( - # self, - # algo_config: AlgorithmConfig, - # task: Task, - # experiment_config, - # cnn_gnn_sequence_config, - # ): - # task = task.get_from_yaml() - # experiment = Experiment( - # algorithm_config=algo_config.get_from_yaml(), - # model_config=cnn_gnn_sequence_config, - # critic_model_config=cnn_gnn_sequence_config, - # seed=0, - # config=experiment_config, - # task=task, - # ) - # experiment.run() - # - # @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig]) - # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) - # def test_lstm( - # self, - # algo_config: AlgorithmConfig, - # task: Task, - # experiment_config, - # cnn_lstm_sequence_config, - # ): - # algo_config = algo_config.get_from_yaml() - # if algo_config.has_critic(): - # algo_config.share_param_critic = False - # experiment_config.share_policy_params = False - # task = task.get_from_yaml() - # experiment = Experiment( - # algorithm_config=algo_config, - # model_config=cnn_lstm_sequence_config, - # critic_model_config=cnn_lstm_sequence_config, - # seed=0, - # config=experiment_config, - # task=task, - # ) - # experiment.run() - # + @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_gnn( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_gnn_sequence_config, + ): + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_gnn_sequence_config, + critic_model_config=cnn_gnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_lstm( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_lstm_sequence_config, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = False + experiment_config.share_policy_params = False + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=cnn_lstm_sequence_config, + critic_model_config=cnn_lstm_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + # @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) # def test_reloading_trainer( @@ -105,25 +106,25 @@ def test_all_algos( # experiment_config=experiment_config, # task=task.get_from_yaml(), # ) - # - # @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) - # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) - # @pytest.mark.parametrize("share_params", [True, False]) - # def test_share_policy_params( - # self, - # algo_config: AlgorithmConfig, - # task: Task, - # share_params, - # experiment_config, - # cnn_sequence_config, - # ): - # experiment_config.share_policy_params = share_params - # task = task.get_from_yaml() - # experiment = Experiment( - # algorithm_config=algo_config.get_from_yaml(), - # model_config=cnn_sequence_config, - # seed=0, - # config=experiment_config, - # task=task, - # ) - # experiment.run() + + @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_policy_params( + self, + algo_config: AlgorithmConfig, + task: Task, + share_params, + experiment_config, + cnn_sequence_config, + ): + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run()