Skip to content

Commit

Permalink
added function that returns the model
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemhenry committed Jul 14, 2023
1 parent 886401a commit 10fb021
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions espaloma/utils/model_fetch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from pathlib import Path
from typing import Any

import requests
import torch.utils.model_zoo
from tqdm import tqdm


def get_model(
def _get_model_url(version: str) -> str:

if version == "latest":
url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt"
else:
# TODO: This scheme requires the version string of the model to match the
# release version
url = f"https://github.com/choderalab/espaloma/releases/download/{version}/espaloma-{version}.pt"

return url


def get_model_path(
model_dir: str | Path = ".espaloma/",
version: str = "latest",
disable_progress_bar: bool = False,
Expand Down Expand Up @@ -33,12 +47,8 @@ def get_model(
Example:
>>> model_path = get_model(model_dir=".espaloma/", version="0.3.0", disable_progress_bar=True)
"""
if version == "latest":
url = "https://github.com/choderalab/espaloma/releases/latest/download/espaloma-latest.pt"
else:
# TODO: This scheme requires the version string of the model to match the
# release version
url = f"https://github.com/choderalab/espaloma/releases/download/{version}/espaloma-{version}.pt"

url = _get_model_url(version)

# This will work as long as we never have a "/" in the version string
file_name = Path(url.split("/")[-1])
Expand All @@ -65,3 +75,12 @@ def get_model(
progress.update(size)

return model_path


def get_model(version: str = "latest") -> dict[str, Any]:

url = _get_model_url(version)
model = torch.utils.model_zoo.load_url(url, map_location="cpu")
model.eval() # type: ignore

return model

0 comments on commit 10fb021

Please sign in to comment.