diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index d2d869ae..2cdb1c44 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -188,7 +188,11 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr energyTensor = module.forward(inputs).toTensor(); // Compute force by backpropagating the PyTorch model if (includeForces) { - energyTensor.backward(); + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + auto none = torch::Tensor(); + energyTensor.backward(none, false, false, posTensor); + // This is minus the forces, we change the sign later on forceTensor = posTensor.grad().clone(); // Zero the gradient to avoid accumulating it posTensor.grad().zero_();