diff --git a/README.md b/README.md index cdfe4160..853abab8 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'}) The first time the model is run, it will be compiled (also known as recording) into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again. It is required to run the model at least once before recording, in what is known as warmup. -By default ```TorchForce``` will run the model just once before recording, but longer warming up might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```: +By default ```TorchForce``` will run the model just a few times before recording, but controlling warmup steps might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```: ```python torch_force.setProperty("CUDAGraphWarmupSteps", "12") ``` diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index 41aa1fb2..e8892048 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -42,7 +42,7 @@ using namespace OpenMM; using namespace std; TorchForce::TorchForce(const torch::jit::Module& module, const map& properties) : file(), usePeriodic(false), outputsForces(false), module(module) { - const std::map defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "1"}}; + const std::map defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "10"}}; this->properties = defaultProperties; for (auto& property : properties) { if (defaultProperties.find(property.first) == defaultProperties.end()) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 2cdb1c44..60ea7794 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -207,11 +207,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce if (!useGraphs) { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); } else { - const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device()); - const c10::cuda::CUDAStreamGuard guard(stream); // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { + //CUDA graph capture must occur in a non-default stream + const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -237,6 +238,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } + // Use the same stream as the OpenMM context, even if it is the default stream + const auto openmmStream = cu.getCurrentStream(); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) {