Skip to content

Commit

Permalink
Used fused push_send_recv kernel for AG overlap
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <slym@nvidia.com>
  • Loading branch information
erhoo82 committed Sep 26, 2024
1 parent a68acd7 commit 8367bad
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 37 deletions.
113 changes: 77 additions & 36 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int _num_comm_sm;
int _cga_size;
bool _atomic_gemm;
bool _use_fused_sendrecv;

UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size,
Expand Down Expand Up @@ -727,6 +728,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::getStreamFromExternal(stream, stream_main.device_index()));
}

// Set to use a single kernel for push_send and push_recv
const char *use_fused_sendrecv = std::getenv("NVTE_UB_FUSED_SEND_RECV");
_use_fused_sendrecv = use_fused_sendrecv != nullptr && use_fused_sendrecv[0] == '1';

// Set the number of SMs for GEMM with margin
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
Expand Down Expand Up @@ -943,9 +948,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {

at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));

NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
if (!_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
Expand All @@ -959,13 +965,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
if (_use_fused_sendrecv) {
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes,
_ub_comm, peer_rank, peer_rank, (cudaStream_t)_stream_send);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
}

int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
Expand Down Expand Up @@ -998,14 +1011,22 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {

if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
if (_use_fused_sendrecv) {
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes * 2,
_ub_comm, next_rank, prev_rank, (cudaStream_t)_stream_send);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
}
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
Expand Down Expand Up @@ -1043,14 +1064,22 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {

if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
if (_use_fused_sendrecv) {
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, (cudaStream_t)_stream_send);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
}
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
Expand All @@ -1066,8 +1095,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
if (!_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
}
at::cuda::setCurrentCUDAStream(stream_main);
_ub_comm->sms = ori_sms;

Expand Down Expand Up @@ -1190,7 +1221,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
if (!_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
Expand Down Expand Up @@ -1218,12 +1251,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
send_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
recv_rank, (cudaStream_t)_stream_recv);
if (_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, send_rank, recv_rank, (cudaStream_t)_stream_send);
} else {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
send_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
recv_rank, (cudaStream_t)_stream_recv);
}
}
}
at::cuda::setCurrentCUDAStream(stream_main);
Expand All @@ -1233,8 +1272,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
if (!_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
}

// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2277,7 +2277,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size

void *send_srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + send_offset;
void *send_dstptr =
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + recv_offset;

if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
Expand Down

0 comments on commit 8367bad

Please sign in to comment.