diff --git a/.github/unittest/install_magent2.sh b/.github/unittest/install_magent2.sh new file mode 100644 index 00000000..a4d156cd --- /dev/null +++ b/.github/unittest/install_magent2.sh @@ -0,0 +1,6 @@ + + +pip install git+https://github.com/Farama-Foundation/MAgent2 + +sudo apt-get update +sudo apt-get install python3-opengl xvfb diff --git a/.github/workflows/magent_tests.yml b/.github/workflows/magent_tests.yml new file mode 100644 index 00000000..5b10f6c5 --- /dev/null +++ b/.github/workflows/magent_tests.yml @@ -0,0 +1,43 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + + +name: magent_tests + +on: + push: + branches: [ $default-branch , "main" ] + pull_request: + branches: [ $default-branch , "main" ] + +permissions: + contents: read + +jobs: + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + bash .github/unittest/install_dependencies_nightly.sh + - name: Install magent2 + run: | + bash .github/unittest/install_magent2.sh + - name: Test with pytest + run: | + xvfb-run -s "-screen 0 1024x768x24" pytest test/test_magent.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + fail_ci_if_error: false diff --git a/README.md b/README.md index 3e515132..cdef823e 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,13 @@ pip install "pettingzoo[all]" ```bash pip install dm-meltingpot ``` + +##### MAgent2 + +```bash +pip install git+https://github.com/Farama-Foundation/MAgent +``` + ##### SMACv2 Follow the instructions on the environment [repository](https://github.com/oxwhirl/smacv2). @@ -239,13 +246,14 @@ determine the training strategy. Here is a table with the currently implemented challenge to solve. They differ based on many aspects, here is a table with the current environments in BenchMARL -| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized | -|--------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:| -| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [27](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes | -| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No | -| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No | -| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No | -| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No | +| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized | +|---------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:| +| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [27](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes | +| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No | +| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No | +| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No | +| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No | +| [MAgent2](https://github.com/Farama-Foundation/magent2) | [1](benchmarl/conf/task/magent) | Cooperative + Competitive | Yes | Global in groups | Discrete | No | > [!NOTE] diff --git a/benchmarl/conf/task/magent/adversarial_pursuit.yaml b/benchmarl/conf/task/magent/adversarial_pursuit.yaml new file mode 100644 index 00000000..cfd5f402 --- /dev/null +++ b/benchmarl/conf/task/magent/adversarial_pursuit.yaml @@ -0,0 +1,9 @@ +defaults: + - magent_adversarial_pursuit_config + - _self_ + +map_size: 45 +minimap_mode: False +tag_penalty: -0.2 +max_cycles: 500 +extra_features: False diff --git a/benchmarl/environments/__init__.py b/benchmarl/environments/__init__.py index 4648cc0b..55f070e5 100644 --- a/benchmarl/environments/__init__.py +++ b/benchmarl/environments/__init__.py @@ -5,6 +5,8 @@ # from .common import _get_task_config_class, Task + +from .magent.common import MAgentTask from .meltingpot.common import MeltingPotTask from .pettingzoo.common import PettingZooTask from .smacv2.common import Smacv2Task @@ -12,7 +14,7 @@ # The enum classes for the environments available. # This is the only object in this file you need to modify when adding a new environment. -tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask] +tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask, MAgentTask] # This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum # It is used by automatically load task enums from yaml files. diff --git a/benchmarl/environments/magent/__init__.py b/benchmarl/environments/magent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarl/environments/magent/adversarial_pursuit.py b/benchmarl/environments/magent/adversarial_pursuit.py new file mode 100644 index 00000000..93a56858 --- /dev/null +++ b/benchmarl/environments/magent/adversarial_pursuit.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from dataclasses import dataclass, MISSING + + +@dataclass +class TaskConfig: + map_size: int = MISSING + minimap_mode: bool = MISSING + tag_penalty: float = MISSING + max_cycles: int = MISSING + extra_features: bool = MISSING diff --git a/benchmarl/environments/magent/common.py b/benchmarl/environments/magent/common.py new file mode 100644 index 00000000..b8964ddd --- /dev/null +++ b/benchmarl/environments/magent/common.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Callable, Dict, List, Optional + +from torchrl.data import Composite +from torchrl.envs import EnvBase, PettingZooWrapper + +from benchmarl.environments.common import Task + +from benchmarl.utils import DEVICE_TYPING + + +class MAgentTask(Task): + """Enum for MAgent2 tasks.""" + + ADVERSARIAL_PURSUIT = None + # BATTLE = None + # BATTLEFIELD = None + # COMBINED_ARMS = None + # GATHER = None + # TIGER_DEER = None + + def get_env_fun( + self, + num_envs: int, + continuous_actions: bool, + seed: Optional[int], + device: DEVICE_TYPING, + ) -> Callable[[], EnvBase]: + + return lambda: PettingZooWrapper( + env=self.__get_env(), + return_state=True, + seed=seed, + done_on_any=False, + use_mask=False, + device=device, + ) + + def __get_env(self) -> EnvBase: + try: + from magent2.environments import ( + adversarial_pursuit_v4, + # battle_v4, + # battlefield_v5, + # combined_arms_v6, + # gather_v5, + # tiger_deer_v4 + ) + except ImportError: + raise ImportError( + "Module `magent2` not found, install it using `pip install magent2`" + ) + + envs = { + "ADVERSARIAL_PURSUIT": adversarial_pursuit_v4, + # "BATTLE": battle_v4, + # "BATTLEFIELD": battlefield_v5, + # "COMBINED_ARMS": combined_arms_v6, + # "GATHER": gather_v5, + # "TIGER_DEER": tiger_deer_v4 + } + if self.name not in envs: + raise Exception(f"{self.name} is not an environment of MAgent2") + return envs[self.name].parallel_env(**self.config, render_mode="rgb_array") + + def supports_continuous_actions(self) -> bool: + return False + + def supports_discrete_actions(self) -> bool: + return True + + def has_state(self) -> bool: + return True + + def has_render(self, env: EnvBase) -> bool: + return True + + def max_steps(self, env: EnvBase) -> int: + return self.config["max_cycles"] + + def group_map(self, env: EnvBase) -> Dict[str, List[str]]: + return env.group_map + + def state_spec(self, env: EnvBase) -> Optional[Composite]: + return Composite({"state": env.observation_spec["state"].clone()}) + + def action_mask_spec(self, env: EnvBase) -> Optional[Composite]: + observation_spec = env.observation_spec.clone() + for group in self.group_map(env): + group_obs_spec = observation_spec[group] + for key in list(group_obs_spec.keys()): + if key != "action_mask": + del group_obs_spec[key] + if group_obs_spec.is_empty(): + del observation_spec[group] + del observation_spec["state"] + if observation_spec.is_empty(): + return None + return observation_spec + + def observation_spec(self, env: EnvBase) -> Composite: + observation_spec = env.observation_spec.clone() + for group in self.group_map(env): + group_obs_spec = observation_spec[group] + for key in list(group_obs_spec.keys()): + if key != "observation": + del group_obs_spec[key] + del observation_spec["state"] + return observation_spec + + def info_spec(self, env: EnvBase) -> Optional[Composite]: + observation_spec = env.observation_spec.clone() + for group in self.group_map(env): + group_obs_spec = observation_spec[group] + for key in list(group_obs_spec.keys()): + if key != "info": + del group_obs_spec[key] + del observation_spec["state"] + return observation_spec + + def action_spec(self, env: EnvBase) -> Composite: + return env.full_action_spec + + @staticmethod + def env_name() -> str: + return "magent" diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index 4c718abb..cb8903d0 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -91,6 +91,8 @@ They differ based on many aspects, here is a table with the current environments +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ | :class:`~benchmarl.environments.MeltingPotTask` | 49 | Cooperative + Competitive | Yes | Independent | Discrete | No | +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ + | :class:`~benchmarl.environments.MAgentTask` | 1 | Cooperative + Competitive | Yes | Global in groups | Discrete | No | + +-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+ diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst index 73644a05..492b42ba 100644 --- a/docs/source/usage/installation.rst +++ b/docs/source/usage/installation.rst @@ -87,6 +87,16 @@ Follow the instructions on the environment `repository `_ is how we install it on linux. +MAgent2 +^^^^^^^ +:github:`null` `GitHub `__ + + +.. code-block:: console + + pip install git+https://github.com/Farama-Foundation/MAgent + + Install models -------------- diff --git a/test/conftest.py b/test/conftest.py index 4ef505bc..3f53416e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -90,6 +90,31 @@ def mlp_gnn_sequence_config() -> ModelConfig: ) +@pytest.fixture +def cnn_gnn_sequence_config() -> ModelConfig: + return SequenceModelConfig( + model_configs=[ + CnnConfig( + cnn_num_cells=[4, 3], + cnn_kernel_sizes=[3, 2], + cnn_strides=1, + cnn_paddings=0, + cnn_activation_class=nn.Tanh, + mlp_num_cells=[4], + mlp_activation_class=nn.Tanh, + mlp_layer_class=nn.Linear, + ), + GnnConfig( + topology="full", + self_loops=False, + gnn_class=torch_geometric.nn.conv.GATv2Conv, + ), + MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear), + ], + intermediate_sizes=[5, 3], + ) + + @pytest.fixture def gru_mlp_sequence_config() -> ModelConfig: return SequenceModelConfig( @@ -128,3 +153,32 @@ def lstm_mlp_sequence_config() -> ModelConfig: ], intermediate_sizes=[5], ) + + +@pytest.fixture +def cnn_lstm_sequence_config() -> ModelConfig: + return SequenceModelConfig( + model_configs=[ + CnnConfig( + cnn_num_cells=[4, 3], + cnn_kernel_sizes=[3, 2], + cnn_strides=1, + cnn_paddings=0, + cnn_activation_class=nn.Tanh, + mlp_num_cells=[4], + mlp_activation_class=nn.Tanh, + mlp_layer_class=nn.Linear, + ), + LstmConfig( + hidden_size=13, + mlp_num_cells=[], + mlp_activation_class=nn.Tanh, + mlp_layer_class=nn.Linear, + n_layers=1, + bias=True, + dropout=0, + compile=False, + ), + ], + intermediate_sizes=[5], + ) diff --git a/test/test_magent.py b/test/test_magent.py new file mode 100644 index 00000000..9be8240c --- /dev/null +++ b/test/test_magent.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import pytest + +from benchmarl.algorithms import ( + algorithm_config_registry, + IppoConfig, + IsacConfig, + MasacConfig, + QmixConfig, +) +from benchmarl.algorithms.common import AlgorithmConfig +from benchmarl.environments import MAgentTask, Task +from benchmarl.experiment import Experiment + +from utils import _has_magent2 +from utils_experiment import ExperimentUtils + + +@pytest.mark.skipif(not _has_magent2, reason="magent2 not found") +class TestMagent: + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_all_algos( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + + # To not run unsupported algo-task pairs + if not algo_config.supports_discrete_actions(): + pytest.skip() + + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_gnn( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_gnn_sequence_config, + ): + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_gnn_sequence_config, + critic_model_config=cnn_gnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_lstm( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_lstm_sequence_config, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = False + experiment_config.share_policy_params = False + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=cnn_lstm_sequence_config, + critic_model_config=cnn_lstm_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + + @pytest.mark.parametrize("algo_config", [IppoConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + def test_reloading_trainer( + self, + algo_config: AlgorithmConfig, + task: Task, + experiment_config, + cnn_sequence_config, + ): + # To not run unsupported algo-task pairs + if not algo_config.supports_discrete_actions(): + pytest.skip() + algo_config = algo_config.get_from_yaml() + + ExperimentUtils.check_experiment_loading( + algo_config=algo_config, + model_config=cnn_sequence_config, + experiment_config=experiment_config, + task=task.get_from_yaml(), + ) + + @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) + @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_policy_params( + self, + algo_config: AlgorithmConfig, + task: Task, + share_params, + experiment_config, + cnn_sequence_config, + ): + experiment_config.share_policy_params = share_params + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config.get_from_yaml(), + model_config=cnn_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() diff --git a/test/utils.py b/test/utils.py index 848b149b..1b5f5b41 100644 --- a/test/utils.py +++ b/test/utils.py @@ -10,3 +10,4 @@ _has_smacv2 = importlib.util.find_spec("smacv2") is not None _has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None _has_meltingpot = importlib.util.find_spec("meltingpot") is not None +_has_magent2 = importlib.util.find_spec("magent2") is not None