Skip to content

Commit

Permalink
[Kuhn Poker] Simplify action space (#1171)
Browse files Browse the repository at this point in the history
  • Loading branch information
Egiob committed Mar 13, 2024
1 parent 6c582b6 commit 309bf22
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 56 deletions.
21 changes: 10 additions & 11 deletions docs/kuhn_poker.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ Kuhn poker is a simplified poker with three cards: J, Q, and K.
## Rules

Each player is dealt one card and the remaining card is unused.
There are four actions: *check*, *call*, *bet*, and *fold* and five possible scenarios.
There are two actions: *bet* and *pass* and five possible scenarios.

1. `bet (1st) - call (2nd)` : *Showdown* and the winner takes `+2`
2. `bet (1st) - fold (2nd)` : 1st player takes `+1`
3. `check (1st) - check (2nd)` : *Showdown* and the winner takes `+1`
4. `check (1st) - bet (2nd) - call (1st)` : *Showdown* and the winner takes `+2`
5. `check (1st) - bet (2nd) - fold (1st)` : 2nd takes `+1`
1. `bet (1st) - bet (2nd)` : *Showdown* and the winner takes `+2`
2. `bet (1st) - pass (2nd)` : 1st player takes `+1`
3. `pass (1st) - pass (2nd)` : *Showdown* and the winner takes `+1`
4. `pass (1st) - bet (2nd) - bet (1st)` : *Showdown* and the winner takes `+2`
5. `pass (1st) - bet (2nd) - pass (1st)` : 2nd takes `+1`

## Specs

| Name | Value |
|:---|:----:|
| Version | `v0` |
| Number of players | `2` |
| Number of actions | `4` |
| Number of actions | `2` |
| Observation shape | `(7,)` |
| Observation type | `bool` |
| Rewards | `{-2, -1, +1, +2}` |
Expand All @@ -55,10 +55,9 @@ There are four distinct actions.

| Action | Index |
|:---|----:|
| Call | 0|
| Bet | 1|
| Fold | 2|
| Check | 3|
| Bet | 0|
| Pass | 1|


## Rewards
The winner takes `+2` or `+1` depending on the game payoff.
Expand Down
30 changes: 10 additions & 20 deletions pgx/kuhn_poker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)
CALL = jnp.int32(0)
BET = jnp.int32(1)
FOLD = jnp.int32(2)
CHECK = jnp.int32(3)
BET = jnp.int32(0)
PASS = jnp.int32(1)


@dataclass
Expand All @@ -34,13 +32,13 @@ class State(core.State):
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(4, dtype=jnp.bool_)
legal_action_mask: Array = jnp.ones(2, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Kuhn poker specific ---
_cards: Array = jnp.int32([-1, -1])
# [(player 0),(player 1)]
_last_action: Array = jnp.int32(-1)
# 0(Call) 1(Bet) 2(Fold) 3(Check)
# 0(Bet) 1(Pass)
_pot: Array = jnp.int32([0, 0])

@property
Expand Down Expand Up @@ -84,46 +82,38 @@ def _init(rng: PRNGKey) -> State:
return State( # type:ignore
current_player=current_player,
_cards=init_card,
legal_action_mask=jnp.bool_([0, 1, 0, 1]),
legal_action_mask=jnp.bool_([1, 1]),
)


def _step(state: State, action):
action = jnp.int32(action)
pot = jax.lax.cond(
(action == BET) | (action == CALL),
(action == BET),
lambda: state._pot.at[state.current_player].add(1),
lambda: state._pot,
)

terminated, reward = jax.lax.cond(
action == FOLD,
(state._last_action == BET) & (action == PASS),
lambda: (
TRUE,
jnp.float32([-1, -1]).at[1 - state.current_player].set(1),
),
lambda: (FALSE, jnp.float32([0, 0])),
)
terminated, reward = jax.lax.cond(
(state._last_action == BET) & (action == CALL),
(state._last_action == BET) & (action == BET),
lambda: (TRUE, _get_unit_reward(state) * 2),
lambda: (terminated, reward),
)
terminated, reward = jax.lax.cond(
(state._last_action == CHECK) & (action == CHECK),
(state._last_action == PASS) & (action == PASS),
lambda: (TRUE, _get_unit_reward(state)),
lambda: (terminated, reward),
)

legal_action = jax.lax.switch(
action,
[
lambda: jnp.bool_([0, 0, 0, 0]), # CALL
lambda: jnp.bool_([1, 0, 1, 0]), # BET
lambda: jnp.bool_([0, 0, 0, 0]), # FOLD
lambda: jnp.bool_([0, 1, 0, 1]), # CHECK
],
)
legal_action = jax.lax.select(terminated, jnp.bool_([0, 0]), jnp.bool_([1, 1]))

return state.replace( # type:ignore
current_player=1 - state.current_player,
Expand Down
52 changes: 27 additions & 25 deletions tests/test_kuhn_poker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import jax.numpy as jnp
from pgx.kuhn_poker import KuhnPoker, CALL, BET, FOLD, CHECK

from pgx.kuhn_poker import BET, PASS, KuhnPoker

env = KuhnPoker()
init = jax.jit(env.init)
Expand All @@ -12,48 +13,48 @@ def test_init():
key = jax.random.PRNGKey(0)
state = init(key=key)
assert state._cards[0] != state._cards[1]
assert (state.legal_action_mask == jnp.bool_([0, 1, 0, 1])).all()
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()


def test_step():
key = jax.random.PRNGKey(0)
# cards = [2, 0]
state = init(key)
state = step(state, CHECK)
state = step(state, PASS)
assert not state.terminated
state = step(state, CHECK)
state = step(state, PASS)
assert state.terminated
assert (state.rewards == jnp.float32([1, -1])).all()

state = init(key)
state = step(state, CHECK)
state = step(state, PASS)
assert not state.terminated
state = step(state, BET)
assert not state.terminated
state = step(state, FOLD)
state = step(state, PASS)
assert state.terminated
assert (state.rewards == jnp.float32([-1, 1])).all()

state = init(key)
state = step(state, CHECK)
state = step(state, PASS)
assert not state.terminated
state = step(state, BET)
assert not state.terminated
state = step(state, CALL)
state = step(state, BET)
assert state.terminated
assert (state.rewards == jnp.float32([2, -2])).all()

state = init(key)
state = step(state, BET)
assert not state.terminated
state = step(state, FOLD)
state = step(state, PASS)
assert state.terminated
assert (state.rewards == jnp.float32([1, -1])).all()

state = init(key)
state = step(state, BET)
assert not state.terminated
state = step(state, CALL)
state = step(state, BET)
assert state.terminated
assert (state.rewards == jnp.float32([2, -2])).all()

Expand All @@ -62,37 +63,37 @@ def test_legal_action():
key = jax.random.PRNGKey(0)
# cards = [2, 0]
state = init(key)
state = step(state, CHECK)
assert (state.legal_action_mask == jnp.bool_([0, 1, 0, 1])).all()
state = step(state, CHECK)
state = step(state, PASS)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, PASS)
assert state.terminated

state = init(key)
state = step(state, CHECK)
assert (state.legal_action_mask == jnp.bool_([0, 1, 0, 1])).all()
state = step(state, PASS)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, BET)
assert (state.legal_action_mask == jnp.bool_([1, 0, 1, 0])).all()
state = step(state, FOLD)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, PASS)
assert state.terminated

state = init(key)
state = step(state, CHECK)
assert (state.legal_action_mask == jnp.bool_([0, 1, 0, 1])).all()
state = step(state, PASS)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, BET)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, BET)
assert (state.legal_action_mask == jnp.bool_([1, 0, 1, 0])).all()
state = step(state, CALL)
assert state.terminated

state = init(key)
state = step(state, BET)
assert (state.legal_action_mask == jnp.bool_([1, 0, 1, 0])).all()
state = step(state, FOLD)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, PASS)
assert state.terminated

state = init(key)
state = step(state, BET)
assert (state.legal_action_mask == jnp.bool_([1, 0, 1, 0])).all()
state = step(state, CALL)
assert (state.legal_action_mask == jnp.bool_([1, 1])).all()
state = step(state, BET)
assert state.terminated


Expand All @@ -114,6 +115,7 @@ def test_observation():

def test_api():
import pgx

env = pgx.make("kuhn_poker")
pgx.api_test(env, 3, use_key=False)
pgx.api_test(env, 3, use_key=True)

0 comments on commit 309bf22

Please sign in to comment.