From 151c9e251617c9b7a091c4c2b9c54e5e415a3446 Mon Sep 17 00:00:00 2001 From: peastman Date: Wed, 15 May 2024 09:17:58 -0700 Subject: [PATCH 1/6] Can compute parameter derivatives --- openmmapi/include/TorchForce.h | 24 +++++++++++++- openmmapi/src/TorchForce.cpp | 20 ++++++++++- platforms/cuda/src/CudaTorchKernels.cpp | 33 +++++++++++++++---- platforms/cuda/src/CudaTorchKernels.h | 6 ++-- platforms/cuda/tests/TestCudaTorchForce.cpp | 15 +++++++-- platforms/opencl/src/OpenCLTorchKernels.cpp | 26 +++++++++++++-- platforms/opencl/src/OpenCLTorchKernels.h | 4 ++- .../opencl/tests/TestOpenCLTorchForce.cpp | 15 +++++++-- .../reference/src/ReferenceTorchKernels.cpp | 29 ++++++++++++++-- .../reference/src/ReferenceTorchKernels.h | 2 ++ .../tests/TestReferenceTorchForce.cpp | 13 +++++++- 11 files changed, 164 insertions(+), 23 deletions(-) diff --git a/openmmapi/include/TorchForce.h b/openmmapi/include/TorchForce.h index 4406eb20..d49ebe73 100644 --- a/openmmapi/include/TorchForce.h +++ b/openmmapi/include/TorchForce.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * @@ -36,6 +36,7 @@ #include "openmm/Force.h" #include #include +#include #include #include "internal/windowsExportTorch.h" @@ -106,6 +107,11 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * Get the number of global parameters that the interaction depends on. */ int getNumGlobalParameters() const; + /** + * Get the number of global parameters with respect to which the derivative of the energy + * should be computed. + */ + int getNumEnergyParameterDerivatives() const; /** * Add a new global parameter that the interaction may depend on. The default value provided to * this method is the initial value of the parameter in newly created Contexts. You can change @@ -144,6 +150,21 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { * @param defaultValue the default value of the parameter */ void setGlobalParameterDefaultValue(int index, double defaultValue); + /** + * Request that this Force compute the derivative of its energy with respect to a global parameter. + * The parameter must have already been added with addGlobalParameter(). + * + * @param name the name of the parameter + */ + void addEnergyParameterDerivative(const std::string& name); + /** + * Get the name of a global parameter with respect to which this Force should compute the + * derivative of the energy. + * + * @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives() + * @return the parameter name + */ + const std::string& getEnergyParameterDerivativeName(int index) const; /** * Set a value of a property. * @@ -163,6 +184,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force { std::string file; bool usePeriodic, outputsForces; std::vector globalParameters; + std::vector energyParameterDerivatives; torch::jit::Module module; std::map properties; std::string emptyProperty; diff --git a/openmmapi/src/TorchForce.cpp b/openmmapi/src/TorchForce.cpp index e8892048..32dcfcbe 100644 --- a/openmmapi/src/TorchForce.cpp +++ b/openmmapi/src/TorchForce.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * @@ -112,6 +112,24 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue) globalParameters[index].defaultValue = defaultValue; } +int TorchForce::getNumEnergyParameterDerivatives() const { + return energyParameterDerivatives.size(); +} + +void TorchForce::addEnergyParameterDerivative(const string& name) { + for (int i = 0; i < globalParameters.size(); i++) + if (name == globalParameters[i].name) { + energyParameterDerivatives.push_back(i); + return; + } + throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'")); +} + +const string& TorchForce::getEnergyParameterDerivativeName(int index) const { + ASSERT_VALID_INDEX(index, energyParameterDerivatives); + return globalParameters[energyParameterDerivatives[index]].name; +} + void TorchForce::setProperty(const std::string& name, const std::string& value) { if (properties.find(name) == properties.end()) throw OpenMMException("TorchForce: Unknown property '" + name + "'"); diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 60ea7794..8d8fa066 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * @@ -66,6 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { + paramDerivs.insert(force.getEnergyParameterDerivativeName(i)); + cu.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i)); + } int numParticles = system.getNumParticles(); // Push the PyTorch context @@ -125,7 +129,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) { /** * Prepare the inputs for the PyTorch model, copying positions from the OpenMM context. */ -std::vector CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) { +void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector& inputs, map& derivInputs) { int numParticles = cu.getNumAtoms(); // Get pointers to the atomic positions and simulation box void* posData = getTensorPointer(cu, posTensor); @@ -145,11 +149,16 @@ std::vector CudaCalcTorchForceKernel::prepareTorchInputs(Con CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context } // Prepare the input of the PyTorch model - vector inputs = {posTensor}; + inputs = {posTensor}; if (usePeriodic) inputs.push_back(boxTensor); - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + for (const string& name : globalNames) { + bool requiresGrad = (paramDerivs.find(name) != paramDerivs.end()); + torch::Tensor globalTensor = torch::tensor(context.getParameter(name), torch::TensorOptions().requires_grad(requiresGrad)); + inputs.push_back(globalTensor); + if (requiresGrad) + derivInputs[name] = globalTensor; + } return inputs; } @@ -203,7 +212,9 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { // Push to the PyTorch context CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); - auto inputs = prepareTorchInputs(context); + vector inputs; + map derivInputs; + prepareTorchInputs(context, inputs, derivInputs); if (!useGraphs) { executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); } else { @@ -239,7 +250,7 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce } } // Use the same stream as the OpenMM context, even if it is the default stream - const auto openmmStream = cu.getCurrentStream(); + const auto openmmStream = cu.getCurrentStream(); const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); @@ -247,6 +258,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce if (includeForces) { addForces(forceTensor); } + map& energyParamDerivs = cu.getEnergyParamDerivWorkspace(); + for (const string& name : paramDerivs) { + if (!hasComputedBackward) { + energyTensor.backward(); + hasComputedBackward = true; + } + energyParamDerivs[name] += derivInputs[name].grad().item(); + } // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context // Pop to the PyTorch context diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index 13f2a9b6..bbd46e89 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: Raimondas Galvelis, Raul P. Pelaez * * * @@ -37,6 +37,7 @@ #include "openmm/cuda/CudaArray.h" #include #include +#include namespace TorchPlugin { @@ -72,11 +73,12 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { torch::Tensor posTensor, boxTensor; torch::Tensor energyTensor, forceTensor; std::vector globalNames; + std::set paramDerivs; bool usePeriodic, outputsForces; CUfunction copyInputsKernel, addForcesKernel; CUcontext primaryContext; std::map graphs; - std::vector prepareTorchInputs(OpenMM::ContextImpl& context); + void prepareTorchInputs(OpenMM::ContextImpl& context, std::vector& inputs, std::map& derivInputs); bool useGraphs; void addForces(torch::Tensor& forceTensor); int warmupSteps; diff --git a/platforms/cuda/tests/TestCudaTorchForce.cpp b/platforms/cuda/tests/TestCudaTorchForce.cpp index 8e8c56b9..6c3160f8 100644 --- a/platforms/cuda/tests/TestCudaTorchForce.cpp +++ b/platforms/cuda/tests/TestCudaTorchForce.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -143,6 +143,7 @@ void testGlobal() { } TorchForce* force = new TorchForce("tests/global.pt"); force->addGlobalParameter("k", 2.0); + force->addEnergyParameterDerivative("k"); system.addForce(force); // Compute the forces and energy. @@ -167,12 +168,22 @@ void testGlobal() { // Change the global parameter and see if the forces are still correct. context.setParameter("k", 3.0); - state = context.getState(State::Forces); + state = context.getState(State::Forces | State::ParameterDerivatives); for (int i = 0; i < numParticles; i++) { Vec3 pos = positions[i]; double r = sqrt(pos.dot(pos)); ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); } + + // Check the gradient of the energy with respect to the parameter. + + double expected = 0.0; + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + expected += pos.dot(pos); + } + double actual = state.getEnergyParameterDerivatives().at("k"); + ASSERT_EQUAL_TOL(expected, actual, 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/opencl/src/OpenCLTorchKernels.cpp b/platforms/opencl/src/OpenCLTorchKernels.cpp index a232b1c6..00bb7dd9 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.cpp +++ b/platforms/opencl/src/OpenCLTorchKernels.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -47,6 +47,10 @@ void OpenCLCalcTorchForceKernel::initialize(const System& system, const TorchFor outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { + paramDerivs.insert(force.getEnergyParameterDerivativeName(i)); + cl.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i)); + } int numParticles = system.getNumParticles(); // Inititalize OpenCL objects. @@ -81,8 +85,14 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor boxTensor = boxTensor.to(torch::kFloat32); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + map derivInputs; + for (const string& name : globalNames) { + bool requiresGrad = (paramDerivs.find(name) != paramDerivs.end()); + torch::Tensor globalTensor = torch::tensor(context.getParameter(name), torch::TensorOptions().requires_grad(requiresGrad)); + inputs.push_back(globalTensor); + if (requiresGrad) + derivInputs[name] = globalTensor; + } torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -91,10 +101,12 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor } else energyTensor = module.forward(inputs).toTensor(); + bool hasComputedBackward = false; if (includeForces) { if (!outputsForces) { energyTensor.backward(); forceTensor = posTensor.grad(); + hasComputedBackward = true; } if (cl.getUseDoublePrecision()) { if (!(forceTensor.dtype() == torch::kFloat64)) @@ -115,6 +127,14 @@ double OpenCLCalcTorchForceKernel::execute(ContextImpl& context, bool includeFor addForcesKernel.setArg(4, outputsForces ? 1 : -1); cl.executeKernel(addForcesKernel, numParticles); } + map& energyParamDerivs = cl.getEnergyParamDerivWorkspace(); + for (const string& name : paramDerivs) { + if (!hasComputedBackward) { + energyTensor.backward(); + hasComputedBackward = true; + } + energyParamDerivs[name] += derivInputs[name].grad().item(); + } return energyTensor.item(); } diff --git a/platforms/opencl/src/OpenCLTorchKernels.h b/platforms/opencl/src/OpenCLTorchKernels.h index d5ccdee9..9e84b10d 100644 --- a/platforms/opencl/src/OpenCLTorchKernels.h +++ b/platforms/opencl/src/OpenCLTorchKernels.h @@ -9,7 +9,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -35,6 +35,7 @@ #include "TorchKernels.h" #include "openmm/opencl/OpenCLContext.h" #include "openmm/opencl/OpenCLArray.h" +#include namespace TorchPlugin { @@ -69,6 +70,7 @@ class OpenCLCalcTorchForceKernel : public CalcTorchForceKernel { OpenMM::OpenCLContext& cl; torch::jit::script::Module module; std::vector globalNames; + std::set paramDerivs; bool usePeriodic, outputsForces; OpenMM::OpenCLArray networkForces; cl::Kernel addForcesKernel; diff --git a/platforms/opencl/tests/TestOpenCLTorchForce.cpp b/platforms/opencl/tests/TestOpenCLTorchForce.cpp index c96d34fd..73a642f0 100644 --- a/platforms/opencl/tests/TestOpenCLTorchForce.cpp +++ b/platforms/opencl/tests/TestOpenCLTorchForce.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -140,6 +140,7 @@ void testGlobal() { } TorchForce* force = new TorchForce("tests/global.pt"); force->addGlobalParameter("k", 2.0); + force->addEnergyParameterDerivative("k"); system.addForce(force); // Compute the forces and energy. @@ -164,12 +165,22 @@ void testGlobal() { // Change the global parameter and see if the forces are still correct. context.setParameter("k", 3.0); - state = context.getState(State::Forces); + state = context.getState(State::Forces | State::ParameterDerivatives); for (int i = 0; i < numParticles; i++) { Vec3 pos = positions[i]; double r = sqrt(pos.dot(pos)); ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); } + + // Check the gradient of the energy with respect to the parameter. + + double expected = 0.0; + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + expected += pos.dot(pos); + } + double actual = state.getEnergyParameterDerivatives().at("k"); + ASSERT_EQUAL_TOL(expected, actual, 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/reference/src/ReferenceTorchKernels.cpp b/platforms/reference/src/ReferenceTorchKernels.cpp index 5de8a2b8..43a9509c 100644 --- a/platforms/reference/src/ReferenceTorchKernels.cpp +++ b/platforms/reference/src/ReferenceTorchKernels.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -54,6 +54,11 @@ static Vec3* extractBoxVectors(ContextImpl& context) { return data->periodicBoxVectors; } +static map& extractEnergyParameterDerivatives(ContextImpl& context) { + ReferencePlatform::PlatformData* data = reinterpret_cast(context.getPlatformData()); + return *data->energyParameterDerivatives; +} + ReferenceCalcTorchForceKernel::~ReferenceCalcTorchForceKernel() { } @@ -63,6 +68,8 @@ void ReferenceCalcTorchForceKernel::initialize(const System& system, const Torch outputsForces = force.getOutputsForces(); for (int i = 0; i < force.getNumGlobalParameters(); i++) globalNames.push_back(force.getGlobalParameterName(i)); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + paramDerivs.insert(force.getEnergyParameterDerivativeName(i)); } double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { @@ -76,8 +83,14 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include torch::Tensor boxTensor = torch::from_blob(box, {3, 3}, torch::kFloat64); inputs.push_back(boxTensor); } - for (const string& name : globalNames) - inputs.push_back(torch::tensor(context.getParameter(name))); + map derivInputs; + for (const string& name : globalNames) { + bool requiresGrad = (paramDerivs.find(name) != paramDerivs.end()); + torch::Tensor globalTensor = torch::tensor(context.getParameter(name), torch::TensorOptions().dtype(torch::kFloat64).requires_grad(requiresGrad)); + inputs.push_back(globalTensor); + if (requiresGrad) + derivInputs[name] = globalTensor; + } torch::Tensor energyTensor, forceTensor; if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -86,10 +99,12 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include } else energyTensor = module.forward(inputs).toTensor(); + bool hasComputedBackward = false; if (includeForces) { if (!outputsForces) { energyTensor.backward(); forceTensor = posTensor.grad(); + hasComputedBackward = true; } if (!(forceTensor.dtype() == torch::kFloat64)) forceTensor = forceTensor.to(torch::kFloat64); @@ -99,5 +114,13 @@ double ReferenceCalcTorchForceKernel::execute(ContextImpl& context, bool include for (int j = 0; j < 3; j++) force[i][j] += forceSign*outputForces[3*i+j]; } + map& energyParamDerivs = extractEnergyParameterDerivatives(context); + for (const string& name : paramDerivs) { + if (!hasComputedBackward) { + energyTensor.backward(); + hasComputedBackward = true; + } + energyParamDerivs[name] += derivInputs[name].grad().item(); + } return energyTensor.item(); } diff --git a/platforms/reference/src/ReferenceTorchKernels.h b/platforms/reference/src/ReferenceTorchKernels.h index f3abc1d2..4d080ab2 100644 --- a/platforms/reference/src/ReferenceTorchKernels.h +++ b/platforms/reference/src/ReferenceTorchKernels.h @@ -34,6 +34,7 @@ #include "TorchKernels.h" #include "openmm/Platform.h" +#include #include namespace TorchPlugin { @@ -67,6 +68,7 @@ class ReferenceCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; std::vector positions, boxVectors; std::vector globalNames; + std::set paramDerivs; bool usePeriodic, outputsForces; }; diff --git a/platforms/reference/tests/TestReferenceTorchForce.cpp b/platforms/reference/tests/TestReferenceTorchForce.cpp index e4c09b31..63dece2c 100644 --- a/platforms/reference/tests/TestReferenceTorchForce.cpp +++ b/platforms/reference/tests/TestReferenceTorchForce.cpp @@ -140,6 +140,7 @@ void testGlobal() { } TorchForce* force = new TorchForce("tests/global.pt"); force->addGlobalParameter("k", 2.0); + force->addEnergyParameterDerivative("k"); system.addForce(force); // Compute the forces and energy. @@ -164,12 +165,22 @@ void testGlobal() { // Change the global parameter and see if the forces are still correct. context.setParameter("k", 3.0); - state = context.getState(State::Forces); + state = context.getState(State::Forces | State::ParameterDerivatives); for (int i = 0; i < numParticles; i++) { Vec3 pos = positions[i]; double r = sqrt(pos.dot(pos)); ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); } + + // Check the gradient of the energy with respect to the parameter. + + double expected = 0.0; + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + expected += pos.dot(pos); + } + double actual = state.getEnergyParameterDerivatives().at("k"); + ASSERT_EQUAL_TOL(expected, actual, 1e-5); } int main() { From f751c654020489fc99426593ca62053a34eb0d87 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 15 May 2024 10:31:12 -0700 Subject: [PATCH 2/6] Fixes to CUDA implemenetation --- platforms/cuda/src/CudaTorchKernels.cpp | 47 +++++++++++++------------ 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 8d8fa066..640e58d7 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -159,7 +159,6 @@ void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor, - torch::Tensor& forceTensor) { + torch::Tensor& forceTensor, map& derivInputs) { + vector gradInputs; + if (!outputsForces) + gradInputs.push_back(posTensor); + for (auto& deriv : derivInputs) + gradInputs.push_back(deriv.second); + auto none = torch::Tensor(); if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); energyTensor = outputs->elements()[0].toTensor(); forceTensor = outputs->elements()[1].toTensor(); + if (gradInputs.size() > 0) + energyTensor.backward(none, false, false, gradInputs); } else { energyTensor = module.forward(inputs).toTensor(); // Compute force by backpropagating the PyTorch model if (includeForces) { // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions - // See https://github.com/openmm/openmm-torch/pull/120/ - auto none = torch::Tensor(); - energyTensor.backward(none, false, false, posTensor); - // This is minus the forces, we change the sign later on + // See https://github.com/openmm/openmm-torch/pull/120/ + energyTensor.backward(none, false, false, gradInputs); + // This is minus the forces, we change the sign later on forceTensor = posTensor.grad().clone(); // Zero the gradient to avoid accumulating it posTensor.grad().zero_(); @@ -216,14 +222,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce map derivInputs; prepareTorchInputs(context, inputs, derivInputs); if (!useGraphs) { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); } else { // Record graph if not already done bool is_graph_captured = false; if (graphs.find(includeForces) == graphs.end()) { //CUDA graph capture must occur in a non-default stream const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + const c10::cuda::CUDAStreamGuard guard(stream); // Warmup the graph workload before capturing. This first // run before capture sets up allocations so that no // allocations are needed after. Pytorch's allocator is @@ -231,14 +237,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // record static pointers and shapes during capture. try { for (int i = 0; i < this->warmupSteps; i++) - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); } catch (std::exception& e) { throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what()); } graphs[includeForces].capture_begin(); try { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); is_graph_captured = true; graphs[includeForces].capture_end(); } @@ -249,23 +255,18 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } } - // Use the same stream as the OpenMM context, even if it is the default stream - const auto openmmStream = cu.getCurrentStream(); - const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); - const c10::cuda::CUDAStreamGuard guard(stream); + // Use the same stream as the OpenMM context, even if it is the default stream + const auto openmmStream = cu.getCurrentStream(); + const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex()); + const c10::cuda::CUDAStreamGuard guard(stream); graphs[includeForces].replay(); } if (includeForces) { addForces(forceTensor); } map& energyParamDerivs = cu.getEnergyParamDerivWorkspace(); - for (const string& name : paramDerivs) { - if (!hasComputedBackward) { - energyTensor.backward(); - hasComputedBackward = true; - } + for (const string& name : paramDerivs) energyParamDerivs[name] += derivInputs[name].grad().item(); - } // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context // Pop to the PyTorch context From 0fc454b37f6bc1915aba59637766b87860d501ff Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 15 May 2024 10:35:37 -0700 Subject: [PATCH 3/6] Python interface for parameter derivatives --- python/openmmtorch.i | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/openmmtorch.i b/python/openmmtorch.i index 05988e3d..558ce0fd 100644 --- a/python/openmmtorch.i +++ b/python/openmmtorch.i @@ -69,11 +69,14 @@ public: void setOutputsForces(bool); bool getOutputsForces() const; int getNumGlobalParameters() const; + int getNumEnergyParameterDerivatives() const; int addGlobalParameter(const std::string& name, double defaultValue); const std::string& getGlobalParameterName(int index) const; void setGlobalParameterName(int index, const std::string& name); double getGlobalParameterDefaultValue(int index) const; void setGlobalParameterDefaultValue(int index, double defaultValue); + void addEnergyParameterDerivative(const std::string& name); + const std::string& getEnergyParameterDerivativeName(int index) const; void setProperty(const std::string& name, const std::string& value); const std::map& getProperties() const; From e801c10e24cf2826c7ba5aaec87840b90eae395d Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 15 May 2024 10:46:09 -0700 Subject: [PATCH 4/6] Serialization of parameter derivatives --- serialization/src/TorchForceProxy.cpp | 15 ++++++++++----- serialization/tests/TestSerializeTorchForce.cpp | 10 +++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/serialization/src/TorchForceProxy.cpp b/serialization/src/TorchForceProxy.cpp index 59ee977f..6d9d3122 100644 --- a/serialization/src/TorchForceProxy.cpp +++ b/serialization/src/TorchForceProxy.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -74,7 +74,7 @@ TorchForceProxy::TorchForceProxy() : SerializationProxy("TorchForce") { } void TorchForceProxy::serialize(const void* object, SerializationNode& node) const { - node.setIntProperty("version", 2); + node.setIntProperty("version", 3); const TorchForce& force = *reinterpret_cast(object); node.setStringProperty("file", force.getFile()); try { @@ -90,14 +90,16 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions()); node.setBoolProperty("outputsForces", force.getOutputsForces()); SerializationNode& globalParams = node.createChildNode("GlobalParameters"); - for (int i = 0; i < force.getNumGlobalParameters(); i++) { + for (int i = 0; i < force.getNumGlobalParameters(); i++) globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); - } + SerializationNode& paramDerivs = node.createChildNode("ParameterDerivatives"); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + paramDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i)); } void* TorchForceProxy::deserialize(const SerializationNode& node) const { int storedVersion = node.getIntProperty("version"); - if (storedVersion > 2) + if (storedVersion > 3) throw OpenMMException("Unsupported version number"); TorchForce* force; if (storedVersion == 1) { @@ -121,6 +123,9 @@ void* TorchForceProxy::deserialize(const SerializationNode& node) const { if (child.getName() == "GlobalParameters") for (auto& parameter : child.getChildren()) force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); + if (child.getName() == "ParameterDerivatives") + for (auto& parameter : child.getChildren()) + force->addEnergyParameterDerivative(parameter.getStringProperty("name")); } return force; } diff --git a/serialization/tests/TestSerializeTorchForce.cpp b/serialization/tests/TestSerializeTorchForce.cpp index eb5f65cd..410f00a6 100644 --- a/serialization/tests/TestSerializeTorchForce.cpp +++ b/serialization/tests/TestSerializeTorchForce.cpp @@ -6,7 +6,7 @@ * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * - * Portions copyright (c) 2018-2022 Stanford University and the Authors. * + * Portions copyright (c) 2018-2024 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * @@ -50,6 +50,7 @@ void serializeAndDeserialize(TorchForce force) { force.addGlobalParameter("y", 2.221); force.setUsesPeriodicBoundaryConditions(true); force.setOutputsForces(true); + force.addEnergyParameterDerivative("y"); // Serialize and then deserialize it. @@ -73,17 +74,20 @@ void serializeAndDeserialize(TorchForce force) { } ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); ASSERT_EQUAL(force.getOutputsForces(), force2.getOutputsForces()); + ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives()); + for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) + ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i)); } void testSerializationFromModule() { - string fileName = "../../tests/forces.pt"; + string fileName = "tests/forces.pt"; torch::jit::Module module = torch::jit::load(fileName); TorchForce force(module); serializeAndDeserialize(force); } void testSerializationFromFile() { - string fileName = "../../tests/forces.pt"; + string fileName = "tests/forces.pt"; TorchForce force(fileName); serializeAndDeserialize(force); } From fde51ffaad1af2c67a257918f8c65ee7baf1a30b Mon Sep 17 00:00:00 2001 From: peastman Date: Wed, 15 May 2024 10:54:59 -0700 Subject: [PATCH 5/6] Improvements to tests --- platforms/cuda/tests/TestCudaTorchForce.cpp | 23 ++++++++++--------- .../opencl/tests/TestOpenCLTorchForce.cpp | 23 ++++++++++--------- .../tests/TestReferenceTorchForce.cpp | 23 ++++++++++--------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/platforms/cuda/tests/TestCudaTorchForce.cpp b/platforms/cuda/tests/TestCudaTorchForce.cpp index 6c3160f8..5b8332ef 100644 --- a/platforms/cuda/tests/TestCudaTorchForce.cpp +++ b/platforms/cuda/tests/TestCudaTorchForce.cpp @@ -152,7 +152,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("CUDA"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -165,16 +165,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -184,6 +174,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/opencl/tests/TestOpenCLTorchForce.cpp b/platforms/opencl/tests/TestOpenCLTorchForce.cpp index 73a642f0..d42d68a3 100644 --- a/platforms/opencl/tests/TestOpenCLTorchForce.cpp +++ b/platforms/opencl/tests/TestOpenCLTorchForce.cpp @@ -149,7 +149,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("OpenCL"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -162,16 +162,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -181,6 +171,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/reference/tests/TestReferenceTorchForce.cpp b/platforms/reference/tests/TestReferenceTorchForce.cpp index 63dece2c..e8885cff 100644 --- a/platforms/reference/tests/TestReferenceTorchForce.cpp +++ b/platforms/reference/tests/TestReferenceTorchForce.cpp @@ -149,7 +149,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("Reference"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -162,16 +162,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -181,6 +171,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main() { From 1d83b867a10fde552aa93018ebfe3cbb474e15a0 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Thu, 16 May 2024 13:18:43 -0700 Subject: [PATCH 6/6] Handle CUDA graphs correctly --- platforms/cuda/src/CudaTorchKernels.cpp | 47 ++++++++++++--------- platforms/cuda/src/CudaTorchKernels.h | 1 + platforms/cuda/tests/TestCudaTorchForce.cpp | 6 ++- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 640e58d7..c6dc790a 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -85,6 +85,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce boxTensor = torch::empty({3, 3}, options); energyTensor = torch::empty({0}, options); forceTensor = torch::empty({0}, options); + for (const string& name : globalNames) + globalTensors[name] = torch::tensor({0}, options); // Pop the PyToch context CUcontext ctx; CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); @@ -129,7 +131,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) { /** * Prepare the inputs for the PyTorch model, copying positions from the OpenMM context. */ -void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector& inputs, map& derivInputs) { +void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector& inputs, map& globalTensors) { int numParticles = cu.getNumAtoms(); // Get pointers to the atomic positions and simulation box void* posData = getTensorPointer(cu, posTensor); @@ -153,11 +155,12 @@ void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor, - torch::Tensor& forceTensor, map& derivInputs) { + torch::Tensor& forceTensor, map& globalTensors, set paramDerivs) { vector gradInputs; - if (!outputsForces) + if (!outputsForces && includeForces) gradInputs.push_back(posTensor); - for (auto& deriv : derivInputs) - gradInputs.push_back(deriv.second); + for (auto& name : paramDerivs) + gradInputs.push_back(globalTensors[name]); auto none = torch::Tensor(); if (outputsForces) { auto outputs = module.forward(inputs).toTuple(); @@ -203,10 +206,11 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr } else { energyTensor = module.forward(inputs).toTensor(); // Compute force by backpropagating the PyTorch model - if (includeForces) { - // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions - // See https://github.com/openmm/openmm-torch/pull/120/ + // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions + // See https://github.com/openmm/openmm-torch/pull/120/ + if (gradInputs.size() > 0) energyTensor.backward(none, false, false, gradInputs); + if (includeForces) { // This is minus the forces, we change the sign later on forceTensor = posTensor.grad().clone(); // Zero the gradient to avoid accumulating it @@ -219,10 +223,9 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Push to the PyTorch context CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); vector inputs; - map derivInputs; - prepareTorchInputs(context, inputs, derivInputs); + prepareTorchInputs(context, inputs, globalTensors); if (!useGraphs) { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs); } else { // Record graph if not already done bool is_graph_captured = false; @@ -237,14 +240,14 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // record static pointers and shapes during capture. try { for (int i = 0; i < this->warmupSteps; i++) - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs); } catch (std::exception& e) { throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what()); } graphs[includeForces].capture_begin(); try { - executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, derivInputs); + executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs); is_graph_captured = true; graphs[includeForces].capture_end(); } @@ -254,6 +257,8 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce } throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what()); } + for (const string& name : paramDerivs) + globalTensors[name].grad().zero_(); } // Use the same stream as the OpenMM context, even if it is the default stream const auto openmmStream = cu.getCurrentStream(); @@ -265,8 +270,10 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce addForces(forceTensor); } map& energyParamDerivs = cu.getEnergyParamDerivWorkspace(); - for (const string& name : paramDerivs) - energyParamDerivs[name] += derivInputs[name].grad().item(); + for (const string& name : paramDerivs) { + energyParamDerivs[name] += globalTensors[name].grad().item(); + globalTensors[name].grad().zero_(); + } // Get energy const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context // Pop to the PyTorch context diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index bbd46e89..1f7b4c97 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -72,6 +72,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { torch::jit::script::Module module; torch::Tensor posTensor, boxTensor; torch::Tensor energyTensor, forceTensor; + std::map globalTensors; std::vector globalNames; std::set paramDerivs; bool usePeriodic, outputsForces; diff --git a/platforms/cuda/tests/TestCudaTorchForce.cpp b/platforms/cuda/tests/TestCudaTorchForce.cpp index 5b8332ef..fcc756f8 100644 --- a/platforms/cuda/tests/TestCudaTorchForce.cpp +++ b/platforms/cuda/tests/TestCudaTorchForce.cpp @@ -129,7 +129,7 @@ void testPeriodicForce() { ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); } -void testGlobal() { +void testGlobal(bool useGraphs) { // Create a random cloud of particles. const int numParticles = 10; @@ -144,6 +144,7 @@ void testGlobal() { TorchForce* force = new TorchForce("tests/global.pt"); force->addGlobalParameter("k", 2.0); force->addEnergyParameterDerivative("k"); + force->setProperty("useCUDAGraphs", useGraphs ? "true" : "false"); system.addForce(force); // Compute the forces and energy. @@ -195,7 +196,8 @@ int main(int argc, char* argv[]) { testForce(false); testForce(true); testPeriodicForce(); - testGlobal(); + testGlobal(false); + testGlobal(true); } catch(const std::exception& e) { std::cout << "exception: " << e.what() << std::endl;