diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index ded56958c9..8a804c164c 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -966,15 +966,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; if (_use_fused_sendrecv) { - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, - _ub_comm, peer_rank, peer_rank, (cudaStream_t)_stream_send); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm, + peer_rank, peer_rank, (cudaStream_t)_stream_send); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + peer_rank, (cudaStream_t)_stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + peer_rank, (cudaStream_t)_stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); @@ -1019,9 +1019,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else { userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); + next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); + prev_rank, (cudaStream_t)_stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent( @@ -1065,16 +1065,16 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (i < _tp_size - 1) { // P2P communication if (_use_fused_sendrecv) { - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, - _ub_comm, _next_rank, _prev_rank, (cudaStream_t)_stream_send); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm, + _next_rank, _prev_rank, (cudaStream_t)_stream_send); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaStreamWaitEvent( (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else { userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); + _next_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); + _prev_rank, (cudaStream_t)_stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent( @@ -1253,15 +1253,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); if (_use_fused_sendrecv) { NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, - _ub_comm, send_rank, recv_rank, (cudaStream_t)_stream_send); + userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm, + send_rank, recv_rank, (cudaStream_t)_stream_send); } else { NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); + send_rank, (cudaStream_t)_stream_send); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); + recv_rank, (cudaStream_t)_stream_recv); } } }