-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add parallel_attention_blocks #446
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking forward to the unittests, will do more thorough review after that
@@ -45,3 +45,17 @@ def forward(self, x: Tensor) -> Tensor: | |||
self.eps, | |||
) | |||
return output.type_as(x) | |||
|
|||
|
|||
class RMSNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring
|
||
* We use SwiGLU for the activation function | ||
* SwiGLU will approximate same total num params as traditional MLP with GELU | ||
* Cross Attention is not enabled here (but available) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does "but available" mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah - we support cross attention in the main codebase for parallel attention. However, it's not really useful for TMM, so I removed it for this PR. (hence the 'but available'). Let me expand that comment to note it's in the main codebase.
* We use SwiGLU for the activation function | ||
* SwiGLU will approximate same total num params as traditional MLP with GELU | ||
* Cross Attention is not enabled here (but available) | ||
* MQA and GQA are enabled - modify heads via 'num_heads_group_query_attn' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems for GQA, how would I use MQA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easy (but I will expand this in the comments to clarify):
num_heads_group_query_attn = 1 and you now have MQA.
num_heads_group_query_attn > 1 and < Q_heads = GQA.
I should probably also clarify that MQA heads needs to be a multiple of Q heads (there is an assert check so you ultimately can't miss it, but might be nicer to note in docstring).
q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2 | ||
k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we care about making this scriptable it might be good to ditch the list unpacking
Summary:
This PR adds Parallel Attention blocks to Torch MultiModal.
There are 3 main additions:
a - Rotary Embeddings
b - RMS Norm
c - Parallel_Blocks
Test plan:
Code has been tested separately in ViT and LLM applications and with rotary unit tests.
However, a general unit test needs to be added along with rotary unit tests being migrated.