From 5b5fb04866fa7052a0e445bf1a1d8f5cc3fa3f3d Mon Sep 17 00:00:00 2001 From: Mike Henry <11765982+mikemhenry@users.noreply.github.com> Date: Fri, 14 Jul 2023 07:54:31 -0700 Subject: [PATCH 1/2] first pass at model fetch (#172) --- .github/workflows/CI.yaml | 2 +- README.md | 8 +- devtools/conda-envs/espaloma.yaml | 1 + espaloma/__init__.py | 1 + espaloma/utils/model_fetch.py | 119 +++++++++++++++++++++++ espaloma/utils/tests/test_model_fetch.py | 23 +++++ 6 files changed, 148 insertions(+), 6 deletions(-) create mode 100644 espaloma/utils/model_fetch.py create mode 100644 espaloma/utils/tests/test_model_fetch.py diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index dfd2a5bf..524603fc 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -67,7 +67,7 @@ jobs: - name: Run tests run: | - pytest -v -n auto --cov=espaloma --cov-report=xml --color=yes espaloma/ + pytest -v --cov=espaloma --cov-report=xml --color=yes espaloma/ - name: CodeCov uses: codecov/codecov-action@v3 diff --git a/README.md b/README.md index 33625254..a3f62c37 100644 --- a/README.md +++ b/README.md @@ -49,10 +49,6 @@ import os import torch import espaloma as esp -# grab pretrained model -if not os.path.exists("espaloma_model.pt"): - os.system("wget http://data.wangyq.net/espaloma_model.pt") - # define or load a molecule of interest via the Open Force Field toolkit from openff.toolkit.topology import Molecule molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") @@ -60,8 +56,10 @@ molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") # create an Espaloma Graph object to represent the molecule of interest molecule_graph = esp.Graph(molecule) +# load pretrained model +espaloma_model = esp.get_model("latest") + # apply a trained espaloma model to assign parameters -espaloma_model = torch.load("espaloma_model.pt") espaloma_model(molecule_graph.heterograph) # create an OpenMM System for the specified molecule diff --git a/devtools/conda-envs/espaloma.yaml b/devtools/conda-envs/espaloma.yaml index 27186486..c000003a 100644 --- a/devtools/conda-envs/espaloma.yaml +++ b/devtools/conda-envs/espaloma.yaml @@ -19,6 +19,7 @@ dependencies: - openmm - openmmforcefields >=0.11.2 - tqdm + - pydantic <2 # We need our deps to fix this # Pytorch - pytorch>=1.8.0 - dgl =0.9.0 diff --git a/espaloma/__init__.py b/espaloma/__init__.py index 2855a327..87cda8b9 100644 --- a/espaloma/__init__.py +++ b/espaloma/__init__.py @@ -8,6 +8,7 @@ from .graphs.graph import Graph from .metrics import GraphMetric from .mm.geometry import * +from .utils.model_fetch import get_model, get_model_path # Add imports here # import espaloma diff --git a/espaloma/utils/model_fetch.py b/espaloma/utils/model_fetch.py new file mode 100644 index 00000000..4542d04c --- /dev/null +++ b/espaloma/utils/model_fetch.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Any, Union + +import requests +import torch.utils.model_zoo +from tqdm import tqdm + + +def _get_model_url(version: str) -> str: + """ + Get the URL of the espaloma model from GitHub releases. + + Parameters: + version (str): Version of the model. If set to "latest", the URL for the latest version will be returned. + + Returns: + str: The URL of the espaloma model. + + Note: + - If version is set to "latest", the URL for the latest version of the model will be returned. + - The URL is obtained from the GitHub releases of the espaloma repository. + + Example: + >>> url = _get_model_url(version="0.3.0") + """ + + if version == "latest": + url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt" + else: + # TODO: This scheme requires the version string of the model to match the + # release version + url = f"https://github.com/choderalab/espaloma/releases/download/{version}/espaloma-{version}.pt" + + return url + + +def get_model_path( + model_dir: Union[str, Path] = ".espaloma/", + version: str = "latest", + disable_progress_bar: bool = False, + overwrite: bool = False, +) -> Path: + """ + Download a model for espaloma. + + Parameters: + model_dir (str or Path): Directory path where the model will be saved. Default is ``.espaloma/``. + version (str): Version of the model to download. Default is "latest". + disable_progress_bar (bool): Whether to disable the progress bar during the download. Default is False. + overwrite (bool): Whether to overwrite the existing model file if it exists. Default is False. + + Returns: + Path: The path to the downloaded model file. + + Raises: + FileExistsError: If the model file already exists and overwrite is set to False. + + Note: + - If version is set to "latest", the latest version of the model will be downloaded. + - The model will be downloaded from GitHub releases. + - The model file will be saved in the specified model directory. + + Example: + >>> model_path = get_model(model_dir=".espaloma/", version="0.3.0", disable_progress_bar=True) + """ + + url = _get_model_url(version) + + # This will work as long as we never have a "/" in the version string + file_name = Path(url.split("/")[-1]) + model_dir = Path(model_dir) + model_path = Path(model_dir / file_name) + + if not overwrite and model_path.exists(): + raise FileExistsError( + f"File '{model_path}' exiits, use overwrite=True to overwrite file" + ) + model_dir.mkdir(parents=True, exist_ok=True) + + request = requests.get(url, stream=True) + request_lenght = int(request.headers.get("content-length", 0)) + with open(model_path, "wb") as file, tqdm( + total=request_lenght, + unit="iB", + unit_scale=True, + unit_divisor=1024, + disable=disable_progress_bar, + ) as progress: + for data in request.iter_content(chunk_size=1024): + size = file.write(data) + progress.update(size) + + return model_path + + +def get_model(version: str = "latest") -> dict[str, Any]: + """ + Load an espaloma model from GitHub releases. + + Parameters: + version (str): Version of the model to load. Default is "latest". + + Returns: + dict[str, Any]: The loaded espaloma model. + + Note: + - If version is set to "latest", the latest version of the model will be loaded. + - The model will be loaded from GitHub releases. + - The model will be loaded onto the CPU. + + Example: + >>> model = get_model(version="0.3.0") + """ + + url = _get_model_url(version) + model = torch.utils.model_zoo.load_url(url, map_location="cpu") + model.eval() # type: ignore + + return model diff --git a/espaloma/utils/tests/test_model_fetch.py b/espaloma/utils/tests/test_model_fetch.py new file mode 100644 index 00000000..5c97168c --- /dev/null +++ b/espaloma/utils/tests/test_model_fetch.py @@ -0,0 +1,23 @@ +import espaloma as esp +import torch +from openff.toolkit.topology import Molecule + + +def test_get_model_path(tmp_path): + model_dir = tmp_path / "latest" + model_path = esp.get_model_path(model_dir=model_dir, disable_progress_bar=True) + + molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") + molecule_graph = esp.Graph(molecule) + + espaloma_model = torch.load(model_path) + espaloma_model.eval() + espaloma_model(molecule_graph.heterograph) + + +def test_get_model(tmp_path): + espaloma_model = esp.get_model() + + molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") + molecule_graph = esp.Graph(molecule) + espaloma_model(molecule_graph.heterograph) From 71650389f6de16375d70717acc0111331089ebf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= <2949729+ijpulidos@users.noreply.github.com> Date: Fri, 14 Jul 2023 11:47:39 -0400 Subject: [PATCH 2/2] using `to_homogeneous` instead of the old deprecated name. (#173) --- espaloma/nn/sequential.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/espaloma/nn/sequential.py b/espaloma/nn/sequential.py index 4f1ef620..53c81f05 100644 --- a/espaloma/nn/sequential.py +++ b/espaloma/nn/sequential.py @@ -135,7 +135,7 @@ def forward(self, g, x=None): import dgl # get homogeneous subgraph - g_ = dgl.to_homo(g.edge_type_subgraph(["n1_neighbors_n1"])) + g_ = dgl.to_homogeneous(g.edge_type_subgraph(["n1_neighbors_n1"])) if x is None: # get node attributes