From 1f273a09c683a33d4e21caa7f5c89f899deb30b7 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 04:37:55 +0000 Subject: [PATCH 1/8] add Rotary Embeddings main code --- .../modules/layers/position_embedding.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 7920ce1c..b244d9d0 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -169,3 +169,101 @@ def forward(self, t: Tensor) -> Tensor: if self.embed_dim % 2 == 1: embeddings = nn.functional.pad(embeddings, (0, 1)) return embeddings + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + ratio: int = 10000, + device=None, + ): + """ + Implements Rotary Positional Embeddings (RoPE) + proposed in: https://arxiv.org/abs/2104.09864 + + Args + ---- + dim : int + Per-head embedding dimension + max_position_embeddings : int + Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed + ratio: int + The ratio for the geometric progression to compute the rotation angles + """ + super(RotaryEmbedding, self).__init__() + self.register_buffer( + "freqs", + 1.0 + / ( + ratio + ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ), + ) + self.compute_freqs_cis(max_position_embeddings) + + def compute_freqs_cis(self, max_position_embeddings=2048): + t = torch.arange( + max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype + ) + freqs = torch.outer(t, self.freqs).float() + self.max_seq_len_cached = max_position_embeddings + self.register_buffer( + "cached_freqs", + torch.stack( + [ + torch.cos(freqs), + -torch.sin(freqs), + torch.sin(freqs), + torch.cos(freqs), + ], + dim=2, + ).view(*freqs.shape, 2, 2), + ) + + def reshape_for_broadcast(self, x: torch.Tensor, cur_freqs): + ndim = x.ndim + assert 1 < ndim + assert cur_freqs.shape[:2] == (x.shape[2], x.shape[-2]) + shape = [d if i == 2 or i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return cur_freqs.view(*shape, 2) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, torch.LongTensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args + ---- + q : torch.Tensor + Embedded query tensor, expected size is B x H x S x Eh + k : torch.Tensor + Embedded query tensor, expected size is B x H x S x Eh + start_pos : Union[int, torch.LongTensor] + The starting position of the tokens encoded in q and k. This is important in + kv-caching and left-padding situations, for which the rotation to be applied might + not always be the pre-cached position 0...S. For kv-caching without dynamic batching + start_pos is shared for all the batch. + """ + seq_len = q.shape[2] + q_ = q.float().reshape(*q.shape[:-1], -1, 2) # B H L D/2 2 + k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 + + if isinstance(start_pos, int): + if start_pos + seq_len > self.max_seq_len_cached: + self.compute_freqs_cis(start_pos + seq_len) + cur_freqs = self.cached_freqs[start_pos : start_pos + seq_len] + freqs = self.reshape_for_broadcast(q_, cur_freqs) + else: + max_start_pos = torch.max(start_pos).item() + if max_start_pos + seq_len > self.max_seq_len_cached: + self.compute_freqs_cis(max_start_pos + seq_len) + freqs_idxs = torch.arange(0, seq_len, dtype=torch.long).repeat( + start_pos.shape[0] + ).view(-1, seq_len) + start_pos.view(-1, 1) + freqs = self.cached_freqs[freqs_idxs].unsqueeze(1) + + freqs = freqs.float() # 1 1 L D/2 2 2 + q_out = freqs.mul(q_.unsqueeze(-2)).sum(5).flatten(3) + k_out = freqs.mul(k_.unsqueeze(-2)).sum(5).flatten(3) + return q_out.type_as(q).contiguous(), k_out.type_as(k).contiguous() From b01516a0f297b9d1b10ed81f5ae209985f98644c Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 04:59:09 +0000 Subject: [PATCH 2/8] add rotary position tests --- .../modules/layers/test_position_embedding.py | 40 ++++++++++++++++++- .../modules/layers/position_embedding.py | 6 +-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 58339460..da292528 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest +import math +import pytest import torch from tests.test_utils import assert_expected from torch import nn + from torchmultimodal.modules.layers.position_embedding import ( BroadcastedPositionEmbedding, + RotaryPositionalEmbeddings, SinusoidalPositionEmbeddings, ) @@ -112,3 +115,38 @@ def test_forward(self, data, emb): actual = emb(data) expected = torch.Size([3, 5]) assert_expected(actual.shape, expected) + + +def test_rotary_embeddings_math(): + q = ( + torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze(0) + ) # b h s e + + k = 2 * torch.tensor([[1, 0], [1, 0]], dtype=torch.float).unsqueeze(0).unsqueeze( + 0 + ) # b h s e + + rotary_embeddings = RotaryPositionalEmbeddings(2, 2, 1) + qr, kr = rotary_embeddings(q, k, 0) + rot0 = torch.tensor([[math.cos(0), -math.sin(0)], [math.sin(0), math.cos(0)]]) + rot1 = torch.tensor([[math.cos(1), -math.sin(1)], [math.sin(1), math.cos(1)]]) + + assert_expected(torch.matmul(rot0, q[..., 0, :].squeeze()), qr[..., 0, :].squeeze()) + assert_expected(torch.matmul(rot1, q[..., 1, :].squeeze()), qr[..., 1, :].squeeze()) + assert_expected(torch.matmul(rot0, k[..., 0, :].squeeze()), kr[..., 0, :].squeeze()) + assert_expected(torch.matmul(rot1, k[..., 1, :].squeeze()), kr[..., 1, :].squeeze()) + + +def test_rotary_embeddings_left_padding(): + q = torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e + k = 2 * torch.ones(2, 1, 4, 16, dtype=torch.float) # b h s e + rotary_embeddings = RotaryPositionalEmbeddings(16, 32) + + qr, kr = rotary_embeddings(q, k, 0) + qr2, kr2 = rotary_embeddings(q, k, torch.tensor([0, 1])) + + assert_expected(qr[0], qr2[0]) + assert_expected(qr[0, :, 1], qr2[1, :, 0]) + + assert_expected(kr[0], kr2[0]) + assert_expected(kr[0, :, 1], kr2[1, :, 0]) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index b244d9d0..7847793a 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Tuple +from typing import Tuple, Union import torch from torch import nn, Tensor @@ -171,7 +171,7 @@ def forward(self, t: Tensor) -> Tensor: return embeddings -class RotaryEmbedding(nn.Module): +class RotaryPositionalEmbeddings(nn.Module): def __init__( self, dim: int, @@ -192,7 +192,7 @@ def __init__( ratio: int The ratio for the geometric progression to compute the rotation angles """ - super(RotaryEmbedding, self).__init__() + super().__init__() self.register_buffer( "freqs", 1.0 From ceda766caf80ad383cba0bf883ead5d7c9907279 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 05:16:56 +0000 Subject: [PATCH 3/8] add missing types --- torchmultimodal/modules/layers/position_embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 7847793a..91c8db3c 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -222,7 +222,9 @@ def compute_freqs_cis(self, max_position_embeddings=2048): ).view(*freqs.shape, 2, 2), ) - def reshape_for_broadcast(self, x: torch.Tensor, cur_freqs): + def reshape_for_broadcast( + self, x: torch.Tensor, cur_freqs: torch.Tensor + ) -> torch.Tensor: ndim = x.ndim assert 1 < ndim assert cur_freqs.shape[:2] == (x.shape[2], x.shape[-2]) From aafa9b34c84fc1b8d2e7a224983d6a8b9dc85320 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 05:19:13 +0000 Subject: [PATCH 4/8] add int typing for embedding count --- torchmultimodal/modules/layers/position_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 91c8db3c..49b26250 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -203,7 +203,7 @@ def __init__( ) self.compute_freqs_cis(max_position_embeddings) - def compute_freqs_cis(self, max_position_embeddings=2048): + def compute_freqs_cis(self, max_position_embeddings: int = 2048): t = torch.arange( max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype ) From a2a98ba1e3411062d9abb84eb78741366d35fee4 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 05:20:26 +0000 Subject: [PATCH 5/8] add torch.device typing --- torchmultimodal/modules/layers/position_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 49b26250..894ea045 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -177,7 +177,7 @@ def __init__( dim: int, max_position_embeddings: int = 2048, ratio: int = 10000, - device=None, + device: torch.device = None, ): """ Implements Rotary Positional Embeddings (RoPE) From 15c746962b43c6d47bd79de16b757a96dacd7a2b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 18:41:09 +0000 Subject: [PATCH 6/8] add typing Union for compute_freq_cis --- torchmultimodal/modules/layers/position_embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 894ea045..419c2dd3 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -203,7 +203,9 @@ def __init__( ) self.compute_freqs_cis(max_position_embeddings) - def compute_freqs_cis(self, max_position_embeddings: int = 2048): + def compute_freqs_cis( + self, max_position_embeddings: Union[int, torch.LongTensor] = 2048 + ) -> None: t = torch.arange( max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype ) From 7c21fbaaf94419c837121e5c3815d4015e750e30 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 19:23:38 +0000 Subject: [PATCH 7/8] more typing - Union[int,float] for start_pos --- torchmultimodal/modules/layers/position_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 419c2dd3..5a378a4e 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -175,7 +175,7 @@ class RotaryPositionalEmbeddings(nn.Module): def __init__( self, dim: int, - max_position_embeddings: int = 2048, + max_position_embeddings: Union[int, float] = 2048, ratio: int = 10000, device: torch.device = None, ): @@ -204,7 +204,7 @@ def __init__( self.compute_freqs_cis(max_position_embeddings) def compute_freqs_cis( - self, max_position_embeddings: Union[int, torch.LongTensor] = 2048 + self, max_position_embeddings: Union[int, float] = 2048 ) -> None: t = torch.arange( max_position_embeddings, device=self.freqs.device, dtype=self.freqs.dtype @@ -234,7 +234,7 @@ def reshape_for_broadcast( return cur_freqs.view(*shape, 2) def forward( - self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, torch.LongTensor] + self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, float] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args From 06fe2a99ab45d4c0e14d7b8e9778808523802821 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Fri, 18 Aug 2023 21:55:13 +0000 Subject: [PATCH 8/8] add torch.LongTensor typing --- torchmultimodal/modules/layers/position_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 5a378a4e..acb4a927 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -234,7 +234,10 @@ def reshape_for_broadcast( return cur_freqs.view(*shape, 2) def forward( - self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, float] + self, + q: torch.Tensor, + k: torch.Tensor, + start_pos: Union[int, float, torch.LongTensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args