diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 58339460..da292528 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest +import math +import pytest import torch from tests.test_utils import assert_expected from torch import nn + from torchmultimodal.modules.layers.position_embedding import ( BroadcastedPositionEmbedding, + RotaryPositionalEmbeddings, SinusoidalPositionEmbeddings, ) @@ -112,3 +115,38 @@ def test_forward(self, data, emb): actual = emb(data) expected = torch.Size([3, 5]) assert_expected(actual.shape, expected) + + +def test_rotary_embeddings_math(): + q = ( + torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze(0) + ) # b h s e + + k = 2 * torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze( + 0 + ) # b h s e + + rotary_embeddings = RotaryPositionalEmbeddings(2, 2, 1) + qr, kr = rotary_embeddings(q, k, 0) + rot0 = torch.tensor([[math.cos(0), -math.sin(0)], [math.sin(0), math.cos(0)]]) + rot1 = torch.tensor([[math.cos(1), -math.sin(1)], [math.sin(1), math.cos(1)]]) + + assert_expected(torch.matmul(rot0, q[..., 0, :].squeeze()), qr[..., 0, :].squeeze()) + assert_expected(torch.matmul(rot1, q[..., 1, :].squeeze()), qr[..., 1, :].squeeze()) + assert_expected(torch.matmul(rot0, k[..., 0, :].squeeze()), kr[..., 0, :].squeeze()) + assert_expected(torch.matmul(rot1, k[..., 1, :].squeeze()), kr[..., 1, :].squeeze()) + + +def test_rotary_embeddings_left_padding(): + q = torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e + k = 2 * torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e + rotary_embeddings = RotaryPositionalEmbeddings(16, 32) + + qr, kr = rotary_embeddings(q, k, 0) + qr2, kr2 = rotary_embeddings(q, k, torch.tensor([0, 1])) + + assert_expected(qr[0], qr2[0]) + assert_expected(qr[0, :, 1], qr2[1, :, 0]) + + assert_expected(kr[0], kr2[0]) + assert_expected(kr[0, :, 1], kr2[1, :, 0]) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 7920ce1c..acb4a927 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Tuple +from typing import Tuple, Union import torch from torch import nn, Tensor @@ -169,3 +169,108 @@ def forward(self, t: Tensor) -> Tensor: if self.embed_dim % 2 == 1: embeddings = nn.functional.pad(embeddings, (0, 1)) return embeddings + + +class RotaryPositionalEmbeddings(nn.Module): + def __init__( + self, + dim: int, + max_position_embeddings: Union[int, float] = 2048, + ratio: int = 10000, + device: torch.device = None, + ): + """ + Implements Rotary Positional Embeddings (RoPE) + proposed in: https://arxiv.org/abs/2104.09864 + + Args + ---- + dim : int + Per-head embedding dimension + max_position_embeddings : int + Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed + ratio: int + The ratio for the geometric progression to compute the rotation angles + """ + super().__init__() + self.register_buffer( + "freqs", + 1.0 + / ( + ratio + ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ), + ) + self.compute_freqs_cis(max_position_embeddings) + + def compute_freqs_cis( + self, max_position_embeddings: Union[int, float] = 2048 + ) -> None: + t = torch.arange( + max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype + ) + freqs = torch.outer(t, self.freqs).float() + self.max_seq_len_cached = max_position_embeddings + self.register_buffer( + "cached_freqs", + torch.stack( + [ + torch.cos(freqs), + -torch.sin(freqs), + torch.sin(freqs), + torch.cos(freqs), + ], + dim=2, + ).view(*freqs.shape, 2, 2), + ) + + def reshape_for_broadcast( + self, x: torch.Tensor, cur_freqs: torch.Tensor + ) -> torch.Tensor: + ndim = x.ndim + assert 1 < ndim + assert cur_freqs.shape[:2] == (x.shape[2], x.shape[-2]) + shape = [d if i == 2 or i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return cur_freqs.view(*shape, 2) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + start_pos: Union[int, float, torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args + ---- + q : torch.Tensor + Embedded query tensor, expected size is B x H x S x Eh + k : torch.Tensor + Embedded query tensor, expected size is B x H x S x Eh + start_pos : Union[int, torch.LongTensor] + The starting position of the tokens encoded in q and k. This is important in + kv-caching and left-padding situations, for which the rotation to be applied might + not always be the pre-cached position 0...S. For kv-caching without dynamic batching + start_pos is shared for all the batch. + """ + seq_len = q.shape[2] + q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2 + k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 + + if isinstance(start_pos, int): + if start_pos + seq_len > self.max_seq_len_cached: + self.compute_freqs_cis(start_pos + seq_len) + cur_freqs = self.cached_freqs[start_pos : start_pos + seq_len] + freqs = self.reshape_for_broadcast(q_, cur_freqs) + else: + max_start_pos = torch.max(start_pos).item() + if max_start_pos + seq_len > self.max_seq_len_cached: + self.compute_freqs_cis(max_start_pos + seq_len) + freqs_idxs = torch.arange(0, seq_len, dtype=torch.long).repeat( + start_pos.shape[0] + ).view(-1, seq_len) + start_pos.view(-1, 1) + freqs = self.cached_freqs[freqs_idxs].unsqueeze(1) + + freqs = freqs.float() # 1 1 L D/2 2 2 + q_out = freqs.mul(q_.unsqueeze(-2)).sum(5).flatten(3) + k_out = freqs.mul(k_.unsqueeze(-2)).sum(5).flatten(3) + return q_out.type_as(q).contiguous(), k_out.type_as(k).contiguous()