Skip to content

Commit

Permalink
TL/NCCL: support for lazy team init
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Mar 29, 2023
1 parent 9a42c4e commit d0e3ca4
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 24 deletions.
5 changes: 5 additions & 0 deletions src/components/tl/nccl/tl_nccl.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
{"", "", NULL, ucc_offsetof(ucc_tl_nccl_context_config_t, super),
UCC_CONFIG_TYPE_TABLE(ucc_tl_context_config_table)},

{"LAZY_INIT", "yes",
"Initialize NCCL communicator on first collective",
ucc_offsetof(ucc_tl_nccl_context_config_t, lazy_init),
UCC_CONFIG_TYPE_BOOL},

{"SYNC", "auto",
"Determines how UCC tests completion of NCCL collective",
ucs_offsetof(ucc_tl_nccl_context_config_t, sync_type),
Expand Down
11 changes: 10 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ typedef enum ucc_tl_nccl_completion_sync_type {

typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
int lazy_init;
ucc_tl_nccl_completion_sync_type_t sync_type;
} ucc_tl_nccl_context_config_t;

Expand All @@ -80,9 +81,15 @@ typedef struct ucc_tl_nccl_context {
UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,
const ucc_base_config_t *);

enum {
TL_NCCL_COMM_STATE_READY,
TL_NCCL_COMM_STATE_INIT,
TL_NCCL_COMM_STATE_ERROR
};

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
int comm_state;
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
Expand Down Expand Up @@ -119,6 +126,8 @@ typedef struct ucc_tl_nccl_task {
UCC_COLL_TYPE_GATHERV | UCC_COLL_TYPE_SCATTER | \
UCC_COLL_TYPE_SCATTERV)

ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team);

UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);

Expand Down
8 changes: 7 additions & 1 deletion src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_tl_nccl_task_t **coll_task)
{
ucc_tl_nccl_team_t *nccl_team = ucc_derived_of(team, ucc_tl_nccl_team_t);
ucc_tl_nccl_context_t *nccl_ctx = ucc_derived_of(team->context,
ucc_tl_nccl_context_t);
ucc_tl_nccl_task_t *task;
Expand All @@ -142,6 +143,11 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
return UCC_ERR_NOT_SUPPORTED;
}

status = ucc_tl_nccl_comm_init(nccl_team);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

task = ucc_mpool_get(&nccl_ctx->req_mp);
if (ucc_unlikely(!task)) {
tl_error(team->context->lib, "failed to get task from mpool");
Expand Down Expand Up @@ -203,7 +209,7 @@ ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;

if (ucc_unlikely(task->super.super.status != UCC_OK)) {
team->comm_state = task->super.super.status;
team->comm_state = TL_NCCL_COMM_STATE_ERROR;
}
tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task);
ucc_tl_nccl_free_task(task);
Expand Down
84 changes: 62 additions & 22 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,44 @@
#include "coll_score/ucc_coll_score.h"
#include "utils/arch/cuda_def.h"

ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team)
{
ucc_status_t status;
ncclResult_t nccl_status;

if (team->comm_state == TL_NCCL_COMM_STATE_READY) {
return UCC_OK;
} else if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) {
return UCC_ERR_NOT_SUPPORTED;
}
CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), err, status);
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
if (nccl_status != ncclSuccess) {
tl_debug(team->super.super.context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
if (nccl_status == ncclInvalidUsage) {
/*
* handles the case when trying to inititize multiple ranks
* on the same GPU. Return "not supported" and fallback to other TL
*/
status = UCC_ERR_NOT_SUPPORTED;
} else {
status = UCC_ERR_NO_RESOURCE;
}
team->comm_state = TL_NCCL_COMM_STATE_ERROR;
goto free_stream;
}
team->comm_state = TL_NCCL_COMM_STATE_READY;
return UCC_OK;

free_stream:
cudaStreamDestroy(team->stream);
err:
return status;
}

UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
Expand All @@ -22,7 +60,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

size = UCC_TL_TEAM_SIZE(self);
self->comm_state = UCC_OK;
self->comm_state = TL_NCCL_COMM_STATE_INIT;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
Expand Down Expand Up @@ -58,7 +96,7 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
{
tl_debug(self->super.super.context->lib, "finalizing tl team: %p", self);
if (self->nccl_comm) {
if (self->comm_state != UCC_OK) {
if (self->comm_state == TL_NCCL_COMM_STATE_ERROR) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(self->nccl_comm);
Expand All @@ -67,6 +105,7 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
}
cudaStreamDestroy(self->stream);
}
ucc_free(self->unique_id);
}

UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_nccl_team_t, ucc_base_team_t);
Expand All @@ -80,9 +119,10 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_tl_nccl_context_t *ctx = ucc_derived_of(tl_team->context,
ucc_tl_nccl_context_t);
ucc_status_t status;
ncclResult_t nccl_status;
ncclUniqueId errorid;

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
Expand All @@ -92,38 +132,38 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
if (status != UCC_OK) {
UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
tl_error(tl_team->context->lib, "oob req test failed");
goto free_unique_id;
goto err;
}
status = UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
if (status != UCC_OK) {
tl_error(tl_team->context->lib, "oob req free failed");
goto free_unique_id;
goto err;
}
team->unique_id = ucc_realloc(team->unique_id, sizeof(ncclUniqueId),
"nccl unique id");
if (!team->unique_id) {
tl_error(tl_team->context->lib,
"failed to realloc %zd bytest for ncclUniqueId",
sizeof(ncclUniqueId));
return UCC_ERR_NO_MEMORY;
}
/* check unique id is valid */
memset(&errorid, 0, sizeof(errorid));
if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) {
tl_error(tl_team->context->lib, "incorrect unique id");
goto free_unique_id;
goto err;
}

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
goto free_stream;
if (!ctx->cfg.lazy_init) {
status = ucc_tl_nccl_comm_init(team);
if (status != UCC_OK) {
goto err;
}
}
ucc_free(team->unique_id);

tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
return UCC_OK;

free_stream:
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
err:
return status;
}

Expand Down

0 comments on commit d0e3ca4

Please sign in to comment.