From bd7fd0a6b363733f43888239f854cc368ce7b981 Mon Sep 17 00:00:00 2001 From: Shijie Date: Sat, 27 Jan 2024 07:29:25 +0800 Subject: [PATCH] [Paddle] Support GQA (#595) * use separate qkv Signed-off-by: jaywan * add support for GQA Signed-off-by: jaywan * minor changes Signed-off-by: Shijie Wang * change rtol Signed-off-by: Shijie Wang * fix reshape issue Signed-off-by: Shijie Wang --------- Signed-off-by: jaywan Signed-off-by: Shijie Wang --- tests/paddle/test_layers.py | 140 ++++----- tests/paddle/test_operators.py | 102 +++++++ transformer_engine/paddle/cpp_extensions.py | 183 ++++++++++++ transformer_engine/paddle/csrc/custom_ops.cu | 204 +++++++++++++ transformer_engine/paddle/layer/attention.py | 280 ++++++++++++------ .../paddle/layer/transformer.py | 10 + 6 files changed, 751 insertions(+), 168 deletions(-) diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py index 83b953d814..686b261b9e 100644 --- a/tests/paddle/test_layers.py +++ b/tests/paddle/test_layers.py @@ -3,7 +3,6 @@ # See LICENSE for license information. """Test TE Paddle Layer-level APIs""" -import math import os from utils import assert_allclose, is_fused_attention_supported @@ -785,7 +784,7 @@ def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activati @pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('attn_type', ['self', 'cross']) @pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, head_size=head_size, dtype=math_dtype, dropout=0.0, - qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd", + qkv_layout="bshd_bshd_bshd", bias_type="no_bias", mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") - self_attn_qkv_input = paddle.normal(mean=0.0, - std=0.02, - shape=(bs, q_seqlen, 3, num_heads, - head_size)).astype(math_dtype) - cross_attn_q_input = paddle.normal(mean=0.0, - std=0.02, - shape=(bs, q_seqlen, num_heads, - head_size)).astype(math_dtype) - cross_attn_kv_input = paddle.normal(mean=0.0, - std=0.02, - shape=(bs, kv_seqlen, 2, num_heads, - head_size)).astype(math_dtype) + attn_q_input = paddle.normal(mean=0.0, std=0.02, + shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype) + attn_k_input = paddle.normal(mean=0.0, std=0.02, + shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) + attn_v_input = paddle.normal(mean=0.0, std=0.02, + shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype) q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32') kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,), @@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, for i in range(0, bs): attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False - norm_factor = math.sqrt(hidden_size // num_heads) - layer_te = te.DotProductAttention(norm_factor, + head_size = hidden_size // num_heads + layer_te = te.DotProductAttention(num_heads, + head_size, attention_dropout=0.0, attn_mask_type=mask_type, attention_type=attn_type, backend='transformer_engine') - layer_pd = te.DotProductAttention(norm_factor, + layer_pd = te.DotProductAttention(num_heads, + head_size, attention_dropout=0.0, attn_mask_type=mask_type, attention_type=attn_type, backend='paddle') - def calc_attn_output_and_grad(layer, q, kv, mask, dout): + def calc_attn_output_and_grad(layer, q, k, v, mask, dout): _q = paddle.to_tensor(q, stop_gradient=False) - _kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None + _k = paddle.to_tensor(k, stop_gradient=False) + _v = paddle.to_tensor(v, stop_gradient=False) - out = layer(_q, _kv, mask) + out = layer(_q, _k, _v, mask) out.backward(dout) - return out, _q.grad, _kv.grad if _kv is not None else None - - if attn_type == 'self': - out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask, - grad_out) - out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None, - attn_mask, grad_out) - valid_out_ref = paddle.full_like(out_ref, 0) - for i in range(0, bs): - valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] - - q_grad = qkv_grad[:, :, 0] - k_grad = qkv_grad[:, :, 1] - v_grad = qkv_grad[:, :, 2] - q_grad_ref = qkv_grad_ref[:, :, 0] - k_grad_ref = qkv_grad_ref[:, :, 1] - v_grad_ref = qkv_grad_ref[:, :, 2] - - else: - out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input, - cross_attn_kv_input, attn_mask, grad_out) - out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input, - cross_attn_kv_input, attn_mask, - grad_out) - - valid_out_ref = paddle.full_like(out_ref, 0) - for i in range(0, bs): - valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] + return out, _q.grad, _k.grad, _v.grad - k_grad = kv_grad[:, :, 0] - v_grad = kv_grad[:, :, 1] - k_grad_ref = kv_grad_ref[:, :, 0] - v_grad_ref = kv_grad_ref[:, :, 1] + out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input, + attn_v_input, attn_mask, grad_out) + out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( + layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out) + valid_out_ref = paddle.full_like(out_ref, 0) + for i in range(0, bs): + valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) @@ -910,17 +882,18 @@ def calc_attn_output_and_grad(layer, q, kv, mask, dout): @pytest.mark.parametrize('bs', [1, 2, 8]) +@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False]) -def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, - no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, - output_layernorm, return_layernorm_output): +def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, + has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, + math_dtype, output_layernorm, return_layernorm_output): """ Test Transformer Encoder Layer """ @@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( num_heads=num_heads, - num_gqa_groups=num_heads, + num_gqa_groups=num_gqa_groups, q_seqlen=q_seqlen, kv_seqlen=kv_seqlen, head_size=hidden_size // num_heads, dtype=math_dtype, dropout=0.0, - qkv_layout="bs3hd", + qkv_layout="bshd_bshd_bshd", bias_type="no_bias", mask_type=mask_type, ): @@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, layer_te = te.TransformerLayer(hidden_size, ffn_hidden_size, num_heads, + num_gqa_groups=num_gqa_groups, layernorm_epsilon=eps, hidden_dropout=0.0, attention_dropout=0.0, @@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, layer_pd = te.TransformerLayer(hidden_size, ffn_hidden_size, num_heads, + num_gqa_groups=num_gqa_groups, layernorm_epsilon=eps, hidden_dropout=0.0, attention_dropout=0.0, @@ -1088,8 +1063,9 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): @pytest.mark.parametrize('bs', [1, 2, 8]) +@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) -@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) +@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('mask_type', ['causal', 'padding']) @@ -1097,9 +1073,9 @@ def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): @pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('recompute_core_attention', [True, False]) -def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, - no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, - output_layernorm, return_layernorm_output, +def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size, + has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type, + math_dtype, output_layernorm, return_layernorm_output, recompute_core_attention): """ Test Transformer Decoder Layer @@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, # Skip if cuDNN fused attention is not supported if not is_fused_attention_supported( num_heads=num_heads, - num_gqa_groups=num_heads, + num_gqa_groups=num_gqa_groups, q_seqlen=q_seqlen, kv_seqlen=kv_seqlen, head_size=hidden_size // num_heads, dtype=math_dtype, dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - if not is_fused_attention_supported( - head_size=hidden_size // num_heads, - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bs2hd", + qkv_layout="bshd_bshd_bshd", bias_type="no_bias", mask_type=mask_type, ): pytest.skip("cuDNN fused attention is not supported") - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype) + encoder_input = paddle.normal(mean=0.0, std=0.1, + shape=(bs, q_seqlen, hidden_size)).astype(math_dtype) + encoder_output = paddle.normal(mean=0.0, std=0.1, + shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype) q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen kv_actual_seqlen = q_actual_seqlen attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') - grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32') + grad_out = paddle.normal(mean=0.0, std=0.01, + shape=(bs, q_seqlen, hidden_size)).astype('float32') + + # rounding to avoid numerical issues + encoder_input = paddle.round(encoder_input * 1000) / 1000 + encoder_output = paddle.round(encoder_output * 1000) / 1000 + grad_out = paddle.round(grad_out * 1000) / 1000 + for i in range(0, bs): grad_out[i, q_actual_seqlen[i]:, :] = 0 grad_out = grad_out.astype(math_dtype) @@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, layer_te = te.TransformerLayer(hidden_size, ffn_hidden_size, num_heads, + num_gqa_groups=num_gqa_groups, layernorm_epsilon=eps, hidden_dropout=0.0, attention_dropout=0.0, @@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, layer_pd = te.TransformerLayer(hidden_size, ffn_hidden_size, num_heads, + num_gqa_groups=num_gqa_groups, layernorm_epsilon=eps, hidden_dropout=0.0, attention_dropout=0.0, @@ -1319,7 +1293,7 @@ def calc_transformer_output_and_grad(layer, assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, layer_pd.self_attention.layernorm_qkv.weight.grad.T, rtol=rtol, - atol=0.1) + atol=atol) assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad, layer_pd.inter_attention.layernorm_query.weight.grad.T, rtol=rtol, @@ -1328,7 +1302,7 @@ def calc_transformer_output_and_grad(layer, if output_layernorm: assert_allclose(layer_te.self_attention.qkv.bias.grad, layer_pd.self_attention.qkv.bias.grad, - rtol=0.01, + rtol=0.5, atol=0.6) else: assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py index b07e091cf0..e1493157b4 100644 --- a/tests/paddle/test_operators.py +++ b/tests/paddle/test_operators.py @@ -5,6 +5,12 @@ import struct +from utils import ( + assert_allclose, + create_fp8_meta, + get_fused_attention_backend, + is_fused_attention_supported, +) import numpy as np import paddle import paddle.nn.functional as F @@ -39,6 +45,8 @@ fused_attn_bwd_qkvpacked, fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked, + fused_attn_fwd, + fused_attn_bwd, scaled_softmax_forward, scaled_softmax_backward, scaled_masked_softmax_forward, @@ -594,6 +602,7 @@ def _random(shape): self.q = _random(self.q_shape) if self.attn_mode == "self_attn": + assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen" self.kv = self.q else: self.kv = _random(self.kv_shape) @@ -774,6 +783,70 @@ def _get_fused_attention_out(self): return out, q_grad, k_grad, v_grad + def _get_fused_attention_with_separate_qkv(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + + q_tensor = paddle.to_tensor(self.q, stop_gradient=False) + k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) + v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) + + q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) + kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) + + qkv_layout = "bshd_bshd_bshd" + fused_attention_backend = get_fused_attention_backend( + num_heads=self.num_heads, + num_gqa_groups=self.num_heads, + q_seqlen=self.q_seqlen, + kv_seqlen=self.kv_seqlen, + head_size=self.head_size, + dtype=self.dtype, + dropout=self.dropout_prob, + qkv_layout=qkv_layout, + bias_type="no_bias", + mask_type="causal" if self.is_causal_masking else "padding", + ) + + qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 + out, softmax_aux_tensor, rng_state = fused_attn_fwd( + q_tensor, + k_tensor, + v_tensor, + q_cu_seqlen_tensor, + kv_cu_seqlen_tensor, + is_training=True, + max_seqlen_q=self.q_seqlen, + max_seqlen_kv=self.kv_seqlen, + qkv_dtype=qkv_dtype, + fused_attention_backend=fused_attention_backend, + Bias=None, + attn_scale=self.scaling_factor, + dropout=self.dropout_prob, + set_zero=False, + qkv_layout=qkv_layout, + attn_mask_type="causal" if self.is_causal_masking else "padding") + dq, dk, dv, _ = fused_attn_bwd( + q_tensor, + k_tensor, + v_tensor, + q_cu_seqlen_tensor, + kv_cu_seqlen_tensor, + rng_state, + out, + self.dout, + softmax_aux_tensor, + fused_attention_backend=fused_attention_backend, + max_seqlen_q=self.q_seqlen, + max_seqlen_kv=self.kv_seqlen, + qkv_dtype=qkv_dtype, + attn_scale=self.scaling_factor, + dropout=self.dropout_prob, + set_zero=False, + qkv_layout=qkv_layout, + attn_mask_type="causal" if self.is_causal_masking else "padding") + + return out, dq, dk, dv + @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('is_causal_masking', [True, False]) @@ -857,6 +930,35 @@ def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) + @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES) + @pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) + @pytest.mark.parametrize('is_causal_masking', [False, True]) + def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype, + is_causal_masking): + """ + test flash attention forward + backward with separate qkv inputs + """ + if not is_fused_attention_supported( + num_heads=h, + num_gqa_groups=h, + q_seqlen=s, + kv_seqlen=s, + head_size=d, + dtype=dtype, + dropout=0.0, + qkv_layout="bshd_bshd_bshd", + bias_type="no_bias", + mask_type="causal" if is_causal_masking else "padding", + ): + pytest.skip("cuDNN fused attention is not supported") + self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) + reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() + fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv() + assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) + assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) + assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) + assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) + class TestSoftmax: """ diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py index 2a5aea0643..fd9928f7be 100644 --- a/transformer_engine/paddle/cpp_extensions.py +++ b/transformer_engine/paddle/cpp_extensions.py @@ -792,6 +792,189 @@ def fused_attn_bwd_kvpacked( return dq, dkv, dbias +def fused_attn_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_kv: paddle.Tensor, + is_training: bool, + max_seqlen_q: int, + max_seqlen_kv: int, + qkv_dtype: tex.DType, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + Bias: paddle.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "bshd_bshd_bshd", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Fused Attention FWD for unpacked QKV input""" + + assert (qkv_dtype in (tex.DType.kBFloat16, + tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." + assert (cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert (qkv_layout == "bshd_bshd_bshd" + ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." + b = cu_seqlens_q.shape[0] - 1 + + h = q.shape[-2] + d = q.shape[-1] + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + if bias_type != "no_bias": + assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." + assert (Bias.shape == [ + 1, h, max_seqlen_q, max_seqlen_kv + ]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." + assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv." + + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + # BF16/FP16 fused attention API from fmha_v1 apex + if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - + 1) // BACKEND_F16m512_THREADS_PER_CTA + + # BF16/FP16 fused attention API from fmha_v2 + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS + + if set_zero: + out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) + else: + out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) + + if is_training: + if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: + softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32') + else: + raise ValueError("Unsupported fused attention backend.") + else: + softmax_aux = None + + rng_state = paddle.empty(shape=[ + 2, + ], dtype=paddle.int64) + + # execute kernel + tex.te_fused_attn_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + Bias, + out, + softmax_aux, + rng_state, + b, + h, + d, + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + dropout, + qkv_layout, + bias_type, + attn_mask_type, + int(qkv_dtype), + rng_elts_per_thread, + ) + return out, softmax_aux, rng_state + + +def fused_attn_bwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_kv: paddle.Tensor, + rng_state: paddle.Tensor, + o: paddle.Tensor, + d_o: paddle.Tensor, + softmax_aux: paddle.Tensor, + fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + max_seqlen_q: int, + max_seqlen_kv: int, + qkv_dtype: tex.DType, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "bshd_bshd_bshd", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Fused Attention BWD for packed KV input""" + + assert (qkv_dtype in (tex.DType.kBFloat16, + tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention." + assert (cu_seqlens_q.shape == cu_seqlens_kv.shape + ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" + assert (qkv_layout == "bshd_bshd_bshd" + ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." + + b = cu_seqlens_q.shape[0] - 1 + h = q.shape[-2] + d = q.shape[-1] + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (fused_attention_backend != FusedAttnBackend["No_Backend"] + ), "Fused attention does not support this input combination." + + if set_zero: + dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) + dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) + dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype) + else: + dq = paddle.empty(shape=q.shape, dtype=q.dtype) + dk = paddle.empty(shape=k.shape, dtype=k.dtype) + dv = paddle.empty(shape=v.shape, dtype=v.dtype) + if bias_type != "no_bias": + dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) + else: + dbias = None + # execute kernel + tex.te_fused_attn_bwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + o, + d_o, + softmax_aux, + dq, + dk, + dv, + dbias, + rng_state, + b, + h, + d, + max_seqlen_q, + max_seqlen_kv, + attn_scale, + dropout, + qkv_layout, + bias_type, + attn_mask_type, + int(qkv_dtype), + ) + return dq, dk, dv, dbias + + def scaled_softmax_forward( inp: paddle.Tensor, scale_factor: float, diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index aadd457da2..2478db32b3 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -864,6 +864,183 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); } +void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, + const paddle::optional &Bias, + paddle::Tensor &O, // NOLINT + paddle::optional &softmax_aux, // NOLINT + paddle::Tensor &rng_state, // NOLINT + int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, + bool is_training, float attn_scale, float p_dropout, + const std::string &qkv_layout, const std::string &bias_type, + const std::string &attn_mask_type, const int64_t qkv_type, + int64_t rng_elts_per_thread) { + if (is_training && !softmax_aux) { + NVTE_ERROR("softmax_aux must be provided when training. \n"); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor(Q); + te_K = MakeNvteTensor(K); + te_V = MakeNvteTensor(V); + te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); + te_O = MakeNvteTensor(O); + } else { // TODO: support fp8 + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + if ((bias_type != "no_bias") && Bias) { + auto bias_shape = Bias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); + } + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract random number generator seed and offset + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); + auto gen_cuda = dev_ctx->GetGenerator(); + auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); + set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast(rng_state.data())); + + auto te_rng_state = MakeNvteTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, + attn_mask_type_enum, workspace.data(), Q.stream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + auto *output_s = + reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + output_s->data.dptr = GetOptionalDataPtr(softmax_aux); + + // execute the kernel + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, + attn_mask_type_enum, workspace.data(), Q.stream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); +} + +void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, + const paddle::Tensor &O, const paddle::Tensor &dO, + const paddle::Tensor &softmax_aux, + paddle::Tensor &dQ, // NOLINT + paddle::Tensor &dK, // NOLINT + paddle::Tensor &dV, // NOLINT + paddle::optional &dBias, // NOLINT + paddle::Tensor &rng_state, // NOLINT + int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, + float attn_scale, float p_dropout, const std::string &qkv_layout, + const std::string &bias_type, const std::string &attn_mask_type, + int64_t qkv_type) { + TensorWrapper te_dBias; + if (bias_type != "no_bias" && dBias) { + auto bias_shape = dBias->shape(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); + } + + auto qkv_dtype = Int2NvteDType(qkv_type); + // construct NVTE tensors + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { + // BF16 or FP16 + te_Q = MakeNvteTensor(Q); + te_K = MakeNvteTensor(K); + te_V = MakeNvteTensor(V); + te_O = MakeNvteTensor(O); + te_dO = MakeNvteTensor(dO); + te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); + te_dQ = MakeNvteTensor(dQ); + te_dK = MakeNvteTensor(dK); + te_dV = MakeNvteTensor(dV); + } else { + NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + nvte_aux_tensor_pack.size = 2; + auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); + auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); + output_s->data.shape = std::vector({static_cast(b), static_cast(h), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}); + output_s->data.dptr = const_cast(softmax_aux.data()); + fwd_rng_state->data.shape = std::vector({2}); + fwd_rng_state->data.dptr = const_cast(rng_state.data()); + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = + MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); + te_cu_seqlens_kv = + MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // allocate memory for workspace + auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); + workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), + Q.stream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); +} + std::vector te_scaled_softmax_forward(const paddle::Tensor &input, float scale_factor) { NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); @@ -1316,6 +1493,33 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked) {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked)); +PD_BUILD_OP(te_fused_attn_fwd) + .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", + paddle::Optional("_softmax_aux"), "_rng_state"}) + .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) + .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", + "max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float", + "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", + "qkv_type: int64_t", "rng_elts_per_thread: int64_t"}) + .SetInplaceMap({{"_O", "O"}, + {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, + {"_rng_state", "rng_state"}}) + .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd)); + +PD_BUILD_OP(te_fused_attn_bwd) + .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK", + "_dV", paddle::Optional("_dBias"), "rng_state"}) + .Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")}) + .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", + "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", + "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", + "qkv_type: int64_t"}) + .SetInplaceMap({{"_dQ", "dQ"}, + {"_dK", "dK"}, + {"_dV", "dV"}, + {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) + .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd)); + PD_BUILD_OP(te_scaled_softmax_forward) .Inputs({"input"}) .Outputs({"softmax_results"}) diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py index bf75986eca..91afea5449 100644 --- a/transformer_engine/paddle/layer/attention.py +++ b/transformer_engine/paddle/layer/attention.py @@ -22,6 +22,8 @@ fused_attn_bwd_qkvpacked, fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked, + fused_attn_fwd, + fused_attn_bwd, mask_to_cu_seqlens, ) from ..distributed import get_tp_group_and_world_size, track_rng_state @@ -31,6 +33,20 @@ __all__ = ["DotProductAttention", "MultiHeadAttention"] +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + Used to repeat the key and value states for GQA. + The hidden states go from (batch, seqlen, num_gqa_groups, head_size) + to (batch, seqlen, num_heads, head_size) + """ + batch, seqlen, num_gqa_groups, head_size = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size]) + + class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): """Function for FusedAttention with packed QKV input""" @@ -130,6 +146,50 @@ def backward(ctx, d_out): return (dq, dkv, None, None, rest[0]) +class FusedAttnFunc(paddle.autograd.PyLayer): + """Function for FusedAttention with separate Q, K, V tensors""" + + @staticmethod + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv, + attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type, + attn_mask_type, is_training, fused_attention_backend): + """Forward function for FusedAttention with separate Q, K, V tensors""" + out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, + is_training, max_seqlen_q, max_seqlen_kv, + qkv_dtype, fused_attention_backend, attn_bias, + attn_scale, dropout_p, set_zero, qkv_layout, + attn_bias_type, attn_mask_type) + + ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.qkv_dtype = qkv_dtype + ctx.attn_scale = attn_scale + ctx.dropout_p = dropout_p + ctx.set_zero = set_zero + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + ctx.fused_attention_backend = fused_attention_backend + + return out + + @staticmethod + def backward(ctx, d_out): + """Backward function for FusedAttention with separate Q, K, V tensors""" + q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() + dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out, + d_out, softmax_aux, ctx.fused_attention_backend, + ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype, + ctx.attn_scale, ctx.dropout_p, ctx.set_zero, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + # if no_bias, return dq, dk, dv + if ctx.attn_bias_type == "no_bias": + return (dq, dk, dv, None, None) + # else, return (dq, dk, dv, dbias) + return (dq, dk, dv, None, None, rest[0]) + + class DotProductAttention(paddle.nn.Layer): """ Allows the model to jointly attend to information from different @@ -143,31 +203,51 @@ class DotProductAttention(paddle.nn.Layer): Parameters ---------- - norm_factor : float - normalization factor for the attention scores. + num_attention_heads: int + number of attention heads in the transformer layer. + kv_channels: int + number of channels in the key and value tensors. + num_gqa_groups : Optional[int] = None + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the queries. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` type of attention mask passed into softmax operation. attention_type: {'self', 'cross'}, default = `self` type of attention operation. + tp_group : ProcessGroup, default = `None` + tensor parallel process group. backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` backend to use for attention operation. """ def __init__(self, - norm_factor: float, + num_attention_heads: int, + kv_channels: int, + num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.1, attn_mask_type: str = "causal", attention_type: str = "self", + tp_size: int = 1, backend: str = 'transformer_engine') -> None: super().__init__() - self.norm_factor = norm_factor self.attn_mask_type = attn_mask_type self.attention_dropout = attention_dropout self.attention_type = attention_type - self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd" + self.qkv_layout = "bshd_bshd_bshd" + self.hidden_size_per_attention_head = kv_channels + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.tp_size = tp_size + self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups) + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) + self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups self.backend = backend @@ -185,7 +265,8 @@ def __init__(self, def forward( self, query_layer: paddle.Tensor, - key_value_layer: paddle.Tensor = None, + key_layer: paddle.Tensor, + value_layer: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[paddle.Tensor] = None, @@ -199,26 +280,15 @@ def forward( Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` is set to `"causal"`. - .. note:: - - For self attention, :attr:`query_layer` is the `[query, key, value]` tensor - stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`, - :attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`). - And :attr:`key_value_layer` is `None`. - For cross attention, :attr:`query_layer` is the `[query]` tensor, which must - be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`, - :attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor, - which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2, - :attr:`num_attention_heads`, :attr:`size_per_head`). - - Parameters ---------- query_layer : paddle.Tensor Query tensor. - key_value_layer : paddle.Tensor - Key tensor. + key_layer : paddle.Tensor + Key tensor. + value_layer : paddle.Tensor + Value tensor. attention_mask : Optional[paddle.Tensor], default = `None` Boolean tensor used to mask out softmax input when not using attention. core_attention_bias_type: str, default = `no_bias` @@ -231,21 +301,25 @@ def forward( backend = self.backend + assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!" + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition + ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + if backend == 'transformer_engine': max_s_q = query_layer.shape[1] - max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1] + max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1] self.fused_attention_backend = tex.get_fused_attn_backend( TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2], - key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2], - max_s_q, max_s_kv, query_layer.shape[-1]) + key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q, + max_s_kv, query_layer.shape[-1]) is_backend_avail = (self.fused_attention_backend in [ FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"] ]) if is_backend_avail and self.use_fused_attention: - return self._te_forward(query_layer, key_value_layer, attention_mask, + return self._te_forward(query_layer, key_layer, value_layer, attention_mask, core_attention_bias_type, core_attention_bias, set_zero) warnings.warn("Fused attention is not enabled, falling back to Paddle backend") backend = 'paddle' @@ -256,13 +330,14 @@ def forward( if core_attention_bias_type != "no_bias": warnings.warn("Paddle backend dot product attention does not support bias yet. " "Bias will be ignored.") - return self._pd_forward(query_layer, key_value_layer, attention_mask) + return self._pd_forward(query_layer, key_layer, value_layer, attention_mask) raise AttributeError(f"Backend {backend} is not supported.") def _te_forward( self, query_layer: paddle.Tensor, - key_value_layer: paddle.Tensor = None, + key_layer: paddle.Tensor, + value_layer: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[paddle.Tensor] = None, @@ -270,10 +345,10 @@ def _te_forward( ) -> paddle.Tensor: if self.attention_type == "self": - # self attention - q: [b, s, 3, h, d] kv: None - assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 - and key_value_layer is None - ), "query shape must be [b, s, 3, h, d] for dot product self attention" + # self attention - q: [b, s, h, d] kv: None + assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4 + and len(value_layer.shape) + == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention" max_seqlen = query_layer.shape[1] if self.attn_mask_type == "causal" or attention_mask is None: cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1], @@ -283,32 +358,33 @@ def _te_forward( cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFuncPackedQKV.apply(query_layer, cu_seqlens, core_attention_bias, - max_seqlen, 1.0 / self.norm_factor, qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, self.qkv_layout, - core_attention_bias_type, self.attn_mask_type, - self.training, self.fused_attention_backend) + output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens, + cu_seqlens, core_attention_bias, max_seqlen, max_seqlen, + 1.0 / self.norm_factor, qkv_dtype, + self.attention_dropout if self.training else 0.0, set_zero, + self.qkv_layout, core_attention_bias_type, + self.attn_mask_type, self.training, + self.fused_attention_backend) elif self.attention_type == "cross": - # cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d] + # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d] assert ( - len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5 - and key_value_layer.shape[2] == 2 - ), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \ + len(query_layer.shape) == 4 and len(key_layer.shape) == 4 + and len(value_layer.shape) == 4 + ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \ "for dot product cross attention" assert (attention_mask is not None), "attention_mask must be provided for cross attention" max_seqlen_q = query_layer.shape[1] - max_seqlen_kv = key_value_layer.shape[1] + max_seqlen_kv = key_layer.shape[1] cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFuncPackedKV.apply(query_layer, key_value_layer, cu_seqlens_q, - cu_seqlens_kv, core_attention_bias, max_seqlen_q, - max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, self.qkv_layout, - core_attention_bias_type, self.attn_mask_type, - self.training, self.fused_attention_backend) + output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q, + cu_seqlens_kv, core_attention_bias, max_seqlen_q, + max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype, + self.attention_dropout if self.training else 0.0, set_zero, + self.qkv_layout, core_attention_bias_type, + self.attn_mask_type, self.training, + self.fused_attention_backend) else: raise ValueError("attention_type must be one of ['self', 'cross']") return output @@ -316,28 +392,14 @@ def _te_forward( def _pd_forward( self, query_layer: paddle.Tensor, - key_value_layer: paddle.Tensor = None, + key_layer: paddle.Tensor, + value_layer: paddle.Tensor, attention_mask: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - if self.attention_type == "self": - # self attention - q: [b, s, 3, h, d] k: None - assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 - and key_value_layer is None - ), "query shape must be [b, s, 3, h, d] for dot product self attention" - q = query_layer[:, :, 0] - k = query_layer[:, :, 1] - v = query_layer[:, :, 2] - elif self.attention_type == "cross": - # cross attention - q: [b, s, h, d] kv: [b, s, 2, h, d] - assert ( - len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5 - and key_value_layer.shape[2] == 2 - ), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \ - f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \ - f"kv: {key_value_layer.shape}" - q = query_layer - k = key_value_layer[:, :, 0] - v = key_value_layer[:, :, 1] + + q = query_layer + k = repeat_kv(key_layer, self.num_queries_per_key_value) + v = repeat_kv(value_layer, self.num_queries_per_key_value) q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) @@ -404,6 +466,14 @@ class MultiHeadAttention(paddle.nn.Layer): if set to `True`, uses sequence parallelism. tp_group : ProcessGroup, default = `None` tensor parallel process group. + num_gqa_groups : int, default = `None` + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the querys. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. rng_state_name : str, default = `local_seed` Controls the rng state used for dropout on attention probs. The specified rng should be set different seeds for different TP ranks. @@ -430,6 +500,7 @@ def __init__( set_parallel_mode: bool = False, sequence_parallel: bool = False, tp_group: Optional[dist_group_type] = None, + num_gqa_groups: Optional[int] = None, rng_state_name: str = 'local_seed', backend: str = 'transformer_engine', ) -> None: @@ -450,19 +521,25 @@ def __init__( self.sequence_parallel = self.tensor_parallel and sequence_parallel self.hidden_size_per_attention_head = hidden_size // num_attention_heads self.num_attention_heads = num_attention_heads - norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.set_parallel_mode = set_parallel_mode self.rng_state_name = rng_state_name self.backend = backend self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) + self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups) + assert (self.num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" + assert (self.num_gqa_groups % self.tp_size == 0 + ), "The number of GQA groups must be divisible by tensor parallel size!" + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) + self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads) qkv_parallel_mode = "column" if set_parallel_mode else None if self.attention_type == "self": if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, - 3 * hidden_size, + hidden_size + 2 * self.hidden_size_kv, eps=layernorm_epsilon, weight_attr=self.weight_attr, bias_attr=self.bias_attr, @@ -476,7 +553,7 @@ def __init__( else: self.qkv = Linear( hidden_size, - 3 * hidden_size, + hidden_size + 2 * self.hidden_size_kv, self.weight_attr, self.bias_attr, parallel_mode=qkv_parallel_mode, @@ -513,7 +590,7 @@ def __init__( ) self.key_value = Linear( hidden_size, - 2 * hidden_size, + 2 * self.hidden_size_kv, self.weight_attr, self.bias_attr, parallel_mode=qkv_parallel_mode, @@ -524,10 +601,13 @@ def __init__( # Attention. self.core_attention = DotProductAttention( - norm_factor, + self.num_attention_heads, + self.hidden_size_per_attention_head, + self.num_gqa_groups, attention_dropout, attn_mask_type=attn_mask_type, attention_type=self.attention_type, + tp_size=self.tp_size, backend=self.backend, ) @@ -619,18 +699,37 @@ def forward( is_first_microbatch=is_first_microbatch, ) - # [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] + num_queries_per_key_value = (self.num_attention_heads_per_partition // + self.num_gqa_groups_per_partition) + + # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d] mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ - -1, max_seq_len, 3, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head + -1, max_seq_len, ( + num_queries_per_key_value + + 2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head ]) + # [b, s_q, (h/ng+2), ng, d] + # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d] + query_layer, key_layer, value_layer = paddle.split( + mixed_qkv_layer, + num_or_sections=(num_queries_per_key_value, 1, 1), + axis=2, + ) + + # query: -> [b, s, h, d] + # key, value: -> [b, s, ng, d] + query_layer, key_layer, value_layer = (x.reshape( + shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) + for x in (query_layer, key_layer, value_layer)) + with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): if recompute_core_attention: context_layer = recompute( self.core_attention, - mixed_qkv_layer, - None, + query_layer, + key_layer, + value_layer, attention_mask, core_attention_bias_type, core_attention_bias, @@ -639,8 +738,9 @@ def forward( ) else: context_layer = self.core_attention( - query_layer=mixed_qkv_layer, - key_value_layer=None, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, attention_mask=attention_mask, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, @@ -654,10 +754,17 @@ def forward( ) # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] mixed_kv_layer = mixed_kv_layer.reshape(shape=[ - -1, max_seq_len, 2, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head + 0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head ]) + # [b, s_kv, 2 * ng, head_size] + # --> 2 [b, s_kv, ng, head_size] + key_layer, value_layer = paddle.split( + mixed_kv_layer, + num_or_sections=2, + axis=2, + ) + if self.input_layernorm: layernorm_query_outputs = self.layernorm_query( hidden_states, @@ -673,6 +780,7 @@ def forward( is_first_microbatch=is_first_microbatch, ) + # [b, s, hidden_size] --> [b, s, h, d] query_layer = query_layer.reshape(shape=[ -1, max_seq_len, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head @@ -682,7 +790,8 @@ def forward( context_layer = recompute( self.core_attention, query_layer, - mixed_kv_layer, + key_layer, + value_layer, attention_mask, core_attention_bias_type, core_attention_bias, @@ -692,7 +801,8 @@ def forward( else: context_layer = self.core_attention( query_layer=query_layer, - key_value_layer=mixed_kv_layer, + key_layer=key_layer, + value_layer=value_layer, attention_mask=attention_mask, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py index 50e5f118b6..e8a69c40a7 100644 --- a/transformer_engine/paddle/layer/transformer.py +++ b/transformer_engine/paddle/layer/transformer.py @@ -27,6 +27,14 @@ class TransformerLayer(paddle.nn.Layer): intermediate size to which input samples are projected. num_attention_heads : int number of attention heads in the transformer layer. + num_gqa_groups : Optional[int], default = `None` + number of GQA groups in the transformer layer. + Grouped Query Attention is described in + `this paper `_. + This only affects the keys and values, not the queries. + GQA-1 is equivalent to Multi-Query Attention + (`MQA `_), while GQA-H + is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. layernorm_epsilon : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. @@ -97,6 +105,7 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int, num_attention_heads: int, + num_gqa_groups: Optional[int] = None, layernorm_epsilon: float = 1e-5, hidden_dropout: float = 0.1, attention_dropout: float = 0.1, @@ -153,6 +162,7 @@ def __init__(self, "set_parallel_mode": set_parallel_mode, "sequence_parallel": self.sequence_parallel, "tp_group": tp_group, + "num_gqa_groups": num_gqa_groups, "rng_state_name": attention_dropout_rng_state_name, "backend": backend, }