Skip to content

Commit

Permalink
adding gru_hidden_dim, correcting hanabi world state
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Mar 30, 2024
1 parent 3c3d1b7 commit 5edbef0
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 56 deletions.
13 changes: 7 additions & 6 deletions baselines/MAPPO/config/mappo_homogenous_ff_hanabi.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"ENV_NAME": "hanabi"
"ENV_KWARGS": {}
"LR": 5.0e-4
"NUM_ENVS": 1024
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e10
"FC_DIM_SIZE": 128
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
Expand All @@ -14,11 +13,13 @@
"VF_COEF": 0.5
"MAX_GRAD_NORM": 0.5
"ACTIVATION": "relu"
"LAYER_WIDTH": 512
"ENV_NAME": "hanabi"
"SEED": 30
"ANNEAL_LR": False
"NUM_SEEDS": 2
"ENV_KWARGS": {}
"ANNEAL_LR": True

# WandB Params
"WANDB_MODE": "disabled"
"ENTITY": ""
"PROJECT": "jaxmarl-hanabi"
"ENTITY": "amacrutherford"
"PROJECT": "jaxmarl-hanabi"
4 changes: 2 additions & 2 deletions baselines/MAPPO/config/mappo_homogenous_rnn_hanabi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
"ANNEAL_LR": True

# WandB Params
"WANDB_MODE": "disabled"
"ENTITY": ""
"WANDB_MODE": "online"
"ENTITY": "amacrutherford"
"PROJECT": "jaxmarl-hanabi"
74 changes: 41 additions & 33 deletions baselines/MAPPO/mappo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,33 @@ def world_state(self, obs, state):
"""
For each agent: [agent obs, own hand]
"""
all_obs = jnp.array([obs[agent] for agent in self._env.agents])
hands = state.player_hands.reshape((self._env.num_agents, -1))
return jnp.concatenate((all_obs, hands), axis=1)
return jnp.array([obs[agent] for agent in self._env.agents])
# hands = state.player_hands.reshape((self._env.num_agents, -1))
# return jnp.concatenate((all_obs, hands), axis=1)


def world_state_size(self):

return self._env.observation_space(self._env.agents[0]).n + 125 # NOTE hardcoded hand size
return self._env.observation_space(self._env.agents[0]).n #+ 125 # NOTE hardcoded hand size


class ActorFF(nn.Module):
action_dim: Sequence[int]
activation: str = "relu"
layer_dim: int = 512
config: Dict

@nn.compact
def __call__(self, x):
if self.activation == "relu":
if self.config["ACTIVATION"] == "relu":
activation = nn.relu
else:
activation = nn.tanh
obs, avail_actions = x
actor_mean = nn.Dense(
self.layer_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(obs)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.layer_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
action_logits = nn.Dense(
Expand All @@ -94,22 +93,21 @@ def __call__(self, x):


class CriticFF(nn.Module):
activation: str = "relu"
layer_dim: int = 512
config: Dict

@nn.compact
def __call__(self, x):
if self.activation == "relu":
if self.config["ACTIVATION"] == "relu":
activation = nn.relu
else:
activation = nn.tanh

critic = nn.Dense(
self.layer_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
self.layer_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
self.config["FC_DIM_SIZE"], kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(
Expand All @@ -119,6 +117,7 @@ def __call__(self, x):
return jnp.squeeze(critic, axis=-1)

class Transition(NamedTuple):
global_done: jnp.ndarray
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
Expand Down Expand Up @@ -167,21 +166,17 @@ def train(rng):
# INIT NETWORK
actor_network = ActorFF(
env.action_space(env.agents[0]).n,
activation=config["ACTIVATION"],
layer_dim=config["LAYER_WIDTH"],
)
critic_network = CriticFF(
activation=config["ACTIVATION"],
layer_dim=config["LAYER_WIDTH"],
config,
)
critic_network = CriticFF(config)
rng, _rng_actor, _rng_critic = jax.random.split(rng, 3)
ac_init_x = (
jnp.zeros((env.observation_space(env.agents[0]).n,)),
jnp.zeros((env.action_space(env.agents[0]).n,)),
)
actor_network_params = actor_network.init(_rng_actor, ac_init_x)

cr_init_x = jnp.zeros((env.world_state_size(),))
cr_init_x = jnp.zeros((658,)) # NOTE hardcoded >:(

critic_network_params = critic_network.init(_rng_critic, cr_init_x)

Expand Down Expand Up @@ -244,7 +239,8 @@ def _env_step(runner_state, unused):
env_act = unbatchify(
action, env.agents, config["NUM_ENVS"], env.num_agents
)

env_act = jax.tree_map(lambda x: x.squeeze(), env_act)

# VALUE
world_state = last_obs["world_state"].swapaxes(0,1)
world_state = world_state.reshape((config["NUM_ACTORS"],-1))
Expand All @@ -259,7 +255,8 @@ def _env_step(runner_state, unused):
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
done_batch,
jnp.tile(done["__all__"], env.num_agents),
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down Expand Up @@ -288,7 +285,7 @@ def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.global_done,
transition.value,
transition.reward,
)
Expand Down Expand Up @@ -325,7 +322,8 @@ def _actor_loss_fn(actor_params, traj_batch, gae):
log_prob = pi.log_prob(traj_batch.action)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
Expand All @@ -337,13 +335,18 @@ def _actor_loss_fn(actor_params, traj_batch, gae):
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean(where=(1 - traj_batch.done))
entropy = pi.entropy().mean(where=(1 - traj_batch.done))
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()

# debug
approx_kl = ((ratio - 1) - logratio).mean()
clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

actor_loss = (
loss_actor
- config["ENT_COEF"] * entropy
)
return actor_loss, (loss_actor, entropy)
return actor_loss, (loss_actor, entropy, ratio, approx_kl, clip_frac)

def _critic_loss_fn(critic_params, traj_batch, targets):
# RERUN NETWORK
Expand All @@ -356,7 +359,7 @@ def _critic_loss_fn(critic_params, traj_batch, targets):
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean(where=(1 - traj_batch.done))
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)
critic_loss = config["VF_COEF"] * value_loss
return critic_loss, (value_loss)
Expand All @@ -377,8 +380,11 @@ def _critic_loss_fn(critic_params, traj_batch, targets):
loss_info = {
"total_loss": total_loss,
"actor_loss": actor_loss[0],
"critic_loss": critic_loss[0],
"value_loss": critic_loss[0],
"entropy": actor_loss[1][1],
"ratio": actor_loss[1][2],
"approx_kl": actor_loss[1][3],
"clip_frac": actor_loss[1][4],
}

return (actor_train_state, critic_train_state), loss_info
Expand Down Expand Up @@ -439,11 +445,12 @@ def _critic_loss_fn(critic_params, traj_batch, targets):
)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
)
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)

)
train_states = update_state[0]
metric = traj_batch.info
loss_info["ratio_0"] = loss_info["ratio"].at[0,0].get()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric["loss"] = loss_info
rng = update_state[-1]

