diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20b16c2809..9bf3f9fa91 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -25,6 +25,7 @@ _jax_dbias_cast_transpose, ) from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 +from transformer_engine.jax.gemm import fp8_gemm, gemm from transformer_engine.jax import cpp_extensions as tex @@ -415,6 +416,60 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_ ) +class TestGemm: + + @staticmethod + def _generate_inputs(b, m, n, k, dtype): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 3) + a = jax.random.normal(subkeys[0], (b, m, k), dtype) + b = jax.random.normal(subkeys[1], (n, k), dtype) + bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 + bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + return a, b, bias + + @staticmethod + def _generate_fp8_inputs(b, m, n, k, fp8_dtype): + a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) + a_scale, b_scale = map( + lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), + [a, b] + ) + a_q, b_q = map( + lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), + [(a, a_scale), (b, b_scale)] + ) + return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias + + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("use_bias", (False, True)) + @pytest.mark.parametrize("do_gelu", (False, True)) + def test_gemm(self, b, m, n, k, use_bias, do_gelu): + a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) + + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + ref_out = jnp.dot(a, b) + if use_bias: + ref_out += bias + if do_gelu: + ref_out = jax.nn.gelu(ref_out) + + assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) + def test_fp8_gemm(self, m, n, k, fp8_dtype): + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( + m, n, k, fp8_dtype + ) + + primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) + ref_out = jnp.dot(a, b) + + assert_allclose(primitive_out, ref_out, dtype=fp8_dtype) + + @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..1e5cc4c07e 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,6 +4,7 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py new file mode 100644 index 0000000000..677fabca59 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -0,0 +1,647 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for cuBlasLt GEMM""" +import warnings +import operator +from functools import reduce +from typing import Optional, Union, Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax.interpreters import mlir +from jax.interpreters.mlir import ir +from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi +from jax.typing import ArrayLike + +from transformer_engine import transformer_engine_jax as tex +from .base import BasePrimitive, register_primitive +from .custom_call import custom_caller, CustomCallArgsWrapper +from .misc import ( + jax_dtype_to_te_dtype, + jax_dtype_is_fp8, + get_padded_spec, + is_ffi_enabled, +) +from ..sharding import ( + global_mesh_resource, + get_mesh_axis_size, + lax_paral_op, + all_reduce_max_along_all_axes_except_PP, +) + + +__all__ = [ + "fp8_gemm_impl", + "gemm_impl", +] + + +def get_cublas_workspace_size_bytes() -> None: + """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" + if tex.get_device_compute_capability() >= 90: + return 33_554_432 + return 4_194_304 + + +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (8, 9, 10, 11, 12, 13, 14) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, + gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, + fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + """ + cuBlasLt GEMM abstract + """ + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Disallow batching for RHS + assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + assert ( + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] + ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + assert ( + not (lhs_trans and rhs_trans) + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Validate output dtype + if jax_dtype_is_fp8(out_dtype): + assert ( + jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + out_dtype = lhs_dtype + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Infer output shape + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + + # Validate bias/bias_grad shape against inferred output + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 + and bias_aval.ndim == 1 + and bias_aval.shape[0] == out_shape[-1] + ), "Incorrect bias shape." + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + if fuse_gelu: + assert ( + all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, + dtype=out_amax_updated_dtype) + out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, + dtype=out_scale_updated_dtype) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), + dtype=jnp.uint8) + + return ( + out_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + workspace_aval + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + _ + ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval + + @staticmethod + def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + """ + Fused attention fwd lowering rules + """ + lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 1, # out_amax <--> out_amax_updated + 7: 2, # out_scale <--> out_scale_updated + } + + if is_ffi_enabled(): + name = "te_gemm_ffi" + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + else: + operands = [ + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + ] + operand_shapes = map(lambda x: ir.RankedTensorType(x.type).shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_dtype(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + k = rhs_aval.shape[rhs_inner_dim] + n = rhs_aval.shape[rhs_outer_dim] + workspace_size = get_cublas_workspace_size_bytes() + operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) + bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) + opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, + jax_dtype_to_te_dtype(out_dtype), bias_dtype, + lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator) + + return custom_caller( + CollectiveGemmPrimitive.name, + args, + opaque, + has_side_effect=False, + operand_output_aliases=operand_output_aliases, + ) + + @staticmethod + def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + assert CollectiveGemmPrimitive.inner_primitive is not None + + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator): + assert CollectiveGemmPrimitive.outer_primitive is not None + + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args + assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." + + # Get contracting and batch dimensions out + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + + # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to + # reorder the axes here to match + if jax_dtype_is_fp8(lhs.dtype): + lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) + lhs_trans = False + rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) + rhs_trans = True + contracting_dims = (1, 1) + + # Collapse all non-contracting dimensions + batch_shape = [lhs.shape[dim] for dim in lhs_bdims] + batch_size = reduce(operator.mul, batch_shape, 1) + lhs_outer_size = lhs.shape[lhs_outer_dim] + lhs_shape_2d = ( + (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) + if lhs_trans + else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) + ) + lhs = jnp.reshape(lhs, lhs_shape_2d) + if fuse_gelu: + gelu_input = jnp.reshape( + gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + + outputs = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Reshape output to recover original LHS batch shape + outputs[0] = jnp.reshape( + outputs[0], + (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + gelu_bdims = batch_dims[3] + if fuse_gelu: + outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) + gelu_bdims = lhs_bdims + + return ( + outputs, + (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + ) + + @staticmethod + def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator, mesh, arg_infos, + result_infos): + del out_dtype, accumulate, use_split_accumulator, result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: + warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly.") + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + if rhs_spec[rhs_inner_dim] is not None and rhs_outer_spec is not None: + raise RuntimeError("Both inner and outer dimensions of RHS cannot be sharded.") + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + + @staticmethod + def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator, mesh, arg_infos, result_infos): + del result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + # Force all-gather the outer (sequence) dimension of the LHS operand + lhs_spec_new = [spec for spec in lhs_spec] + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + + # RHS operand is unchanged, we already enforce that only one dimension can be sharded + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, + bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) + out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, + bias_sharding) + + def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, + out_scale): + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + if rhs_spec[rhs_inner_dim] is not None: + # GEMM output needs to be all-reduced when the contracting dimension is sharded. + # If the layer is sequence-parallel, we also need to scatter the output, which we + # can combine into a reduce-scatter here. + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, + mesh) + if fuse_gelu: + pre_gelu_out = lax_paral_op( + pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh + ) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." + else: + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) + + out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( + rhs, + rhs_scale_inv, + lhs, + lhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=tuple(reversed(contracting_dims)), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + return out, out_amax, out_scale, pre_gelu_out + + +def gemm_impl( + lhs: ArrayLike, + rhs: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + fuse_bias: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + dummy_fp8_meta, + rhs, + dummy_fp8_meta, + bias, + gelu_input, + dummy_fp8_meta, + dummy_fp8_meta, + out_dtype=lhs.dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if grad: + return out, pre_gelu_out, bias_grad + else: + return out, pre_gelu_out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..15d7537fbd 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -81,6 +81,13 @@ def jax_dtype_to_te_dtype(jax_dtype): return converter.get(jax_dtype) +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_padded_spec(arg_info): """ Get padded spec for partitioning from arguments' information diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..afac283a6f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,6 +147,31 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right); +struct CustomCallGemmDescriptor { + size_t batch; + size_t m; + size_t k; + size_t n; + size_t workspace_size; + DType operand_dtype; + DType bias_dtype; + DType out_dtype; + bool lhs_trans; + bool rhs_trans; + bool fuse_gelu; + bool fuse_bias; + bool grad; + bool accumulate; + bool use_split_accumulator; +}; + +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType out_dtype, DType bias_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator); + // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -308,6 +333,20 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +// GEMM + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp new file mode 100644 index 0000000000..f60ae510df --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/gemm.h" + +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "extensions.h" + +namespace transformer_engine { + +namespace jax { + +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} + : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + // Inputs + auto *lhs = buffers[0]; + auto *lhs_scale_inv = reinterpret_cast(buffers[1]); + auto *rhs = buffers[2]; + auto *rhs_scale_inv = reinterpret_cast(buffers[3]); + auto *bias = buffers[4]; + auto *gelu_input = buffers[5]; + auto *out_amax = reinterpret_cast(buffers[6]); + auto *out_scale = reinterpret_cast(buffers[7]); + + // Outputs + auto *out = buffers[8]; + auto *out_amax_updated = reinterpret_cast(buffers[9]); + auto *out_scale_updated = reinterpret_cast(buffers[10]); + auto *pre_gelu_out = buffers[11]; + auto *bias_grad = buffers[12]; + auto *workspace = buffers[13]; + + // Operand aliasing + NVTE_CHECK(bias == bias_grad, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale == out_scale_updated, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + const auto &desc = *UnpackOpaque(opaque, opaque_len); + std::vector lhs_shape = {(desc.lhs_trans) ? desc.k : desc.m, + (desc.lhs_trans) ? desc.m : desc.k}; + std::vector rhs_shape = {(desc.rhs_trans) ? desc.n : desc.k, + (desc.rhs_trans) ? desc.k : desc.n}; + + GemmImpl(stream, lhs, lhs_shape, lhs_scale_inv, desc.lhs_trans, rhs, rhs_shape, rhs_scale_inv, + desc.rhs_trans, desc.operand_dtype, bias, desc.bias_dtype, out, out_amax, out_scale, + desc.out_dtype, pre_gelu_out, workspace, desc.workspace_size, desc.fuse_gelu, + desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..1a9ce987af 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -80,5 +80,16 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( deterministic, window_size_left, window_size_right}); } +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType bias_dtype, DType out_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator) { + return PackOpaque(CustomCallGemmDescriptor{batch, m, n, k, workspace_size, operand_dtype, + bias_dtype, out_dtype, lhs_trans, rhs_trans, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator}); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..7b8ebdcdd2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -51,6 +51,7 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + dict["te_gemm"] = EncapsulateFunction(Gemm); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -101,6 +102,7 @@ pybind11::dict Registrations() { fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); return dict; } @@ -114,10 +116,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); + m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_device_compute_capability", &GetDeviceComputeCapability, pybind11::arg("gpu_id") = -1); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..b328c6e278 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -23,7 +23,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); -int GetDeviceComputeCapability(int gpu_id); +int GetDeviceComputeCapability(int gpu_id = -1); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..7312aa8295 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -334,6 +334,7 @@ def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage: input_name_post_fix = f"_i_{postfix}" weight_name_post_fix = f"_w_{postfix}" grad_name_post_fix = f"_g_{postfix}" + output_name_post_fix = f"_o_{postfix}" def generate_a_set(target_postfix): amax = nn_partitioning.variable_with_axes( @@ -359,10 +360,10 @@ def generate_a_set(target_postfix): input_amax, input_scale = generate_a_set(input_name_post_fix) weight_amax, weight_scale = generate_a_set(weight_name_post_fix) grad_amax, grad_scale = generate_a_set(grad_name_post_fix) + output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage( - input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale - ) + return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, + grad_scale, output_amax, output_scale) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..3d58c86e3e 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -86,10 +86,11 @@ class FP8MetaPackage: A container that contains all required meta data for FP8 """ - NUM_OF_META: int = 3 + NUM_OF_META: int = 4 INPUT_IDX: int = 0 WEIGHT_IDX: int = 1 GRAD_IDX: int = 2 + OUTPUT_IDX: int = 3 def __init__( self, @@ -99,6 +100,8 @@ def __init__( weight_scale: jnp.ndarray, grad_amax: jnp.ndarray, grad_scale: jnp.ndarray, + output_amax: jnp.ndarray, + output_scale: jnp.ndarray, ) -> None: self._amax_list = [None] * FP8MetaPackage.NUM_OF_META @@ -110,6 +113,8 @@ def __init__( self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale + self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax + self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale @property def amax_list(self) -> List[jnp.ndarray]: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py new file mode 100644 index 0000000000..ccd109e095 --- /dev/null +++ b/transformer_engine/jax/gemm.py @@ -0,0 +1,425 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax.ad_checkpoint import checkpoint_name + +from .fp8 import FP8Helper, FP8MetaPackage +from .cpp_extensions import ( + gemm_impl, + fp8_gemm_impl, + cast_fp8, + cast_transpose, + dact_lu, + dbias_cast_transpose, + dact_lu_dbias_cast_transpose, +) + + + +__all__ = [ + "gemm", + "fp8_gemm", + "type_safe_gemm", +] + + +def gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Union[ArrayLike, None], + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, + use_split_accumulator) + return out + + +def _gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + out, pre_gelu_out = gemm_impl( + x, + kernel, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + + ctx = ( + x, + kernel, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + ) + + return out, ctx + + +def _gemm_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + x, kernel, pre_gelu_out, fuse_bias = ctx + + x_t_contracting = 0 if contracting_dims[0] == 1 else 1 + wgrad, dgelu, bgrad = gemm_impl( + x, + grad, + gelu_input=pre_gelu_out, + contracting_dims=(x_t_contracting, 0), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 + dgrad, *_ = gemm_impl( + dgelu if fuse_gelu else grad, + kernel, + gelu_input=pre_gelu_out, + contracting_dims=(1, kernel_t_contracting), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) + + +def fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + fp8_meta: FP8MetaPackage, + bias: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def _fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" + out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return out + + +def _fp8_gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, *scale_list, + ) + amax_list = maybe_fm32_to_fp32(*amax_list) + scale_list = maybe_fm32_to_fp32(*scale_list) + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) + amax_list = FP8MetaPackage.update_amax_list(amax_list) + + x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] + x_scale = scale_list[FP8MetaPackage.INPUT_IDX] + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[0] == 0: + _, casted_x, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + else: + if contracting_dims[0] == 0: + casted_x_t = x + casted_x = casted_x_t.transpose() + else: + casted_x = x + updated_x_amax = x_amax + + kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] + kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM + _, casted_kernel_t, updated_kernel_amax = cast_transpose( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_kernel_t, updated_kernel_amax = cast_fp8( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + ) + else: + if contracting_dims[1] == 0: + casted_kernel = kernel + casted_kernel_t = casted_kernel.transpose() + else: + casted_kernel_t = kernel + updated_kernel_amax = kernel_amax + + out_amax = ( + amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out_scale = ( + scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out, updated_out_amax, updated_out_scale, pre_gelu_out = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_kernel_t, + kernel_scale_inv, + bias=bias, + out_amax=out_amax, + out_scale=out_scale, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + updated_out_amax = None + updated_out_scale = None + + ctx = ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + maybe_fp32_to_fm32 + ) + + return (out, updated_out_amax, updated_out_scale), ctx + + +def _fp8_gemm_bwd_rule( + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out, + fuse_bias, + maybe_fp32_to_fm32 + ) = ctx + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + + grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] + grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] + grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] + if fuse_bias and not fuse_gelu: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + _, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None + + + + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_grad_t, + grad_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if fuse_gelu and fuse_bias: + # Fuse dbias into this dGELU. + casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu", ), + ) + elif fuse_gelu: + # No bias to fuse so we just do dGELU. + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + bgrad = None + + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = gemm_impl( + casted_dgelu if fuse_gelu else grad, + grad_scale_inv, + casted_kernel_t, + kernel_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + amax_list[FP8MetaPackage.INPUT_IDX] = ( + amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( + amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) + ) + + amax_list = maybe_fp32_to_fm32(*amax_list) + scale_list = maybe_fp32_to_fm32(*scale_list) + + return dgrad, wgrad, bgrad, amax_list, scale_list + + +_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) + + +def type_safe_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + fp8_meta: Optional[FP8MetaPackage] = None, + out_dtype: Optional[jnp.dtype] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." + + if fp8_meta is not None: + return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, + accumulate, use_split_accumulator) + else: + return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator)