Skip to content

Commit

Permalink
[JAX] Custom Op Workspace Tensors from XLA Buffers (#532)
Browse files Browse the repository at this point in the history
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed import order for linting

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit

Signed-off-by: Alp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed linting errors for blank lines

Signed-off-by: Alp Dener <adener@nvidia.com>

---------

Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera authored Jan 29, 2024
1 parent bd7fd0a commit 4077ccc
Show file tree
Hide file tree
Showing 14 changed files with 1,190 additions and 657 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ tests/cpp/build/
docs/_build
.ipynb_checkpoints
docs/doxygen
*.log
CMakeFiles/CMakeSystem.cmake
4 changes: 2 additions & 2 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import layernrom_geglu_fp8_mlp
from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp

GEMM_CASES = [
(256, 256, 512),
Expand Down Expand Up @@ -196,7 +196,7 @@ def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale,
# out = (x * y) * z
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(layernrom_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))
return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm"))

def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
Expand Down Expand Up @@ -248,6 +246,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ},
Expand Down Expand Up @@ -300,8 +302,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
Expand Down Expand Up @@ -519,6 +519,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

// build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{q, devPtrQ},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,6 @@ void fused_attn_max_512_fwd_impl(
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

FADescriptor descriptor{b, h,
s_q, s_kv,
d, scaling_factor,
Expand Down Expand Up @@ -754,6 +752,10 @@ void fused_attn_max_512_fwd_impl(
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

// Prepare actual seqlen
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
Expand Down Expand Up @@ -845,9 +847,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
try {
// Create cudnn handle
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

FADescriptor descriptor{
b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability,
layout, bias_type, mask_type, tensorType, false};
Expand Down Expand Up @@ -1194,6 +1193,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));

constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
Expand Down
12 changes: 8 additions & 4 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,6 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));

FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout,
Expand Down Expand Up @@ -1212,6 +1210,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));

int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
Expand Down Expand Up @@ -1324,8 +1326,6 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cudaStream_t stream,
cudnnHandle_t handle_) {
try {
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));

FADescriptor descriptor{
b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout,
Expand Down Expand Up @@ -1745,6 +1745,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
return;
}

// cuDNN stream check needs to be moved here to support dummy kernel calls with
// null streams for sizing the cuDNN workspace.
NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream));

int32_t* qkv_ragged_offset = reinterpret_cast<int32_t*>(
reinterpret_cast<int8_t*>(workspace_ptr) + wkspace_size);
int32_t* o_ragged_offset = reinterpret_cast<int32_t*>(
Expand Down
37 changes: 20 additions & 17 deletions transformer_engine/common/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;

CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");

CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");

NVTE_CHECK(x.data.shape.size() == 2);

const size_t rows = x.data.shape[0];
Expand Down Expand Up @@ -227,6 +219,16 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size

return;
}

// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");

CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");

if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
Expand Down Expand Up @@ -273,15 +275,6 @@ void layernorm_bwd(const Tensor& dz,
auto otype = wtype;
auto ctype = DType::kFloat32;

CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");

NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
Expand Down Expand Up @@ -354,6 +347,16 @@ void layernorm_bwd(const Tensor& dz,
return;
}

// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");

if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
Expand Down
30 changes: 17 additions & 13 deletions transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
const bool fp8_out = is_fp8_dtype(otype);
auto ctype = DType::kFloat32;

CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");

CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");

NVTE_CHECK(x.data.shape.size() == 2);

const size_t rows = x.data.shape[0];
Expand Down Expand Up @@ -172,6 +166,15 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens

return;
}

// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");

CheckOutputTensor(*z, "z");
CheckOutputTensor(*rsigma, "rsigma");


if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
Expand Down Expand Up @@ -204,13 +207,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
auto otype = wtype;
auto ctype = DType::kFloat32;

CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");

NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(rsigma.data.dtype == ctype);

Expand Down Expand Up @@ -268,6 +264,14 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
return;
}

// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");

if (launch_params.barrier_size > 0) {
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
Expand Down
Loading

0 comments on commit 4077ccc

Please sign in to comment.