Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Jitable eval_reset for test cases generation #78

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/docker-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Run tests on Docker Image
on: [push, pull_request]

jobs:

build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Build the Docker image
run: make build
- name: Run tests
run: make workflow-test
34 changes: 0 additions & 34 deletions .github/workflows/tests.yml

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ tmp/
wandb/
outputs/
models/
.devcontainer/
.gitignore
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ RUN export XLA_PYTHON_CLIENT_PREALLOCATE=false
RUN export XLA_PYTHON_CLIENT_MEM_FRACTION=0.25
RUN export TF_FORCE_GPU_ALLOW_GROWTH=true

# if you want jupyter
RUN pip install pip install jupyterlab
# Uncomment below if you want jupyter
# RUN pip install jupyterlab

#for secrets and debug
ENV WANDB_API_KEY=""
ENV WANDB_ENTITY=""
RUN git config --global --add safe.directory /home/workdir
RUN git config --global --add safe.directory /home/workdir
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif
BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G
RUN_FLAGS=$(GPUS) $(BASE_FLAGS)

DOCKER_IMAGE_NAME = jaxmarl
DOCKER_IMAGE_NAME = jaxmarl-cf
IMAGE = $(DOCKER_IMAGE_NAME):latest
DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE)
USE_CUDA = $(if $(GPUS),true,false)
Expand All @@ -26,3 +26,6 @@ run:
test:
$(DOCKER_RUN) /bin/bash -c "pytest ./tests/"

workflow-test:
# without -it flag
docker run --rm -v ${PWD}:/home/workdir --shm-size 20G $(IMAGE) /bin/bash -c "pytest ./tests/"
4 changes: 2 additions & 2 deletions baselines/IPPO/config/ippo_ff_hint_guess.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

# WandB Params
"WANDB_MODE": "online"
"ENTITY": "mttga"
"PROJECT": "hint_guess"
"ENTITY": "clf26"
"PROJECT": "action-feature"
4 changes: 2 additions & 2 deletions baselines/IPPO/ippo_ff_hint_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def wrapped_make_train():
@hydra.main(version_base=None, config_path="config", config_name="ippo_ff_hint_guess")
def main(config):
config = OmegaConf.to_container(config)
#single_run(config)
tune(config)
single_run(config)
# tune(config)


if __name__ == "__main__":
Expand Down
147 changes: 146 additions & 1 deletion jaxmarl/environments/hanabi/hint_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flax import struct
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
from gymnax.environments.spaces import Discrete
import copy


@struct.dataclass
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
self.num_classes_per_feature = num_classes_per_feature
self.num_cards = np.prod(self.num_classes_per_feature)
self.matrix_obs = matrix_obs
self.card_feature_space = jnp.array(list(product(*[np.arange(n_c) for n_c in self.num_classes_per_feature])))

# generate the deck of one-hot encoded cards
if card_encoding == "onehot":
Expand Down Expand Up @@ -76,7 +78,7 @@ def reset(self, rng):
self.hand_size,
),
)

# every agent sees the hands in different order
_rngs = jax.random.split(rng_hands, self.num_agents)
permuted_hands = jax.vmap(
Expand All @@ -91,6 +93,137 @@ def reset(self, rng):
)
return jax.lax.stop_gradient(self.get_obs(state)), state

@partial(jax.jit, static_argnums=[0, 2, 3])
def reset_for_eval(self, rng, reset_mode="exact_match", replace=True):

def p_exact_match(masks):
target_mask, non_target_mask, _, _, _ = masks
p_hint = target_mask/jnp.sum(target_mask)
p_hinter_and_guesser_rest_of_hand = non_target_mask/jnp.sum(non_target_mask)
return p_hint, p_hinter_and_guesser_rest_of_hand

def p_similarity_match(masks):
_, _, random_similar_feature_exclude_target_mask, _, non_similar_feature_mask = masks
hint_p = random_similar_feature_exclude_target_mask/jnp.sum(random_similar_feature_exclude_target_mask)
p_hinter_and_guesser_rest_of_hand = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
return hint_p, p_hinter_and_guesser_rest_of_hand

def p_mutual_exclusive(masks):
_, _, _, _, non_similar_feature_mask = masks
p_non_sim = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
hint_flat_id = jax.random.choice(hint_rng,
card_space,
shape=(1,),
p=p_non_sim)
hint_mask = jax.nn.one_hot(x=hint_flat_id, num_classes=self.num_cards).flatten() # note this is also p_hint, as the chosen card has p=1
hinter_and_guesser_rest_of_hand_mask = jnp.logical_and(non_similar_feature_mask, jnp.logical_not(hint_mask))
p_hinter_and_guesser_rest_of_hand = hinter_and_guesser_rest_of_hand_mask/jnp.sum(hinter_and_guesser_rest_of_hand_mask)
return hint_mask, p_hinter_and_guesser_rest_of_hand

def p_mutual_exclusice_similarity(masks):
_, _, _, similar_feature_exclude_target_mask, non_similar_feature_mask = masks
p_hint = non_similar_feature_mask/jnp.sum(non_similar_feature_mask)
p_hinter_and_guesser_rest_of_hand = similar_feature_exclude_target_mask/jnp.sum(similar_feature_exclude_target_mask)
print(similar_feature_exclude_target_mask, p_hinter_and_guesser_rest_of_hand)
return p_hint, p_hinter_and_guesser_rest_of_hand

