diff --git a/mllib-dal/src/main/native/CCLInitSingleton.hpp b/mllib-dal/src/main/native/CCLInitSingleton.hpp index 2805f8e3f..4858a8eb3 100644 --- a/mllib-dal/src/main/native/CCLInitSingleton.hpp +++ b/mllib-dal/src/main/native/CCLInitSingleton.hpp @@ -41,14 +41,21 @@ class CCLInitSingleton { auto t1 = std::chrono::high_resolution_clock::now(); ccl::init(); + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = + (float)std::chrono::duration_cast(t2 - t1) + .count(); + logger::println(logger::INFO, "OneCCL (native): init took %f secs", + duration / 1000); + t1 = std::chrono::high_resolution_clock::now(); auto kvs_attr = ccl::create_kvs_attr(); kvs_attr.set(ccl_ip_port); kvs = ccl::create_main_kvs(kvs_attr); - auto t2 = std::chrono::high_resolution_clock::now(); - auto duration = + t2 = std::chrono::high_resolution_clock::now(); + duration = (float)std::chrono::duration_cast(t2 - t1).count(); logger::println(logger::INFO, "OneCCL singleton init took %f secs", diff --git a/mllib-dal/src/main/native/OneCCL.cpp b/mllib-dal/src/main/native/OneCCL.cpp index 587556d63..d9bdbb1be 100644 --- a/mllib-dal/src/main/native/OneCCL.cpp +++ b/mllib-dal/src/main/native/OneCCL.cpp @@ -58,64 +58,59 @@ getDalComm() { #endif JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init( JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, - jobject param) { + jint computeDeviceOrdinal, jobject param) { logger::println(logger::INFO, "OneCCL (native): init"); - - auto t1 = std::chrono::high_resolution_clock::now(); - - ccl::init(); - auto t2 = std::chrono::high_resolution_clock::now(); - auto duration = - (float)std::chrono::duration_cast(t2 - t1) - .count(); - logger::println(logger::INFO, "OneCCL (native): init took %f secs", - duration / 1000); const char *str = env->GetStringUTFChars(ip_port, 0); ccl::string ccl_ip_port(str); - -#ifdef CPU_ONLY_PROFILE auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port); - g_kvs.push_back(singletonCCLInit.kvs); - g_comms.push_back( - ccl::create_communicator(size, rank, singletonCCLInit.kvs)); - - rank_id = getComm().rank(); - comm_size = getComm().size(); - -#endif - + ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal); + switch (device) { + case ComputeDevice::host: + case ComputeDevice::cpu: { + auto t1 = std::chrono::high_resolution_clock::now(); + g_comms.push_back( + ccl::create_communicator(size, rank, singletonCCLInit.kvs)); + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = + (float)std::chrono::duration_cast(t2 - + t1) + .count(); + logger::println(logger::INFO, + "OneCCL (native): create communicator took %f secs", + duration / 1000); + rank_id = getComm().rank(); + comm_size = getComm().size(); + break; + } #ifdef CPU_GPU_PROFILE - t1 = std::chrono::high_resolution_clock::now(); - auto kvs_attr = ccl::create_kvs_attr(); - - kvs_attr.set(ccl_ip_port); - - ccl::shared_ptr_class kvs = ccl::create_main_kvs(kvs_attr); - - t2 = std::chrono::high_resolution_clock::now(); - duration = - (float)std::chrono::duration_cast(t2 - t1) - .count(); - logger::println(logger::INFO, "OneCCL (native): create kvs took %f secs", - duration / 1000); - auto gpus = get_gpus(); - sycl::queue queue{gpus[0]}; - t1 = std::chrono::high_resolution_clock::now(); - auto comm = oneapi::dal::preview::spmd::make_communicator< - oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, kvs); - t2 = std::chrono::high_resolution_clock::now(); - duration = - (float)std::chrono::duration_cast(t2 - t1) - .count(); - logger::println(logger::INFO, - "OneCCL (native): create communicator took %f secs", - duration / 1000); - g_dal_comms.push_back(comm); - rank_id = getDalComm().get_rank(); - comm_size = getDalComm().get_rank_count(); + case ComputeDevice::gpu: { + auto gpus = get_gpus(); + sycl::queue queue{gpus[0]}; + auto t1 = std::chrono::high_resolution_clock::now(); + auto comm = oneapi::dal::preview::spmd::make_communicator< + oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, + singletonCCLInit.kvs); + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = + (float)std::chrono::duration_cast(t2 - + t1) + .count(); + logger::println(logger::INFO, + "OneCCL (native): create communicator took %f secs", + duration / 1000); + g_dal_comms.push_back(comm); + rank_id = getDalComm().get_rank(); + comm_size = getDalComm().get_rank_count(); + break; + } #endif + default: { + deviceError("communicator", + ComputeDeviceString[computeDeviceOrdinal].c_str()); + } + } jclass cls = env->GetObjectClass(param); jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J"); jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J"); @@ -130,12 +125,13 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init( JNIEXPORT void JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1cleanup(JNIEnv *env, jobject obj) { logger::printerrln(logger::INFO, "OneCCL (native): cleanup"); -#ifdef CPU_ONLY_PROFILE - g_kvs.pop_back(); - g_comms.pop_back(); -#endif + if (!g_comms.empty()) { + g_comms.pop_back(); + } #ifdef CPU_GPU_PROFILE - g_dal_comms.pop_back(); + if (!g_dal_comms.empty()) { + g_dal_comms.pop_back(); + } #endif } diff --git a/mllib-dal/src/main/native/javah/com_intel_oap_mllib_OneCCL__.h b/mllib-dal/src/main/native/javah/com_intel_oap_mllib_OneCCL__.h index a89b7d214..5534acef0 100644 --- a/mllib-dal/src/main/native/javah/com_intel_oap_mllib_OneCCL__.h +++ b/mllib-dal/src/main/native/javah/com_intel_oap_mllib_OneCCL__.h @@ -45,7 +45,7 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1getAvailPort * Signature: (IILjava/lang/String;Lcom/intel/oap/mllib/CCLParam;)I */ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init - (JNIEnv *, jobject, jint, jint, jstring, jobject); + (JNIEnv *, jobject, jint, jint, jstring, jint, jobject); /* * Class: com_intel_oap_mllib_OneCCL__ diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala index bd2a7be18..000768081 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala @@ -16,6 +16,7 @@ package com.intel.oap.mllib +import com.intel.oneapi.dal.table.Common import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD @@ -25,14 +26,17 @@ object CommonJob { kvsIPPort: String, useDevice: String): Unit = { data.mapPartitionsWithIndex { (rank, table) => - OneCCL.init(executorNum, rank, kvsIPPort) + OneCCL.init(executorNum, rank, kvsIPPort, + Common.ComputeDevice.getDeviceByName(useDevice).ordinal()) val gpuIndices = if (useDevice == "GPU") { val resources = TaskContext.get().resources() resources("gpu").addresses.map(_.toInt) } else { - null + Array.empty[Int] + } + if (gpuIndices.nonEmpty) { + OneCCL.setAffinityMask(gpuIndices(0).toString()) } - OneCCL.setAffinityMask(gpuIndices(0).toString()) Iterator.empty }.count() } diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala index 87289d559..ea83b83d7 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/OneCCL.scala @@ -16,6 +16,7 @@ package com.intel.oap.mllib +import com.intel.oneapi.dal.table.Common import org.apache.spark.internal.Logging object OneCCL extends Logging { @@ -24,12 +25,13 @@ object OneCCL extends Logging { var cclParam = new CCLParam() - def init(executor_num: Int, rank: Int, ip_port: String): Unit = { + def init(executor_num: Int, rank: Int, ip_port: String, + computeDevice: Int = Common.ComputeDevice.CPU.ordinal()): Unit = { logInfo(s"Initializing with IP_PORT: ${ip_port}") // cclParam is output from native code - c_init(executor_num, rank, ip_port, cclParam) + c_init(executor_num, rank, ip_port, computeDevice, cclParam) // executor number should equal to oneCCL world size assert(executor_num == cclParam.getCommSize, @@ -61,7 +63,8 @@ object OneCCL extends Logging { @native def c_getAvailPort(localIP: String): Int - @native private def c_init(size: Int, rank: Int, ip_port: String, param: CCLParam): Int + @native private def c_init(size: Int, rank: Int, ip_port: String, + computeDevice: Int, param: CCLParam): Int @native private def c_cleanup(): Unit }