Skip to content

Commit

Permalink
[BridgeBidding] Improve efficiency in using dds results (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Jul 17, 2024
1 parent 25d5a50 commit d0af062
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 94 deletions.
82 changes: 15 additions & 67 deletions pgx/bridge_bidding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import jax
import jax.numpy as jnp
import numpy as np

import pgx.core as core
from pgx._src.struct import dataclass
Expand All @@ -28,6 +27,8 @@
TRUE = jnp.bool_(True)
FALSE = jnp.bool_(False)

HandArray = Array # (4,) int array

# The card and number correspondence
# 0~12 spade, 13~25 heart, 26~38 diamond, 39~51 club
# For each suit, the numbers are arranged in the following order
Expand All @@ -45,42 +46,19 @@
def download_dds_results(download_dir="dds_results"):
"""Download and split the results into 100K chunks."""

def split_data(data, prefix, base_i=0):
with open(data, "rb") as f:
keys, values = jnp.load(f)
n = 100_000
m = keys.shape[0] // n
for i in range(m):
fname = os.path.join(download_dir, f"{prefix}_{base_i + i:03d}.npy")
with open(fname, "wb") as f:
print(
f"saving {fname} ... [{i * n}, {(i + 1) * n})",
file=sys.stderr,
)
jnp.save(
f,
(
keys[i * n : (i + 1) * n],
values[i * n : (i + 1) * n],
),
)

os.makedirs(download_dir, exist_ok=True)

train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy")
if not os.path.exists(train_small_fname):
_download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname)
split_data(train_small_fname, "train")

train_large_fname = os.path.join(download_dir, "dds_results_10M.npy")
if not os.path.exists(train_large_fname):
_download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname)
split_data(train_large_fname, "train", base_i=25)

test_fname = os.path.join(download_dir, "dds_results_500K.npy")
if not os.path.exists(test_fname):
_download(DDS_RESULTS_TEST_URL, test_fname)
split_data(test_fname, "test")


@dataclass
Expand Down Expand Up @@ -136,14 +114,15 @@ class State(core.State):
_first_denomination_EW: Array = jnp.full(5, -1, dtype=jnp.int32)
# Number of pass
_pass_num: Array = jnp.array(0, dtype=jnp.int32)
_dds_val: Array = jnp.zeros(4, dtype=jnp.int32)

@property
def env_id(self) -> core.EnvId:
return "bridge_bidding"


class BridgeBidding(core.Env):
def __init__(self, dds_results_table_path: str = "dds_results/train_000.npy"):
def __init__(self, dds_results_table_path: str = "dds_results/dds_results_10M.npy"):
super().__init__()
print(
f"Loading dds results from {dds_results_table_path} ...",
Expand Down Expand Up @@ -180,12 +159,14 @@ def __init__(self, dds_results_table_path: str = "dds_results/train_000.npy"):

def _init(self, key: PRNGKey) -> State:
key1, key2 = jax.random.split(key, num=2)
return _init_by_key(jax.random.choice(key1, self._lut_keys), key2)
ix = jax.random.choice(key1, jnp.arange(self._lut_keys.shape[0]))
key, val = self._lut_keys[ix], self._lut_values[ix]
return _init_by_key(key, val, key2)

def _step(self, state: core.State, action: int, key) -> State:
del key
assert isinstance(state, State)
return _step(state, action, self._lut_keys, self._lut_values)
return _step(state, action)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
Expand All @@ -208,7 +189,7 @@ def _illegal_action_penalty(self) -> float:
return -7600.0


def _init_by_key(key: PRNGKey, rng: PRNGKey) -> State:
def _init_by_key(key: HandArray, val: Array, rng: PRNGKey) -> State:
"""Make init state from key"""
rng1, rng2, rng3, rng4 = jax.random.split(rng, num=4)
hand = _key_to_hand(key)
Expand All @@ -230,6 +211,7 @@ def _init_by_key(key: PRNGKey, rng: PRNGKey) -> State:
_vul_NS=vul_NS,
_vul_EW=vul_EW,
legal_action_mask=legal_actions,
_dds_val=val,
)
return state

