diff --git a/README.md b/README.md index 33625254..398da147 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ We show that this approach is not only sufficiently expressive to reproduce lega * `polynomial.py` higher order polynomials. -# Example: Deploy espaloma 0.2.0 pretrained force field to arbitrary MM system +# Example: Deploy latest espaloma pretrained force field to arbitrary MM system ```python # imports @@ -49,10 +49,6 @@ import os import torch import espaloma as esp -# grab pretrained model -if not os.path.exists("espaloma_model.pt"): - os.system("wget http://data.wangyq.net/espaloma_model.pt") - # define or load a molecule of interest via the Open Force Field toolkit from openff.toolkit.topology import Molecule molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") @@ -61,7 +57,7 @@ molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") molecule_graph = esp.Graph(molecule) # apply a trained espaloma model to assign parameters -espaloma_model = torch.load("espaloma_model.pt") +espaloma_model = esp.get_latest_model() espaloma_model(molecule_graph.heterograph) # create an OpenMM System for the specified molecule diff --git a/espaloma/__init__.py b/espaloma/__init__.py index 2855a327..986ccc18 100644 --- a/espaloma/__init__.py +++ b/espaloma/__init__.py @@ -8,6 +8,7 @@ from .graphs.graph import Graph from .metrics import GraphMetric from .mm.geometry import * +from .graphs.deploy import get_latest_model # Add imports here # import espaloma diff --git a/espaloma/graphs/deploy.py b/espaloma/graphs/deploy.py index 86216a99..fb997b13 100644 --- a/espaloma/graphs/deploy.py +++ b/espaloma/graphs/deploy.py @@ -24,10 +24,16 @@ OPENMM_BOND_K_UNIT = OPENMM_ENERGY_UNIT / (OPENMM_LENGTH_UNIT**2) OPENMM_ANGLE_K_UNIT = OPENMM_ENERGY_UNIT / (OPENMM_ANGLE_UNIT**2) +LATEST_URL = "https://github.com/choderalab/espaloma/" +"releases/latest/download/espaloma-latest.pt" + # ============================================================================= # MODULE FUNCTIONS # ============================================================================= - +def get_latest_model(); + model = torch.utils.model_zoo.load_url(LATEST_URL) + model.eval() + return model def load_forcefield(forcefield="openff_unconstrained-2.0.0"): # get a forcefield