From 1c6bada23884043cdd2a5715bce405bf2bb000f0 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 26 Sep 2024 14:53:29 +0200 Subject: [PATCH] Chunk prefill cache writes, remove div_i32 from insert_or_update_cache (#289) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Re-implements following PRs for current habana_main: https://github.com/HabanaAI/vllm-fork/pull/102 (Removing div_i32 operations from each layer) https://github.com/HabanaAI/vllm-fork/pull/115 (removing scatter for reshape&cache in case of prompt) Accuracy (GSM8K on Llama3.1-8B-Instruct): | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |---------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k_cot_llama| 3|flexible-extract| 8|exact_match|↑ |0.8415|± |0.0101| | | |strict-match | 8|exact_match|↑ |0.8400|± |0.0101| I've benchmarked this change on Llama3.1-8B-Instruct and on average, +2.50% throughput gain (+558.14 tok/s, ~21594 tok/s -> ~22152 tok/s) can be observed across all prefill buckets on G2, with up to +4.40% (+956.79 tok/s, ~25031 -> ~25988 tok/s) throughput increase in compute-bound scenarios. --- requirements-hpu.txt | 3 +-- vllm/attention/backends/habana_attn.py | 17 +++++++++-------- vllm/attention/ops/habana_paged_attn.py | 2 ++ vllm/worker/habana_model_runner.py | 22 +++++++++++++++++++++- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 1af5460128fbb..33619dc4883d5 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -6,5 +6,4 @@ ray == 2.32.0 triton pandas tabulate - -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab \ No newline at end of file +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@940fdb7 diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 59a99b89c293f..dad33fefc51f3 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -8,7 +8,6 @@ import torch import vllm_hpu_extension.ops as ops -from vllm_hpu_extension import cache_ops from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -166,6 +165,11 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) if kv_cache is not None: key_cache, value_cache = HabanaPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -173,13 +177,10 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - num_kv_cache_passes, num_slots_available, indices, offsets = \ - cache_ops.prepare_to_cache(key_cache, - attn_metadata.slot_mapping) - key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, - num_slots_available, indices, offsets) - value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, - num_slots_available, indices, offsets) + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) if attn_metadata.is_prompt: # Prompt run. diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 49a3e3f774d58..7f080e0727457 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -18,6 +18,8 @@ class HabanaPagedAttentionMetadata: block_list: Optional[torch.Tensor] block_mapping: Optional[torch.Tensor] block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] class HabanaPagedAttention: diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index f3bda39ec4822..d3d2973688843 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -245,6 +245,17 @@ def pad_list(list, k, v): return list + [v] * padding +def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + return indices, offsets + + class HpuModelAdapter(): def __init__(self, model, block_size, dtype, enforce_eager): @@ -890,11 +901,15 @@ def _prepare_prompt( dtype=torch.long, device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, block_list=None, block_mapping=None, block_usage=None, + block_indices=block_indices, + block_offsets=block_offsets, attn_bias=None, seq_lens_tensor=seq_lens_tensor, num_prefills=real_num_seqs, @@ -1044,11 +1059,15 @@ def _prepare_decode( dtype=torch.long, device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets( + self.block_size, slot_mapping, False) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, block_mapping=block_mapping, block_usage=block_usage, + block_indices=block_indices, + block_offsets=block_offsets, attn_bias=None, seq_lens_tensor=None, num_prefills=0, @@ -1266,7 +1285,8 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', - 'block_usage', 'slot_mapping', 'is_prompt' + 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', + 'block_offsets' ]) return attention_metadata