Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can compute parameter derivatives #143

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand All @@ -36,6 +36,7 @@
#include "openmm/Force.h"
#include <map>
#include <string>
#include <vector>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"

Expand Down Expand Up @@ -106,6 +107,11 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* Get the number of global parameters that the interaction depends on.
*/
int getNumGlobalParameters() const;
/**
* Get the number of global parameters with respect to which the derivative of the energy
* should be computed.
*/
int getNumEnergyParameterDerivatives() const;
/**
* Add a new global parameter that the interaction may depend on. The default value provided to
* this method is the initial value of the parameter in newly created Contexts. You can change
Expand Down Expand Up @@ -144,6 +150,21 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Request that this Force compute the derivative of its energy with respect to a global parameter.
* The parameter must have already been added with addGlobalParameter().
*
* @param name the name of the parameter
*/
void addEnergyParameterDerivative(const std::string& name);
/**
* Get the name of a global parameter with respect to which this Force should compute the
* derivative of the energy.
*
* @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives()
* @return the parameter name
*/
const std::string& getEnergyParameterDerivativeName(int index) const;
/**
* Set a value of a property.
*
Expand All @@ -163,6 +184,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
std::vector<int> energyParameterDerivatives;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
Expand Down
20 changes: 19 additions & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -112,6 +112,24 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
globalParameters[index].defaultValue = defaultValue;
}

int TorchForce::getNumEnergyParameterDerivatives() const {
return energyParameterDerivatives.size();
}

void TorchForce::addEnergyParameterDerivative(const string& name) {
for (int i = 0; i < globalParameters.size(); i++)
if (name == globalParameters[i].name) {
energyParameterDerivatives.push_back(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen this in other parts of OpenMM, what happens if I call this function twice with the same name? Is that handled somewhere before or after this?

return;
}
throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'"));
}

const string& TorchForce::getEnergyParameterDerivativeName(int index) const {
ASSERT_VALID_INDEX(index, energyParameterDerivatives);
return globalParameters[energyParameterDerivatives[index]].name;
}

void TorchForce::setProperty(const std::string& name, const std::string& value) {
if (properties.find(name) == properties.end())
throw OpenMMException("TorchForce: Unknown property '" + name + "'");
Expand Down
73 changes: 50 additions & 23 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -66,6 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
outputsForces = force.getOutputsForces();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
paramDerivs.insert(force.getEnergyParameterDerivativeName(i));
cu.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i));
}
int numParticles = system.getNumParticles();

// Push the PyTorch context
Expand All @@ -81,6 +85,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
boxTensor = torch::empty({3, 3}, options);
energyTensor = torch::empty({0}, options);
forceTensor = torch::empty({0}, options);
for (const string& name : globalNames)
globalTensors[name] = torch::tensor({0}, options);
// Pop the PyToch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
Expand Down Expand Up @@ -125,7 +131,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
/**
* Prepare the inputs for the PyTorch model, copying positions from the OpenMM context.
*/
std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) {
void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector<torch::jit::IValue>& inputs, map<string, torch::Tensor>& globalTensors) {
int numParticles = cu.getNumAtoms();
// Get pointers to the atomic positions and simulation box
void* posData = getTensorPointer(cu, posTensor);
Expand All @@ -145,12 +151,17 @@ std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(Con
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}
// Prepare the input of the PyTorch model
vector<torch::jit::IValue> inputs = {posTensor};
inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
for (const string& name : globalNames)
inputs.push_back(torch::tensor(context.getParameter(name)));
return inputs;
for (const string& name : globalNames) {
// PyTorch requires us to set requires_grad to false before initializing a tensor.
globalTensors[name].set_requires_grad(false);
globalTensors[name][0] = context.getParameter(name);
if (paramDerivs.find(name) != paramDerivs.end())
globalTensors[name].set_requires_grad(true);
inputs.push_back(globalTensors[name]);
}
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -173,26 +184,34 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
}

/**
* This function launches the workload in a way compatible with CUDA
* graphs as far as OpenMM-Torch goes. Capturing this function when
* the model is not itself graph compatible (due to, for instance,
* This function launches the workload in a way compatible with CUDA
* graphs as far as OpenMM-Torch goes. Capturing this function when
* the model is not itself graph compatible (due to, for instance,
* implicit synchronizations) will result in a CUDA error.
*/
static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
torch::Tensor& forceTensor) {
torch::Tensor& forceTensor, map<string, torch::Tensor>& globalTensors, set<string> paramDerivs) {
vector<torch::Tensor> gradInputs;
if (!outputsForces && includeForces)
gradInputs.push_back(posTensor);
for (auto& name : paramDerivs)
gradInputs.push_back(globalTensors[name]);
auto none = torch::Tensor();
if (outputsForces) {
auto outputs = module.forward(inputs).toTuple();
energyTensor = outputs->elements()[0].toTensor();
forceTensor = outputs->elements()[1].toTensor();
if (gradInputs.size() > 0)
energyTensor.backward(none, false, false, gradInputs);
} else {
energyTensor = module.forward(inputs).toTensor();
// Compute force by backpropagating the PyTorch model
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
if (gradInputs.size() > 0)
energyTensor.backward(none, false, false, gradInputs);
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
if (includeForces) {
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
auto none = torch::Tensor();
energyTensor.backward(none, false, false, posTensor);
// This is minus the forces, we change the sign later on
// This is minus the forces, we change the sign later on
forceTensor = posTensor.grad().clone();
// Zero the gradient to avoid accumulating it
posTensor.grad().zero_();
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -203,31 +222,32 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
auto inputs = prepareTorchInputs(context);
vector<torch::jit::IValue> inputs;
prepareTorchInputs(context, inputs, globalTensors);
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
} else {
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
//CUDA graph capture must occur in a non-default stream
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const c10::cuda::CUDAStreamGuard guard(stream);
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
// stream capture-aware and, after warmup, will provide
// record static pointers and shapes during capture.
try {
for (int i = 0; i < this->warmupSteps; i++)
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
}
catch (std::exception& e) {
throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what());
}
graphs[includeForces].capture_begin();
try {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
is_graph_captured = true;
graphs[includeForces].capture_end();
}
Expand All @@ -237,16 +257,23 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
}
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
for (const string& name : paramDerivs)
globalTensors[name].grad().zero_();
}
// Use the same stream as the OpenMM context, even if it is the default stream
// Use the same stream as the OpenMM context, even if it is the default stream
const auto openmmStream = cu.getCurrentStream();
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
graphs[includeForces].replay();
}
if (includeForces) {
addForces(forceTensor);
}
map<string, double>& energyParamDerivs = cu.getEnergyParamDerivWorkspace();
for (const string& name : paramDerivs) {
energyParamDerivs[name] += globalTensors[name].grad().item<double>();
globalTensors[name].grad().zero_();
}
// Get energy
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context
// Pop to the PyTorch context
Expand Down
7 changes: 5 additions & 2 deletions platforms/cuda/src/CudaTorchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -37,6 +37,7 @@
#include "openmm/cuda/CudaArray.h"
#include <torch/version.h>
#include <ATen/cuda/CUDAGraph.h>
#include <set>

namespace TorchPlugin {

Expand Down Expand Up @@ -71,12 +72,14 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
torch::jit::script::Module module;
torch::Tensor posTensor, boxTensor;
torch::Tensor energyTensor, forceTensor;
std::map<std::string, torch::Tensor> globalTensors;
std::vector<std::string> globalNames;
std::set<std::string> paramDerivs;
bool usePeriodic, outputsForces;
CUfunction copyInputsKernel, addForcesKernel;
CUcontext primaryContext;
std::map<bool, at::cuda::CUDAGraph> graphs;
std::vector<torch::jit::IValue> prepareTorchInputs(OpenMM::ContextImpl& context);
void prepareTorchInputs(OpenMM::ContextImpl& context, std::vector<torch::jit::IValue>& inputs, std::map<std::string, torch::Tensor>& derivInputs);
bool useGraphs;
void addForces(torch::Tensor& forceTensor);
int warmupSteps;
Expand Down
24 changes: 19 additions & 5 deletions platforms/cuda/tests/TestCudaTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -129,7 +129,7 @@ void testPeriodicForce() {
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
}

void testGlobal() {
void testGlobal(bool useGraphs) {
// Create a random cloud of particles.

const int numParticles = 10;
Expand All @@ -143,6 +143,8 @@ void testGlobal() {
}
TorchForce* force = new TorchForce("tests/global.pt");
force->addGlobalParameter("k", 2.0);
force->addEnergyParameterDerivative("k");
force->setProperty("useCUDAGraphs", useGraphs ? "true" : "false");
system.addForce(force);

// Compute the forces and energy.
Expand All @@ -151,7 +153,7 @@ void testGlobal() {
Platform& platform = Platform::getPlatformByName("CUDA");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);

// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2

Expand All @@ -164,15 +166,26 @@ void testGlobal() {
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);

// Check the gradient of the energy with respect to the parameter.

double expected = 0.0;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
expected += pos.dot(pos);
}
double actual = state.getEnergyParameterDerivatives().at("k");
ASSERT_EQUAL_TOL(expected, actual, 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
}

int main(int argc, char* argv[]) {
Expand All @@ -183,7 +196,8 @@ int main(int argc, char* argv[]) {
testForce(false);
testForce(true);
testPeriodicForce();
testGlobal();
testGlobal(false);
testGlobal(true);
}
catch(const std::exception& e) {
std::cout << "exception: " << e.what() << std::endl;
Expand Down
Loading
Loading