Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving communication overlap for the case of multi kernel queue usage #1308

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,25 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);

//Managing launch ordering to maximize comm-comp overlap for the case of using CUDA_DEVICE_MAX_CONNECTIONS>1
youngeunkwon0405 marked this conversation as resolved.
Show resolved Hide resolved
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
//Hopper-only feature
youngeunkwon0405 marked this conversation as resolved.
Show resolved Hide resolved
if (deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
} else {
_comm_launch_event = 0;
}
}

CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);

if (_atomic_gemm) cudaFree(_counter.dptr());

Expand Down Expand Up @@ -168,7 +180,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
Expand All @@ -178,13 +191,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm);
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
}
}

assert(pre_gelu_out.numel() == 0);
// If enforcing the communication-computation launch order for the Hopper GPU, wait for the launch event
youngeunkwon0405 marked this conversation as resolved.
Show resolved Hide resolved
if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main);
Expand Down
107 changes: 81 additions & 26 deletions transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,19 @@ __global__ void __launch_bounds__(MAX_THREADS)
cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;

#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[3] = {}; \
attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \
attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 3;

#define callranks_ag(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
Expand Down Expand Up @@ -1753,7 +1766,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
}

void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1766,11 +1780,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -1790,7 +1813,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
}

void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1803,17 +1827,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream) {
cudaStream_t stream, cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1827,23 +1860,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream,
comm_launch_event);
}

template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1857,33 +1902,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
Comment on lines +1905 to +1911
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for duplicated kernel launch code.

Suggested change
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
}
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @denera, the suggested coding style causes a compile error, which is why I had to do a duplicated kernel launch...
Since both SETUP_LAUNCH_CONFIG and callranks_** are define functions, there is a variable scope issue. The compute kernel call should be in the same or lower scope than the SETUP kernel. This issue applies the same to the other comments. If you have a better solution for this, please let me know.

}

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) {
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream);
comm, stream, comm_launch_event);
}

template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *

// for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
/*
each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements
Expand All @@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream = 0);
const int elements, communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CommOverlapCore {
bool _ubuf_scale_inv_initialized{false};

std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;

public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
Expand Down