Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSNorm (part of parallel_attn_blocks) #448

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion tests/modules/layers/test_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm, Fp32LayerNorm
from tests.test_utils import assert_expected

from torchmultimodal.modules.layers.normalizations import (
Fp32GroupNorm,
Fp32LayerNorm,
RMSNorm,
)


def test_fp32layernorm():
Expand All @@ -20,3 +26,38 @@ def test_fp32groupnorm():
norm = Fp32GroupNorm(2, 4)
output = norm(x)
assert output.dtype == torch.float16


def test_rms_norm_core_algo():
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
"""compare RMSNorm with RMSNorm using F.norm version"""
dims = 10
rms_norm = RMSNorm(dims)

input_ones = torch.ones(dims, dtype=torch.float)

input_fixed = torch.tensor(
[0.999, 1.1111, 2.222, 3.333, 4.444, 5.555, 6.678, 7.987, 8.123, 9.101010],
dtype=torch.float16,
)
fixed_expected = torch.tensor(
[
0.1749,
0.1946,
0.3892,
0.5835,
0.7783,
0.9727,
1.1699,
1.3984,
1.4229,
1.5938,
],
dtype=torch.float,
)

output_fixed = rms_norm(input_fixed)
output_ones = rms_norm(input_ones)

assert_expected(output_ones, input_ones)
assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05)
assert output_fixed.dtype == torch.float32
27 changes: 27 additions & 0 deletions torchmultimodal/modules/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any

import torch
from torch import nn, Tensor


Expand Down Expand Up @@ -45,3 +46,29 @@ def forward(self, x: Tensor) -> Tensor:
self.eps,
)
return output.type_as(x)


class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization
as proposed in: https://arxiv.org/abs/1910.07467

lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
Calcs are done in fp32.

original impl: https://github.com/facebookresearch/llama/blob/main/llama/model.py

Args:
dim(int) = model size
eps(float) = epsilon
"""

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))

def _norm(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
x_normed = self._norm(x.float()).type_as(x)
return x_normed * self.scale
Loading