Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parallel_attention_blocks #446

Closed
wants to merge 2 commits into from
Closed

add parallel_attention_blocks #446

wants to merge 2 commits into from

Conversation

lessw2020
Copy link
Contributor

Summary:
This PR adds Parallel Attention blocks to Torch MultiModal.
There are 3 main additions:
a - Rotary Embeddings
b - RMS Norm
c - Parallel_Blocks

Test plan:
Code has been tested separately in ViT and LLM applications and with rotary unit tests.
However, a general unit test needs to be added along with rotary unit tests being migrated.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 15, 2023
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking forward to the unittests, will do more thorough review after that

@@ -45,3 +45,17 @@ def forward(self, x: Tensor) -> Tensor:
self.eps,
)
return output.type_as(x)


class RMSNorm(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring


* We use SwiGLU for the activation function
* SwiGLU will approximate same total num params as traditional MLP with GELU
* Cross Attention is not enabled here (but available)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "but available" mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah - we support cross attention in the main codebase for parallel attention. However, it's not really useful for TMM, so I removed it for this PR. (hence the 'but available'). Let me expand that comment to note it's in the main codebase.

* We use SwiGLU for the activation function
* SwiGLU will approximate same total num params as traditional MLP with GELU
* Cross Attention is not enabled here (but available)
* MQA and GQA are enabled - modify heads via 'num_heads_group_query_attn'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems for GQA, how would I use MQA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easy (but I will expand this in the comments to clarify):
num_heads_group_query_attn = 1 and you now have MQA.
num_heads_group_query_attn > 1 and < Q_heads = GQA.
I should probably also clarify that MQA heads needs to be a multiple of Q heads (there is an assert check so you ultimately can't miss it, but might be nicer to note in docstring).

Comment on lines +246 to +247
q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2
k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we care about making this scriptable it might be good to ditch the list unpacking

@lessw2020 lessw2020 closed this by deleting the head repository Aug 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants