From 8367badc363b4c353f18a3bda5bf37bd81ffff76 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Sat, 21 Sep 2024 10:22:04 -0700 Subject: [PATCH] Used fused push_send_recv kernel for AG overlap Signed-off-by: Sangkug Lym --- .../pytorch/csrc/comm_gemm_overlap.h | 113 ++++++++++++------ .../pytorch/csrc/userbuffers/userbuffers.cu | 2 +- 2 files changed, 78 insertions(+), 37 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 3b4e126943..ded56958c9 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -666,6 +666,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int _num_comm_sm; int _cga_size; bool _atomic_gemm; + bool _use_fused_sendrecv; UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, @@ -727,6 +728,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::cuda::getStreamFromExternal(stream, stream_main.device_index())); } + // Set to use a single kernel for push_send and push_recv + const char *use_fused_sendrecv = std::getenv("NVTE_UB_FUSED_SEND_RECV"); + _use_fused_sendrecv = use_fused_sendrecv != nullptr && use_fused_sendrecv[0] == '1'; + // Set the number of SMs for GEMM with margin cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); @@ -943,9 +948,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + } for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } @@ -959,13 +965,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int send_offset = comm_bytes * send_chunk_id; int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); + if (_use_fused_sendrecv) { + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, + _ub_comm, peer_rank, peer_rank, (cudaStream_t)_stream_send); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); + } int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; @@ -998,14 +1011,22 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (i < num_steps - 1) { // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + if (_use_fused_sendrecv) { + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes * 2, + _ub_comm, next_rank, prev_rank, (cudaStream_t)_stream_send); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); @@ -1043,14 +1064,22 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (i < _tp_size - 1) { // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + if (_use_fused_sendrecv) { + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, + _ub_comm, _next_rank, _prev_rank, (cudaStream_t)_stream_send); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, (cudaStream_t)_stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent( + (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } } else if (B_copy.numel() > 0) { assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); @@ -1066,8 +1095,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { } NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + } at::cuda::setCurrentCUDAStream(stream_main); _ub_comm->sms = ori_sms; @@ -1190,7 +1221,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + } for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } @@ -1218,12 +1251,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; NVTE_CHECK_CUDA(cudaEventRecord( _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); + if (_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, + _ub_comm, send_rank, recv_rank, (cudaStream_t)_stream_send); + } else { + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + send_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + recv_rank, (cudaStream_t)_stream_recv); + } } } at::cuda::setCurrentCUDAStream(stream_main); @@ -1233,8 +1272,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { } NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + if (!_use_fused_sendrecv) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + } // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 0cd2a0253b..1af022d5b0 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -2277,7 +2277,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; void *send_dstptr = - reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + recv_offset; if (comm->use_ce) { // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr));