Replies: 2 comments
-
Alternatively, I guess what I am asking is how to release the memory occupied by the graph to compute I find in jax, the forward part of |
Beta Was this translation helpful? Give feedback.
-
Hey @zw615, there are 3 options of how to stop gradients in JAX:
For more info checkout Flax's Transfer Learning guide. |
Beta Was this translation helpful? Give feedback.
-
Hi there! I want to turn off gradient computations when a model forwards some input, which is important in some cases like accumulative gradient implementation when memory is limited. I have searched the issues and discussion panels and find this post: #1937, which talks about
jax.lax.stop_gradient
. However, I find the code below only disables the gradient flow through the opjax.lax.stop_gradient
, but still performs computational graph building/tracing. As a result, the accumulative gradient technique does not save memory at all. I wonder how I can extract features without any gradient operation, just like inference undertorch.no_grad
?Thanks a lot!
Note that this accumulative gradient implementation is different from the one commonly used in supervised training like here https://github.com/google-research/big_vision/blob/47ac2fd075fcb66cadc0e39bd959c78a6080070d/big_vision/utils.py#L296. This implementation is useful in contrastive learning like CLIP.
Beta Was this translation helpful? Give feedback.
All reactions