From 6c582b68f49a9bebda7c4af46f36f50887dca3e8 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Wed, 13 Mar 2024 16:56:47 +0900 Subject: [PATCH] [Hotfix] Restore CI (#1175) --- tests/test_baseline.py | 32 -------------------------------- tests/test_bridge_bidding.py | 2 +- 2 files changed, 1 insertion(+), 33 deletions(-) delete mode 100644 tests/test_baseline.py diff --git a/tests/test_baseline.py b/tests/test_baseline.py deleted file mode 100644 index 12bb6f71f..000000000 --- a/tests/test_baseline.py +++ /dev/null @@ -1,32 +0,0 @@ -import jax -import pgx -import pgx - -import haiku as hk - - -# def test_az_basline(): -# batch_size = 2 -# test_cases = ( -# ("animal_shogi", "animal_shogi_v0"), -# ("gardner_chess", "gardner_chess_v0"), -# ("go_9x9", "go_9x9_v0"), -# ("hex", "hex_v0"), -# ("othello", "othello_v0"), -# ("minatar-asterix", "minatar-asterix_v0"), -# ("minatar-breakout", "minatar-breakout_v0"), -# ("minatar-freeway", "minatar-freeway_v0"), -# ("minatar-seaquest", "minatar-seaquest_v0"), -# ("minatar-space_invaders", "minatar-space_invaders_v0") -# ) -# -# for env_id, model_id in test_cases: -# env = pgx.make(env_id) -# model = pgx.make_baseline_model(model_id) -# state = jax.jit(jax.vmap(env.init))( -# jax.random.split(jax.random.PRNGKey(0), batch_size) -# ) -# -# logits, value = model(state.observation) -# assert logits.shape == (batch_size, env.num_actions) -# assert value.shape == (batch_size,) diff --git a/tests/test_bridge_bidding.py b/tests/test_bridge_bidding.py index f5e5474d4..28a4d7a18 100644 --- a/tests/test_bridge_bidding.py +++ b/tests/test_bridge_bidding.py @@ -25,7 +25,7 @@ ) -def init(rng: jax.random.KeyArray) -> State: +def init(rng: jax.Array) -> State: rng1, rng2, rng3, rng4, rng5, rng6 = jax.random.split(rng, num=6) hand = jnp.arange(0, 52) hand = jax.random.permutation(rng2, hand)