From cb5c62aab8b353e636fa73cf0b809cfc957a67f5 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Wed, 14 Aug 2024 09:56:58 +0300 Subject: [PATCH] Github actions pass for tflops --- vllm/hpu/ops.py | 60 +++++++++++++++++------------------------ vllm/worker/profiler.py | 2 +- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index d26b33c8e1831..01dd5bf0155e2 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -92,16 +92,16 @@ def paged_attention_v1(query, end_time = time.time() flops = flops_counter_decode(num_att_heads=query.shape[1], - batch_size=batch_size, - query_seq_len=query.shape[2], - max_seq_len=key_cache.shape[2], - block_size=block_size, - query_embedding_dim=query.shape[3], - value_embedding_dim=key_cache.shape[3], - duration=end_time - start_time) - habana_profiler.record_counter(habana_profiler.get_timestamp_us(), + batch_size=batch_size, + query_seq_len=query.shape[2], + max_seq_len=key_cache.shape[2], + block_size=block_size, + query_embedding_dim=query.shape[3], + value_embedding_dim=key_cache.shape[3], + duration=end_time - start_time) + habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"PA TFLOPS": flops / 1e12}) - + return attn_weights.squeeze(-2) @@ -173,41 +173,31 @@ def prompt_attention( attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) htorch.core.mark_step() - + end_time = time.time() flops = flops_counter_prompt(num_att_heads=query.shape[1], - batch_size=query.shape[0], - query_seq_len=query.shape[2], - max_seq_len=key.shape[2], - query_embedding_dim=query.shape[3], - value_embedding_dim=key.shape[3], - duration=end_time - start_time) - habana_profiler.record_counter(habana_profiler.get_timestamp_us(), + batch_size=query.shape[0], + query_seq_len=query.shape[2], + max_seq_len=key.shape[2], + query_embedding_dim=query.shape[3], + value_embedding_dim=key.shape[3], + duration=end_time - start_time) + habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"Prompt TFLOPS": flops / 1e12}) return attn_weights -def flops_counter_decode(num_att_heads, - batch_size, - query_seq_len, - max_seq_len, - block_size, - query_embedding_dim, - value_embedding_dim, - duration) -> float: +def flops_counter_decode(num_att_heads, batch_size, query_seq_len, max_seq_len, + block_size, query_embedding_dim, value_embedding_dim, + duration) -> float: return (batch_size * num_att_heads * query_seq_len * ceil(max_seq_len / block_size) * block_size * 2 * (query_embedding_dim + value_embedding_dim) / duration) -def flops_counter_prompt(num_att_heads, - batch_size, - query_seq_len, - max_seq_len, - query_embedding_dim, - value_embedding_dim, - duration) -> float: - return (batch_size * num_att_heads * query_seq_len * - max_seq_len * 2 * (query_embedding_dim + - value_embedding_dim) / duration) +def flops_counter_prompt(num_att_heads, batch_size, query_seq_len, max_seq_len, + query_embedding_dim, value_embedding_dim, + duration) -> float: + return (batch_size * num_att_heads * query_seq_len * max_seq_len * 2 * + (query_embedding_dim + value_embedding_dim) / duration) diff --git a/vllm/worker/profiler.py b/vllm/worker/profiler.py index 3e1a695d79105..fdb80312d2ac4 100644 --- a/vllm/worker/profiler.py +++ b/vllm/worker/profiler.py @@ -54,7 +54,7 @@ def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] - + return getinstance