Skip to content

Commit

Permalink
fix init issue for rms_norm in squence_parallel (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 authored Oct 4, 2024
1 parent 598c092 commit 8be7f48
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 8 additions & 2 deletions megatron/model/fused_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from megatron import get_args

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import intel_extension_for_pytorch as ipex # noqa

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))
self.weight = Parameter(torch.ones(dim,
dtype=get_args().params_dtype))
self.sequence_parallel = sequence_parallel
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)

def forward(self, x):
output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps)
Expand Down
5 changes: 3 additions & 2 deletions megatron/model/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False):
super().__init__()
self.eps = eps
init_device = None
Expand All @@ -19,7 +19,8 @@ def __init__(self, dim: int, eps: float = 1e-6):
device=init_device,
dtype=get_args().params_dtype))
init.ones_(self.weight)
setattr(self.weight, 'sequence_parallel', sequence_parallel)
self.sequence_parallel = sequence_parallel
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
Expand Down

0 comments on commit 8be7f48

Please sign in to comment.