def shuffle_and_index(rng, players_hands):
def set_single_hand(hand, index):
empty_hands = jnp.zeros(5, dtype=jnp.int32)
return empty_hands.at[index].set(hand)
"""
generates a permutation mapping for the hands of the players such that the target_card and hint_card are tractable after the permutation
returns permuted hands, hint_card_index and target_card_index in the permuted hands of hinter and guesser
"""
rngs = jax.random.split(rng, 2)
permutation_index = jax.vmap(jax.random.permutation, in_axes=(0, None))(rngs, 5)
permuted_hands = jax.vmap(set_single_hand, in_axes=(0, 0))(players_hands, permutation_index)
return permuted_hands, permutation_index[0, 0], permutation_index[1, 0]

target_rng, hint_rng, hinter_hand_rngs, guesser_hand_rngs = jax.random.split(rng, 4)

# target randomisation
target_flat_id = jax.random.choice(target_rng, self.num_cards)
target_multi_id = jnp.array(jnp.unravel_index(target_flat_id, self.num_classes_per_feature))

#copy card space to ensure env is not modified
card_space = jnp.arange(self.num_cards)
card_feature_space = self.card_feature_space

# generate mask for exact match and non_exact match
target_mask = jnp.where(target_flat_id
== card_space,
1,
0).flatten()
non_target_mask = 1 - target_mask

# generate mask for similar cards for a randomly selected feature
random_feature_of_interest = jax.random.choice(hint_rng, self.num_features)
random_similar_feature_mask = jnp.where(card_feature_space[:, random_feature_of_interest]
== target_multi_id[random_feature_of_interest],
1,
0).flatten()
random_similar_feature_exclude_target_mask = non_target_mask * random_similar_feature_mask

# generate mask for all non-similar cards for all features
similar_feature_mask = jnp.zeros(self.num_cards)
non_similar_feature_mask = jnp.ones(self.num_cards)
for feature_dim in range(self.num_features):
# + is logical or operation, * is logical and operation
similar_feature_mask = similar_feature_mask + jnp.where(card_feature_space[:, feature_dim]
== target_multi_id[feature_dim],
1,
0).flatten()
non_similar_feature_mask = non_similar_feature_mask * jnp.where(card_feature_space[:, feature_dim]
!= target_multi_id[feature_dim],
1,
0).flatten()
similar_feature_mask_exculde_target = similar_feature_mask * non_target_mask

masks = (target_mask, non_target_mask, random_similar_feature_exclude_target_mask, similar_feature_mask_exculde_target, non_similar_feature_mask)
p_reset_modes = {
"exact_match": p_exact_match,
"similarity_match": p_similarity_match,
"mutual_exclusive": p_mutual_exclusive,
"mutual_exclusive_similarity": p_mutual_exclusice_similarity,
}
if reset_mode in p_reset_modes:
p_hint, p_other = p_reset_modes[reset_mode](masks)
else:
raise ValueError("reset_mode is not supported")

hinter_flat_id = jax.random.choice(hint_rng,
card_space,
shape=(1,),
replace=replace,
p=p_hint)

hinter_flat_rest_of_hand = jax.random.choice(hinter_hand_rngs,
card_space,
shape=(self.hand_size-1,),
replace=replace,
p=p_other)
if reset_mode == "mutual_exclusive" or reset_mode == "mutual_exclusive_similarity":
guesser_flat_rest_of_hand = hinter_flat_rest_of_hand
else:
guesser_flat_rest_of_hand = jax.random.choice(guesser_hand_rngs,
card_space,
shape=(self.hand_size-1,),
replace=replace,
p=p_other)

hinter_hand = jnp.append(hinter_flat_id, hinter_flat_rest_of_hand)
guesser_hand = jnp.append(target_flat_id, guesser_flat_rest_of_hand)

player_hands = jnp.stack((hinter_hand, guesser_hand))
rngs = jnp.stack((hinter_hand_rngs, guesser_hand_rngs))
permuted_hands, hint_indices, target_indices = jax.vmap(shuffle_and_index, in_axes=(0, None), out_axes=(0, 0, 0))(rngs, player_hands)
state = State(
player_hands=permuted_hands, target=target_flat_id, hint=-1, guess=-1, turn=0
)

return jax.lax.stop_gradient(self.get_obs(state)), state, hint_indices, target_indices

@partial(jax.jit, static_argnums=[0])
def step_env(self, rng, state, actions):

Expand Down Expand Up @@ -210,3 +343,15 @@ def get_onehot_encodings(self):
[jnp.concatenate(combination) for combination in list(product(*encodings))]
)
return encodings


if __name__ == "__main__":
# jax.config.update("jax_disable_jit", True)
env = HintGuessGame()
rng = jax.random.PRNGKey(10)
# reset_modes: exact_match, similarity_match, mutual_exclusive, mutual_exclusive_similarity
_, state, hints, targets = env.reset_for_eval(rng, reset_mode="similarity_match", replace=True)
print(jnp.arange(9).reshape(3, 3))
print(state)
print(hints)
print(targets)
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image
jax==0.4.17
jaxlib==0.4.17
jax==0.4.17.*
jaxlib==0.4.17.*
flax==0.7.4
chex==0.1.84
optax==0.1.7
Expand Down
Loading