From 575d427725cc1c2d6425204a990925b7d7671c80 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 26 Sep 2023 14:13:53 +0200 Subject: [PATCH] Increase default warmup steps to 10 --- README.md | 2 +- openmmapi/src/TorchForce.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cdfe4160..853abab8 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'}) The first time the model is run, it will be compiled (also known as recording) into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again. It is required to run the model at least once before recording, in what is known as warmup. -By default ```TorchForce``` will run the model just once before recording, but longer warming up might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```: +By default ```TorchForce``` will run the model just a few times before recording, but controlling warmup steps might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```: ```python torch_force.setProperty("CUDAGraphWarmupSteps", "12") ``` diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index 41aa1fb2..e8892048 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -42,7 +42,7 @@ using namespace OpenMM; using namespace std; TorchForce::TorchForce(const torch::jit::Module& module, const map& properties) : file(), usePeriodic(false), outputsForces(false), module(module) { - const std::map defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "1"}}; + const std::map defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "10"}}; this->properties = defaultProperties; for (auto& property : properties) { if (defaultProperties.find(property.first) == defaultProperties.end())