From 38da73a3ac4cbd99ee1b1ed8830040d419e074a3 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Thu, 14 Dec 2023 14:34:23 +0800 Subject: [PATCH] improve test --- tests/test_algorithm/test_a2c_algorithm.py | 95 +++++++++++++++++++++ tests/test_algorithm/test_bc_algorithm.py | 84 ++++++++++++++++++ tests/test_algorithm/test_ddpg_algorithm.py | 91 ++++++++++++++++++++ tests/test_algorithm/test_ppo_algorithm.py | 4 +- tests/test_algorithm/test_sac_algorithm.py | 91 ++++++++++++++++++++ 5 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 tests/test_algorithm/test_a2c_algorithm.py create mode 100644 tests/test_algorithm/test_bc_algorithm.py create mode 100644 tests/test_algorithm/test_ddpg_algorithm.py create mode 100644 tests/test_algorithm/test_sac_algorithm.py diff --git a/tests/test_algorithm/test_a2c_algorithm.py b/tests/test_algorithm/test_a2c_algorithm.py new file mode 100644 index 00000000..0f4f7226 --- /dev/null +++ b/tests/test_algorithm/test_a2c_algorithm.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture( + scope="module", params=["--use_share_model false", "--use_share_model true"] +) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.ppo_module import PPOModule + + module = PPOModule( + config, + policy_input_space=obs_space, + critic_input_space=obs_space, + act_space=act_space, + share_model=config.use_share_model, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=100, + ) + return buffer.data + + +@pytest.mark.unittest +def test_a2c_algorithm(config, init_module, buffer_data): + from openrl.algorithms.a2c import A2CAlgorithm + + a2c_algo = A2CAlgorithm(config, init_module) + + a2c_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_bc_algorithm.py b/tests/test_algorithm/test_bc_algorithm.py new file mode 100644 index 00000000..fa073174 --- /dev/null +++ b/tests/test_algorithm/test_bc_algorithm.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture(scope="module", params=["", "--use_share_model true"]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.bc_module import BCModule + + module = BCModule( + config, + policy_input_space=obs_space, + critic_input_space=obs_space, + act_space=act_space, + share_model=config.use_share_model, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=100, + ) + return buffer.data + + +@pytest.mark.unittest +def test_bc_algorithm(config, init_module, buffer_data): + from openrl.algorithms.behavior_cloning import BCAlgorithm + + bc_algo = BCAlgorithm(config, init_module) + + bc_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_ddpg_algorithm.py b/tests/test_algorithm/test_ddpg_algorithm.py new file mode 100644 index 00000000..b31a56df --- /dev/null +++ b/tests/test_algorithm/test_ddpg_algorithm.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.ddpg_module import DDPGModule + + module = DDPGModule( + config, + input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer + + buffer = OffPolicyReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=5000, + ) + return buffer.data + + +@pytest.mark.unittest +def test_ddpg_algorithm(config, init_module, buffer_data): + from openrl.algorithms.ddpg import DDPGAlgorithm + + ddpg_algo = DDPGAlgorithm(config, init_module) + + ddpg_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_ppo_algorithm.py b/tests/test_algorithm/test_ppo_algorithm.py index 8ac5c865..98a8a5d4 100644 --- a/tests/test_algorithm/test_ppo_algorithm.py +++ b/tests/test_algorithm/test_ppo_algorithm.py @@ -33,7 +33,9 @@ def act_space(): return spaces.Discrete(2) -@pytest.fixture(scope="module", params=["", "--use_share_model true"]) +@pytest.fixture( + scope="module", params=["--use_share_model false", "--use_share_model true"] +) def config(request): from openrl.configs.config import create_config_parser diff --git a/tests/test_algorithm/test_sac_algorithm.py b/tests/test_algorithm/test_sac_algorithm.py new file mode 100644 index 00000000..80447a3a --- /dev/null +++ b/tests/test_algorithm/test_sac_algorithm.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.sac_module import SACModule + + module = SACModule( + config, + input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer + + buffer = OffPolicyReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=5000, + ) + return buffer.data + + +@pytest.mark.unittest +def test_sac_algorithm(config, init_module, buffer_data): + from openrl.algorithms.sac import SACAlgorithm + + sac_algo = SACAlgorithm(config, init_module) + + sac_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))