Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel_Attention_Blocks (3 of 3) #457

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

lessw2020
Copy link
Contributor

Summary:
This PR adds the main and final part for upstreaming Parallel Attention Blocks, specifically the actual Parallel Attn Blocks class itself.
This supports both MHA, MQA and GQA attention head setups.
RMSNorm has already landed and so can be used.
Cross Attention has been removed as requested.

Test plan:
Added 3 unit tests - one each for Parallel Attn Blocks using MHA, MQA and GQA.
Within each test, the number of query heads and number of KV heads are tested to ensure appropriate head counts (i.e. for GQA, num_KV heads is set to 2 and then verified.
From there, a Parallel Attn Layer runs a forward pass on a fixed single input tensor and the first row of the output is checked as is the attn_output shape.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 27, 2023
@codecov-commenter
Copy link

codecov-commenter commented Aug 27, 2023

Codecov Report

Patch coverage: 95.12% and project coverage change: +0.34% 🎉

Comparison is base (1b4f79f) 69.78% compared to head (2156f2f) 70.12%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #457      +/-   ##
==========================================
+ Coverage   69.78%   70.12%   +0.34%     
==========================================
  Files         175      177       +2     
  Lines       11992    12156     +164     
==========================================
+ Hits         8369     8525     +156     
- Misses       3623     3631       +8     
Files Changed Coverage Δ
...rchmultimodal/modules/layers/parallel_attention.py 91.66% <91.66%> (ø)
tests/modules/layers/test_parallel_attention.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall, @ebsmothers do you mind taking a pass through this as well?

# confirm num Q matches num_heads
assert_expected(num_heads, mha_parallel_attention.num_heads)

# input_ones = torch.ones(dims, dtype=torch.float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rm

fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim])

assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4)
assert_expected(fixed_output_shape, attn_output.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, do we / should we do any additional testing besides verifying the first row of outputs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This plus shape should be sufficient. One nitpick would be to assert on something besides the first row (maybe the mean over an axis) just because I have seen cases where the first row is actually correct but others are not (e.g. if there is a bug in masking logic).

)
fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim])
assert_expected(fixed_output_shape, attn_output.shape)
# print(f"{attn_output[0][0]}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

rel_pos_bias: Optional[torch.Tensor] = None,
has_causal_mask: bool = False,
) -> torch.Tensor:
"""TODO: No KV cache support yet"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we file an issue for this?

fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim])

assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4)
assert_expected(fixed_output_shape, attn_output.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This plus shape should be sufficient. One nitpick would be to assert on something besides the first row (maybe the mean over an axis) just because I have seen cases where the first row is actually correct but others are not (e.g. if there is a bug in masking logic).

return 32

@pytest.fixture
def mha_parallel_attention(self, embedding_dim, num_heads, total_layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to be careful about for these unit tests: if you are not explicitly initializing the params of the modules then the test results will be sensitive to the order in which submodules are initialized. In the past we've seen cases where some otherwise no-op change breaks tests just cause of changes in initialization order. We have the util init_weights_with_constant for this, but the tradeoff is that it also makes the test case a lot more trivial (since all weights are 1s).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!
Currently the parallel attention blocks have their own full init that automatically happens.
So that covers this concern though I think I should add a comment that this is the assumption so that if that breaks in the future, the reader can quickly ascertain what might be going awry.

# from position_embedding import RotaryEmbedding


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: is this more efficient than repeat_interleave? I thought in this case the extra memory would be allocated either way

# swiglu
activated_mlp = self.mlp_activation(inner_mlp) * gate

if self.mlp_dropout.p:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the if statement here? Isn't it just a no-op if p=0.0 anyways?

q, k = self.rotary_emb(q, k, start_pos)

# group query expansion
def kv_expansion(head: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we either make this a method of the class or a standalone function? I feel the nested function harms readability here

emb_dimension: int,
num_heads: int,
head_dimension: int = None,
mlp_expansion_ratio: float = 2.6875, # 8/3 is param matching
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason not to just use an integer value here and skip the multiplication by ratio?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh there is a reason for this actually - the idea was to mimic the same param count as when most people do mlp_expansion =4, but keep it within a power of 8 regime.
A common mistake made is when people compare swiglu vs say GeLU they will claim swiglu is slower...but what's really happening is that since swiglu uses a gate you have more total params if you leave the same 4.0 multiplication factor and simply drop in swiglu in the mlp.
Therefore, this 2.6875 gives you ~ the same params as activation + mul factor 4, but within the power of 8 regime (this was from a paper that showed this gave you a slight efficiency gain...can't remember if it was power of 8 or 16 but something like that). (Hence the comment that exact match is 8/3).

# input_ones = torch.ones(dims, dtype=torch.float)

x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) # bs =1,
attn_output = mha_parallel_attention(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also consider adding a test case covering mask and/or rel_pos_bias args


# input_ones = torch.ones(dims, dtype=torch.float)

x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) # bs =1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably just make this a fixture (since I think it's used in all test cases)


self.num_q = 1

self.in_proj_dims = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment either here or on self.in_proj definition about the fusing you're doing. I think fusing of the MLP and gate could be a bit unusual for those unfamiliar with parallel attention


from torchmultimodal.modules.layers.normalizations import RMSNorm

# from position_embedding import RotaryEmbedding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

@rohan-varma rohan-varma self-requested a review September 25, 2023 23:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants