Skip to content

Commit

Permalink
Github actions pass for tflops
Browse files Browse the repository at this point in the history
  • Loading branch information
adobrzyniewicz-habana committed Aug 14, 2024
1 parent b5d1035 commit cb5c62a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
60 changes: 25 additions & 35 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ def paged_attention_v1(query,
end_time = time.time()

flops = flops_counter_decode(num_att_heads=query.shape[1],
batch_size=batch_size,
query_seq_len=query.shape[2],
max_seq_len=key_cache.shape[2],
block_size=block_size,
query_embedding_dim=query.shape[3],
value_embedding_dim=key_cache.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
batch_size=batch_size,
query_seq_len=query.shape[2],
max_seq_len=key_cache.shape[2],
block_size=block_size,
query_embedding_dim=query.shape[3],
value_embedding_dim=key_cache.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"PA TFLOPS": flops / 1e12})

return attn_weights.squeeze(-2)


Expand Down Expand Up @@ -173,41 +173,31 @@ def prompt_attention(
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
htorch.core.mark_step()

end_time = time.time()
flops = flops_counter_prompt(num_att_heads=query.shape[1],
batch_size=query.shape[0],
query_seq_len=query.shape[2],
max_seq_len=key.shape[2],
query_embedding_dim=query.shape[3],
value_embedding_dim=key.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
batch_size=query.shape[0],
query_seq_len=query.shape[2],
max_seq_len=key.shape[2],
query_embedding_dim=query.shape[3],
value_embedding_dim=key.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"Prompt TFLOPS": flops / 1e12})

return attn_weights


def flops_counter_decode(num_att_heads,
batch_size,
query_seq_len,
max_seq_len,
block_size,
query_embedding_dim,
value_embedding_dim,
duration) -> float:
def flops_counter_decode(num_att_heads, batch_size, query_seq_len, max_seq_len,
block_size, query_embedding_dim, value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len *
ceil(max_seq_len / block_size) * block_size * 2 *
(query_embedding_dim + value_embedding_dim) / duration)


def flops_counter_prompt(num_att_heads,
batch_size,
query_seq_len,
max_seq_len,
query_embedding_dim,
value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len *
max_seq_len * 2 * (query_embedding_dim +
value_embedding_dim) / duration)
def flops_counter_prompt(num_att_heads, batch_size, query_seq_len, max_seq_len,
query_embedding_dim, value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len * max_seq_len * 2 *
(query_embedding_dim + value_embedding_dim) / duration)
2 changes: 1 addition & 1 deletion vllm/worker/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]

return getinstance


Expand Down

0 comments on commit cb5c62a

Please sign in to comment.