Skip to content

Commit

Permalink
[JAX] Support SP + RoPE + GeLU (#602)
Browse files Browse the repository at this point in the history
* Adding support of sequence parallelism

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Adding RoPE

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix wrong batch_logical_axes

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Rnaming FSDP outer env var

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Poring RoPE to Praxis layers.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Porting GeLU + [FP8 Cast].

Signed-off-by: Ming Huang <mingh@nvidia.com>

* WAR to make XLA successfully match FP8 GEMM on FFN1 with GeLU.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Allowing arbitrary dimension of NVShape for the workspace allocation

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Adding checkpoint_name to fused functions of mlp.py to get better perf with nn.scan.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Modify with review feedback.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix bugs

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix typo.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fixed for lint

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Follow review feedback to modify code.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix typo.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Port SP to Praxis

Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>

* Fix an issue when enabling both GQA and RoPE.

Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>

* Update docs

Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>

---------

Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>
  • Loading branch information
mingxu1067 authored Feb 2, 2024
1 parent 29b0c9c commit ce163f9
Show file tree
Hide file tree
Showing 17 changed files with 2,133 additions and 370 deletions.
19 changes: 16 additions & 3 deletions docs/api/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
Jax
=======

.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
Pre-defined Variable of Logical Axes
------------------------------------
Variables are available in `transformer_engine.jax.sharding`.

* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.


Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.MeshResource()


.. autoapifunction:: transformer_engine.jax.fp8_autocast
Expand Down
27 changes: 26 additions & 1 deletion examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
enable_seq_paral: bool

@nn.compact
def __call__(self, x, mask, disable_dropout=False):
Expand All @@ -50,11 +51,17 @@ def __call__(self, x, mask, disable_dropout=False):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(x,
jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))

x = te_flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
Expand Down Expand Up @@ -266,7 +273,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
None)):
encoder = Net(num_embed)
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
Expand Down Expand Up @@ -379,6 +386,10 @@ def encoder_parser(args):
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument("--enable-sp",
action="store_true",
default=False,
help="Enable sequence parallelism.")

return parser.parse_args(args)

Expand All @@ -405,6 +416,20 @@ def test_te_fp8(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79

def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79

@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
213 changes: 213 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
from flax import linen as nn

from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp

GEMM_CASES = [
(256, 256, 512),
Expand Down Expand Up @@ -285,6 +288,126 @@ def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_ln_gelu_fp8_mlp(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
activations = ('gelu',)

a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16)
k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16)
b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16)
b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)

init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
init_fp8_metas_amax = jnp.zeros(
(FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)

def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(
layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm"))

def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray) -> jnp.ndarray:

x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)

fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,)))

bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)

x = jax.nn.gelu(linear_1_out)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)

fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,)))

bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape
output += jnp.reshape(bias_2, bias_2_shape)

return output

def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
return jnp.mean(
ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv))

value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))

ref_fp8_max = init_fp8_max
ref_fp8_metas_amax = init_fp8_metas_amax
ref_fp8_metas_scale = init_fp8_metas_scale
ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv

pri_fp8_max = init_fp8_max
pri_fp8_metas_amax = init_fp8_metas_amax
pri_fp8_metas_scale = init_fp8_metas_scale
pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv

for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale,
ref_fp8_metas_scale_inv) = value_n_grad_ref_func(
a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax,
ref_fp8_metas_scale, ref_fp8_metas_scale_inv)

for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, pri_fp8_max,
pri_fp8_metas_amax, pri_fp8_metas_scale,
pri_fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, pri_fp8_max, pri_fp8_metas_amax,
pri_fp8_metas_scale, pri_fp8_metas_scale_inv)

assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=jnp.bfloat16)
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=jnp.bfloat16)


@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
Expand All @@ -294,6 +417,96 @@ def random_inputs_fixture(shape):
return out


class TestGeLu:

def ref_func(self, inputs):

func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x))))
return func(inputs)

def prim_func(self, inputs):

@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out

def primitive_fwd(x):
out = gelu(x)
ctx = x
return out, ctx

def primitive_bwd(ctx, g):
x = ctx
out = dgelu(g, x)
return (out,)

primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x: jnp.mean(primitive(x)))
return func(inputs)

@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)

assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)


class TestGeLuFP8(TestGeLu):

def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)

@jax.custom_vjp
def primitive(x, y, z, w):
out = primitive_fwd(x)
return out

def primitive_fwd(x, y, z, w):
out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn)
out = dequantize(out, x.dtype, scale_inv)
ctx = x
return out, ctx

def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose(
g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1)
dgelu = dequantize(dgelu, x.dtype, scale_inv)
dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv)
return dgelu, dgelu_trans, dbias, amax_out

primitive.defvjp(primitive_fwd, primitive_bwd)
func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3))

return func(inputs, no_use, no_use, no_use)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)])
def test_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)

x = random_inputs
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)

assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, (2, 0, 1)),
dtype=FP8Helper.BWD_DTYPE)


class TestGatedGeLu:

def ref_func(self, inputs):
Expand Down
21 changes: 20 additions & 1 deletion tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"

BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
Expand Down Expand Up @@ -137,7 +138,25 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
}]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
Expand Down
Loading

0 comments on commit ce163f9

Please sign in to comment.