Skip to content

Commit

Permalink
[C++/PyTorch] Add alibi_slopes support (#608)
Browse files Browse the repository at this point in the history
* test alibi between fa and fu

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move alibi slopes and bias to global to avoid repeating calculation

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix alibi slopes/bias generation

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix _is_flash_attention_supported to allow alibi type

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable padding mask when alibi is used for fused attn arbi backend

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support for custom [n_heads] alibi_slopes in flash, fused, unfused attention

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove alibi_type=none tests as they are unnecessary

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to 1.0.2

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape to allow b,1/1,h/b,h in arbi backend

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests for arbi post_scale_bias [1,h,s,s] or alibi_slopes [n_heads]

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape in max512 backend - incomplete

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove max512 changes from last commit and disable max512 (and arbi temporarily) for [b, h, s, s]; pending cuDNN backend support

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and tweak backend selection logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace || with () in docstring

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bias shape for max512 backend

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* combine slopes/bias generation to one function get_alibi() and fix alibi tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix PR557 bugs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>

* encapsulate global alibi tensors into a dict cache

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce alibi slopes test size

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn-frontend 1.0.3

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use dBias shape to define bias_b/bias_h because jax materializes dBias rather than Bias in bwd abstract

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
  • Loading branch information
cyanguwa and timmoon10 authored Feb 8, 2024
1 parent da30634 commit 94de051
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 141 deletions.
75 changes: 54 additions & 21 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))

def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
Expand All @@ -81,6 +83,7 @@ def __init__(
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1,
):
self.batch_size = batch_size
Expand All @@ -94,6 +97,7 @@ def __init__(
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers

Expand Down Expand Up @@ -167,7 +171,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
return False
if config.attn_bias_type != "no_bias":
if config.attn_bias_type not in ["no_bias", "alibi"]:
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
Expand Down Expand Up @@ -283,18 +287,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
)

if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
Expand Down Expand Up @@ -382,6 +394,21 @@ def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)

model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
Expand Down Expand Up @@ -477,9 +504,17 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
else:
window_size, attention_mask = None, None

alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")

# Create input tensors
dim_to_num = {
'b' : config.batch_size,
Expand Down Expand Up @@ -570,6 +605,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True)
out.backward(out_grad)

Expand All @@ -583,6 +619,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
Expand Down Expand Up @@ -654,12 +692,18 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
)

if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)

Expand Down Expand Up @@ -758,28 +802,10 @@ def _run_transformer_layer(
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

# Create bias
if config.attn_bias_type == 'no_bias':
bias = None
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
elif config.attn_bias_type == 'alibi':
if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
config.attn_bias_type = 'post_scale_bias'
n = 2 ** math.floor(math.log2(config.num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))

a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
d = c - torch.transpose(c, 0, 1)
bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
for i in range(config.num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=dtype, device="cuda")
else:
bias = None

# Create RoPE
rotary_pos_emb = None
Expand Down Expand Up @@ -825,14 +851,21 @@ def _run_transformer_layer(
.to(dtype=dtype, device="cuda")
)

# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")

# Run a forward and backward pass
out = block(inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias)
core_attention_bias=bias,
alibi_slopes=alibi_slopes)
loss = out.sum()
loss.backward()

Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90))))
Expand Down
Loading

0 comments on commit 94de051

Please sign in to comment.