Skip to content

Commit

Permalink
Improvements to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed May 15, 2024
1 parent e801c10 commit fde51ff
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 33 deletions.
23 changes: 12 additions & 11 deletions platforms/cuda/tests/TestCudaTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void testGlobal() {
Platform& platform = Platform::getPlatformByName("CUDA");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);

// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2

Expand All @@ -165,16 +165,6 @@ void testGlobal() {
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}

// Check the gradient of the energy with respect to the parameter.

double expected = 0.0;
Expand All @@ -184,6 +174,17 @@ void testGlobal() {
}
double actual = state.getEnergyParameterDerivatives().at("k");
ASSERT_EQUAL_TOL(expected, actual, 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
}

int main(int argc, char* argv[]) {
Expand Down
23 changes: 12 additions & 11 deletions platforms/opencl/tests/TestOpenCLTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void testGlobal() {
Platform& platform = Platform::getPlatformByName("OpenCL");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);

// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2

Expand All @@ -162,16 +162,6 @@ void testGlobal() {
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}

// Check the gradient of the energy with respect to the parameter.

double expected = 0.0;
Expand All @@ -181,6 +171,17 @@ void testGlobal() {
}
double actual = state.getEnergyParameterDerivatives().at("k");
ASSERT_EQUAL_TOL(expected, actual, 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
}

int main(int argc, char* argv[]) {
Expand Down
23 changes: 12 additions & 11 deletions platforms/reference/tests/TestReferenceTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void testGlobal() {
Platform& platform = Platform::getPlatformByName("Reference");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);

// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2

Expand All @@ -162,16 +162,6 @@ void testGlobal() {
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}

// Check the gradient of the energy with respect to the parameter.

double expected = 0.0;
Expand All @@ -181,6 +171,17 @@ void testGlobal() {
}
double actual = state.getEnergyParameterDerivatives().at("k");
ASSERT_EQUAL_TOL(expected, actual, 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
}

int main() {
Expand Down

0 comments on commit fde51ff

Please sign in to comment.