Skip to content

Commit

Permalink
[Paddle] Support GQA (#595)
Browse files Browse the repository at this point in the history
* use separate qkv

Signed-off-by: jaywan <jaywan@nvidia.com>

* add support for GQA

Signed-off-by: jaywan <jaywan@nvidia.com>

* minor changes

Signed-off-by: Shijie Wang <jaywan@nvidia.com>

* change rtol

Signed-off-by: Shijie Wang <jaywan@nvidia.com>

* fix reshape issue

Signed-off-by: Shijie Wang <jaywan@nvidia.com>

---------

Signed-off-by: jaywan <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
  • Loading branch information
Wong4j authored Jan 26, 2024
1 parent e531cd2 commit bd7fd0a
Show file tree
Hide file tree
Showing 6 changed files with 751 additions and 168 deletions.
140 changes: 57 additions & 83 deletions tests/paddle/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""

import math
import os
from utils import assert_allclose, is_fused_attention_supported

Expand Down Expand Up @@ -785,7 +784,7 @@ def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activati

@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
Expand All @@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
head_size=head_size,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")

self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, 3, num_heads,
head_size)).astype(math_dtype)
cross_attn_q_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, num_heads,
head_size)).astype(math_dtype)
cross_attn_kv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, kv_seqlen, 2, num_heads,
head_size)).astype(math_dtype)
attn_q_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
attn_k_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
attn_v_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)

q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
Expand All @@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False

norm_factor = math.sqrt(hidden_size // num_heads)
layer_te = te.DotProductAttention(norm_factor,
head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='transformer_engine')
layer_pd = te.DotProductAttention(norm_factor,
layer_pd = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='paddle')

def calc_attn_output_and_grad(layer, q, kv, mask, dout):
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None
_k = paddle.to_tensor(k, stop_gradient=False)
_v = paddle.to_tensor(v, stop_gradient=False)

out = layer(_q, _kv, mask)
out = layer(_q, _k, _v, mask)
out.backward(dout)
return out, _q.grad, _kv.grad if _kv is not None else None

if attn_type == 'self':
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask,
grad_out)
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None,
attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]

q_grad = qkv_grad[:, :, 0]
k_grad = qkv_grad[:, :, 1]
v_grad = qkv_grad[:, :, 2]
q_grad_ref = qkv_grad_ref[:, :, 0]
k_grad_ref = qkv_grad_ref[:, :, 1]
v_grad_ref = qkv_grad_ref[:, :, 2]

else:
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
cross_attn_kv_input, attn_mask, grad_out)
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
cross_attn_kv_input, attn_mask,
grad_out)

valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
return out, _q.grad, _k.grad, _v.grad

k_grad = kv_grad[:, :, 0]
v_grad = kv_grad[:, :, 1]
k_grad_ref = kv_grad_ref[:, :, 0]
v_grad_ref = kv_grad_ref[:, :, 1]
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
attn_v_input, attn_mask, grad_out)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]

valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
Expand All @@ -910,17 +882,18 @@ def calc_attn_output_and_grad(layer, q, kv, mask, dout):


@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output):
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output):
"""
Test Transformer Encoder Layer
"""
Expand All @@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
Expand All @@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
Expand All @@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
Expand Down Expand Up @@ -1088,18 +1063,19 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, dout):


@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output,
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention):
"""
Test Transformer Decoder Layer
Expand All @@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")

encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
encoder_input = paddle.normal(mean=0.0, std=0.1,
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
encoder_output = paddle.normal(mean=0.0, std=0.1,
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)

q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')

grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.01,
shape=(bs, q_seqlen, hidden_size)).astype('float32')

# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
encoder_output = paddle.round(encoder_output * 1000) / 1000
grad_out = paddle.round(grad_out * 1000) / 1000

for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
Expand All @@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
Expand All @@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
Expand Down Expand Up @@ -1319,7 +1293,7 @@ def calc_transformer_output_and_grad(layer,
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=0.1)
atol=atol)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
Expand All @@ -1328,7 +1302,7 @@ def calc_transformer_output_and_grad(layer,
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
rtol=0.5,
atol=0.6)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
Expand Down
Loading

0 comments on commit bd7fd0a

Please sign in to comment.