Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 26, 2024
1 parent a93256f commit ef2ff72
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,12 +1576,12 @@ def forward(
req.wait()

if i < (cp_size - 1):
p2p_comm_buffers[(i+1)%2] = torch.empty_like(p2p_comm_buffers[i%2])
p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2])
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
rank,
p2p_comm_buffers[i%2],
p2p_comm_buffers[i % 2],
send_dst,
p2p_comm_buffers[(i+1)%2],
p2p_comm_buffers[(i + 1) % 2],
recv_src,
cp_group,
batch_p2p_comm,
Expand All @@ -1592,11 +1592,11 @@ def forward(
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
kv_inputs[i % 2] = p2p_comm_buffers[i%2]
kv_inputs[i % 2] = p2p_comm_buffers[i % 2]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = cast_to_fp8(
p2p_comm_buffers[i%2],
p2p_comm_buffers[i % 2],
fp8_meta["scaling_fwd"],
META_QKV,
fp8_dtype_forward,
Expand Down

0 comments on commit ef2ff72

Please sign in to comment.