Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 27, 2024
1 parent edbe898 commit b096051
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
7 changes: 5 additions & 2 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run_dpa_with_cp(
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0
cp_comm_sub_ranks = [range(i*2, (i+1)*2) for i in range(world_size//2)]
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
Expand Down Expand Up @@ -241,7 +241,10 @@ def run_dpa_with_cp(
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)

if dtype == "fp8":
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def get_bash_arguments(num_gpus, **kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node="+str(num_gpus)]
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus)]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
Expand Down
33 changes: 27 additions & 6 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,8 +1445,13 @@ def forward(
softmax_scale = q.shape[-1] ** (-0.5)

if isinstance(cp_group, list):
assert qkv_format != "thd", f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!"
assert (
qkv_format != "thd"
), f"{qkv_format} format is not supported with hierarchical CP implementation yet!"
assert attn_bias_type == "no_bias", (
f"{attn_bias_type} bias type is not supported with hierarchical CP implementation"
" yet!"
)
cp_group_a2a = cp_group[0]
cp_size_a2a = get_distributed_world_size(cp_group_a2a)
rank_a2a = get_distributed_rank(cp_group_a2a)
Expand Down Expand Up @@ -2385,7 +2390,13 @@ def backward(ctx, dout):
dout = dout.view(*out.shape)
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True)
out, dout = flash_attn_a2a_communicate(
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True
[out, dout],
chunk_ids_for_a2a,
seq_dim,
cp_size_a2a,
ctx.cp_group_a2a,
ctx.cp_stream,
True,
)
if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
dout = cast_from_fp8(
Expand Down Expand Up @@ -2992,7 +3003,13 @@ def backward(ctx, dout):
if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False)
dq, dk, dv = flash_attn_a2a_communicate(
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, False
[dq, dk, dv],
chunk_ids_for_a2a,
seq_dim,
cp_size_a2a,
ctx.cp_group_a2a,
ctx.cp_stream,
False,
)
if ctx.qkv_format == "bshd":
dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]]
Expand Down Expand Up @@ -4062,7 +4079,9 @@ def attn_forward_func_with_cp(
"""

if cp_comm_type == "a2a+p2p":
assert isinstance(cp_group, list), "Hierarchical CP implementation needs multi-level CP groups!"
assert isinstance(
cp_group, list
), "Hierarchical CP implementation needs multi-level CP groups!"
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
if get_distributed_world_size(cp_group[0]) == 1:
cp_group = cp_group[1]
Expand All @@ -4071,7 +4090,9 @@ def attn_forward_func_with_cp(
cp_group = cp_group[0]
cp_comm_type = "a2a"
else:
assert isinstance(cp_group, dist_group_type), f"Unsupported process group for CP communication type {cp_comm_type}!"
assert isinstance(
cp_group, dist_group_type
), f"Unsupported process group for CP communication type {cp_comm_type}!"

assert qkv_format in [
"bshd",
Expand Down

0 comments on commit b096051

Please sign in to comment.