From 3de4a190879e76dadce11331774331445207b36a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 16 Aug 2024 11:04:52 +0200 Subject: [PATCH] Address peastman's comments. --- README.md | 68 +++++++++++++++++++------------------------------------ 1 file changed, 23 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 7786752..b0900a6 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ torch_force.setOutputsForces(True) Computing energy derivatives with respect to global parameters -------------------------------------------------------------- -Its possible to query `TorchForce` for the derivative of the energy with respect to global parameters. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter. +TorchForce can compute derivatives of the energy with respect to global parameters.. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter. The parameter derivatives can be queried by calling `getEnergyParameterDerivatives()` on the `State` object returned by `Context.getState()`. The result is a dictionary with the parameter names as keys and the derivatives as values. @@ -274,50 +274,28 @@ class ForceWithParameters(pt.nn.Module): def __init__(self): super(ForceWithParameters, self).__init__() - def forward( - self, positions: Tensor, parameter1: Tensor, parameter2: Tensor - ) -> Tensor: - x2 = positions.pow(2).sum(dim=1) - u_harmonic = ((parameter1 + parameter2**2) * x2).sum() - return u_harmonic - - -def example(): - numParticles = 10 - system = mm.System() - positions = np.random.rand(numParticles, 3) - for _ in range(numParticles): - system.addParticle(1.0) - - pt_force = ForceWithParameters() - model = pt.jit.script(pt_force) - tforce = TorchForce(model) - parameter1 = 1.0 - parameter2 = 1.0 - force.setOutputsForces(False) - force.addGlobalParameter("parameter1", parameter1) - force.addEnergyParameterDerivative("parameter1") - force.addGlobalParameter("parameter2", parameter2) - force.addEnergyParameterDerivative("parameter2") - system.addForce(force) - integ = mm.VerletIntegrator(1.0) - platform = mm.Platform.getPlatformByName(platform) - context = mm.Context(system, integ, platform) - context.setPositions(positions) - state = context.getState( - getEnergy=True, getForces=True, getParameterDerivatives=True - ) - # The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2 - r2 = np.sum(positions * positions) - expectedEnergy = (parameter1 + parameter2**2) * r2 - assert np.allclose( - r2, - state.getEnergyParameterDerivatives()["parameter1"], - ) - assert np.allclose( - 2 * parameter2 * r2, - state.getEnergyParameterDerivatives()["parameter2"], - ) + def forward(self, positions: Tensor, k: Tensor) -> Tensor: + return k*torch.sum(positions**2) + + +numParticles = 10 +system = mm.System() +for _ in range(numParticles): + system.addParticle(1.0) + +model = pt.jit.script(ForceWithParameters()) +tforce = TorchForce(model) +force.setOutputsForces(False) +force.addGlobalParameter("k", 2.0) +force.addEnergyParameterDerivative("k") +system.addForce(force) +integ = mm.VerletIntegrator(1.0) +platform = mm.Platform.getPlatformByName(platform) +context = mm.Context(system, integ, platform) +positions = np.random.rand(numParticles, 3) +context.setPositions(positions) +state = context.getState(getParameterDerivatives=True) +dEdk = state.getEnergyParameterDerivatives()["k"] ```