Skip to content

Commit

Permalink
Automatic generation of tests (#37)
Browse files Browse the repository at this point in the history
* Make tests parameterizable

* Add nondeterministic flag to mo-highway

* Fix minecart observation bounds

* Fix water-reservoir observation dtype
  • Loading branch information
LucasAlegre authored Feb 13, 2023
1 parent 6dc669c commit 3bb3ce9
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 117 deletions.
4 changes: 2 additions & 2 deletions mo_gymnasium/envs/highway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gymnasium.envs.registration import register


register(id="mo-highway-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnv")
register(id="mo-highway-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnv", nondeterministic=True)

register(id="mo-highway-fast-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnvFast")
register(id="mo-highway-fast-v0", entry_point="mo_gymnasium.envs.highway.highway:MOHighwayEnvFast", nondeterministic=True)
2 changes: 1 addition & 1 deletion mo_gymnasium/envs/minecart/minecart.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
dtype=np.uint8,
)
else:
self.observation_space = Box(np.zeros(7), np.ones(7), dtype=np.float32)
self.observation_space = Box(-np.ones(7), np.ones(7), dtype=np.float32)

self.action_space = Discrete(6)
self.reward_space = Box(low=-1, high=self.capacity, shape=(self.ore_cnt + 1,))
Expand Down
6 changes: 3 additions & 3 deletions mo_gymnasium/envs/water_reservoir/dam_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
if not self.penalize:
state = self.np_random.choice(DamEnv.s_init, size=1)
else:
state = self.np_random.randint(0, 160, size=1).astype(np.float32)
state = self.np_random.randint(0, 160, size=1)

self.state = np.array(state)
self.state = np.array(state, dtype=np.float32)
return self.state, {}

def step(self, action):
Expand All @@ -112,7 +112,7 @@ def step(self, action):
action = bounded_action
dam_inflow = self.np_random.normal(DamEnv.DAM_INFLOW_MEAN, DamEnv.DAM_INFLOW_STD, len(self.state))
# small chance dam_inflow < 0
n_state = np.clip(self.state + dam_inflow - action, 0, None)
n_state = np.clip(self.state + dam_inflow - action, 0, None).astype(np.float32)

# cost due to excess level wrt a flooding threshold (upstream)
r0 = -np.clip(n_state / DamEnv.S - DamEnv.H_FLO_U, 0, None) + penalty
Expand Down
214 changes: 103 additions & 111 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,142 +1,134 @@
import pickle

import gymnasium as gym
import numpy as np
import pytest
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import check_env, data_equivalence

import mo_gymnasium as mo_gym
from mo_gymnasium import LinearReward


def _test_pickle_env(env: gym.Env):
pickled_env = pickle.loads(pickle.dumps(env))

data_equivalence(env.reset(), pickled_env.reset())

action = env.action_space.sample()
data_equivalence(env.step(action), pickled_env.step(action))
env.close()
pickled_env.close()


def test_deep_sea_treasure():
env = mo_gym.make("deep-sea-treasure-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_fishwood():
env = mo_gym.make("fishwood-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_four_room():
env = mo_gym.make("four-room-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_minecart():
env = mo_gym.make("minecart-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_mountaincar():
env = mo_gym.make("mo-mountaincar-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_continuous_mountaincar():
env = mo_gym.make("mo-mountaincarcontinuous-v0")
env = LinearReward(env)
all_testing_env_specs = []
for env_spec in gym.envs.registry.values():
if type(env_spec.entry_point) is not str:
continue
# collect MO Gymnasium envs
if env_spec.entry_point.split(".")[0] == "mo_gymnasium":
# Ignore highway as they do not deal with the random seed appropriately
if not env_spec.id.startswith("mo-highway"):
all_testing_env_specs.append(env_spec)


@pytest.mark.parametrize(
"spec",
all_testing_env_specs,
ids=[spec.id for spec in all_testing_env_specs],
)
def test_all_env_api(spec):
"""Check that all environments pass the environment checker."""
env = mo_gym.make(spec.id)
env = mo_gym.LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)


def test_resource_gathering():
env = mo_gym.make("resource-gathering-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
@pytest.mark.parametrize("spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs])
def test_all_env_passive_env_checker(spec):
env = mo_gym.make(spec.id)
env.reset()
env.step(env.action_space.sample())
env.close()


def test_fruit_tree():
env = mo_gym.make("fruit-tree-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
# Note that this precludes running this test in multiple threads.
# However, we probably already can't do multithreading due to some environments.
SEED = 0
NUM_STEPS = 50


def test_mario():
env = mo_gym.make("mo-supermario-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
@pytest.mark.parametrize(
"env_spec",
all_testing_env_specs,
ids=[env.id for env in all_testing_env_specs],
)
def test_env_determinism_rollout(env_spec: EnvSpec):
"""Run a rollout with two environments and assert equality.
This test run a rollout of NUM_STEPS steps with two environments
initialized with the same seed and assert that:
- observation after first reset are the same
- same actions are sampled by the two envs
- observations are contained in the observation space
- obs, rew, done and info are equals between the two envs
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return

env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)
env_1 = mo_gym.LinearReward(env_1)
env_2 = mo_gym.LinearReward(env_2)

def test_reacher_pybullet():
env = mo_gym.make("mo-reacher-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2)

env_1.action_space.seed(SEED)

def test_reacher_mujoco():
env = mo_gym.make("mo-reacher-v4")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
for time_step in range(NUM_STEPS):
# We don't evaluate the determinism of actions
action = env_1.action_space.sample()

obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)

# TODO: failing because highway_env is not deterministic given a random seed
""" def test_highway():
env = mo_gym.make('mo-highway-v0')
env = LinearReward(env)
check_env(env)
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert env_1.observation_space.contains(obs_1) # obs_2 verified by previous assertion

def test_highway_fast():
env = mo_gym.make('mo-highway-fast-v0')
env = LinearReward(env)
check_env(env) """


def test_halfcheetah():
env = mo_gym.make("mo-halfcheetah-v4")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
assert terminated_1 == terminated_2, f"[{time_step}] done 1={terminated_1}, done 2={terminated_2}"
assert truncated_1 == truncated_2, f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
assert_equals(info_1, info_2, f"[{time_step}] ")

if terminated_1 or truncated_1: # terminated_2, truncated_2 verified by previous assertion
env_1.reset(seed=SEED)
env_2.reset(seed=SEED)

def test_hopper():
env = mo_gym.make("mo-hopper-v4")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
env_1.close()
env_2.close()


def test_breakable_bottles():
env = mo_gym.make("breakable-bottles-v0")
env = LinearReward(env)
check_env(env)
_test_pickle_env(env)
def _test_pickle_env(env: gym.Env):
pickled_env = pickle.loads(pickle.dumps(env))

data_equivalence(env.reset(), pickled_env.reset())

def test_water_reservoir():
env = mo_gym.make("water-reservoir-v0")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
action = env.action_space.sample()
data_equivalence(env.step(action), pickled_env.step(action))
env.close()
pickled_env.close()


def test_lunar_lander():
env = mo_gym.make("mo-lunar-lander-v2")
env = LinearReward(env)
check_env(env, skip_render_check=True)
_test_pickle_env(env)
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"

for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b

0 comments on commit 3bb3ce9

Please sign in to comment.