Skip to content

Commit

Permalink
dout shape fix
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
  • Loading branch information
xrennvidia committed Sep 26, 2024
1 parent 1655edc commit edbe898
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,12 +2172,15 @@ def forward(
out = flash_attn_a2a_communicate(
out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False
)
if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, ctx.batch_size, *out.shape[-2:])
if use_fused_attention:
if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, ctx.batch_size, *out.shape[-2:])
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])

if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
Expand Down Expand Up @@ -2377,6 +2380,9 @@ def backward(ctx, dout):
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

if cp_size_a2a > 1:
if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True
Expand Down Expand Up @@ -3257,10 +3263,13 @@ def forward(

torch.cuda.current_stream().wait_stream(cp_stream)

if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
if use_fused_attention:
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])

ctx.save_for_backward(
q,
Expand Down Expand Up @@ -3743,12 +3752,13 @@ def forward(
out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
)

if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, batch_size, *out.shape[-2:])
if use_fused_attention:
if qkv_format == "bshd":
# [b*s, np, hn] -> [b, s, np, hn]
out = out.view(batch_size, -1, *out.shape[-2:])
elif qkv_format == "sbhd":
# [s*b, np, hn] -> [s, b, np, hn]
out = out.view(-1, batch_size, *out.shape[-2:])

if fp8:
if fp8_meta["recipe"].fp8_mha:
Expand Down Expand Up @@ -3888,6 +3898,10 @@ def backward(ctx, dout):
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

if not ctx.use_fused_attention:
out = out.view(ctx.batch_size, -1, *out.shape[-2:])
dout = dout.view(*out.shape)

chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True)
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True
Expand Down

0 comments on commit edbe898

Please sign in to comment.