-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Make tests parameterizable * Add nondeterministic flag to mo-highway * Fix minecart observation bounds * Fix water-reservoir observation dtype
- Loading branch information
1 parent
6dc669c
commit 3bb3ce9
Showing
4 changed files
with
109 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from gymnasium.envs.registration import register | ||
|
||
|
||
register(id="mo-highway-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnv") | ||
register(id="mo-highway-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnv", nondeterministic=True) | ||
|
||
register(id="mo-highway-fast-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnvFast") | ||
register(id="mo-highway-fast-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnvFast", nondeterministic=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,142 +1,134 @@ | ||
import pickle | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
from gymnasium.envs.registration import EnvSpec | ||
from gymnasium.utils.env_checker import check_env, data_equivalence | ||
|
||
import mo_gymnasium as mo_gym | ||
from mo_gymnasium import LinearReward | ||
|
||
|
||
def _test_pickle_env(env: gym.Env): | ||
pickled_env = pickle.loads(pickle.dumps(env)) | ||
|
||
data_equivalence(env.reset(), pickled_env.reset()) | ||
|
||
action = env.action_space.sample() | ||
data_equivalence(env.step(action), pickled_env.step(action)) | ||
env.close() | ||
pickled_env.close() | ||
|
||
|
||
def test_deep_sea_treasure(): | ||
env = mo_gym.make("deep-sea-treasure-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_fishwood(): | ||
env = mo_gym.make("fishwood-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_four_room(): | ||
env = mo_gym.make("four-room-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_minecart(): | ||
env = mo_gym.make("minecart-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_mountaincar(): | ||
env = mo_gym.make("mo-mountaincar-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_continuous_mountaincar(): | ||
env = mo_gym.make("mo-mountaincarcontinuous-v0") | ||
env = LinearReward(env) | ||
all_testing_env_specs = [] | ||
for env_spec in gym.envs.registry.values(): | ||
if type(env_spec.entry_point) is not str: | ||
continue | ||
# collect MO Gymnasium envs | ||
if env_spec.entry_point.split(".")[0] == "mo_gymnasium": | ||
# Ignore highway as they do not deal with the random seed appropriately | ||
if not env_spec.id.startswith("mo-highway"): | ||
all_testing_env_specs.append(env_spec) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"spec", | ||
all_testing_env_specs, | ||
ids=[spec.id for spec in all_testing_env_specs], | ||
) | ||
def test_all_env_api(spec): | ||
"""Check that all environments pass the environment checker.""" | ||
env = mo_gym.make(spec.id) | ||
env = mo_gym.LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
|
||
|
||
def test_resource_gathering(): | ||
env = mo_gym.make("resource-gathering-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
@pytest.mark.parametrize("spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]) | ||
def test_all_env_passive_env_checker(spec): | ||
env = mo_gym.make(spec.id) | ||
env.reset() | ||
env.step(env.action_space.sample()) | ||
env.close() | ||
|
||
|
||
def test_fruit_tree(): | ||
env = mo_gym.make("fruit-tree-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
# Note that this precludes running this test in multiple threads. | ||
# However, we probably already can't do multithreading due to some environments. | ||
SEED = 0 | ||
NUM_STEPS = 50 | ||
|
||
|
||
def test_mario(): | ||
env = mo_gym.make("mo-supermario-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
@pytest.mark.parametrize( | ||
"env_spec", | ||
all_testing_env_specs, | ||
ids=[env.id for env in all_testing_env_specs], | ||
) | ||
def test_env_determinism_rollout(env_spec: EnvSpec): | ||
"""Run a rollout with two environments and assert equality. | ||
This test run a rollout of NUM_STEPS steps with two environments | ||
initialized with the same seed and assert that: | ||
- observation after first reset are the same | ||
- same actions are sampled by the two envs | ||
- observations are contained in the observation space | ||
- obs, rew, done and info are equals between the two envs | ||
""" | ||
# Don't check rollout equality if it's a nondeterministic environment. | ||
if env_spec.nondeterministic is True: | ||
return | ||
|
||
env_1 = env_spec.make(disable_env_checker=True) | ||
env_2 = env_spec.make(disable_env_checker=True) | ||
env_1 = mo_gym.LinearReward(env_1) | ||
env_2 = mo_gym.LinearReward(env_2) | ||
|
||
def test_reacher_pybullet(): | ||
env = mo_gym.make("mo-reacher-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED) | ||
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED) | ||
assert_equals(initial_obs_1, initial_obs_2) | ||
|
||
env_1.action_space.seed(SEED) | ||
|
||
def test_reacher_mujoco(): | ||
env = mo_gym.make("mo-reacher-v4") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
for time_step in range(NUM_STEPS): | ||
# We don't evaluate the determinism of actions | ||
action = env_1.action_space.sample() | ||
|
||
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action) | ||
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action) | ||
|
||
# TODO: failing because highway_env is not deterministic given a random seed | ||
""" def test_highway(): | ||
env = mo_gym.make('mo-highway-v0') | ||
env = LinearReward(env) | ||
check_env(env) | ||
assert_equals(obs_1, obs_2, f"[{time_step}] ") | ||
assert env_1.observation_space.contains(obs_1) # obs_2 verified by previous assertion | ||
|
||
def test_highway_fast(): | ||
env = mo_gym.make('mo-highway-fast-v0') | ||
env = LinearReward(env) | ||
check_env(env) """ | ||
|
||
|
||
def test_halfcheetah(): | ||
env = mo_gym.make("mo-halfcheetah-v4") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}" | ||
assert terminated_1 == terminated_2, f"[{time_step}] done 1={terminated_1}, done 2={terminated_2}" | ||
assert truncated_1 == truncated_2, f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}" | ||
assert_equals(info_1, info_2, f"[{time_step}] ") | ||
|
||
if terminated_1 or truncated_1: # terminated_2, truncated_2 verified by previous assertion | ||
env_1.reset(seed=SEED) | ||
env_2.reset(seed=SEED) | ||
|
||
def test_hopper(): | ||
env = mo_gym.make("mo-hopper-v4") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
env_1.close() | ||
env_2.close() | ||
|
||
|
||
def test_breakable_bottles(): | ||
env = mo_gym.make("breakable-bottles-v0") | ||
env = LinearReward(env) | ||
check_env(env) | ||
_test_pickle_env(env) | ||
def _test_pickle_env(env: gym.Env): | ||
pickled_env = pickle.loads(pickle.dumps(env)) | ||
|
||
data_equivalence(env.reset(), pickled_env.reset()) | ||
|
||
def test_water_reservoir(): | ||
env = mo_gym.make("water-reservoir-v0") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
action = env.action_space.sample() | ||
data_equivalence(env.step(action), pickled_env.step(action)) | ||
env.close() | ||
pickled_env.close() | ||
|
||
|
||
def test_lunar_lander(): | ||
env = mo_gym.make("mo-lunar-lander-v2") | ||
env = LinearReward(env) | ||
check_env(env, skip_render_check=True) | ||
_test_pickle_env(env) | ||
def assert_equals(a, b, prefix=None): | ||
"""Assert equality of data structures `a` and `b`. | ||
Args: | ||
a: first data structure | ||
b: second data structure | ||
prefix: prefix for failed assertion message for types and dicts | ||
""" | ||
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}" | ||
if isinstance(a, dict): | ||
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}" | ||
|
||
for k in a.keys(): | ||
v_a = a[k] | ||
v_b = b[k] | ||
assert_equals(v_a, v_b) | ||
elif isinstance(a, np.ndarray): | ||
np.testing.assert_array_equal(a, b) | ||
elif isinstance(a, tuple): | ||
for elem_from_a, elem_from_b in zip(a, b): | ||
assert_equals(elem_from_a, elem_from_b) | ||
else: | ||
assert a == b |