Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Alibi positional embeddings #462

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

lessw2020
Copy link
Contributor

Summary:
This PR adds Alibi positional embeddings class. (per the Alibi paper https://arxiv.org/abs/2108.12409)
This generates the Alibi attn mask to be added post QKT/sqrt(k.dim) and replaces the usual sinusoidal type positional embeddings.
The usage is designed to be instantiated outside the transformer block loop based on max_seq_length, and the layers retrieve the attn mask based on current seq length (thus only a single mask buffer needs to be created).

Test plan:
I tested by running in a 200M gpt2 model along with 10% of openwebtext to compare curves between learned embeddings (default in gpt2) and alibi.
alibi_training_curves

I also added a unit test with three tests:
a - shape of the alibi mask
b - verify first head row entry
c - verify last head last row entry
Note that half the mask is -inf, but in trying to use allclose with -inf, they will not match...so I targeted entries that have only real numbers.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2023
@lessw2020
Copy link
Contributor Author

unit test failure is not related.

@lessw2020
Copy link
Contributor Author

test failure is not related - appears to be rounding issue:
test_model.py::TestAudioMaskedAutoEncoder::test_audio_mae_train_masking - AssertionError: actual: 512.999755859375, expected: 513.

@lessw2020
Copy link
Contributor Author

I reran the training with the updates to confirm:
proper_causal_mask

@codecov-commenter
Copy link

codecov-commenter commented Sep 7, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (1fd96dc) 74.01% compared to head (c54548f) 74.13%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #462      +/-   ##
==========================================
+ Coverage   74.01%   74.13%   +0.12%     
==========================================
  Files         207      207              
  Lines       14203    14274      +71     
==========================================
+ Hits        10512    10582      +70     
- Misses       3691     3692       +1     
Files Coverage Δ
...rchmultimodal/modules/layers/position_embedding.py 100.00% <100.00%> (ø)
tests/modules/layers/test_position_embedding.py 98.76% <97.29%> (-1.24%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

return self.alibi_mask[..., :curr_seq_len, :curr_seq_len]

@classmethod
def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw there is also the get_causal_attention_mask utility (you may even be able to use get_extended_attention_mask from the same file in lieu of the repeat, it does broadcast to an extra dim for batch size though)

max_seq_len: int,
num_heads: int,
) -> None:
"""recommended usage: create alibi mask before transformer block loop and integrate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a bit tricky. Kinda similar to RoPE embeddings: integrating this properly will necessitate rethinking some aspects of our transformer implementation. For instance, seems like one assumption here is that our transformer's mask should be float dtype and not bool

@@ -169,3 +170,108 @@ def forward(self, t: Tensor) -> Tensor:
if self.embed_dim % 2 == 1:
embeddings = nn.functional.pad(embeddings, (0, 1))
return embeddings


class AlibiPositionEmbeddings(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level q: if we not using model forward and mostly using class/static methods, why not just define as a function? Offhand I don't see a reason why this needs to be stateful (it's very possible I'm missing something though)

@staticmethod
def get_slopes(num_heads: int) -> List[float]:
"""for n heads, a range from (0,1) and is the geometric sequence
that starts at 2^(-8/n) and uses this same value as its ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining/documenting the magic numbers 🙂

return get_slopes_power_of_2(num_heads)

# paper authors note that they only trained models that have 2^a heads for some a.
# This has beneficial properties related to input being power of 2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what these properties are? Tbh I am confused by this because even if n is a power of 2 some of the ratios will not be rational for n > 8

b = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
: num_heads - closest_power_of_2
]
return [x for pair in zip(b, a) for x in pair] + a[len(b) :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo this is hard to parse. Agree with @daviswer's comment about returning values in order but could we just do sorted(a+b)? (Maybe I'm missing a tricky case.. if so a comment explaining this would suffice instead)

# paper authors note that they only trained models that have 2^a heads for some a.
# This has beneficial properties related to input being power of 2.

# Closest power of 2 below is workaround for when num of heads is not power of 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their method of interpolating is a bit unusual. Maybe explicitly explain that for $num \textunderscore heads=2^N + k$ they are splicing the geometric series with ratio $2^{-\frac{8}{N}}$ with the first $2k$ elements of the geometric series with ratio $2^{-\frac{8}{N+1}}$ (assuming I am even understanding it correctly 😅)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants