diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index e9ef3817..ed16d052 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -5,7 +5,13 @@ # LICENSE file in the root directory of this source tree. import torch -from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm, Fp32LayerNorm +from tests.test_utils import assert_expected + +from torchmultimodal.modules.layers.normalizations import ( + Fp32GroupNorm, + Fp32LayerNorm, + RMSNorm, +) def test_fp32layernorm(): @@ -20,3 +26,38 @@ def test_fp32groupnorm(): norm = Fp32GroupNorm(2, 4) output = norm(x) assert output.dtype == torch.float16 + + +def test_rms_norm_core_algo(): + """compare RMSNorm with RMSNorm using F.norm version""" + dims = 10 + rms_norm = RMSNorm(dims) + + input_ones = torch.ones(dims, dtype=torch.float) + + input_fixed = torch.tensor( + [0.999, 1.1111, 2.222, 3.333, 4.444, 5.555, 6.678, 7.987, 8.123, 9.101010], + dtype=torch.float16, + ) + fixed_expected = torch.tensor( + [ + 0.1749, + 0.1946, + 0.3892, + 0.5835, + 0.7783, + 0.9727, + 1.1699, + 1.3984, + 1.4229, + 1.5938, + ], + dtype=torch.float, + ) + + output_fixed = rms_norm(input_fixed) + output_ones = rms_norm(input_ones) + + assert_expected(output_ones, input_ones) + assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05) + assert output_fixed.dtype == torch.float32 diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index b4fc64b7..63623737 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -6,6 +6,7 @@ from typing import Any +import torch from torch import nn, Tensor @@ -45,3 +46,29 @@ def forward(self, x: Tensor) -> Tensor: self.eps, ) return output.type_as(x) + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization + as proposed in: https://arxiv.org/abs/1910.07467 + + Calcs are done in fp32. + + original impl: https://github.com/facebookresearch/llama/blob/main/llama/model.py + + Args: + dim(int) = model size + eps(float) = epsilon + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: Tensor) -> Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + x_normed = self._norm(x.float()).type_as(x) + return x_normed * self.scale