From 013299abd8a80d03a7fb9dd930d2d471912827d0 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 13 Aug 2024 12:38:35 +0300 Subject: [PATCH] fix layernorm once for all --- vllm/hpu/ops.py | 10 ++++++++++ vllm/model_executor/layers/layernorm.py | 11 +---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 8ae292b5413aa..7a40e6e720259 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -12,6 +12,16 @@ import torch.nn.functional as F import vllm.hpu.utils as hpu_utils +from vllm.logger import init_logger + +logger = init_logger() +HPUFusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + HPUFusedRMSNorm = FusedRMSNorm +except ImportError: + logger.warning("Could not import HPU FusedRMSNorm kernel. " + "vLLM will use forward_native implementation of RMSNorm.") PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e00cb9ca6e1ac..55cbbabd7da44 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,18 +6,8 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.utils import is_hpu logger = init_logger(__name__) -if is_hpu(): - try: - from habana_frameworks.torch.hpex.normalization import ( - FusedRMSNorm as HPUFusedRMSNorm) - except ImportError: - logger.warning( - "Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") - HPUFusedRMSNorm = None class RMSNorm(CustomOp): @@ -85,6 +75,7 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm.hpu.ops import HPUFusedRMSNorm if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: