diff --git a/README.md b/README.md index 7786752..b0900a6 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ torch_force.setOutputsForces(True) Computing energy derivatives with respect to global parameters -------------------------------------------------------------- -Its possible to query `TorchForce` for the derivative of the energy with respect to global parameters. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter. +TorchForce can compute derivatives of the energy with respect to global parameters.. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter. The parameter derivatives can be queried by calling `getEnergyParameterDerivatives()` on the `State` object returned by `Context.getState()`. The result is a dictionary with the parameter names as keys and the derivatives as values. @@ -274,50 +274,28 @@ class ForceWithParameters(pt.nn.Module): def __init__(self): super(ForceWithParameters, self).__init__() - def forward( - self, positions: Tensor, parameter1: Tensor, parameter2: Tensor - ) -> Tensor: - x2 = positions.pow(2).sum(dim=1) - u_harmonic = ((parameter1 + parameter2**2) * x2).sum() - return u_harmonic - - -def example(): - numParticles = 10 - system = mm.System() - positions = np.random.rand(numParticles, 3) - for _ in range(numParticles): - system.addParticle(1.0) - - pt_force = ForceWithParameters() - model = pt.jit.script(pt_force) - tforce = TorchForce(model) - parameter1 = 1.0 - parameter2 = 1.0 - force.setOutputsForces(False) - force.addGlobalParameter("parameter1", parameter1) - force.addEnergyParameterDerivative("parameter1") - force.addGlobalParameter("parameter2", parameter2) - force.addEnergyParameterDerivative("parameter2") - system.addForce(force) - integ = mm.VerletIntegrator(1.0) - platform = mm.Platform.getPlatformByName(platform) - context = mm.Context(system, integ, platform) - context.setPositions(positions) - state = context.getState( - getEnergy=True, getForces=True, getParameterDerivatives=True - ) - # The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2 - r2 = np.sum(positions * positions) - expectedEnergy = (parameter1 + parameter2**2) * r2 - assert np.allclose( - r2, - state.getEnergyParameterDerivatives()["parameter1"], - ) - assert np.allclose( - 2 * parameter2 * r2, - state.getEnergyParameterDerivatives()["parameter2"], - ) + def forward(self, positions: Tensor, k: Tensor) -> Tensor: + return k*torch.sum(positions**2) + + +numParticles = 10 +system = mm.System() +for _ in range(numParticles): + system.addParticle(1.0) + +model = pt.jit.script(ForceWithParameters()) +tforce = TorchForce(model) +force.setOutputsForces(False) +force.addGlobalParameter("k", 2.0) +force.addEnergyParameterDerivative("k") +system.addForce(force) +integ = mm.VerletIntegrator(1.0) +platform = mm.Platform.getPlatformByName(platform) +context = mm.Context(system, integ, platform) +positions = np.random.rand(numParticles, 3) +context.setPositions(positions) +state = context.getState(getParameterDerivatives=True) +dEdk = state.getEnergyParameterDerivatives()["k"] ```