diff --git a/platforms/cuda/tests/TestCudaTorchForce.cpp b/platforms/cuda/tests/TestCudaTorchForce.cpp index 6c3160f..5b8332e 100644 --- a/platforms/cuda/tests/TestCudaTorchForce.cpp +++ b/platforms/cuda/tests/TestCudaTorchForce.cpp @@ -152,7 +152,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("CUDA"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -165,16 +165,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -184,6 +174,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/opencl/tests/TestOpenCLTorchForce.cpp b/platforms/opencl/tests/TestOpenCLTorchForce.cpp index 73a642f..d42d68a 100644 --- a/platforms/opencl/tests/TestOpenCLTorchForce.cpp +++ b/platforms/opencl/tests/TestOpenCLTorchForce.cpp @@ -149,7 +149,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("OpenCL"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -162,16 +162,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -181,6 +171,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main(int argc, char* argv[]) { diff --git a/platforms/reference/tests/TestReferenceTorchForce.cpp b/platforms/reference/tests/TestReferenceTorchForce.cpp index 63dece2..e8885cf 100644 --- a/platforms/reference/tests/TestReferenceTorchForce.cpp +++ b/platforms/reference/tests/TestReferenceTorchForce.cpp @@ -149,7 +149,7 @@ void testGlobal() { Platform& platform = Platform::getPlatformByName("Reference"); Context context(system, integ, platform); context.setPositions(positions); - State state = context.getState(State::Energy | State::Forces); + State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives); // See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2 @@ -162,16 +162,6 @@ void testGlobal() { } ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); - // Change the global parameter and see if the forces are still correct. - - context.setParameter("k", 3.0); - state = context.getState(State::Forces | State::ParameterDerivatives); - for (int i = 0; i < numParticles; i++) { - Vec3 pos = positions[i]; - double r = sqrt(pos.dot(pos)); - ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); - } - // Check the gradient of the energy with respect to the parameter. double expected = 0.0; @@ -181,6 +171,17 @@ void testGlobal() { } double actual = state.getEnergyParameterDerivatives().at("k"); ASSERT_EQUAL_TOL(expected, actual, 1e-5); + + // Change the global parameter and see if the forces are still correct. + + context.setParameter("k", 3.0); + state = context.getState(State::Forces | State::ParameterDerivatives); + for (int i = 0; i < numParticles; i++) { + Vec3 pos = positions[i]; + double r = sqrt(pos.dot(pos)); + ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5); + } + ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5); } int main() {