Skip to content

Commit

Permalink
fixed_jit_issue for eval cases generation
Browse files Browse the repository at this point in the history
  • Loading branch information
collinfeng committed Apr 1, 2024
1 parent bd4fa01 commit feb82dc
Showing 1 changed file with 105 additions and 114 deletions.
219 changes: 105 additions & 114 deletions jaxmarl/environments/hanabi/hint_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.num_classes_per_feature = num_classes_per_feature
self.num_cards = np.prod(self.num_classes_per_feature)
self.matrix_obs = matrix_obs
self.feature_tree = [np.arange(n_c) for n_c in num_classes_per_feature]
self.card_feature_space = jnp.array(list(product(*[np.arange(n_c) for n_c in self.num_classes_per_feature])))

# generate the deck of one-hot encoded cards
if card_encoding == "onehot":
Expand Down Expand Up @@ -93,149 +93,136 @@ def reset(self, rng):
)
return jax.lax.stop_gradient(self.get_obs(state)), state

@partial(jax.jit, static_argnums=[0, 2])
def reset_for_eval(self, rng, reset_mode="exact_match"):
@partial(jax.jit, static_argnums=[0, 2, 3])
def reset_for_eval(self, rng, reset_mode="exact_match", replace=True):

def exact_match(card_multi_set):
card_flat_set = card_multi_set.flatten()
hint_flat_id = target_flat_id
hinter_and_guesser_flat_hand_set = jnp.delete(card_flat_set, target_flat_id, assume_unique_indices=True)
hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
hinter_and_guesser_flat_hand_set,
shape=(self.hand_size-1,))
guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs,
hinter_and_guesser_flat_hand_set,
shape=(self.hand_size-1,))
return hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand
def p_exact_match(masks):
target_mask, non_target_mask, _, _, _ = masks
p_hint = target_mask/jnp.sum(target_mask)
p_hinter_and_guesser_rest_of_hand = non_target_mask/jnp.sum(non_target_mask)
return p_hint, p_hinter_and_guesser_rest_of_hand

def similarity_match(card_multi_set):
feature_of_interest = jax.random.choice(hint_rng, self.num_features)
target_index_of_interest = target_multi_id[feature_of_interest]
# note this hint_set also include target card, need to be removed
# print(feature_of_interest, target_index_of_interest)
hint_set = jax.lax.dynamic_index_in_dim(card_multi_set,
target_index_of_interest,
feature_of_interest,
keepdims=False)
# find the target card id in hint set after slicing from the feature of interest
target_id = jnp.concatenate((target_multi_id[:feature_of_interest], target_multi_id[feature_of_interest+1:]))
hint_flat_set = jnp.delete(hint_set, target_id, assume_unique_indices=True).flatten()

non_similar_hand_set = copy.deepcopy(card_multi_set)
for feature_dim in range(self.num_features):
non_similar_hand_set = jnp.delete(non_similar_hand_set,
target_multi_id[feature_dim],
axis=feature_dim,
assume_unique_indices=True)

non_similar_flat_hand_set = non_similar_hand_set.flatten()
hinter_flat_id = jax.random.choice(hint_rng,
hint_flat_set,
shape=(1,))
hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
non_similar_flat_hand_set,
shape=(self.hand_size-1,))
guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs,
non_similar_flat_hand_set,
shape=(self.hand_size-1,))
return hinter_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand
def p_similarity_match(masks):
_, _, random_similar_feature_exclude_target_mask, _, non_similar_feature_mask = masks
hint_p = random_similar_feature_exclude_target_mask/jnp.sum(random_similar_feature_exclude_target_mask)
p_hinter_and_guesser_rest_of_hand = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
return hint_p, p_hinter_and_guesser_rest_of_hand

def mutual_exclusive(card_multi_set):
non_similar_hand_set = copy.deepcopy(card_multi_set)
for feature_dim in range(self.num_features):
non_similar_hand_set = jnp.delete(non_similar_hand_set,
target_multi_id[feature_dim],
axis=feature_dim,
assume_unique_indices=True)

