You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi everyone, hope you are doing well! I'm working on a research project with the DeepMind JAX ecosystem (Haiku, Optax), but for some reason, I find that when I train over a dataset, the training loss doesn't go down, as shown in this screenshot.
I'm trying to do something pretty simple: train Random Network Distillation (https://arxiv.org/abs/1810.12894, https://github.com/deepmind/acme/tree/master/acme/agents/jax/rnd) on an offline dataset of D4RL MuJoCo data. I tried a few sanity checks, including training on one random data point for some number of iterations. That loss also doesn't go down: it basically stays at 0.005 for 1000 straight epochs (shown in below screenshots):
Here are some snippets:
RND neural network + trainer code:
classRNDTrainState(NamedTuple):
params: hk.Paramstarget_params: hk.Paramsopt_state: optax.OptStateclassMLPRNDModel(hk.Module):
def__init__(self, cfg):
super().__init__()
self.encoder=hk.nets.MLP(
[cfg.hidden_dim, cfg.hidden_dim],
activation=jax.nn.swish
)
self.predictor=RNDPredictor(cfg)
def__call__(self, obs):
reprs=self.encoder(obs)
returnself.predictor(reprs)
classRNDModelTrainer:
'''RND model trainer.'''def__init__(self, cfg):
self.cfg=cfgifcfg.taskinMUJOCO_ENVS:
rnd_fn=lambdao: MLPRNDModel(cfg.d4rl)(o)
else:
rnd_fn=lambdao: ConvRNDModel(cfg.vd4rl)(o)
self.rnd=hk.without_apply_rng(hk.transform(rnd_fn))
# paramskey=jax.random.PRNGKey(cfg.seed)
k1, k2=jax.random.split(key)
rnd_params=self.rnd.init(k1, jnp.zeros((1,) +tuple(cfg.obs_shape)))
target_params=self.rnd.init(k2, jnp.zeros((1,) +tuple(cfg.obs_shape)))
# optimizerself.rnd_opt=optax.adam(cfg.lr)
rnd_opt_state=self.rnd_opt.init(rnd_params)
self.train_state=RNDTrainState(
params=rnd_params,
target_params=target_params,
opt_state=rnd_opt_state
)
@functools.partial(jax.jit, static_argnames=('self',))defrnd_loss_fn(self, params, target_params, obs):
output=self.rnd.apply(params, obs)
target_output=self.rnd.apply(target_params, obs)
# no need to do jax.lax.stop_gradient, as gradient is only taken w.r.t. first paramreturnjnp.mean(jnp.square(target_output-output))
@functools.partial(jax.jit, static_argnames=('self',))defupdate(self, obs, step):
delsteploss_grad_fn=jax.value_and_grad(self.rnd_loss_fn)
loss, grads=loss_grad_fn(self.train_state.params, self.train_state.target_params, obs)
update, new_opt_state=self.rnd_opt.update(grads, self.train_state.opt_state)
new_params=optax.apply_updates(self.train_state.params, update)
metrics= {
'rnd_loss': loss
}
new_train_state=RNDTrainState(
params=new_params,
target_params=self.train_state.target_params,
opt_state=new_opt_state
)
returnnew_train_state, metrics
Training loop code:
deftrain_rnd(self):
'''Train RND model offline.'''forepochintrange(1, self.cfg.model_train_epochs+1):
epoch_metrics=defaultdict(AverageMeter)
forbatchinself.rnd_dataloader:
obs, _, _, _, _=batchself.rnd_trainer.train_state, batch_metrics=self.rnd_trainer.update(obs, self.global_step)
fork, vinbatch_metrics.items():
epoch_metrics[k].update(v, obs.shape[0])
ifself.cfg.wandb:
log_dump= {k: v.value() fork, vinepoch_metrics.items()}
wandb.log(log_dump)
ifself.cfg.save_modelandepoch%self.cfg.model_save_every==0:
model_path=self.pretrained_rnd_dir/f'rnd_{epoch}.pkl'self.rnd_trainer.save(model_path)
deftrain_one_datapoint(self):
'''Train on one datapoint for sanity checking. Loss SHOULD converge to 0.'''self.rng, subkey=jax.random.split(self.rng)
rand_datapoint=jax.random.normal(key=subkey, shape=(1,) +tuple(self.cfg.obs_shape), dtype=jnp.float32)
forepochintrange(1, self.cfg.model_train_epochs+1):
self.rnd_trainer.train_state, metrics=self.rnd_trainer.update(rand_datapoint, self.global_step)
print(f'metrics for epoch {epoch}: {metrics["rnd_loss"]}')
ifself.cfg.wandb:
wandb.log(metrics)
where self refers to a workspace with an experiment config cfg where I train and save everything of interest.
As shown, I use the optax.adam optimizer with learning rate 1e-3. This I think is standard (maybe a bit large, but I've swept through a few learning rates both larger and smaller to get the same results).
I'm wondering where I am going wrong in this training approach--I think I have it correct, but there's something that I'm certainly missing that I don't know about. Any help would be greatly appreciated! If you guys have any additional questions, I'll be happy to send you updates either on here or through a video chat. Also let me know if the Optax repo is the right place to send this msg--I don't think this is an issue yet (more on me than on the package) so I'm putting it in the discussions tab.
Regarding package versions, I am using Haiku 0.0.7, Optax 0.1.3, JAX 0.3.16 on CUDA for these experiments. I love the framework by the way!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi everyone, hope you are doing well! I'm working on a research project with the DeepMind JAX ecosystem (Haiku, Optax), but for some reason, I find that when I train over a dataset, the training loss doesn't go down, as shown in this screenshot.
I'm trying to do something pretty simple: train Random Network Distillation (https://arxiv.org/abs/1810.12894, https://github.com/deepmind/acme/tree/master/acme/agents/jax/rnd) on an offline dataset of D4RL MuJoCo data. I tried a few sanity checks, including training on one random data point for some number of iterations. That loss also doesn't go down: it basically stays at 0.005 for 1000 straight epochs (shown in below screenshots):
Here are some snippets:
RND neural network + trainer code:
Training loop code:
where
self
refers to a workspace with an experiment configcfg
where I train and save everything of interest.As shown, I use the
optax.adam
optimizer with learning rate1e-3
. This I think is standard (maybe a bit large, but I've swept through a few learning rates both larger and smaller to get the same results).I'm wondering where I am going wrong in this training approach--I think I have it correct, but there's something that I'm certainly missing that I don't know about. Any help would be greatly appreciated! If you guys have any additional questions, I'll be happy to send you updates either on here or through a video chat. Also let me know if the Optax repo is the right place to send this msg--I don't think this is an issue yet (more on me than on the package) so I'm putting it in the discussions tab.
Regarding package versions, I am using Haiku 0.0.7, Optax 0.1.3, JAX 0.3.16 on CUDA for these experiments. I love the framework by the way!
Beta Was this translation helpful? Give feedback.
All reactions