diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 2cdb1c44..60ea7794 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -207,11 +207,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce if (!useGraphs) { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); } else { - const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device()); - const c10::cuda::CUDAStreamGuard guard(stream); // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { + //CUDA graph capture must occur in a non-default stream + const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -237,6 +238,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } + // Use the same stream as the OpenMM context, even if it is the default stream + const auto openmmStream = cu.getCurrentStream(); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) {