From 4077ccc1a5f1f031bfd86ab7cf0772f86503f159 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 29 Jan 2024 12:03:40 -0600 Subject: [PATCH] [JAX] Custom Op Workspace Tensors from XLA Buffers (#532) * Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by: Alp Dener * removed unused GEMM C++ API in TE-JAX Signed-off-by: Alp Dener * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by: Alp Dener * fixed import order for linting Signed-off-by: Alp Dener * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by: Alp Dener * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by: Alp Dener * fixed linting errors for blank lines Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener --- .gitignore | 2 + tests/jax/test_custom_call_compute.py | 4 +- .../fused_attn_f16_arbitrary_seqlen.cu | 12 +- .../fused_attn_f16_max512_seqlen.cu | 13 +- .../common/fused_attn/fused_attn_fp8.cu | 12 +- .../common/layer_norm/ln_api.cpp | 37 +- .../common/rmsnorm/rmsnorm_api.cpp | 30 +- transformer_engine/jax/cpp_extensions.py | 550 +++++++--- transformer_engine/jax/csrc/extensions.cpp | 14 +- transformer_engine/jax/csrc/modules.cpp | 994 +++++++++++------- transformer_engine/jax/csrc/modules.h | 99 +- transformer_engine/jax/csrc/utils.h | 60 -- transformer_engine/jax/flax/module.py | 4 +- transformer_engine/jax/mlp.py | 16 +- 14 files changed, 1190 insertions(+), 657 deletions(-) diff --git a/.gitignore b/.gitignore index 71d6bacc39..4502c06264 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ tests/cpp/build/ docs/_build .ipynb_checkpoints docs/doxygen +*.log +CMakeFiles/CMakeSystem.cmake \ No newline at end of file diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d5382bcc2f..163e0a5a5d 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -20,7 +20,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp +from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp GEMM_CASES = [ (256, 256, 512), @@ -196,7 +196,7 @@ def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, # out = (x * y) * z fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv) - return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) + return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) def _convert_to_activation_function(fn_or_string): """Convert a string to an activation function.""" diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 491b43347d..e11da334a8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) @@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + // Build variant pack std::unordered_map, void*> variant_pack = { {Q, devPtrQ}, @@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) @@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + // build variant pack std::unordered_map, void*> variant_pack = { {q, devPtrQ}, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 9d9e6d05f4..8e899788dc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl( void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { try { - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - FADescriptor descriptor{b, h, s_q, s_kv, d, scaling_factor, @@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl( return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + // Prepare actual seqlen constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; @@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv size_t *workspace_size, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { try { - // Create cudnn handle - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - FADescriptor descriptor{ b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, layout, bias_type, mask_type, tensorType, false}; @@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 38c5d8c7b3..1d9a881fba 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in cudaStream_t stream, cudnnHandle_t handle_) { try { - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - FADescriptor descriptor{ b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout, @@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + int32_t* qkv_ragged_offset = reinterpret_cast( reinterpret_cast(workspace_ptr) + wkspace_size); int32_t* o_ragged_offset = reinterpret_cast( @@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in cudaStream_t stream, cudnnHandle_t handle_) { try { - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - FADescriptor descriptor{ b, h, s_q, s_kv, d, attnScale, false, dropoutProbability, layout, @@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in return; } + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + int32_t* qkv_ragged_offset = reinterpret_cast( reinterpret_cast(workspace_ptr) + wkspace_size); int32_t* o_ragged_offset = reinterpret_cast( diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index 63aa43622d..9112d89666 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const bool fp8_out = is_fp8_dtype(otype); const auto ctype = layer_norm::DType::kFloat32; - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - NVTE_CHECK(x.data.shape.size() == 2); const size_t rows = x.data.shape[0]; @@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size return; } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); + if ( launch_params.barrier_size > 0 ) { params.workspace = workspace->data.dptr; params.barrier = reinterpret_cast(barrier->data.dptr); @@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz, auto otype = wtype; auto ctype = DType::kFloat32; - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(mu.data.dtype == ctype); NVTE_CHECK(rsigma.data.dtype == ctype); @@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz, return; } + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + if ( launch_params.barrier_size > 0 ) { params.workspace = workspace->data.dptr; params.barrier = reinterpret_cast(barrier->data.dptr); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index f9d0825c0c..6487336b20 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens const bool fp8_out = is_fp8_dtype(otype); auto ctype = DType::kFloat32; - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - NVTE_CHECK(x.data.shape.size() == 2); const size_t rows = x.data.shape[0]; @@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens return; } + + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); + + if (launch_params.barrier_size > 0) { params.workspace = workspace->data.dptr; params.barrier = reinterpret_cast(barrier->data.dptr); @@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const auto otype = wtype; auto ctype = DType::kFloat32; - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - NVTE_CHECK(dz.data.dtype == otype); NVTE_CHECK(rsigma.data.dtype == ctype); @@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const return; } + // Tensor checks are delayed here in order to recover workspace sizes with null data + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + if (launch_params.barrier_size > 0) { params.workspace = workspace->data.dptr; params.barrier = reinterpret_cast(barrier->data.dptr); diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 015c43358e..d9f7446ba9 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -10,13 +10,6 @@ import os import warnings -import transformer_engine_jax -from transformer_engine_jax import DType as TEDType -from transformer_engine_jax import NVTE_Bias_Type -from transformer_engine_jax import NVTE_Mask_Type -from transformer_engine_jax import NVTE_QKV_Layout -from transformer_engine_jax import NVTE_Fused_Attn_Backend - import numpy as np import jax.numpy as jnp from jax.lib import xla_client @@ -28,6 +21,13 @@ from jax._src.interpreters import batching from jax._src import dispatch +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType +from transformer_engine_jax import NVTE_Bias_Type +from transformer_engine_jax import NVTE_Mask_Type +from transformer_engine_jax import NVTE_QKV_Layout +from transformer_engine_jax import NVTE_Fused_Attn_Backend + from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import get_all_mesh_axes, num_of_devices @@ -58,6 +58,7 @@ def te_dtype_to_jax_dtype(te_dtype): TEDType.kInt64: jnp.int64, TEDType.kFloat8E4M3: jnp.float8_e4m3fn, TEDType.kFloat8E5M2: jnp.float8_e5m2, + TEDType.kByte: jnp.uint8 } if te_dtype not in converter: @@ -94,6 +95,7 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.int64.dtype: TEDType.kInt64, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, + jnp.uint8.dtype: TEDType.kByte, } if jax_dtype not in converter: @@ -124,7 +126,7 @@ def _check_valid_batch_dims(bdims): class BasePrimitive(metaclass=ABCMeta): """ - jax premitive + jax primitive """ @staticmethod @@ -135,6 +137,13 @@ def abstract(): """ return NotImplemented + @classmethod + def outer_abstract(cls, *args, **kwargs): + """ + optional abstract wrapper to eliminate workspace tensors + """ + return cls.abstract(*args, **kwargs) + @staticmethod @abstractmethod def lowering(): @@ -196,7 +205,7 @@ def name_of_wrapper_p(): dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results outer_p.def_impl(cls.impl) - outer_p.def_abstract_eval(cls.abstract) + outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, @@ -287,9 +296,9 @@ class LayerNormFwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(x_aval, gamma_aval, beta_aval, **kwargs): # pylint: disable=unused-argument + def abstract(x_aval, gamma_aval, beta_aval, **kwargs): """ - LayerNorm fwd abstract + LayerNorm fwd inner primitive abstract """ x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -303,6 +312,28 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): # pylint: disable=unus hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 + wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + x_aval.size // hidden_size, # batch size + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) + True, kwargs['zero_centered_gamma'], kwargs['epsilon'] + ) + wkspace_aval = out_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = out_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + + return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + LayerNorm fwd outer primitive abstract + """ + out_aval, mu_aval, rsigma_aval, _, _ = \ + LayerNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval @staticmethod @@ -333,10 +364,14 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + out_types = [ ir.RankedTensorType.get(out_shape, output_type), ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) ] operands = [x, gamma, beta] operand_shapes = [x_shape, g_shape, b_shape] @@ -347,8 +382,16 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + 0, # no dgamma_part in FWD pass + 0, # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -364,7 +407,7 @@ def impl(x, gamma, beta, zero_centered_gamma, epsilon): to describe implementation """ assert LayerNormFwdPrimitive.inner_primitive is not None - out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind( + out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) return out, mu, rsigma @@ -449,9 +492,9 @@ class LayerNormBwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument + def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): """ - Layernorm bwd abstract + Layernorm bwd inner primitive abstract """ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) @@ -464,6 +507,34 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): # dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + + wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = \ + transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + True, kwargs['zero_centered_gamma'], kwargs['epsilon'] + ) + wkspace_aval = dx_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = dx_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0], + dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])) + dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0], + dtype=te_dtype_to_jax_dtype(dbeta_part_info[1])) + + return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \ + dgamma_part_aval, dbeta_part_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + LayerNorm bwd outer primitive abstract + """ + dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = \ + LayerNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval, dbeta_aval @staticmethod @@ -488,22 +559,32 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size + out_types = [ - ir.RankedTensorType.get(x_shape, x_type.element_type), - ir.RankedTensorType.get(g_shape, g_type.element_type), - ir.RankedTensorType.get(b_shape, b_type.element_type), + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out ] + operands = [dz, mu, rsigma, x, gamma] operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + dgamma_part_aval.size, + dbeta_part_aval.size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + jax_dtype_to_te_dtype(dgamma_part_aval.dtype), + jax_dtype_to_te_dtype(dbeta_part_aval.dtype), zero_centered_gamma, epsilon, sm_margin, @@ -516,7 +597,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): @staticmethod def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): assert LayerNormBwdPrimitive.inner_primitive is not None - dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) return dx, dgamma, dbeta @@ -609,9 +690,9 @@ class RmsNormFwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument + def abstract(x_aval, gamma_aval, **kwargs): """ - RMSNorm fwd abstract + RMSNorm fwd inner primitive abstract """ x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -624,6 +705,27 @@ def abstract(x_aval, gamma_aval, **kwargs): # pylint: disable=unused-argument hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 + wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + x_aval.size // hidden_size, # batch size + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) + False, False, kwargs['epsilon'] + ) + wkspace_aval = out_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = out_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + + return out_aval, rsigma_aval, wkspace_aval, barrier_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + RMSNorm fwd outer primitive abstract + """ + out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, rsigma_aval @staticmethod @@ -643,9 +745,13 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(batch_shape, rsigma_element_type), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) ] operands = [x, gamma] operand_shapes = [x_shape, g_shape] @@ -656,8 +762,16 @@ def lowering(ctx, x, gamma, *, epsilon): opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + 0, # no dgamma_part in FWD pass + 0, # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -673,7 +787,7 @@ def impl(x, gamma, epsilon): to describe implementation """ assert RmsNormFwdPrimitive.inner_primitive is not None - out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) + out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) return out, rsigma @staticmethod @@ -744,15 +858,9 @@ class RmsNormBwdPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract( - dz_aval, - x_aval, - rsigma_aval, - gamma_aval, - **kwargs # pylint: disable=unused-argument - ): + def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): """ - RMSNorm bwd abstract + RMSNorm bwd inner primitive abstract """ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) @@ -764,6 +872,30 @@ def abstract( dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = core.raise_to_shaped(gamma_aval) + + wkspace_info, barrier_info, dgamma_part_info, _ = \ + transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + False, False, kwargs['epsilon'] + ) + wkspace_aval = dx_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = dx_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0], + dtype=te_dtype_to_jax_dtype(dgamma_part_info[1])) + + return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + RMSNorm bwd outer primitive abstract + """ + dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval @staticmethod @@ -782,9 +914,15 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size + wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] + out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(g_shape, g_type.element_type), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), + ir.RankedTensorType.get(dgamma_part_aval.shape, + jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)) ] operands = [dz, rsigma, x, gamma] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] @@ -795,8 +933,16 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + dgamma_part_aval.size, + 0, # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + jax_dtype_to_te_dtype(dgamma_part_aval.dtype), + TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -809,7 +955,8 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): @staticmethod def impl(dz, x, rsigma, gamma, epsilon): assert RmsNormBwdPrimitive.inner_primitive is not None - dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) + dx, dgamma, _, _, _ = \ + RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) return dx, dgamma @staticmethod @@ -1721,40 +1868,60 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): """ - Self fused attention fwd abstract + Self fused attention fwd inner primitive abstract """ - # outer_primitve is seqlen, inner_primitive is cu_seqlen - del seqlen_or_cu_seqlen_aval, scaling_factor, is_training + # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen + del seqlen_or_cu_seqlen_aval qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) - *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape + *batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape assert nqkv == 3 assert qkv_aval.dtype == bias_aval.dtype - output_shape = (*batch_shape, max_seqlen, num_head, head_dim) - output_dtype = qkv_dtype + output_shape = (*batch_shape, max_seqlen, num_heads, head_dim) + out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype) + # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, - attn_mask_type, dropout_probability, num_head, num_head, + attn_mask_type, dropout_probability, num_heads, num_heads, max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen) + softmax_shape = (*batch_shape, num_heads, max_seqlen, max_seqlen) softmax_dtype = qkv_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1) + softmax_shape = (*batch_shape, num_heads, max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f'Unsupported {backend=}') + softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype) + # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with + # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) assert seed_dtype == checker.rng_state_dtype rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_dtype = seed_dtype + rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype) - softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype) - rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype) + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to + # prepare for the active fused-attn backend + batch_size = reduce(operator.mul, batch_shape) + wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes( + batch_size, max_seqlen, num_heads, head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(qkv_aval.dtype), is_training + ) + wkspace_aval = qkv_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + + return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Self fused attention fwd outer primitive abstract + """ + out_aval, softmax_aux_aval, rng_state_aval, _ = \ + SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs) return out_aval, softmax_aux_aval, rng_state_aval @staticmethod @@ -1763,23 +1930,25 @@ def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, """ Self fused attention fwd lowering rules """ - qkv_aval, _, _, _ = ctx.avals_in - - *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape - batch = reduce(operator.mul, batch_shape) - operands = [qkv, bias, cu_seqlen, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) for output in ctx.avals_out ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + qkv_aval = ctx.avals_in[0] + *batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, - dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) + batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -1792,7 +1961,7 @@ def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor cu_seqlen = generate_cu_seqlen(seqlen) - output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind( + output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind( qkv, bias, cu_seqlen, @@ -1897,16 +2066,35 @@ def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, """ Self fused attention bwd abstract """ - del softmax_aux_aval, rng_state_aval - # outer_primitve is seqlen, inner_primitive is cu_seqlen - del seqlen_or_cu_seqlen_aval, attn_bias_type, attn_mask_type - del scaling_factor, dropout_probability, is_training + del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval + + assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype + *batch_shape, max_seqlen, nqkv, num_heads, head_dim = qkv_aval.shape + assert nqkv == 3 qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype + + batch_size = reduce(operator.mul, batch_shape) + wkspace_shape, wkspace_dtype = \ + transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes( + batch_size, max_seqlen, num_heads, head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(qkv_aval.dtype), is_training + ) dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + wkspace_aval = qkv_aval.update(shape=wkspace_shape, + dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + + return dqkv_aval, dbias_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Self fused attention bwd outer primitive abstract + """ + dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs) return dqkv_aval, dbias_aval @staticmethod @@ -1915,24 +2103,25 @@ def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, """ Self fused attention bwd lowering rules """ - qkv_aval, _, _, _, _, _, _ = ctx.avals_in - - *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape - batch = reduce(operator.mul, batch_shape) - operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) for output in ctx.avals_out ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + qkv_aval = ctx.avals_in[0] + *batch_shape, max_seqlen, _, num_heads, head_dim = qkv_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, - dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) + batch_size, max_seqlen, max_seqlen, num_heads, num_heads, head_dim, wkspace_aval.size, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) @@ -1945,7 +2134,7 @@ def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_t cu_seqlen = generate_cu_seqlen(seqlen) - dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( + dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind( qkv, bias, softmax_aux, @@ -2067,50 +2256,62 @@ def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, """ Cross fused attention fwd abstract """ - # outer_primitve is seqlen, inner_primitive is cu_seqlen - del scaling_factor, is_training - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape - kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) - *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == kv_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + *q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape assert q_batch_shape == kv_batch_shape assert q_head_dim == kv_head_dim assert nkv == 2 - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype - - output_shape = q_aval.shape - output_dtype = q_dtype + out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, - attn_bias_type, attn_mask_type, dropout_probability, q_num_head, - kv_num_head, q_max_seqlen, kv_max_seqlen, + attn_bias_type, attn_mask_type, dropout_probability, + num_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen) - softmax_aux_dtype = q_dtype + softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, kv_max_seqlen) + softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1) - softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32) + softmax_shape = (*q_batch_shape, num_heads, q_max_seqlen, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f'Unsupported {backend=}') + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) + # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with + # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) assert seed_dtype == checker.rng_state_dtype rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_dtype = seed_dtype + rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) + + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to + # prepare for the active fused-attn backend + batch_size = reduce(operator.mul, q_batch_shape) + wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training + ) + wkspace_aval = q_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - out_aval = q_aval.update(shape=output_shape, dtype=output_dtype) - softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype) - rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype) + return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Cross fused attention fwd outer primitive abstract + """ + out_aval, softmax_aux_aval, rng_state_aval, _ = \ + CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs) return out_aval, softmax_aux_aval, rng_state_aval @staticmethod @@ -2119,25 +2320,27 @@ def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_typ """ Cross fused attention fwd lowering rules """ - q_aval, kv_aval, *_ = ctx.avals_in - assert q_aval.dtype == kv_aval.dtype - - *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape - batch = reduce(operator.mul, batch_shape) - kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2] - operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) for output in ctx.avals_out ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + q_aval, kv_aval, *_ = ctx.avals_in + *batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape + *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim, + batch_size, q_max_seqlen, kv_max_seqlen, + num_heads, num_gqa_groups, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) + jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -2151,7 +2354,7 @@ def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, q_cu_seqlen = generate_cu_seqlen(q_seqlen) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind( + output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind( q, kv, bias, @@ -2266,7 +2469,7 @@ def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, outpu Cross fused attention bwd abstract """ del softmax_aux_aval, rng_state_aval, output_aval - del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) @@ -2274,9 +2477,35 @@ def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, outpu assert q_dtype == kv_dtype == bias_dtype == doutput_dtype assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype + *q_batch_shape, q_max_seqlen, num_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape + assert q_batch_shape == kv_batch_shape + assert q_head_dim == kv_head_dim + assert nkv == 2 + + batch_size = reduce(operator.mul, q_batch_shape) + wkspace_shape, wkspace_dtype = \ + transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes( + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, q_head_dim, + scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, + jax_dtype_to_te_dtype(q_aval.dtype), is_training + ) + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + wkspace_aval = q_aval.update(shape=wkspace_shape, + dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + + return dq_aval, dkv_aval, dbias_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Cross fused attention fwd outer primitive abstract + """ + dq_aval, dkv_aval, dbias_aval, _ = \ + CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs) return dq_aval, dkv_aval, dbias_aval @staticmethod @@ -2286,13 +2515,6 @@ def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seq """ Cross fused attention bwd lowering rules """ - q_aval, kv_aval, *_ = ctx.avals_in - assert q_aval.dtype == kv_aval.dtype - - *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape - batch = reduce(operator.mul, batch_shape) - kv_max_seqlen, kv_num_head = kv_aval.shape[-4], kv_aval.shape[-2] - operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ @@ -2302,12 +2524,19 @@ def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seq args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - # the dropout elements are encoded in the forward auxiliary tensor - # so seed is not needed in backward + q_aval, kv_aval, *_ = ctx.avals_in + *batch_shape, q_max_seqlen, num_heads, head_dim = q_aval.shape + *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape + batch_size = reduce(operator.mul, batch_shape) + + wkspace_aval = ctx.avals_out[-1] + opaque = transformer_engine_jax.pack_fused_attn_descriptor( - batch, num_head, kv_num_head, q_max_seqlen, kv_max_seqlen, head_dim, + batch_size, q_max_seqlen, kv_max_seqlen, + num_heads, num_gqa_groups, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) + jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + is_training) out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) @@ -2321,7 +2550,7 @@ def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seql q_cu_seqlen = generate_cu_seqlen(q_seqlen) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind( + dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind( q, kv, bias, @@ -3143,9 +3372,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive): def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, zero_centered_gamma, epsilon): """ - LayerNorm fwd (fp8 out) abstract + LayerNorm fwd (fp8 out) inner primitive abstract """ - del zero_centered_gamma, epsilon x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -3157,10 +3385,32 @@ def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_ava assert gamma_aval.size == beta_aval.size + wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in type + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type + jax_dtype_to_te_dtype(out_dtype), + True, zero_centered_gamma, epsilon + ) + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + wkspace_aval = x_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = x_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + LayerNorm fwd (fp8 out) outer primitive abstract + """ + out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = \ + LayerNormFwdFp8Primitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval, updated_amax_aval @staticmethod @@ -3204,11 +3454,15 @@ def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_cen batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) ] operands = [x, gamma, beta, amax, scale, scale_inv] operand_shapes = [ @@ -3221,8 +3475,16 @@ def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_cen opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + 0, # no dgamma_part in FWD pass + 0, # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -3242,7 +3504,7 @@ def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, to describe implementation """ assert LayerNormFwdFp8Primitive.inner_primitive is not None - out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind( + out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( x, gamma, beta, @@ -3359,9 +3621,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive): @staticmethod def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon): """ - RMSNorm fwd (fp8 out) abstract + RMSNorm fwd (fp8 out) inner primitive abstract """ - del epsilon x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -3374,10 +3635,31 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp rsigama_dtype = jnp.float32 + wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + x_aval.size // hidden_size, # batch_size + hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + jax_dtype_to_te_dtype(out_dtype), # out te_dtype + False, False, epsilon + ) + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + wkspace_aval = x_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + barrier_aval = x_aval.update(shape=barrier_info[0], + dtype=te_dtype_to_jax_dtype(barrier_info[1])) + + return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval + @staticmethod + def outer_abstract(*args, **kwargs): + """ + RMSNorm fwd (fp8 out) outer primitive abstract + """ + out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) return out_aval, rsigma_aval, amax_aval @staticmethod @@ -3414,10 +3696,14 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) ] operands = [x, gamma, amax, scale, scale_inv] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] @@ -3428,8 +3714,16 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, + wkspace_aval.size, + barrier_aval.size, + 0, # no dgamma_part in FWD pass + 0, # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -3449,13 +3743,13 @@ def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon): to describe implementation """ assert RmsNormFwdFp8Primitive.inner_primitive is not None - out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x, - gamma, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - epsilon=epsilon) + out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(x, + gamma, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + epsilon=epsilon) return out, rsigma, amax @staticmethod diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 153ec47a2a..ee428d13bd 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -29,7 +29,6 @@ pybind11::dict Registrations() { dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose); - dict["te_gemm"] = EncapsulateFunction(Gemm); dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); @@ -56,14 +55,19 @@ pybind11::dict Registrations() { PYBIND11_MODULE(transformer_engine_jax, m) { m.def("registrations", &Registrations); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); - m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); - m.def("get_cublasLt_version", &cublasLtGetVersion); - m.def("get_cuda_version", &GetCudaRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); + m.def("get_cuda_version", &GetCudaRuntimeVersion); + m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_cublasLt_version", &cublasLtGetVersion); + m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); + m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); + m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes); + m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes); + m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes); + m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 487a66c720..25ddbc67c3 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -22,7 +22,6 @@ #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "transformer_engine/fused_attn.h" -#include "transformer_engine/gemm.h" #include "transformer_engine/layer_norm.h" #include "transformer_engine/rmsnorm.h" #include "transformer_engine/softmax.h" @@ -33,11 +32,12 @@ namespace transformer_engine { namespace jax { -constexpr size_t kCublasLtForwardWorkspaceSize = 32 * 1024 * 1024; -constexpr size_t kCublasLtBackwardWorkspaceSize = 32 * 1024 * 1024; - inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } +std::vector MakeShapeVector(NVTEShape shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + template pybind11::bytes PackOpaque(const T &descriptor) { auto str = std::string(reinterpret_cast(&descriptor), sizeof(T)); @@ -61,33 +61,37 @@ pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, return PackOpaque(desc); } -pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype, - DType B_dtype, DType D_dtype, bool transa, bool transb, - bool use_split_accumulator) { - return PackOpaque(CustomCallGemmDescriptor{m, n, k, A_dtype, B_dtype, D_dtype, transa, transb, - use_split_accumulator}); -} - -pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, size_t barrier_size, + size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, + DType x_dtype, DType w_dtype, + DType wkspace_dtype, DType barrier_dtype, + DType dgamma_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { - return PackOpaque( - CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin}); + return PackOpaque(CustomCallNormDescriptor{batch_size, hidden_size, wkspace_size, barrier_size, + dgamma_part_sizes, dbeta_part_sizes, + x_dtype, w_dtype, wkspace_dtype, barrier_dtype, + dgamma_part_dtype, dbeta_part_dtype, + zero_centered_gamma, eps, sm_margin}); } -pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, - size_t q_seqlen, size_t k_seqlen, DType dtype, - float scale_factor) { - return PackOpaque( - SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor}); +pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, + size_t head_dim, size_t q_seqlen, size_t k_seqlen, + DType dtype, float scale_factor) { + return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, + dtype, scale_factor}); } pybind11::bytes PackCustomCallFusedAttnDescriptor( - size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, DType dtype, bool is_training) { + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + DType dtype, DType wkspace_dtype, bool is_training) { return PackOpaque(CustomCallFusedAttnDescriptor{ - batch, num_head, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, scaling_factor, - dropout_probability, bias_type, mask_type, dtype, is_training}); + batch_size, q_max_seqlen, kv_max_seqlen, num_heads, num_gqa_groups, head_dim, wkspace_size, + scaling_factor, dropout_probability, bias_type, mask_type, dtype, wkspace_dtype, + is_training}); } void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, @@ -247,48 +251,56 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op output_trans_tensor.data(), stream); } -void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *A = buffers[0]; - auto *B = buffers[1]; - auto *A_scale_inverse = reinterpret_cast(buffers[2]); - auto *B_scale_inverse = reinterpret_cast(buffers[3]); - auto *D = buffers[4]; - - // We transposes shape of A, B and D here to correctly invoke - // cuBlasLt GEMM (col-major) for row-major data. - const auto &desc = *UnpackOpaque(opaque, opaque_len); - - auto m = desc.m; - auto n = desc.n; - auto k = desc.k; - auto A_shape = std::vector{k, m}; - auto A_tensor = TensorWrapper(A, A_shape, desc.A_dtype, nullptr, nullptr, A_scale_inverse); - - auto B_shape = std::vector{n, k}; - auto B_tensor = TensorWrapper(B, B_shape, desc.B_dtype, nullptr, nullptr, B_scale_inverse); - - auto D_shape = std::vector{n, m}; - auto D_tensor = TensorWrapper(D, D_shape, desc.D_dtype); - - auto null_tensor = TensorWrapper(nullptr, std::vector{0}, DType::kFloat32); +pybind11::tuple GetLayerNormForwardWorkspaceSizes( + size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, + bool is_layer_norm, bool zero_centered_gamma, float eps +) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + + // empty tensor wrappers are okay just to get workspace size + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype); + auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); + + // dummy tensor wrappers that will carry workspace size info later + TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + if (is_layer_norm) { + auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - size_t workspace_size = kCublasLtForwardWorkspaceSize; - auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); - auto wk_tensor = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, + num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), + rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), + dummy_barrier_tensor.data()); + } - nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(), - null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N, - (desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false, - desc.use_split_accumulator, 0, stream); + auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); + auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), + std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); } -void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, - int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype, - void *bias, void *output, DType out_dtype, void *mu, void *rsigma, - float *amax, float *scale, float *scale_inv, cudaStream_t stream) { - auto input_shape = std::vector{n, hidden}; - auto weight_shape = std::vector{hidden}; - auto intermediates_shape = std::vector{n}; +void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, + size_t workspace_size, size_t barrier_size, + bool zero_centered_gamma, float eps, void *input, DType in_dtype, + void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, + void *workspace, DType work_dtype, void *barrier, DType barrier_dtype, + void *mu, void *rsigma, float *amax, float *scale, float *scale_inv, + cudaStream_t stream) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + auto workspace_shape = std::vector{workspace_size}; + auto barrier_shape = std::vector{barrier_size}; auto is_layer_norm = (bias) ? true : false; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -300,63 +312,95 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); - // Create uninitialized workspace, barrier and init them on the first - TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - if (!is_layer_norm) { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - } - // The first call is to query the required workspace + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); + auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); + if (is_layer_norm) { auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, - num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data()); + num_sm, workspace_tensor.data(), barrier_tensor.data()); } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, dummy_workspace_tensor.data(), - dummy_barrier_tensor.data()); + rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), + barrier_tensor.data()); } +} - size_t workspace_size = - dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) + - dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype()); - - void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); +pybind11::tuple GetLayerNormBackwardWorkspaceSizes( + size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, + bool zero_centered_gamma, float eps +) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; + auto intermediates_dtype = DType::kFloat32; - auto workspace_tensor = - TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype()); + // empty tensor wrappers are okay just to get workspace size + auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); + auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); + auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - auto barrier_tensor = - TensorWrapper(reinterpret_cast(workspace) + dummy_workspace_tensor.shape().data[0], - dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype()); + // dummy tensor wrappers that will carry workspace size info later + TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); + auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; + // initialize dBeta information here -- layernorm will modify but RMSnorm will not + std::vector dbeta_part_shape; if (is_layer_norm) { - auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); - auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); + auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); + auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, - num_sm, workspace_tensor.data(), barrier_tensor.data()); + layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), + rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), + wgrad_tensor.data(), dbeta_tensor.data(), + dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr, + num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + + dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); } else { - nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); + nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), + gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), + dummy_dgamma_part_tensor.data(), nullptr, + num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + + dbeta_part_shape = std::vector{0, 0}; } + + auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); + auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); + auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), + std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), + std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), + std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); } -void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, - int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype, - void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad, - void *dbeta, cudaStream_t stream) { - auto input_shape = std::vector{n, hidden}; - auto weight_shape = std::vector{hidden}; - auto intermediates_shape = std::vector{n}; +void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, + size_t wkspace_size, size_t barrier_size, + size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, + bool zero_centered_gamma, float eps, + void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, + void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype, + void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta, + void *dgamma_part, DType dgamma_dtype, + void* dbeta_part, DType dbeta_dtype, + cudaStream_t stream) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto weight_shape = std::vector{hidden_size}; + auto intermediates_shape = std::vector{batch_size}; auto intermediates_dtype = DType::kFloat32; auto is_layer_norm = (dbeta) ? true : false; @@ -374,62 +418,21 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); - TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; - TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - size_t dbeta_part_size{}; - + auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - if (!is_layer_norm) { - NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); - } - - // The first call is to query the workspace - if (is_layer_norm) { - auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); - auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), - rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), - wgrad_tensor.data(), dbeta_tensor.data(), - dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream, - num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data()); - dbeta_part_size = dummy_dbeta_part_tensor.shape().data[0] * - dummy_dbeta_part_tensor.shape().data[1] * - typeToSize(dummy_dbeta_part_tensor.dtype()); - } else { - nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), - gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dummy_dgamma_part_tensor.data(), stream, num_sm, - dummy_workspace_tensor.data(), dummy_barrier_tensor.data()); - } - - size_t workspace_size = - dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()); - size_t barrier_size = - dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype()); - size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] * - dummy_dgamma_part_tensor.shape().data[1] * - typeToSize(dummy_dgamma_part_tensor.dtype()); - - auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace( - workspace_size, dgamma_part_size, dbeta_part_size, barrier_size); - - auto workspace_tensor = - TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype()); - - auto barrier_tensor = - TensorWrapper(barrier, dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype()); - - auto dgamma_part_tensor = TensorWrapper(dgamma_part, dummy_dgamma_part_tensor.shape(), - dummy_dgamma_part_tensor.dtype()); + auto workspace_shape = std::vector{wkspace_size}; + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); + auto barrier_shape = std::vector{barrier_size}; + auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); + auto dgamma_part_shape = std::vector{dgamma_part_sizes[0], dgamma_part_sizes[1]}; + auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype); if (is_layer_norm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(), - dummy_dbeta_part_tensor.dtype()); + auto dbeta_part_shape = std::vector{dbeta_part_sizes[0], dbeta_part_sizes[1]}; + auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype); layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), @@ -437,6 +440,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(), barrier_tensor.data()); } else { + NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(), @@ -456,22 +460,29 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto *mu = buffers[7]; auto *rsigma = buffers[8]; auto *amax_out = buffers[9]; + auto *workspace = buffers[10]; + auto *barrier = buffers[11]; assert(amax_out == amax); const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, + output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, + mu, rsigma, amax, scale, scale_inv, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -481,33 +492,48 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto *output = buffers[3]; auto *mu = buffers[4]; auto *rsigma = buffers[5]; + auto *workspace = buffers[6]; + auto *barrier = buffers[7]; float *amax = nullptr; float *scale = nullptr; float *scale_inv = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, + output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, + mu, rsigma, amax, scale, scale_inv, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto *dgamma_part_sizes = desc.dgamma_part_sizes; + auto *dbeta_part_sizes = desc.dbeta_part_sizes; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto dgamma_part_dtype = desc.dgamma_part_dtype; + auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; @@ -520,9 +546,16 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto *xgrad = buffers[5]; auto *wgrad = buffers[6]; auto *dbeta = buffers[7]; - - LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); + auto *workspace = buffers[8]; + auto *barrier = buffers[9]; + auto *dgamma_part = buffers[10]; + auto *dbeta_part = buffers[11]; + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps, + input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, + barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, + dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -534,24 +567,31 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto *output = buffers[5]; auto *rsigma = buffers[6]; auto *amax_out = buffers[7]; + auto *workspace = buffers[8]; + auto *barrier = buffers[9]; assert(amax_out == amax); void *bias = nullptr; void *mu = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, + output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, + mu, rsigma, amax, scale, scale_inv, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -559,6 +599,8 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto *weight = buffers[1]; auto *output = buffers[2]; auto *rsigma = buffers[3]; + auto *workspace = buffers[4]; + auto *barrier = buffers[5]; void *bias = nullptr; void *mu = nullptr; @@ -567,18 +609,23 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz float *scale_inv = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; - LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, - stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, bias, + output, out_dtype, workspace, wkspace_dtype, barrier, barrier_dtype, + mu, rsigma, amax, scale, scale_inv, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -588,21 +635,35 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto *weight = buffers[3]; auto *xgrad = buffers[4]; auto *wgrad = buffers[5]; + auto *workspace = buffers[6]; + auto *barrier = buffers[7]; + auto *dgamma_part = buffers[8]; + + void *mu = nullptr; + void *dbeta = nullptr; + void *dbeta_part = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto n = desc.n; - auto hidden = desc.hidden; + auto batch_size = desc.batch_size; + auto hidden_size = desc.hidden_size; + auto wkspace_size = desc.wkspace_size; + auto barrier_size = desc.barrier_size; + auto dgamma_part_sizes = desc.dgamma_part_sizes; + size_t dbeta_part_sizes[2] = {0, 0}; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; + auto wkspace_dtype = desc.wkspace_dtype; + auto barrier_dtype = desc.barrier_dtype; + auto dgamma_part_dtype = desc.dgamma_part_dtype; + auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; - auto sm_margin = desc.sm_margin; - void *mu = nullptr; - void *dbeta = nullptr; - - LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight, - w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, + dgamma_part_sizes, dbeta_part_sizes, zero_centered_gamma, eps, + input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, + barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, + dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); } void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -645,7 +706,7 @@ void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaqu auto *output = buffers[1]; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = std::vector{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; + auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; auto dtype = desc.dtype; auto input_tensor = TensorWrapper(input, shape, dtype); @@ -662,7 +723,7 @@ void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaq auto *dgrad = buffers[2]; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto shape = std::vector{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; + auto shape = std::vector{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen}; auto dtype = desc.dtype; auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); @@ -680,8 +741,9 @@ void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char auto *output = buffers[2]; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto io_shape = std::vector{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen}; - auto mask_shape = std::vector{desc.pad_batch, 1, desc.q_seqlen, desc.k_seqlen}; + auto io_shape = std::vector{desc.batch_size, desc.head_dim, + desc.q_seqlen, desc.k_seqlen}; + auto mask_shape = std::vector{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen}; auto dtype = desc.dtype; auto input_tensor = TensorWrapper(input, io_shape, dtype); @@ -705,7 +767,7 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, auto *output = buffers[1]; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto attn_batch = desc.batch * desc.heads; + auto attn_batch = desc.batch_size * desc.head_dim; auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; auto dtype = desc.dtype; @@ -724,7 +786,7 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, auto *dgrad = buffers[2]; const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto attn_batch = desc.batch * desc.heads; + auto attn_batch = desc.batch_size * desc.head_dim; auto shape = std::vector{attn_batch, desc.q_seqlen, desc.k_seqlen}; auto dtype = desc.dtype; @@ -750,91 +812,225 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, return backend; } +/* + NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused + attention forward kernels in: + - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 + - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 +*/ +void PrepareFusedAttnForwardAuxTensors( + NVTETensorPack *tensor_pack, const CustomCallFusedAttnDescriptor *desc, + NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, + void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr +) { + auto batch_size = desc->batch_size; + auto num_heads = desc->num_heads; + auto q_max_seqlen = desc->q_max_seqlen; + auto kv_max_seqlen = desc->kv_max_seqlen; + + // all backends need softmax but expect different shapes/dtypes + // start with the max512 sequence length softmax shape/dtype and correct later + tensor_pack->size = 1; + Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); + softmax_aux->data.dptr = softmax_buf; + softmax_aux->data.shape = std::vector{ + batch_size, num_heads, q_max_seqlen, kv_max_seqlen}; + softmax_aux->data.dtype = desc->dtype; + + // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + tensor_pack->size = 2; + Tensor *rng_state_aux = reinterpret_cast(tensor_pack->tensors[1]); + rng_state_aux->data.dptr = rng_state_buf; + rng_state_aux->data.shape = std::vector{2}; + rng_state_aux->data.dtype = DType::kInt64; + // correct softmax shape/dtype + softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} + softmax_aux->data.dtype = DType::kFloat32; + + // include bias if enabled + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { + tensor_pack->size = 3; + Tensor *bias_aux = reinterpret_cast(tensor_pack->tensors[2]); + bias_aux->data.dptr = bias_buf; + bias_aux->data.shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + bias_aux->data.dtype = desc->dtype; + } + } +} + +/* + NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments + instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the + necessary tensors out of the tensor pack for the active kernel. That means we can just dump + everything we got into the tensor pack and not worry about its sizing for the backward pass. + + TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()? +*/ +void PrepareFusedAttnBackwardAuxTensors( + NVTETensorPack* tensor_pack, const CustomCallFusedAttnDescriptor *desc, + NVTE_Fused_Attn_Backend backend, void* softmax_buf, void* rng_state_buf, void* bias_buf +) { + // Backward calls put everything into the tensor pack for every backend + // so we set dummy bias_type and backend choices here to follow the correct code path + auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; + auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, + softmax_buf, rng_state_buf, bias_buf); + + // correct softmax shape for max512 sequence length kernel + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + Tensor* softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); + softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} + softmax_aux->data.dtype = desc->dtype; + } +} + +pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; + + auto qkv_shape = std::vector{batch_size * max_seqlen, 3, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, max_seqlen, max_seqlen}; + + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + auto cu_seqlens_tensor = TensorWrapper( + nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto o_tensor = TensorWrapper( + nullptr, std::vector{batch_size * max_seqlen, num_heads, head_dim}, dtype); + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, num_heads, num_heads, + max_seqlen, max_seqlen, head_dim); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), + rng_state_tensor.data(), max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); +} + void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input + // input buffers from XLA void *qkv = buffers[0]; void *bias = buffers[1]; void *cu_seqlens = buffers[2]; void *seed = buffers[3]; - // output + // output buffers from XLA void *output = buffers[4]; void *softmax_aux = buffers[5]; void *rng_state = buffers[6]; + void *workspace = buffers[7]; - auto batch = descriptor.batch; - auto num_head = descriptor.num_head; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; + // tensor sizes + auto batch_size = descriptor.batch_size; + auto max_seqlen = descriptor.q_max_seqlen; + auto num_heads = descriptor.num_heads; auto head_dim = descriptor.head_dim; auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, - "q_max_seqlen should be equal to kv_max_seqlen in the self attention."); - - NVTE_CHECK(num_head == num_gqa_groups, - "num_head should be equal to num_gqa_groups in the qkvpacked attention"); auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{batch * q_max_seqlen, 3, num_head, head_dim}; - auto bias_shape = std::vector{1, num_head, q_max_seqlen, kv_max_seqlen}; + auto qkv_shape = std::vector{batch_size * max_seqlen, 3, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, max_seqlen, max_seqlen}; // input tensors auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{batch + 1}, DType::kInt32); + auto cu_seqlens_tensor = TensorWrapper( + cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); // output tensors - auto o_tensor = - TensorWrapper(output, std::vector{batch * q_max_seqlen, num_head, head_dim}, dtype); - - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 + auto o_tensor = TensorWrapper( + output, std::vector{batch_size * max_seqlen, num_heads, head_dim}, dtype); - // aux tensors + // prep RNG state + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, num_heads, num_heads, + max_seqlen, max_seqlen, head_dim); + PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, num_head, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); - PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - + // auxiliary tensors (to be propagated to the backward pass later) NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, + softmax_aux); - TensorWrapper query_workspace_tensor; + // cuDNN workspace + auto wkspace_size = std::vector{descriptor.wkspace_size}; + auto wkspace_dtype = descriptor.wkspace_dtype; + auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, + rng_state_tensor.data(), max_seqlen, descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, query_workspace_tensor.data(), stream); + bias_type, mask_type, workspace_tensor.data(), stream); - auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); - output_s->data.dptr = softmax_aux; + nvte_tensor_pack_destroy(&aux_output_tensors); +} - auto workspace_size = query_workspace_tensor.shape().data[0]; - auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); - auto workspace_tensor = - TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); +pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, workspace_tensor.data(), stream); + auto qkv_shape = std::vector{batch_size * max_seqlen, 3, num_heads, head_dim}; + auto output_shape = std::vector{batch_size * max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, max_seqlen, max_seqlen}; - nvte_tensor_pack_destroy(&aux_output_tensors); + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); + // F16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + + auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + auto cu_seqlens_tensor = TensorWrapper(nullptr, std::vector{batch_size + 1}, + DType::kInt32); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + cu_seqlens_tensor.data(), max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, @@ -842,7 +1038,7 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input + // input buffers from XLA void *qkv = buffers[0]; void *bias = buffers[1]; void *softmax_aux = buffers[2]; @@ -851,82 +1047,107 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq void *doutput = buffers[5]; void *cu_seqlens = buffers[6]; - // output + // output buffers from XLA void *dqkv = buffers[7]; void *dbias = buffers[8]; + void *workspace = buffers[9]; - auto batch = descriptor.batch; - auto num_head = descriptor.num_head; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; + // tensor sizes + auto batch_size = descriptor.batch_size; + auto max_seqlen = descriptor.q_max_seqlen; + auto num_heads = descriptor.num_heads; auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, - "q_max_seqlen should be equal to kv_max_seqlen in the self attention."); - - NVTE_CHECK(num_head == num_gqa_groups, - "num_head should be equal to num_gqa_groups in the qkvpacked attention"); auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{batch * q_max_seqlen, 3, num_head, head_dim}; - auto output_shape = std::vector{batch * q_max_seqlen, num_head, head_dim}; - auto bias_shape = std::vector{1, num_head, q_max_seqlen, kv_max_seqlen}; + auto qkv_shape = std::vector{batch_size * max_seqlen, 3, num_heads, head_dim}; + auto output_shape = std::vector{batch_size * max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, max_seqlen, max_seqlen}; + // input tensors auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + // output tensors + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{batch + 1}, DType::kInt32); + TensorWrapper(cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); - // TODO: needs to think about how to pass aux_output_tensors - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); + // auxiliary tensors (propagated from the forward pass) + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, num_heads, num_heads, + max_seqlen, max_seqlen, head_dim); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, + softmax_aux, rng_state, bias); - aux_output_tensors.size = 3; - auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); - output_s->data.dptr = softmax_aux; - auto *rng_state_tensor = reinterpret_cast(aux_output_tensors.tensors[1]); - rng_state_tensor->data.shape = std::vector{2}; - rng_state_tensor->data.dtype = DType::kInt64; - rng_state_tensor->data.dptr = rng_state; - auto *bias_tensor = reinterpret_cast(aux_output_tensors.tensors[2]); - bias_tensor->data = SimpleTensor(bias, bias_shape, dtype); - - TensorWrapper query_workspace_tensor; + // cuDNN workspace + auto wkspace_size = std::vector{descriptor.wkspace_size}; + auto wkspace_dtype = descriptor.wkspace_dtype; + auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 - &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + cu_seqlens_tensor.data(), max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), stream); + workspace_tensor.data(), stream); - size_t workspace_size = query_workspace_tensor.shape().data[0]; - auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); - auto workspace_tensor = - TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); + nvte_tensor_pack_destroy(&aux_input_tensors); +} - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); +pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - nvte_tensor_pack_destroy(&aux_output_tensors); + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto kv_shape = std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + + // TODO(rewang): add bias for cross attn? + auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + // FP16/BF16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); + + auto q_cu_seqlens_tensor = TensorWrapper( + nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = TensorWrapper( + nullptr, std::vector{batch_size + 1}, DType::kInt32); + + auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_fwd_kvpacked( + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, @@ -934,7 +1155,7 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input + // input buffers from XLA void *q = buffers[0]; void *kv = buffers[1]; void *bias = buffers[2]; @@ -942,83 +1163,115 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq void *kv_cu_seqlens = buffers[4]; void *seed = buffers[5]; - // output + // output buffers from XLA void *output = buffers[6]; void *softmax_aux = buffers[7]; void *rng_state = buffers[8]; + void *workspace = buffers[9]; - auto batch = descriptor.batch; - auto num_head = descriptor.num_head; - auto num_gqa_groups = descriptor.num_gqa_groups; + // tensor sizes + auto batch_size = descriptor.batch_size; auto q_max_seqlen = descriptor.q_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto num_heads = descriptor.num_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto dtype = descriptor.dtype; - auto q_shape = std::vector{batch * q_max_seqlen, num_head, head_dim}; - auto kv_shape = std::vector{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto bias_shape = std::vector{1, num_head, q_max_seqlen, kv_max_seqlen}; + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto kv_shape = std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; // input tensors + auto dtype = descriptor.dtype; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); + // output tensors + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 + auto o_tensor = TensorWrapper(output, q_shape, dtype); auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{batch + 1}, DType::kInt32); + TensorWrapper(q_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{batch + 1}, DType::kInt32); - - // output tensors - auto o_tensor = - TensorWrapper(output, std::vector{batch * q_max_seqlen, num_head, head_dim}, dtype); - - // aux tensors - - // F16 doesn't use s_tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + TensorWrapper(kv_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + // prep RNG state + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, num_head, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, head_dim); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); + // auxiliary tensors (to be propagated to the backward pass later) NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, + softmax_aux); - TensorWrapper query_workspace_tensor; + // cuDNN workspace + auto workspace_tensor = TensorWrapper( + workspace, std::vector{descriptor.wkspace_size}, descriptor.wkspace_dtype); nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), stream); + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); - auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); - output_s->data.dptr = softmax_aux; + nvte_tensor_pack_destroy(&aux_output_tensors); +} - auto workspace_size = query_workspace_tensor.shape().data[0]; - auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); - auto workspace_tensor = - TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); +pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +) { + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto kv_shape = std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto output_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; - nvte_tensor_pack_destroy(&aux_output_tensors); + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); + // FP16/BF16 doesn't use this tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); + + auto q_cu_seqlens_tensor = TensorWrapper( + nullptr, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = TensorWrapper( + nullptr, std::vector{batch_size + 1}, DType::kInt32); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + + TensorWrapper query_workspace_tensor; + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for FP16/BF16 + s_tensor.data(), // not used for FP16/BF16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, query_workspace_tensor.data(), nullptr); + + auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, @@ -1026,7 +1279,7 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input + // input buffers from XLA void *q = buffers[0]; void *kv = buffers[1]; void *bias = buffers[2]; @@ -1037,85 +1290,72 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa void *q_cu_seqlens = buffers[7]; void *kv_cu_seqlens = buffers[8]; - // output + // output buffers from XLA void *dq = buffers[9]; void *dkv = buffers[10]; void *dbias = buffers[11]; + void *workspace = buffers[12]; - auto batch = descriptor.batch; - auto num_head = descriptor.num_head; - auto num_gqa_groups = descriptor.num_gqa_groups; + // tensor sizes + auto batch_size = descriptor.batch_size; auto q_max_seqlen = descriptor.q_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen; + auto num_heads = descriptor.num_heads; + auto num_gqa_groups = descriptor.num_gqa_groups; auto head_dim = descriptor.head_dim; + auto scaling_factor = descriptor.scaling_factor; auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto dtype = descriptor.dtype; - auto q_shape = std::vector{batch * q_max_seqlen, num_head, head_dim}; - auto kv_shape = std::vector{batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto output_shape = std::vector{batch * q_max_seqlen, num_head, head_dim}; - auto bias_shape = std::vector{1, num_head, q_max_seqlen, kv_max_seqlen}; + auto q_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto kv_shape = std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto output_shape = std::vector{batch_size * q_max_seqlen, num_heads, head_dim}; + auto bias_shape = std::vector{1, num_heads, q_max_seqlen, kv_max_seqlen}; + // input tensors + auto dtype = descriptor.dtype; auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + // output tensors + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto q_cu_seqlens_tensor = TensorWrapper( + q_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + auto kv_cu_seqlens_tensor = TensorWrapper( + kv_cu_seqlens, std::vector{batch_size + 1}, DType::kInt32); + + // auxiliary tensors (propagated from the forward pass) + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, num_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, head_dim); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, + softmax_aux, rng_state, bias); - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{batch + 1}, DType::kInt32); - - // TODO(rewang): need to think about how to pass aux_output_tensors - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - - aux_output_tensors.size = 3; - auto *output_s = reinterpret_cast(aux_output_tensors.tensors[0]); - output_s->data.dptr = softmax_aux; - auto *rng_state_tensor = reinterpret_cast(aux_output_tensors.tensors[1]); - rng_state_tensor->data.shape = std::vector{2}; - rng_state_tensor->data.dtype = DType::kInt64; - rng_state_tensor->data.dptr = rng_state; - auto *bias_tensor = reinterpret_cast(aux_output_tensors.tensors[2]); - bias_tensor->data = SimpleTensor(bias, bias_shape, dtype); - - TensorWrapper query_workspace_tensor; - - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for FP16/BF16 - s_tensor.data(), // not used for FP16/BF16 - &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type, - mask_type, query_workspace_tensor.data(), stream); - - size_t workspace_size = query_workspace_tensor.shape().data[0]; - auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); - - auto workspace_tensor = - TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); + // cuDNN workspace + auto wkspace_size = std::vector{descriptor.wkspace_size}; + auto wkspace_dtype = descriptor.wkspace_dtype; + auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16 - &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type, - mask_type, workspace_tensor.data(), stream); + scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); - nvte_tensor_pack_destroy(&aux_output_tensors); + nvte_tensor_pack_destroy(&aux_input_tensors); } } // namespace jax diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 2878a15e91..471cea8a65 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -52,68 +52,69 @@ struct CustomCallCommonDescriptor { pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, DType in_dtype, DType out_dtype); -struct CustomCallGemmDescriptor { - size_t m; - size_t n; - size_t k; - DType A_dtype; - DType B_dtype; - DType D_dtype; - bool transa; - bool transb; - bool use_split_accumulator; -}; - -pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype, - DType B_dtype, DType D_dtype, bool transa, bool transb, - bool use_split_accumulator); - struct CustomCallNormDescriptor { - size_t n; - size_t hidden; + size_t batch_size; + size_t hidden_size; + size_t wkspace_size; + size_t barrier_size; + size_t *dgamma_part_sizes; // 2D tensor + size_t *dbeta_part_sizes; // 2D tensor DType x_dtype; DType w_dtype; + DType wkspace_dtype; + DType barrier_dtype; + DType dgamma_part_dtype; + DType dbeta_part_dtype; bool zero_centered_gamma; float eps; int sm_margin; }; -pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, size_t barrier_size, + size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, + DType x_dtype, DType w_dtype, + DType wkspace_dtype, DType barrier_dtype, + DType dgamma_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); struct SoftmaxDescriptor { - size_t batch; - size_t pad_batch; - size_t heads; + size_t batch_size; + size_t padding_size; + size_t head_dim; size_t q_seqlen; size_t k_seqlen; DType dtype; float scale_factor; }; -pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, - size_t q_seqlen, size_t k_seqlen, DType dtype, - float scale_factor); +pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, + size_t head_dim, size_t q_seqlen, size_t k_seqlen, + DType dtype, float scale_factor); struct CustomCallFusedAttnDescriptor { - size_t batch; - size_t num_head; - size_t num_gqa_groups; + size_t batch_size; size_t q_max_seqlen; size_t kv_max_seqlen; + size_t num_heads; + size_t num_gqa_groups; size_t head_dim; + size_t wkspace_size; float scaling_factor; float dropout_probability; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; DType dtype; + DType wkspace_dtype; bool is_training; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( - size_t batch, size_t num_head, size_t num_gqa_groups, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, DType dtype, bool is_training); + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + DType dtype, DType wkspace_dtype, bool is_training); NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -135,13 +136,21 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetLayerNormForwardWorkspaceSizes( + size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, + bool is_layer_norm, bool zero_centered_gamma, float eps +); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetLayerNormBackwardWorkspaceSizes( + size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, + bool zero_centered_gamma, float eps +); + void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -172,15 +181,41 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); +pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +); + void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +); + void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +); + void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t num_heads, size_t num_gqa_groups, size_t head_dim, + float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training +); + void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 2db9e85dca..01e10de984 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -28,66 +28,6 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, cudaStream_t stream); -class WorkspaceManager { - public: - static WorkspaceManager &Instance() { - static thread_local WorkspaceManager instance; - return instance; - } - - WorkspaceManager() {} - ~WorkspaceManager() { Clear_(); } - - void *GetWorkspace(size_t size = 4194304) { - ReallocateIfNeed_(size); - return workspace_; - } - - template - inline auto GetWorkspace(Args... args) { - auto asks = std::array{args...}; - std::array offsets = {0}; - std::array workspaces = {nullptr}; - std::transform_inclusive_scan( - asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus{}, - [=](auto x) { return PadSize_(x); }, 0); - auto *workspace = GetWorkspace(offsets.back()); - std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(), - [workspace](auto x) { return static_cast(workspace) + x; }); - return workspaces; - } - - private: - void *workspace_ = nullptr; - size_t size_ = 0; - - size_t PadSize_(size_t size) { - constexpr size_t alignment = 128; - return ((size + alignment - 1) / alignment) * alignment; - } - - void Clear_() { - if (workspace_ != nullptr) { - NVTE_CHECK_CUDA(cudaFree(workspace_)); - } - workspace_ = nullptr; - size_ = 0; - } - - void Allocate_(size_t new_size) { - new_size = PadSize_(new_size); - NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size)); - size_ = new_size; - } - - void ReallocateIfNeed_(size_t new_size) { - if (new_size > size_) { - Clear_(); - Allocate_(new_size); - } - } -}; - class cudaDevicePropertiesManager { public: static cudaDevicePropertiesManager &Instance() { diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c1c5ecc490..2cb971d306 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -22,7 +22,7 @@ from ..fp8 import FP8Helper, FP8MetaPackage from ..layernorm import canonicalize_layernorm_type from ..layernorm import layernorm, layernorm_fp8_dot -from ..mlp import layernrom_geglu_fp8_mlp, geglu +from ..mlp import layernorm_geglu_fp8_mlp, geglu from ..softmax import is_softmax_kernel_available from ..softmax import softmax, SoftmaxType @@ -886,7 +886,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if use_fused_ln_mlp: assert self.axis == -1 # Only support axis = =-1 at this moment - out = layernrom_geglu_fp8_mlp(y, + out = layernorm_geglu_fp8_mlp(y, scale, ln_bias, [kernel_1, kernel_2], fp8_meta_package, diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index b390ba5910..f2466a90b8 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -55,7 +55,7 @@ def _geglu_bwd_rule(ctx, g): _geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) -def layernrom_geglu_fp8_mlp(x: jnp.ndarray, +def layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernels: List[jnp.ndarray], @@ -86,25 +86,25 @@ def layernrom_geglu_fp8_mlp(x: jnp.ndarray, assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - output = _layernrom_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, + output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon) return output @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13)) -def _layernrom_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, +def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, epsilon: float): - output, _ = _layernrom_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, + output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon) return output -def _layernrom_geglu_fp8_mlp_fwd_rule( +def _layernorm_geglu_fp8_mlp_fwd_rule( x, gamma, beta, @@ -209,7 +209,7 @@ def _layernrom_geglu_fp8_mlp_fwd_rule( return dot_2_output, ctx -def _layernrom_geglu_fp8_mlp_bwd_rule( +def _layernorm_geglu_fp8_mlp_bwd_rule( fwd_dtype, # pylint: disable=unused-argument bwd_dtype, layernorm_type, @@ -307,5 +307,5 @@ def _layernrom_geglu_fp8_mlp_bwd_rule( fp8_max, amax, scale, scale_inv -_layernrom_geglu_fp8_mlp.defvjp(_layernrom_geglu_fp8_mlp_fwd_rule, - _layernrom_geglu_fp8_mlp_bwd_rule) +_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule, + _layernorm_geglu_fp8_mlp_bwd_rule)