diff --git a/src/init.cc b/src/init.cc index 9da0098be..371204eee 100644 --- a/src/init.cc +++ b/src/init.cc @@ -299,6 +299,8 @@ static ncclResult_t commFree(ncclComm_t comm) { return ncclSuccess; } +RCCL_PARAM(AllToAllDisable, "ALLTOALL_KERNEL_DISABLE", 0); + static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { if (ndev < 1) { WARN("invalid device count (%d) requested", ndev); @@ -362,6 +364,9 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { // Mark channels as non initialized. for (int c=0; cchannels[c].id = -1; + comm->alltoallDisable = false; + if (rcclParamAllToAllDisable()) comm->alltoallDisable = true; + *comret = comm; return ncclSuccess; } @@ -662,7 +667,6 @@ static ncclResult_t checkCollNetSetup(struct ncclComm* comm, int rank, int collN NCCL_PARAM(CrossNic, "CROSS_NIC", 2); NCCL_PARAM(GraphDumpFileRank, "GRAPH_DUMP_FILE_RANK", 0); -RCCL_PARAM(AllToAllDisable, "ALLTOALL_KERNEL_DISABLE", 0); static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* commId) { // We use 3 AllGathers @@ -760,6 +764,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm int fullCudaCompCap; int nChannels; int gcn; + int alltoallDisable; struct ncclGraphInfo tree; struct ncclGraphInfo ring; struct ncclGraphInfo collNet; @@ -773,6 +778,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].gcn = comm->topo->nodes[GPU].nodes[idx].gpu.gcn; allGather3Data[rank].nChannels = comm->nChannels = treeGraph.nChannels = ringGraph.nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels); + allGather3Data[rank].alltoallDisable = comm->alltoallDisable; allGather3Data[rank].tree.sameChannels = treeGraph.sameChannels; allGather3Data[rank].tree.speedIntra = treeGraph.speedIntra; allGather3Data[rank].tree.speedInter = treeGraph.speedInter; @@ -818,9 +824,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclTopoRanks** allTopoRanks; NCCLCHECK(ncclCalloc(&allTopoRanks, comm->nRanks)); int gcn = allGather3Data[0].gcn; + int alltoallDisable = 0; for (int i=0; inChannels = std::min(allGather3Data[i].nChannels, comm->nChannels); treeGraph.sameChannels = std::min(allGather3Data[i].tree.sameChannels, treeGraph.sameChannels); @@ -836,6 +844,10 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm collNetGraph.speedInter = std::min(allGather3Data[i].collNet.speedInter, collNetGraph.speedInter); collNetGraph.typeIntra = std::min(allGather3Data[i].collNet.typeIntra, collNetGraph.typeIntra); } + if (comm->alltoallDisable != alltoallDisable) { + comm->alltoallDisable = alltoallDisable; + } + INFO(NCCL_INIT, "RCCL AllToAll/Scatter/Gather kernels %s", comm->alltoallDisable ? "disabled" : "enabled"); if (comm->nChannels < nChannelsOrig) { // We started duplicating channels during Preset(), so we need to move the @@ -927,9 +939,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Compute nChannels per peer for p2p NCCLCHECK(ncclTopoComputeP2pChannels(comm)); - comm->alltoallDisable = true; - if (rcclParamAllToAllDisable() == 0) { - comm->alltoallDisable = false; + if (!alltoallDisable) { for (int c=0; cnChannels; c++) { const int peersPerChan = (comm->nChannels >= nranks ? 1 : DIVUP(nranks, comm->nChannels)); struct ncclP2PConnect* connect = &comm->p2plist.connect; @@ -968,7 +978,6 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm connect->nsend[c] = 0; } } - INFO(NCCL_INIT, "RCCL AllToAll/Scatter/Gather kernels %s", comm->alltoallDisable ? "disabled" : "enabled"); // We should have allocated all buffers, collective fifos, ... we can // restore the affinity. affinity_restore: