Skip to content

Commit

Permalink
Enable FusedSDPA for prompt attention with VLLM_PROMPT_USE_FUSEDSDPA (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta authored Aug 19, 2024
1 parent f7dd91d commit 275e325
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 24 deletions.
29 changes: 20 additions & 9 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -166,6 +167,12 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -223,15 +230,18 @@ def forward(
if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads,
self.head_size)
Expand All @@ -247,6 +257,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand Down
58 changes: 46 additions & 12 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')

Expand Down Expand Up @@ -126,6 +133,21 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
return final_hidden_states.view(-1, D)


#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -136,24 +158,36 @@ def prompt_attention(
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None or HPUFusedSDPA is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights
7 changes: 4 additions & 3 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ class HpuModelAdapter():

def __init__(self, model, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']

if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
Expand All @@ -159,7 +162,7 @@ def __init__(self, model, enforce_eager):
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
prefill_metadata = attn_metadata
if prefill_metadata is None:
if prefill_metadata is None or self.prefill_use_fusedsdpa:
return attn_metadata

seq_lens_t = prefill_metadata.seq_lens_tensor
Expand Down Expand Up @@ -599,7 +602,6 @@ def _prepare_prompt(
# actual prompt lens
context_lens.append(context_len)
query_lens.append(seq_len - context_len)

input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
Expand Down Expand Up @@ -672,7 +674,6 @@ def _prepare_prompt(
max_prompt_len = max(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
Expand Down

0 comments on commit 275e325

Please sign in to comment.