Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks #450

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

lessw2020
Copy link
Contributor

Summary:
Adds Rotary Positional Embeddings (RoPE)

Test plan:
two unit tests - one for math, one for padding

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 17, 2023
@codecov-commenter
Copy link

codecov-commenter commented Aug 17, 2023

Codecov Report

Patch coverage: 96.55% and project coverage change: +0.13% 🎉

Comparison is base (951a452) 69.11% compared to head (15c7469) 69.24%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #450      +/-   ##
==========================================
+ Coverage   69.11%   69.24%   +0.13%     
==========================================
  Files         170      170              
  Lines       11524    11580      +56     
==========================================
+ Hits         7965     8019      +54     
- Misses       3559     3561       +2     
Files Changed Coverage Δ
...rchmultimodal/modules/layers/position_embedding.py 97.50% <94.28%> (-2.50%) ⬇️
tests/modules/layers/test_position_embedding.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a few minor things, mainly around testing and comments

@@ -112,3 +115,38 @@ def test_forward(self, data, emb):
actual = emb(data)
expected = torch.Size([3, 5])
assert_expected(actual.shape, expected)


def test_rotary_embeddings_math():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put these unit tests into a class? (Similar to the other tests in this file)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will do.

return cur_freqs.view(*shape, 2)

def forward(
self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, float]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes sense to have start_pos default to 0? (My assumption is that this would at least be the starting point for most users)

Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed
ratio: int
The ratio for the geometric progression to compute the rotation angles
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to add more in the docstring on the exact details of these embeddings, e.g. at least the [[cos, -sin], [sin, cos]] matrix and maybe even a small example (like the simple 2D one you wrote for the unit test)

assert_expected(qr[0, :, 1], qr2[1, :, 0])

assert_expected(kr[0], kr2[0])
assert_expected(kr[0, :, 1], kr2[1, :, 0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test for updating the cached frequencies? (As far as I can tell this second test is not hitting that block in L262-268, lmk if I'm misunderstanding)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's a good idea.

k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2

if isinstance(start_pos, int):
if start_pos + seq_len > self.max_seq_len_cached:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments here about when the frequencies need to be recomputed might be helpful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good - offhand should be changing dtype, changing device, and resetting seq len > max_seq_len.

)
self.compute_freqs_cis(max_position_embeddings)

def compute_freqs_cis(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random q: what does cis mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's short form for rotation transform technically doing e^(alpha*i) = cos(alpha) + i * sin(alpha), or shortened, cos + i * sin = cis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably add that in the docstring actually, otherwise it's too cryptic.

@facebook-github-bot
Copy link
Contributor

@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high level comment, but let's maybe create a modules/layers/embeddings folder in the future as we might have multiple embedding layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants