Skip to content

Commit

Permalink
[Shogi] Unuse assets (#1272)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Oct 31, 2024
1 parent ac67514 commit 1e08b48
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 7 deletions.
Binary file removed pgx/_src/assets/between.npy
Binary file not shown.
Binary file removed pgx/_src/assets/can_move.npy
Binary file not shown.
137 changes: 131 additions & 6 deletions pgx/_src/shogi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os

import numpy as np
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -30,20 +31,144 @@
[15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241
# fmt: on

EMPTY = -1 # 空白
PAWN = 0 # 歩
LANCE = 1 # 香
KNIGHT = 2 # 桂
SILVER = 3 # 銀
BISHOP = 4 # 角
ROOK = 5 # 飛
GOLD = 6 # 金
KING = 7 # 玉
PRO_PAWN = 8 # と
PRO_LANCE = 9 # 成香
PRO_KNIGHT = 10 # 成桂
PRO_SILVER = 11 # 成銀
HORSE = 12 # 馬
DRAGON = 13 # 龍


# Can <piece,14> reach from <from,81> to <to,81> ignoring pieces on board?
file_path = "assets/can_move.npy"
with open(os.path.join(os.path.dirname(__file__), file_path), "rb") as f:
CAN_MOVE = jnp.load(f)
def can_move_to(piece, from_, to):
"""Can <piece> move from <from_> to <to>?"""
if from_ == to:
return False
x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
dx = x1 - x0
dy = y1 - y0
if piece == PAWN:
if dx == 0 and dy == -1:
return True
else:
return False
elif piece == LANCE:
if dx == 0 and dy < 0:
return True
else:
return False
elif piece == KNIGHT:
if dx in (-1, 1) and dy == -2:
return True
else:
return False
elif piece == SILVER:
if dx in (-1, 0, 1) and dy == -1:
return True
elif dx in (-1, 1) and dy == 1:
return True
else:
return False
elif piece == BISHOP:
if dx == dy or dx == -dy:
return True
else:
return False
elif piece == ROOK:
if dx == 0 or dy == 0:
return True
else:
return False
if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER):
if dx in (-1, 0, 1) and dy in (0, -1):
return True
elif dx == 0 and dy == 1:
return True
else:
return False
elif piece == KING:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
else:
return False
elif piece == HORSE:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
elif dx == dy or dx == -dy:
return True
else:
return False
elif piece == DRAGON:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
if dx == 0 or dy == 0:
return True
else:
return False
else:
assert False


def is_on_the_way(piece, from_, to, point):
if to == point:
return False
if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON):
return False
if not can_move_to(piece, from_, to):
return False
if not can_move_to(piece, from_, point):
return False

x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
x2, y2 = point // 9, point % 9
dx1, dy1 = x1 - x0, y1 - y0
dx2, dy2 = x2 - x0, y2 - y0

def sign(d):
if d == 0:
return 0
return d > 0

if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)):
return False

return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1)


CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_)
for piece in range(14):
for from_ in range(81):
for to in range(81):
CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to)

assert CAN_MOVE.sum() == 8228
CAN_MOVE = jnp.array(CAN_MOVE)


# When <lance/bishop/rook/horse/dragon,5> moves from <from,81> to <to,81>,
# is <point,81> on the way between two points?
file_path = "assets/between.npy"
with open(os.path.join(os.path.dirname(__file__), file_path), "rb") as f:
BETWEEN = jnp.load(f)
BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_)
for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)):
for from_ in range(81):
for to in range(81):
for p in range(81):
BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p)

BETWEEN = jnp.array(BETWEEN)
assert BETWEEN.sum() == 10564


# Give <dir,10> and <to,81>, return the legal <from> idx
# E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1)
# Used for computing dlshogi action
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _read_requirements(fname):
keywords="",
packages=find_packages(),
package_data={
"": ["LICENSE", "*.svg", "_src/assets/*.npy"]
"": ["LICENSE", "*.svg"]
},
include_package_data=True,
install_requires=_read_requirements("requirements.txt"),
Expand Down

0 comments on commit 1e08b48

Please sign in to comment.