diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index e9b6e28fa6bcb..9d40813a98d7a 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -2,13 +2,13 @@ name: clang-format on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: clang-format: diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 5780f09a646cb..c2674b914f485 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -2,13 +2,13 @@ name: mypy on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: ruff: @@ -50,4 +50,6 @@ jobs: mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml + mypy vllm/hpu --config-file pyproject.toml + diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd966..a2b7aa2549af9 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -2,13 +2,13 @@ name: ruff on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: ruff: diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 04f307bcf8b0e..4e0d67c5b59d6 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -2,13 +2,13 @@ name: yapf on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: yapf: runs-on: ubuntu-latest diff --git a/format.sh b/format.sh index 5ad6d6f2938bb..fbfc27a68bb3d 100755 --- a/format.sh +++ b/format.sh @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml +mypy vllm/hpu --config-file pyproject.toml # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 3748eb3544dd1..7a40e6e720259 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -12,6 +12,16 @@ import torch.nn.functional as F import vllm.hpu.utils as hpu_utils +from vllm.logger import init_logger + +logger = init_logger() +HPUFusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + HPUFusedRMSNorm = FusedRMSNorm +except ImportError: + logger.warning("Could not import HPU FusedRMSNorm kernel. " + "vLLM will use forward_native implementation of RMSNorm.") PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') @@ -52,8 +62,7 @@ def paged_attention_v1(query, keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = [torch.matmul(query, k) for k in keys] - attn_weights = torch.cat(attn_weights, dim=-1) + attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) @@ -128,7 +137,8 @@ def prompt_attention( query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) - attn_bias = attn_bias.unsqueeze(2) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(2) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 01429d2fcbd17..55cbbabd7da44 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,19 +6,8 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.utils import is_hpu logger = init_logger(__name__) -if is_hpu(): - try: - from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as - HPUFusedRMSNorm - ) - except ImportError: - logger.warning( - "Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") - HPUFusedRMSNorm = None class RMSNorm(CustomOp): @@ -86,6 +75,7 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm.hpu.ops import HPUFusedRMSNorm if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: