Skip to content

Use of jax.lax.stop_gradient in nn.softmax #24935

Answered by jakevdp
amishra791 asked this question in Q&A
Discussion options

You must be logged in to vote

The reason for this is that the non-deprecated softmax has a custom JVP, and so we don't differentiate through the implementation, which makes the stop_gradient call unnecessary.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@amishra791
Comment options

@jakevdp
Comment options

Answer selected by amishra791
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants