Skip to content

Commit

Permalink
Support GEMM-GELU fusion with split AG overlap (#661)
Browse files Browse the repository at this point in the history
* Support GEMM-GELU fusion with split AG overlap

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Fix linter complaints

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Jaemin Choi <minitu77@gmail.com>

* Avoid code duplication

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Fix issue with modifying tuple

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Disable GEMM-GELU fusion when split AG overlap is not enabled

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Add ub_split_ag parameter to LayerNormMLP unit test

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Move knob into LayerNormMLP, auto-disable fusion when split AG overlap is not enabled

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>

* Revert changes to test_layernorm_mlp_accuracy

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Jaemin Choi <minitu77@gmail.com>

---------

Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>
Signed-off-by: Jaemin Choi <minitu77@gmail.com>
Co-authored-by: Jaemin Choi <jaeminc@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent a950061 commit a174985
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
20 changes: 18 additions & 2 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {

// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int output_chunk_bytes = (n_chunk * m) * HALF_BYTES;
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (do_gelu
? (n_chunk * m) * D.element_size()
: (n_chunk * m) * HALF_BYTES);
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;

// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();

Expand All @@ -809,7 +814,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));

assert(pre_gelu_out.numel() == 0);
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
Expand Down Expand Up @@ -848,6 +852,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options());
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options());
if (do_gelu) {
pre_gelu_out = torch::from_blob(
pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes),
{n_chunk * 2, m},
pre_gelu_out.options());
}
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
Expand Down Expand Up @@ -901,6 +911,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// GEMM
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options());
if (do_gelu) {
pre_gelu_out = torch::from_blob(
pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes),
{n_chunk, m},
pre_gelu_out.options());
}
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
Expand Down
40 changes: 33 additions & 7 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def forward(
ub_atomic_gemm_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
gemm_gelu_fusion: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
Expand Down Expand Up @@ -261,7 +262,9 @@ def forward(

ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm(

# Perform FP8 GEMM
fp8_gemm_args = [
fc1_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
Expand All @@ -272,22 +275,40 @@ def forward(
fp8_dtype_forward,
activation_dtype,
get_workspace(),
]
fp8_gemm_kwargs = dict(
bias=fc1_bias,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if gemm_gelu_fusion:
fp8_gemm_args[8] = torch.uint8 # out_dtype
fp8_gemm_kwargs.update(
dict(
gelu=True,
out_index=tex.FP8FwdTensors.GEMM2_INPUT,
fp8_meta_tensor=fp8_meta["scaling_fwd"],
D_dtype=fp8_dtype_forward,
)
)
fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs)
if not is_grad_enabled:
clear_tensor_data(ln_out_total)

gelu_out = activation_func(
fc1_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
)
# Perform activation
if gemm_gelu_fusion:
gelu_out, fc1_out = fp8_gemm_out
else:
fc1_out, _ = fp8_gemm_out
gelu_out = activation_func(
fc1_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
)
if not is_grad_enabled:
clear_tensor_data(fc1_out)

Expand Down Expand Up @@ -1033,6 +1054,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1175,6 +1197,9 @@ def __init__(
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and self.ub_split_ag)

if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad
Expand Down Expand Up @@ -1438,6 +1463,7 @@ def forward(
self.ub_atomic_gemm_rs,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.gemm_gelu_fusion,
)
out = fwd_fn(*args)

Expand Down

0 comments on commit a174985

Please sign in to comment.