Skip to content

Commit

Permalink
TL/CUDA: check devices
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Mar 29, 2023
1 parent 915c5ac commit 26c36ac
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 8 deletions.
1 change: 0 additions & 1 deletion src/components/tl/cuda/tl_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ typedef struct ucc_tl_cuda_sync {
ucc_tl_cuda_mem_info_t mem_info_src;
ucc_tl_cuda_mem_info_t mem_info_dst;
cudaEvent_t ipc_event_local;
cudaIpcEventHandle_t ev_handle;
union {
struct {
size_t sbytes[UCC_TL_CUDA_MAX_PEERS];
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/cuda/tl_cuda_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ extern const char
#define UCC_TL_CUDA_CHECK_DEVICE_MATCH(_team) do { \
int _dev; \
CUDA_CHECK(cudaGetDevice(&_dev)); \
if (_dev != (_team)->device) { \
if (((_team)->device != -1) && _dev != (_team)->device) { \
tl_error(UCC_TL_TEAM_LIB(_team), "CUDA device mismatch, " \
"current device %d, team device %d\n", _dev, \
(_team)->device); \
Expand Down
29 changes: 23 additions & 6 deletions src/components/tl/cuda/tl_cuda_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
if (cu_ctx == NULL || cu_st != CUDA_SUCCESS) {
tl_debug(tl_lib,
"cannot create CUDA TL team without active CUDA context");
return UCC_ERR_NO_RESOURCE;
team->device_id = TL_CUDA_DEVICE_INVALID;
team->state = TL_CUDA_STATE_ERROR;
goto exchnage_rank_ids;
}

status = CUDA_FUNC(cudaGetDevice(&team->device));
if (status != UCC_OK) {
tl_debug(tl_lib, "failed to get current device id");
return status;
team->device_id = TL_CUDA_DEVICE_INVALID;
team->state = TL_CUDA_STATE_ERROR;
goto exchnage_rank_ids;
}

status = ucc_tl_cuda_topo_get_pci_id(team->device, &team->device_id);
Expand Down Expand Up @@ -88,6 +92,7 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
goto free_scratch;
}

exchnage_rank_ids:
rank_id->pci_id = team->device_id;
status = team->oob.allgather(rank_id, team->ids, rank_id_size,
team->oob.coll_info, &team->oob_req);
Expand Down Expand Up @@ -127,6 +132,17 @@ ucc_status_t ucc_tl_cuda_comm_init_test(ucc_tl_cuda_team_t *team)
return status;
}
team->oob.req_free(team->oob_req);
/* check all ranks have valid CUDA device set */
for (r = 0; r < tsize; r++) {
rank_id = GET_RANK_ID(team->ids, r, max_concurrent);
if (ucc_tl_cuda_topo_device_id_equal(&rank_id->pci_id,
&TL_CUDA_DEVICE_INVALID)) {
tl_debug(tl_lib, "rank %d device is invalid, team can't be created",
r);
team->state = TL_CUDA_STATE_ERROR;
return UCC_ERR_NO_RESOURCE;
}
}

status = ucc_tl_cuda_team_topo_create(&team->super, &team->topo);
if (status != UCC_OK) {
Expand Down Expand Up @@ -230,10 +246,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context,
ucc_tl_cuda_rank_id_t *rank_id;

UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
self->oob = params->params.oob;
self->stream = NULL;
self->topo = NULL;
self->scratch.loc = NULL;
self->oob = params->params.oob;
self->stream = NULL;
self->topo = NULL;
self->device = -1;
memset(&self->scratch, 0, sizeof(ucc_tl_cuda_scratch_t));

if (!ucc_team_map_is_single_node(params->team, params->map)) {
tl_debug(tl_context->lib, "multinode team is not supported");
Expand Down
7 changes: 7 additions & 0 deletions src/components/tl/cuda/tl_cuda_topo.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pthread_mutex_t nvml_lock = PTHREAD_MUTEX_INITIALIZER;
} \
} while (0)

const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID = {
.domain = 0xFFFF,
.bus = 0xFF,
.device = 0xFF,
.function = 0xFF,
};

static ucc_status_t
ucc_tl_cuda_topo_pci_id_from_str(const char * bus_id_str,
ucc_tl_cuda_device_pci_id_t *pci_id)
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/cuda/tl_cuda_topo.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ typedef struct ucc_tl_cuda_device_id {
uint8_t function; /* range: 0 to 7 */
} ucc_tl_cuda_device_pci_id_t;

extern const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID;

typedef enum ucc_tl_cuda_topo_dev_type {
UCC_TL_CUDA_TOPO_DEV_TYPE_GPU,
UCC_TL_CUDA_TOPO_DEV_TYPE_SWITCH,
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

size = UCC_TL_TEAM_SIZE(self);
self->nccl_comm = NULL;
self->comm_state = TL_NCCL_COMM_STATE_INIT;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
Expand Down

0 comments on commit 26c36ac

Please sign in to comment.