diff --git a/.github/unittest/install_dependencies.sh b/.github/unittest/install_dependencies.sh index aeef298d..581f0361 100644 --- a/.github/unittest/install_dependencies.sh +++ b/.github/unittest/install_dependencies.sh @@ -1,7 +1,7 @@ python -m pip install --upgrade pip -python -m pip install flake8 pytest pytest-cov hydra-core tqdm torch torch_geometric torchvision av +python -m pip install flake8 pytest pytest-cov hydra-core tqdm torch torch_geometric torchvision "av<14" if [ -f requirements.txt ]; then pip install -r requirements.txt; fi diff --git a/.github/unittest/install_dependencies_nightly.sh b/.github/unittest/install_dependencies_nightly.sh index bd71bb8d..c53a5f6c 100644 --- a/.github/unittest/install_dependencies_nightly.sh +++ b/.github/unittest/install_dependencies_nightly.sh @@ -5,7 +5,7 @@ python -m pip install flake8 pytest pytest-cov hydra-core tqdm torch_geometric if [ -f requirements.txt ]; then pip install -r requirements.txt; fi -python -m pip install torch torchvision av +python -m pip install torch torchvision "av<14" # Not using nightly torch # python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst index 492b42ba..d9e6ff30 100644 --- a/docs/source/usage/installation.rst +++ b/docs/source/usage/installation.rst @@ -43,7 +43,7 @@ You may want to install the following rendering and logging tools .. code-block:: console - pip install wandb moviepy torchvision av + pip install wandb moviepy torchvision "av<14" Install environments -------------------- diff --git a/setup.py b/setup.py index d8c0c27c..1a5ff528 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def get_version(): "pettingzoo": ["pettingzoo[all]>=1.24.3"], "meltingpot": ["dm-meltingpot"], "gnn": ["torch_geometric"], - "logging": ["moviepy", "wandb", "torchvision"], + "logging": ["moviepy", "wandb", "torchvision", "av<14"], }, packages=find_packages(), include_package_data=True, diff --git a/test/test_magent.py b/test/test_magent.py index 9be8240c..3db33220 100644 --- a/test/test_magent.py +++ b/test/test_magent.py @@ -113,7 +113,7 @@ def test_reloading_trainer( task=task.get_from_yaml(), ) - @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig]) + @pytest.mark.parametrize("algo_config", [QmixConfig, MasacConfig]) @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT]) @pytest.mark.parametrize("share_params", [True, False]) def test_share_policy_params(