Skip to content

Commit

Permalink
[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements (#1100)
Browse files Browse the repository at this point in the history
* fp8 mha with rope

Signed-off-by: Xin Yao <xiny@nvidia.com>

* avoid index select in cast ops

Signed-off-by: Xin Yao <xiny@nvidia.com>

* avoid index select in fused_attn_fwd

Signed-off-by: Xin Yao <xiny@nvidia.com>

* rename is_first_module_in_mha to fp8_output

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move transpose to backward for fp8 input

Signed-off-by: Xin Yao <xiny@nvidia.com>

* fix ut

Signed-off-by: Xin Yao <xiny@nvidia.com>

* resolve comments

Signed-off-by: Xin Yao <xiny@nvidia.com>

* update argument list for CP

Signed-off-by: Xin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for FA3

Signed-off-by: Xin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unnecessary copy of scale_inv

Signed-off-by: Xin Yao <xiny@nvidia.com>

* skip fp8 dpa/mha tests when fa3 is not available

Signed-off-by: Xin Yao <xiny@nvidia.com>

* fix a merge bug

Signed-off-by: Xin Yao <xiny@nvidia.com>

---------

Signed-off-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yaox12 and pre-commit-ci[bot] authored Sep 5, 2024
1 parent 247850e commit 5fafeb0
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 287 deletions.
36 changes: 25 additions & 11 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,32 +1344,35 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]

if _flash_attn_3_plus and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, is_training
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
)

atol = 5e-1
Expand Down Expand Up @@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
)


def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
Expand All @@ -1429,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
)

with fp8_model_init(enabled=fp8_mha):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
Expand Down Expand Up @@ -1489,6 +1496,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
rotary_pos_emb=rotary_pos_emb,
)
if is_training:
out.backward(out_grad)
Expand Down Expand Up @@ -1977,12 +1985,18 @@ def forward(
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
fp8_meta["scaling_fwd"].scale[META_O],
fp8_meta["scaling_fwd"].amax_history[0][META_S],
fp8_meta["scaling_fwd"].amax_history[0][META_O],
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
Expand Down
Loading

0 comments on commit 5fafeb0

Please sign in to comment.