Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving communication overlap for the case of multi kernel queue usage #1308

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):


@pytest.mark.parametrize(
"comm_type,fp8",
"comm_type, fp8, connections",
[
("AG", False),
("RS", False),
("RS", True),
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
"REDUCE-SCATTER - BF16 - 1 connections",
"REDUCE-SCATTER - FP8 - 1 connections",
"ALL-GATHER - BF16 - 8 connections",
"REDUCE-SCATTER - BF16 - 8 connections",
"REDUCE-SCATTER - FP8 - 8 connections",
],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
)
def test_bulk_overlaps(comm_type, fp8):
def test_bulk_overlaps(comm_type, fp8, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
if connections == 8:
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip(
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)


@pytest.mark.parametrize(
Expand Down
28 changes: 25 additions & 3 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,29 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);

/*
Defining the launcher order between the communication and GEMM kernels
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
The event is used to schedule the communication kernel before the GEMM.
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
} else {
_comm_launch_event = 0;
}
}

CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);

if (_atomic_gemm) cudaFree(_counter.dptr());

Expand Down Expand Up @@ -168,7 +184,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
Expand All @@ -178,13 +195,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm);
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
}
}

assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main);
Expand Down
107 changes: 81 additions & 26 deletions transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,19 @@ __global__ void __launch_bounds__(MAX_THREADS)
cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;

#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[3] = {}; \
attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \
attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 3;

#define callranks_ag(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
Expand Down Expand Up @@ -1753,7 +1766,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
}

void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1766,11 +1780,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -1790,7 +1813,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
}

void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1803,17 +1827,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream) {
cudaStream_t stream, cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1827,23 +1860,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
}
denera marked this conversation as resolved.
Show resolved Hide resolved
}
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream,
comm_launch_event);
}

template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1857,33 +1902,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
Comment on lines +1905 to +1911
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for duplicated kernel launch code.

Suggested change
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
}
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @denera, the suggested coding style causes a compile error, which is why I had to do a duplicated kernel launch...
Since both SETUP_LAUNCH_CONFIG and callranks_** are define functions, there is a variable scope issue. The compute kernel call should be in the same or lower scope than the SETUP kernel. This issue applies the same to the other comments. If you have a better solution for this, please let me know.

}

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) {
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream);
comm, stream, comm_launch_event);
}

template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *

// for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
/*
each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements
Expand All @@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream = 0);
const int elements, communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CommOverlapCore {
bool _ubuf_scale_inv_initialized{false};

std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;

public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
Expand Down