Skip to content

Differences between the attention computation in JAX and Torch #24694

Closed Answered by AakashKumarNain
AakashKumarNain asked this question in Q&A
Discussion options

You must be logged in to vote

@dfm I got it. The dot_product_attention(...) in torch accepts BNTH while in JAX it accepts BTNH. JAX implementation makes more sense as I am aware why torch has that implementation. Guess I overlooked the dimensions in the documentation.

Replies: 2 comments

Comment options

dfm
Nov 4, 2024
Collaborator

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by AakashKumarNain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants