Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Sep 10, 2024
1 parent 00801bc commit a5913a7
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 27 deletions.
29 changes: 13 additions & 16 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ccl::communicator &getComm() { return g_comms[0]; }
ccl::shared_ptr_class<ccl::kvs> &getKvs() { return g_kvs[0]; }

JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jint computeDeviceOrdinal,
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port,
jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");
Expand All @@ -64,21 +64,18 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

g_kvs.push_back(singletonCCLInit.kvs);
ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
break;
}

#ifdef CPU_ONLY_PROFILE
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
#endif

jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
Expand Down
7 changes: 3 additions & 4 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ object OneCCL extends Logging {

var cclParam = new CCLParam()

def init(executor_num: Int, rank: Int, ip_port: String, computeDeviceOrdinal: Int): Unit = {
def init(executor_num: Int, rank: Int, ip_port: String): Unit = {

setExecutorEnv()

logInfo(s"Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip_port, computeDeviceOrdinal, cclParam)
c_init(executor_num, rank, ip_port, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.getCommSize,
Expand Down Expand Up @@ -67,8 +67,7 @@ object OneCCL extends Logging {

@native def c_getAvailPort(localIP: String): Int

@native private def c_init(size: Int, rank: Int, ip_port: String,
computeDeviceOrdinal: Int, param: CCLParam): Int
@native private def c_init(size: Int, rank: Int, ip_port: String, param: CCLParam): Int

@native private def c_cleanup(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class RandomForestClassifierDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
rfcTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class KMeansDALImpl(var nClusters: Int,
val kvsIPPort = getOneCCLIPPort(coalescedTables)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
kmeansTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PCADALImpl(val k: Int,
pcaTimer.record("Data Convertion")

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
pcaTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class LinearRegressionDALImpl( val fitIntercept: Boolean,
(label.toString.toLong, 0L, 0L)
}

OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
val result = new LiRResult()

val gpuIndices = if (useDevice == "GPU") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class RandomForestRegressorDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
rfrTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CorrelationDALImpl(
val kvsIPPort = getOneCCLIPPort(coalescedTables)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
corTimer.record("OneCCL Init")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SummarizerDALImpl(val executorNum: Int,
val kvsIPPort = getOneCCLIPPort(data)

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, computeDevice.ordinal())
OneCCL.init(executorNum, rank, kvsIPPort)
Iterator.empty
}.count()
sumTimer.record("OneCCL Init")
Expand Down

0 comments on commit a5913a7

Please sign in to comment.