why separate apply_updates from update? #155
-
From the flax documentation, a simple GD algorithm runs as follows: tx = optax.sgd(learning_rate=alpha)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss)
for i in range(101):
loss_val, grads = loss_grad_fn(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates) Why is the From reading the documentation it seems that the separation is there to be able to use optimisers with extra things like gradient clipping which aren't in the default optimizers (sgd, adam, etc..). Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hello, Separating the transformation of updates from their application to the params has several advantages
|
Beta Was this translation helpful? Give feedback.
Hello,
Separating the transformation of updates from their application to the params has several advantages
chain
(e.g. you might want to create custom optimisers by chaining together different existing gradient transformations,
without having to rewrite the entire thing as a single monolithic optimiser,
For a very trivial example, you might want to first clip gradient then rescale them using Adam, or viceversa.
but you may also do more sophisticated combinations.
If you take a look at alias.py you can see many popular optimisers are actually build from a relatively small set of primitives,
by freely combining these you can experiment …