diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a9d26ed0e1..4a69a5be80 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1576,12 +1576,12 @@ def forward( req.wait() if i < (cp_size - 1): - p2p_comm_buffers[(i+1)%2] = torch.empty_like(p2p_comm_buffers[i%2]) + p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2]) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, - p2p_comm_buffers[i%2], + p2p_comm_buffers[i % 2], send_dst, - p2p_comm_buffers[(i+1)%2], + p2p_comm_buffers[(i + 1) % 2], recv_src, cp_group, batch_p2p_comm, @@ -1592,11 +1592,11 @@ def forward( or fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ): - kv_inputs[i % 2] = p2p_comm_buffers[i%2] + kv_inputs[i % 2] = p2p_comm_buffers[i % 2] else: # KV exchange is in BF16/FP16, cast received KV in each step kv_inputs[i % 2] = cast_to_fp8( - p2p_comm_buffers[i%2], + p2p_comm_buffers[i % 2], fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward,