Skip to content

Commit

Permalink
[MinAtar] Restore MinAtar tests (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Jul 24, 2024
1 parent c6b9255 commit 95b0fb4
Show file tree
Hide file tree
Showing 7 changed files with 722 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ blackdoc
isort
flake8
mypy
git+https://github.com/sotetsuk/minatar
101 changes: 101 additions & 0 deletions tests/minatar_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import copy
from typing import Any, Dict

import numpy as np
from jax import numpy as jnp


INF = 99

def extract_state(env, state_keys):
state_dict = {}
# task-dependent attribute
for k in state_keys:
state_dict[k] = copy.deepcopy(getattr(env.env, k))
return state_dict


def assert_states(state1, state2):
keys = state1.keys()
assert keys == state2.keys()
for key in keys:
if key == "entities":
assert len(state1[key]) == len(state2[key])
for s1, s2 in zip(state1[key], state2[key]):
assert s1 == s2, f"{s1}, {s2}\n{state1}\n{state2}"
else:
assert np.allclose(
state1[key], state2[key]
), f"{key}, {state1[key]}, {state2[key]}\n{state1}\n{state2}"


def pgx2minatar(state, keys) -> Dict[str, Any]:
d = {}
for key in keys:
d[key] = copy.deepcopy(getattr(state, "_" + key))
if isinstance(d[key], jnp.ndarray):
d[key] = np.array(d[key])
if key == "entities":
val = [None] * 8
for i in range(8):
if d[key][i][0] != INF:
e = [d[key][i][j] for j in range(4)]
val[i] = e
d[key] = val
return d


def minatar2pgx(state_dict: Dict[str, Any], state_cls):
d = {}
for key in state_dict.keys():
val = copy.deepcopy(state_dict[key])

# Exception in Asterix
if key == "entities":
_val = [[INF if x is None else x[j] for j in range(4)] for i, x in enumerate(val)]
val = jnp.array(_val, dtype=jnp.int32)
d[key] = val
continue

# Exception in Seaquest
if key in ["f_bullets", "e_bullets", "e_fish", "e_subs", "divers"]:
N = 25 if key.startswith("e_") else 5
M = 3 if key.endswith("bullets") else 4
if key == "e_subs":
M = 5
v = - jnp.ones((N, M), dtype=jnp.int32)
for i, x in enumerate(val):
v = v.at[i, :].set(jnp.array(x))
d[key] = v
continue

# Cast to int32
if key in ["terminate_timer", "oxygen"]:
val = jnp.array(val, dtype=jnp.int32)
d[key] = val
continue

# Cast to bool
if isinstance(val, np.ndarray):
if key in (
"brick_map",
"alien_map",
"f_bullet_map",
"e_bullet_map",
"allien_map",
):
val = jnp.array(val, dtype=jnp.bool_)
else:
val = jnp.array(val, dtype=jnp.int32)
d[key] = val
continue

if key in ["terminal", "sub_or", "surface"]:
val = jnp.array(val, dtype=jnp.bool_)
else:
val = jnp.array(val, dtype=jnp.int32)
d[key] = val

d = {"_" + k: v for k, v in d.items()}
s = state_cls(**d)
return s
111 changes: 111 additions & 0 deletions tests/test_asterix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import jax
import random

from minatar import Environment

from pgx_minatar import asterix

from tests.minatar_utils import *

state_keys = {
"player_x",
"player_y",
"entities",
"shot_timer",
"spawn_speed",
"spawn_timer",
"move_speed",
"move_timer",
"ramp_timer",
"ramp_index",
"terminal",
"last_action",
}

INF = 99


_spawn_entity = jax.jit(asterix._spawn_entity)
_step_det = jax.jit(asterix._step_det)
_observe = jax.jit(asterix._observe)

def test_spawn_entity():
entities = jnp.ones((8, 4), dtype=jnp.int32) * INF
entities = entities.at[:, :].set(
_spawn_entity(entities, True, True, 1)
)
assert entities[1][0] == 0, entities
assert entities[1][1] == 2, entities
assert entities[1][2] == 1, entities
assert entities[1][3] == 1, entities


def test_step_det():
env = Environment("asterix", sticky_action_prob=0.0)
num_actions = env.num_actions()

N = 100
for _ in range(N):
env.reset()
done = False
while not done:
s = extract_state(env, state_keys)
a = random.randrange(num_actions)
r, done = env.act(a)
lr, is_gold, slot = env.env.lr, env.env.is_gold, env.env.slot
s_next = extract_state(env, state_keys)
s_next_pgx = _step_det(
minatar2pgx(s, asterix.State),
a,
lr,
is_gold,
slot,
)
assert_states(s_next, pgx2minatar(s_next_pgx, state_keys))
assert r == s_next_pgx.rewards[0]
assert done == s_next_pgx.terminated


def test_observe():
env = Environment("asterix", sticky_action_prob=0.0)
num_actions = env.num_actions()

N = 10
for _ in range(N):
env.reset()
done = False
while not done:
s = extract_state(env, state_keys)
s_pgx = minatar2pgx(s, asterix.State)
obs_pgx = _observe(s_pgx)
assert jnp.allclose(
env.state(),
obs_pgx,
)
a = random.randrange(num_actions)
r, done = env.act(a)

# check terminal state
s = extract_state(env, state_keys)
s_pgx = minatar2pgx(s, asterix.State)
obs_pgx = _observe(s_pgx)
assert jnp.allclose(
env.state(),
obs_pgx,
)


def test_minimal_action_set():
import pgx
env = pgx.make("minatar-asterix")
assert env.num_actions == 5
state = jax.jit(env.init)(jax.random.PRNGKey(0))
assert state.legal_action_mask.shape == (5,)
state = jax.jit(env.step)(state, 0, jax.random.PRNGKey(0))
assert state.legal_action_mask.shape == (5,)


def test_api():
import pgx
env = pgx.make("minatar-asterix")
pgx.api_test(env, 10, True)
100 changes: 100 additions & 0 deletions tests/test_breakout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import random
import jax

from minatar import Environment

from pgx_minatar import breakout

from tests.minatar_utils import *

state_keys = {
"ball_y",
"ball_x",
"ball_dir",
"pos",
"brick_map",
"strike",
"last_x",
"last_y",
"terminal",
"last_action",
}
_step_det = jax.jit(breakout._step_det)
_init_det = jax.jit(breakout._init_det)
observe = jax.jit(breakout._observe)

def test_step_det():
env = Environment("breakout", sticky_action_prob=0.0)
num_actions = env.num_actions()

N = 1000
for n in range(N):
env.reset()
done = False
while not done:
s = extract_state(env, state_keys)
a = random.randrange(num_actions)
r, done = env.act(a)
s_next = extract_state(env, state_keys)
s_next_pgx = _step_det(
minatar2pgx(s, breakout.State), a
)
assert_states(s_next, pgx2minatar(s_next_pgx, state_keys))
assert r == s_next_pgx.rewards[0]
assert done == s_next_pgx.terminated


def test_init_det():
env = Environment("breakout", sticky_action_prob=0.0)
N = 1
for _ in range(N):
env.reset()
ball_start = 0 if env.env.ball_x == 0 else 1
s = extract_state(env, state_keys)
s_pgx = _init_det(ball_start)
assert_states(s, pgx2minatar(s_pgx, state_keys))


def test_observe():
env = Environment("breakout", sticky_action_prob=0.0)
num_actions = env.num_actions()

N = 100
for _ in range(N):
env.reset()
done = False
while not done:
s = extract_state(env, state_keys)
s_pgx = minatar2pgx(s, breakout.State)
obs_pgx = observe(s_pgx)
assert jnp.allclose(
env.state(),
obs_pgx,
)
a = random.randrange(num_actions)
r, done = env.act(a)

# check terminal state
s = extract_state(env, state_keys)
s_pgx = minatar2pgx(s, breakout.State)
obs_pgx = observe(s_pgx)
assert jnp.allclose(
env.state(),
obs_pgx,
)


def test_minimal_action_set():
import pgx
env = pgx.make("minatar-breakout")
assert env.num_actions == 3
state = jax.jit(env.init)(jax.random.PRNGKey(0))
assert state.legal_action_mask.shape == (3,)
state = jax.jit(env.step)(state, 0, jax.random.PRNGKey(0))
assert state.legal_action_mask.shape == (3,)


def test_api():
import pgx
env = pgx.make("minatar-breakout")
pgx.api_test(env, 10, True)
Loading

0 comments on commit 95b0fb4

Please sign in to comment.