Skip to content

Commit

Permalink
all tests apart reloading
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 27, 2024
1 parent fbcff01 commit cb35028
Showing 1 changed file with 67 additions and 66 deletions.
133 changes: 67 additions & 66 deletions test/test_magent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

from algorithms import IppoConfig, IsacConfig, MasacConfig, QmixConfig
from benchmarl.algorithms import algorithm_config_registry
from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import MAgentTask, Task
Expand Down Expand Up @@ -41,50 +42,50 @@ def test_all_algos(
)
experiment.run()

# @pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig, IsacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_gnn(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# experiment_config,
# cnn_gnn_sequence_config,
# ):
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config.get_from_yaml(),
# model_config=cnn_gnn_sequence_config,
# critic_model_config=cnn_gnn_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()
#
# @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_lstm(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# experiment_config,
# cnn_lstm_sequence_config,
# ):
# algo_config = algo_config.get_from_yaml()
# if algo_config.has_critic():
# algo_config.share_param_critic = False
# experiment_config.share_policy_params = False
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config,
# model_config=cnn_lstm_sequence_config,
# critic_model_config=cnn_lstm_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()
#
@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_gnn(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_gnn_sequence_config,
):
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_gnn_sequence_config,
critic_model_config=cnn_gnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
def test_lstm(
self,
algo_config: AlgorithmConfig,
task: Task,
experiment_config,
cnn_lstm_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config,
model_config=cnn_lstm_sequence_config,
critic_model_config=cnn_lstm_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

# @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# def test_reloading_trainer(
Expand All @@ -105,25 +106,25 @@ def test_all_algos(
# experiment_config=experiment_config,
# task=task.get_from_yaml(),
# )
#
# @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
# @pytest.mark.parametrize("share_params", [True, False])
# def test_share_policy_params(
# self,
# algo_config: AlgorithmConfig,
# task: Task,
# share_params,
# experiment_config,
# cnn_sequence_config,
# ):
# experiment_config.share_policy_params = share_params
# task = task.get_from_yaml()
# experiment = Experiment(
# algorithm_config=algo_config.get_from_yaml(),
# model_config=cnn_sequence_config,
# seed=0,
# config=experiment_config,
# task=task,
# )
# experiment.run()

@pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
@pytest.mark.parametrize("share_params", [True, False])
def test_share_policy_params(
self,
algo_config: AlgorithmConfig,
task: Task,
share_params,
experiment_config,
cnn_sequence_config,
):
experiment_config.share_policy_params = share_params
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
model_config=cnn_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

0 comments on commit cb35028

Please sign in to comment.