Skip to content

Commit

Permalink
add fused_rms_norm support on XPU device (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 authored Aug 4, 2024
1 parent 1bfc35c commit 53b241f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
7 changes: 6 additions & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import torch
import intel_extension_for_pytorch
from deepspeed.accelerator.real_accelerator import get_accelerator
if get_accelerator().device_name() == 'cuda':
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from apex.normalization import MixedFusedRMSNorm as RMSNorm
else:
from .rmsnorm import RMSNorm
if hasattr(torch.xpu, "IpexRmsNorm"):
from .fused_rmsnorm import RMSNorm
else:
from .rmsnorm import RMSNorm
from torch.nn import LayerNorm

from .distributed import DistributedDataParallel
Expand Down
14 changes: 14 additions & 0 deletions megatron/model/fused_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch.nn.parameter import Parameter
import intel_extension_for_pytorch as ipex # noqa

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))

def forward(self, x):
output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps)
return output

0 comments on commit 53b241f

Please sign in to comment.