Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical CP implementation (Ulysses + Ring) #1209

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2ee1389
change API for hierarchical CP
xrennvidia Sep 19, 2024
3743c5f
Merge branch 'main' into xren/cp_a2a_p2p
xrennvidia Sep 19, 2024
0e826e7
move fp8 code before qkv reshape
xrennvidia Sep 20, 2024
44db43b
try to insert A2A for hierarchical CP
xrennvidia Sep 23, 2024
e717b81
Merge branch 'main' into xren/cp_a2a_p2p
xrennvidia Sep 23, 2024
705da7e
make fwd work
xrennvidia Sep 24, 2024
99ace46
remove a redundant sync
xrennvidia Sep 24, 2024
f73b2fb
make bwd of hierarchical CP work
xrennvidia Sep 24, 2024
0be5dd7
fix dout a2a in bwd
xrennvidia Sep 24, 2024
a6833c7
fix q_f16 with fp8
xrennvidia Sep 24, 2024
caf3746
assert hierarchical CP implementation does not support THD format
xrennvidia Sep 24, 2024
03806c4
bug fix
xrennvidia Sep 25, 2024
d3d336a
Merge branch 'main' into xren/cp_a2a_p2p
xrennvidia Sep 25, 2024
ea1e4a3
assert hierarchical CP does not support attn bias
xrennvidia Sep 25, 2024
3cd21ab
add unit test for hierarchical CP
xrennvidia Sep 25, 2024
90c0bb8
fix cp_comm_type in unit test
xrennvidia Sep 25, 2024
8c6139d
bug fix and code cleaning
xrennvidia Sep 25, 2024
52c7ae3
Merge branch 'main' into xren/cp_a2a_p2p
xrennvidia Sep 25, 2024
1db4e8b
minor change
xrennvidia Sep 25, 2024
1655edc
an assert info change
xrennvidia Sep 25, 2024
edbe898
dout shape fix
xrennvidia Sep 26, 2024
b096051
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
6063bec
Merge branch 'main' into xren/cp_a2a_p2p
xrennvidia Sep 27, 2024
67c590d
move function definitions to the front of the first call
xrennvidia Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def run_dpa_with_cp(
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)

if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
Expand Down Expand Up @@ -167,13 +176,6 @@ def run_dpa_with_cp(
else:
bias = None

# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group)

# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
Expand Down Expand Up @@ -239,7 +241,10 @@ def run_dpa_with_cp(
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)

if dtype == "fp8":
Expand Down
24 changes: 13 additions & 11 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
}


def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
def get_bash_arguments(num_gpus, **kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus)]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
Expand All @@ -51,27 +51,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
num_gpus=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down Expand Up @@ -106,7 +107,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
Expand All @@ -122,7 +123,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
Expand All @@ -140,16 +141,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
num_gpus=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down
Loading
Loading