Clip value of a variable in during Optimization Step? #5671
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think you need to do something like this: params = get_params(opt_state)
params = jax.tree_map(lambda x: jnp.clip(x, min_val, max_val), params)
# set params back into opt_state Which framework are you using? I think this is a bit nicer to do in |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
After setting up an optimization function as following. What I would like to do as indicated at the point in the code is once the gradients have been applied is to clip a particular variable values to be in a certain range. get_params(opt_state) will get you a dictionary of the variable values that I can then clip, but I don't know how (or if it is even possible) to then take the updated dictionary of variables and overwrite the opt_state.
Beta Was this translation helpful? Give feedback.
All reactions