From 6c332c723d17d1aa4e8962027061da78e828641b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 26 Sep 2023 12:52:29 +0200 Subject: [PATCH] Use the same stream as the OpenMM context when replaying a CUDA graph. Use a non-default stream for CUDA graph capturing. --- platforms/cuda/src/CudaTorchKernels.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 2cdb1c44..60ea7794 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -207,11 +207,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce if (!useGraphs) { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); } else { - const auto stream = c10::cuda::getStreamFromPool(false, posTensor.get_device()); - const c10::cuda::CUDAStreamGuard guard(stream); // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { + //CUDA graph capture must occur in a non-default stream + const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -237,6 +238,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } + // Use the same stream as the OpenMM context, even if it is the default stream + const auto openmmStream = cu.getCurrentStream(); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) {