Skip to content

Commit

Permalink
fix TFLOPs calculation (#371)
Browse files Browse the repository at this point in the history
* fix TFLOPs calculation

when GQA used, we observe right TFLOPs after this fix.
when GQA is not used, huge difference in TFLOPs is solved with 
selective recompute .
some other minor difference will also be observed as logits macs also added.

* add copyrights
  • Loading branch information
polisettyvarma authored Aug 19, 2024
1 parent cdf5194 commit b7b2d5e
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,38 @@ def throughput_calculator(model, args, iteration_time, total_iterations):

#flops calculator
hidden_size = args.hidden_size
num_attention_heads = args.num_attention_heads
head_dim = hidden_size // num_attention_heads
ffn_hidden_size = args.ffn_hidden_size
num_layers = args.num_layers
vocab_size = args.padded_vocab_size
gqa = args.num_attention_heads // args.num_key_value_heads
ffn_multiplier = 3 if args.swiglu else 2
macs_per_flops = 2

# General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of
# https://arxiv.org/pdf/2104.04473.pdf).
# The factor of 4 is when used with activation check-pointing,
# otherwise it will be 3.
checkpoint_activations_factor = 3
if hasattr(args, 'checkpoint_activations') and args.checkpoint_activations:
checkpoint_activations_factor = 4
if hasattr(args, 'recompute_granularity') and (args.recompute_granularity == 'selective' or args.recompute_granularity == 'full'):
checkpoint_activations_factor = 4
# correction has been made to TFLOPs formula due to incorrect behavior
# observed with selective recompute when GQA not used and for all with GQA
seq_len = args.seq_length
if hasattr(args, 'actual_seq_length'):
seq_len = args.actual_seq_length
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size)))

pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len
mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2)
ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len
logit_lmhead_gemm_macs = batch_size * vocab_size * hidden_size * seq_len

fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs
bwd_macs = 2 * fwd_macs
fwd_bwd_macs = fwd_macs + bwd_macs

if (hasattr(args, 'checkpoint_activations') and args.checkpoint_activations) or (hasattr(args, 'recompute_granularity') and args.recompute_granularity == 'full'):
fwd_bwd_macs += fwd_macs
if hasattr(args, 'recompute_granularity') and args.recompute_granularity == 'selective':
fwd_bwd_macs += mha_bgemm_macs

flops_per_iteration = fwd_bwd_macs * macs_per_flops
tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12))
return samples_per_second, tflops, approx_parameters_in_billions

Expand Down

0 comments on commit b7b2d5e

Please sign in to comment.