diff --git a/mllib-dal/src/main/native/GPU.cpp b/mllib-dal/src/main/native/GPU.cpp index ecaa42121..872dc7dc9 100644 --- a/mllib-dal/src/main/native/GPU.cpp +++ b/mllib-dal/src/main/native/GPU.cpp @@ -115,7 +115,7 @@ sycl::queue getQueue(const ComputeDevice device) { preview::spmd::communicator createDalCommunicator(const jint executorNum, const jint rank, - const ccl::string ccl_ip_port) { + const ccl::string ccl_ip_port, std::string breakdown_name) { auto gpus = get_gpus(); auto t1 = std::chrono::high_resolution_clock::now(); @@ -127,10 +127,7 @@ createDalCommunicator(const jint executorNum, const jint rank, (float)std::chrono::duration_cast(t2 - t1) .count(); - logger::println(logger::INFO, "OneCCL singleton init took %f secs", - duration / 1000); - - t1 = std::chrono::high_resolution_clock::now(); + logger::Logger::getInstance(breakdown_name).printLogToFile("rankID was %d, OneCCL singleton init took %f secs.", rank, duration / 1000 ); auto kvs_attr = ccl::create_kvs_attr(); @@ -138,12 +135,6 @@ createDalCommunicator(const jint executorNum, const jint rank, ccl::shared_ptr_class kvs = ccl::create_main_kvs(kvs_attr); - t2 = std::chrono::high_resolution_clock::now(); - duration = - (float)std::chrono::duration_cast(t2 - t1) - .count(); - logger::println(logger::INFO, "OneCCL (native): create kvs took %f secs", - duration / 1000); sycl::queue queue{gpus[0]}; t1 = std::chrono::high_resolution_clock::now(); auto comm = preview::spmd::make_communicator( @@ -152,5 +143,6 @@ createDalCommunicator(const jint executorNum, const jint rank, duration = (float)std::chrono::duration_cast(t2 - t1) .count(); + logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rank, duration / 1000 ); return comm; } diff --git a/mllib-dal/src/main/native/GPU.h b/mllib-dal/src/main/native/GPU.h index 1056ef22a..b9832ed46 100644 --- a/mllib-dal/src/main/native/GPU.h +++ b/mllib-dal/src/main/native/GPU.h @@ -12,4 +12,4 @@ sycl::queue getAssignedGPU(const ComputeDevice device, jint *gpu_indices); sycl::queue getQueue(const ComputeDevice device); preview::spmd::communicator -createDalCommunicator(jint executorNum, jint rank, ccl::string ccl_ip_port); +createDalCommunicator(jint executorNum, jint rank, ccl::string ccl_ip_port, std::string breakdown_name);