Skip to content

Logging gradients norms before clipping #1026

Closed Answered by vroulet
AakashKumarNain asked this question in Q&A
Discussion options

You must be logged in to vote

You may simply create a custom GradientTransform that does not touch the update, just computes the norm, put it in the state (or even print it if you want). Then you do the usual chain except that you insert that custom transform just before the clipping.

You may then fetch the gradient norm from the overall state using optax.tree_utils.tree_get

Something along the following lines
´´´
class RecordNormState(typing.NamedTuple):
grad_norm: jax.Array

def record_norm():
def init_fn(params)
return RecordNormState(grad_norm=jax.as_array(0))

def update_fn(updates, state, params=None):
return updates, RecordNormState(grad_norm=optax.tree_utils.tree_l2_norm(updates))

return optax.GradientTransforma…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AakashKumarNain
Comment options

Answer selected by AakashKumarNain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants