Skip to content

Commit

Permalink
[Refactor] Rename HabanaAttention -> HPUAttention (#362)
Browse files Browse the repository at this point in the history
I've missed the attention backend in
#359
  • Loading branch information
kzawora-intel authored Oct 7, 2024
1 parent 1f6de5d commit ad08dd4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention,
HabanaPagedAttentionMetadata)
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class HabanaAttentionBackend(AttentionBackend):
class HPUAttentionBackend(AttentionBackend):

@staticmethod
def get_impl_cls() -> Type["HabanaAttentionImpl"]:
return HabanaAttentionImpl
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HabanaAttentionMetadata
return HPUAttentionMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
Expand All @@ -41,37 +41,36 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return HabanaPagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
HabanaPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache,
src_to_dst)
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
HabanaPagedAttention.copy_blocks(kv_caches, src_to_dists)
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class HabanaAttentionMetadata(HabanaPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HabanaAttentionbackend."""
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]


class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
Expand Down Expand Up @@ -126,7 +125,7 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes()
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
Expand All @@ -138,7 +137,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HabanaAttentionMetadata,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
Expand All @@ -158,7 +157,7 @@ def forward(
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HabanaAttentionImpl")
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

Expand All @@ -171,7 +170,7 @@ def forward(
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HabanaPagedAttention.split_kv_cache(
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
Expand Down Expand Up @@ -216,7 +215,7 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HabanaPagedAttention.forward_decode(
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@dataclass
class HabanaPagedAttentionMetadata:
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
Expand All @@ -22,7 +22,7 @@ class HabanaPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]


class HabanaPagedAttention:
class HPUPagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
Expand Down Expand Up @@ -76,7 +76,7 @@ def forward_prefix(
sliding_window: Optional[int],
) -> torch.Tensor:
raise NotImplementedError(
"forward_prefix is not implemented for HabanaPagedAttention")
"forward_prefix is not implemented for HPUPagedAttention")

@staticmethod
def swap_blocks(
Expand Down
13 changes: 6 additions & 7 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HABANA_ATTN = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()

Expand Down Expand Up @@ -143,11 +143,10 @@ def get_attn_backend(
logger.info("Using Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.HABANA_ATTN:
logger.info("Using HabanaAttention backend.")
from vllm.attention.backends.habana_attn import ( # noqa: F401
HabanaAttentionBackend)
return HabanaAttentionBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
Expand Down Expand Up @@ -217,7 +216,7 @@ def which_attn_to_use(
return _Backend.ROCM_FLASH

if current_platform.is_hpu():
return _Backend.HABANA_ATTN
return _Backend.HPU_ATTN

# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
Expand Down

0 comments on commit ad08dd4

Please sign in to comment.