Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks #450

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
40 changes: 39 additions & 1 deletion tests/modules/layers/test_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import math

import pytest
import torch
from tests.test_utils import assert_expected
from torch import nn

from torchmultimodal.modules.layers.position_embedding import (
BroadcastedPositionEmbedding,
RotaryPositionalEmbeddings,
SinusoidalPositionEmbeddings,
)

Expand Down Expand Up @@ -112,3 +115,38 @@ def test_forward(self, data, emb):
actual = emb(data)
expected = torch.Size([3, 5])
assert_expected(actual.shape, expected)


def test_rotary_embeddings_math():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put these unit tests into a class? (Similar to the other tests in this file)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will do.

q = (
torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze(0)
) # b h s e

k = 2 * torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze(
0
) # b h s e

rotary_embeddings = RotaryPositionalEmbeddings(2, 2, 1)
qr, kr = rotary_embeddings(q, k, 0)
rot0 = torch.tensor([[math.cos(0), -math.sin(0)], [math.sin(0), math.cos(0)]])
rot1 = torch.tensor([[math.cos(1), -math.sin(1)], [math.sin(1), math.cos(1)]])

assert_expected(torch.matmul(rot0, q[..., 0, :].squeeze()), qr[..., 0, :].squeeze())
assert_expected(torch.matmul(rot1, q[..., 1, :].squeeze()), qr[..., 1, :].squeeze())
assert_expected(torch.matmul(rot0, k[..., 0, :].squeeze()), kr[..., 0, :].squeeze())
assert_expected(torch.matmul(rot1, k[..., 1, :].squeeze()), kr[..., 1, :].squeeze())


def test_rotary_embeddings_left_padding():
q = torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e
k = 2 * torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e
rotary_embeddings = RotaryPositionalEmbeddings(16, 32)

qr, kr = rotary_embeddings(q, k, 0)
qr2, kr2 = rotary_embeddings(q, k, torch.tensor([0, 1]))

assert_expected(qr[0], qr2[0])
assert_expected(qr[0, :, 1], qr2[1, :, 0])

assert_expected(kr[0], kr2[0])
assert_expected(kr[0, :, 1], kr2[1, :, 0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test for updating the cached frequencies? (As far as I can tell this second test is not hitting that block in L262-268, lmk if I'm misunderstanding)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's a good idea.

107 changes: 106 additions & 1 deletion torchmultimodal/modules/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Tuple
from typing import Tuple, Union

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -169,3 +169,108 @@ def forward(self, t: Tensor) -> Tensor:
if self.embed_dim % 2 == 1:
embeddings = nn.functional.pad(embeddings, (0, 1))
return embeddings


class RotaryPositionalEmbeddings(nn.Module):
def __init__(
self,
dim: int,
max_position_embeddings: Union[int, float] = 2048,
ratio: int = 10000,
device: torch.device = None,
):
"""
Implements Rotary Positional Embeddings (RoPE)
proposed in: https://arxiv.org/abs/2104.09864

Args
----
dim : int
Per-head embedding dimension
max_position_embeddings : int
Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed
ratio: int
The ratio for the geometric progression to compute the rotation angles
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to add more in the docstring on the exact details of these embeddings, e.g. at least the [[cos, -sin], [sin, cos]] matrix and maybe even a small example (like the simple 2D one you wrote for the unit test)

super().__init__()
self.register_buffer(
"freqs",
1.0
/ (
ratio
** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
),
)
self.compute_freqs_cis(max_position_embeddings)

def compute_freqs_cis(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random q: what does cis mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's short form for rotation transform technically doing e^(alpha*i) = cos(alpha) + i * sin(alpha), or shortened, cos + i * sin = cis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably add that in the docstring actually, otherwise it's too cryptic.

self, max_position_embeddings: Union[int, float] = 2048
) -> None:
t = torch.arange(
max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype
)
freqs = torch.outer(t, self.freqs).float()
self.max_seq_len_cached = max_position_embeddings
self.register_buffer(
"cached_freqs",
torch.stack(
[
torch.cos(freqs),
-torch.sin(freqs),
torch.sin(freqs),
torch.cos(freqs),
],
dim=2,
).view(*freqs.shape, 2, 2),
)

def reshape_for_broadcast(
self, x: torch.Tensor, cur_freqs: torch.Tensor
) -> torch.Tensor:
ndim = x.ndim
assert 1 < ndim
assert cur_freqs.shape[:2] == (x.shape[2], x.shape[-2])
shape = [d if i == 2 or i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return cur_freqs.view(*shape, 2)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
start_pos: Union[int, float, torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args
----
q : torch.Tensor
Embedded query tensor, expected size is B x H x S x Eh
k : torch.Tensor
Embedded query tensor, expected size is B x H x S x Eh
start_pos : Union[int, torch.LongTensor]
The starting position of the tokens encoded in q and k. This is important in
kv-caching and left-padding situations, for which the rotation to be applied might
not always be the pre-cached position 0...S. For kv-caching without dynamic batching
start_pos is shared for all the batch.
"""
seq_len = q.shape[2]
q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2
k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2

if isinstance(start_pos, int):
if start_pos + seq_len > self.max_seq_len_cached:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments here about when the frequencies need to be recomputed might be helpful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good - offhand should be changing dtype, changing device, and resetting seq len > max_seq_len.

self.compute_freqs_cis(start_pos + seq_len)
cur_freqs = self.cached_freqs[start_pos : start_pos + seq_len]
freqs = self.reshape_for_broadcast(q_, cur_freqs)
else:
max_start_pos = torch.max(start_pos).item()
if max_start_pos + seq_len > self.max_seq_len_cached:
self.compute_freqs_cis(max_start_pos + seq_len)
freqs_idxs = torch.arange(0, seq_len, dtype=torch.long).repeat(
start_pos.shape[0]
).view(-1, seq_len) + start_pos.view(-1, 1)
freqs = self.cached_freqs[freqs_idxs].unsqueeze(1)

freqs = freqs.float() # 1 1 L D/2 2 2
q_out = freqs.mul(q_.unsqueeze(-2)).sum(5).flatten(3)
k_out = freqs.mul(k_.unsqueeze(-2)).sum(5).flatten(3)
return q_out.type_as(q).contiguous(), k_out.type_as(k).contiguous()
Loading