diff --git a/espaloma/graphs/deploy.py b/espaloma/graphs/deploy.py index fb997b13..7ebb23fa 100644 --- a/espaloma/graphs/deploy.py +++ b/espaloma/graphs/deploy.py @@ -31,7 +31,7 @@ # MODULE FUNCTIONS # ============================================================================= def get_latest_model(); - model = torch.utils.model_zoo.load_url(LATEST_URL) + model = torch.utils.model_zoo.load_url(LATEST_URL, map_location="cpu") model.eval() return model diff --git a/espaloma/graphs/tests/test_deploy.py b/espaloma/graphs/tests/test_deploy.py index 16f7b8e2..3104e463 100644 --- a/espaloma/graphs/tests/test_deploy.py +++ b/espaloma/graphs/tests/test_deploy.py @@ -22,15 +22,12 @@ def test_butane_charge_nn(): the nn charge method""" import torch # Download serialized espaloma model - url = f'https://github.com/choderalab/espaloma/releases/download/0.3.0/espaloma-0.3.0rc1.pt' - espaloma_model_filepath = f'espaloma-0.3.0rc1.pt' - urllib.request.urlretrieve(url, filename=espaloma_model_filepath) # Test deployment ff = esp.graphs.legacy_force_field.LegacyForceField("openff-1.2.0") g = esp.Graph("CCCC") g = ff.parametrize(g) # apply a trained espaloma model to assign parameters - net = torch.load(espaloma_model_filepath, map_location=torch.device('cpu')) + net = esp.get_latest_model() net.eval() net(g.heterograph) esp.graphs.deploy.openmm_system_from_graph(g, suffix="_ref", charge_method="nn")