Skip to content

Commit

Permalink
Serialization of parameter derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed May 15, 2024
1 parent 0fc454b commit e801c10
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
15 changes: 10 additions & 5 deletions serialization/src/TorchForceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -74,7 +74,7 @@ TorchForceProxy::TorchForceProxy() : SerializationProxy("TorchForce") {
}

void TorchForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 2);
node.setIntProperty("version", 3);
const TorchForce& force = *reinterpret_cast<const TorchForce*>(object);
node.setStringProperty("file", force.getFile());
try {
Expand All @@ -90,14 +90,16 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
node.setBoolProperty("outputsForces", force.getOutputsForces());
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
}
SerializationNode& paramDerivs = node.createChildNode("ParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
}

void* TorchForceProxy::deserialize(const SerializationNode& node) const {
int storedVersion = node.getIntProperty("version");
if (storedVersion > 2)
if (storedVersion > 3)
throw OpenMMException("Unsupported version number");
TorchForce* force;
if (storedVersion == 1) {
Expand All @@ -121,6 +123,9 @@ void* TorchForceProxy::deserialize(const SerializationNode& node) const {
if (child.getName() == "GlobalParameters")
for (auto& parameter : child.getChildren())
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
if (child.getName() == "ParameterDerivatives")
for (auto& parameter : child.getChildren())
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
}
return force;
}
10 changes: 7 additions & 3 deletions serialization/tests/TestSerializeTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -50,6 +50,7 @@ void serializeAndDeserialize(TorchForce force) {
force.addGlobalParameter("y", 2.221);
force.setUsesPeriodicBoundaryConditions(true);
force.setOutputsForces(true);
force.addEnergyParameterDerivative("y");

// Serialize and then deserialize it.

Expand All @@ -73,17 +74,20 @@ void serializeAndDeserialize(TorchForce force) {
}
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
ASSERT_EQUAL(force.getOutputsForces(), force2.getOutputsForces());
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
}

void testSerializationFromModule() {
string fileName = "../../tests/forces.pt";
string fileName = "tests/forces.pt";
torch::jit::Module module = torch::jit::load(fileName);
TorchForce force(module);
serializeAndDeserialize(force);
}

void testSerializationFromFile() {
string fileName = "../../tests/forces.pt";
string fileName = "tests/forces.pt";
TorchForce force(fileName);
serializeAndDeserialize(force);
}
Expand Down

0 comments on commit e801c10

Please sign in to comment.