Skip to content

Commit

Permalink
remove internal ticket #
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta committed Aug 19, 2024
1 parent 32910c3 commit 4b669fc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
return final_hidden_states.view(-1, D)


#TODO: remove after SW-195415 fix
#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
Expand Down Expand Up @@ -180,7 +180,7 @@ def prompt_attention(
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
#TODO: remove after SW-195415 fix
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
Expand Down

0 comments on commit 4b669fc

Please sign in to comment.