Skip to content

Commit

Permalink
Use the same stream as the OpenMM context when replaying a CUDA graph.
Browse files Browse the repository at this point in the history
Use a non-default stream for CUDA graph capturing.
  • Loading branch information
RaulPPelaez committed Sep 26, 2023
1 parent 2270256 commit 6c332c7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
} else {
const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device());
const c10::cuda::CUDAStreamGuard guard(stream);
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
//CUDA graph capture must occur in a non-default stream
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
Expand All @@ -237,6 +238,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
}
// Use the same stream as the OpenMM context, even if it is the default stream
const auto openmmStream = cu.getCurrentStream();
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
graphs[includeForces].replay();
}
if (includeForces) {
Expand Down

0 comments on commit 6c332c7

Please sign in to comment.