Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the same stream as the OpenMM context when replaying a CUDA graph. #122

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ torch_force = TorchForce('model.pt', {'useCUDAGraphs': 'true'})
The first time the model is run, it will be compiled (also known as recording) into a CUDA graph. Subsequent runs will use the compiled graph, which can be significantly faster. It is possible that compilation fails, in which case an `OpenMMException` will be raised. If that happens, you can disable CUDA graphs and try again.

It is required to run the model at least once before recording, in what is known as warmup.
By default ```TorchForce``` will run the model just once before recording, but longer warming up might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```:
By default ```TorchForce``` will run the model just a few times before recording, but controlling warmup steps might be desired. In these cases one can set the property ```CUDAGraphWarmupSteps```:
raimis marked this conversation as resolved.
Show resolved Hide resolved
```python
torch_force.setProperty("CUDAGraphWarmupSteps", "12")
```
Expand Down
2 changes: 1 addition & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace OpenMM;
using namespace std;

TorchForce::TorchForce(const torch::jit::Module& module, const map<string, string>& properties) : file(), usePeriodic(false), outputsForces(false), module(module) {
const std::map<std::string, std::string> defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "1"}};
const std::map<std::string, std::string> defaultProperties = {{"useCUDAGraphs", "false"}, {"CUDAGraphWarmupSteps", "10"}};
this->properties = defaultProperties;
for (auto& property : properties) {
if (defaultProperties.find(property.first) == defaultProperties.end())
Expand Down
9 changes: 7 additions & 2 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
} else {
const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device());
const c10::cuda::CUDAStreamGuard guard(stream);
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
//CUDA graph capture must occur in a non-default stream
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
Expand All @@ -237,6 +238,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
}
// Use the same stream as the OpenMM context, even if it is the default stream
const auto openmmStream = cu.getCurrentStream();
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
graphs[includeForces].replay();
}
if (includeForces) {
Expand Down
Loading