Skip to content

Help: How to prevent params of a hybrid module from becoming dynamic or traced array by optax? #707

Closed Answered by BraveDrXuTF
BraveDrXuTF asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @fabianp. Thank you for your response!
Yeah, I can indeed use debug to print traced array, but I strongly doubt that it is because I have not been able to use optax correctly in some ways that the parameters are transformed into traced-array after running apply_updates, while the parameters in another file in your gallery have NOT been transformed (I have tested them in colab with python print function, and proved parameters keep themselves as jnp.array) , even with @jax.jit :

# In your case params are not transformed.
@jax.jit
def train_step(params, net_state, solver_state, batch):
  # Performs a one step update.
  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(
…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@BraveDrXuTF
Comment options

Answer selected by fabianp
@fabianp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants