diff --git a/src/components/tl/cuda/tl_cuda.h b/src/components/tl/cuda/tl_cuda.h index d99a49b16c..6237694749 100644 --- a/src/components/tl/cuda/tl_cuda.h +++ b/src/components/tl/cuda/tl_cuda.h @@ -142,7 +142,6 @@ typedef struct ucc_tl_cuda_sync { ucc_tl_cuda_mem_info_t mem_info_src; ucc_tl_cuda_mem_info_t mem_info_dst; cudaEvent_t ipc_event_local; - cudaIpcEventHandle_t ev_handle; union { struct { size_t sbytes[UCC_TL_CUDA_MAX_PEERS]; diff --git a/src/components/tl/cuda/tl_cuda_coll.h b/src/components/tl/cuda/tl_cuda_coll.h index 1588433905..3096ac5a7d 100644 --- a/src/components/tl/cuda/tl_cuda_coll.h +++ b/src/components/tl/cuda/tl_cuda_coll.h @@ -20,7 +20,7 @@ extern const char #define UCC_TL_CUDA_CHECK_DEVICE_MATCH(_team) do { \ int _dev; \ CUDA_CHECK(cudaGetDevice(&_dev)); \ - if (_dev != (_team)->device) { \ + if (((_team)->device != -1) && _dev != (_team)->device) { \ tl_error(UCC_TL_TEAM_LIB(_team), "CUDA device mismatch, " \ "current device %d, team device %d\n", _dev, \ (_team)->device); \ diff --git a/src/components/tl/cuda/tl_cuda_team.c b/src/components/tl/cuda/tl_cuda_team.c index 17154858ae..b53661eae8 100644 --- a/src/components/tl/cuda/tl_cuda_team.c +++ b/src/components/tl/cuda/tl_cuda_team.c @@ -38,13 +38,17 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team) if (cu_ctx == NULL || cu_st != CUDA_SUCCESS) { tl_debug(tl_lib, "cannot create CUDA TL team without active CUDA context"); - return UCC_ERR_NO_RESOURCE; + team->device_id = TL_CUDA_DEVICE_INVALID; + team->state = TL_CUDA_STATE_ERROR; + goto exchnage_rank_ids; } status = CUDA_FUNC(cudaGetDevice(&team->device)); if (status != UCC_OK) { tl_debug(tl_lib, "failed to get current device id"); - return status; + team->device_id = TL_CUDA_DEVICE_INVALID; + team->state = TL_CUDA_STATE_ERROR; + goto exchnage_rank_ids; } status = ucc_tl_cuda_topo_get_pci_id(team->device, &team->device_id); @@ -88,6 +92,7 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team) goto free_scratch; } +exchnage_rank_ids: rank_id->pci_id = team->device_id; status = team->oob.allgather(rank_id, team->ids, rank_id_size, team->oob.coll_info, &team->oob_req); @@ -127,6 +132,17 @@ ucc_status_t ucc_tl_cuda_comm_init_test(ucc_tl_cuda_team_t *team) return status; } team->oob.req_free(team->oob_req); + /* check all ranks have valid CUDA device set */ + for (r = 0; r < tsize; r++) { + rank_id = GET_RANK_ID(team->ids, r, max_concurrent); + if (ucc_tl_cuda_topo_device_id_equal(&rank_id->pci_id, + &TL_CUDA_DEVICE_INVALID)) { + tl_debug(tl_lib, "rank %d device is invalid, team can't be created", + r); + team->state = TL_CUDA_STATE_ERROR; + return UCC_ERR_NO_RESOURCE; + } + } status = ucc_tl_cuda_team_topo_create(&team->super, &team->topo); if (status != UCC_OK) { @@ -230,10 +246,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context, ucc_tl_cuda_rank_id_t *rank_id; UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params); - self->oob = params->params.oob; - self->stream = NULL; - self->topo = NULL; - self->scratch.loc = NULL; + self->oob = params->params.oob; + self->stream = NULL; + self->topo = NULL; + self->device = -1; + memset(&self->scratch, 0, sizeof(ucc_tl_cuda_scratch_t)); if (!ucc_team_map_is_single_node(params->team, params->map)) { tl_debug(tl_context->lib, "multinode team is not supported"); diff --git a/src/components/tl/cuda/tl_cuda_topo.c b/src/components/tl/cuda/tl_cuda_topo.c index 96862e921e..031ecfdc21 100644 --- a/src/components/tl/cuda/tl_cuda_topo.c +++ b/src/components/tl/cuda/tl_cuda_topo.c @@ -25,6 +25,13 @@ pthread_mutex_t nvml_lock = PTHREAD_MUTEX_INITIALIZER; } \ } while (0) +const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID = { + .domain = 0xFFFF, + .bus = 0xFF, + .device = 0xFF, + .function = 0xFF, +}; + static ucc_status_t ucc_tl_cuda_topo_pci_id_from_str(const char * bus_id_str, ucc_tl_cuda_device_pci_id_t *pci_id) diff --git a/src/components/tl/cuda/tl_cuda_topo.h b/src/components/tl/cuda/tl_cuda_topo.h index 7d27d236b5..9823ecbcc3 100644 --- a/src/components/tl/cuda/tl_cuda_topo.h +++ b/src/components/tl/cuda/tl_cuda_topo.h @@ -19,6 +19,8 @@ typedef struct ucc_tl_cuda_device_id { uint8_t function; /* range: 0 to 7 */ } ucc_tl_cuda_device_pci_id_t; +extern const ucc_tl_cuda_device_pci_id_t TL_CUDA_DEVICE_INVALID; + typedef enum ucc_tl_cuda_topo_dev_type { UCC_TL_CUDA_TOPO_DEV_TYPE_GPU, UCC_TL_CUDA_TOPO_DEV_TYPE_SWITCH, diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index 8231c2359f..e83f9df423 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -60,6 +60,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params); size = UCC_TL_TEAM_SIZE(self); + self->nccl_comm = NULL; self->comm_state = TL_NCCL_COMM_STATE_INIT; self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1), "tl_nccl_unique_id");