Use of jax.lax.stop_gradient in nn.softmax #24935
Answered
by
jakevdp
amishra791
asked this question in
Q&A
-
Hi, I was taking a look at the softmax function and it's deprecated version. I noticed that in the deprecated version has |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Nov 20, 2024
Replies: 1 comment 2 replies
-
The reason for this is that the non-deprecated softmax has a custom JVP, and so we don't differentiate through the implementation, which makes the |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
amishra791
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The reason for this is that the non-deprecated softmax has a custom JVP, and so we don't differentiate through the implementation, which makes the
stop_gradient
call unnecessary.