diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index ed16d052..762f60e1 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -11,6 +11,7 @@ Fp32GroupNorm, Fp32LayerNorm, RMSNorm, + SimpleRMSNorm, ) @@ -61,3 +62,63 @@ def test_rms_norm_core_algo(): assert_expected(output_ones, input_ones) assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05) assert output_fixed.dtype == torch.float32 + + +def test_simple_rmsnorm(): + dims = 12 + srms_norm = SimpleRMSNorm(dims) + + input_bf16_ones = torch.ones(dims, dtype=torch.bfloat16) + + input_fixed_fp32 = torch.tensor( + [ + 0.999, + 1.1111, + 2.222, + 3.333, + 4.444, + 5.555, + 6.678, + 7.987, + 8.123, + 9.101010, + 110.00, + 120.2589, + ], + dtype=torch.float32, + ) + + expected_output_bf16_ones = torch.tensor( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + dtype=torch.bfloat16, + ) + expected_output_fixed = torch.tensor( + [ + 0.0211, + 0.0235, + 0.0469, + 0.0704, + 0.0939, + 0.1174, + 0.1411, + 0.1687, + 0.1716, + 0.1923, + 2.3238, + 2.5405, + ], + dtype=torch.float32, + ) + + actual_output_bf16_ones = srms_norm(input_bf16_ones) + actual_output_fixed = srms_norm(input_fixed_fp32) + + # verify ones output and dtype + assert_expected( + actual_output_bf16_ones, expected_output_bf16_ones, atol=1e-04, rtol=1e-05 + ) + assert actual_output_bf16_ones.dtype == torch.bfloat16 + + # verify fixed output and dtype + assert_expected(actual_output_fixed, expected_output_fixed, atol=1e-04, rtol=1e-05) + assert actual_output_fixed.dtype == torch.float32 diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index 63623737..9e835f9f 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -72,3 +72,25 @@ def _norm(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: x_normed = self._norm(x.float()).type_as(x) return x_normed * self.scale + + +class SimpleRMSNorm(nn.Module): + """Simple RMSNorm + + SRMSNorm(x) = (x / ∥x∥2) /√d + + as proposed in: + Scaling TransNormer to 175 Billion Parameters + https://arxiv.org/abs/2307.14995 + + Usage: use as drop in replacement for RMSNorm. + """ + + def __init__(self, dim: int, eps: float = 1e-12): + super().__init__() + self.scaling = dim**0.5 + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + denom = x.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(x) + return (x / denom) * self.scaling