Skip to content

Commit

Permalink
[Unity][Disco] separate computation and communication into 2 stream (…
Browse files Browse the repository at this point in the history
…#15742)

put computation on default stream and put communication on a new stream.
  • Loading branch information
jinhongyii authored Sep 15, 2023
1 parent 25d6c45 commit afb2e42
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
41 changes: 29 additions & 12 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ namespace nccl {
struct NCCLThreadLocalContext {
DiscoWorker* worker;
int device_id;
cudaStream_t stream;
cudaStream_t comm_stream;
cudaStream_t compute_stream = nullptr;
ncclComm_t comm;

void Clear() {
NCCL_CALL(ncclCommDestroy(comm));
CUDA_CALL(cudaStreamDestroy(stream));
CUDA_CALL(cudaStreamDestroy(comm_stream));
}

static NCCLThreadLocalContext* Get() {
Expand Down Expand Up @@ -74,9 +75,8 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
// Step up local context of NCCL
int device_id = device_ids[worker->worker_id];
CUDA_CALL(cudaSetDevice(device_id));
CUDA_CALL(cudaStreamCreate(&ctx->stream));
CUDA_CALL(cudaStreamCreate(&ctx->comm_stream));
Device device{DLDeviceType::kDLCUDA, device_id};
DeviceAPI::Get(device)->SetStream(device, ctx->stream);
worker->default_device = device;
worker->ccl = "nccl";
ctx->worker = worker;
Expand All @@ -91,26 +91,34 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
Device device = ctx->worker->default_device;
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream);
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, ctx->stream));
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, ctx->comm_stream));
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream);
}

void BroadcastFromWorker0(NDArray send, NDArray recv) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
Device device = ctx->worker->default_device;
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream);
NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*root=*/0, ctx->comm, ctx->stream));
/*root=*/0, ctx->comm, ctx->comm_stream));
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream);
}

void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None";
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
Device device = ctx->worker->default_device;
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream);
if (worker_id == 0) {
CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0.";
NDArray buffer = send.value();
Expand All @@ -129,7 +137,8 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
for (int i = 0; i < num_workers; ++i) {
NCCL_CALL(ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->stream));
NCCL_CALL(
ncclSend(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->comm_stream));
data += bytes_per_shard;
}
} else {
Expand All @@ -142,15 +151,18 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
}
int64_t numel = recv.Shape()->Product();
DataType dtype(recv->dtype);
NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->stream));
NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->comm_stream));
NCCL_CALL(ncclGroupEnd());
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream);
}

void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
CHECK(send.defined()) << "ValueError: buffer `send` must not be None";
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
Device device = ctx->worker->default_device;
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream);
if (worker_id == 0) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0.";
NDArray buffer = recv.value();
Expand All @@ -169,7 +181,8 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
for (int i = 0; i < num_workers; ++i) {
NCCL_CALL(ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->stream));
NCCL_CALL(
ncclRecv(data, numel_per_shard, AsNCCLDataType(dtype), i, ctx->comm, ctx->comm_stream));
data += bytes_per_shard;
}
} else {
Expand All @@ -182,24 +195,28 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
}
int64_t numel = send.Shape()->Product();
DataType dtype(send->dtype);
NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->stream));
NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, ctx->comm, ctx->comm_stream));
NCCL_CALL(ncclGroupEnd());
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream);
}

void RecvFromWorker0(NDArray buffer) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
CHECK_NE(ctx->worker->worker_id, 0)
<< "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
Device device = ctx->worker->default_device;
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->compute_stream, ctx->comm_stream);
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0,
ctx->comm, ctx->stream));
ctx->comm, ctx->comm_stream));
NCCL_CALL(ncclGroupEnd());
DeviceAPI::Get(device)->SyncStreamFromTo(device, ctx->comm_stream, ctx->compute_stream);
}

void SyncWorker() {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
CUDA_CALL(cudaStreamSynchronize(ctx->stream));
CUDA_CALL(cudaStreamSynchronize(ctx->compute_stream));
}

TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL);
Expand Down
11 changes: 3 additions & 8 deletions tests/python/disco/test_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import relax as rx
from tvm.runtime import disco as di
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_scatter(session_kind):

@pytest.mark.parametrize("session_kind", _all_session_kinds)
def test_gather(session_kind):
devices = [1, 2]
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl("nccl", *devices)

Expand Down Expand Up @@ -376,10 +377,4 @@ def relax_build(mod, target):


if __name__ == "__main__":
test_init(di.ProcessSession)
test_allreduce(di.ProcessSession)
test_broadcast_from_worker0(di.ProcessSession)
test_scatter(di.ProcessSession)
test_gather(di.ProcessSession)
test_mlp(di.ProcessSession)
test_attention(di.ProcessSession)
tvm.testing.main()

0 comments on commit afb2e42

Please sign in to comment.