From 0b781c277e0442198e807f1c35b4de9bc0c63fd4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 16 Aug 2024 11:11:27 +0200 Subject: [PATCH] Further example simplification --- README.md | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index b0900a6..3456c23 100644 --- a/README.md +++ b/README.md @@ -265,17 +265,17 @@ The parameter derivatives can be queried by calling `getEnergyParameterDerivativ ```python import torch as pt -from torch import Tensor from openmmtorch import TorchForce import openmm as mm + class ForceWithParameters(pt.nn.Module): def __init__(self): super(ForceWithParameters, self).__init__() - def forward(self, positions: Tensor, k: Tensor) -> Tensor: - return k*torch.sum(positions**2) + def forward(self, positions: pt.Tensor, k: pt.Tensor) -> pt.Tensor: + return k * pt.sum(positions**2) numParticles = 10 @@ -285,15 +285,12 @@ for _ in range(numParticles): model = pt.jit.script(ForceWithParameters()) tforce = TorchForce(model) -force.setOutputsForces(False) -force.addGlobalParameter("k", 2.0) -force.addEnergyParameterDerivative("k") -system.addForce(force) -integ = mm.VerletIntegrator(1.0) -platform = mm.Platform.getPlatformByName(platform) -context = mm.Context(system, integ, platform) -positions = np.random.rand(numParticles, 3) -context.setPositions(positions) +tforce.setOutputsForces(False) +tforce.addGlobalParameter("k", 2.0) +tforce.addEnergyParameterDerivative("k") +system.addForce(tforce) +context = mm.Context(system, mm.VerletIntegrator(1.0)) +context.setPositions(pt.rand(numParticles, 3).numpy()) state = context.getState(getParameterDerivatives=True) dEdk = state.getEnergyParameterDerivatives()["k"] ```