Skip to content

Commit

Permalink
[Typing] Add types.py (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Nov 6, 2023
1 parent 0dfe71c commit 3c1f9e8
Show file tree
Hide file tree
Showing 26 changed files with 596 additions and 615 deletions.
18 changes: 10 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@


install-dev:
python3 -m pip install \
pytest==7.1.2 \
python3 -m pip install -U pip
python3 -m pip install -U \
pytest \
matplotlib \
ipython \
jax[cpu] \
Expand All @@ -12,12 +13,13 @@ install-dev:
pgx-minatar

install-fmt:
python3 -m pip install \
black==22.6.0 \
blackdoc==0.3.6 \
isort==5.10.1 \
flake8==5.0.4 \
mypy==0.971
python3 -m pip install -U pip
python3 -m pip install -U \
black \
blackdoc \
isort \
flake8 \
mypy

clean:
rm -rf build
Expand Down
4 changes: 4 additions & 0 deletions pgx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pgx._src.api_test import api_test
from pgx._src.baseline import BaselineModelId, make_baseline_model
from pgx._src.types import Array, PRNGKey
from pgx._src.visualizer import (
save_svg,
save_svg_animation,
Expand All @@ -10,6 +11,9 @@
__version__ = "2.0.0"

__all__ = [
# types
"Array",
"PRNGKey",
# v1 api components
"State",
"Env",
Expand Down
5 changes: 5 additions & 0 deletions pgx/_src/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Any

# typing only for documentation
Array = Any
PRNGKey = Any
2 changes: 1 addition & 1 deletion pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def save_svg_animation(
[
e
for e in dwg.elements
if type(e) == svgwrite.container.Group
if type(e) is svgwrite.container.Group
]
)
== 1
Expand Down
55 changes: 27 additions & 28 deletions pgx/animal_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

TRUE = jnp.bool_(True)
FALSE = jnp.bool_(False)
Expand Down Expand Up @@ -69,28 +70,28 @@

@dataclass
class State(core.State):
current_player: jnp.ndarray = jnp.int32(0)
rewards: jnp.ndarray = jnp.float32([0.0, 0.0])
terminated: jnp.ndarray = FALSE
truncated: jnp.ndarray = FALSE
legal_action_mask: jnp.ndarray = jnp.ones(132, dtype=jnp.bool_) # (132,)
observation: jnp.ndarray = jnp.zeros((4, 3, 22), dtype=jnp.bool_)
_step_count: jnp.ndarray = jnp.int32(0)
current_player: Array = jnp.int32(0)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(132, dtype=jnp.bool_) # (132,)
observation: Array = jnp.zeros((4, 3, 22), dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Animal Shogi specific ---
_turn: jnp.ndarray = jnp.int32(0)
_board: jnp.ndarray = INIT_BOARD # (12,)
_hand: jnp.ndarray = jnp.zeros((2, 3), dtype=jnp.int32)
_zobrist_hash: jnp.ndarray = jnp.uint32([233882788, 593924309])
_hash_history: jnp.ndarray = (
_turn: Array = jnp.int32(0)
_board: Array = INIT_BOARD # (12,)
_hand: Array = jnp.zeros((2, 3), dtype=jnp.int32)
_zobrist_hash: Array = jnp.uint32([233882788, 593924309])
_hash_history: Array = (
jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32)
.at[0]
.set(jnp.uint32([233882788, 593924309]))
)
_board_history: jnp.ndarray = (
_board_history: Array = (
(-jnp.ones((8, 12), dtype=jnp.int32)).at[0, :].set(INIT_BOARD)
)
_hand_history: jnp.ndarray = jnp.zeros((8, 6), dtype=jnp.int32)
_rep_history: jnp.ndarray = jnp.zeros((8,), dtype=jnp.int32)
_hand_history: Array = jnp.zeros((8, 6), dtype=jnp.int32)
_rep_history: Array = jnp.zeros((8,), dtype=jnp.int32)

@property
def env_id(self) -> core.EnvId:
Expand All @@ -99,13 +100,13 @@ def env_id(self) -> core.EnvId:

@dataclass
class Action:
is_drop: jnp.ndarray = FALSE
from_: jnp.ndarray = jnp.int32(-1)
to: jnp.ndarray = jnp.int32(-1)
drop_piece: jnp.ndarray = jnp.int32(-1)
is_drop: Array = FALSE
from_: Array = jnp.int32(-1)
to: Array = jnp.int32(-1)
drop_piece: Array = jnp.int32(-1)

@staticmethod
def _from_label(a: jnp.ndarray):
def _from_label(a: Array):
# Implements AlphaZero like action label:
# 132 labels =
# [Move] 8 (direction) * 12 (from_) +
Expand All @@ -123,13 +124,13 @@ class AnimalShogi(core.Env):
def __init__(self):
super().__init__()

def _init(self, key: jax.random.KeyArray) -> State:
def _init(self, key: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(key))
state = State(current_player=current_player) # type: ignore
state = state.replace(legal_action_mask=_legal_action_mask(state)) # type: ignore
return state

def _step(self, state: core.State, action: jnp.ndarray, key) -> State:
def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
state = _step(state, action)
Expand All @@ -141,9 +142,7 @@ def _step(self, state: core.State, action: jnp.ndarray, key) -> State:
)
return state # type: ignore

def _observe(
self, state: core.State, player_id: jnp.ndarray
) -> jnp.ndarray:
def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
return _observe(state, player_id)

Expand All @@ -160,7 +159,7 @@ def num_players(self) -> int:
return 2


def _step(state: State, action: jnp.ndarray):
def _step(state: State, action: Array):
a = Action._from_label(action)
# apply move/drop action
state = jax.lax.cond(a.is_drop, _step_drop, _step_move, *(state, a))
Expand Down Expand Up @@ -205,7 +204,7 @@ def _step(state: State, action: jnp.ndarray):
)


def _observe(state: State, player_id: jnp.ndarray) -> jnp.ndarray:
def _observe(state: State, player_id: Array) -> Array:
# player_id's color
color = jax.lax.select(
state.current_player == player_id, state._turn, 1 - state._turn
Expand Down Expand Up @@ -308,7 +307,7 @@ def _step_drop(state: State, action: Action) -> State:


def _legal_action_mask(state: State):
def is_legal(label: jnp.ndarray):
def is_legal(label: Array):
action = Action._from_label(label)
return jax.lax.cond(
action.is_drop, is_legal_drop, is_legal_move, action
Expand Down
Loading

0 comments on commit 3c1f9e8

Please sign in to comment.