def callback(metric):
Expand All @@ -454,6 +461,7 @@ def callback(metric):
"env_step": metric["update_steps"]
* config["NUM_ENVS"]
* config["NUM_STEPS"],
**metric["loss"],
}
)

Expand Down
22 changes: 11 additions & 11 deletions baselines/MAPPO/mappo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def world_state(self, obs, state):
For each agent: [agent obs, own hand]
"""

all_obs = jnp.array([obs[agent] for agent in self._env.agents])
hands = state.player_hands.reshape((self._env.num_agents, -1))
return jnp.concatenate((all_obs, hands), axis=1)
return jnp.array([obs[agent] for agent in self._env.agents])
# hands = state.player_hands.reshape((self._env.num_agents, -1))
# return jnp.concatenate((all_obs, hands), axis=1)

@partial(jax.jit, static_argnums=0)
def world_state_size(self):

return self._env.observation_space(self._env.agents[0]).n + 125 # NOTE hardcoded hand size
return self._env.observation_space(self._env.agents[0]).n * self._env.num_agents # + 125 # NOTE hardcoded hand size

class ScannedRNN(nn.Module):
@functools.partial(
Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(self, hidden, x):
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)

actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(
actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=orthogonal(2), bias_init=constant(0.0))(
embedding
)
actor_mean = nn.relu(actor_mean)
Expand Down Expand Up @@ -200,14 +200,14 @@ def train(rng):
jnp.zeros((1, config["NUM_ENVS"])),
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)),
)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)
print('ac init x',ac_init_x)
cr_init_x = (
jnp.zeros((1, config["NUM_ENVS"], 658+125,)), # NOTE hardcoded
jnp.zeros((1, config["NUM_ENVS"], 658,)), # NOTE hardcoded
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
critic_network_params = critic_network.init(_rng_critic, cr_init_hstate, cr_init_x)

if config["ANNEAL_LR"]:
Expand All @@ -234,7 +234,7 @@ def train(rng):
tx=actor_tx,
)
critic_train_state = TrainState.create(
apply_fn=actor_network.apply,
apply_fn=critic_network.apply,
params=critic_network_params,
tx=critic_tx,
)
Expand All @@ -243,8 +243,8 @@ def train(rng):
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], 128)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

# TRAIN LOOP
def _update_step(update_runner_state, unused):
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ def train(rng):
jnp.zeros((1, config["NUM_ENVS"], env.observation_space(env.agents[0]).shape[0])),
jnp.zeros((1, config["NUM_ENVS"])),
)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)

cr_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.world_state_size(),)), # + env.observation_space(env.agents[0]).shape[0]
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
critic_network_params = critic_network.init(_rng_critic, cr_init_hstate, cr_init_x)

if config["ANNEAL_LR"]:
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ def train(rng):
jnp.zeros((1, config["NUM_ENVS"])),
jnp.zeros((1, config["NUM_ENVS"], env.action_space(env.agents[0]).n)),
)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
ac_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
actor_network_params = actor_network.init(_rng_actor, ac_init_hstate, ac_init_x)
cr_init_x = (
jnp.zeros((1, config["NUM_ENVS"], env.world_state_size(),)),
jnp.zeros((1, config["NUM_ENVS"])),
)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
cr_init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], config["GRU_HIDDEN_DIM"])
critic_network_params = critic_network.init(_rng_critic, cr_init_hstate, cr_init_x)

if config["ANNEAL_LR"]:
Expand Down

0 comments on commit 5edbef0

Please sign in to comment.