Skip to content

Commit

Permalink
Fix error when running with CPU (#403)
Browse files Browse the repository at this point in the history
* update spark to 3.3.3

Signed-off-by: minmingzhu <minming.zhu@intel.com>

* fix cpu error

* update

* update

* format code style

* update

* update

* update

* update

* update

---------

Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu authored Nov 14, 2024
1 parent 6b452d3 commit 4b79dfd
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 64 deletions.
11 changes: 9 additions & 2 deletions mllib-dal/src/main/native/CCLInitSingleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ class CCLInitSingleton {
auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);

t1 = std::chrono::high_resolution_clock::now();
auto kvs_attr = ccl::create_kvs_attr();
kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);

kvs = ccl::create_main_kvs(kvs_attr);

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

logger::println(logger::INFO, "OneCCL singleton init took %f secs",
Expand Down
106 changes: 51 additions & 55 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,64 +58,59 @@ getDalComm() {
#endif
JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port,
jobject param) {
jint computeDeviceOrdinal, jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");

auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);

#ifdef CPU_ONLY_PROFILE
auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

g_kvs.push_back(singletonCCLInit.kvs);
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

rank_id = getComm().rank();
comm_size = getComm().size();

#endif

ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
auto t1 = std::chrono::high_resolution_clock::now();
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
rank_id = getComm().rank();
comm_size = getComm().size();
break;
}
#ifdef CPU_GPU_PROFILE
t1 = std::chrono::high_resolution_clock::now();
auto kvs_attr = ccl::create_kvs_attr();

kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);

ccl::shared_ptr_class<ccl::kvs> kvs = ccl::create_main_kvs(kvs_attr);

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): create kvs took %f secs",
duration / 1000);
auto gpus = get_gpus();
sycl::queue queue{gpus[0]};
t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, kvs);
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
g_dal_comms.push_back(comm);
rank_id = getDalComm().get_rank();
comm_size = getDalComm().get_rank_count();
case ComputeDevice::gpu: {
auto gpus = get_gpus();
sycl::queue queue{gpus[0]};
auto t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank,
singletonCCLInit.kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
t1)
.count();
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
g_dal_comms.push_back(comm);
rank_id = getDalComm().get_rank();
comm_size = getDalComm().get_rank_count();
break;
}
#endif
default: {
deviceError("communicator",
ComputeDeviceString[computeDeviceOrdinal].c_str());
}
}
jclass cls = env->GetObjectClass(param);
jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J");
jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J");
Expand All @@ -130,12 +125,13 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEXPORT void JNICALL
Java_com_intel_oap_mllib_OneCCL_00024_c_1cleanup(JNIEnv *env, jobject obj) {
logger::printerrln(logger::INFO, "OneCCL (native): cleanup");
#ifdef CPU_ONLY_PROFILE
g_kvs.pop_back();
g_comms.pop_back();
#endif
if (!g_comms.empty()) {
g_comms.pop_back();
}
#ifdef CPU_GPU_PROFILE
g_dal_comms.pop_back();
if (!g_dal_comms.empty()) {
g_dal_comms.pop_back();
}
#endif
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.oap.mllib

import com.intel.oneapi.dal.table.Common
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD

Expand All @@ -25,14 +26,17 @@ object CommonJob {
kvsIPPort: String,
useDevice: String): Unit = {
data.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort)
OneCCL.init(executorNum, rank, kvsIPPort,
Common.ComputeDevice.getDeviceByName(useDevice).ordinal())
val gpuIndices = if (useDevice == "GPU") {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
} else {
null
Array.empty[Int]
}
if (gpuIndices.nonEmpty) {
OneCCL.setAffinityMask(gpuIndices(0).toString())
}
OneCCL.setAffinityMask(gpuIndices(0).toString())
Iterator.empty
}.count()
}
Expand Down
9 changes: 6 additions & 3 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.oap.mllib

import com.intel.oneapi.dal.table.Common
import org.apache.spark.internal.Logging

object OneCCL extends Logging {
Expand All @@ -24,12 +25,13 @@ object OneCCL extends Logging {

var cclParam = new CCLParam()

def init(executor_num: Int, rank: Int, ip_port: String): Unit = {
def init(executor_num: Int, rank: Int, ip_port: String,
computeDevice: Int = Common.ComputeDevice.CPU.ordinal()): Unit = {

logInfo(s"Initializing with IP_PORT: ${ip_port}")

// cclParam is output from native code
c_init(executor_num, rank, ip_port, cclParam)
c_init(executor_num, rank, ip_port, computeDevice, cclParam)

// executor number should equal to oneCCL world size
assert(executor_num == cclParam.getCommSize,
Expand Down Expand Up @@ -61,7 +63,8 @@ object OneCCL extends Logging {

@native def c_getAvailPort(localIP: String): Int

@native private def c_init(size: Int, rank: Int, ip_port: String, param: CCLParam): Int
@native private def c_init(size: Int, rank: Int, ip_port: String,
computeDevice: Int, param: CCLParam): Int

@native private def c_cleanup(): Unit
}

0 comments on commit 4b79dfd

Please sign in to comment.