diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index e404e3c2bb..88eaf03c44 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -39,12 +39,13 @@ namespace nccl { struct NCCLThreadLocalContext { DiscoWorker* worker; int device_id; - cudaStream_t stream; + cudaStream_t comm_stream; + cudaStream_t compute_stream = nullptr; ncclComm_t comm; void Clear() { NCCL_CALL(ncclCommDestroy(comm)); - CUDA_CALL(cudaStreamDestroy(stream)); + CUDA_CALL(cudaStreamDestroy(comm_stream)); } static NCCLThreadLocalContext* Get() { @@ -74,9 +75,8 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) { // Step up local context of NCCL int device_id = device_ids[worker->worker_id]; CUDA_CALL(cudaSetDevice(device_id)); - CUDA_CALL(cudaStreamCreate(&ctx->stream)); + CUDA_CALL(cudaStreamCreate(&ctx->comm_stream)); Device device{DLDeviceType::kDLCUDA, device_id}; - DeviceAPI::Get(device)->SetStream(device, ctx->stream); worker->default_device = device; worker->ccl = "nccl"; ctx->worker = worker; @@ -91,9 +91,12 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); + Device device = ctx->worker->default_device; + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream); NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(send->dtype)), - /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, ctx->stream)); + /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, ctx->comm_stream)); + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream); } void BroadcastFromWorker0(NDArray send, NDArray recv) { @@ -101,9 +104,12 @@ void BroadcastFromWorker0(NDArray send, NDArray recv) { ICHECK(send.Shape()->Product() == recv.Shape()->Product()); ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); + Device device = ctx->worker->default_device; + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream); NCCL_CALL(ncclBroadcast(send->data, recv->data, numel, /*datatype=*/AsNCCLDataType(DataType(send->dtype)), - /*root=*/0, ctx->comm, ctx->stream)); + /*root=*/0, ctx->comm, ctx->comm_stream)); + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream); } void ScatterFromWorker0(Optional send, NDArray recv) { @@ -111,6 +117,8 @@ void ScatterFromWorker0(Optional send, NDArray recv) { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + Device device = ctx->worker->default_device; + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream); if (worker_id == 0) { CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; NDArray buffer = send.value(); @@ -129,7 +137,8 @@ void ScatterFromWorker0(Optional send, NDArray recv) { NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->stream)); + NCCL_CALL( + ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->comm_stream)); data += bytes_per_shard; } } else { @@ -142,8 +151,9 @@ void ScatterFromWorker0(Optional send, NDArray recv) { } int64_t numel = recv.Shape()->Product(); DataType dtype(recv->dtype); - NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->stream)); + NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->comm_stream)); NCCL_CALL(ncclGroupEnd()); + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream); } void GatherToWorker0(NDArray send, Optional recv) { @@ -151,6 +161,8 @@ void GatherToWorker0(NDArray send, Optional recv) { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; + Device device = ctx->worker->default_device; + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream); if (worker_id == 0) { CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; NDArray buffer = recv.value(); @@ -169,7 +181,8 @@ void GatherToWorker0(NDArray send, Optional recv) { NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_workers; ++i) { - NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->stream)); + NCCL_CALL( + ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->comm_stream)); data += bytes_per_shard; } } else { @@ -182,24 +195,28 @@ void GatherToWorker0(NDArray send, Optional recv) { } int64_t numel = send.Shape()->Product(); DataType dtype(send->dtype); - NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->stream)); + NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->comm_stream)); NCCL_CALL(ncclGroupEnd()); + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream); } void RecvFromWorker0(NDArray buffer) { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); CHECK_NE(ctx->worker->worker_id, 0) << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; + Device device = ctx->worker->default_device; + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream); NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0, - ctx->comm, ctx->stream)); + ctx->comm, ctx->comm_stream)); NCCL_CALL(ncclGroupEnd()); + DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream); } void SyncWorker() { NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get(); ICHECK(ctx->worker != nullptr); - CUDA_CALL(cudaStreamSynchronize(ctx->stream)); + CUDA_CALL(cudaStreamSynchronize(ctx->compute_stream)); } TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL); diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py index f0f949ab80..e86c973fc2 100644 --- a/tests/python/disco/test_nccl.py +++ b/tests/python/disco/test_nccl.py @@ -22,6 +22,7 @@ import pytest import tvm +import tvm.testing from tvm import dlight as dl from tvm import relax as rx from tvm.runtime import disco as di @@ -103,7 +104,7 @@ def test_scatter(session_kind): @pytest.mark.parametrize("session_kind", _all_session_kinds) def test_gather(session_kind): - devices = [1, 2] + devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl("nccl", *devices) @@ -376,10 +377,4 @@ def relax_build(mod, target): if __name__ == "__main__": - test_init(di.ProcessSession) - test_allreduce(di.ProcessSession) - test_broadcast_from_worker0(di.ProcessSession) - test_scatter(di.ProcessSession) - test_gather(di.ProcessSession) - test_mlp(di.ProcessSession) - test_attention(di.ProcessSession) + tvm.testing.main()