Skip to content

Commit

Permalink
better use of flashbax in shaq
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Dec 7, 2023
1 parent 90d5e69 commit 3ddb2f9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
2 changes: 1 addition & 1 deletion baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pip install -r requirements/requirements-qlearning.txt
❗The implementations were tested in the following environments:
- MPE
- SMAX
- Hanabi
```

## 🔎 Implementation Details
Expand Down Expand Up @@ -57,7 +58,6 @@ If you have cloned JaxMARL and you are in the repository root, you can run the a
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
# VDN with MPE spread
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
python baselines/QLearning/qmix.py +alg=qmix_mpe +env=mpe_spread
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMix with hanabi
Expand Down
76 changes: 49 additions & 27 deletions baselines/QLearning/shaq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from jaxmarl.wrappers.baselines import CTRolloutManager
from jaxmarl.wrappers.baselines import LogWrapper, SMAXLogWrapper, CTRolloutManager

class ScannedRNN(nn.Module):

Expand Down Expand Up @@ -299,6 +299,7 @@ class Transition(NamedTuple):
actions: dict
rewards: dict
dones: dict
infos: dict

def make_train(config, env):

Expand All @@ -323,20 +324,21 @@ def _env_sample_step(env_state, unused):
key_a = jax.random.split(key_a, env.num_agents)
actions = {agent: wrapped_env.batch_sample(key_a[i], agent) for i, agent in enumerate(env.agents)}
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
transition = Transition(obs, actions, rewards, dones)
transition = Transition(obs, actions, rewards, dones, infos)
return env_state, transition
_, sample_traj = jax.lax.scan(
_env_sample_step, env_state, None, config["NUM_STEPS"]
)
sample_traj_unbatched = jax.tree_map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim
buffer = fbx.make_flat_buffer(
max_length=config['BUFFER_SIZE'],
min_length=config['BUFFER_BATCH_SIZE'],
buffer = fbx.make_trajectory_buffer(
max_length_time_axis=config['BUFFER_SIZE']//config['NUM_ENVS'],
min_length_time_axis=config['BUFFER_BATCH_SIZE'],
sample_batch_size=config['BUFFER_BATCH_SIZE'],
add_sequences=True,
add_batch_size=None,
add_batch_size=config['NUM_ENVS'],
sample_sequence_length=1,
period=1,
)
buffer_state = buffer.init(sample_traj_unbatched)
buffer_state = buffer.init(sample_traj_unbatched)


# INIT NETWORK
Expand Down Expand Up @@ -458,7 +460,7 @@ def _env_step(step_state, unused):

# STEP ENV
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
transition = Transition(last_obs, actions, rewards, dones)
transition = Transition(last_obs, actions, rewards, dones, infos)

step_state = (params, env_state, obs, dones, hstate, rng, t+1)
return step_state, transition
Expand Down Expand Up @@ -486,7 +488,10 @@ def _env_step(step_state, unused):
)

# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_traj_batch = jax.tree_util.tree_map(lambda x:jnp.swapaxes(x, 0, 1), traj_batch) # put the batch size (num envs) in first axis
buffer_traj_batch = jax.tree_util.tree_map(
lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim
traj_batch
) # (num_envs, 1, time_steps, ...)
buffer_state = buffer.add(buffer_state, buffer_traj_batch)

# LEARN PHASE
Expand Down Expand Up @@ -582,8 +587,11 @@ def _td_lambda_target(ret, values):

# sample a batched trajectory from the buffer and set the time step dim in first axis
rng, _rng = jax.random.split(rng)
learn_traj = buffer.sample(buffer_state, _rng).experience.first # (batch_size, max_time_steps, ...)
learn_traj = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), learn_traj) # (max_time_steps, batch_size, ...)
learn_traj = buffer.sample(buffer_state, _rng).experience # (batch_size, 1, max_time_steps, ...)
learn_traj = jax.tree_map(
lambda x: jnp.swapaxes(x[:, 0], 0, 1), # remove the dummy sequence dim (1) and swap batch and temporal dims
learn_traj
) # (max_time_steps, batch_size, ...)
if config["PARAMETERS_SHARING"]:
init_hs = ScannedRNN.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents)*config["BUFFER_BATCH_SIZE"]) # (n_agents*batch_size, hs_size)
else:
Expand Down Expand Up @@ -634,22 +642,29 @@ def _td_lambda_target(ret, values):
'timesteps': time_state['timesteps']*config['NUM_ENVS'],
'updates' : time_state['updates'],
'loss': loss,
'rewards': jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards)
'rewards': jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards),
'eps': explorer.get_epsilon(time_state['timesteps'])
}
metrics.update(test_metrics) # add the test metrics dictionary
metrics['test_metrics'] = test_metrics # add the test metrics dictionary

if config.get('WANDB_ONLINE_REPORT', False):
def callback(metrics):
def callback(metrics, infos):
info_metrics = {
k:v[...,0][infos["returned_episode"][..., 0]].mean()
for k,v in infos.items() if k!="returned_episode"
}
wandb.log(
{
"returns": metrics['rewards']['__all__'].mean(),
"test_returns": metrics['test_returns']['__all__'].mean(),
"timestep": metrics['timesteps'],
"updates": metrics['updates'],
"loss": metrics['loss'],
'epsilon': metrics['eps'],
**info_metrics,
**{k:v.mean() for k, v in metrics['test_metrics'].items()}
}
)
jax.debug.callback(callback, metrics)
jax.debug.callback(callback, metrics, traj_batch.infos)

runner_state = (
train_state_agent,
Expand All @@ -676,10 +691,10 @@ def _greedy_env_step(step_state, unused):
obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones)
hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_)
actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, wrapped_env.valid_actions)
actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions)
step_state = (params, env_state, obs, dones, hstate, rng)
return step_state, (rewards, dones)
return step_state, (rewards, dones, infos)
rng, _rng = jax.random.split(rng)
init_obs, env_state = test_env.batch_reset(_rng)
init_dones = {agent:jnp.zeros((config["NUM_TEST_EPISODES"]), dtype=bool) for agent in env.agents+['__all__']}
Expand All @@ -696,23 +711,25 @@ def _greedy_env_step(step_state, unused):
hstate,
_rng,
)
step_state, rews_dones = jax.lax.scan(
step_state, (rewards, dones, infos) = jax.lax.scan(
_greedy_env_step, step_state, None, config["NUM_STEPS"]
)
# compute the episode returns of the first episode that is done for each parallel env
# compute the metrics of the first episode that is done for each parallel env
def first_episode_returns(rewards, dones):
first_done = jax.lax.select(jnp.argmax(dones)==0., dones.size, jnp.argmax(dones))
first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False)
return jnp.where(first_episode_mask, rewards, 0.).sum()
all_dones = rews_dones[1]['__all__']
returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rews_dones[0])
all_dones = dones['__all__']
first_returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards)
first_infos = jax.tree_map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos)
metrics = {
'test_returns': returns # episode returns
'test_returns': first_returns['__all__'],# episode returns
**{'test_'+k:v for k,v in first_infos.items()}
}
if config.get('VERBOSE', False):
def callback(timestep, val):
print(f"Timestep: {timestep}, return: {val}")
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], returns['__all__'].mean())
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], first_returns['__all__'].mean())
return metrics

time_state = {
Expand Down Expand Up @@ -754,9 +771,14 @@ def main(config):
alg_name = f'shaq_{"ps" if config["alg"].get("PARAMETERS_SHARING", True) else "ns"}'

# smac init neeeds a scenario
if 'SMAC' in env_name:
if 'smax' in env_name.lower():
config['env']['ENV_KWARGS']['scenario'] = map_name_to_scenario(config['env']['MAP_NAME'])
env_name = 'jaxmarl_'+config['env']['MAP_NAME']
env_name = f"{config['env']['ENV_NAME']}_{config['env']['MAP_NAME']}"
env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
env = SMAXLogWrapper(env)
else:
env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
env = LogWrapper(env)

env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
config["alg"]["NUM_STEPS"] = config["alg"].get("NUM_STEPS", env.max_steps) # default steps defined by the env
Expand Down

0 comments on commit 3ddb2f9

Please sign in to comment.