diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index ae936925b5001..b69710e3a0a78 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -13,7 +13,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.worker.profiler import Profiler from vllm.logger import init_logger @@ -93,14 +92,15 @@ def paged_attention_v1(query, end_time = time.time() flops = flops_counter_decode(num_att_heads=query.shape[1], - batch_size=batch_size, - query_seq_len=query.shape[2], - max_seq_len=key_cache.shape[2], - block_size=block_size, - query_embedding_dim=query.shape[3], - value_embedding_dim=key_cache.shape[3], - duration=end_time - start_time) - habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"TFLOPS": flops / 1e12}) + batch_size=batch_size, + query_seq_len=query.shape[2], + max_seq_len=key_cache.shape[2], + block_size=block_size, + query_embedding_dim=query.shape[3], + value_embedding_dim=key_cache.shape[3], + duration=end_time - start_time) + habana_profiler.record_counter(habana_profiler.get_timestamp_us(), + {"PA TFLOPS": flops / 1e12}) return attn_weights.squeeze(-2) @@ -176,13 +176,14 @@ def prompt_attention( end_time = time.time() flops = flops_counter_prompt(num_att_heads=query.shape[1], - batch_size=query.shape[0], - query_seq_len=query.shape[2], - max_seq_len=key.shape[2], - query_embedding_dim=query.shape[3], - value_embedding_dim=key.shape[3], - duration=end_time - start_time) - habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"TFLOPS": flops / 1e12}) + batch_size=query.shape[0], + query_seq_len=query.shape[2], + max_seq_len=key.shape[2], + query_embedding_dim=query.shape[3], + value_embedding_dim=key.shape[3], + duration=end_time - start_time) + habana_profiler.record_counter(habana_profiler.get_timestamp_us(), + {"Prompt TFLOPS": flops / 1e12}) return attn_weights @@ -195,8 +196,9 @@ def flops_counter_decode(num_att_heads, query_embedding_dim, value_embedding_dim, duration) -> float: - return (batch_size * num_att_heads * query_seq_len * ceil(max_seq_len / block_size) - * block_size * 2 * (query_embedding_dim + value_embedding_dim) / duration) + return (batch_size * num_att_heads * query_seq_len * + ceil(max_seq_len / block_size) * block_size * 2 * + (query_embedding_dim + value_embedding_dim) / duration) def flops_counter_prompt(num_att_heads, @@ -206,5 +208,6 @@ def flops_counter_prompt(num_att_heads, query_embedding_dim, value_embedding_dim, duration) -> float: - return (batch_size * num_att_heads * query_seq_len * max_seq_len * 2 - * (query_embedding_dim + value_embedding_dim) / duration) + return (batch_size * num_att_heads * query_seq_len * + max_seq_len * 2 * (query_embedding_dim + + value_embedding_dim) / duration) diff --git a/vllm/worker/profiler.py b/vllm/worker/profiler.py index 63b24351923c4..3e1a695d79105 100644 --- a/vllm/worker/profiler.py +++ b/vllm/worker/profiler.py @@ -49,10 +49,12 @@ def run(self): def singleton(class_): instances = {} + def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] + return getinstance