Skip to content

Commit

Permalink
Added doc strings and cleaned up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemhenry committed Jul 14, 2023
1 parent 10fb021 commit 3d3149f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
33 changes: 33 additions & 0 deletions espaloma/utils/model_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@


def _get_model_url(version: str) -> str:
"""
Get the URL of the espaloma model from GitHub releases.
Parameters:
version (str): Version of the model. If set to "latest", the URL for the latest version will be returned.
Returns:
str: The URL of the espaloma model.
Note:
- If version is set to "latest", the URL for the latest version of the model will be returned.
- The URL is obtained from the GitHub releases of the espaloma repository.
Example:
>>> url = _get_model_url(version="0.3.0")
"""

if version == "latest":
url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt"
Expand Down Expand Up @@ -78,6 +94,23 @@ def get_model_path(


def get_model(version: str = "latest") -> dict[str, Any]:
"""
Load an espaloma model from GitHub releases.
Parameters:
version (str): Version of the model to load. Default is "latest".
Returns:
dict[str, Any]: The loaded espaloma model.
Note:
- If version is set to "latest", the latest version of the model will be loaded.
- The model will be loaded from GitHub releases.
- The model will be loaded onto the CPU.
Example:
>>> model = get_model(version="0.3.0")
"""

url = _get_model_url(version)
model = torch.utils.model_zoo.load_url(url, map_location="cpu")
Expand Down
14 changes: 12 additions & 2 deletions espaloma/utils/tests/test_model_fetch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import espaloma as esp
import torch
from espaloma.utils.model_fetch import get_model
from espaloma.utils.model_fetch import get_model, get_model_path
from openff.toolkit.topology import Molecule


def test_get_model_path(tmp_path):
model_dir = tmp_path / "latest"
model_path = get_model(model_dir=model_dir, disable_progress_bar=True)
model_path = get_model_path(model_dir=model_dir, disable_progress_bar=True)

molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
molecule_graph = esp.Graph(molecule)

espaloma_model = torch.load(model_path)
espaloma_model.eval()
espaloma_model(molecule_graph.heterograph)


def test_get_model(tmp_path):
model_dir = tmp_path / "zoo"

espaloma_model = get_model()

molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
molecule_graph = esp.Graph(molecule)
espaloma_model(molecule_graph.heterograph)

0 comments on commit 3d3149f

Please sign in to comment.