Skip to content

Commit

Permalink
Chunk prefill cache writes, remove div_i32 from insert_or_update_cache (
Browse files Browse the repository at this point in the history
#289)

Re-implements following PRs for current habana_main:
#102 (Removing div_i32
operations from each layer)
#115 (removing scatter for
reshape&cache in case of prompt)

Accuracy (GSM8K on Llama3.1-8B-Instruct):
| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|

|---------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot_llama| 3|flexible-extract| 8|exact_match|↑ |0.8415|± |0.0101|
| | |strict-match | 8|exact_match|↑ |0.8400|± |0.0101|

I've benchmarked this change on Llama3.1-8B-Instruct and on average,
+2.50% throughput gain (+558.14 tok/s, ~21594 tok/s -> ~22152 tok/s) can
be observed across all prefill buckets on G2, with up to +4.40% (+956.79
tok/s, ~25031 -> ~25988 tok/s) throughput increase in compute-bound
scenarios.
  • Loading branch information
kzawora-intel authored Sep 26, 2024
1 parent 4c8a6c6 commit 1c6bada
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
3 changes: 1 addition & 2 deletions requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ ray == 2.32.0
triton
pandas
tabulate

vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@940fdb7
17 changes: 9 additions & 8 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension import cache_ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -166,20 +165,22 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HabanaPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
num_kv_cache_passes, num_slots_available, indices, offsets = \
cache_ops.prepare_to_cache(key_cache,
attn_metadata.slot_mapping)
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class HabanaPagedAttentionMetadata:
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]


class HabanaPagedAttention:
Expand Down
22 changes: 21 additions & 1 deletion vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ def pad_list(list, k, v):
return list + [v] * padding


def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
return indices, offsets


class HpuModelAdapter():

def __init__(self, model, block_size, dtype, enforce_eager):
Expand Down Expand Up @@ -890,11 +901,15 @@ def _prepare_prompt(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=None,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=real_num_seqs,
Expand Down Expand Up @@ -1044,11 +1059,15 @@ def _prepare_decode(
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=block_mapping,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
attn_bias=None,
seq_lens_tensor=None,
num_prefills=0,
Expand Down Expand Up @@ -1266,7 +1285,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt'
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets'
])
return attention_metadata

Expand Down

0 comments on commit 1c6bada

Please sign in to comment.