Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 26, 2024
1 parent 8367bad commit 5cad985
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,15 +966,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
if (_use_fused_sendrecv) {
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes,
_ub_comm, peer_rank, peer_rank, (cudaStream_t)_stream_send);
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm,
peer_rank, peer_rank, (cudaStream_t)_stream_send);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
peer_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
peer_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
Expand Down Expand Up @@ -1019,9 +1019,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_send);
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_recv);
prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
Expand Down Expand Up @@ -1065,16 +1065,16 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (i < _tp_size - 1) {
// P2P communication
if (_use_fused_sendrecv) {
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, (cudaStream_t)_stream_send);
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, send_offset, comm_bytes, _ub_comm,
_next_rank, _prev_rank, (cudaStream_t)_stream_send);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_send);
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
_prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
Expand Down Expand Up @@ -1253,15 +1253,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
if (_use_fused_sendrecv) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, send_rank, recv_rank, (cudaStream_t)_stream_send);
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm,
send_rank, recv_rank, (cudaStream_t)_stream_send);
} else {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
send_rank, (cudaStream_t)_stream_send);
send_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
recv_rank, (cudaStream_t)_stream_recv);
recv_rank, (cudaStream_t)_stream_recv);
}
}
}
Expand Down

0 comments on commit 5cad985

Please sign in to comment.