diff --git a/src/init.cc b/src/init.cc index 86c73b15e..3e37abae7 100644 --- a/src/init.cc +++ b/src/init.cc @@ -425,11 +425,10 @@ NCCL_PARAM(GdrCopyFifoEnable, "GDRCOPY_FIFO_ENABLE", 1); NCCL_PARAM(WorkFifoDepth, "WORK_FIFO_DEPTH", 64<<10); enum ncclLaunchMode ncclParamLaunchMode; -NCCL_PARAM(DmaBufEnable, "DMABUF_ENABLE", 0); // Detect DMA-BUF support static ncclResult_t dmaBufSupported(struct ncclComm* comm) { - if (ncclParamDmaBufEnable() == 0 || comm->ncclNet->regMrDmaBuf == NULL || rocmLibraryInit() != ncclSuccess) return ncclInternalError; + if (comm->ncclNet->regMrDmaBuf == NULL || rocmLibraryInit() != ncclSuccess) return ncclInternalError; #if CUDA_VERSION >= 11070 int flag = 0; CUdevice dev; @@ -1727,7 +1726,7 @@ constexpr nvtxPayloadSchemaEntry_t CommInitRankSchema[] = { NCCL_API(ncclResult_t, ncclCommInitRank, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank); ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank) { // Load the CUDA driver and dlsym hooks (can fail on old drivers) - if (ncclParamDmaBufEnable()) rocmLibraryInit(); + rocmLibraryInit(); int cudaDev; ncclConfig_t config = NCCL_CONFIG_INITIALIZER; @@ -1743,7 +1742,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId comm NCCL_API(ncclResult_t, ncclCommInitRankMulti, ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int virtualId); ncclResult_t ncclCommInitRankMulti(ncclComm_t* newcomm, int nranks, ncclUniqueId commId, int myrank, int virtualId) { // Load the CUDA driver and dlsym hooks (can fail on old drivers) - if (ncclParamDmaBufEnable()) rocmLibraryInit(); + rocmLibraryInit(); int cudaDev; ncclConfig_t config = NCCL_CONFIG_INITIALIZER; @@ -1770,7 +1769,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { NVTX3_FUNC_WITH_PARAMS(CommInitAll, CommInitAllSchema, ndev) // Load the CUDA driver and dlsym hooks (can fail on old drivers) - if (ncclParamDmaBufEnable()) (void) rocmLibraryInit(); + rocmLibraryInit(); NCCLCHECKGOTO(PtrCheck(comms, "CommInitAll", "comms"), ret, fail); if (ndev < 0) { @@ -1834,7 +1833,7 @@ ncclResult_t ncclCommInitRankConfig(ncclComm_t *newcomm, int nranks, ncclUniqueI ncclConfig_t *internalConfigPtr = NULL; NCCLCHECK(ncclGroupStartInternal()); - if (ncclParamDmaBufEnable()) (void) rocmLibraryInit(); + rocmLibraryInit(); CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, fail); if (config == NULL) diff --git a/src/misc/rocmwrap.cc b/src/misc/rocmwrap.cc index e32038955..76d368891 100644 --- a/src/misc/rocmwrap.cc +++ b/src/misc/rocmwrap.cc @@ -9,6 +9,7 @@ #include "debug.h" #include "rocmwrap.h" #include "hsa/hsa.h" +#include "param.h" #include #include @@ -17,7 +18,7 @@ #define DECLARE_ROCM_PFN(symbol) PFN_##symbol pfn_##symbol = nullptr DECLARE_ROCM_PFN(hsa_amd_portable_export_dmabuf); // DMA-BUF support - +NCCL_PARAM(DmaBufEnable, "DMABUF_ENABLE", 0); /* ROCr Driver functions loaded with dlsym() */ DECLARE_ROCM_PFN(hsa_init); DECLARE_ROCM_PFN(hsa_system_get_info); @@ -28,7 +29,6 @@ static enum { hsaUninitialized, hsaInitializing, hsaInitialized, hsaError } hsaS static void *hsaLib; static uint16_t version_major, version_minor; bool ncclCudaLaunchBlocking = false; -bool dmaBufSupport = false; ncclResult_t rocmLibraryInit(void) { do { @@ -36,6 +36,7 @@ ncclResult_t rocmLibraryInit(void) { ncclCudaLaunchBlocking = val!=nullptr && val[0]!=0 && !(val[0]=='0' && val[1]==0); } while (0); + bool dmaBufSupport = false; hsa_status_t res; if (hsaState == hsaInitialized) @@ -108,14 +109,21 @@ ncclResult_t rocmLibraryInit(void) { /* DMA-BUF support */ //ROCm support + if (ncclParamDmaBufEnable() == 0 ) { + INFO(NCCL_INIT, "Dmabuf feature disabled without NCCL_ENABLE_DMABUF_SUPPORT=1"); + goto error; + } res = pfn_hsa_system_get_info((hsa_system_info_t) 0x204, &dmaBufSupport); - if (res != HSA_STATUS_SUCCESS || !dmaBufSupport) INFO(NCCL_INIT, "Current version of ROCm does not support dmabuf feature."); + if (res != HSA_STATUS_SUCCESS || !dmaBufSupport) { + INFO(NCCL_INIT, "Current version of ROCm does not support dmabuf feature."); + goto error; + } else { pfn_hsa_amd_portable_export_dmabuf = (PFN_hsa_amd_portable_export_dmabuf) dlsym(hsaLib, "hsa_amd_portable_export_dmabuf"); if (pfn_hsa_amd_portable_export_dmabuf == NULL) { WARN("Failed to load ROCr missing symbol hsa_amd_portable_export_dmabuf"); goto error; - } + } else { //check OS kernel support struct utsname utsname; @@ -126,7 +134,7 @@ ncclResult_t rocmLibraryInit(void) { char buf[256]; int found_opt1 = 0; int found_opt2 = 0; - + //check for kernel name exists if (uname(&utsname) == -1) INFO(NCCL_INIT,"Could not get kernel name"); //format and store the kernel conf file location diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index a01f39113..a2342fb82 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -319,6 +319,7 @@ ncclResult_t ncclIbDmaBufSupport(int dev) { static int dmaBufSupported = -1; if (dmaBufSupported == -1) { ncclResult_t res; + NCCLCHECKGOTO(rocmLibraryInit(), res, failure); struct ibv_pd* pd; struct ibv_context* ctx; ctx = ncclIbDevs[dev].context;