Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
adobrzyniewicz-habana committed Aug 13, 2024
1 parent 396015b commit 2c7437a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
43 changes: 23 additions & 20 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.worker.profiler import Profiler
from vllm.logger import init_logger

Expand Down Expand Up @@ -93,14 +92,15 @@ def paged_attention_v1(query,
end_time = time.time()

flops = flops_counter_decode(num_att_heads=query.shape[1],
batch_size=batch_size,
query_seq_len=query.shape[2],
max_seq_len=key_cache.shape[2],
block_size=block_size,
query_embedding_dim=query.shape[3],
value_embedding_dim=key_cache.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"TFLOPS": flops / 1e12})
batch_size=batch_size,
query_seq_len=query.shape[2],
max_seq_len=key_cache.shape[2],
block_size=block_size,
query_embedding_dim=query.shape[3],
value_embedding_dim=key_cache.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"PA TFLOPS": flops / 1e12})

return attn_weights.squeeze(-2)

Expand Down Expand Up @@ -176,13 +176,14 @@ def prompt_attention(

end_time = time.time()
flops = flops_counter_prompt(num_att_heads=query.shape[1],
batch_size=query.shape[0],
query_seq_len=query.shape[2],
max_seq_len=key.shape[2],
query_embedding_dim=query.shape[3],
value_embedding_dim=key.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(), {"TFLOPS": flops / 1e12})
batch_size=query.shape[0],
query_seq_len=query.shape[2],
max_seq_len=key.shape[2],
query_embedding_dim=query.shape[3],
value_embedding_dim=key.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"Prompt TFLOPS": flops / 1e12})

return attn_weights

Expand All @@ -195,8 +196,9 @@ def flops_counter_decode(num_att_heads,
query_embedding_dim,
value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len * ceil(max_seq_len / block_size)
* block_size * 2 * (query_embedding_dim + value_embedding_dim) / duration)
return (batch_size * num_att_heads * query_seq_len *
ceil(max_seq_len / block_size) * block_size * 2 *
(query_embedding_dim + value_embedding_dim) / duration)


def flops_counter_prompt(num_att_heads,
Expand All @@ -206,5 +208,6 @@ def flops_counter_prompt(num_att_heads,
query_embedding_dim,
value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len * max_seq_len * 2
* (query_embedding_dim + value_embedding_dim) / duration)
return (batch_size * num_att_heads * query_seq_len *
max_seq_len * 2 * (query_embedding_dim +
value_embedding_dim) / duration)
2 changes: 2 additions & 0 deletions vllm/worker/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def run(self):

def singleton(class_):
instances = {}

def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]

return getinstance


Expand Down

0 comments on commit 2c7437a

Please sign in to comment.