From 0793eb497d8f08b5a41a1046ca40983c3e09072c Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 26 Sep 2023 08:59:47 -0700 Subject: [PATCH] add SimpleRMSNorm (#465) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This PR implements SimpleRMSNorm, as proposed in: Scaling TransNormer to 175 Billion Parameters and is: " In TransNormerLLM, we replace the RMSNorm with a new simple normalization function called SimpleRMSNorm, abbreviated as SRMSNorm: SRMSNorm(x) = x / ∥x∥2/√d We empirically find that using SRMSNorm does not lead to any performance loss, as demonstrated in the ablation study [below]: Norm Type Params Updates Loss PPL SRMSNorm 385M 100K 2.247 4.765 RMSNorm 385M 100K 2.247 4.766 LayerNorm 385M 100K 2.247 4.765 " note that their architecture is not a TransFormer but a TransNormer...therefore, I tested this on gpt2 transformer and saw equivalent results between LayerNorm and SimpleRMSNorm as below: simpleRMS_gpt2 In addition, SimpleRMSNorm is ~ 34% faster vs regular RMSNorm (eager mode comparison). Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/465 Test Plan: Tested on GPT2 training as shown above, and have added 4 unit tests (2 for BF16 and 2 for FP32 dtypes). Reviewed By: ebsmothers Differential Revision: D49638459 Pulled By: pbontrager fbshipit-source-id: 203b2bdd95dd79a5817060d85fc5920c6523733a --- tests/modules/layers/test_normalizations.py | 61 +++++++++++++++++++ .../modules/layers/normalizations.py | 22 +++++++ 2 files changed, 83 insertions(+) diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index ed16d052..762f60e1 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -11,6 +11,7 @@ Fp32GroupNorm, Fp32LayerNorm, RMSNorm, + SimpleRMSNorm, ) @@ -61,3 +62,63 @@ def test_rms_norm_core_algo(): assert_expected(output_ones, input_ones) assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05) assert output_fixed.dtype == torch.float32 + + +def test_simple_rmsnorm(): + dims = 12 + srms_norm = SimpleRMSNorm(dims) + + input_bf16_ones = torch.ones(dims, dtype=torch.bfloat16) + + input_fixed_fp32 = torch.tensor( + [ + 0.999, + 1.1111, + 2.222, + 3.333, + 4.444, + 5.555, + 6.678, + 7.987, + 8.123, + 9.101010, + 110.00, + 120.2589, + ], + dtype=torch.float32, + ) + + expected_output_bf16_ones = torch.tensor( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + dtype=torch.bfloat16, + ) + expected_output_fixed = torch.tensor( + [ + 0.0211, + 0.0235, + 0.0469, + 0.0704, + 0.0939, + 0.1174, + 0.1411, + 0.1687, + 0.1716, + 0.1923, + 2.3238, + 2.5405, + ], + dtype=torch.float32, + ) + + actual_output_bf16_ones = srms_norm(input_bf16_ones) + actual_output_fixed = srms_norm(input_fixed_fp32) + + # verify ones output and dtype + assert_expected( + actual_output_bf16_ones, expected_output_bf16_ones, atol=1e-04, rtol=1e-05 + ) + assert actual_output_bf16_ones.dtype == torch.bfloat16 + + # verify fixed output and dtype + assert_expected(actual_output_fixed, expected_output_fixed, atol=1e-04, rtol=1e-05) + assert actual_output_fixed.dtype == torch.float32 diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index 63623737..9e835f9f 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -72,3 +72,25 @@ def _norm(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: x_normed = self._norm(x.float()).type_as(x) return x_normed * self.scale + + +class SimpleRMSNorm(nn.Module): + """Simple RMSNorm + + SRMSNorm(x) = (x / ∥x∥2) /√d + + as proposed in: + Scaling TransNormer to 175 Billion Parameters + https://arxiv.org/abs/2307.14995 + + Usage: use as drop in replacement for RMSNorm. + """ + + def __init__(self, dim: int, eps: float = 1e-12): + super().__init__() + self.scaling = dim**0.5 + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + denom = x.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(x) + return (x / denom) * self.scaling