diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 192f430ae1..4a69a5be80 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1561,7 +1561,7 @@ def forward( fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - p2p_comm_buffers = [None for _ in range(cp_size)] + p2p_comm_buffers = [None, None] if use_fused_attention and qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) else: @@ -1576,12 +1576,12 @@ def forward( req.wait() if i < (cp_size - 1): - p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) + p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2]) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, - p2p_comm_buffers[i], + p2p_comm_buffers[i % 2], send_dst, - p2p_comm_buffers[i + 1], + p2p_comm_buffers[(i + 1) % 2], recv_src, cp_group, batch_p2p_comm, @@ -1592,11 +1592,11 @@ def forward( or fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ): - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i % 2] else: # KV exchange is in BF16/FP16, cast received KV in each step kv_inputs[i % 2] = cast_to_fp8( - p2p_comm_buffers[i], + p2p_comm_buffers[i % 2], fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward,