Skip to content

Commit

Permalink
Merge branch 'main' into fix/dataset-split
Browse files Browse the repository at this point in the history
  • Loading branch information
kntkb committed Jul 14, 2023
2 parents 74ae92a + 7165038 commit be666e5
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- name: Run tests
run: |
pytest -v -n auto --cov=espaloma --cov-report=xml --color=yes espaloma/
pytest -v --cov=espaloma --cov-report=xml --color=yes espaloma/
- name: CodeCov
uses: codecov/codecov-action@v3
Expand Down
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,17 @@ import os
import torch
import espaloma as esp

# grab pretrained model
if not os.path.exists("espaloma_model.pt"):
os.system("wget http://data.wangyq.net/espaloma_model.pt")

# define or load a molecule of interest via the Open Force Field toolkit
from openff.toolkit.topology import Molecule
molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C")

# create an Espaloma Graph object to represent the molecule of interest
molecule_graph = esp.Graph(molecule)

# load pretrained model
espaloma_model = esp.get_model("latest")

# apply a trained espaloma model to assign parameters
espaloma_model = torch.load("espaloma_model.pt")
espaloma_model(molecule_graph.heterograph)

# create an OpenMM System for the specified molecule
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/espaloma.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- openmm
- openmmforcefields >=0.11.2
- tqdm
- pydantic <2 # We need our deps to fix this
# Pytorch
- pytorch>=1.8.0
- dgl =0.9.0
Expand Down
1 change: 1 addition & 0 deletions espaloma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .graphs.graph import Graph
from .metrics import GraphMetric
from .mm.geometry import *
from .utils.model_fetch import get_model, get_model_path

# Add imports here
# import espaloma
Expand Down
2 changes: 1 addition & 1 deletion espaloma/nn/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self, g, x=None):
import dgl

# get homogeneous subgraph
g_ = dgl.to_homo(g.edge_type_subgraph(["n1_neighbors_n1"]))
g_ = dgl.to_homogeneous(g.edge_type_subgraph(["n1_neighbors_n1"]))

if x is None:
# get node attributes
Expand Down
119 changes: 119 additions & 0 deletions espaloma/utils/model_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from pathlib import Path
from typing import Any, Union

import requests
import torch.utils.model_zoo
from tqdm import tqdm


def _get_model_url(version: str) -> str:
"""
Get the URL of the espaloma model from GitHub releases.
Parameters:
version (str): Version of the model. If set to "latest", the URL for the latest version will be returned.
Returns:
str: The URL of the espaloma model.
Note:
- If version is set to "latest", the URL for the latest version of the model will be returned.
- The URL is obtained from the GitHub releases of the espaloma repository.
Example:
>>> url = _get_model_url(version="0.3.0")
"""

if version == "latest":
url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt"
else:
# TODO: This scheme requires the version string of the model to match the
# release version
url = f"https://github.com/choderalab/espaloma/releases/download/{version}/espaloma-{version}.pt"

return url


def get_model_path(
model_dir: Union[str, Path] = ".espaloma/",
version: str = "latest",
disable_progress_bar: bool = False,
overwrite: bool = False,
) -> Path:
"""
Download a model for espaloma.
Parameters:
model_dir (str or Path): Directory path where the model will be saved. Default is ``.espaloma/``.
version (str): Version of the model to download. Default is "latest".
disable_progress_bar (bool): Whether to disable the progress bar during the download. Default is False.
overwrite (bool): Whether to overwrite the existing model file if it exists. Default is False.
Returns:
Path: The path to the downloaded model file.
Raises:
FileExistsError: If the model file already exists and overwrite is set to False.
Note:
- If version is set to "latest", the latest version of the model will be downloaded.
- The model will be downloaded from GitHub releases.
- The model file will be saved in the specified model directory.
Example:
>>> model_path = get_model(model_dir=".espaloma/", version="0.3.0", disable_progress_bar=True)
"""

url = _get_model_url(version)

# This will work as long as we never have a "/" in the version string
file_name = Path(url.split("/")[-1])
model_dir = Path(model_dir)
model_path = Path(model_dir / file_name)

if not overwrite and model_path.exists():
raise FileExistsError(
f"File '{model_path}' exiits, use overwrite=True to overwrite file"
)
model_dir.mkdir(parents=True, exist_ok=True)

request = requests.get(url, stream=True)
request_lenght = int(request.headers.get("content-length", 0))
with open(model_path, "wb") as file, tqdm(
total=request_lenght,
unit="iB",
unit_scale=True,
unit_divisor=1024,
disable=disable_progress_bar,
) as progress:
for data in request.iter_content(chunk_size=1024):
size = file.write(data)
progress.update(size)

return model_path


def get_model(version: str = "latest") -> dict[str, Any]:
"""
Load an espaloma model from GitHub releases.
Parameters:
version (str): Version of the model to load. Default is "latest".
Returns:
dict[str, Any]: The loaded espaloma model.
Note:
- If version is set to "latest", the latest version of the model will be loaded.
- The model will be loaded from GitHub releases.
- The model will be loaded onto the CPU.
Example:
>>> model = get_model(version="0.3.0")
"""

url = _get_model_url(version)
model = torch.utils.model_zoo.load_url(url, map_location="cpu")
model.eval() # type: ignore

return model
23 changes: 23 additions & 0 deletions espaloma/utils/tests/test_model_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import espaloma as esp
import torch
from openff.toolkit.topology import Molecule


def test_get_model_path(tmp_path):
model_dir = tmp_path / "latest"
model_path = esp.get_model_path(model_dir=model_dir, disable_progress_bar=True)

molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
molecule_graph = esp.Graph(molecule)

espaloma_model = torch.load(model_path)
espaloma_model.eval()
espaloma_model(molecule_graph.heterograph)


def test_get_model(tmp_path):
espaloma_model = esp.get_model()

molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
molecule_graph = esp.Graph(molecule)
espaloma_model(molecule_graph.heterograph)

0 comments on commit be666e5

Please sign in to comment.