diff --git a/megatron/utils.py b/megatron/utils.py index cbb7aa6426..15160b1644 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -275,22 +275,38 @@ def throughput_calculator(model, args, iteration_time, total_iterations): #flops calculator hidden_size = args.hidden_size + num_attention_heads = args.num_attention_heads + head_dim = hidden_size // num_attention_heads + ffn_hidden_size = args.ffn_hidden_size num_layers = args.num_layers vocab_size = args.padded_vocab_size + gqa = args.num_attention_heads // args.num_key_value_heads + ffn_multiplier = 3 if args.swiglu else 2 + macs_per_flops = 2 # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of # https://arxiv.org/pdf/2104.04473.pdf). - # The factor of 4 is when used with activation check-pointing, - # otherwise it will be 3. - checkpoint_activations_factor = 3 - if hasattr(args, 'checkpoint_activations') and args.checkpoint_activations: - checkpoint_activations_factor = 4 - if hasattr(args, 'recompute_granularity') and (args.recompute_granularity == 'selective' or args.recompute_granularity == 'full'): - checkpoint_activations_factor = 4 + # correction has been made to TFLOPs formula due to incorrect behavior + # observed with selective recompute when GQA not used and for all with GQA seq_len = args.seq_length if hasattr(args, 'actual_seq_length'): seq_len = args.actual_seq_length - flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + + pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len + mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2) + ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len + logit_lmhead_gemm_macs = batch_size * vocab_size * hidden_size * seq_len + + fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs + bwd_macs = 2 * fwd_macs + fwd_bwd_macs = fwd_macs + bwd_macs + + if (hasattr(args, 'checkpoint_activations') and args.checkpoint_activations) or (hasattr(args, 'recompute_granularity') and args.recompute_granularity == 'full'): + fwd_bwd_macs += fwd_macs + if hasattr(args, 'recompute_granularity') and args.recompute_granularity == 'selective': + fwd_bwd_macs += mha_bgemm_macs + + flops_per_iteration = fwd_bwd_macs * macs_per_flops tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions