diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 2306749fcb..00c7322331 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -1,11 +1,16 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import torch +import intel_extension_for_pytorch from deepspeed.accelerator.real_accelerator import get_accelerator if get_accelerator().device_name() == 'cuda': from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from apex.normalization import MixedFusedRMSNorm as RMSNorm else: - from .rmsnorm import RMSNorm + if hasattr(torch.xpu, "IpexRmsNorm"): + from .fused_rmsnorm import RMSNorm + else: + from .rmsnorm import RMSNorm from torch.nn import LayerNorm from .distributed import DistributedDataParallel diff --git a/megatron/model/fused_rmsnorm.py b/megatron/model/fused_rmsnorm.py new file mode 100644 index 0000000000..a3b9927c87 --- /dev/null +++ b/megatron/model/fused_rmsnorm.py @@ -0,0 +1,14 @@ +import torch +from torch.nn.parameter import Parameter +import intel_extension_for_pytorch as ipex # noqa + +# Taken from facebookresearch/llama +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = Parameter(torch.ones(dim)) + + def forward(self, x): + output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps) + return output