Skip to content

Commit

Permalink
Improve CP P2P efficiency
Browse files Browse the repository at this point in the history
Signed-off-by: Yen-Chen Lin <yenchenl@nvidia.com>
  • Loading branch information
Yen-Chen Lin committed Sep 26, 2024
1 parent 209b8e5 commit 4906662
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@ def forward(
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

p2p_comm_buffers = [None for _ in range(cp_size)]
p2p_comm_buffers = [None, None]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else:
Expand All @@ -1576,12 +1576,12 @@ def forward(
req.wait()

if i < (cp_size - 1):
p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2])
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
rank,
p2p_comm_buffers[i],
p2p_comm_buffers[i % 2],
send_dst,
p2p_comm_buffers[i + 1],
p2p_comm_buffers[(i + 1) % 2],
recv_src,
cp_group,
batch_p2p_comm,
Expand All @@ -1592,11 +1592,11 @@ def forward(
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
kv_inputs[i % 2] = p2p_comm_buffers[i]
kv_inputs[i % 2] = p2p_comm_buffers[i % 2]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = cast_to_fp8(
p2p_comm_buffers[i],
p2p_comm_buffers[i % 2],
fp8_meta["scaling_fwd"],
META_QKV,
fp8_dtype_forward,
Expand Down

0 comments on commit 4906662

Please sign in to comment.