From 5fafeb0efef60d6f10574bb4366cdc5a8db7192d Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 5 Sep 2024 13:57:30 +0800 Subject: [PATCH] [PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements (#1100) * fp8 mha with rope Signed-off-by: Xin Yao * avoid index select in cast ops Signed-off-by: Xin Yao * avoid index select in fused_attn_fwd Signed-off-by: Xin Yao * rename is_first_module_in_mha to fp8_output Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move transpose to backward for fp8 input Signed-off-by: Xin Yao * fix ut Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * update argument list for CP Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for FA3 Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unnecessary copy of scale_inv Signed-off-by: Xin Yao * skip fp8 dpa/mha tests when fa3 is not available Signed-off-by: Xin Yao * fix a merge bug Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 36 +- transformer_engine/pytorch/attention.py | 444 ++++++++++-------- .../pytorch/cpp_extensions/fused_attn.py | 84 +++- transformer_engine/pytorch/csrc/extensions.h | 47 +- .../pytorch/csrc/extensions/attention.cu | 81 ++-- .../pytorch/csrc/extensions/cast.cu | 36 +- .../pytorch/csrc/extensions/pybind.cpp | 12 +- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 +- .../pytorch/module/layernorm_linear.py | 6 +- transformer_engine/pytorch/module/linear.py | 24 +- 10 files changed, 491 insertions(+), 287 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index da26c7c42f..a1ebead04a 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1344,19 +1344,22 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) +@pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training): +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] if _flash_attn_3_plus and not is_training: + if RoPE: + pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1364,12 +1367,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training ) atol = 5e-1 @@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training): +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1429,6 +1432,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) with fp8_model_init(enabled=fp8_mha): + rotary_pos_emb = None + if RoPE: + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) + rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, @@ -1489,6 +1496,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, + rotary_pos_emb=rotary_pos_emb, ) if is_training: out.backward(out_grad) @@ -1977,12 +1985,18 @@ def forward( None, None, None, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 27d7c0fdc4..59bc26140d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -38,8 +38,20 @@ AttnBiasType, AttnMaskType, FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, + META_O_CP, + META_DQKV_CP, +) +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, + get_fp8_torch_dtype, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -120,15 +132,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) @@ -1546,10 +1549,14 @@ def forward( for x in [k_f16, v_f16] ] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] + fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV + fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv + fp8_meta_kwargs["d_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_s_offset"] = META_S + fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale + fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) else: assert False, "FP8 is only supported with Fused Attention!" @@ -1601,8 +1608,10 @@ def forward( fp8_dtype_forward, ) if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] + fp8_meta_kwargs["amax_s"] = amax_per_step + fp8_meta_kwargs["amax_s_offset"] = i + fp8_meta_kwargs["amax_o"] = amax_per_step + fp8_meta_kwargs["amax_o_offset"] = cp_size + i if causal: if i == 0: if pad_between_seqs_q: @@ -4153,9 +4162,8 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - stride = k.stride() - check_strides_kv = torch.equal( - torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] ) shape = q.shape @@ -4635,19 +4643,20 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." + is_input_fp8 = isinstance(qkv, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 1, ( - "qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" - f" {qkv_layout}." - ) - if fp8_meta["recipe"].fp8_mha: + assert ( + qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." + if is_input_fp8: qkv_fp8 = qkv._data else: qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) @@ -4663,12 +4672,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4678,7 +4693,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -4696,22 +4711,24 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( qkv_fp8, out_fp8, @@ -4728,12 +4745,18 @@ def forward( fused_attention_backend, attn_bias, cu_seqlens_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -4747,6 +4770,8 @@ def forward( out_save = out_ret ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward( *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors @@ -4771,7 +4796,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -4828,7 +4853,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -4868,7 +4893,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dqkv = Float8Tensor( data=dqkv_fp8, fp8_meta=ctx.fp8_meta, @@ -5006,22 +5031,23 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: - if fp8_meta["recipe"].fp8_mha: - assert isinstance(q, Float8Tensor) and isinstance( - kv, Float8Tensor - ), "q/kv must be Float8Tensors for FP8 MHA." + assert isinstance(kv, q.__class__), "q and kv must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: + if is_input_fp8: q_fp8, kv_fp8 = q._data, kv._data else: # 1: qkv packed, 2: kv packed, 3: qkv separate qkv_group = len(qkv_layout.split("_")) assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f" but found {qkv_layout}." + "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " + f"but found {qkv_layout}." ) q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( q.shape @@ -5043,12 +5069,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5058,7 +5090,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5076,25 +5108,31 @@ def forward( qkv_dtype, ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8( - q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if is_input_fp8: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], + META_O, + fp8_dtype_forward, + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, kv_fp8, @@ -5116,12 +5154,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5135,6 +5179,8 @@ def forward( fp8_tensors = (None, None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5166,7 +5212,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5227,7 +5273,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5271,7 +5317,7 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -5437,15 +5483,16 @@ def forward( fp8_meta, deterministic, ): + is_input_fp8 = False + is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA." + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data else: @@ -5496,12 +5543,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], + fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv + META_QKV, # d_scale_qkv_offset + fp8_meta["scaling_fwd"].scale_inv, # d_scale_s + META_S, # d_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_s + META_S, # q_scale_s_offset + fp8_meta["scaling_fwd"].scale, # q_scale_o + META_O, # q_scale_o_offset + fp8_meta["scaling_fwd"].amax_history, # amax_s + META_S, # amax_s_offset + fp8_meta["scaling_fwd"].amax_history, # amax_o + META_O, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5511,7 +5564,7 @@ def forward( window_size, rng_gen, ) - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, @@ -5530,71 +5583,73 @@ def forward( ).view(out_fp8.shape) out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, + if is_input_fp8: + qkv_group = len(qkv_layout.split("_")) + if qkv_group == 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8( + qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[qkv.dtype], + ).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8( + kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[kv.dtype], + ).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8( + q._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[q.dtype], + ).view(q.shape) + k = cast_from_fp8( + k._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[k.dtype], + ).view(k.shape) + v = cast_from_fp8( + v._data, + fp8_meta["scaling_fwd"], + META_QKV, + fp8_dtype_forward, + TE_DType[v.dtype], + ).view(v.shape) + if is_output_fp8: + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), fp8_meta["scaling_fwd"], - META_QKV, + META_O, fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + qkv_dtype, + ).view(out_fp8.shape) fp8_tensors = ( q_fp8, @@ -5619,12 +5674,18 @@ def forward( attn_bias, cu_seqlens_q_padded, cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, + None, # d_scale_qkv + 0, # d_scale_qkv_offset + None, # d_scale_s + 0, # d_scale_s_offset + None, # q_scale_s + 0, # q_scale_s_offset + None, # q_scale_o + 0, # q_scale_o_offset + None, # amax_s + 0, # amax_s_offset + None, # amax_o + 0, # amax_o_offset attn_scale, dropout_p, fast_zero_fill, @@ -5647,6 +5708,8 @@ def forward( tensor.activation_offloading = True ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward( *qkvo_tensors, @@ -5678,7 +5741,7 @@ def forward( @staticmethod def backward(ctx, d_out): - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." @@ -5743,7 +5806,7 @@ def backward(ctx, d_out): fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: d_out_fp8 = d_out ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: @@ -5789,7 +5852,7 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq = Float8Tensor( data=dq_fp8, fp8_meta=ctx.fp8_meta, @@ -7719,12 +7782,18 @@ def forward( # Query, Key, and Value # ====================== + fp8_mha = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + ) + if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -7734,7 +7803,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) num_queries_per_key_value = ( @@ -7795,7 +7864,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.qkv_weight_interleaved: @@ -7845,6 +7914,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, + fp8_output=fp8_mha and rotary_pos_emb is None, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -7854,7 +7924,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - is_first_module_in_mha=True, # specific to FP8 MHA + fp8_output=fp8_mha and rotary_pos_emb is None, ) # [sq, b, hp] --> [sq, b, np, hn] diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d0ba644621..cd0ecbaa6c 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -78,6 +78,16 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +# repurpose some unused amax history buffers for partial results of CP fwd and bwd +META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT +META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 + def fused_attn_fwd_qkvpacked( is_training: bool, @@ -89,11 +99,17 @@ def fused_attn_fwd_qkvpacked( attn_bias: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -128,16 +144,28 @@ def fused_attn_fwd_qkvpacked( cumulative sequence offsets for QKV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -248,11 +276,17 @@ def fused_attn_fwd_qkvpacked( qkv_dtype, cu_seqlens_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -448,11 +482,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -496,16 +536,28 @@ def fused_attn_fwd_kvpacked( cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -621,11 +673,17 @@ def fused_attn_fwd_kvpacked( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, @@ -843,11 +901,17 @@ def fused_attn_fwd( cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_qkv_offset: int = META_QKV, d_scale_s: torch.Tensor = None, + d_scale_s_offset: int = META_S, q_scale_s: torch.Tensor = None, + q_scale_s_offset: int = META_S, q_scale_o: torch.Tensor = None, + q_scale_o_offset: int = META_O, amax_s: torch.Tensor = None, + amax_s_offset: int = META_S, amax_o: torch.Tensor = None, + amax_o_offset: int = META_O, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -894,17 +958,29 @@ def fused_attn_fwd( cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations + input tensor for the dequantization of QKV in FP8 computations + d_scale_qkv_offset: int, default = META_QKV + offset in d_scale_qkv for QKV d_scale_s: torch.Tensor, default = None input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_s_offset: int, default = META_S + offset in d_scale_s for S q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_s_offset: int, default = META_S + offset in q_scale_s for S q_scale_o: torch.Tensor, default = None input tensor for the quantization of O in FP8 computations + q_scale_o_offset: int, default = META_O + offset in q_scale_o for O amax_s: torch.Tensor, default = None output tensor, amax of S, used by the next iteration in FP8 computations + amax_s_offset: int, default = META_S + offset in amax_s for S amax_o: torch.Tensor, default = None output tensor, amax of O, used by the next iteration in FP8 computations + amax_o_offset: int, default = META_O + offset in amax_o for O attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1023,11 +1099,17 @@ def fused_attn_fwd( cu_seqlens_q_padded, cu_seqlens_kv_padded, d_scale_qkv, + d_scale_qkv_offset, d_scale_s, + d_scale_s_offset, q_scale_s, + q_scale_s_offset, q_scale_o, + q_scale_o_offset, amax_s, + amax_s_offset, amax_o, + amax_o_offset, attn_bias, rng_gen, rng_elts_per_thread, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1a6f5f157e..45ef9951d7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -48,11 +48,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, @@ -75,11 +77,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -104,11 +108,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread); + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -335,13 +341,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl **************************************************************************************************/ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset = 0, const int amax_offset = 0, + const int scale_inv_offset = 0); at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype); + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset = 0); /*************************************************************************************************** * Softmax diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50eb7b830f..fb1fc97a33 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -83,11 +83,13 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto qkv_sizes = QKV.sizes().vec(); @@ -122,11 +124,14 @@ std::vector fused_attn_fwd_qkvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -393,11 +398,13 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -429,13 +436,16 @@ std::vector fused_attn_fwd_kvpacked( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); @@ -747,11 +757,13 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional scale_S, const c10::optional scale_O, - c10::optional amax_S, c10::optional amax_O, - const c10::optional Bias, const c10::optional rng_gen, - size_t rng_elts_per_thread) { + const c10::optional descale_QKV, const int descale_QKV_offset, + const c10::optional descale_S, const int descale_S_offset, + const c10::optional scale_S, const int scale_S_offset, + const c10::optional scale_O, const int scale_O_offset, + c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, + const int amax_O_offset, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; auto q_sizes = Q.sizes().vec(); @@ -788,15 +800,18 @@ std::vector fused_attn_fwd( NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), - scale_O.value().data_ptr(), nullptr); + getDataPtr(descale_QKV.value(), descale_QKV_offset)); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + getDataPtr(amax_S.value(), amax_S_offset), + getDataPtr(scale_S.value(), scale_S_offset), + getDataPtr(descale_S.value(), descale_S_offset)); + te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, + getDataPtr(amax_O.value(), amax_O_offset), + getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { O.fill_(0); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index c783c9d988..47f5825866 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -6,8 +6,9 @@ #include "extensions.h" -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype) { +at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, + at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten if (input.numel() == 0) return output; + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return output; } -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { +void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, + at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int scale_offset, const int amax_offset, + const int scale_inv_offset) { using namespace transformer_engine; size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return; } -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype) { +at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, + transformer_engine::DType itype, transformer_engine::DType otype, + const int scale_inv_offset) { using namespace transformer_engine; auto input_shape = input.sizes().vec(); std::vector shape{input_shape.begin(), input_shape.end()}; @@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - scale_inv.data_ptr()); + getDataPtr(scale_inv, scale_inv_offset)); auto output_cu = makeTransformerEngineTensor(output); nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f903a1c35b..dc82b6e2df 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -93,10 +93,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, "Fused Multi-tensor Cast + Transpose with allocating output tensors", py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), + py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard()); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard()); + py::call_guard(), py::arg("input"), py::arg("scale"), + py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), + py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), + py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), + py::arg("scale_inv_offset") = 0); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 8515092ae0..8c480e8343 100644 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at:: at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); at::Tensor output = - cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); + cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); return output; } @@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca at::Tensor output, at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], - otype_arg); + cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, + fp8_tensor); return output; } @@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv int64_t fp8_tensor, int64_t itype, int64_t otype) { transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); + at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); return output; } diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d6045d8e77..9586d6d345 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -91,6 +91,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -220,7 +221,7 @@ def forward( if is_in_onnx_export_mode(): ln_out_scale_inv.fill_(ln_out_scale_inv.item()) - if fp8_meta["recipe"].fp8_mha: + if fp8_output: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -765,6 +766,7 @@ def backward( None, # ub_overlap_rs_dgrad None, # ub_overlap_ag None, # ub_name + None, # fp8_output None, # fsdp_group ) @@ -1117,6 +1119,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. @@ -1244,6 +1247,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + fp8_output, self.fsdp_group, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 175e5ab5cf..f92a2db2d9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -82,12 +82,10 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - is_first_module_in_mha: bool, + fp8_output: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -110,14 +108,6 @@ def forward( fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if isinstance(inputmat, Float8Tensor): inputmat_scale_inv = inputmat._scale_inv - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat_t = inputmat.transpose_2d() else: inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( @@ -171,7 +161,7 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) - if is_first_module_in_mha: + if fp8_output: proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], @@ -240,7 +230,7 @@ def forward( fp8_meta_tensor=meta_tensor, D_dtype=proj_out_tetype, ) - if is_first_module_in_mha: + if fp8_output: out = Float8Tensor( data=out, fp8_meta=fp8_meta, @@ -639,7 +629,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_overlap_rs None, # ub_overlap_ag None, # ub_name - None, # is_first_module_in_mha + None, # fp8_output None, # fsdp_group ) @@ -917,7 +907,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, - is_first_module_in_mha: Optional[bool] = False, + fp8_output: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -951,8 +941,6 @@ def forward( allow_non_contiguous=isinstance(inp, Float8Tensor), ) as inp: - is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha - # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, Float8Tensor) for w in unfused_weights): @@ -1037,7 +1025,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - is_first_module_in_mha, + fp8_output, self.fsdp_group, ) out = linear_fn(*args)