Skip to content

Commit

Permalink
Merge branch 'main' into add_padding_for_unfused
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Aug 16, 2024
2 parents 4d8dc92 + 4edcff5 commit 5d46ea9
Show file tree
Hide file tree
Showing 13 changed files with 1,386 additions and 285 deletions.
41 changes: 34 additions & 7 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -254,12 +254,39 @@ def get_frameworks() -> List[str]:
return _frameworks


def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
def copy_common_headers(
src_dir: Union[Path, str],
dst_dir: Union[Path, str],
) -> None:
"""Copy headers from core library
src_dir should be the transformer_engine directory within the root
Transformer Engine repository. All .h and .cuh files within
transformer_engine/common are copied into dst_dir. Relative paths
are preserved.
"""

# Find common header files in src dir
headers = glob.glob(
os.path.join(str(src_dir), "common", "**", "*.h"),
recursive=True,
)
headers.extend(
glob.glob(
os.path.join(str(src_dir), "common", "**", "*.cuh"),
recursive=True,
)
)
headers = [Path(path) for path in headers]

# Copy common header files to dst dir
src_dir = Path(src_dir)
dst_dir = Path(dst_dir)
for path in headers:
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)


def install_and_import(package):
Expand Down
68 changes: 41 additions & 27 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}


def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
def run_dpa_with_cp(
dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p"
):
"""Test DotProductAttention module with context parallelism"""

os.environ["NVTE_FLASH_ATTN"] = "0"
Expand All @@ -24,10 +26,16 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
if qkv_format == "thd" and (
config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
):
return

assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"

rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -49,73 +57,77 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")

assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"

if kernel_backend == "FusedAttention" and qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"

# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()

# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
q_input_shape = (config.batch_size * config.max_seqlen_q, config.num_heads, config.head_dim)
q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim,
config.num_heads * config.head_dim_qk,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
Expand Down Expand Up @@ -211,7 +223,9 @@ def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend=
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn(
q_,
k_,
Expand Down
63 changes: 59 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
)

model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}


Expand All @@ -39,7 +45,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"CP implementation with KV P2P does not support window size {config.window_size} yet!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
Expand All @@ -49,7 +76,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):


model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
Expand All @@ -66,9 +93,37 @@ def test_cp_with_flash_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip(f"{qkv_format} format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip(f"{qkv_format} format does not support {config.attn_bias_type} bias type yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip(
f"CP implementation with KV all-gather does not support {qkv_format} format yet!"
)
if cp_comm_type == "all_gather" and "causal" not in config.attn_mask_type:
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_mask_type} mask"
" type yet!"
)
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip(
f"CP implementation with KV all-gather does not support {config.attn_bias_type} bias"
" type yet!"
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
)

subprocess.run(
get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
Expand Down
Loading

0 comments on commit 5d46ea9

Please sign in to comment.