Skip to content

Commit

Permalink
fix layernorm once for all
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Aug 13, 2024
1 parent e6c4086 commit 013299a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
10 changes: 10 additions & 0 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.logger import init_logger

logger = init_logger()
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
HPUFusedRMSNorm = FusedRMSNorm
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')

Expand Down
11 changes: 1 addition & 10 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,8 @@

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_hpu

logger = init_logger(__name__)
if is_hpu():
try:
from habana_frameworks.torch.hpex.normalization import (
FusedRMSNorm as HPUFusedRMSNorm)
except ImportError:
logger.warning(
"Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedRMSNorm = None


class RMSNorm(CustomOp):
Expand Down Expand Up @@ -85,6 +75,7 @@ def forward_hpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm.hpu.ops import HPUFusedRMSNorm
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
Expand Down

0 comments on commit 013299a

Please sign in to comment.