From 8be7f4895968574d44ae06785fdbe115baa1c13c Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Fri, 4 Oct 2024 18:01:49 +0800 Subject: [PATCH] fix init issue for rms_norm in squence_parallel (#448) --- megatron/model/fused_rmsnorm.py | 10 ++++++++-- megatron/model/rmsnorm.py | 5 +++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/megatron/model/fused_rmsnorm.py b/megatron/model/fused_rmsnorm.py index a3b9927c87..ab1c42e24e 100644 --- a/megatron/model/fused_rmsnorm.py +++ b/megatron/model/fused_rmsnorm.py @@ -1,13 +1,19 @@ +from megatron import get_args + import torch from torch.nn.parameter import Parameter +from torch.nn import init import intel_extension_for_pytorch as ipex # noqa # Taken from facebookresearch/llama class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False): super().__init__() self.eps = eps - self.weight = Parameter(torch.ones(dim)) + self.weight = Parameter(torch.ones(dim, + dtype=get_args().params_dtype)) + self.sequence_parallel = sequence_parallel + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) def forward(self, x): output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps) diff --git a/megatron/model/rmsnorm.py b/megatron/model/rmsnorm.py index 4860d81716..7bcaec37ef 100644 --- a/megatron/model/rmsnorm.py +++ b/megatron/model/rmsnorm.py @@ -9,7 +9,7 @@ # Taken from facebookresearch/llama class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False): super().__init__() self.eps = eps init_device = None @@ -19,7 +19,8 @@ def __init__(self, dim: int, eps: float = 1e-6): device=init_device, dtype=get_args().params_dtype)) init.ones_(self.weight) - setattr(self.weight, 'sequence_parallel', sequence_parallel) + self.sequence_parallel = sequence_parallel + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)