non_similar_flat_hand_set = non_similar_hand_set.flatten()
def p_mutual_exclusive(masks):
_, _, _, _, non_similar_feature_mask = masks
p_non_sim = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
hint_flat_id = jax.random.choice(hint_rng,
non_similar_flat_hand_set,
shape=(1,))
hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(non_similar_flat_hand_set, hint_flat_id, assume_unique_indices=True)
# note the rest of hand of both players are the same, so use either of the rngs
hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
hinter_and_guesser_flat_rest_of_hand_set,
shape=(self.hand_size-1,))
return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand
card_space,
shape=(1,),
p=p_non_sim)
hint_mask = jax.nn.one_hot(x=hint_flat_id, num_classes=self.num_cards).flatten() # note this is also p_hint, as the chosen card has p=1
hinter_and_guesser_rest_of_hand_mask = jnp.logical_and(non_similar_feature_mask, jnp.logical_not(hint_mask))
p_hinter_and_guesser_rest_of_hand = hinter_and_guesser_rest_of_hand_mask/jnp.sum(hinter_and_guesser_rest_of_hand_mask)
return hint_mask, p_hinter_and_guesser_rest_of_hand

def mutual_exclusice_similarity(card_multi_set):
# the target will be included by the first slice, thus need to be removed
similar_cards_of_the_first_feature = jax.lax.dynamic_index_in_dim(card_multi_set,
target_multi_id[0],
0,
keepdims=False)
target_id = target_multi_id[1:]
hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(similar_cards_of_the_first_feature, target_id, assume_unique_indices=True).flatten()

# later slices does include the target card
for feature_dim in range(1, self.num_features):
similar_cards = jax.lax.dynamic_index_in_dim(card_multi_set,
target_multi_id[feature_dim],
feature_dim,
keepdims=False)
hinter_and_guesser_flat_rest_of_hand_set = jnp.append(hinter_and_guesser_flat_rest_of_hand_set,
similar_cards.flatten())
card_multi_set = jnp.delete(card_multi_set,
target_multi_id[feature_dim],
axis=feature_dim,
assume_unique_indices=True)

hint_flat_set = card_multi_set.flatten()
hint_flat_id = jax.random.choice(hint_rng, hint_flat_set, shape=(1,))
# note the rest of hand of both players are the same, so use either of the rngs
hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
hinter_and_guesser_flat_rest_of_hand_set,
shape=(self.hand_size-1,))
return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand
def p_mutual_exclusice_similarity(masks):
_, _, _, similar_feature_exclude_target_mask, non_similar_feature_mask = masks
p_hint = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
p_hinter_and_guesser_rest_of_hand = similar_feature_exclude_target_mask/jnp.sum(similar_feature_exclude_target_mask)
print(similar_feature_exclude_target_mask, p_hinter_and_guesser_rest_of_hand)
return p_hint, p_hinter_and_guesser_rest_of_hand

def shuffle_and_index(rng, players_hands):
def set_single_hand(hand, index):
empty_hands = jnp.zeros(5, dtype=jnp.int32)
return empty_hands.at[index].set(hand)
"""
generates a permutation mapping for the hands of the players such that the target_card and hint_card are tractable after the permutation
returns permuted hands, hint_card_index and target_card_index in the permuted hands
returns permuted hands, hint_card_index and target_card_index in the permuted hands of hinter and guesser
"""
rngs = jax.random.split(rng, 2)
permutation_index = jax.vmap(jax.random.permutation, in_axes=(0, None))(rngs, 5)
permuted_hands = jax.vmap(set_single_hand, in_axes=(0, 0))(players_hands, permutation_index)
return permuted_hands, permutation_index[0, 0], permutation_index[0, 1]
return permuted_hands, permutation_index[0, 0], permutation_index[1, 0]

target_rng, hint_rng, hinter_hand_rngs, guesser_hand_rngs = jax.random.split(rng, 4)

# constants
# target randomisation
target_flat_id = jax.random.choice(target_rng, self.num_cards)
target_multi_id = jnp.array(jnp.unravel_index(target_flat_id, self.num_classes_per_feature))
card_multi_set = jnp.arange(self.num_cards).reshape(self.num_classes_per_feature)


if reset_mode == "exact_match":
hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = exact_match(card_multi_set)
elif reset_mode == "similarity_match":
hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = similarity_match(card_multi_set)
elif reset_mode == "mutual_exclusive":
hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusive(card_multi_set)
elif reset_mode == "mutual_exclusive_similarity":
hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusice_similarity(card_multi_set)
#copy card space to ensure env is not modified
card_space = jnp.arange(self.num_cards)
card_feature_space = self.card_feature_space

