From a37ad135158affa4c728d1edfe284f7fb072a539 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Fri, 1 Nov 2024 14:39:13 -0700 Subject: [PATCH 1/9] draft implementation Signed-off-by: Youngeun Kwon --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 25 +++- .../userbuffers/userbuffers.cu | 109 +++++++++++++----- .../userbuffers/userbuffers.h | 14 ++- .../transformer_engine/comm_gemm_overlap.h | 2 +- 4 files changed, 114 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 59ec56f161..b7dae04bea 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -90,6 +90,20 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); + + //Managing launch ordering to maximize comm-comp overlap for the case of using CUDA_DEVICE_MAX_CONNECTIONS>1 + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + //Hopper-only feature + if (deviceProp.major == 9 && max_connection > 1){ + cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + printf("!!! [UB][FDL] CUDA EVENT CREATION\n"); + } + else{ + _comm_launch_event = 0; + printf("!!! [UB] non-FDL CUDA EVENT CREATION\n"); + } } CommOverlapCore::~CommOverlapCore() { @@ -97,6 +111,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); + if(_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -168,7 +183,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -178,13 +193,17 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm); + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); + // If enforcing the communication-computation launch order for the Hopper GPU, wait for the launch event + if(_comm_launch_event) NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 26843d8107..6453077c56 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1366,6 +1366,19 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[3] = {}; \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 3; + #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -1753,7 +1766,7 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1766,11 +1779,21 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) - } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } + } + else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } } @@ -1790,7 +1813,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1803,17 +1827,27 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) - } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } + } + else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, - cudaStream_t stream) { + cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1827,23 +1861,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) - } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } + } + else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); + communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); } template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1857,23 +1903,32 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } + else{ + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream) { + const int elements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, - comm, stream); + comm, stream, comm_launch_event); } template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 57e68afce0..13fda855e2 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -228,13 +229,15 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, @@ -242,7 +245,8 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const communicator *comm, cudaStream_t stream = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0); + const int elements, communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 17ecca5ff0..1d5d192a39 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -62,7 +62,7 @@ class CommOverlapCore { bool _ubuf_scale_inv_initialized{false}; std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, From 3627dcc59b22c9a20979ee94c0329d872a20a1e1 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Fri, 1 Nov 2024 16:30:37 -0700 Subject: [PATCH 2/9] compile error fix Signed-off-by: Youngeun Kwon --- .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 6453077c56..a5f96f3252 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1934,11 +1934,11 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, From 0cde98eb907312211004fc4a6215867e68a16b5b Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Fri, 1 Nov 2024 17:31:45 -0700 Subject: [PATCH 3/9] fix compile error Signed-off-by: Youngeun Kwon --- .../common/comm_gemm_overlap/userbuffers/userbuffers.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 13fda855e2..571b8a3015 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -242,7 +242,8 @@ template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream = 0, From 8344e493e7198af5859e256961155bf1fe1489a1 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Sat, 2 Nov 2024 08:39:23 -0700 Subject: [PATCH 4/9] remove print Signed-off-by: Youngeun Kwon --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index b7dae04bea..e5ac2dc2f3 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -98,11 +98,9 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl //Hopper-only feature if (deviceProp.major == 9 && max_connection > 1){ cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); - printf("!!! [UB][FDL] CUDA EVENT CREATION\n"); } else{ _comm_launch_event = 0; - printf("!!! [UB] non-FDL CUDA EVENT CREATION\n"); } } From caadcacbad04cdb0cf1a0dc8cc59c57bead0e7b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Nov 2024 15:55:51 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 17 +++--- .../userbuffers/userbuffers.cu | 58 +++++++++---------- .../userbuffers/userbuffers.h | 13 +++-- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index e5ac2dc2f3..b62ca4708a 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -96,10 +96,9 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); //Hopper-only feature - if (deviceProp.major == 9 && max_connection > 1){ + if (deviceProp.major == 9 && max_connection > 1) { cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); - } - else{ + } else { _comm_launch_event = 0; } } @@ -109,7 +108,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); - if(_comm_launch_event) cudaEventDestroy(_comm_launch_event); + if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -181,7 +180,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -191,17 +191,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm, + comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); // If enforcing the communication-computation launch order for the Hopper GPU, wait for the launch event - if(_comm_launch_event) NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index a5f96f3252..3e106f0155 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1366,17 +1366,17 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; -#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ - cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ - cudaLaunchAttribute attribute_ub[3] = {}; \ - attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ - attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; \ - attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ - attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ - attribute_ub[1].val.clusterDim.y = 1; \ - attribute_ub[1].val.clusterDim.z = 1; \ - attribute_ub[0].id = cudaLaunchAttributeCooperative; \ - cfg.attrs = attribute_ub; \ +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[3] = {}; \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ cfg.numAttrs = 3; #define callranks_ag(x) \ @@ -1766,7 +1766,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1786,8 +1787,7 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int } else { callranks_ag(2) callranks_ag(4) callranks_ag(8) } - } - else { + } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) @@ -1813,7 +1813,7 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream, + communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1834,8 +1834,7 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const } else { callranks_rs(2) callranks_rs(4) callranks_rs(8) } - } - else { + } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) @@ -1867,20 +1866,20 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) } else { callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) - } - } - else { + } + } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) } else { callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) - } + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, comm_launch_event); } @@ -1888,7 +1887,7 @@ template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream, + communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; @@ -1906,8 +1905,7 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) - } - else{ + } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) } @@ -1915,7 +1913,7 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( @@ -1934,11 +1932,13 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream, cudaEvent_t comm_launch_event); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream, cudaEvent_t comm_launch_event); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 571b8a3015..75655ef691 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -213,7 +213,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0, + communicator *comm, cudaStream_t stream = 0, cudaEvent_t comm_launch_event = 0); /* each Rank input is @@ -229,24 +229,25 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0, + communicator *comm, cudaStream_t stream = 0, cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0, + const int elements, communicator *comm, cudaStream_t stream = 0, cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, From a264df849aeee5c476a6f747b24d415eb45d723e Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Sat, 2 Nov 2024 18:44:50 -0700 Subject: [PATCH 6/9] Edit comments Signed-off-by: Youngeun Kwon --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index b62ca4708a..2abd063f4c 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -91,11 +91,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); - //Managing launch ordering to maximize comm-comp overlap for the case of using CUDA_DEVICE_MAX_CONNECTIONS>1 + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); - //Hopper-only feature if (deviceProp.major == 9 && max_connection > 1) { cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); } else { @@ -200,7 +204,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper } assert(pre_gelu_out.numel() == 0); - // If enforcing the communication-computation launch order for the Hopper GPU, wait for the launch event + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch if (_comm_launch_event) NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, From 955354780347d76666540569018e84c44c5cbe40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Nov 2024 01:45:32 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 2abd063f4c..37b1a57ee1 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -92,9 +92,9 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_comm, 0); /* - Defining the launcher order between the communication and GEMM kernels - using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. - The event is used to schedule the communication kernel before the GEMM. + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. This is needed only for Hopper, which uses persistent CTA execution. */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); From 8c305725ddc7586dd891be189ecc11fd273ceebc Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Thu, 7 Nov 2024 13:07:25 -0800 Subject: [PATCH 8/9] edit the bulk-overlap test case Signed-off-by: Youngeun Kwon --- .../distributed/test_comm_gemm_overlap.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ce46a72189..08fd5847e0 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): @pytest.mark.parametrize( - "comm_type,fp8", + "comm_type, fp8, connections", [ - ("AG", False), - ("RS", False), - ("RS", True), + ("AG", False, 1), + ("RS", False, 1), + ("RS", True, 1), + ("AG", False, 8), + ("RS", False, 8), + ("RS", True, 8), + ], + ids=[ + "ALL-GATHER - BF16 - 1 connections", + "REDUCE-SCATTER - BF16 - 1 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], - ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], ) -def test_bulk_overlaps(comm_type, fp8): +def test_bulk_overlaps(comm_type, fp8, connections): """ Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + if connections == 8: + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip( + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability 9.0 (HOPPER ARCH)." + ) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + @pytest.mark.parametrize( From 71e8cb3225e819c816b78e3ff420b9525c6d6b0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 21:08:30 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/test_comm_gemm_overlap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 08fd5847e0..f81fbae1fe 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -234,7 +234,8 @@ def test_bulk_overlaps(comm_type, fp8, connections): if connections == 8: if torch.cuda.get_device_properties(0).major != 9: pytest.skip( - "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability 9.0 (HOPPER ARCH)." + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" + " 9.0 (HOPPER ARCH)." ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) @@ -243,7 +244,6 @@ def test_bulk_overlaps(comm_type, fp8, connections): _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) - @pytest.mark.parametrize( "layer_type", [layer.__name__ for layer in TE_LAYERS],