Skip to content

Commit

Permalink
remove oneccl communicator
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Sep 10, 2024
1 parent 0a44fbe commit 00801bc
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 25 deletions.
35 changes: 20 additions & 15 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ccl::communicator &getComm() { return g_comms[0]; }
ccl::shared_ptr_class<ccl::kvs> &getKvs() { return g_kvs[0]; }

JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port,
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jint computeDeviceOrdinal,
jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");
Expand All @@ -58,29 +58,34 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
const char *device = env->GetStringUTFChars(use_device, 0);
ccl::string ccl_ip_port(str);

auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

g_kvs.push_back(singletonCCLInit.kvs);

#ifdef CPU_ONLY_PROFILE
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
#endif
ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
break;
}

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");

env->SetLongField(param, fid_comm_size, comm_size);
env->SetLongField(param, fid_rank_id, rank_id);
env->SetLongField(param, size, comm_size);
env->SetLongField(param, rank, rank_id);
env->ReleaseStringUTFChars(ip_port, str);

return 1;
Expand Down
6 changes: 6 additions & 0 deletions mllib-dal/src/main/native/SummarizerImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,15 @@ static void doSummarizerOneAPICompute(

JNIEXPORT jlong JNICALL
Java_com_intel_oap_mllib_stat_SummarizerDALImpl_cSummarizerTrainDAL(
<<<<<<< HEAD
JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows,
jlong numCols, jint executorNum, jint executorCores,
jint computeDeviceOrdinal, jintArray gpuIdxArray, jobject resultObj) {
=======
JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows, jlong numCols,
jint executorNum, jint executorCores, jint computeDeviceOrdinal,
jintArray gpuIdxArray, jobject resultObj) {
>>>>>>> remove oneccl communicator
logger::println(logger::INFO,
"oneDAL (native): use DPC++ kernels; device %s",
ComputeDeviceString[computeDeviceOrdinal].c_str());
Expand Down
7 changes: 4 additions & 3 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ object OneCCL extends Logging {

var cclParam = new CCLParam()

def init(executor_num: Int, rank: Int, ip_port: String): Unit = {
def init(executor_num: Int, rank: Int, ip_port: String, computeDeviceOrdinal: Int): Unit = {

setExecutorEnv()

logInfo(s"Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip_port, cclParam)
c_init(executor_num, rank, ip_port, computeDeviceOrdinal, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.getCommSize,
Expand Down Expand Up @@ -67,7 +67,8 @@ object OneCCL extends Logging {

@native def c_getAvailPort(localIP: String): Int

@native private def c_init(size: Int, rank: Int, ip_port: String, param: CCLParam): Int
@native private def c_init(size: Int, rank: Int, ip_port: String,
computeDeviceOrdinal: Int, param: CCLParam): Int

@native private def c_cleanup(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class RandomForestClassifierDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
rfcTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class KMeansDALImpl(var nClusters: Int,
val kvsIPPort = getOneCCLIPPort(coalescedTables)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
kmeansTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PCADALImpl(val k: Int,
pcaTimer.record("Data Convertion")

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
pcaTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class LinearRegressionDALImpl( val fitIntercept: Boolean,
(label.toString.toLong, 0L, 0L)
}

OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
val result = new LiRResult()

val gpuIndices = if (useDevice == "GPU") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class RandomForestRegressorDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
rfrTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CorrelationDALImpl(
val kvsIPPort = getOneCCLIPPort(coalescedTables)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
corTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SummarizerDALImpl(val executorNum: Int,
val kvsIPPort = getOneCCLIPPort(data)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
Iterator.empty
}.count()
sumTimer.record("OneCCL Init")
Expand Down

0 comments on commit 00801bc

Please sign in to comment.