From 3d3149f9befc5f44bbc248470f9e0d420170182c Mon Sep 17 00:00:00 2001 From: Mike Henry <11765982+mikemhenry@users.noreply.github.com> Date: Fri, 14 Jul 2023 07:05:04 -0700 Subject: [PATCH] Added doc strings and cleaned up tests --- espaloma/utils/model_fetch.py | 33 ++++++++++++++++++++++++ espaloma/utils/tests/test_model_fetch.py | 14 ++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/espaloma/utils/model_fetch.py b/espaloma/utils/model_fetch.py index 92bb699f..d0ebc311 100644 --- a/espaloma/utils/model_fetch.py +++ b/espaloma/utils/model_fetch.py @@ -7,6 +7,22 @@ def _get_model_url(version: str) -> str: + """ + Get the URL of the espaloma model from GitHub releases. + + Parameters: + version (str): Version of the model. If set to "latest", the URL for the latest version will be returned. + + Returns: + str: The URL of the espaloma model. + + Note: + - If version is set to "latest", the URL for the latest version of the model will be returned. + - The URL is obtained from the GitHub releases of the espaloma repository. + + Example: + >>> url = _get_model_url(version="0.3.0") + """ if version == "latest": url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt" @@ -78,6 +94,23 @@ def get_model_path( def get_model(version: str = "latest") -> dict[str, Any]: + """ + Load an espaloma model from GitHub releases. + + Parameters: + version (str): Version of the model to load. Default is "latest". + + Returns: + dict[str, Any]: The loaded espaloma model. + + Note: + - If version is set to "latest", the latest version of the model will be loaded. + - The model will be loaded from GitHub releases. + - The model will be loaded onto the CPU. + + Example: + >>> model = get_model(version="0.3.0") + """ url = _get_model_url(version) model = torch.utils.model_zoo.load_url(url, map_location="cpu") diff --git a/espaloma/utils/tests/test_model_fetch.py b/espaloma/utils/tests/test_model_fetch.py index a636b32f..1266d4d6 100644 --- a/espaloma/utils/tests/test_model_fetch.py +++ b/espaloma/utils/tests/test_model_fetch.py @@ -1,12 +1,12 @@ import espaloma as esp import torch -from espaloma.utils.model_fetch import get_model +from espaloma.utils.model_fetch import get_model, get_model_path from openff.toolkit.topology import Molecule def test_get_model_path(tmp_path): model_dir = tmp_path / "latest" - model_path = get_model(model_dir=model_dir, disable_progress_bar=True) + model_path = get_model_path(model_dir=model_dir, disable_progress_bar=True) molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") molecule_graph = esp.Graph(molecule) @@ -14,3 +14,13 @@ def test_get_model_path(tmp_path): espaloma_model = torch.load(model_path) espaloma_model.eval() espaloma_model(molecule_graph.heterograph) + + +def test_get_model(tmp_path): + model_dir = tmp_path / "zoo" + + espaloma_model = get_model() + + molecule = Molecule.from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") + molecule_graph = esp.Graph(molecule) + espaloma_model(molecule_graph.heterograph)