Skip to content

Commit

Permalink
[PyTorch] Fix multiple calls to saved_tensors in CP attention (#1334)
Browse files Browse the repository at this point in the history
* Limit to one call of ctx.saved_tensors per autograd bwd

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ksivaman and pre-commit-ci[bot] authored Nov 14, 2024
1 parent 28aa41a commit d1488e7
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2528,12 +2528,13 @@ def backward(ctx, dout):
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)

(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
(*saved_tensors,) = ctx.saved_tensors
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6]
(fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8]
cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]

causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
Expand Down Expand Up @@ -3577,11 +3578,12 @@ def backward(ctx, dout):
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5]
cu_seqlens_kv_per_step = ctx.saved_tensors[5:7]
out_per_step = ctx.saved_tensors[7:9]
softmax_lse_per_step = ctx.saved_tensors[9:11]
rng_states = ctx.saved_tensors[11:13]
(*saved_tensors,) = ctx.saved_tensors
(q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5]
cu_seqlens_kv_per_step = saved_tensors[5:7]
out_per_step = saved_tensors[7:9]
softmax_lse_per_step = saved_tensors[9:11]
rng_states = saved_tensors[11:13]
kv_seq_range_per_step = ctx.kv_seq_range_per_step
window_size_per_step = ctx.window_size_per_step

Expand Down Expand Up @@ -4056,12 +4058,11 @@ def backward(ctx, dout):
# pylint: disable=missing-function-docstring
cp_size = get_distributed_world_size(ctx.cp_group)

q, k, v, out = ctx.saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[
4:8
]
fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10]
aux_ctx_tensors = ctx.saved_tensors[10:]
(*saved_tensors,) = ctx.saved_tensors
q, k, v, out = saved_tensors[:4]
cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8]
fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10]
aux_ctx_tensors = saved_tensors[10:]

qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
causal = "causal" in ctx.attn_mask_type
Expand Down

0 comments on commit d1488e7

Please sign in to comment.