diff --git a/jaxmarl/environments/hanabi/hint_guess.py b/jaxmarl/environments/hanabi/hint_guess.py index c848adfa..dceebc94 100644 --- a/jaxmarl/environments/hanabi/hint_guess.py +++ b/jaxmarl/environments/hanabi/hint_guess.py @@ -47,7 +47,7 @@ def __init__( self.num_classes_per_feature = num_classes_per_feature self.num_cards = np.prod(self.num_classes_per_feature) self.matrix_obs = matrix_obs - self.feature_tree = [np.arange(n_c) for n_c in num_classes_per_feature] + self.card_feature_space = jnp.array(list(product(*[np.arange(n_c) for n_c in self.num_classes_per_feature]))) # generate the deck of one-hot encoded cards if card_encoding == "onehot": @@ -93,101 +93,39 @@ def reset(self, rng): ) return jax.lax.stop_gradient(self.get_obs(state)), state - @partial(jax.jit, static_argnums=[0, 2]) - def reset_for_eval(self, rng, reset_mode="exact_match"): + @partial(jax.jit, static_argnums=[0, 2, 3]) + def reset_for_eval(self, rng, reset_mode="exact_match", replace=True): - def exact_match(card_multi_set): - card_flat_set = card_multi_set.flatten() - hint_flat_id = target_flat_id - hinter_and_guesser_flat_hand_set = jnp.delete(card_flat_set, target_flat_id, assume_unique_indices=True) - hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_hand_set, - shape=(self.hand_size-1,)) - guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, - hinter_and_guesser_flat_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand + def p_exact_match(masks): + target_mask, non_target_mask, _, _, _ = masks + p_hint = target_mask/jnp.sum(target_mask) + p_hinter_and_guesser_rest_of_hand = non_target_mask/jnp.sum(non_target_mask) + return p_hint, p_hinter_and_guesser_rest_of_hand - def similarity_match(card_multi_set): - feature_of_interest = jax.random.choice(hint_rng, self.num_features) - target_index_of_interest = target_multi_id[feature_of_interest] - # note this hint_set also include target card, need to be removed - # print(feature_of_interest, target_index_of_interest) - hint_set = jax.lax.dynamic_index_in_dim(card_multi_set, - target_index_of_interest, - feature_of_interest, - keepdims=False) - # find the target card id in hint set after slicing from the feature of interest - target_id = jnp.concatenate((target_multi_id[:feature_of_interest], target_multi_id[feature_of_interest+1:])) - hint_flat_set = jnp.delete(hint_set, target_id, assume_unique_indices=True).flatten() - - non_similar_hand_set = copy.deepcopy(card_multi_set) - for feature_dim in range(self.num_features): - non_similar_hand_set = jnp.delete(non_similar_hand_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - non_similar_flat_hand_set = non_similar_hand_set.flatten() - hinter_flat_id = jax.random.choice(hint_rng, - hint_flat_set, - shape=(1,)) - hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - non_similar_flat_hand_set, - shape=(self.hand_size-1,)) - guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, - non_similar_flat_hand_set, - shape=(self.hand_size-1,)) - return hinter_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand + def p_similarity_match(masks): + _, _, random_similar_feature_exclude_target_mask, _, non_similar_feature_mask = masks + hint_p = random_similar_feature_exclude_target_mask/jnp.sum(random_similar_feature_exclude_target_mask) + p_hinter_and_guesser_rest_of_hand = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) + return hint_p, p_hinter_and_guesser_rest_of_hand - def mutual_exclusive(card_multi_set): - non_similar_hand_set = copy.deepcopy(card_multi_set) - for feature_dim in range(self.num_features): - non_similar_hand_set = jnp.delete(non_similar_hand_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - non_similar_flat_hand_set = non_similar_hand_set.flatten() + def p_mutual_exclusive(masks): + _, _, _, _, non_similar_feature_mask = masks + p_non_sim = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) hint_flat_id = jax.random.choice(hint_rng, - non_similar_flat_hand_set, - shape=(1,)) - hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(non_similar_flat_hand_set, hint_flat_id, assume_unique_indices=True) - # note the rest of hand of both players are the same, so use either of the rngs - hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_rest_of_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand + card_space, + shape=(1,), + p=p_non_sim) + hint_mask = jax.nn.one_hot(x=hint_flat_id, num_classes=self.num_cards).flatten() # note this is also p_hint, as the chosen card has p=1 + hinter_and_guesser_rest_of_hand_mask = jnp.logical_and(non_similar_feature_mask, jnp.logical_not(hint_mask)) + p_hinter_and_guesser_rest_of_hand = hinter_and_guesser_rest_of_hand_mask/jnp.sum(hinter_and_guesser_rest_of_hand_mask) + return hint_mask, p_hinter_and_guesser_rest_of_hand - def mutual_exclusice_similarity(card_multi_set): - # the target will be included by the first slice, thus need to be removed - similar_cards_of_the_first_feature = jax.lax.dynamic_index_in_dim(card_multi_set, - target_multi_id[0], - 0, - keepdims=False) - target_id = target_multi_id[1:] - hinter_and_guesser_flat_rest_of_hand_set = jnp.delete(similar_cards_of_the_first_feature, target_id, assume_unique_indices=True).flatten() - - # later slices does include the target card - for feature_dim in range(1, self.num_features): - similar_cards = jax.lax.dynamic_index_in_dim(card_multi_set, - target_multi_id[feature_dim], - feature_dim, - keepdims=False) - hinter_and_guesser_flat_rest_of_hand_set = jnp.append(hinter_and_guesser_flat_rest_of_hand_set, - similar_cards.flatten()) - card_multi_set = jnp.delete(card_multi_set, - target_multi_id[feature_dim], - axis=feature_dim, - assume_unique_indices=True) - - hint_flat_set = card_multi_set.flatten() - hint_flat_id = jax.random.choice(hint_rng, hint_flat_set, shape=(1,)) - # note the rest of hand of both players are the same, so use either of the rngs - hinter_and_guesser_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, - hinter_and_guesser_flat_rest_of_hand_set, - shape=(self.hand_size-1,)) - return hint_flat_id, hinter_and_guesser_flat_rest_of_hand, hinter_and_guesser_flat_rest_of_hand + def p_mutual_exclusice_similarity(masks): + _, _, _, similar_feature_exclude_target_mask, non_similar_feature_mask = masks + p_hint = non_similar_feature_mask/jnp.sum(non_similar_feature_mask) + p_hinter_and_guesser_rest_of_hand = similar_feature_exclude_target_mask/jnp.sum(similar_feature_exclude_target_mask) + print(similar_feature_exclude_target_mask, p_hinter_and_guesser_rest_of_hand) + return p_hint, p_hinter_and_guesser_rest_of_hand def shuffle_and_index(rng, players_hands): def set_single_hand(hand, index): @@ -195,47 +133,96 @@ def set_single_hand(hand, index): return empty_hands.at[index].set(hand) """ generates a permutation mapping for the hands of the players such that the target_card and hint_card are tractable after the permutation - returns permuted hands, hint_card_index and target_card_index in the permuted hands + returns permuted hands, hint_card_index and target_card_index in the permuted hands of hinter and guesser """ rngs = jax.random.split(rng, 2) permutation_index = jax.vmap(jax.random.permutation, in_axes=(0, None))(rngs, 5) permuted_hands = jax.vmap(set_single_hand, in_axes=(0, 0))(players_hands, permutation_index) - return permuted_hands, permutation_index[0, 0], permutation_index[0, 1] + return permuted_hands, permutation_index[0, 0], permutation_index[1, 0] target_rng, hint_rng, hinter_hand_rngs, guesser_hand_rngs = jax.random.split(rng, 4) - # constants + # target randomisation target_flat_id = jax.random.choice(target_rng, self.num_cards) target_multi_id = jnp.array(jnp.unravel_index(target_flat_id, self.num_classes_per_feature)) - card_multi_set = jnp.arange(self.num_cards).reshape(self.num_classes_per_feature) - - if reset_mode == "exact_match": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = exact_match(card_multi_set) - elif reset_mode == "similarity_match": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = similarity_match(card_multi_set) - elif reset_mode == "mutual_exclusive": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusive(card_multi_set) - elif reset_mode == "mutual_exclusive_similarity": - hint_flat_id, hinter_flat_rest_of_hand, guesser_flat_rest_of_hand = mutual_exclusice_similarity(card_multi_set) + #copy card space to ensure env is not modified + card_space = jnp.arange(self.num_cards) + card_feature_space = self.card_feature_space + + # generate mask for exact match and non_exact match + target_mask = jnp.where(target_flat_id + == card_space, + 1, + 0).flatten() + non_target_mask = 1 - target_mask + + # generate mask for similar cards for a randomly selected feature + random_feature_of_interest = jax.random.choice(hint_rng, self.num_features) + random_similar_feature_mask = jnp.where(card_feature_space[:, random_feature_of_interest] + == target_multi_id[random_feature_of_interest], + 1, + 0).flatten() + random_similar_feature_exclude_target_mask = non_target_mask * random_similar_feature_mask + + # generate mask for all non-similar cards for all features + similar_feature_mask = jnp.zeros(self.num_cards) + non_similar_feature_mask = jnp.ones(self.num_cards) + for feature_dim in range(self.num_features): + # + is logical or operation, * is logical and operation + similar_feature_mask = similar_feature_mask + jnp.where(card_feature_space[:, feature_dim] + == target_multi_id[feature_dim], + 1, + 0).flatten() + non_similar_feature_mask = non_similar_feature_mask * jnp.where(card_feature_space[:, feature_dim] + != target_multi_id[feature_dim], + 1, + 0).flatten() + similar_feature_mask_exculde_target = similar_feature_mask * non_target_mask + + masks = (target_mask, non_target_mask, random_similar_feature_exclude_target_mask, similar_feature_mask_exculde_target, non_similar_feature_mask) + p_reset_modes = { + "exact_match": p_exact_match, + "similarity_match": p_similarity_match, + "mutual_exclusive": p_mutual_exclusive, + "mutual_exclusive_similarity": p_mutual_exclusice_similarity, + } + if reset_mode in p_reset_modes: + p_hint, p_other = p_reset_modes[reset_mode](masks) else: raise ValueError("reset_mode is not supported") + + hinter_flat_id = jax.random.choice(hint_rng, + card_space, + shape=(1,), + replace=replace, + p=p_hint) + + hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs, + card_space, + shape=(self.hand_size-1,), + replace=replace, + p=p_other) + if reset_mode == "mutual_exclusive" or reset_mode == "mutual_exclusive_similarity": + guesser_flat_rest_of_hand = hinter_flat_rest_of_hand + else: + guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs, + card_space, + shape=(self.hand_size-1,), + replace=replace, + p=p_other) - hinter_hand = jnp.append(hint_flat_id, hinter_flat_rest_of_hand) + hinter_hand = jnp.append(hinter_flat_id, hinter_flat_rest_of_hand) guesser_hand = jnp.append(target_flat_id, guesser_flat_rest_of_hand) player_hands = jnp.stack((hinter_hand, guesser_hand)) - print(player_hands.shape) rngs = jnp.stack((hinter_hand_rngs, guesser_hand_rngs)) - permuted_hands, hints, targets = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands) + permuted_hands, hint_indices, target_indices = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands) state = State( player_hands=permuted_hands, target=target_flat_id, hint=-1, guess=-1, turn=0 ) - return jax.lax.stop_gradient(self.get_obs(state)), state, hints, targets - - - + return jax.lax.stop_gradient(self.get_obs(state)), state, hint_indices, target_indices @partial(jax.jit, static_argnums=[0]) def step_env(self, rng, state, actions): @@ -359,8 +346,12 @@ def get_onehot_encodings(self): if __name__ == "__main__": - jax.config.update("jax_disable_jit", True) + # jax.config.update("jax_disable_jit", True) env = HintGuessGame() - rng = jax.random.PRNGKey(0) - _, state, _, _ = env.reset_for_eval(rng, reset_mode="exact_match") + rng = jax.random.PRNGKey(10) + # reset_modes: exact_match, similarity_match, mutual_exclusive, mutual_exclusive_similarity + _, state, hints, targets = env.reset_for_eval(rng, reset_mode="similarity_match", replace=True) + print(jnp.arange(9).reshape(3, 3)) print(state) + print(hints) + print(targets)