Expand Down Expand Up @@ -285,8 +267,6 @@ def _player_position(player: Array, state: State) -> Array:
def _step(
state: State,
action: int,
lut_keys: Array,
lut_values: Array,
) -> State:
# fmt: off
state = state.replace(_bidding_history=state._bidding_history.at[state._turn].set(action)) # type: ignore
Expand All @@ -298,7 +278,7 @@ def _step(
[
lambda: jax.lax.cond(
_is_terminated(_state_pass(state)),
lambda: _terminated_step(_state_pass(state), lut_keys, lut_values),
lambda: _terminated_step(_state_pass(state)),
lambda: _continue_step(_state_pass(state)),
),
lambda: _continue_step(_state_X(state)),
Expand Down Expand Up @@ -392,12 +372,10 @@ def _convert_card_pgx_to_openspiel(card: Array) -> Array:

def _terminated_step(
state: State,
lut_keys: Array,
lut_values: Array,
) -> State:
"""Return state if the game is successfully completed"""
terminated = jnp.bool_(True)
reward = _reward(state, lut_keys, lut_values)
reward = _reward(state)
# fmt: off
return state.replace(terminated=terminated, rewards=reward) # type: ignore
# fmt: on
Expand Down Expand Up @@ -433,23 +411,19 @@ def _is_terminated(state: State) -> bool:

def _reward(
state: State,
lut_keys: Array,
lut_values: Array,
) -> Array:
"""Return reward
If pass out, 0 reward for everyone; if bid, calculate and return reward
"""
return jax.lax.cond(
(state._last_bid == -1) & (state._pass_num == 4),
lambda: jnp.zeros(4, dtype=jnp.float32), # pass out
lambda: _make_reward(state, lut_keys, lut_values), # caluculate reward
lambda: _make_reward(state), # caluculate reward
)


def _make_reward(
state: State,
lut_keys: Array,
lut_values: Array,
) -> Array:
"""Calculate rewards for each player by dds results
Expand All @@ -459,7 +433,7 @@ def _make_reward(
# Extract contract from state
declare_position, denomination, level, vul = _contract(state)
# Calculate trick table from hash table
dds_tricks = _calculate_dds_tricks(state, lut_keys, lut_values)
dds_tricks = _value_to_dds_tricks(state._dds_val)
# Calculate the tricks you could have accomplished with this contraption
dds_trick = dds_tricks[declare_position * 5 + denomination]
# Clculate score
Expand Down Expand Up @@ -814,7 +788,7 @@ def _card_str_to_int(card: str) -> int:
return int(card) - 1


def _key_to_hand(key: PRNGKey) -> Array:
def _key_to_hand(key: HandArray) -> Array:
"""Convert key to hand"""

def _convert_quat(j):
Expand Down Expand Up @@ -859,32 +833,6 @@ def _convert_hex(j):
return jnp.array(hex_digits, dtype=jnp.int32)


def _calculate_dds_tricks(
state: State,
lut_keys: Array,
lut_values: Array,
) -> Array:
key = _state_to_key(state)
return _value_to_dds_tricks(_find_value_from_key(key, lut_keys, lut_values))


def _find_value_from_key(key: PRNGKey, lut_keys: Array, lut_values: Array):
"""Find a value matching key without batch processing
>>> VALUES = jnp.arange(20).reshape(5, 4)
>>> KEYS = jnp.arange(20).reshape(5, 4)
>>> key = jnp.arange(4, 8)
>>> _find_value_from_key(key, KEYS, VALUES)
Array([4, 5, 6, 7], dtype=int32)
"""
mask = jnp.where(
jnp.all((lut_keys == key), axis=1),
jnp.ones(1, dtype=np.bool_),
jnp.zeros(1, dtype=np.bool_),
)
ix = jnp.argmax(mask)
return lut_values[ix]


def _load_sample_hash() -> Tuple[Array, Array]:
# fmt: off
return jnp.array([[19556549, 61212362, 52381660, 50424958], [53254536, 21854346, 37287883, 14009558], [44178585, 6709002, 23279217, 16304124], [36635659, 48114215, 13583653, 26208086], [44309474, 39388022, 28376136, 59735189], [61391908, 52173479, 29276467, 31670621], [34786519, 13802254, 57433417, 43152306], [48319039, 55845612, 44614774, 58169152], [47062227, 32289487, 12941848, 21338650], [36579116, 15643926, 64729756, 18678099], [62136384, 37064817, 59701038, 39188202], [13417016, 56577539, 25995845, 27248037], [61125047, 43238281, 23465183, 20030494], [7139188, 31324229, 58855042, 14296487], [2653767, 47502150, 35507905, 43823846], [31453323, 11605145, 6716808, 41061859], [21294711, 49709, 26110952, 50058629], [48130172, 3340423, 60445890, 7686579], [16041939, 27817393, 37167847, 9605779], [61154057, 17937858, 12254613, 12568801], [13796245, 46546127, 49123920, 51772041], [7195005, 45581051, 41076865, 17429796], [20635965, 14642724, 7001617, 45370595], [35616421, 19938131, 45131030, 16524847], [14559399, 15413729, 39188470, 535365], [48743216, 39672069, 60203571, 60210880], [63862780, 2462075, 23267370, 36595020], [11229980, 11616119, 20292263, 3695004], [24135854, 37532826, 54421444, 14130249], [42798085, 33026223, 2460251, 18566823], [49558558, 65537599, 14768519, 31103243], [44321156, 20075251, 42663767, 11615602], [20186726, 42678073, 11763300, 56739471], [57534601, 16703645, 6039937, 17088125], [50795278, 17350238, 11955835, 21538127], [45919621, 5520088, 27736513, 52674927], [13928720, 57324148, 28222453, 15480785], [910719, 47238830, 26345802, 56166394], [58841430, 1098476, 61890558, 26907706], [10379825, 8624220, 39701822, 29045990], [54444873, 50000486, 48563308, 55867521], [47291672, 22084522, 45484828, 32878832], [55350706, 23903891, 46142039, 11499952], [4708326, 27588734, 31010458, 11730972], [27078872, 59038086, 62842566, 51147874], [28922172, 32377861, 9109075, 10154350], [26104086, 62786977, 224865, 14335943], [20448626, 33187645, 34338784, 26382893], [29194006, 19635744, 24917755, 8532577], [64047742, 34885257, 5027048, 58399668], [27603972, 26820121, 44837703, 63748595], [60038456, 19611050, 7928914, 38555047], [13583610, 19626473, 22239272, 19888268], [28521006, 1743692, 31319264, 15168920], [64585849, 63931241, 57019799, 14189800], [2632453, 7269809, 60404342, 57986125], [1996183, 49918209, 49490468, 47760867], [6233580, 15318425, 51356120, 55074857], [15769884, 61654638, 8374039, 43685186], [44162419, 47272176, 62693156, 35359329], [36345796, 15667465, 53341561, 2978505], [1664472, 12761950, 34145519, 55197543], [37567005, 3228834, 6198166, 15646487], [63233399, 42640049, 12969011, 41620641], [22090925, 3386355, 56655568, 31631004], [16442787, 9420273, 48595545, 29770176], [49404288, 37823218, 58551818, 6772527], [36575583, 53847347, 32379432, 1630009], [9004247, 12999580, 48379959, 14252211], [25850203, 26136823, 64934025, 29362603], [10214276, 43557352, 33387586, 55512773], [45810841, 49561478, 41130845, 27034816], [34460081, 16560450, 57722793, 41007718], [53414778, 6845803, 15340368, 16647575], [30535873, 5193469, 43608154, 11391114], [20622004, 34424126, 31475211, 29619615], [10428836, 49656416, 7912677, 33427787], [57600861, 18251799, 46147432, 58946294], [6760779, 14675737, 42952146, 5480498], [46037552, 39969058, 30103468, 55330772], [64466305, 29376674, 49914839, 55269895], [36494113, 27010567, 65752150, 12395385], [49385632, 19550767, 39809394, 58806235], [20987521, 37444597, 49290126, 42326125], [37150229, 37487849, 28254397, 32949826], [9724895, 53813417, 19431235, 27438556], [42132748, 47073733, 19396568, 10026137], [3961481, 27204521, 62087205, 37602005], [22178323, 17505521, 42006207, 44143605], [12753258, 63063515, 61993175, 8920985], [10998000, 64833190, 6446892, 63676805], [66983817, 63684932, 18378359, 39946382], [63476803, 60000436, 19442420, 66417845], [38004446, 64752157, 42570179, 52844512], [1270809, 23735482, 17543294, 18795903], [4862706, 16352249, 57100612, 6219870], [63203206, 25630930, 35608240, 51357885], [59819625, 64662579, 50925335, 55670434], [29216830, 26446697, 52243336, 58475666], [43138915, 30592834, 43931516, 50628002]], dtype=jnp.int32), jnp.array([[71233, 771721, 71505, 706185], [289177, 484147, 358809, 484147], [359355, 549137, 359096, 549137], [350631, 558133, 350630, 554037], [370087, 538677, 370087, 538677], [4432, 899725, 4432, 904077], [678487, 229987, 678487, 229987], [423799, 480614, 423799, 480870], [549958, 284804, 549958, 280708], [423848, 480565, 423848, 480549], [489129, 283940, 554921, 283940], [86641, 822120, 86641, 822120], [206370, 702394, 206370, 567209], [500533, 407959, 500533, 407959], [759723, 79137, 759723, 79137], [563305, 345460, 559209, 345460], [231733, 611478, 231733, 611478], [502682, 406082, 498585, 406082], [554567, 288662, 554567, 288662], [476823, 427846, 476823, 427846], [488823, 415846, 488823, 415846], [431687, 477078, 431687, 477078], [419159, 424070, 415062, 424070], [493399, 345734, 493143, 345718], [678295, 230451, 678295, 230451], [496520, 342596, 496520, 346709], [567109, 276116, 567109, 276116], [624005, 284758, 624005, 284758], [420249, 484420, 420248, 484420], [217715, 621418, 217715, 621418], [344884, 493977, 344884, 493977], [550841, 292132, 550841, 292132], [284262, 558967, 284006, 558967], [152146, 756616, 152146, 756616], [144466, 698763, 144466, 694667], [284261, 624504, 284261, 624504], [288406, 620102, 288405, 620358], [301366, 607383, 301366, 607382], [468771, 435882, 468771, 435882], [555688, 283444, 555688, 283444], [485497, 414820, 485497, 414820], [633754, 275010, 633754, 275010], [419141, 489608, 419157, 489608], [694121, 214387, 694121, 214387], [480869, 427639, 481125, 427639], [489317, 419447, 489301, 419447], [152900, 747672, 152900, 747672], [348516, 494457, 348516, 494457], [534562, 370088, 534562, 370088], [371272, 537475, 371274, 537475], [144194, 760473, 144194, 760473], [567962, 275011, 567962, 275011], [493161, 350052, 493161, 350052], [490138, 348979, 490138, 348979], [328450, 506552, 328450, 506552], [148882, 759593, 148626, 755497], [642171, 266593, 642171, 266593], [685894, 218774, 685894, 218774], [674182, 234548, 674214, 234548], [756347, 152146, 690811, 86353], [612758, 291894, 612758, 291894], [296550, 612214, 296550, 612214], [363130, 475730, 363130, 475730], [691559, 16496, 691559, 16496], [340755, 502202, 336659, 502218], [632473, 210499, 628377, 210483], [564410, 266513, 564410, 266513], [427366, 481399, 427366, 481399], [493159, 349797, 493159, 415605], [331793, 576972, 331793, 576972], [416681, 492084, 416681, 492084], [813496, 95265, 813496, 91153], [695194, 213571, 695194, 213571], [436105, 407124, 436105, 407124], [836970, 6243, 902506, 6243], [160882, 747882, 160882, 747882], [493977, 414788, 489624, 414788], [29184, 551096, 29184, 616888], [903629, 4880, 899517, 4880], [351419, 553250, 351419, 553250], [75554, 767671, 75554, 767671], [279909, 563304, 279909, 563304], [215174, 628054, 215174, 628054], [361365, 481864, 361365, 481864], [424022, 484743, 358486, 484725], [271650, 633018, 271650, 633018], [681896, 226867, 616088, 226867], [222580, 686184, 222564, 686184], [144451, 698778, 209987, 698778], [532883, 310086, 532883, 310086], [628872, 279893, 628872, 279893], [533797, 374951, 533797, 374951], [91713, 817036, 91713, 817036], [427605, 477046, 431718, 477046], [145490, 689529, 145490, 689529], [551098, 291875, 551098, 291875], [349781, 558984, 349781, 558983], [205378, 703115, 205378, 703115], [362053, 546456, 362053, 546456], [612248, 226371, 678040, 226371]], dtype=jnp.int32)
Expand Down
32 changes: 5 additions & 27 deletions tests/test_bridge_bidding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
BridgeBidding,
State,
_calc_score,
_calculate_dds_tricks,
_contract,
_init_by_key,
_key_to_hand,
Expand Down Expand Up @@ -64,7 +63,6 @@ def init(rng: jax.Array) -> State:
step = jax.jit(env.step)
observe = jax.jit(env.observe)
_calc_score = jax.jit(_calc_score)
_calculate_dds_tricks = jax.jit(_calculate_dds_tricks)
_contract = jax.jit(_contract)
_init_by_key = jax.jit(_init_by_key)
_key_to_hand = jax.jit(_key_to_hand)
Expand Down Expand Up @@ -131,7 +129,7 @@ def test_step():
# fmt: on
key = jax.random.PRNGKey(0)
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0], key)
# state = init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = state.replace(
_dealer=jnp.int32(1),
Expand Down Expand Up @@ -957,7 +955,7 @@ def max_action_length_agent(state: State) -> int:
def test_max_action():
key = jax.random.PRNGKey(0)
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key)

for i in range(319):
if i < 318:
Expand All @@ -984,7 +982,7 @@ def max_action_length_agent(state: State) -> int:
def test_max_action():
key = jax.random.PRNGKey(0)
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key)

for i in range(319):
if i < 318:
Expand All @@ -1005,7 +1003,7 @@ def test_pass_out():
actions = iter([0, 0, 0, 0])
key = jax.random.PRNGKey(0)
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key)
# state = init_by_key(HASH_TABLE_SAMPLE_KEYS[1], key)
state = state.replace(
_dealer=jnp.int32(1),
Expand Down Expand Up @@ -1146,7 +1144,7 @@ def test_observe():
)
key = jax.random.PRNGKey(0)
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key)
state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key)
state = state.replace(
_dealer=jnp.int32(1),
current_player=jnp.int32(3),
Expand Down Expand Up @@ -2193,26 +2191,6 @@ def test_state_to_key_cycle():
assert jnp.all(sorted_hand == reconst_hand)


def test_calcurate_dds_tricks():
HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash()
samples = []
with open("tests/assets/contractbridge-ddstable-sample100.csv", "r") as f:
reader = csv.reader(f, delimiter=",")
for i in reader:
samples.append([i[0], np.array(i[1:]).astype(np.int32)])
key = jax.random.PRNGKey(0)
for i in range(len(HASH_TABLE_SAMPLE_KEYS)):
key, subkey = jax.random.split(key)
state = init(subkey)
state = state.replace(_hand=_key_to_hand(HASH_TABLE_SAMPLE_KEYS[i]))
dds_tricks = _calculate_dds_tricks(
state, HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES
)
# calculate dds results from hash table made by dample data
# check whether the results are conssitent with sample data
assert jnp.all(dds_tricks == samples[i][1])


def test_value_to_dds_tricks():
value = jnp.array([4160, 904605, 4160, 904605])
# fmt: off
Expand Down

0 comments on commit d0af062

Please sign in to comment.