Skip to content

Commit

Permalink
[Chess] Accelerate chess_utils.py import (jnp => np) (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akulen committed Mar 13, 2024
1 parent 2c36418 commit 262afea
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 61 deletions.
68 changes: 37 additions & 31 deletions pgx/_src/chess_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# type: ignore
import jax.numpy as jnp
import jax.random
import numpy as np

TO_MAP = -jnp.ones((64, 73), dtype=jnp.int32)
PLANE_MAP = -jnp.ones((64, 64), dtype=jnp.int32) # ignores underpromotion
TO_MAP = -np.ones((64, 73), dtype=np.int32)
PLANE_MAP = -np.ones((64, 64), dtype=np.int32) # ignores underpromotion
# underpromotiona
for from_ in range(64):
if (from_ % 8) not in (1, 6):
Expand All @@ -18,10 +19,10 @@
# black
# 2 6 14 22 30 38 46 54 62
# 1 7 15 23 31 39 47 55 63
to = from_ + jnp.int32([+1, +9, -7])[dir_]
to = from_ + [+1, +9, -7][dir_]
if not (0 <= to < 64):
continue
TO_MAP = TO_MAP.at[from_, plane].set(to)
TO_MAP[from_, plane] = to
# normal move
seq = list(range(1, 8))
zeros = [0 for _ in range(7)]
Expand Down Expand Up @@ -59,12 +60,12 @@
c = c + dc[plane - 9]
if r < 0 or r >= 8 or c < 0 or c >= 8:
continue
to = jnp.int32(c * 8 + r)
TO_MAP = TO_MAP.at[from_, plane].set(to)
PLANE_MAP = PLANE_MAP.at[from_, to].set(jnp.int32(plane))
to = c * 8 + r
TO_MAP[from_, plane] = to
PLANE_MAP[from_, to] = plane


CAN_MOVE = -jnp.ones((7, 64, 27), jnp.int32)
CAN_MOVE = -np.ones((7, 64, 27), np.int32)
# usage: CAN_MOVE[piece, from_x, from_y]
# CAN_MOVE[0, :, :] are all -1
# Note that the board is not symmetric about the center (different from shogi)
Expand All @@ -79,25 +80,25 @@
legal_dst = []
for to in range(64):
r1, c1 = to % 8, to // 8
if jnp.abs(r1 - r0) == 1 and jnp.abs(c1 - c0) <= 1:
if np.abs(r1 - r0) == 1 and np.abs(c1 - c0) <= 1:
legal_dst.append(to)
# init move
if (r0 == 1 or r0 == 6) and (jnp.abs(c1 - c0) == 0 and jnp.abs(r1 - r0) == 2):
if (r0 == 1 or r0 == 6) and (np.abs(c1 - c0) == 0 and np.abs(r1 - r0) == 2):
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[1, from_, : len(legal_dst)] = legal_dst
# KNIGHT
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
legal_dst = []
for to in range(64):
r1, c1 = to % 8, to // 8
if jnp.abs(r1 - r0) == 1 and jnp.abs(c1 - c0) == 2:
if np.abs(r1 - r0) == 1 and np.abs(c1 - c0) == 2:
legal_dst.append(to)
if jnp.abs(r1 - r0) == 2 and jnp.abs(c1 - c0) == 1:
if np.abs(r1 - r0) == 2 and np.abs(c1 - c0) == 1:
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[2, from_, : len(legal_dst)] = legal_dst
# BISHOP
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -106,10 +107,10 @@
r1, c1 = to % 8, to // 8
if from_ == to:
continue
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
if np.abs(r1 - r0) == np.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[3, from_, : len(legal_dst)] = legal_dst
# ROOK
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -118,10 +119,10 @@
r1, c1 = to % 8, to // 8
if from_ == to:
continue
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
if np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0:
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[4, from_, : len(legal_dst)] = legal_dst
# QUEEN
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -130,12 +131,12 @@
r1, c1 = to % 8, to // 8
if from_ == to:
continue
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
if np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0:
legal_dst.append(to)
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
if np.abs(r1 - r0) == np.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[5, from_, : len(legal_dst)] = legal_dst
# KING
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -144,19 +145,19 @@
r1, c1 = to % 8, to // 8
if from_ == to:
continue
if (jnp.abs(r1 - r0) <= 1) and (jnp.abs(c1 - c0) <= 1):
if (np.abs(r1 - r0) <= 1) and (np.abs(c1 - c0) <= 1):
legal_dst.append(to)
# castling
# if from_ == 32:
# legal_dst += [16, 48]
# if from_ == 39:
# legal_dst += [23, 55]
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[6, from_, : len(legal_dst)] = legal_dst

assert (CAN_MOVE[0, :, :] == -1).all()

CAN_MOVE_ANY = -jnp.ones((64, 35), jnp.int32)
CAN_MOVE_ANY = -np.ones((64, 35), np.int32)
for from_ in range(64):
legal_dst = []
for i in range(27):
Expand All @@ -167,16 +168,16 @@
to = CAN_MOVE[2, from_, i] # KNIGHT
if to >= 0:
legal_dst.append(to)
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE_ANY[from_, : len(legal_dst)] = legal_dst


# Between
BETWEEN = -jnp.ones((64, 64, 6), dtype=jnp.int32)
BETWEEN = -np.ones((64, 64, 6), dtype=np.int32)
for from_ in range(64):
for to in range(64):
r0, c0 = from_ % 8, from_ // 8
r1, c1 = to % 8, to // 8
if not ((jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0) or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))):
if not ((np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0) or (np.abs(r1 - r0) == np.abs(c1 - c0))):
continue
dr = max(min(r1 - r0, 1), -1)
dc = max(min(c1 - c0, 1), -1)
Expand All @@ -190,17 +191,22 @@
break
bet.append(c * 8 + r)
assert len(bet) <= 6
BETWEEN = BETWEEN.at[from_, to, : len(bet)].set(jnp.int32(bet))
BETWEEN[from_, to, : len(bet)] = bet

INIT_LEGAL_ACTION_MASK = jnp.zeros(64 * 73, dtype=jnp.bool_)
INIT_LEGAL_ACTION_MASK = np.zeros(64 * 73, dtype=np.bool_)
# fmt: off
ixs = [89, 90, 652, 656, 673, 674, 1257, 1258, 1841, 1842, 2425, 2426, 3009, 3010, 3572, 3576, 3593, 3594, 4177, 4178]
# fmt: on
for ix in ixs:
INIT_LEGAL_ACTION_MASK = INIT_LEGAL_ACTION_MASK.at[ix].set(True)
INIT_LEGAL_ACTION_MASK[ixs] = True
assert INIT_LEGAL_ACTION_MASK.shape == (64 * 73,)
assert INIT_LEGAL_ACTION_MASK.sum() == 20

TO_MAP = jnp.array(TO_MAP)
PLANE_MAP = jnp.array(PLANE_MAP)
CAN_MOVE = jnp.array(CAN_MOVE)
CAN_MOVE_ANY = jnp.array(CAN_MOVE_ANY)
BETWEEN = jnp.array(BETWEEN)
INIT_LEGAL_ACTION_MASK = jnp.array(INIT_LEGAL_ACTION_MASK)
INIT_POSSIBLE_PIECE_POSITIONS = jnp.int32(
[
[0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57],
Expand Down
68 changes: 38 additions & 30 deletions pgx/_src/gardner_chess_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# type: ignore
import jax
import jax.numpy as jnp
import numpy as np

TO_MAP = -jnp.ones((25, 49), dtype=jnp.int32)
PLANE_MAP = -jnp.ones((25, 25), dtype=jnp.int32) # ignores underpromotions
TO_MAP = -np.ones((25, 49), dtype=np.int32)
PLANE_MAP = -np.ones((25, 25), dtype=np.int32) # ignores underpromotions
# underpromotions
for from_ in range(25):
if from_ % 5 != 3: # 4th row in current player view
Expand All @@ -16,10 +17,10 @@
# board index (flipped black view)
# 5 0 5 10 15 20
# 4 1 6 11 16 21
to = from_ + jnp.int32([+1, +6, -4])[dir_]
to = from_ + [+1, +6, -4][dir_]
if not (0 <= to < 25):
continue
TO_MAP = TO_MAP.at[from_, plane].set(to)
TO_MAP[from_, plane] = to
# normal move
# fmt: off
dr = [-4, -3, -2, -1, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, -4, -3, -2, -1, 1, 2, 3, 4, 4, 3, 2, 1, -1, -2, -3, -4, -1, +1, -2, +2, -1, +1, -2, +2] # noqa
Expand All @@ -32,18 +33,18 @@
c = c + dc[plane - 9]
if r < 0 or r >= 5 or c < 0 or c >= 5:
continue
to = jnp.int32(c * 5 + r)
TO_MAP = TO_MAP.at[from_, plane].set(to)
PLANE_MAP = PLANE_MAP.at[from_, to].set(jnp.int32(plane))
to = c * 5 + r
TO_MAP[from_, plane] = to
PLANE_MAP[from_, to] = plane


CAN_MOVE = -jnp.ones((7, 25, 16), jnp.int32)
CAN_MOVE = -np.ones((7, 25, 16), np.int32)
# usage: CAN_MOVE[piece, from_x, from_y]
# CAN_MOVE[0, :, :] are all -1
# Note that the board is not symmetric about the center (different from shogi)
# You can imagine that the viewpoint is always from the white side.
# Except PAWN, the moves are symmetric about the center.
# We define PAWN as a piece that can move up, down, and diagonally, and filter it according to the turn.
# We define PAWN as a piece that can move up and diagonally.


# PAWN
Expand All @@ -52,22 +53,22 @@
legal_dst = []
for to in range(25):
r1, c1 = to % 5, to // 5
if r1 - r0 == 1 and jnp.abs(c1 - c0) <= 1:
if r1 - r0 == 1 and np.abs(c1 - c0) <= 1:
legal_dst.append(to)
assert len(legal_dst) <= 6, f"{from_=}, {to=}, {legal_dst=}"
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[1, from_, : len(legal_dst)] = legal_dst
# KNIGHT
for from_ in range(25):
r0, c0 = from_ % 5, from_ // 5
legal_dst = []
for to in range(25):
r1, c1 = to % 5, to // 5
if jnp.abs(r1 - r0) == 1 and jnp.abs(c1 - c0) == 2:
if np.abs(r1 - r0) == 1 and np.abs(c1 - c0) == 2:
legal_dst.append(to)
if jnp.abs(r1 - r0) == 2 and jnp.abs(c1 - c0) == 1:
if np.abs(r1 - r0) == 2 and np.abs(c1 - c0) == 1:
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[2, from_, : len(legal_dst)] = legal_dst
# BISHOP
for from_ in range(25):
r0, c0 = from_ % 5, from_ // 5
Expand All @@ -76,10 +77,10 @@
r1, c1 = to % 5, to // 5
if from_ == to:
continue
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
if np.abs(r1 - r0) == np.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[3, from_, : len(legal_dst)] = legal_dst
# ROOK
for from_ in range(25):
r0, c0 = from_ % 5, from_ // 5
Expand All @@ -88,10 +89,10 @@
r1, c1 = to % 5, to // 5
if from_ == to:
continue
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
if np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0:
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[4, from_, : len(legal_dst)] = legal_dst
# QUEEN
for from_ in range(25):
r0, c0 = from_ % 5, from_ // 5
Expand All @@ -100,12 +101,12 @@
r1, c1 = to % 5, to // 5
if from_ == to:
continue
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
if np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0:
legal_dst.append(to)
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
if np.abs(r1 - r0) == np.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 16
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[5, from_, : len(legal_dst)] = legal_dst
# KING
for from_ in range(25):
r0, c0 = from_ % 5, from_ // 5
Expand All @@ -114,14 +115,14 @@
r1, c1 = to % 5, to // 5
if from_ == to:
continue
if (jnp.abs(r1 - r0) <= 1) and (jnp.abs(c1 - c0) <= 1):
if (np.abs(r1 - r0) <= 1) and (np.abs(c1 - c0) <= 1):
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE[6, from_, : len(legal_dst)] = legal_dst

assert (CAN_MOVE[0, :, :] == -1).all()

CAN_MOVE_ANY = -jnp.ones((25, 24), jnp.int32)
CAN_MOVE_ANY = -np.ones((25, 24), np.int32)
for from_ in range(25):
legal_dst = []
for i in range(16):
Expand All @@ -133,14 +134,14 @@
if to >= 0:
legal_dst.append(to)
assert len(legal_dst) <= 24
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(jnp.int32(legal_dst))
CAN_MOVE_ANY[from_, : len(legal_dst)] = legal_dst

BETWEEN = -jnp.ones((25, 25, 3), dtype=jnp.int32)
BETWEEN = -np.ones((25, 25, 3), dtype=np.int32)
for from_ in range(25):
for to in range(25):
r0, c0 = from_ % 5, from_ // 5
r1, c1 = to % 5, to // 5
if not ((jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0) or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))):
if not ((np.abs(r1 - r0) == 0 or np.abs(c1 - c0) == 0) or (np.abs(r1 - r0) == np.abs(c1 - c0))):
continue
dr = max(min(r1 - r0, 1), -1)
dc = max(min(c1 - c0, 1), -1)
Expand All @@ -154,18 +155,25 @@
break
bet.append(c * 5 + r)
assert len(bet) <= 3
BETWEEN = BETWEEN.at[from_, to, : len(bet)].set(jnp.int32(bet))
BETWEEN[from_, to, : len(bet)] = bet


INIT_LEGAL_ACTION_MASK = jnp.zeros(25 * 49, dtype=jnp.bool_)
INIT_LEGAL_ACTION_MASK = np.zeros(25 * 49, dtype=np.bool_)
# fmt: off
ixs = [62, 289, 293, 307, 552, 797, 1042]
# fmt: on
for ix in ixs:
INIT_LEGAL_ACTION_MASK = INIT_LEGAL_ACTION_MASK.at[ix].set(True)
INIT_LEGAL_ACTION_MASK[ix] = True
assert INIT_LEGAL_ACTION_MASK.shape == (25 * 49,)
assert INIT_LEGAL_ACTION_MASK.sum() == 7

TO_MAP = jnp.array(TO_MAP)
PLANE_MAP = jnp.array(PLANE_MAP)
CAN_MOVE = jnp.array(CAN_MOVE)
CAN_MOVE_ANY = jnp.array(CAN_MOVE_ANY)
BETWEEN = jnp.array(BETWEEN)
INIT_LEGAL_ACTION_MASK = jnp.array(INIT_LEGAL_ACTION_MASK)


key = jax.random.PRNGKey(238942)
key, subkey = jax.random.split(key)
Expand Down

0 comments on commit 262afea

Please sign in to comment.