diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index def3df933b..2bd0cf7135 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -15,6 +15,13 @@ namespace deepmd { typedef double ENERGYTYPE; enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; +/** + * @brief Get the backend of the model. + * @param[in] model The model name. + * @return The backend of the model. + **/ +DPBackend get_backend(const std::string& model); + struct NeighborListData { /// Array stores the core region atom's index std::vector ilist; diff --git a/source/api_cc/src/DataModifier.cc b/source/api_cc/src/DataModifier.cc index 23e0321410..4f319b4f66 100644 --- a/source/api_cc/src/DataModifier.cc +++ b/source/api_cc/src/DataModifier.cc @@ -28,8 +28,7 @@ void DipoleChargeModifier::init(const std::string& model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dcm = std::make_shared(model, gpu_rank, diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 8769f5b211..29f8b99cfd 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -39,17 +39,7 @@ void DeepPot::init(const std::string& model, << std::endl; return; } - DPBackend backend; - if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { - backend = deepmd::DPBackend::PyTorch; - } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { - backend = deepmd::DPBackend::TensorFlow; - } else if (model.length() >= 11 && - model.substr(model.length() - 11) == ".savedmodel") { - backend = deepmd::DPBackend::JAX; - } else { - throw deepmd::deepmd_exception("Unsupported model file format"); - } + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); diff --git a/source/api_cc/src/DeepSpin.cc b/source/api_cc/src/DeepSpin.cc index d761e9d3c2..eb37828410 100644 --- a/source/api_cc/src/DeepSpin.cc +++ b/source/api_cc/src/DeepSpin.cc @@ -36,14 +36,7 @@ void DeepSpin::init(const std::string& model, << std::endl; return; } - DPBackend backend; - if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { - backend = deepmd::DPBackend::PyTorch; - } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { - backend = deepmd::DPBackend::TensorFlow; - } else { - throw deepmd::deepmd_exception("Unsupported model file format"); - } + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); diff --git a/source/api_cc/src/DeepTensor.cc b/source/api_cc/src/DeepTensor.cc index a0596e046f..a9031472e6 100644 --- a/source/api_cc/src/DeepTensor.cc +++ b/source/api_cc/src/DeepTensor.cc @@ -30,8 +30,7 @@ void DeepTensor::init(const std::string &model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dt = std::make_shared(model, gpu_rank, name_scope_); diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index bd3f18c579..5a4f05d75c 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -1399,3 +1399,15 @@ void deepmd::print_summary(const std::string& pre) { << "set tf inter_op_parallelism_threads: " << num_inter_nthreads << std::endl; } + +deepmd::DPBackend deepmd::get_backend(const std::string& model) { + if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { + return deepmd::DPBackend::PyTorch; + } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { + return deepmd::DPBackend::TensorFlow; + } else if (model.length() >= 11 && + model.substr(model.length() - 11) == ".savedmodel") { + return deepmd::DPBackend::JAX; + } + throw deepmd::deepmd_exception("Unsupported model file format"); +}