# generate mask for exact match and non_exact match
target_mask = jnp.where(target_flat_id
== card_space,
1,
0).flatten()
non_target_mask = 1 - target_mask

# generate mask for similar cards for a randomly selected feature
random_feature_of_interest = jax.random.choice(hint_rng, self.num_features)
random_similar_feature_mask = jnp.where(card_feature_space[:, random_feature_of_interest]
== target_multi_id[random_feature_of_interest],
1,
0).flatten()
random_similar_feature_exclude_target_mask = non_target_mask * random_similar_feature_mask

# generate mask for all non-similar cards for all features
similar_feature_mask = jnp.zeros(self.num_cards)
non_similar_feature_mask = jnp.ones(self.num_cards)
for feature_dim in range(self.num_features):
# + is logical or operation, * is logical and operation
similar_feature_mask = similar_feature_mask + jnp.where(card_feature_space[:, feature_dim]
== target_multi_id[feature_dim],
1,
0).flatten()
non_similar_feature_mask = non_similar_feature_mask * jnp.where(card_feature_space[:, feature_dim]
!= target_multi_id[feature_dim],
1,
0).flatten()
similar_feature_mask_exculde_target = similar_feature_mask * non_target_mask

masks = (target_mask, non_target_mask, random_similar_feature_exclude_target_mask, similar_feature_mask_exculde_target, non_similar_feature_mask)
p_reset_modes = {
"exact_match": p_exact_match,
"similarity_match": p_similarity_match,
"mutual_exclusive": p_mutual_exclusive,
"mutual_exclusive_similarity": p_mutual_exclusice_similarity,
}
if reset_mode in p_reset_modes:
p_hint, p_other = p_reset_modes[reset_mode](masks)
else:
raise ValueError("reset_mode is not supported")

hinter_flat_id = jax.random.choice(hint_rng,
card_space,
shape=(1,),
replace=replace,
p=p_hint)

hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
card_space,
shape=(self.hand_size-1,),
replace=replace,
p=p_other)
if reset_mode == "mutual_exclusive" or reset_mode == "mutual_exclusive_similarity":
guesser_flat_rest_of_hand = hinter_flat_rest_of_hand
else:
guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs,
card_space,
shape=(self.hand_size-1,),
replace=replace,
p=p_other)

hinter_hand = jnp.append(hint_flat_id, hinter_flat_rest_of_hand)
hinter_hand = jnp.append(hinter_flat_id, hinter_flat_rest_of_hand)
guesser_hand = jnp.append(target_flat_id, guesser_flat_rest_of_hand)

player_hands = jnp.stack((hinter_hand, guesser_hand))
print(player_hands.shape)
rngs = jnp.stack((hinter_hand_rngs, guesser_hand_rngs))
permuted_hands, hints, targets = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands)
permuted_hands, hint_indices, target_indices = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands)
state = State(
player_hands=permuted_hands, target=target_flat_id, hint=-1, guess=-1, turn=0
)

return jax.lax.stop_gradient(self.get_obs(state)), state, hints, targets



return jax.lax.stop_gradient(self.get_obs(state)), state, hint_indices, target_indices

@partial(jax.jit, static_argnums=[0])
def step_env(self, rng, state, actions):
Expand Down Expand Up @@ -359,8 +346,12 @@ def get_onehot_encodings(self):


if __name__ == "__main__":
jax.config.update("jax_disable_jit", True)
# jax.config.update("jax_disable_jit", True)
env = HintGuessGame()
rng = jax.random.PRNGKey(0)
_, state, _, _ = env.reset_for_eval(rng, reset_mode="exact_match")
rng = jax.random.PRNGKey(10)
# reset_modes: exact_match, similarity_match, mutual_exclusive, mutual_exclusive_similarity
_, state, hints, targets = env.reset_for_eval(rng, reset_mode="similarity_match", replace=True)
print(jnp.arange(9).reshape(3, 3))
print(state)
print(hints)
print(targets)

0 comments on commit feb82dc

Please sign in to comment.