diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e41b20e..a42419e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,6 @@ name: "CI" on: pull_request: branches: - - main push: branches: - main @@ -36,6 +35,8 @@ jobs: include: - os: "macos" python-version: "3.11" + env: + OE_LICENSE: ${{ github.workspace }}/oe_license.txt steps: - uses: actions/checkout@v4 @@ -66,6 +67,14 @@ jobs: micromamba info micromamba list + - name: Decrypt OpenEye license + shell: bash -l {0} + env: + OE_LICENSE_TEXT: ${{ secrets.OE_LICENSE }} + run: | + echo "${OE_LICENSE_TEXT}" > ${OE_LICENSE} + python -c "import openeye; assert openeye.oechem.OEChemIsLicensed(), 'OpenEye license checks failed!'" + - name: "Run tests" env: # Set the OFE_SLOW_TESTS to True if running a Cron job diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0515a09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,171 @@ +# custom ignores +.duecredit.p +.xxrun +.idea/ +.vscode/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +*~ + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +_version.py + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/reference/api/generated + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# vim +*.swp + +# vscode +.vscode/ + +# Rever +rever/ + diff --git a/README.md b/README.md index 362dcaf..d0ff3b4 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ FE Flow ============================== [//]: # (Badges) -[![GitHub Actions Build Status](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/feflow/workflows/CI/badge.svg)](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/feflow/actions?query=workflow%3ACI) -[![codecov](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/feflow/branch/main/graph/badge.svg)](https://codecov.io/gh/REPLACE_WITH_OWNER_ACCOUNT/feflow/branch/main) +[![GitHub Actions Build Status](https://github.com/choderalab/feflow/actions/workflows/ci.yaml/badge.svg)](https://github.com/choderalab/feflow/actions/workflows/ci.yaml) +[![codecov](https://codecov.io/gh/choderalab/feflow/branch/main/graph/badge.svg)](https://codecov.io/gh/choderalab/feflow/branch/main) -Recipes and protocols for molecular free energy calculations using the openmmtools/perses and Open Free Energy toolkits +Recipes, utilities, and protocols for molecular free energy calculations using the openmmtools/perses and Open Free Energy toolkits ### Copyright @@ -13,6 +13,7 @@ Copyright (c) 2023, ChoderaLab #### Acknowledgements - +[Choderalab -- Perses](https://github.com/choderalab/perses) +[Open Free energy Consortium](https://openfree.energy/) Project based on the [Computational Molecular Science Python Cookiecutter](https://github.com/molssi/cookiecutter-cms) version 1.1. diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 7682b0c..a4e216f 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -1,11 +1,12 @@ name: feflow-test channels: - conda-forge + - openeye dependencies: # Base depends - gufe >=0.9.5 - numpy - - openfe # TODO: Remove once we don't depend on openfe + - openfe >=0.15 # TODO: Remove once we don't depend on openfe - openff-units - openmm - pymbar <4 @@ -13,6 +14,9 @@ dependencies: - python # Testing + - openeye-toolkits + - openmoltools # TODO: Remove once we refactor tests + - perses # TODO: Remove once we don't depend on perses for tests - pytest - pytest-cov - pytest-xdist diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index ee908a5..037c21a 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -120,11 +120,12 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): """ # needed imports import openmm + import numpy as np from openff.units.openmm import ensure_quantity from openmmtools.integrators import PeriodicNonequilibriumIntegrator from gufe.components import SmallMoleculeComponent from openfe.protocols.openmm_rfe import _rfe_utils - from feflow.utils.hybrid_topology import HybridTopologyFactoryModded as HybridTopologyFactory + from feflow.utils.hybrid_topology import HybridTopologyFactory # Check compatibility between states (same receptor and solvent) self._check_states_compatibility(state_a, state_b) diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index e412e94..292c2c4 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -6,9 +6,15 @@ from gufe.mapping import LigandAtomMapping -@pytest.fixture -def benzene_modifications(): - source = files("gufe.tests.data").joinpath("benzene_modifications.sdf") +@pytest.fixture(scope='session') +def gufe_data_dir(): + path = files("gufe.tests.data") + return path + + +@pytest.fixture(scope='session') +def benzene_modifications(gufe_data_dir): + source = gufe_data_dir.joinpath("benzene_modifications.sdf") with as_file(source) as f: supp = Chem.SDMolSupplier(str(f), removeHs=False) mols = list(supp) @@ -23,12 +29,12 @@ def solvent_comp(): yield gufe.SolventComponent(positive_ion="Na", negative_ion="Cl") -@pytest.fixture +@pytest.fixture(scope='session') def benzene(benzene_modifications): return gufe.SmallMoleculeComponent(benzene_modifications["benzene"]) -@pytest.fixture +@pytest.fixture(scope='session') def toluene(benzene_modifications): return gufe.SmallMoleculeComponent(benzene_modifications["toluene"]) @@ -130,7 +136,7 @@ def production_settings(short_settings): # Mappings fixtures -@pytest.fixture +@pytest.fixture(scope='session') def mapping_benzene_toluene(benzene, toluene): """Mapping from toluene to benzene""" mapping_toluene_to_benzene = {0: 4, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11, 8: 12, 9: 13, 11: 14} diff --git a/feflow/tests/test_hybrid_topology.py b/feflow/tests/test_hybrid_topology.py new file mode 100644 index 0000000..96595fb --- /dev/null +++ b/feflow/tests/test_hybrid_topology.py @@ -0,0 +1,333 @@ +""" +Module to implement the basic unit testing for the hybrid topology implementation. +Specifically, regarding the HybridTopologyFactory object. +More oriented to testing code functionality than science correctness. +""" +import pytest + +from feflow.utils.hybrid_topology import HybridTopologyFactory +from feflow.tests.utils import extract_htf_data + +import mdtraj as mdt +import numpy as np +from openmm import unit as omm_unit +from openmm.app import NoCutoff, PME +from openmm import ( + MonteCarloBarostat, + NonbondedForce, + CustomNonbondedForce +) +from openmmforcefields.generators import SystemGenerator +from openff.units.openmm import ( + to_openmm, + from_openmm, + ensure_quantity +) +from perses.tests import utils as perses_utils + + +@pytest.fixture(scope="module") +def standard_system_generator(): + """ + Fixture to create a standard/default system generator based on commonly used options for + force fields, temperature and pressure. That is, amber forcefields for proteins, openff 2.1.0 + for small molecules, 1 bar pressure and 300k temperature. + + Returns + ------- + generator: openmmforcefields.generators.SystemGenerator + SystemGenerator object. + """ + sys_gen_config = {} + sys_gen_config["forcefields"] = ["amber/ff14SB.xml", + "amber/tip3p_standard.xml", + "amber/tip3p_HFE_multivalent.xml", + "amber/phosaa10.xml"] + sys_gen_config["small_molecule_forcefield"] = "openff-2.1.0" + sys_gen_config["nonperiodic_forcefield_kwargs"] = { + "nonbondedMethod": NoCutoff, + } + sys_gen_config["periodic_forcefield_kwargs"] = { + "nonbondedMethod": PME, + "nonbondedCutoff": 1.0 * omm_unit.nanometer, + } + sys_gen_config["barostat"] = MonteCarloBarostat(1 * omm_unit.bar, 300 * omm_unit.kelvin) + + generator = SystemGenerator(**sys_gen_config) + + return generator + + +class TestHybridTopologyFactory: + """Class to test the base/vanilla HybridTopologyFactory object""" + + def test_custom_nonbonded_cutoff(self): + """ + Test that nonbonded cutoff gets propagated to the custom nonbonded forces generated in the HTF via the + _add_nonbonded_force_terms method. + + Creates an HTF and manually changes the cutoff in the OLD system of the hybrid topology factory and checks the + expected behavior with both running or not running the referenced method. + """ + from openmm import NonbondedForce, CustomNonbondedForce + # TODO: we should probably make a fixture with the following top proposal and factory + topology_proposal, current_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name='propane', proposed_mol_name='pentane', vacuum=False) + # Extract htf data from proposal + htf_data = extract_htf_data(topology_proposal) + hybrid_factory = HybridTopologyFactory(old_positions=current_positions, + new_positions=new_positions, + **htf_data, use_dispersion_correction=True) + old_system_forces = hybrid_factory._old_system_forces + hybrid_system_forces = hybrid_factory.hybrid_system.getForces() + old_nonbonded_forces = [force for force in old_system_forces if + isinstance(force, NonbondedForce)] + hybrid_custom_nonbonded_forces = [force for force in hybrid_system_forces if + isinstance(force, CustomNonbondedForce)] + # Modify the cutoff for nonbonded forces in the OLD system (!) + for force in old_nonbonded_forces: + force.setCutoffDistance(force.getCutoffDistance() + 1 * omm_unit.nanometer) + # Assert that the nb cutoff distance is different compared to the custom nb forces + for custom_force in hybrid_custom_nonbonded_forces: + assert custom_force.getCutoffDistance() != \ + force.getCutoffDistance(), "Expected different cutoff distances between NB and custom NB forces." + # propagate the cutoffs + hybrid_factory._add_nonbonded_force_terms() + # Check now that cutoff match for all nonbonded forces (including custom) + for force in old_nonbonded_forces: + for custom_force in hybrid_custom_nonbonded_forces: + assert custom_force.getCutoffDistance() == \ + force.getCutoffDistance(), "Expected equal cutoff distances between NB and custom NB forces." + + def test_hybrid_topology_benzene_phenol(self, benzene, toluene, mapping_benzene_toluene, + standard_system_generator): + """ + Test the creation of a HybridTopologyFactory object from scratch from a benzene to toluene + transformation. + + Tests that we can create a HTF from scratch and checks that the difference in the number of + atoms in the Hybrid topology initial and final states is the expected one. + + Returns + ------- + None + """ + benzene_offmol = benzene.to_openff() + toluene_offmol = toluene.to_openff() + # Create Openmm topologies from openff topologies + off_top_benzene = benzene_offmol.to_topology() + off_top_phenol = toluene_offmol.to_topology() + # Create openmm topologies initial and final states + initial_top = off_top_benzene.to_openmm() + final_top = off_top_phenol.to_openmm() + # Create openmm systems with the small molecules + system_generator = standard_system_generator + initial_system = system_generator.create_system(initial_top, molecules=[benzene_offmol]) + final_system = system_generator.create_system(final_top, molecules=[toluene_offmol]) + initial_positions = benzene_offmol.conformers[-1].to_openmm() + final_positions = toluene_offmol.conformers[-1].to_openmm() + # mapping + mapping = mapping_benzene_toluene.componentA_to_componentB # Initial to final map + initial_to_final_atom_map = mapping + initial_to_final_core_atom_map = mapping + # Instantiate HTF + htf = HybridTopologyFactory( + initial_system, + initial_positions, + initial_top, + final_system, + final_positions, + final_top, + initial_to_final_atom_map, + initial_to_final_core_atom_map + ) + + # Validate number of atoms in hybrid topology end systems + n_atoms_diff = benzene_offmol.n_atoms - toluene_offmol.n_atoms # Initial - Final -- Sign/order matters + initial_htf_n_atoms = len(htf.initial_atom_indices) + final_htf_n_atoms = len(htf.final_atom_indices) + assert initial_htf_n_atoms - final_htf_n_atoms == n_atoms_diff, \ + "Different number of atoms in HTF compared to original molecules." + + # 16 atoms: + # 11 common atoms, 1 extra hydrogen in benzene, 4 extra in toluene + # 12 bonds in benzene + 4 extra toluene bonds + assert len(list(htf.hybrid_topology.atoms)) == 16 + assert len(list(htf.hybrid_topology.bonds)) == 16 + # check that the omm_hybrid_topology has the right things + assert len(list(htf.omm_hybrid_topology.atoms())) == 16 + assert len(list(htf.omm_hybrid_topology.bonds())) == 16 + # check that we can convert back the mdtraj hybrid_topology attribute + ret_top = mdt.Topology.to_openmm(htf.hybrid_topology) + assert len(list(ret_top.atoms())) == 16 + assert len(list(ret_top.bonds())) == 16 + + # TODO: Validate common atoms include 6 carbon atoms + + +class TestHTFVirtualSites: + @pytest.fixture(scope='module') + def tip4p_system_generator(self): + """ + SystemGenerator object with tip4p-ew water + + Returns + ------- + generator: openmmforcefields.generators.SystemGenerator + SystemGenerator object. + """ + sys_gen_config = {} + sys_gen_config["forcefields"] = ["amber/ff14SB.xml", + "amber/tip4pew_standard.xml", + "amber/phosaa10.xml"] + sys_gen_config["small_molecule_forcefield"] = "openff-2.1.0" + sys_gen_config["nonperiodic_forcefield_kwargs"] = { + "nonbondedMethod": NoCutoff, + } + sys_gen_config["periodic_forcefield_kwargs"] = { + "nonbondedMethod": PME, + "nonbondedCutoff": 1.0 * omm_unit.nanometer, + } + sys_gen_config["barostat"] = MonteCarloBarostat(1 * omm_unit.bar, 300 * omm_unit.kelvin) + + generator = SystemGenerator(**sys_gen_config) + + return generator + + @pytest.fixture(scope='module') + def tip4p_benzene_to_toluene_htf(self, tip4p_system_generator, + benzene, toluene, mapping_benzene_toluene): + """ + TODO: turn part of this into a method for creating HTFs? + """ + from gufe import SolventComponent + # TODO: change imports once utils get moved + from openfe.protocols.openmm_utils import system_creation + from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers + from openfe.protocols.openmm_utils.omm_settings import SolvationSettings + + benz_off = benzene.to_openff() + tol_off = toluene.to_openff() + + solv_settings = SolvationSettings() + solv_settings.solvent_model = 'tip4pew' + + for mol in [benz_off, tol_off]: + tip4p_system_generator.create_system( + mol.to_topology().to_openmm(), molecules=[mol] + ) + + # Create state A model & get relevant OpenMM objects + benz_model, comp_resids = system_creation.get_omm_modeller( + protein_comp=None, + solvent_comp=SolventComponent(), + small_mols={benzene: benz_off}, + omm_forcefield=tip4p_system_generator.forcefield, + solvent_settings=solv_settings, + ) + + benz_topology = benz_model.getTopology() + benz_positions = to_openmm(from_openmm(benz_model.getPositions())) + benz_system = tip4p_system_generator.create_system( + benz_topology, molecules=[benz_off] + ) + + # Now for state B + tol_topology, tol_alchem_resids = topologyhelpers.combined_topology( + benz_topology, tol_off.to_topology().to_openmm(), + exclude_resids=comp_resids[benzene] + ) + + tol_system = tip4p_system_generator.create_system( + tol_topology, molecules=[tol_off] + ) + + ligand_mappings = topologyhelpers.get_system_mappings( + mapping_benzene_toluene.componentA_to_componentB, + benz_system, benz_topology, comp_resids[benzene], + tol_system, tol_topology, tol_alchem_resids + ) + + tol_positions = topologyhelpers.set_and_check_new_positions( + ligand_mappings, + benz_topology, tol_topology, + old_positions=benz_positions, + insert_positions=to_openmm(tol_off.conformers[0]) + ) + + # Finally get the HTF + hybrid_factory = HybridTopologyFactory( + benz_system, benz_positions, benz_topology, + tol_system, tol_positions, tol_topology, + old_to_new_atom_map=ligand_mappings['old_to_new_atom_map'], + old_to_new_core_atom_map=ligand_mappings['old_to_new_core_atom_map'], + ) + + return hybrid_factory + + def test_tip4p_particle_count(self, tip4p_benzene_to_toluene_htf): + """ + Check that the particle count is conserved, i.e. no vsites are lost + or double counted. + """ + htf = tip4p_benzene_to_toluene_htf + old_count = htf._old_system.getNumParticles() + unique_new_count = len(htf._unique_new_atoms) + hybrid_particle_count = htf.hybrid_system.getNumParticles() + + assert old_count + unique_new_count == hybrid_particle_count + + def test_tip4p_num_waters(self, tip4p_benzene_to_toluene_htf): + """ + Check that the nuumber of virtual sites is equal to the number of + waters + """ + htf = tip4p_benzene_to_toluene_htf + + num_waters = len( + [r for r in htf._old_topology.residues() if r.name == 'HOH'] + ) + + virtual_sites = [ + ix for ix in range(htf.hybrid_system.getNumParticles()) if + htf.hybrid_system.isVirtualSite(ix) + ] + + assert num_waters == len(virtual_sites) + + def test_tip4p_check_vsite_parameters(self, tip4p_benzene_to_toluene_htf): + + htf = tip4p_benzene_to_toluene_htf + + virtual_sites = [ + ix for ix in range(htf.hybrid_system.getNumParticles()) if + htf.hybrid_system.isVirtualSite(ix) + ] + + # get the standard and custom nonbonded forces - one of each + nonbond = [f for f in htf.hybrid_system.getForces() + if isinstance(f, NonbondedForce)][0] + + cust_nonbond = [f for f in htf.hybrid_system.getForces() + if isinstance(f, CustomNonbondedForce)][0] + + # loop through every virtual site and check that they have the + # expected tip4p parameters + for entry in virtual_sites: + vs = htf.hybrid_system.getVirtualSite(entry) + vs_mass = htf.hybrid_system.getParticleMass(entry) + assert ensure_quantity(vs_mass, 'openff').m == pytest.approx(0) + vs_weights = [vs.getWeight(ix) for ix in range(vs.getNumParticles())] + np.testing.assert_allclose( + vs_weights, [0.786646558, 0.106676721, 0.106676721] + ) + c, s, e = nonbond.getParticleParameters(entry) + assert ensure_quantity(c, 'openff').m == pytest.approx(-1.04844) + assert ensure_quantity(s, 'openff').m == 1 + assert ensure_quantity(e, 'openff').m == 0 + + s1, e1, s2, e2, i, j = cust_nonbond.getParticleParameters(entry) + + assert i == j == 0 + assert s1 == s2 == 1 + assert e1 == e2 == 0 diff --git a/feflow/tests/test_lambdaprotocol.py b/feflow/tests/test_lambdaprotocol.py new file mode 100644 index 0000000..1472bac --- /dev/null +++ b/feflow/tests/test_lambdaprotocol.py @@ -0,0 +1,67 @@ +import os +import pytest + +from feflow.utils import lambda_protocol + +running_on_github_actions = os.environ.get('GITHUB_ACTIONS', None) == 'true' + + +def test_lambda_protocol(): + """ + + Tests LambdaProtocol, ensures that it can be instantiated with defaults, and that it fails if disallowed functions are tried + + """ + + # check that it's possible to instantiate a LambdaProtocol for all the default types + for protocol in ['default', 'namd', 'quarters']: + lp = lambda_protocol.LambdaProtocol(functions=protocol) + assert isinstance(lp, lambda_protocol.LambdaProtocol), "instantiated is not instance of LambdaProtocol." + + + +"""this test is a little unhappy + +it checks that missing terms are added back in + +however a more recent commit in openfe land changed this to error not warn + +so the test as-is can't function +""" +@pytest.mark.skip +def test_missing_functions(): + # check that if we give an incomplete set of parameters it will add in the missing terms + missing_functions = {'lambda_sterics_delete': lambda x: x} + lp = lambda_protocol.LambdaProtocol(functions=missing_functions) + assert (len(missing_functions) == 1) + assert(len(lp.get_functions()) == 9) + + +def test_lambda_protocol_failure_ends(): + bad_function = {'lambda_sterics_delete': lambda x: -x} + with pytest.raises(ValueError): + lp = lambda_protocol.LambdaProtocol(functions=bad_function) + + +def test_lambda_protocol_naked_charges(): + naked_charge_functions = {'lambda_sterics_insert': + lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + 'lambda_electrostatics_insert': + lambda x: 2.0 * x if x < 0.5 else 1.0} + with pytest.raises(ValueError): + lp = lambda_protocol.LambdaProtocol(functions=naked_charge_functions) + + +def test_lambda_schedule_defaults(): + lambdas = lambda_protocol.LambdaProtocol(functions='default') + assert len(lambdas.lambda_schedule) == 10 + + +@pytest.mark.parametrize('windows', [11, 6, 9000]) +def test_lambda_schedule(windows): + lambdas = lambda_protocol.LambdaProtocol( + functions='default', + windows=windows + ) + assert len(lambdas.lambda_schedule) == windows + diff --git a/feflow/tests/test_relative.py b/feflow/tests/test_relative.py new file mode 100644 index 0000000..31f46e4 --- /dev/null +++ b/feflow/tests/test_relative.py @@ -0,0 +1,896 @@ +""" +Module to test setting up relative FE calculations using hybrid topologies as implemented +in the HybridTopologyFactory class. + +This module should be mostly related to testing the "science" consistency rather than the +code functionality. +""" + +########################################### +# IMPORTS +########################################### +import os + +import openmm +from openmm import app +from openmm import unit +import numpy as np +import pytest + +from feflow.utils.hybrid_topology import HybridTopologyFactory +from feflow.tests.utils import generate_endpoint_thermodynamic_states, extract_htf_data +from openmmtools.states import SamplerState +import openmmtools.mcmc as mcmc +import openmmtools.cache as cache + +from perses.tests import utils as perses_utils + +from openmmtools.multistate.pymbar import _pymbar_exp, detect_equilibration + +running_on_github_actions = os.environ.get('GITHUB_ACTIONS', None) == 'true' + +############################################# +# CONSTANTS +############################################# +# TODO: Maybe use openmmtools constants? +kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA +temperature = 300.0 * unit.kelvin +kT = kB * temperature +beta = 1.0 / kT +CARBON_MASS = 12.01 +ENERGY_THRESHOLD = 1e-1 +REFERENCE_PLATFORM = openmm.Platform.getPlatformByName("CPU") +aminos = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', + 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'] + + +def run_hybrid_endpoint_overlap(topology_proposal, current_positions, new_positions, + n_iterations=100): + """ + Test that the variance of the perturbation from lambda={0,1} to the corresponding nonalchemical endpoint is not + too large. + + Parameters + ---------- + topology_proposal : perses.rjmc.TopologyProposal + TopologyProposal object describing the transformation + current_positions : np.array, unit-bearing + Positions of the initial system + new_positions : np.array, unit-bearing + Positions of the new system + + Returns + ------- + hybrid_endpoint_results : list + list of [df, ddf, N_eff] for 1 and 0 + """ + # Extract data for HTF creation + htf_data = extract_htf_data(topology_proposal) + # Create the hybrid system: + hybrid_factory = HybridTopologyFactory( + old_positions=current_positions, + new_positions=new_positions, + **htf_data, + ) + + # Get the relevant thermodynamic states: + (nonalchemical_zero_thermodynamic_state, nonalchemical_one_thermodynamic_state, + lambda_zero_thermodynamic_state, lambda_one_thermodynamic_state) = ( + generate_endpoint_thermodynamic_states( + hybrid_factory.hybrid_system, topology_proposal)) + + nonalchemical_thermodynamic_states = [nonalchemical_zero_thermodynamic_state, + nonalchemical_one_thermodynamic_state] + + alchemical_thermodynamic_states = [lambda_zero_thermodynamic_state, + lambda_one_thermodynamic_state] + + # Create an MCMCMove, BAOAB with default parameters (but don't restart if we encounter a NaN) + mc_move = mcmc.LangevinDynamicsMove(n_restart_attempts=0, n_steps=100) + + initial_sampler_state = SamplerState(hybrid_factory.hybrid_positions, + box_vectors=hybrid_factory.hybrid_system.getDefaultPeriodicBoxVectors()) + + hybrid_endpoint_results = [] + all_results = [] + for lambda_state in (0, 1): + result, non, hybrid = run_endpoint_perturbation( + alchemical_thermodynamic_states[lambda_state], + nonalchemical_thermodynamic_states[lambda_state], initial_sampler_state, + mc_move, n_iterations, hybrid_factory, lambda_index=lambda_state) + all_results.append(non) + all_results.append(hybrid) + print('lambda {} : {}'.format(lambda_state, result)) + + hybrid_endpoint_results.append(result) + calculate_cross_variance(all_results) + return hybrid_endpoint_results + + +def calculate_cross_variance(all_results): + """ + Calculates the overlap (df and ddf) between the non-alchemical state at lambda=0 to the hybrid state at lambda=1 and visa versa + These ensembles are not expected to have good overlap, as they are of explicitly different system, but provides a benchmark of appropriate dissimilarity + """ + if len(all_results) != 4: + return + else: + non_a = all_results[0] + hybrid_a = all_results[1] + non_b = all_results[2] + hybrid_b = all_results[3] + print('CROSS VALIDATION') + [df, ddf] = _pymbar_exp(non_a - hybrid_b) + print('df: {}, ddf: {}'.format(df, ddf)) + [df, ddf] = _pymbar_exp(non_b - hybrid_a) + print('df: {}, ddf: {}'.format(df, ddf)) + return + + +def check_result(results, threshold=3.0, neffmin=10): + """ + Ensure results are within threshold standard deviations and Neff_max > neffmin + + Parameters + ---------- + results : list + list of [df, ddf, Neff_max] + threshold : float, default 3.0 + the standard deviation threshold + neff_min : float, default 10 + the minimum number of acceptable samples + """ + [df, ddf, t0, N_eff] = results + + if N_eff < neffmin: + raise Exception("Number of effective samples %f was below minimum of %f" % (N_eff, neffmin)) + + if ddf > threshold: + raise Exception("Standard deviation of %f exceeds threshold of %f" % (ddf, threshold)) + + +def test_networkx_proposal_order(): + """ + This test fails with a 'no topical torsions found' error with the old ProposalOrderTools + """ + pairs = [('pentane', 'propane')] + for pair in pairs: + print('{} -> {}'.format(pair[0], pair[1])) + simple_overlap(pair[0], pair[1]) + print('{} -> {}'.format(pair[1], pair[0])) + simple_overlap(pair[1], pair[0]) + + +def test_explosion(): + """ + This test fails with ridiculous DeltaF if the alchemical factory is misbehaving + """ + pairs = [['2-phenyl ethanol', 'benzene']] + for pair in pairs: + print('{} -> {}'.format(pair[0], pair[1])) + simple_overlap(pair[0], pair[1]) + print('{} -> {}'.format(pair[1], pair[0])) + simple_overlap(pair[1], pair[0]) + + +def test_vacuum_overlap_with_constraints(): + """ + Test that constraints do not cause problems for the hybrid factory in vacuum + """ + simple_overlap('2-phenyl ethanol', 'benzene', + forcefield_kwargs={'constraints': app.HBonds}) + + +def test_valence_overlap(): + """ + Test hybrid factory vacuum overlap with valence terms only + """ + system_generator_kwargs = { + 'particle_charge': False, 'exception_charge': False, 'particle_epsilon': False, + 'exception_epsilon': False, 'torsions': True, + } + simple_overlap('2-phenyl ethanol', 'benzene', + system_generator_kwargs=system_generator_kwargs) + + +def test_bonds_angles_overlap(): + """ + Test hybrid factory vacuum overlap with bonds and angles + """ + system_generator_kwargs = { + 'particle_charge': False, 'exception_charge': False, 'particle_epsilon': False, + 'exception_epsilon': False, 'torsions': False, + } + simple_overlap('2-phenyl ethanol', 'benzene', + system_generator_kwargs=system_generator_kwargs) + + +def test_sterics_overlap(): + """ + Test hybrid factory vacuum overlap with valence terms and sterics only + """ + system_generator_kwargs = { + 'particle_charge': False, 'exception_charge': False, 'particle_epsilon': True, + 'exception_epsilon': True, 'torsions': True, + } + simple_overlap('2-phenyl ethanol', 'benzene', + system_generator_kwargs=system_generator_kwargs) + + +def test_simple_overlap_pairs(pairs=None): + """ + Test to run pairs of small molecule perturbations in vacuum, using test_simple_overlap, both forward and backward. + + Parameters + ---------- + pairs : list of lists of str, optional, default=None + Pairs of IUPAC names to test. + If None, will test a default set: + [['pentane','butane'],['fluorobenzene', 'chlorobenzene'],['benzene', 'catechol'],['benzene','2-phenyl ethanol'],['imatinib','nilotinib']] + + pentane <-> butane is adding a methyl group + fluorobenzene <-> chlorobenzene perturbs one halogen to another, with no adding or removing of atoms + benzene <-> catechol perturbing molecule in two positions simultaneously + benzene <-> 2-phenyl ethanol addition of 3 heavy atom group + """ + if pairs is None: + pairs = [['pentane', 'butane'], ['fluorobenzene', 'chlorobenzene'], ['benzene', 'catechol'], + ['benzene', '2-phenyl ethanol']] # 'imatinib' --> 'nilotinib' atom mapping is bad + + for pair in pairs: + print('{} -> {}'.format(pair[0], pair[1])) + simple_overlap(pair[0], pair[1]) + # Now running the reverse + print('{} -> {}'.format(pair[1], pair[0])) + simple_overlap(pair[1], pair[0]) + + +def simple_overlap(name1='pentane', name2='butane', forcefield_kwargs=None, + system_generator_kwargs=None): + """Test that the variance of the hybrid -> real perturbation in vacuum is sufficiently small. + + Parameters + ---------- + name1 : str + IUPAC name of initial molecule + name2 : str + IUPAC name of final molecule + forcefield_kwargs : dict, optional, default=None + If None, these parameters are fed to the SystemGenerator + Setting { 'constraints' : app.HBonds } will enable constraints to hydrogen + system_generator_kwargs : dict, optional, default=None + If None, these parameters are fed to the SystemGenerator + Setting { 'particle_charge' : False } will turn off particle charges in parameterized systems + Can also disable 'exception_charge', 'particle_epsilon', 'exception_epsilon', and 'torsions' by setting to False + + """ + topology_proposal, current_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name=name1, proposed_mol_name=name2, vacuum=True) + results = run_hybrid_endpoint_overlap(topology_proposal, current_positions, new_positions) + for idx, lambda_result in enumerate(results): + try: + check_result(lambda_result) + except Exception as e: + message = "pentane->butane failed at lambda %d \n" % idx + message += str(e) + raise Exception(message) + + +# @skipIf(running_on_github_actions, "Skip expensive test on GH Actions") +@pytest.mark.skip(reason="Skip expensive test on GH Actions") +def test_hostguest_overlap(): + """Test that the variance of the endpoint->nonalchemical perturbation is sufficiently small for host-guest system in vacuum""" + topology_proposal, current_positions, new_positions = perses_utils.generate_vacuum_hostguest_proposal() + results = run_hybrid_endpoint_overlap(topology_proposal, current_positions, new_positions) + + for idx, lambda_result in enumerate(results): + try: + check_result(lambda_result) + except Exception as e: + message = "pentane->butane failed at lambda %d \n" % idx + message += str(e) + raise Exception(message) + +# TODO: This is skipped on perses, probably needs more iterations to converge? Skipping for now. +@pytest.mark.skip(reason="Expensive. Hard to converge.") +def test_difficult_overlap(n_iterations=500): + """Test that the variance of the endpoint->nonalchemical perturbation is sufficiently small for imatinib->nilotinib in solvent""" + name1 = 'imatinib' + name2 = 'nilotinib' + + print(name1, name2) + topology_proposal, solvated_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name=name1, proposed_mol_name=name2) + results = run_hybrid_endpoint_overlap(topology_proposal, solvated_positions, new_positions) + + for idx, lambda_result in enumerate(results): + try: + check_result(lambda_result) + except Exception as e: + message = "solvated imatinib->nilotinib failed at lambda %d \n" % idx + message += str(e) + raise Exception(message) + + print(name2, name1) + topology_proposal, solvated_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name=name2, proposed_mol_name=name1) + results = run_hybrid_endpoint_overlap(topology_proposal, solvated_positions, new_positions, + n_iterations=n_iterations) + + for idx, lambda_result in enumerate(results): + try: + check_result(lambda_result) + except Exception as e: + message = "solvated imatinib->nilotinib failed at lambda %d \n" % idx + message += str(e) + raise Exception(message) + + +def run_endpoint_perturbation(lambda_thermodynamic_state, nonalchemical_thermodynamic_state, + initial_hybrid_sampler_state, mc_move, n_iterations, factory, + lambda_index=0, print_work=False, write_system=False, + write_state=False, write_trajectories=False): + """ + + Parameters + ---------- + lambda_thermodynamic_state : ThermodynamicState + The thermodynamic state corresponding to the hybrid system at a lambda endpoint + nonalchemical_thermodynamic_state : ThermodynamicState + The nonalchemical thermodynamic state for the relevant endpoint + initial_hybrid_sampler_state : SamplerState + Starting positions for the sampler. Must be compatible with lambda_thermodynamic_state + mc_move : MCMCMove + The MCMove that will be used for sampling at the lambda endpoint + n_iterations : int + The number of iterations + factory : HybridTopologyFactory + The hybrid topology factory + lambda_index : int, optional, default=0 + The index, 0 or 1, at which to retrieve nonalchemical positions + print_work : bool, optional, default=False + If True, will print work values + write_system : bool, optional, default=False + If True, will write alchemical and nonalchemical System XML files + write_state : bool, optional, default=False + If True, write alchemical (hybrid) State XML files each iteration + write_trajectories : bool, optional, default=False + If True, will write trajectories + + Returns + ------- + df : float + Free energy difference between alchemical and nonalchemical systems, estimated with EXP + ddf : float + Standard deviation of estimate, corrected for correlation, from EXP estimator. + """ + import mdtraj as md + + # Run an initial minimization: + mcmc_sampler = mcmc.MCMCSampler(lambda_thermodynamic_state, initial_hybrid_sampler_state, + mc_move) + mcmc_sampler.minimize(max_iterations=20) + new_sampler_state = mcmc_sampler.sampler_state + + if write_system: + with open(f'hybrid{lambda_index}-system.xml', 'w') as outfile: + outfile.write(openmm.XmlSerializer.serialize(lambda_thermodynamic_state.system)) + with open(f'nonalchemical{lambda_index}-system.xml', 'w') as outfile: + outfile.write(openmm.XmlSerializer.serialize(nonalchemical_thermodynamic_state.system)) + + # Initialize work array + w = np.zeros([n_iterations]) + non_potential = np.zeros([n_iterations]) + hybrid_potential = np.zeros([n_iterations]) + + # Run n_iterations of the endpoint perturbation: + hybrid_trajectory = unit.Quantity( + np.zeros([n_iterations, lambda_thermodynamic_state.system.getNumParticles(), 3]), + unit.nanometers) # DEBUG + nonalchemical_trajectory = unit.Quantity( + np.zeros([n_iterations, nonalchemical_thermodynamic_state.system.getNumParticles(), 3]), + unit.nanometers) # DEBUG + for iteration in range(n_iterations): + # Generate a new sampler state for the hybrid system + mc_move.apply(lambda_thermodynamic_state, new_sampler_state) + + # Compute the hybrid reduced potential at the new sampler state + hybrid_context, integrator = cache.global_context_cache.get_context( + lambda_thermodynamic_state) + new_sampler_state.apply_to_context(hybrid_context, ignore_velocities=True) + hybrid_reduced_potential = lambda_thermodynamic_state.reduced_potential(hybrid_context) + + if write_state: + state = hybrid_context.getState(getPositions=True, getParameters=True) + state_xml = openmm.XmlSerializer.serialize(state) + with open(f'state{iteration}_l{lambda_index}.xml', 'w') as outfile: + outfile.write(state_xml) + + # Construct a sampler state for the nonalchemical system + if lambda_index == 0: + nonalchemical_positions = factory.old_positions(new_sampler_state.positions) + elif lambda_index == 1: + nonalchemical_positions = factory.new_positions(new_sampler_state.positions) + else: + raise ValueError( + "The lambda index needs to be either one or zero for this to be meaningful") + nonalchemical_sampler_state = SamplerState(nonalchemical_positions, + box_vectors=new_sampler_state.box_vectors) + + if write_trajectories: + state = hybrid_context.getState(getPositions=True) + hybrid_trajectory[iteration, :, :] = state.getPositions(asNumpy=True) + nonalchemical_trajectory[iteration, :, :] = nonalchemical_positions + + # Compute the nonalchemical reduced potential + nonalchemical_context, integrator = cache.global_context_cache.get_context( + nonalchemical_thermodynamic_state) + nonalchemical_sampler_state.apply_to_context(nonalchemical_context, ignore_velocities=True) + nonalchemical_reduced_potential = nonalchemical_thermodynamic_state.reduced_potential( + nonalchemical_context) + + # Compute and store the work + w[iteration] = nonalchemical_reduced_potential - hybrid_reduced_potential + non_potential[iteration] = nonalchemical_reduced_potential + hybrid_potential[iteration] = hybrid_reduced_potential + + if print_work: + print( + f'{iteration:8d} {hybrid_reduced_potential:8.3f} {nonalchemical_reduced_potential:8.3f} => {w[iteration]:8.3f}') + + # TODO: Do we need to write trajectories? Maybe for debugging purposes + if write_trajectories: + if lambda_index == 0: + nonalchemical_mdtraj_topology = md.Topology.from_openmm( + factory._old_topology) + elif lambda_index == 1: + nonalchemical_mdtraj_topology = md.Topology.from_openmm( + factory._new_topology) + md.Trajectory(hybrid_trajectory / unit.nanometers, factory.hybrid_topology).save( + f'hybrid{lambda_index}.pdb') + md.Trajectory(nonalchemical_trajectory / unit.nanometers, + nonalchemical_mdtraj_topology).save(f'nonalchemical{lambda_index}.pdb') + + # Analyze data and return results + [t0, g, Neff_max] = detect_equilibration(w) + w_burned_in = w[t0:] + [df, ddf] = _pymbar_exp(w_burned_in) + ddf_corrected = ddf * np.sqrt(g) + results = [df, ddf_corrected, t0, Neff_max] + + return results, non_potential, hybrid_potential + + +# TODO: temporarily we can depend on perses for this tests while we figure things out, because it seems important +# - CI to require perses for now +# - Then in the future rewrite for it not to require perses +def compare_energies(mol_name="naphthalene", ref_mol_name="benzene", + atom_expression=['Hybridization'], bond_expression=['Hybridization']): + """ + Make an atom map where the molecule at either lambda endpoint is identical, and check that the energies are also the same. + """ + from openmoltools.openeye import generate_conformers + from openmmtools.constants import kB + from perses.rjmc.topology_proposal import SmallMoleculeSetProposalEngine + from perses.rjmc.geometry import FFAllAngleGeometryEngine + from perses.utils.openeye import iupac_to_oemol, extractPositionsFromOEMol + from perses.utils.openeye import generate_expression + from openmmforcefields.generators import SystemGenerator + from openmoltools.forcefield_generators import generateTopologyFromOEMol + from perses.dispersed.utils import validate_endstate_energies + temperature = 300 * unit.kelvin + # Compute kT and inverse temperature. + kT = kB * temperature + beta = 1.0 / kT + ENERGY_THRESHOLD = 1e-6 + + atom_expr, bond_expr = generate_expression(atom_expression), generate_expression( + bond_expression) + + mol = iupac_to_oemol(mol_name) + mol = generate_conformers(mol, max_confs=1) + + refmol = iupac_to_oemol(ref_mol_name) + refmol = generate_conformers(refmol, max_confs=1) + + from openff.toolkit.topology import Molecule + molecules = [Molecule.from_openeye(oemol) for oemol in [refmol, mol]] + barostat = None + forcefield_files = ['amber14/protein.ff14SB.xml', 'amber14/tip3p.xml'] + forcefield_kwargs = {'removeCMMotion': False, 'ewaldErrorTolerance': 1e-4, + 'constraints': app.HBonds, 'hydrogenMass': 3 * unit.amus} + nonperiodic_forcefield_kwargs = {'nonbondedMethod': app.NoCutoff} + + system_generator = SystemGenerator(forcefields=forcefield_files, barostat=barostat, + forcefield_kwargs=forcefield_kwargs, + nonperiodic_forcefield_kwargs=nonperiodic_forcefield_kwargs, + small_molecule_forcefield='gaff-2.11', molecules=molecules, + cache=None) + + # Make a topology proposal with the appropriate data: + topology = generateTopologyFromOEMol(refmol) + system = system_generator.create_system(topology) + positions = extractPositionsFromOEMol(refmol) + + proposal_engine = SmallMoleculeSetProposalEngine([refmol, mol], system_generator, + atom_expr=atom_expr, bond_expr=bond_expr, + allow_ring_breaking=True) + proposal = proposal_engine.propose(system, topology) + geometry_engine = FFAllAngleGeometryEngine() + new_positions, _ = geometry_engine.propose(proposal, positions, beta=beta, + validate_energy_bookkeeping=False) + _ = geometry_engine.logp_reverse(proposal, new_positions, positions, beta) + + # Extract data from proposal to instantiate HTF + htf_data = extract_htf_data(proposal) + factory = HybridTopologyFactory(old_positions=positions, new_positions=new_positions, + **htf_data) + added_valence_energy = (geometry_engine.forward_final_context_reduced_potential + - geometry_engine.forward_atoms_with_positions_reduced_potential) + + subtracted_valence_energy = (geometry_engine.reverse_final_context_reduced_potential + - geometry_engine.reverse_atoms_with_positions_reduced_potential) + + _ = validate_endstate_energies(proposal, + factory, added_valence_energy, + subtracted_valence_energy, + beta=1.0 / (kB * temperature), + ENERGY_THRESHOLD=ENERGY_THRESHOLD, + platform=openmm.Platform.getPlatformByName( + 'Reference')) + return factory + + +def test_compare_energies(): + mols_and_refs = [['naphthalene', 'benzene'], ['pentane', 'propane'], ['biphenyl', 'benzene']] + + for mol_ref_pair in mols_and_refs: + _ = compare_energies(mol_name=mol_ref_pair[0], ref_mol_name=mol_ref_pair[1]) + + +def test_position_output(): + """ + Test that the hybrid returns the correct positions for the new and old systems after construction + """ + import numpy as np + + # Generate topology proposal + topology_proposal, old_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology() + # Extract HTF data from proposal + htf_data = extract_htf_data(topology_proposal) + factory = HybridTopologyFactory(old_positions=old_positions, new_positions=new_positions, + **htf_data) + + old_positions_factory = factory.old_positions(factory.hybrid_positions) + new_positions_factory = factory.new_positions(factory.hybrid_positions) + + assert np.all(np.isclose(old_positions.in_units_of(unit.nanometers), + old_positions_factory.in_units_of(unit.nanometers))) + assert np.all(np.isclose(new_positions.in_units_of(unit.nanometers), + new_positions_factory.in_units_of(unit.nanometers))) + + +def test_generate_endpoint_thermodynamic_states(): + """ + test whether the hybrid system zero and one thermodynamic states have the appropriate lambda values + """ + topology_proposal, current_positions, new_positions = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name='propane', proposed_mol_name='pentane', vacuum=False) + # Extract htf data from proposal + htf_data = extract_htf_data(topology_proposal) + hybrid_factory = HybridTopologyFactory(old_positions=current_positions, + new_positions=new_positions, + **htf_data, + use_dispersion_correction=True) + + # Get the relevant thermodynamic states: + _, _, lambda_zero_thermodynamic_state, lambda_one_thermodynamic_state = generate_endpoint_thermodynamic_states( + hybrid_factory.hybrid_system, topology_proposal) + # Check the parameters for each state + lambda_protocol = ['lambda_sterics_core', 'lambda_electrostatics_core', 'lambda_sterics_insert', + 'lambda_electrostatics_insert', 'lambda_sterics_delete', + 'lambda_electrostatics_delete'] + for value in lambda_protocol: + if getattr(lambda_zero_thermodynamic_state, value) != 0.: + raise Exception( + 'Interaction {} not set to 0. at lambda = 0. {} set to {}'.format(value, value, + getattr( + lambda_one_thermodynamic_state, + value))) + if getattr(lambda_one_thermodynamic_state, value) != 1.: + raise Exception( + 'Interaction {} not set to 1. at lambda = 1. {} set to {}'.format(value, value, + getattr( + lambda_one_thermodynamic_state, + value))) + + +def HybridTopologyFactory_energies(current_mol='toluene', + proposed_mol='1,2-bis(trifluoromethyl) benzene', + validate_geometry_energy_bookkeeping=True): + """ + Test whether the difference in the nonalchemical zero and alchemical zero states is the forward valence energy. Also test for the one states. + """ + import openmmtools.cache as cache + from perses.rjmc.geometry import FFAllAngleGeometryEngine + + # Just test the solvated system + top_proposal, old_positions, _ = perses_utils.generate_solvated_hybrid_test_topology( + current_mol_name=current_mol, proposed_mol_name=proposed_mol) + + # Remove the dispersion correction + force_names_old_system = {force.__class__.__name__: index for index, force in + enumerate(top_proposal._old_system.getForces())} + force_names_new_system = {force.__class__.__name__: index for index, force in + enumerate(top_proposal._new_system.getForces())} + top_proposal._old_system.getForce( + force_names_old_system["NonbondedForce"]).setUseDispersionCorrection(False) + top_proposal._new_system.getForce( + force_names_new_system["NonbondedForce"]).setUseDispersionCorrection(False) + + # Run geometry engine to generate old and new positions + _geometry_engine = FFAllAngleGeometryEngine(metadata=None, use_sterics=False, + n_bond_divisions=100, n_angle_divisions=180, + n_torsion_divisions=360, verbose=True, storage=None, + bond_softening_constant=1.0, + angle_softening_constant=1.0, neglect_angles=False) + _new_positions, _lp = _geometry_engine.propose(top_proposal, old_positions, beta, + validate_geometry_energy_bookkeeping) + _lp_rev = _geometry_engine.logp_reverse(top_proposal, _new_positions, old_positions, beta, + validate_geometry_energy_bookkeeping) + + # Make the hybrid system, reset the CustomNonbondedForce cutoff + # Extract htf data from proposal + htf_data = extract_htf_data(top_proposal) + HTF = HybridTopologyFactory(old_positions=old_positions, new_positions=_new_positions, + **htf_data) + hybrid_system = HTF.hybrid_system + + nonalch_zero, nonalch_one, alch_zero, alch_one = generate_endpoint_thermodynamic_states( + hybrid_system, top_proposal) + + # Compute reduced energies for the nonalchemical systems... + attrib_list = [ + (nonalch_zero, old_positions, top_proposal._old_system.getDefaultPeriodicBoxVectors()), + (alch_zero, HTF._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()), + (alch_one, HTF._hybrid_positions, hybrid_system.getDefaultPeriodicBoxVectors()), + (nonalch_one, _new_positions, top_proposal._new_system.getDefaultPeriodicBoxVectors())] + + rp_list = [] + for (state, pos, box_vectors) in attrib_list: + context, integrator = cache.global_context_cache.get_context(state) + samplerstate = SamplerState(positions=pos, box_vectors=box_vectors) + samplerstate.apply_to_context(context) + rp = state.reduced_potential(context) + rp_list.append(rp) + + # Valence energy definitions + forward_added_valence_energy = _geometry_engine.forward_final_context_reduced_potential - _geometry_engine.forward_atoms_with_positions_reduced_potential + reverse_subtracted_valence_energy = _geometry_engine.reverse_final_context_reduced_potential - _geometry_engine.reverse_atoms_with_positions_reduced_potential + + nonalch_zero_rp, alch_zero_rp, alch_one_rp, nonalch_one_rp = rp_list[0], rp_list[1], rp_list[2], \ + rp_list[3] + # print(f"Difference between zeros: {nonalch_zero_rp - alch_zero_rp}; forward added: {forward_added_valence_energy}") + # print(f"Difference between ones: {nonalch_zero_rp - alch_zero_rp}; forward added: {forward_added_valence_energy}") + + assert abs( + nonalch_zero_rp - alch_zero_rp + forward_added_valence_energy) < ENERGY_THRESHOLD, f"The zero state alchemical and nonalchemical energy absolute difference {abs(nonalch_zero_rp - alch_zero_rp + forward_added_valence_energy)} is greater than the threshold of {ENERGY_THRESHOLD}." + assert abs( + nonalch_one_rp - alch_one_rp + reverse_subtracted_valence_energy) < ENERGY_THRESHOLD, f"The one state alchemical and nonalchemical energy absolute difference {abs(nonalch_one_rp - alch_one_rp + reverse_subtracted_valence_energy)} is greater than the threshold of {ENERGY_THRESHOLD}." + + print( + f"Abs difference in zero alchemical vs nonalchemical systems: {abs(nonalch_zero_rp - alch_zero_rp + forward_added_valence_energy)}") + print( + f"Abs difference in one alchemical vs nonalchemical systems: {abs(nonalch_one_rp - alch_one_rp + reverse_subtracted_valence_energy)}") + + +def test_HybridTopologyFactory_energies( + molecule_perturbation_list=[['naphthalene', 'benzene'], ['pentane', 'propane'], + ['biphenyl', 'benzene']], validations=[False, True, False]): + """ + Test whether the difference in the nonalchemical zero and alchemical zero states is the forward valence energy. Also test for the one states. + """ + for molecule_pair, validate in zip(molecule_perturbation_list, validations): + print(f"\tconduct energy comparison for {molecule_pair[0]} --> {molecule_pair[1]}") + HybridTopologyFactory_energies(current_mol=molecule_pair[0], proposed_mol=molecule_pair[1], + validate_geometry_energy_bookkeeping=validate) + + +# TODO: Current HTF doesn't allow RMSD restraint, but we probably need it in the future? +@pytest.mark.skip( + reason="No implementation of RMSD restraints. Might be needed in the future, though.") +def test_RMSD_restraint(): + """ + test the creation of an RMSD restraint between core heavy atoms and protein CA atoms on a hostguest transformation in a periodic solvent. + will assert the existence of an RMSD force, minimizes at lambda=0, and runs 500 steps of MD. + + """ + from pkg_resources import resource_filename + from perses.app.relative_setup import RelativeFEPSetup + from openmmtools.states import ThermodynamicState, SamplerState + from openmmtools.integrators import LangevinIntegrator + from perses.dispersed.utils import minimize + + # Setup directory + ligand_sdf = resource_filename("perses", "data/given-geometries/ligands.sdf") + host_pdb = resource_filename("perses", "data/given-geometries/receptor.pdb") + + setup = RelativeFEPSetup( + ligand_input=ligand_sdf, + old_ligand_index=0, + new_ligand_index=1, + forcefield_files=['amber/ff14SB.xml', 'amber/tip3p_standard.xml', + 'amber/tip3p_HFE_multivalent.xml'], + phases=['complex', 'solvent', 'vacuum'], + protein_pdb_filename=host_pdb, + receptor_mol2_filename=None, + pressure=1.0 * unit.atmosphere, + temperature=300.0 * unit.kelvin, + solvent_padding=9.0 * unit.angstroms, + ionic_strength=0.15 * unit.molar, + hmass=3 * unit.amus, + neglect_angles=False, + map_strength='default', + atom_expr=None, + bond_expr=None, + anneal_14s=False, + small_molecule_forcefield='gaff-2.11', + small_molecule_parameters_cache=None, + trajectory_directory=None, + trajectory_prefix=None, + spectator_filenames=None, + nonbonded_method='PME', + complex_box_dimensions=None, + solvent_box_dimensions=None, + remove_constraints=False, + use_given_geometries=False + ) + phase = 'complex' + top_prop = setup._complex_topology_proposal + htf = HybridTopologyFactory(setup._complex_topology_proposal, + setup.complex_old_positions, + setup.complex_new_positions, + rmsd_restraint=True + ) + # assert there is at least a CV force + force_names = {htf._hybrid_system.getForce(i).__class__.__name__: htf._hybrid_system.getForce(i) + for i in range(htf._hybrid_system.getNumForces())} + assert 'CustomCVForce' in list(force_names.keys()) + coll_var_name = force_names['CustomCVForce'].getCollectiveVariableName(0) + assert coll_var_name == 'RMSD' + coll_var = force_names['CustomCVForce'].getCollectiveVariable(0) + coll_var_particles = coll_var.getParticles() + assert len( + coll_var_particles) > 0 # the number of particles is nonzero. this will cause problems otherwise + # assert coll_var.usesPeriodicBoundaryConditions() #should this be the case? + + # make thermo and sampler state + thermostate = ThermodynamicState(system=htf._hybrid_system, temperature=300 * unit.kelvin, + pressure=1.0 * unit.atmosphere) + ss = SamplerState(positions=htf._hybrid_positions, + box_vectors=htf._hybrid_system.getDefaultPeriodicBoxVectors()) + + # attempt to minimize + minimize(thermostate, ss) + + # run simulation to validate no nans + integrator = LangevinIntegrator(300 * unit.kelvin, 5.0 / unit.picosecond, + 2.0 * unit.femtosecond) + context = thermostate.create_context(integrator) + ss.apply_to_context(context) + context.setVelocitiesToTemperature(300 * unit.kelvin) + + integrator.step(500) + + +def run_unsampled_endstate_energies(use_point_energies=True, use_md_energies=False): + """ + Check that the energies of the unsampled endstate hybrid systems generated by dispersed/utils.py/create_endstates_from_real_systems() + match the energies of the original hybrid system. + + Checks the energies using validate_unsampled_endstates_point() and/or validate_unsampled_endstates_md() -- first for + RESTCapableHybridTopologyFactory and then for HybridTopologyFactory + + Parameters + ---------- + test_name : str + Name of the test system. Currently supports: 'ala-dipeptide', 'barstar'. + use_point_energies : boolean, default True + Whether to run the point energy test for energy validation. + use_md_energies : boolean, default False + Whether to run the MD energy test for energy validation. + + """ + + from perses.dispersed.utils import create_endstates_from_real_systems + from perses.tests.utils import validate_unsampled_endstates_point, \ + validate_unsampled_endstates_md + + import tempfile + from perses.utils.url_utils import retrieve_file_url + from perses.app.relative_setup import RelativeFEPSetup + + def concatenate_files(input_files, output_file): + """ + Concatenate files given in input_files iterator into output_file. + """ + with open(output_file, 'w') as outfile: + for filename in input_files: + with open(filename) as infile: + for line in infile: + outfile.write(line) + + with tempfile.TemporaryDirectory() as temp_dir: + # Fetch ligands sdf files and concatenate them in one + base_repo_url = "https://github.com/openforcefield/protein-ligand-benchmark" + ligand_files = [] + for ligand in ['lig_ejm_42', 'lig_ejm_54']: + ligand_url = f"{base_repo_url}/raw/0.2.1/data/2020-02-07_tyk2/02_ligands/{ligand}/crd/{ligand}.sdf" + ligand_file = retrieve_file_url(ligand_url) + ligand_files.append(ligand_file) + concatenate_files(ligand_files, os.path.join(temp_dir, 'ligands.sdf')) + ligands_filename = os.path.join(temp_dir, 'ligands.sdf') + + # Retrieve host PDB + pdb_url = f"{base_repo_url}/raw/0.2.1/data/2020-02-07_tyk2/01_protein/crd/protein.pdb" + host_pdb = retrieve_file_url(pdb_url) + + # TODO: This needs to use our setup pipeline (no perses) + # Generate topology proposal, old/new positions + fe_setup = RelativeFEPSetup( + ligand_input=ligands_filename, + protein_pdb_filename=host_pdb, + old_ligand_index=0, + new_ligand_index=1, + forcefield_files=['amber/ff14SB.xml', 'amber/tip3p_standard.xml', + 'amber/tip3p_HFE_multivalent.xml'], + small_molecule_forcefield="gaff-2.11", + phases=["complex"], + ) + + # Generate htfs + # extract htf data from proposal + htf_data = extract_htf_data(fe_setup.complex_topology_proposal) + htf = HybridTopologyFactory( + old_positions=fe_setup.complex_old_positions, + new_positions=fe_setup.complex_new_positions, + **htf_data, + ) + + # Modify the htf for tests + # For these tests, we need to turn the LRC on for the CustomNonbondedForce, since the LRC is on for the real systems + force_dict = {force.getName(): index for index, force in + enumerate(htf.hybrid_system.getForces())} + if htf.__class__.__name__ == 'HybridTopologyFactory': + htf.hybrid_system.getForce( + force_dict['CustomNonbondedForce']).setUseLongRangeCorrection(True) + elif htf.__class__.__name__ == 'RESTCapableHybridTopologyFactory': + htf.hybrid_system.getForce( + force_dict['CustomNonbondedForce_sterics']).setUseLongRangeCorrection(True) + + # Generate unsampled endstates + unsampled_endstates = create_endstates_from_real_systems(htf, for_testing=True) + + # Check to make sure energies are the same in the unsampled endstate hybrid system as they are in the original hybrid system + for endstate in [0, 1]: + if use_point_energies: + validate_unsampled_endstates_point(htf, unsampled_endstates[endstate].system, + endstate, minimize=True) + if use_md_energies: + validate_unsampled_endstates_md(htf, unsampled_endstates[endstate].system, endstate, + n_steps=10, save_freq=1) + + +@pytest.mark.gpu_needed +@pytest.mark.skip(reason="Skip expensive. Needs GPU.") +def test_unsampled_endstate_energies_GPU(): + """ + Uses run_unsampled_endstate_energies() to run energy validation for the unsampled endstates generated for + RESTCapableHybridTopologyFactory and HybridTopologyFactory. + + Test systems: alanine dipeptide in solvent and barstar in solvent + + Only run this on a GPU as the CPU is too slow. + """ + # Tyk2 -- Run point and MD energy validation tests + run_unsampled_endstate_energies(use_point_energies=True, use_md_energies=True) diff --git a/feflow/tests/utils/__init__.py b/feflow/tests/utils/__init__.py new file mode 100644 index 0000000..6bea5da --- /dev/null +++ b/feflow/tests/utils/__init__.py @@ -0,0 +1,2 @@ +from .end_states import generate_endpoint_thermodynamic_states +from .topology_proposal import extract_htf_data \ No newline at end of file diff --git a/feflow/tests/utils/end_states.py b/feflow/tests/utils/end_states.py new file mode 100644 index 0000000..32102da --- /dev/null +++ b/feflow/tests/utils/end_states.py @@ -0,0 +1,98 @@ +import copy +from openmm import unit as omm_unit + + +def check_system(system): + """ + Check OpenMM System object for pathologies, like duplicate atoms in torsions. + + Parameters + ---------- + system : openmm.System + + """ + # from openmm import XmlSerializer + forces = {system.getForce(index).__class__.__name__: system.getForce(index) for index in + range(system.getNumForces())} + force = forces['PeriodicTorsionForce'] + for index in range(force.getNumTorsions()): + [i, j, k, l, _, _, _] = force.getTorsionParameters(index) + if len({i, j, k, l}) < 4: + msg = f'Torsion index {index} of self._topology_proposal.new_system has duplicate atoms: {i} {j} {k} {l}\n' + msg += 'Serialized system to system.xml for inspection.\n' + raise Exception(msg) + # IP: I don't think we need to serialize + # serialized_system = XmlSerializer.serialize(system) + # outfile = open('system.xml', 'w') + # outfile.write(serialized_system) + # outfile.close() + + +def generate_endpoint_thermodynamic_states(system, topology_proposal, repartitioned_endstate=None, + temperature=300.0 * omm_unit.kelvin): + """ + Generate endpoint thermodynamic states for the system + + Parameters + ---------- + system : openmm.System + System object corresponding to thermodynamic state + topology_proposal : perses.rjmc.topology_proposal.TopologyProposal + TopologyProposal representing transformation + repartitioned_endstate : int, default None + If the htf was generated using RepartitionedHybridTopologyFactory, use this argument to + specify the endstate at which it was generated. Otherwise, leave as None. + temperature : openmm.unit.Quantity, default 300 K + Temperature to set when generating the thermodynamic states + + Returns + ------- + nonalchemical_zero_thermodynamic_state : ThermodynamicState + Nonalchemical thermodynamic state for lambda zero endpoint + nonalchemical_one_thermodynamic_state : ThermodynamicState + Nonalchemical thermodynamic state for lambda one endpoint + lambda_zero_thermodynamic_state : ThermodynamicState + Alchemical (hybrid) thermodynamic state for lambda zero + lambda_one_thermodynamic_State : ThermodynamicState + Alchemical (hybrid) thermodynamic state for lambda one + """ + # Create the thermodynamic state + from feflow.utils.lambda_protocol import RelativeAlchemicalState + from openmmtools import states + + check_system(system) + + # Create thermodynamic states for the nonalchemical endpoints + nonalchemical_zero_thermodynamic_state = states.ThermodynamicState(topology_proposal.old_system, + temperature=temperature) + nonalchemical_one_thermodynamic_state = states.ThermodynamicState(topology_proposal.new_system, + temperature=temperature) + + # Create the base thermodynamic state with the hybrid system + thermodynamic_state = states.ThermodynamicState(system, temperature=temperature) + + if repartitioned_endstate == 0: + lambda_zero_thermodynamic_state = thermodynamic_state + lambda_one_thermodynamic_state = None + elif repartitioned_endstate == 1: + lambda_zero_thermodynamic_state = None + lambda_one_thermodynamic_state = thermodynamic_state + else: + # Create relative alchemical states + lambda_zero_alchemical_state = RelativeAlchemicalState.from_system(system) + lambda_one_alchemical_state = copy.deepcopy(lambda_zero_alchemical_state) + + # Ensure their states are set appropriately + lambda_zero_alchemical_state.set_alchemical_parameters(0.0) + lambda_one_alchemical_state.set_alchemical_parameters(1.0) + + # Now create the compound states with different alchemical states + lambda_zero_thermodynamic_state = ( + states.CompoundThermodynamicState(thermodynamic_state, + composable_states=[lambda_zero_alchemical_state])) + lambda_one_thermodynamic_state = ( + states.CompoundThermodynamicState(thermodynamic_state, + composable_states=[lambda_one_alchemical_state])) + + return (nonalchemical_zero_thermodynamic_state, nonalchemical_one_thermodynamic_state, + lambda_zero_thermodynamic_state, lambda_one_thermodynamic_state) diff --git a/feflow/tests/utils/topology_proposal.py b/feflow/tests/utils/topology_proposal.py new file mode 100644 index 0000000..cc35a3e --- /dev/null +++ b/feflow/tests/utils/topology_proposal.py @@ -0,0 +1,45 @@ +""" +Utility module for tests to process perses TopologyProposal objects to extract objects useful +for FEFlow/OpenFE +""" + +from perses.rjmc.topology_proposal import TopologyProposal + +def extract_htf_data(top_proposal: TopologyProposal): + """ + Extract OpenMM system and OpenMM topology data objects from a perses TopologyProposal object. + In order to be passed to the HybridTopologyFactory constructor. + + Parameters + ---------- + top_proposal: perses.rjmc.topology_proposal.TopologyProposal + Instance of TopologyProposal class from perses where to extract the data from. + + Returns + ------- + htf_data: dict + Dictionary with the data for the HybridTopologyFactory constructor. + Keys are "old/new_system", "old/new_topology". + """ + # Extract systems + old_system = top_proposal.old_system + new_system = top_proposal.new_system + # Extract coordinates + old_topology = top_proposal.old_topology + new_topology = top_proposal.new_topology + # Extract atom maps + old_to_new_atom_map = top_proposal.old_to_new_atom_map + # TODO: Check that core atoms are understood as the same in Perses. I'm not sure they are. + old_to_new_core_atom_map = {value: key for key, value in + top_proposal.core_new_to_old_atom_map.items()} + + htf_data = { + "old_system": old_system, + "new_system": new_system, + "old_topology": old_topology, + "new_topology": new_topology, + "old_to_new_atom_map": old_to_new_atom_map, + "old_to_new_core_atom_map": old_to_new_core_atom_map + } + + return htf_data \ No newline at end of file diff --git a/feflow/utils/hybrid_topology.py b/feflow/utils/hybrid_topology.py index ee73a22..83f7049 100644 --- a/feflow/utils/hybrid_topology.py +++ b/feflow/utils/hybrid_topology.py @@ -1,11 +1,80 @@ -from openfe.protocols.openmm_rfe._rfe_utils.relative import HybridTopologyFactory +# This code is a slightly modified version of the HybridTopologyFactory code +# from https://github.com/choderalab/perses +# The eventual goal is to move a version of this towards openmmtools +# LICENSE: MIT +import logging +import openmm +from openmm import unit, app +import numpy as np +import copy +import itertools +# OpenMM constant for Coulomb interactions (implicitly in md_unit_system units) +from openmmtools.constants import ONE_4PI_EPS0 +import mdtraj as mdt -# TODO: This is an utility class. To be deleted when we migrate/extend the base HybridTopologyFactory -class HybridTopologyFactoryModded(HybridTopologyFactory): +logger = logging.getLogger(__name__) + + +class HybridTopologyFactory: """ - Utility class that extends the base HybridTopologyFactory class with properties for - getting the indices from initial and final states. + This class generates a hybrid topology based on two input systems and an + atom mapping. For convenience the states are called "old" and "new" + respectively, defining the starting and end states along the alchemical + transformation. + + The input systems are assumed to have: + 1. The total number of molecules + 2. The same coordinates for equivalent atoms + + Atoms in the resulting hybrid system are treated as being from one + of four possible types: + + unique_old_atom : These atoms are not mapped and only present in the old + system. Their interactions will be on for lambda=0, off for lambda=1 + unique_new_atom : These atoms are not mapped and only present in the new + system. Their interactions will be off for lambda=0, on for lambda=1 + core_atom : These atoms are mapped between the two end states, and are + part of a residue that is changing alchemically. Their interactions + will be those corresponding to the old system at lambda=0, and those + corresponding to the new system at lambda=1 + environment_atom : These atoms are mapped between the two end states, and + are not part of a residue undergoing an alchemical change. Their + interactions are always on and are alchemically unmodified. + + Properties + ---------- + hybrid_system : openmm.System + The hybrid system for simulation + new_to_hybrid_atom_map : dict of int : int + The mapping of new system atoms to hybrid atoms + old_to_hybrid_atom_map : dict of int : int + The mapping of old system atoms to hybrid atoms + hybrid_positions : [n, 3] np.ndarray + The positions of the hybrid system + hybrid_topology : mdtraj.Topology + The topology of the hybrid system + omm_hybrid_topology : openmm.app.Topology + The OpenMM topology object corresponding to the hybrid system + + .. warning :: This API is experimental and subject to change. + + Notes + ----- + * Logging has been removed and will be revamped at a later date. + * The ability to define custom functions has been removed for now. + * Neglected angle terms have been removed for now. + * RMSD restraint option has been removed for now. + * Endstate support has been removed for now. + * Bond softening has been removed for now. + * Unused InteractionGroup code paths have been removed. + + TODO + ---- + * Document how positions for hybrid system are constructed. + * Allow support for annealing in omitted terms. + * Implement omitted terms (this was not available in the original class). + """ def __init__(self, @@ -17,7 +86,6 @@ def __init__(self, softcore_LJ_v2=True, softcore_LJ_v2_alpha=0.85, interpolate_old_and_new_14s=False, - flatten_torsions=False, **kwargs): """ Initialize the Hybrid topology factory. @@ -56,23 +124,2437 @@ def __init__(self, Whether to turn off interactions for new exceptions (not just 1,4s) at lambda = 0 and old exceptions at lambda = 1; if False, they are present in the nonbonded force. - flatten_torsions : bool, default False - If True, torsion terms involving `unique_new_atoms` will be - scaled such that at lambda=0,1, the torsion term is turned off/on - respectively. The opposite is true for `unique_old_atoms`. - """ - super().__init__(old_system, old_positions, old_topology, - new_system, new_positions, new_topology, - old_to_new_atom_map, old_to_new_core_atom_map, - use_dispersion_correction=use_dispersion_correction, - softcore_alpha=softcore_alpha, - softcore_LJ_v2=softcore_LJ_v2, - softcore_LJ_v2_alpha=softcore_LJ_v2_alpha, - interpolate_old_and_new_14s=interpolate_old_and_new_14s, - flatten_torsions=flatten_torsions, - **kwargs) - - # TODO: We need to refactor for the init to use these properties and have an attribute with the indices + """ + + # Assign system positions and force + # IA - Are deep copies really needed here? + self._old_system = copy.deepcopy(old_system) + self._old_positions = old_positions + self._old_topology = old_topology + self._new_system = copy.deepcopy(new_system) + self._new_positions = new_positions + self._new_topology = new_topology + self._hybrid_system_forces = dict() + + # Set mappings (full, core, and env maps) + self._set_mappings(old_to_new_atom_map, old_to_new_core_atom_map) + + # Other options + self._use_dispersion_correction = use_dispersion_correction + self._interpolate_14s = interpolate_old_and_new_14s + + # Sofcore options + self._softcore_alpha = softcore_alpha + self._check_bounds(softcore_alpha, "softcore_alpha") # [0,1] check + + self._softcore_LJ_v2 = softcore_LJ_v2 + if self._softcore_LJ_v2: + self._check_bounds(softcore_LJ_v2_alpha, "softcore_LJ_v2_alpha") + self._softcore_LJ_v2_alpha = softcore_LJ_v2_alpha + + # TODO: end __init__ here and move everything else to + # create_hybrid_system() or equivalent + + self._check_and_store_system_forces() + + logger.info("Creating hybrid system") + # Create empty system that will become the hybrid system + self._hybrid_system = openmm.System() + + # Add particles to system + self._add_particles() + + # Add box + barostat + self._handle_box() + + # Assign atoms to one of the classes described in the class docstring + # Renamed from original _determine_atom_classes + self._set_atom_classes() + + # Construct dictionary of exceptions in old and new systems + self._old_system_exceptions = self._generate_dict_from_exceptions( + self._old_system_forces['NonbondedForce']) + self._new_system_exceptions = self._generate_dict_from_exceptions( + self._new_system_forces['NonbondedForce']) + + # check for exceptions clashes between unique and env atoms + self._validate_disjoint_sets() + + logger.info("Setting force field terms") + # Copy constraints, checking to make sure they are not changing + self._handle_constraints() + + # Copy over relevant virtual sites - pick up refactor from here + self._handle_virtual_sites() + + # TODO - move to a single method call? Would be good to group these + # Call each of the force methods to add the corresponding force terms + # and prepare the forces: + self._add_bond_force_terms() + + self._add_angle_force_terms() + + self._add_torsion_force_terms() + + has_nonbonded_force = ('NonbondedForce' in self._old_system_forces or + 'NonbondedForce' in self._new_system_forces) + + if has_nonbonded_force: + self._add_nonbonded_force_terms() + + # Call each force preparation method to generate the actual + # interactions that we need: + logger.info("Adding forces") + self._handle_harmonic_bonds() + + self._handle_harmonic_angles() + + self._handle_periodic_torsion_force() + + if has_nonbonded_force: + self._handle_nonbonded() + if not (len(self._old_system_exceptions.keys()) == 0 and + len(self._new_system_exceptions.keys()) == 0): + self._handle_old_new_exceptions() + + # Get positions for the hybrid + self._hybrid_positions = self._compute_hybrid_positions() + + # Get an MDTraj topology for writing + self._hybrid_topology = self._create_mdtraj_topology() + self._omm_hybrid_topology = self._create_hybrid_topology() + logger.info("Hybrid system created") + + @staticmethod + def _check_bounds(value, varname, minmax=(0, 1)): + """ + Convenience method to check the bounds of a value. + + Parameters + ---------- + value : float + Value to evaluate. + varname : str + Name of value to raise in error message + minmax : tuple + Two element tuple with the lower and upper bounds to check. + + Raises + ------ + AssertionError + If value is lower or greater than bounds. + """ + if value < minmax[0] or value > minmax[1]: + raise AssertionError(f"{varname} is not in {minmax}") + + @staticmethod + def _invert_dict(dictionary): + """ + Convenience method to invert a dictionary (since we do it so often). + + Paramters: + ---------- + dictionary : dict + Dictionary you want to invert + """ + return {v: k for k, v in dictionary.items()} + + def _set_mappings(self, old_to_new_map, core_old_to_new_map): + """ + Parameters + ---------- + old_to_new_map : dict of int : int + Dictionary mapping atoms between the old and new systems. + + Notes + ----- + * For now this directly sets the system, core and env old_to_new_map, + new_to_old_map, an empty new_to_hybrid_map and an empty + old_to_hybrid_map. In the future this will be moved to the one + dictionary to make things a lot less confusing. + """ + self._old_to_new_map = old_to_new_map + self._core_old_to_new_map = core_old_to_new_map + self._new_to_old_map = self._invert_dict(old_to_new_map) + self._core_new_to_old_map = self._invert_dict(core_old_to_new_map) + self._old_to_hybrid_map = {} + self._new_to_hybrid_map = {} + + # Get unique atoms + # old system first + self._unique_old_atoms = [] + for particle_idx in range(self._old_system.getNumParticles()): + if particle_idx not in self._old_to_new_map.keys(): + self._unique_old_atoms.append(particle_idx) + + self._unique_new_atoms = [] + for particle_idx in range(self._new_system.getNumParticles()): + if particle_idx not in self._new_to_old_map.keys(): + self._unique_new_atoms.append(particle_idx) + + # Get env atoms (i.e. atoms mapped not in core) + self._env_old_to_new_map = {} + for key, value in old_to_new_map.items(): + if key not in self._core_old_to_new_map.keys(): + self._env_old_to_new_map[key] = value + + self._env_new_to_old_map = self._invert_dict(self._env_old_to_new_map) + + # IA - Internal check for now (move to test later) + num_env = len(self._env_old_to_new_map.keys()) + num_core = len(self._core_old_to_new_map.keys()) + num_total = len(self._old_to_new_map.keys()) + assert num_env + num_core == num_total + + def _check_and_store_system_forces(self): + """ + Conveniently stores the system forces and checks that no unknown + forces exist. + """ + + def _check_unknown_forces(forces, system_name): + # TODO: double check that CMMotionRemover is ok being here + known_forces = {'HarmonicBondForce', 'HarmonicAngleForce', + 'PeriodicTorsionForce', 'NonbondedForce', + 'MonteCarloBarostat', 'CMMotionRemover'} + + force_names = forces.keys() + unknown_forces = set(force_names) - set(known_forces) + if unknown_forces: + errmsg = (f"Unknown forces {unknown_forces} encountered in " + f"{system_name} system") + raise ValueError(errmsg) + + # Prepare dicts of forces, which will be useful later + # TODO: Store this as self._system_forces[name], name in ('old', + # 'new', 'hybrid') for compactness + self._old_system_forces = {type(force).__name__: force for force in + self._old_system.getForces()} + _check_unknown_forces(self._old_system_forces, 'old') + self._new_system_forces = {type(force).__name__: force for force in + self._new_system.getForces()} + _check_unknown_forces(self._new_system_forces, 'new') + + # TODO: check if this is actually used much, otherwise ditch it + # Get and store the nonbonded method from the system: + self._nonbonded_method = self._old_system_forces['NonbondedForce'].getNonbondedMethod() + + def _add_particles(self): + """ + Adds particles to the hybrid system. + + This does not copy over interactions, but does copy over the masses. + + Note + ---- + * If there is a difference in masses between the old and new systems + the average mass of the two is used. + + TODO + ---- + * Review influence of lack of mass scaling. + """ + # Begin by copying all particles in the old system + for particle_idx in range(self._old_system.getNumParticles()): + mass_old = self._old_system.getParticleMass(particle_idx) + + if particle_idx in self._old_to_new_map.keys(): + particle_idx_new_system = self._old_to_new_map[particle_idx] + mass_new = self._new_system.getParticleMass( + particle_idx_new_system) + # Take the average of the masses if the atom is mapped + particle_mass = (mass_old + mass_new) / 2 + else: + particle_mass = mass_old + + hybrid_idx = self._hybrid_system.addParticle(particle_mass) + self._old_to_hybrid_map[particle_idx] = hybrid_idx + + # If the particle index in question is mapped, make sure to add it + # to the new to hybrid map as well. + if particle_idx in self._old_to_new_map.keys(): + self._new_to_hybrid_map[particle_idx_new_system] = hybrid_idx + + # Next, add the remaining unique atoms from the new system to the + # hybrid system and map accordingly. + for particle_idx in self._unique_new_atoms: + particle_mass = self._new_system.getParticleMass(particle_idx) + hybrid_idx = self._hybrid_system.addParticle(particle_mass) + self._new_to_hybrid_map[particle_idx] = hybrid_idx + + # Create the opposite atom maps for later use (nonbonded processing) + self._hybrid_to_old_map = self._invert_dict(self._old_to_hybrid_map) + self._hybrid_to_new_map = self._invert_dict(self._new_to_hybrid_map) + + def _handle_box(self): + """ + Copies over the barostat and box vectors as necessary. + """ + # Check that if there is a barostat in the old system, + # it is added to the hybrid system + if "MonteCarloBarostat" in self._old_system_forces.keys(): + barostat = copy.deepcopy( + self._old_system_forces["MonteCarloBarostat"]) + self._hybrid_system.addForce(barostat) + + # Copy over the box vectors from the old system + box_vectors = self._old_system.getDefaultPeriodicBoxVectors() + self._hybrid_system.setDefaultPeriodicBoxVectors(*box_vectors) + + def _set_atom_classes(self): + """ + This method determines whether each atom belongs to unique old, + unique new, core, or environment, as defined in the class docstring. + All indices are indices in the hybrid system. + """ + self._atom_classes = {'unique_old_atoms': set(), + 'unique_new_atoms': set(), + 'core_atoms': set(), + 'environment_atoms': set()} + + # First, find the unique old atoms + for atom_idx in self._unique_old_atoms: + hybrid_idx = self._old_to_hybrid_map[atom_idx] + self._atom_classes['unique_old_atoms'].add(hybrid_idx) + + # Then the unique new atoms + for atom_idx in self._unique_new_atoms: + hybrid_idx = self._new_to_hybrid_map[atom_idx] + self._atom_classes['unique_new_atoms'].add(hybrid_idx) + + # The core atoms: + for new_idx, old_idx in self._core_new_to_old_map.items(): + new_to_hybrid_idx = self._new_to_hybrid_map[new_idx] + old_to_hybrid_idx = self._old_to_hybrid_map[old_idx] + if new_to_hybrid_idx != old_to_hybrid_idx: + errmsg = (f"there is an index collision in hybrid indices of " + f"the core atom map: {self._core_new_to_old_map}") + raise AssertionError(errmsg) + self._atom_classes['core_atoms'].add(new_to_hybrid_idx) + + # The environment atoms: + for new_idx, old_idx in self._env_new_to_old_map.items(): + new_to_hybrid_idx = self._new_to_hybrid_map[new_idx] + old_to_hybrid_idx = self._old_to_hybrid_map[old_idx] + if new_to_hybrid_idx != old_to_hybrid_idx: + errmsg = (f"there is an index collion in hybrid indices of " + f"the environment atom map: " + f"{self._env_new_to_old_map}") + raise AssertionError(errmsg) + self._atom_classes['environment_atoms'].add(new_to_hybrid_idx) + + @staticmethod + def _generate_dict_from_exceptions(force): + """ + This is a utility function to generate a dictionary of the form + (particle1_idx, particle2_idx) : [exception parameters]. + This will facilitate access and search of exceptions. + + Parameters + ---------- + force : openmm.NonbondedForce object + a force containing exceptions + + Returns + ------- + exceptions_dict : dict + Dictionary of exceptions + """ + exceptions_dict = {} + + for exception_index in range(force.getNumExceptions()): + [index1, index2, chargeProd, sigma, epsilon] = force.getExceptionParameters(exception_index) + exceptions_dict[(index1, index2)] = [chargeProd, sigma, epsilon] + + return exceptions_dict + + def _validate_disjoint_sets(self): + """ + Conduct a sanity check to make sure that the hybrid maps of the old + and new system exception dict keys do not contain both environment + and unique_old/new atoms. + + TODO: repeated code - condense + """ + for old_indices in self._old_system_exceptions.keys(): + hybrid_indices = (self._old_to_hybrid_map[old_indices[0]], + self._old_to_hybrid_map[old_indices[1]]) + old_env_intersection = set(old_indices).intersection( + self._atom_classes['environment_atoms']) + if old_env_intersection: + if set(old_indices).intersection( + self._atom_classes['unique_old_atoms'] + ): + errmsg = (f"old index exceptions {old_indices} include " + "unique old and environment atoms, which is " + "disallowed") + raise AssertionError(errmsg) + + for new_indices in self._new_system_exceptions.keys(): + hybrid_indices = (self._new_to_hybrid_map[new_indices[0]], + self._new_to_hybrid_map[new_indices[1]]) + new_env_intersection = set(hybrid_indices).intersection( + self._atom_classes['environment_atoms']) + if new_env_intersection: + if set(hybrid_indices).intersection( + self._atom_classes['unique_new_atoms'] + ): + errmsg = (f"new index exceptions {new_indices} include " + "unique new and environment atoms, which is " + "dissallowed") + raise AssertionError + + def _handle_constraints(self): + """ + This method adds relevant constraints from the old and new systems. + + First, all constraints from the old systenm are added. + Then, constraints to atoms unique to the new system are added. + + TODO: condense duplicated code + """ + # lengths of constraints already added + constraint_lengths = dict() + + # old system + hybrid_map = self._old_to_hybrid_map + for const_idx in range(self._old_system.getNumConstraints()): + at1, at2, length = self._old_system.getConstraintParameters( + const_idx) + hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) + if hybrid_atoms not in constraint_lengths.keys(): + self._hybrid_system.addConstraint(hybrid_atoms[0], + hybrid_atoms[1], length) + constraint_lengths[hybrid_atoms] = length + else: + + if constraint_lengths[hybrid_atoms] != length: + raise AssertionError('constraint length is changing') + + # new system + hybrid_map = self._new_to_hybrid_map + for const_idx in range(self._new_system.getNumConstraints()): + at1, at2, length = self._new_system.getConstraintParameters( + const_idx) + hybrid_atoms = tuple(sorted([hybrid_map[at1], hybrid_map[at2]])) + if hybrid_atoms not in constraint_lengths.keys(): + self._hybrid_system.addConstraint(hybrid_atoms[0], + hybrid_atoms[1], length) + constraint_lengths[hybrid_atoms] = length + else: + if constraint_lengths[hybrid_atoms] != length: + raise AssertionError('constraint length is changing') + + @staticmethod + def _copy_threeparticleavg(atm_map, env_atoms, vs): + """ + Helper method to copy a ThreeParticleAverageSite virtual site + from two mapped Systems. + + Parameters + ---------- + atm_map : dict[int, int] + The atom map correspondance between the two Systems. + env_atoms: set[int] + A list of environment atoms for the target System. This + checks that no alchemical atoms are being tied to. + vs : openmm.ThreeParticleAverageSite + + Returns + ------- + openmm.ThreeParticleAverageSite + """ + particles = {} + weights = {} + for i in range(vs.getNumParticles()): + particles[i] = atm_map[vs.getParticle(i)] + weights[i] = vs.getWeight(i) + if not all(i in env_atoms for i in particles.values()): + errmsg = ("Virtual sites bound to non-environment atoms " + "are not supported") + raise ValueError(errmsg) + return openmm.ThreeParticleAverageSite( + particles[0], particles[1], particles[2], + weights[0], weights[1], weights[2], + ) + + def _handle_virtual_sites(self): + """ + Ensure that all virtual sites in old and new system are copied over to + the hybrid system. Note that we do not support virtual sites in the + changing region. + + TODO - remerge into a single loop + TODO - check that it's fine to double count here (even so, there's + an optimisation that could be done here...) + """ + # old system + # Loop through virtual sites + for particle_idx in range(self._old_system.getNumParticles()): + if self._old_system.isVirtualSite(particle_idx): + # If it's a virtual site, make sure it is not in the unique or + # core atoms, since this is currently unsupported + hybrid_idx = self._old_to_hybrid_map[particle_idx] + if hybrid_idx not in self._atom_classes['environment_atoms']: + errmsg = ("Virtual sites in changing residue are " + "unsupported.") + raise ValueError(errmsg) + else: + virtual_site = self._old_system.getVirtualSite( + particle_idx) + if isinstance( + virtual_site, openmm.ThreeParticleAverageSite): + vs_copy = self._copy_threeparticleavg( + self._old_to_hybrid_map, + self._atom_classes['environment_atoms'], + virtual_site, + ) + else: + errmsg = ("Unsupported VirtualSite " + f"class: {virtual_site}") + raise ValueError(errmsg) + + self._hybrid_system.setVirtualSite(hybrid_idx, + vs_copy) + + # new system - there should be nothing left to add + # Loop through virtual sites + for particle_idx in range(self._new_system.getNumParticles()): + if self._new_system.isVirtualSite(particle_idx): + # If it's a virtual site, make sure it is not in the unique or + # core atoms, since this is currently unsupported + hybrid_idx = self._new_to_hybrid_map[particle_idx] + if hybrid_idx not in self._atom_classes['environment_atoms']: + errmsg = ("Virtual sites in changing residue are " + "unsupported.") + raise ValueError(errmsg) + else: + if not self._hybrid_system.isVirtualSite(hybrid_idx): + errmsg = ("Environment virtual site in new system " + "found not copied from old system") + raise ValueError(errmsg) + + def _add_bond_force_terms(self): + """ + This function adds the appropriate bond forces to the system + (according to groups defined in the main class docstring). Note that + it does _not_ add the particles to the force. It only adds the force + to facilitate another method adding the particles to the force. + + Notes + ----- + * User defined functions have been removed for now. + """ + core_energy_expression = '(K/2)*(r-length)^2;' + # linearly interpolate spring constant + core_energy_expression += 'K = (1-lambda_bonds)*K1 + lambda_bonds*K2;' + # linearly interpolate bond length + core_energy_expression += 'length = (1-lambda_bonds)*length1 + lambda_bonds*length2;' + + # Create the force and add the relevant parameters + custom_core_force = openmm.CustomBondForce(core_energy_expression) + custom_core_force.addPerBondParameter('length1') # old bond length + custom_core_force.addPerBondParameter('K1') # old spring constant + custom_core_force.addPerBondParameter('length2') # new bond length + custom_core_force.addPerBondParameter('K2') # new spring constant + + custom_core_force.addGlobalParameter('lambda_bonds', 0.0) + + self._hybrid_system.addForce(custom_core_force) + self._hybrid_system_forces['core_bond_force'] = custom_core_force + + # Add a bond force for environment and unique atoms (bonds are never + # scaled for these): + standard_bond_force = openmm.HarmonicBondForce() + self._hybrid_system.addForce(standard_bond_force) + self._hybrid_system_forces['standard_bond_force'] = standard_bond_force + + def _add_angle_force_terms(self): + """ + This function adds the appropriate angle force terms to the hybrid + system. It does not add particles or parameters to the force; this is + done elsewhere. + + Notes + ----- + * User defined functions have been removed for now. + * Neglected angle terms have been removed for now. + """ + energy_expression = '(K/2)*(theta-theta0)^2;' + # linearly interpolate spring constant + energy_expression += 'K = (1.0-lambda_angles)*K_1 + lambda_angles*K_2;' + # linearly interpolate equilibrium angle + energy_expression += 'theta0 = (1.0-lambda_angles)*theta0_1 + lambda_angles*theta0_2;' + + # Create the force and add relevant parameters + custom_core_force = openmm.CustomAngleForce(energy_expression) + # molecule1 equilibrium angle + custom_core_force.addPerAngleParameter('theta0_1') + # molecule1 spring constant + custom_core_force.addPerAngleParameter('K_1') + # molecule2 equilibrium angle + custom_core_force.addPerAngleParameter('theta0_2') + # molecule2 spring constant + custom_core_force.addPerAngleParameter('K_2') + + custom_core_force.addGlobalParameter('lambda_angles', 0.0) + + # Add the force to the system and the force dict. + self._hybrid_system.addForce(custom_core_force) + self._hybrid_system_forces['core_angle_force'] = custom_core_force + + # Add an angle term for environment/unique interactions -- these are + # never scaled + standard_angle_force = openmm.HarmonicAngleForce() + self._hybrid_system.addForce(standard_angle_force) + self._hybrid_system_forces['standard_angle_force'] = standard_angle_force + + def _add_torsion_force_terms(self): + """ + This function adds the appropriate PeriodicTorsionForce terms to the + system. Core torsions are interpolated, while environment and unique + torsions are always on. + + Notes + ----- + * User defined functions have been removed for now. + * Options for add_custom_core_force (default True) and + add_unique_atom_torsion_force (default True) have been removed for + now. + """ + energy_expression = '(1-lambda_torsions)*U1 + lambda_torsions*U2;' + energy_expression += 'U1 = K1*(1+cos(periodicity1*theta-phase1));' + energy_expression += 'U2 = K2*(1+cos(periodicity2*theta-phase2));' + + # Create the force and add the relevant parameters + custom_core_force = openmm.CustomTorsionForce(energy_expression) + # molecule1 periodicity + custom_core_force.addPerTorsionParameter('periodicity1') + # molecule1 phase + custom_core_force.addPerTorsionParameter('phase1') + # molecule1 spring constant + custom_core_force.addPerTorsionParameter('K1') + # molecule2 periodicity + custom_core_force.addPerTorsionParameter('periodicity2') + # molecule2 phase + custom_core_force.addPerTorsionParameter('phase2') + # molecule2 spring constant + custom_core_force.addPerTorsionParameter('K2') + + custom_core_force.addGlobalParameter('lambda_torsions', 0.0) + + # Add the force to the system + self._hybrid_system.addForce(custom_core_force) + self._hybrid_system_forces['custom_torsion_force'] = custom_core_force + + # Create and add the torsion term for unique/environment atoms + unique_atom_torsion_force = openmm.PeriodicTorsionForce() + self._hybrid_system.addForce(unique_atom_torsion_force) + self._hybrid_system_forces['unique_atom_torsion_force'] = unique_atom_torsion_force + + @staticmethod + def _nonbonded_custom(v2): + """ + Get a part of the nonbonded energy expression when there is no cutoff. + + Parameters + ---------- + v2 : bool + Whether to use the softcore methods as defined by Gapsys et al. + JCTC 2012. + + Returns + ------- + sterics_energy_expression : str + The energy expression for U_sterics + electrostatics_energy_expression : str + The energy expression for electrostatics + + TODO + ---- + * Move to a dictionary or equivalent. + """ + # Soft-core Lennard-Jones + if v2: + sterics_energy_expression = "U_sterics = select(step(r - r_LJ), 4*epsilon*x*(x-1.0), U_sterics_quad);" + sterics_energy_expression += "U_sterics_quad = Force*(((r - r_LJ)^2)/2 - (r - r_LJ)) + U_sterics_cut;" + sterics_energy_expression += "U_sterics_cut = 4*epsilon*((sigma/r_LJ)^6)*(((sigma/r_LJ)^6) - 1.0);" + sterics_energy_expression += "Force = -4*epsilon*((-12*sigma^12)/(r_LJ^13) + (6*sigma^6)/(r_LJ^7));" + sterics_energy_expression += "x = (sigma/r)^6;" + sterics_energy_expression += "r_LJ = softcore_alpha*((26/7)*(sigma^6)*lambda_sterics_deprecated)^(1/6);" + sterics_energy_expression += "lambda_sterics_deprecated = new_interaction*(1.0 - lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + else: + sterics_energy_expression = "U_sterics = 4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;" + + return sterics_energy_expression + + @staticmethod + def _nonbonded_custom_sterics_common(): + """ + Get a custom sterics expression using amber softcore expression + + Returns + ------- + sterics_addition : str + The common softcore sterics energy expression + + TODO + ---- + * Move to a dictionary or equivalent. + """ + # interpolation + sterics_addition = "epsilon = (1-lambda_sterics)*epsilonA + lambda_sterics*epsilonB;" + # effective softcore distance for sterics + sterics_addition += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" + sterics_addition += "sigma = (1-lambda_sterics)*sigmaA + lambda_sterics*sigmaB;" + + sterics_addition += "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + sterics_addition += "lambda_sterics = core_interaction*lambda_sterics_core + new_interaction*lambda_sterics_insert + old_interaction*lambda_sterics_delete;" + sterics_addition += "core_interaction = delta(unique_old1+unique_old2+unique_new1+unique_new2);new_interaction = max(unique_new1, unique_new2);old_interaction = max(unique_old1, unique_old2);" + + return sterics_addition + + @staticmethod + def _nonbonded_custom_mixing_rules(): + """ + Mixing rules for the custom nonbonded force. + + Returns + ------- + sterics_mixing_rules : str + The mixing expression for sterics + electrostatics_mixing_rules : str + The mixiing rules for electrostatics + + TODO + ---- + * Move to a dictionary or equivalent. + """ + # Define mixing rules. + # mixing rule for epsilon + sterics_mixing_rules = "epsilonA = sqrt(epsilonA1*epsilonA2);" + # mixing rule for epsilon + sterics_mixing_rules += "epsilonB = sqrt(epsilonB1*epsilonB2);" + # mixing rule for sigma + sterics_mixing_rules += "sigmaA = 0.5*(sigmaA1 + sigmaA2);" + # mixing rule for sigma + sterics_mixing_rules += "sigmaB = 0.5*(sigmaB1 + sigmaB2);" + return sterics_mixing_rules + + @staticmethod + def _translate_nonbonded_method_to_custom(standard_nonbonded_method): + """ + Utility function to translate the nonbonded method enum from the + standard nonbonded force to the custom version + `CutoffPeriodic`, `PME`, and `Ewald` all become `CutoffPeriodic`; + `NoCutoff` becomes `NoCutoff`; `CutoffNonPeriodic` becomes + `CutoffNonPeriodic` + + Parameters + ---------- + standard_nonbonded_method : openmm.NonbondedForce.NonbondedMethod + the nonbonded method of the standard force + + Returns + ------- + custom_nonbonded_method : openmm.CustomNonbondedForce.NonbondedMethod + the nonbonded method for the equivalent customnonbonded force + """ + if standard_nonbonded_method in [openmm.NonbondedForce.CutoffPeriodic, + openmm.NonbondedForce.PME, + openmm.NonbondedForce.Ewald]: + return openmm.CustomNonbondedForce.CutoffPeriodic + elif standard_nonbonded_method == openmm.NonbondedForce.NoCutoff: + return openmm.CustomNonbondedForce.NoCutoff + elif standard_nonbonded_method == openmm.NonbondedForce.CutoffNonPeriodic: + return openmm.CustomNonbondedForce.CutoffNonPeriodic + else: + errmsg = "This nonbonded method is not supported." + raise NotImplementedError(errmsg) + + def _add_nonbonded_force_terms(self): + """ + Add the nonbonded force terms to the hybrid system. Note that as with + the other forces, this method does not add any interactions. It only + sets up the forces. + + Notes + ----- + * User defined functions have been removed for now. + * Argument `add_custom_sterics_force` (default True) has been removed + for now. + + TODO + ---- + * Move nonbonded_method defn here to avoid just setting it globally + and polluting `self`. + """ + # Add a regular nonbonded force for all interactions that are not + # changing. + standard_nonbonded_force = openmm.NonbondedForce() + self._hybrid_system.addForce(standard_nonbonded_force) + self._hybrid_system_forces['standard_nonbonded_force'] = standard_nonbonded_force + + # Create a CustomNonbondedForce to handle alchemically interpolated + # nonbonded parameters. + # Select functional form based on nonbonded method. + # TODO: check _nonbonded_custom_ewald and _nonbonded_custom_cutoff + # since they take arguments that are never used... + r_cutoff = self._old_system_forces['NonbondedForce'].getCutoffDistance() + sterics_energy_expression = self._nonbonded_custom(self._softcore_LJ_v2) + if self._nonbonded_method in [openmm.NonbondedForce.NoCutoff]: + pass + elif self._nonbonded_method in [openmm.NonbondedForce.CutoffPeriodic, + openmm.NonbondedForce.CutoffNonPeriodic]: + epsilon_solvent = self._old_system_forces['NonbondedForce'].getReactionFieldDielectric() + standard_nonbonded_force.setReactionFieldDielectric( + epsilon_solvent) + standard_nonbonded_force.setCutoffDistance(r_cutoff) + elif self._nonbonded_method in [openmm.NonbondedForce.PME, + openmm.NonbondedForce.Ewald]: + [alpha_ewald, nx, ny, nz] = self._old_system_forces['NonbondedForce'].getPMEParameters() + delta = self._old_system_forces['NonbondedForce'].getEwaldErrorTolerance() + standard_nonbonded_force.setPMEParameters(alpha_ewald, nx, ny, nz) + standard_nonbonded_force.setEwaldErrorTolerance(delta) + standard_nonbonded_force.setCutoffDistance(r_cutoff) + else: + errmsg = f"Nonbonded method {self._nonbonded_method} not supported" + raise ValueError(errmsg) + + standard_nonbonded_force.setNonbondedMethod(self._nonbonded_method) + + sterics_energy_expression += self._nonbonded_custom_sterics_common() + + sterics_mixing_rules = self._nonbonded_custom_mixing_rules() + + custom_nonbonded_method = self._translate_nonbonded_method_to_custom( + self._nonbonded_method) + + total_sterics_energy = "U_sterics;" + sterics_energy_expression + sterics_mixing_rules + + sterics_custom_nonbonded_force = openmm.CustomNonbondedForce( + total_sterics_energy) + + # Match cutoff from non-custom NB forces + sterics_custom_nonbonded_force.setCutoffDistance(r_cutoff) + + if self._softcore_LJ_v2: + sterics_custom_nonbonded_force.addGlobalParameter( + "softcore_alpha", self._softcore_LJ_v2_alpha) + else: + sterics_custom_nonbonded_force.addGlobalParameter( + "softcore_alpha", self._softcore_alpha) + + # Lennard-Jones sigma initial + sterics_custom_nonbonded_force.addPerParticleParameter("sigmaA") + # Lennard-Jones epsilon initial + sterics_custom_nonbonded_force.addPerParticleParameter("epsilonA") + # Lennard-Jones sigma final + sterics_custom_nonbonded_force.addPerParticleParameter("sigmaB") + # Lennard-Jones epsilon final + sterics_custom_nonbonded_force.addPerParticleParameter("epsilonB") + # 1 = hybrid old atom, 0 otherwise + sterics_custom_nonbonded_force.addPerParticleParameter("unique_old") + # 1 = hybrid new atom, 0 otherwise + sterics_custom_nonbonded_force.addPerParticleParameter("unique_new") + + sterics_custom_nonbonded_force.addGlobalParameter( + "lambda_sterics_core", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter( + "lambda_electrostatics_core", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter( + "lambda_sterics_insert", 0.0) + sterics_custom_nonbonded_force.addGlobalParameter( + "lambda_sterics_delete", 0.0) + + sterics_custom_nonbonded_force.setNonbondedMethod( + custom_nonbonded_method) + + self._hybrid_system.addForce(sterics_custom_nonbonded_force) + self._hybrid_system_forces['core_sterics_force'] = sterics_custom_nonbonded_force + + # Set the use of dispersion correction to be the same between the new + # nonbonded force and the old one: + if self._old_system_forces['NonbondedForce'].getUseDispersionCorrection(): + self._hybrid_system_forces['standard_nonbonded_force'].setUseDispersionCorrection(True) + if self._use_dispersion_correction: + sterics_custom_nonbonded_force.setUseLongRangeCorrection(True) + else: + self._hybrid_system_forces['standard_nonbonded_force'].setUseDispersionCorrection(False) + + if self._old_system_forces['NonbondedForce'].getUseSwitchingFunction(): + switching_distance = self._old_system_forces['NonbondedForce'].getSwitchingDistance() + standard_nonbonded_force.setUseSwitchingFunction(True) + standard_nonbonded_force.setSwitchingDistance(switching_distance) + sterics_custom_nonbonded_force.setUseSwitchingFunction(True) + sterics_custom_nonbonded_force.setSwitchingDistance(switching_distance) + else: + standard_nonbonded_force.setUseSwitchingFunction(False) + sterics_custom_nonbonded_force.setUseSwitchingFunction(False) + + @staticmethod + def _find_bond_parameters(bond_force, index1, index2): + """ + This is a convenience function to find bond parameters in another + system given the two indices. + + Parameters + ---------- + bond_force : openmm.HarmonicBondForce + The bond force where the parameters should be found + index1 : int + Index1 (order does not matter) of the bond atoms + index2 : int + Index2 (order does not matter) of the bond atoms + + Returns + ------- + bond_parameters : list + List of relevant bond parameters + """ + index_set = {index1, index2} + # Loop through all the bonds: + for bond_index in range(bond_force.getNumBonds()): + parms = bond_force.getBondParameters(bond_index) + if index_set == {parms[0], parms[1]}: + return parms + + return [] + + def _handle_harmonic_bonds(self): + """ + This method adds the appropriate interaction for all bonds in the + hybrid system. The scheme used is: + + 1) If the two atoms are both in the core, then we add to the + CustomBondForce and interpolate between the two parameters + 2) If one of the atoms is in core and the other is environment, we + have to assert that the bond parameters do not change between the + old and the new system; then, the parameters are added to the + regular bond force + 3) Otherwise, we add the bond to a regular bond force. + + Notes + ----- + * Bond softening logic has been removed for now. + """ + old_system_bond_force = self._old_system_forces['HarmonicBondForce'] + new_system_bond_force = self._new_system_forces['HarmonicBondForce'] + + # First, loop through the old system bond forces and add relevant terms + for bond_index in range(old_system_bond_force.getNumBonds()): + # Get each set of bond parameters + [index1_old, index2_old, r0_old, k_old] = old_system_bond_force.getBondParameters(bond_index) + + # Map the indices to the hybrid system, for which our atom classes + # are defined. + index1_hybrid = self._old_to_hybrid_map[index1_old] + index2_hybrid = self._old_to_hybrid_map[index2_old] + index_set = {index1_hybrid, index2_hybrid} + + # Now check if it is a subset of the core atoms (that is, both + # atoms are in the core) + # If it is, we need to find the parameters in the old system so + # that we can interpolate + if index_set.issubset(self._atom_classes['core_atoms']): + index1_new = self._old_to_new_map[index1_old] + index2_new = self._old_to_new_map[index2_old] + new_bond_parameters = self._find_bond_parameters( + new_system_bond_force, index1_new, index2_new) + if not new_bond_parameters: + r0_new = r0_old + k_new = 0.0 * unit.kilojoule_per_mole / unit.angstrom ** 2 + else: + # TODO - why is this being recalculated? + [index1, index2, r0_new, k_new] = self._find_bond_parameters( + new_system_bond_force, index1_new, index2_new) + self._hybrid_system_forces['core_bond_force'].addBond( + index1_hybrid, index2_hybrid, + [r0_old, k_old, r0_new, k_new]) + + # Check if the index set is a subset of anything besides + # environment (in the case of environment, we just add the bond to + # the regular bond force) + # that would mean that this bond is core-unique_old or + # unique_old-unique_old + # NOTE - These are currently all the same because we don't soften + # TODO - work these out somewhere else, this is terribly difficult + # to understand logic. + elif (index_set.issubset(self._atom_classes['unique_old_atoms']) or + (len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 1 + and len(index_set.intersection(self._atom_classes['core_atoms'])) == 1)): + + # We can just add it to the regular bond force. + self._hybrid_system_forces['standard_bond_force'].addBond( + index1_hybrid, index2_hybrid, r0_old, k_old) + + elif (len(index_set.intersection(self._atom_classes['environment_atoms'])) == 1 and + len(index_set.intersection(self._atom_classes['core_atoms'])) == 1): + self._hybrid_system_forces['standard_bond_force'].addBond( + index1_hybrid, index2_hybrid, r0_old, k_old) + + # Otherwise, we just add the same parameters as those in the old + # system (these are environment atoms, and the parameters are the + # same) + elif index_set.issubset(self._atom_classes['environment_atoms']): + self._hybrid_system_forces['standard_bond_force'].addBond( + index1_hybrid, index2_hybrid, r0_old, k_old) + else: + errmsg = (f"hybrid index set {index_set} does not fit into a " + "canonical atom type") + raise ValueError(errmsg) + + # Now loop through the new system to get the interactions that are + # unique to it. + for bond_index in range(new_system_bond_force.getNumBonds()): + # Get each set of bond parameters + [index1_new, index2_new, r0_new, k_new] = new_system_bond_force.getBondParameters(bond_index) + + # Convert indices to hybrid, since that is how we represent atom classes: + index1_hybrid = self._new_to_hybrid_map[index1_new] + index2_hybrid = self._new_to_hybrid_map[index2_new] + index_set = {index1_hybrid, index2_hybrid} + + # If the intersection of this set and unique new atoms contains + # anything, the bond is unique to the new system and must be added + # all other bonds in the new system have been accounted for already + # NOTE - These are mostly all the same because we don't soften + if (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) == 2 or + (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) == 1 and + len(index_set.intersection(self._atom_classes['core_atoms'])) == 1)): + + # If we aren't softening bonds, then just add it to the standard bond force + self._hybrid_system_forces['standard_bond_force'].addBond( + index1_hybrid, index2_hybrid, r0_new, k_new) + + # If the bond is in the core, it has probably already been added + # in the above loop. However, there are some circumstances + # where it was not (closing a ring). In that case, the bond has + # not been added and should be added here. + # This has some peculiarities to be discussed... + # TODO - Work out what the above peculiarities are... + elif index_set.issubset(self._atom_classes['core_atoms']): + if not self._find_bond_parameters( + self._hybrid_system_forces['core_bond_force'], + index1_hybrid, index2_hybrid): + r0_old = r0_new + k_old = 0.0 * unit.kilojoule_per_mole / unit.angstrom ** 2 + self._hybrid_system_forces['core_bond_force'].addBond( + index1_hybrid, index2_hybrid, + [r0_old, k_old, r0_new, k_new]) + elif index_set.issubset(self._atom_classes['environment_atoms']): + # Already been added + pass + + elif (len(index_set.intersection(self._atom_classes['environment_atoms'])) == 1 and + len(index_set.intersection(self._atom_classes['core_atoms'])) == 1): + pass + + else: + errmsg = (f"hybrid index set {index_set} does not fit into a " + "canonical atom type") + raise ValueError(errmsg) + + @staticmethod + def _find_angle_parameters(angle_force, indices): + """ + Convenience function to find the angle parameters corresponding to a + particular set of indices + + Parameters + ---------- + angle_force : openmm.HarmonicAngleForce + The force where the angle of interest may be found. + indices : list of int + The indices (any order) of the angle atoms + + Returns + ------- + angle_params : list + list of angle parameters + """ + indices_reversed = indices[::-1] + + # Now loop through and try to find the angle: + for angle_index in range(angle_force.getNumAngles()): + angle_params = angle_force.getAngleParameters(angle_index) + + # Get a set representing the angle indices + angle_param_indices = angle_params[:3] + + if (indices == angle_param_indices or + indices_reversed == angle_param_indices): + return angle_params + return [] # Return empty if no matching angle found + + def _handle_harmonic_angles(self): + """ + This method adds the appropriate interaction for all angles in the + hybrid system. The scheme used, as with bonds, is: + + 1) If the three atoms are all in the core, then we add to the + CustomAngleForce and interpolate between the two parameters + 2) If the three atoms contain at least one unique new, check if the + angle is in the neglected new list, and if so, interpolate from + K_1 = 0; else, if the three atoms contain at least one unique old, + check if the angle is in the neglected old list, and if so, + interpolate from K_2 = 0. + 3) If the angle contains at least one environment and at least one + core atom, assert there are no unique new atoms and that the angle + terms are preserved between the new and the old system. Then add to + the standard angle force. + 4) Otherwise, we add the angle to a regular angle force since it is + environment. + + Notes + ----- + * Removed softening and neglected angle functionality + """ + old_system_angle_force = self._old_system_forces['HarmonicAngleForce'] + new_system_angle_force = self._new_system_forces['HarmonicAngleForce'] + + # First, loop through all the angles in the old system to determine + # what to do with them. We will only use the + # custom angle force if all atoms are part of "core." Otherwise, they + # are either unique to one system or never change. + for angle_index in range(old_system_angle_force.getNumAngles()): + + old_angle_parameters = old_system_angle_force.getAngleParameters( + angle_index) + + # Get the indices in the hybrid system + hybrid_index_list = [ + self._old_to_hybrid_map[old_atomid] for old_atomid in old_angle_parameters[:3] + ] + hybrid_index_set = set(hybrid_index_list) + + # If all atoms are in the core, we'll need to find the + # corresponding parameters in the old system and interpolate + if hybrid_index_set.issubset(self._atom_classes['core_atoms']): + # Get the new indices so we can get the new angle parameters + new_indices = [ + self._old_to_new_map[old_atomid] for old_atomid in old_angle_parameters[:3] + ] + new_angle_parameters = self._find_angle_parameters( + new_system_angle_force, new_indices + ) + if not new_angle_parameters: + new_angle_parameters = [ + 0, 0, 0, old_angle_parameters[3], + 0.0 * unit.kilojoule_per_mole / unit.radian ** 2 + ] + + # Add to the hybrid force: + # the parameters at indices 3 and 4 represent theta0 and k, + # respectively. + hybrid_force_parameters = [ + old_angle_parameters[3], old_angle_parameters[4], + new_angle_parameters[3], new_angle_parameters[4] + ] + self._hybrid_system_forces['core_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_force_parameters + ) + + # Check if the atoms are neither all core nor all environment, + # which would mean they involve unique old interactions + elif not hybrid_index_set.issubset( + self._atom_classes['environment_atoms']): + # if there is an environment atom + if hybrid_index_set.intersection( + self._atom_classes['environment_atoms']): + if hybrid_index_set.intersection( + self._atom_classes['unique_old_atoms']): + errmsg = "we disallow unique-environment terms" + raise ValueError(errmsg) + + self._hybrid_system_forces['standard_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], old_angle_parameters[3], + old_angle_parameters[4] + ) + else: + # There are no env atoms, so we can treat this term + # appropriately + + # We don't soften so just add this to the standard angle + # force + self._hybrid_system_forces['standard_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], old_angle_parameters[3], + old_angle_parameters[4] + ) + + # Otherwise, only environment atoms are in this interaction, so + # add it to the standard angle force + elif hybrid_index_set.issubset( + self._atom_classes['environment_atoms']): + self._hybrid_system_forces['standard_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], old_angle_parameters[3], + old_angle_parameters[4] + ) + else: + errmsg = (f"handle_harmonic_angles: angle_index {angle_index} " + "does not fit a canonical form.") + raise ValueError(errmsg) + + # Finally, loop through the new system force to add any unique new + # angles + for angle_index in range(new_system_angle_force.getNumAngles()): + + new_angle_parameters = new_system_angle_force.getAngleParameters( + angle_index) + + # Get the indices in the hybrid system + hybrid_index_list = [ + self._new_to_hybrid_map[new_atomid] for new_atomid in new_angle_parameters[:3] + ] + hybrid_index_set = set(hybrid_index_list) + + # If the intersection of this hybrid set with the unique new atoms + # is nonempty, it must be added: + # TODO - there's a ton of len > 0 on sets, empty sets == False, + # so we can simplify this logic. + if len(hybrid_index_set.intersection( + self._atom_classes['unique_new_atoms'])) > 0: + if hybrid_index_set.intersection( + self._atom_classes['environment_atoms']): + errmsg = ("we disallow angle terms with unique new and " + "environment atoms") + raise ValueError(errmsg) + + # Not softening just add to the nonalchemical force + self._hybrid_system_forces['standard_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], new_angle_parameters[3], + new_angle_parameters[4] + ) + + elif hybrid_index_set.issubset(self._atom_classes['core_atoms']): + if not self._find_angle_parameters(self._hybrid_system_forces['core_angle_force'], + hybrid_index_list): + hybrid_force_parameters = [ + new_angle_parameters[3], + 0.0 * unit.kilojoule_per_mole / unit.radian ** 2, + new_angle_parameters[3], new_angle_parameters[4] + ] + self._hybrid_system_forces['core_angle_force'].addAngle( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_force_parameters + ) + elif hybrid_index_set.issubset(self._atom_classes['environment_atoms']): + # We have already added the appropriate environmental atom + # terms + pass + elif hybrid_index_set.intersection(self._atom_classes['environment_atoms']): + if hybrid_index_set.intersection(self._atom_classes['unique_new_atoms']): + errmsg = ("we disallow angle terms with unique new and " + "environment atoms") + raise ValueError(errmsg) + else: + errmsg = (f"hybrid index list {hybrid_index_list} does not " + "fit into a canonical atom set") + raise ValueError(errmsg) + + @staticmethod + def _find_torsion_parameters(torsion_force, indices): + """ + Convenience function to find the torsion parameters corresponding to a + particular set of indices. + + Parameters + ---------- + torsion_force : openmm.PeriodicTorsionForce + torsion force where the torsion of interest may be found + indices : list of int + The indices of the atoms of the torsion + + Returns + ------- + torsion_parameters : list + torsion parameters + """ + indices_reversed = indices[::-1] + + torsion_params_list = list() + + # Now loop through and try to find the torsion: + for torsion_idx in range(torsion_force.getNumTorsions()): + torsion_params = torsion_force.getTorsionParameters(torsion_idx) + + # Get a set representing the torsion indices: + torsion_param_indices = torsion_params[:4] + + if (indices == torsion_param_indices or + indices_reversed == torsion_param_indices): + torsion_params_list.append(torsion_params) + + return torsion_params_list + + def _handle_periodic_torsion_force(self): + """ + Handle the torsions defined in the new and old systems as such: + + 1. old system torsions will enter the ``custom_torsion_force`` if they + do not contain ``unique_old_atoms`` and will interpolate from ``on`` + to ``off`` from ``lambda_torsions`` = 0 to 1, respectively. + 2. new system torsions will enter the ``custom_torsion_force`` if they + do not contain ``unique_new_atoms`` and will interpolate from + ``off`` to ``on`` from ``lambda_torsions`` = 0 to 1, respectively. + 3. old *and* new system torsions will enter the + ``unique_atom_torsion_force`` (``standard_torsion_force``) and will + *not* be interpolated. + + Notes + ----- + * Torsion flattening logic has been removed for now. + """ + old_system_torsion_force = self._old_system_forces['PeriodicTorsionForce'] + new_system_torsion_force = self._new_system_forces['PeriodicTorsionForce'] + + # aux list stores the torsions that we already computed such that we don't add them again when checking the new system + auxiliary_custom_torsion_force = [] + # aludel/valence.py -- convenient way of handling all the valence terms for alchemistry + old_custom_torsions_to_standard = [] + + # We need to keep track of what torsions we added so that we do not + # double count + # added_torsions = [] + # TODO: Commented out since this actually isn't being done anywhere? + # Is it necessary? Should we add this logic back in? + for torsion_index in range(old_system_torsion_force.getNumTorsions()): + + torsion_parameters = old_system_torsion_force.getTorsionParameters( + torsion_index) + + # Get the indices in the hybrid system + hybrid_index_list = [ + self._old_to_hybrid_map[old_index] for old_index in torsion_parameters[:4] + ] + hybrid_index_set = set(hybrid_index_list) + + # If all atoms are in the core, we'll need to find the + # corresponding parameters in the old system and interpolate + if hybrid_index_set.intersection(self._atom_classes['unique_old_atoms']): + # Then it goes to a standard force... + self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + torsion_parameters[4], torsion_parameters[5], + torsion_parameters[6] + ) + else: + # It is a core-only term, an environment-only term, or a + # core/env term; in any case, it goes to the core torsion_force + # TODO - why are we even adding the 0.0, 0.0, 0.0 section? + hybrid_force_parameters = [ + torsion_parameters[4], torsion_parameters[5], + torsion_parameters[6], 0.0, 0.0, 0.0 + ] + auxiliary_custom_torsion_force.append( + [hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + hybrid_force_parameters[:3]] + ) + + for torsion_index in range(new_system_torsion_force.getNumTorsions()): + torsion_parameters = new_system_torsion_force.getTorsionParameters(torsion_index) + + # Get the indices in the hybrid system: + hybrid_index_list = [ + self._new_to_hybrid_map[new_index] for new_index in torsion_parameters[:4]] + hybrid_index_set = set(hybrid_index_list) + + if hybrid_index_set.intersection(self._atom_classes['unique_new_atoms']): + # Then it goes to the custom torsion force (scaled to zero) + self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + torsion_parameters[4], torsion_parameters[5], + torsion_parameters[6] + ) + else: + hybrid_force_parameters = [ + 0.0, 0.0, 0.0, torsion_parameters[4], + torsion_parameters[5], torsion_parameters[6]] + + # Check to see if this term is in the olds... + term = [hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + hybrid_force_parameters[3:]] + if term in auxiliary_custom_torsion_force: + # Then this terms has to go to standard and be deleted... + old_index = auxiliary_custom_torsion_force.index(term) + old_custom_torsions_to_standard.append(old_index) + self._hybrid_system_forces['unique_atom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + torsion_parameters[4], torsion_parameters[5], + torsion_parameters[6] + ) + else: + # Then this term has to go to the core force... + self._hybrid_system_forces['custom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + hybrid_force_parameters + ) + + # Now we have to loop through the aux custom torsion force + for index in [q for q in range(len(auxiliary_custom_torsion_force)) + if q not in old_custom_torsions_to_standard]: + terms = auxiliary_custom_torsion_force[index] + hybrid_index_list = terms[:4] + hybrid_force_parameters = terms[4] + [0., 0., 0.] + self._hybrid_system_forces['custom_torsion_force'].addTorsion( + hybrid_index_list[0], hybrid_index_list[1], + hybrid_index_list[2], hybrid_index_list[3], + hybrid_force_parameters + ) + + def _handle_nonbonded(self): + """ + Handle the nonbonded interactions defined in the new and old systems. + + TODO + ---- + * Expand this docstring to explain the logic. + * A lot of this logic is duplicated, probably turn it into a couple of + functions. + """ + + def _check_indices(idx1, idx2): + if idx1 != idx2: + errmsg = ("Attempting to add incorrect particle to hybrid " + "system") + raise ValueError(errmsg) + + old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] + new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + hybrid_to_old_map = self._hybrid_to_old_map + hybrid_to_new_map = self._hybrid_to_new_map + + # Define new global parameters for NonbondedForce + self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter('lambda_electrostatics_core', 0.0) + self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter('lambda_sterics_core', 0.0) + self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter("lambda_electrostatics_delete", 0.0) + self._hybrid_system_forces['standard_nonbonded_force'].addGlobalParameter("lambda_electrostatics_insert", 0.0) + + # We have to loop through the particles in the system, because + # nonbonded force does not accept index + for particle_index in range(self._hybrid_system.getNumParticles()): + + if particle_index in self._atom_classes['unique_old_atoms']: + # Get the parameters in the old system + old_index = hybrid_to_old_map[particle_index] + [charge, sigma, epsilon] = old_system_nonbonded_force.getParticleParameters(old_index) + + # Add the particle to the hybrid custom sterics and + # electrostatics. + # turning off sterics in forward direction + check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( + [sigma, epsilon, sigma, 0.0 * epsilon, 1, 0] + ) + _check_indices(particle_index, check_index) + + # Add particle to the regular nonbonded force, but + # Lennard-Jones will be handled by CustomNonbondedForce + check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( + charge, sigma, 0.0 * epsilon + ) + _check_indices(particle_index, check_index) + + # Charge will be turned off at + # lambda_electrostatics_delete = 0, on at + # lambda_electrostatics_delete = 1; kill charge with + # lambda_electrostatics_delete = 0 --> 1 + self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( + 'lambda_electrostatics_delete', particle_index, + -charge, 0 * sigma, 0 * epsilon + ) + + elif particle_index in self._atom_classes['unique_new_atoms']: + # Get the parameters in the new system + new_index = hybrid_to_new_map[particle_index] + [charge, sigma, epsilon] = new_system_nonbonded_force.getParticleParameters(new_index) + + # Add the particle to the hybrid custom sterics and electrostatics + # turning on sterics in forward direction + check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( + [sigma, 0.0 * epsilon, sigma, epsilon, 0, 1] + ) + _check_indices(particle_index, check_index) + + # Add particle to the regular nonbonded force, but + # Lennard-Jones will be handled by CustomNonbondedForce + check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( + 0.0, sigma, 0.0 + ) # charge starts at zero + _check_indices(particle_index, check_index) + + # Charge will be turned off at lambda_electrostatics_insert = 0 + # on at lambda_electrostatics_insert = 1; + # add charge with lambda_electrostatics_insert = 0 --> 1 + self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( + 'lambda_electrostatics_insert', particle_index, + +charge, 0, 0 + ) + + elif particle_index in self._atom_classes['core_atoms']: + # Get the parameters in the new and old systems: + old_index = hybrid_to_old_map[particle_index] + [charge_old, sigma_old, epsilon_old] = old_system_nonbonded_force.getParticleParameters(old_index) + new_index = hybrid_to_new_map[particle_index] + [charge_new, sigma_new, epsilon_new] = new_system_nonbonded_force.getParticleParameters(new_index) + + # Add the particle to the custom forces, interpolating between + # the two parameters; add steric params and zero electrostatics + # to core_sterics per usual + check_index = self._hybrid_system_forces['core_sterics_force'].addParticle( + [sigma_old, epsilon_old, sigma_new, epsilon_new, 0, 0]) + _check_indices(particle_index, check_index) + + # Still add the particle to the regular nonbonded force, but + # with zeroed out parameters; add old charge to + # standard_nonbonded and zero sterics + check_index = self._hybrid_system_forces['standard_nonbonded_force'].addParticle( + charge_old, 0.5 * (sigma_old + sigma_new), 0.0) + _check_indices(particle_index, check_index) + + # Charge is charge_old at lambda_electrostatics = 0, + # charge_new at lambda_electrostatics = 1 + # TODO: We could also interpolate the Lennard-Jones here + # instead of core_sterics force so that core_sterics_force + # could just be softcore. + + # Interpolate between old and new charge with + # lambda_electrostatics core make sure to keep sterics off + self._hybrid_system_forces['standard_nonbonded_force'].addParticleParameterOffset( + 'lambda_electrostatics_core', particle_index, + (charge_new - charge_old), 0, 0 + ) + + # Otherwise, the particle is in the environment + else: + # The parameters will be the same in new and old system, so + # just take the old parameters + old_index = hybrid_to_old_map[particle_index] + [charge, sigma, epsilon] = old_system_nonbonded_force.getParticleParameters(old_index) + + # Add the particle to the hybrid custom sterics, but they dont + # change; electrostatics are ignored + self._hybrid_system_forces['core_sterics_force'].addParticle( + [sigma, epsilon, sigma, epsilon, 0, 0] + ) + + # Add the environment atoms to the regular nonbonded force as + # well: should we be adding steric terms here, too? + self._hybrid_system_forces['standard_nonbonded_force'].addParticle( + charge, sigma, epsilon + ) + + # Now loop pairwise through (unique_old, unique_new) and add exceptions + # so that they never interact electrostatically + # (place into Nonbonded Force) + unique_old_atoms = self._atom_classes['unique_old_atoms'] + unique_new_atoms = self._atom_classes['unique_new_atoms'] + + for old in unique_old_atoms: + for new in unique_new_atoms: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + old, new, 0.0 * unit.elementary_charge ** 2, + 1.0 * unit.nanometers, 0.0 * unit.kilojoules_per_mole) + # This is only necessary to avoid the 'All forces must have + # identical exclusions' rule + self._hybrid_system_forces['core_sterics_force'].addExclusion(old, new) + + self._handle_interaction_groups() + + self._handle_hybrid_exceptions() + + self._handle_original_exceptions() + + def _handle_interaction_groups(self): + """ + Create the appropriate interaction groups for the custom nonbonded + forces. The groups are: + + 1) Unique-old - core + 2) Unique-old - environment + 3) Unique-new - core + 4) Unique-new - environment + 5) Core - environment + 6) Core - core + + Unique-old and Unique new are prevented from interacting this way, + and intra-unique interactions occur in an unmodified nonbonded force. + + Must be called after particles are added to the Nonbonded forces + TODO: we should also be adding the following interaction groups... + 7) Unique-new - Unique-new + 8) Unique-old - Unique-old + """ + # Get the force objects for convenience: + sterics_custom_force = self._hybrid_system_forces['core_sterics_force'] + + # Also prepare the atom classes + core_atoms = self._atom_classes['core_atoms'] + unique_old_atoms = self._atom_classes['unique_old_atoms'] + unique_new_atoms = self._atom_classes['unique_new_atoms'] + environment_atoms = self._atom_classes['environment_atoms'] + + sterics_custom_force.addInteractionGroup(unique_old_atoms, core_atoms) + + sterics_custom_force.addInteractionGroup(unique_old_atoms, + environment_atoms) + + sterics_custom_force.addInteractionGroup(unique_new_atoms, + core_atoms) + + sterics_custom_force.addInteractionGroup(unique_new_atoms, + environment_atoms) + + sterics_custom_force.addInteractionGroup(core_atoms, environment_atoms) + + sterics_custom_force.addInteractionGroup(core_atoms, core_atoms) + + sterics_custom_force.addInteractionGroup(unique_new_atoms, + unique_new_atoms) + + sterics_custom_force.addInteractionGroup(unique_old_atoms, + unique_old_atoms) + + def _handle_hybrid_exceptions(self): + """ + Instead of excluding interactions that shouldn't occur, we provide + exceptions for interactions that were zeroed out but should occur. + """ + # TODO - are these actually used anywhere? Flake8 says no + old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] + new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + + # Prepare the atom classes + unique_old_atoms = self._atom_classes['unique_old_atoms'] + unique_new_atoms = self._atom_classes['unique_new_atoms'] + + # Get the list of interaction pairs for which we need to set exceptions + unique_old_pairs = list(itertools.combinations(unique_old_atoms, 2)) + unique_new_pairs = list(itertools.combinations(unique_new_atoms, 2)) + + # Add back the interactions of the old unique atoms, unless there are + # exceptions + for atom_pair in unique_old_pairs: + # Since the pairs are indexed in the dictionary by the old system + # indices, we need to convert + old_index_atom_pair = (self._hybrid_to_old_map[atom_pair[0]], + self._hybrid_to_old_map[atom_pair[1]]) + + # Now we check if the pair is in the exception dictionary + if old_index_atom_pair in self._old_system_exceptions: + [chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair] + # if we are interpolating 1,4 exceptions then we have to + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd * 0.0, + sigma, epsilon * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + ) + + # Add exclusion to ensure exceptions are consistent + self._hybrid_system_forces['core_sterics_force'].addExclusion( + atom_pair[0], atom_pair[1] + ) + + # Check if the pair is in the reverse order and use that if so + elif old_index_atom_pair[::-1] in self._old_system_exceptions: + [chargeProd, sigma, epsilon] = self._old_system_exceptions[old_index_atom_pair[::-1]] + # If we are interpolating 1,4 exceptions then we have to + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd * 0.0, + sigma, epsilon * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon) + + # Add exclusion to ensure exceptions are consistent + self._hybrid_system_forces['core_sterics_force'].addExclusion( + atom_pair[0], atom_pair[1]) + + # TODO: work out why there's a bunch of commented out code here + # Exerpt: + # If it's not handled by an exception in the original system, we + # just add the regular parameters as an exception + # TODO: this implies that the old-old nonbonded interactions (those + # which are not exceptions) are always self-interacting throughout + # lambda protocol... + + # Add back the interactions of the new unique atoms, unless there are + # exceptions + for atom_pair in unique_new_pairs: + # Since the pairs are indexed in the dictionary by the new system + # indices, we need to convert + new_index_atom_pair = (self._hybrid_to_new_map[atom_pair[0]], + self._hybrid_to_new_map[atom_pair[1]]) + + # Now we check if the pair is in the exception dictionary + if new_index_atom_pair in self._new_system_exceptions: + [chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair] + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd * 0.0, + sigma, epsilon * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + ) + + self._hybrid_system_forces['core_sterics_force'].addExclusion( + atom_pair[0], atom_pair[1] + ) + + # Check if the pair is present in the reverse order and use that if so + elif new_index_atom_pair[::-1] in self._new_system_exceptions: + [chargeProd, sigma, epsilon] = self._new_system_exceptions[new_index_atom_pair[::-1]] + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd * 0.0, + sigma, epsilon * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + atom_pair[0], atom_pair[1], chargeProd, sigma, epsilon + ) + + self._hybrid_system_forces['core_sterics_force'].addExclusion( + atom_pair[0], atom_pair[1] + ) + + # TODO: work out why there's a bunch of commented out code here + # If it's not handled by an exception in the original system, we + # just add the regular parameters as an exception + + @staticmethod + def _find_exception(force, index1, index2): + """ + Find the exception that corresponds to the given indices in the given + system + + Parameters + ---------- + force : openmm.NonbondedForce object + System containing the exceptions + index1 : int + The index of the first atom (order is unimportant) + index2 : int + The index of the second atom (order is unimportant) + + Returns + ------- + exception_parameters : list + List of exception parameters + """ + index_set = {index1, index2} + + # Loop through the exceptions and try to find one matching the criteria + for exception_idx in range(force.getNumExceptions()): + exception_parameters = force.getExceptionParameters(exception_idx) + if index_set == set(exception_parameters[:2]): + return exception_parameters + return [] + + def _handle_original_exceptions(self): + """ + This method ensures that exceptions present in the original systems are + present in the hybrid appropriately. + """ + # Get what we need to find the exceptions from the new and old systems: + old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] + new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + hybrid_to_old_map = self._hybrid_to_old_map + hybrid_to_new_map = self._hybrid_to_new_map + + # First, loop through the old system's exceptions and add them to the + # hybrid appropriately: + for exception_pair, exception_parameters in self._old_system_exceptions.items(): + + [index1_old, index2_old] = exception_pair + [chargeProd_old, sigma_old, epsilon_old] = exception_parameters + + # Get hybrid indices: + index1_hybrid = self._old_to_hybrid_map[index1_old] + index2_hybrid = self._old_to_hybrid_map[index2_old] + index_set = {index1_hybrid, index2_hybrid} + + # In this case, the interaction is only covered by the regular + # nonbonded force, and as such will be copied to that force + # In the unique-old case, it is handled elsewhere due to internal + # peculiarities regarding exceptions + if index_set.issubset(self._atom_classes['environment_atoms']): + self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_old, + sigma_old, epsilon_old + ) + self._hybrid_system_forces['core_sterics_force'].addExclusion( + index1_hybrid, index2_hybrid + ) + + # We have already handled unique old - unique old exceptions + elif len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 2: + continue + + # Otherwise, check if one of the atoms in the set is in the + # unique_old_group and the other is not: + elif len(index_set.intersection(self._atom_classes['unique_old_atoms'])) == 1: + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_old * 0.0, + sigma_old, epsilon_old * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_old, + sigma_old, epsilon_old + ) + + self._hybrid_system_forces['core_sterics_force'].addExclusion( + index1_hybrid, index2_hybrid + ) + + # If the exception particles are neither solely old unique, solely + # environment, nor contain any unique old atoms, they are either + # core/environment or core/core + # In this case, we need to get the parameters from the exception in + # the other (new) system, and interpolate between the two + else: + # First get the new indices. + index1_new = hybrid_to_new_map[index1_hybrid] + index2_new = hybrid_to_new_map[index2_hybrid] + # Get the exception parameters: + new_exception_parms = self._find_exception( + new_system_nonbonded_force, + index1_new, index2_new) + + # If there's no new exception, then we should just set the + # exception parameters to be the nonbonded parameters + if not new_exception_parms: + [charge1_new, sigma1_new, epsilon1_new] = new_system_nonbonded_force.getParticleParameters( + index1_new) + [charge2_new, sigma2_new, epsilon2_new] = new_system_nonbonded_force.getParticleParameters( + index2_new) + + chargeProd_new = charge1_new * charge2_new + sigma_new = 0.5 * (sigma1_new + sigma2_new) + epsilon_new = unit.sqrt(epsilon1_new * epsilon2_new) + else: + [index1_new, index2_new, chargeProd_new, sigma_new, epsilon_new] = new_exception_parms + + # Interpolate between old and new + exception_index = self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_old, + sigma_old, epsilon_old + ) + self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( + 'lambda_electrostatics_core', exception_index, + (chargeProd_new - chargeProd_old), 0, 0 + ) + self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( + 'lambda_sterics_core', exception_index, 0, + (sigma_new - sigma_old), (epsilon_new - epsilon_old) + ) + self._hybrid_system_forces['core_sterics_force'].addExclusion( + index1_hybrid, index2_hybrid + ) + + # Now, loop through the new system to collect remaining interactions. + # The only that remain here are uniquenew-uniquenew, uniquenew-core, + # and uniquenew-environment. There might also be core-core, since not + # all core-core exceptions exist in both + for exception_pair, exception_parameters in self._new_system_exceptions.items(): + [index1_new, index2_new] = exception_pair + [chargeProd_new, sigma_new, epsilon_new] = exception_parameters + + # Get hybrid indices: + index1_hybrid = self._new_to_hybrid_map[index1_new] + index2_hybrid = self._new_to_hybrid_map[index2_new] + + index_set = {index1_hybrid, index2_hybrid} + + # If it's a subset of unique_new_atoms, then this is an + # intra-unique interaction and should have its exceptions + # specified in the regular nonbonded force. However, this is + # handled elsewhere as above due to pecularities with exception + # handling + if index_set.issubset(self._atom_classes['unique_new_atoms']): + continue + + # Look for the final class- interactions between uniquenew-core and + # uniquenew-environment. They are treated similarly: they are + # simply on and constant the entire time (as a valence term) + elif len(index_set.intersection(self._atom_classes['unique_new_atoms'])) > 0: + if self._interpolate_14s: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_new * 0.0, + sigma_new, epsilon_new * 0.0 + ) + else: + self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, chargeProd_new, + sigma_new, epsilon_new + ) + + self._hybrid_system_forces['core_sterics_force'].addExclusion( + index1_hybrid, index2_hybrid + ) + + # However, there may be a core exception that exists in one system + # but not the other (ring closure) + elif index_set.issubset(self._atom_classes['core_atoms']): + + # Get the old indices + try: + index1_old = self._new_to_old_map[index1_new] + index2_old = self._new_to_old_map[index2_new] + except KeyError: + continue + + # See if it's also in the old nonbonded force. if it is, then we don't need to add it. + # But if it's not, we need to interpolate + if not self._find_exception(old_system_nonbonded_force, index1_old, index2_old): + [charge1_old, sigma1_old, epsilon1_old] = old_system_nonbonded_force.getParticleParameters( + index1_old) + [charge2_old, sigma2_old, epsilon2_old] = old_system_nonbonded_force.getParticleParameters( + index2_old) + + chargeProd_old = charge1_old * charge2_old + sigma_old = 0.5 * (sigma1_old + sigma2_old) + epsilon_old = unit.sqrt(epsilon1_old * epsilon2_old) + + exception_index = self._hybrid_system_forces['standard_nonbonded_force'].addException( + index1_hybrid, index2_hybrid, + chargeProd_old, sigma_old, + epsilon_old) + + self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( + 'lambda_electrostatics_core', exception_index, + (chargeProd_new - chargeProd_old), 0, 0 + ) + + self._hybrid_system_forces['standard_nonbonded_force'].addExceptionParameterOffset( + 'lambda_sterics_core', exception_index, 0, + (sigma_new - sigma_old), (epsilon_new - epsilon_old) + ) + + self._hybrid_system_forces['core_sterics_force'].addExclusion( + index1_hybrid, index2_hybrid + ) + + def _handle_old_new_exceptions(self): + """ + Find the exceptions associated with old-old and old-core interactions, + as well as new-new and new-core interactions. Theses exceptions will + be placed in CustomBondedForce that will interpolate electrostatics and + a softcore potential. + + TODO + ---- + * Move old_new_bond_exceptions to a dictionary or similar. + """ + + old_new_nonbonded_exceptions = "U_electrostatics + U_sterics;" + + if self._softcore_LJ_v2: + old_new_nonbonded_exceptions += "U_sterics = select(step(r - r_LJ), 4*epsilon*x*(x-1.0), U_sterics_quad);" + old_new_nonbonded_exceptions += f"U_sterics_quad = Force*(((r - r_LJ)^2)/2 - (r - r_LJ)) + U_sterics_cut;" + old_new_nonbonded_exceptions += f"U_sterics_cut = 4*epsilon*((sigma/r_LJ)^6)*(((sigma/r_LJ)^6) - 1.0);" + old_new_nonbonded_exceptions += f"Force = -4*epsilon*((-12*sigma^12)/(r_LJ^13) + (6*sigma^6)/(r_LJ^7));" + old_new_nonbonded_exceptions += f"x = (sigma/r)^6;" + old_new_nonbonded_exceptions += f"r_LJ = softcore_alpha*((26/7)*(sigma^6)*lambda_sterics_deprecated)^(1/6);" + old_new_nonbonded_exceptions += f"lambda_sterics_deprecated = new_interaction*(1.0 - lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + else: + old_new_nonbonded_exceptions += "U_sterics = 4*epsilon*x*(x-1.0); x = (sigma/reff_sterics)^6;" + old_new_nonbonded_exceptions += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" + old_new_nonbonded_exceptions += "reff_sterics = sigma*((softcore_alpha*lambda_alpha + (r/sigma)^6))^(1/6);" # effective softcore distance for sterics + old_new_nonbonded_exceptions += "lambda_alpha = new_interaction*(1-lambda_sterics_insert) + old_interaction*lambda_sterics_delete;" + + old_new_nonbonded_exceptions += "U_electrostatics = (lambda_electrostatics_insert * unique_new + unique_old * (1 - lambda_electrostatics_delete)) * ONE_4PI_EPS0*chargeProd/r;" + old_new_nonbonded_exceptions += "ONE_4PI_EPS0 = %f;" % ONE_4PI_EPS0 + + old_new_nonbonded_exceptions += "epsilon = (1-lambda_sterics)*epsilonA + lambda_sterics*epsilonB;" # interpolation + old_new_nonbonded_exceptions += "sigma = (1-lambda_sterics)*sigmaA + lambda_sterics*sigmaB;" + + old_new_nonbonded_exceptions += "lambda_sterics = new_interaction*lambda_sterics_insert + old_interaction*lambda_sterics_delete;" + old_new_nonbonded_exceptions += "new_interaction = delta(1-unique_new); old_interaction = delta(1-unique_old);" + + nonbonded_exceptions_force = openmm.CustomBondForce( + old_new_nonbonded_exceptions) + name = f"{nonbonded_exceptions_force.__class__.__name__}_exceptions" + nonbonded_exceptions_force.setName(name) + self._hybrid_system.addForce(nonbonded_exceptions_force) + + # For reference, set name in force dict + self._hybrid_system_forces['old_new_exceptions_force'] = nonbonded_exceptions_force + + if self._softcore_LJ_v2: + nonbonded_exceptions_force.addGlobalParameter( + "softcore_alpha", self._softcore_LJ_v2_alpha + ) + else: + nonbonded_exceptions_force.addGlobalParameter( + "softcore_alpha", self._softcore_alpha + ) + + # electrostatics insert + nonbonded_exceptions_force.addGlobalParameter( + "lambda_electrostatics_insert", 0.0 + ) + # electrostatics delete + nonbonded_exceptions_force.addGlobalParameter( + "lambda_electrostatics_delete", 0.0 + ) + # sterics insert + nonbonded_exceptions_force.addGlobalParameter( + "lambda_sterics_insert", 0.0 + ) + # steric delete + nonbonded_exceptions_force.addGlobalParameter( + "lambda_sterics_delete", 0.0 + ) + + for parameter in ['chargeProd', 'sigmaA', 'epsilonA', 'sigmaB', + 'epsilonB', 'unique_old', 'unique_new']: + nonbonded_exceptions_force.addPerBondParameter(parameter) + + # Prepare for exceptions loop by grabbing nonbonded forces, + # hybrid_to_old/new maps + old_system_nonbonded_force = self._old_system_forces['NonbondedForce'] + new_system_nonbonded_force = self._new_system_forces['NonbondedForce'] + hybrid_to_old_map = self._hybrid_to_old_map + hybrid_to_new_map = self._hybrid_to_new_map + + # First, loop through the old system's exceptions and add them to the + # hybrid appropriately: + for exception_pair, exception_parameters in self._old_system_exceptions.items(): + + [index1_old, index2_old] = exception_pair + [chargeProd_old, sigma_old, epsilon_old] = exception_parameters + + # Get hybrid indices: + index1_hybrid = self._old_to_hybrid_map[index1_old] + index2_hybrid = self._old_to_hybrid_map[index2_old] + index_set = {index1_hybrid, index2_hybrid} + + # Otherwise, check if one of the atoms in the set is in the + # unique_old_group and the other is not: + if (len(index_set.intersection(self._atom_classes['unique_old_atoms'])) > 0 and + (chargeProd_old.value_in_unit_system(unit.md_unit_system) != 0.0 or + epsilon_old.value_in_unit_system(unit.md_unit_system) != 0.0)): + if self._interpolate_14s: + # If we are interpolating 1,4s, then we anneal this term + # off; otherwise, the exception force is constant and + # already handled in the standard nonbonded force + nonbonded_exceptions_force.addBond( + index1_hybrid, index2_hybrid, + [chargeProd_old, sigma_old, epsilon_old, sigma_old, + epsilon_old * 0.0, 1, 0] + ) + + # Next, loop through the new system's exceptions and add them to the + # hybrid appropriately + for exception_pair, exception_parameters in self._new_system_exceptions.items(): + [index1_new, index2_new] = exception_pair + [chargeProd_new, sigma_new, epsilon_new] = exception_parameters + + # Get hybrid indices: + index1_hybrid = self._new_to_hybrid_map[index1_new] + index2_hybrid = self._new_to_hybrid_map[index2_new] + + index_set = {index1_hybrid, index2_hybrid} + + # Look for the final class- interactions between uniquenew-core and + # uniquenew-environment. They are treated + # similarly: they are simply on and constant the entire time + # (as a valence term) + if (len(index_set.intersection(self._atom_classes['unique_new_atoms'])) > 0 and + (chargeProd_new.value_in_unit_system(unit.md_unit_system) != 0.0 or + epsilon_new.value_in_unit_system(unit.md_unit_system) != 0.0)): + if self._interpolate_14s: + # If we are interpolating 1,4s, then we anneal this term + # on; otherwise, the exception force is constant and + # already handled in the standard nonbonded force + nonbonded_exceptions_force.addBond( + index1_hybrid, index2_hybrid, + [chargeProd_new, sigma_new, epsilon_new * 0.0, + sigma_new, epsilon_new, 0, 1] + ) + + def _compute_hybrid_positions(self): + """ + The positions of the hybrid system. Dimensionality is (n_environment + + n_core + n_old_unique + n_new_unique), + The positions are assigned by first copying all the mapped positions + from the old system in, then copying the + mapped positions from the new system. This means that there is an + assumption that the positions common to old and new are the same + (which is the case for perses as-is). + + Returns + ------- + hybrid_positions : np.ndarray [n, 3] + Positions of the hybrid system, in nm + """ + # Get unitless positions + old_pos_without_units = np.array( + self._old_positions.value_in_unit(unit.nanometer)) + new_pos_without_units = np.array( + self._new_positions.value_in_unit(unit.nanometer)) + + # Determine the number of particles in the system + n_atoms_hybrid = self._hybrid_system.getNumParticles() + + # Initialize an array for hybrid positions + hybrid_pos_array = np.zeros([n_atoms_hybrid, 3]) + + # Loop through the old system indices, and assign positions. + for old_idx, hybrid_idx in self._old_to_hybrid_map.items(): + hybrid_pos_array[hybrid_idx, :] = old_pos_without_units[old_idx, :] + + # Do the same for new indices. Note that this overwrites some + # coordinates, but as stated above, the assumption is that these are + # the same. + for new_idx, hybrid_idx in self._new_to_hybrid_map.items(): + hybrid_pos_array[hybrid_idx, :] = new_pos_without_units[new_idx, :] + + return unit.Quantity(hybrid_pos_array, unit=unit.nanometers) + + def _create_mdtraj_topology(self): + """ + Create an MDTraj trajectory of the hybrid system. + + Note + ---- + This is purely for writing out trajectories and is not expected to be + parametrized. + + TODO + ---- + * A lot of this can be simplified / reworked. + """ + old_top = mdt.Topology.from_openmm(self._old_topology) + new_top = mdt.Topology.from_openmm(self._new_topology) + + hybrid_topology = copy.deepcopy(old_top) + + added_atoms = dict() + + # Get the core atoms in the new index system (as opposed to the hybrid + # index system). We will need this later + core_atoms_new_indices = set(self._core_old_to_new_map.values()) + + # Now, add each unique new atom to the topology (this is the same order + # as the system) + for particle_idx in self._unique_new_atoms: + new_particle_hybrid_idx = self._new_to_hybrid_map[particle_idx] + new_system_atom = new_top.atom(particle_idx) + + # First, we get the residue in the new system associated with this + # atom + new_system_res = new_system_atom.residue + + # Next, we have to enumerate the other atoms in that residue to + # find mapped atoms + new_system_atom_set = {atom.index for atom in new_system_res.atoms} + + # Now, we find the subset of atoms that are mapped. These must be + # in the "core" category, since they are mapped and part of a + # changing residue + mapped_new_atom_indices = core_atoms_new_indices.intersection( + new_system_atom_set) + + # Now get the old indices of the above atoms so that we can find + # the appropriate residue in the old system for this we can use the + # new to old atom map + mapped_old_atom_indices = [self._new_to_old_map[atom_idx] for + atom_idx in mapped_new_atom_indices] + + # We can just take the first one--they all have the same residue + first_mapped_old_atom_index = mapped_old_atom_indices[0] + + # Get the atom object corresponding to this index from the hybrid + # (which is a deepcopy of the old) + mapped_hybrid_system_atom = hybrid_topology.atom( + first_mapped_old_atom_index) + + # Get the residue that is relevant to this atom + mapped_residue = mapped_hybrid_system_atom.residue + + # Add the atom using the mapped residue + added_atoms[new_particle_hybrid_idx] = hybrid_topology.add_atom( + new_system_atom.name, + new_system_atom.element, + mapped_residue) + + # Now loop through the bonds in the new system, and if the bond + # contains a unique new atom, then add it to the hybrid topology + for (atom1, atom2) in new_top.bonds: + at1_hybrid_idx = self._new_to_hybrid_map[atom1.index] + at2_hybrid_idx = self._new_to_hybrid_map[atom2.index] + + # If at least one atom is in the unique new class, we need to add + # it to the hybrid system + at1_uniq = at1_hybrid_idx in self._atom_classes['unique_new_atoms'] + at2_uniq = at2_hybrid_idx in self._atom_classes['unique_new_atoms'] + if at1_uniq or at2_uniq: + if at1_uniq: + atom1_to_bond = added_atoms[at1_hybrid_idx] + else: + old_idx = self._hybrid_to_old_map[at1_hybrid_idx] + atom1_to_bond = hybrid_topology.atom(old_idx) + + if at2_uniq: + atom2_to_bond = added_atoms[at2_hybrid_idx] + else: + old_idx = self._hybrid_to_old_map[at2_hybrid_idx] + atom2_to_bond = hybrid_topology.atom(old_idx) + + hybrid_topology.add_bond(atom1_to_bond, atom2_to_bond) + + return hybrid_topology + + def _create_hybrid_topology(self): + """ + Create a hybrid openmm.app.Topology from the input old and new + Topologies. + + Note + ---- + * This is not intended for parameterisation purposes, but instead + for system visualisation. + * Unlike the MDTraj Topology object, the residues of the alchemical + species are not squashed. + """ + + hybrid_top = app.Topology() + + # In the first instance, create a list of necessary atoms from + # both old & new Topologies + atom_list = [] + + for pidx in range(self.hybrid_system.getNumParticles()): + if pidx in self._hybrid_to_old_map: + idx = self._hybrid_to_old_map[pidx] + atom_list.append(list(self._old_topology.atoms())[idx]) + else: + idx = self._hybrid_to_new_map[pidx] + atom_list.append(list(self._new_topology.atoms())[idx]) + + # Now we loop over the atoms and add them in alongside chains & resids + + # Non ideal variables to track the previous set of residues & chains + # without having to constantly search backwards + prev_res = None + prev_chain = None + + for at in atom_list: + if at.residue.chain != prev_chain: + hybrid_chain = hybrid_top.addChain() + prev_chain = at.residue.chain + + if at.residue != prev_res: + hybrid_residue = hybrid_top.addResidue( + at.residue.name, hybrid_chain, at.residue.id + ) + prev_res = at.residue + + hybrid_atom = hybrid_top.addAtom( + at.name, at.element, hybrid_residue, at.id + ) + + # Next we deal with bonds + # First we add in all the old topology bonds + for bond in self._old_topology.bonds(): + at1 = self.old_to_hybrid_atom_map[bond.atom1.index] + at2 = self.old_to_hybrid_atom_map[bond.atom2.index] + + hybrid_top.addBond( + list(hybrid_top.atoms())[at1], + list(hybrid_top.atoms())[at2], + bond.type, bond.order, + ) + + # Finally we add in all the bonds from the unique atoms in the + # new Topology + for bond in self._new_topology.bonds(): + at1 = self.new_to_hybrid_atom_map[bond.atom1.index] + at2 = self.new_to_hybrid_atom_map[bond.atom2.index] + if ((at1 in self._atom_classes['unique_new_atoms']) or + (at2 in self._atom_classes['unique_new_atoms'])): + hybrid_top.addBond( + list(hybrid_top.atoms())[at1], + list(hybrid_top.atoms())[at2], + bond.type, bond.order, + ) + + return hybrid_top + + def old_positions(self, hybrid_positions): + """ + From input hybrid positions, get the positions which would correspond + to the old system + + Parameters + ---------- + hybrid_positions : [n, 3] np.ndarray or simtk.unit.Quantity + The positions of the hybrid system + + Returns + ------- + old_positions : [m, 3] np.ndarray with unit + The positions of the old system + """ + n_atoms_old = self._old_system.getNumParticles() + # making sure hybrid positions are simtk.unit.Quantity objects + if not isinstance(hybrid_positions, unit.Quantity): + hybrid_positions = unit.Quantity(hybrid_positions, + unit=unit.nanometer) + old_positions = unit.Quantity(np.zeros([n_atoms_old, 3]), + unit=unit.nanometer) + for idx in range(n_atoms_old): + hyb_idx = self._old_to_hybrid_map[idx] + old_positions[idx, :] = hybrid_positions[hyb_idx, :] + return old_positions + + def new_positions(self, hybrid_positions): + """ + From input hybrid positions, get the positions which could correspond + to the new system. + + Parameters + ---------- + hybrid_positions : [n, 3] np.ndarray or simtk.unit.Quantity + The positions of the hybrid system + + Returns + ------- + new_positions : [m, 3] np.ndarray with unit + The positions of the new system + """ + n_atoms_new = self._new_system.getNumParticles() + # making sure hybrid positions are simtk.unit.Quantity objects + if not isinstance(hybrid_positions, unit.Quantity): + hybrid_positions = unit.Quantity(hybrid_positions, + unit=unit.nanometer) + new_positions = unit.Quantity(np.zeros([n_atoms_new, 3]), + unit=unit.nanometer) + for idx in range(n_atoms_new): + hyb_idx = self._new_to_hybrid_map[idx] + new_positions[idx, :] = hybrid_positions[hyb_idx, :] + return new_positions + + @property + def hybrid_system(self): + """ + The hybrid system. + + Returns + ------- + hybrid_system : openmm.System + The system representing a hybrid between old and new topologies + """ + return self._hybrid_system + + @property + def new_to_hybrid_atom_map(self): + """ + Give a dictionary that maps new system atoms to the hybrid system. + + Returns + ------- + new_to_hybrid_atom_map : dict of {int, int} + The mapping of atoms from the new system to the hybrid + """ + return self._new_to_hybrid_map + + @property + def old_to_hybrid_atom_map(self): + """ + Give a dictionary that maps old system atoms to the hybrid system. + + Returns + ------- + old_to_hybrid_atom_map : dict of {int, int} + The mapping of atoms from the old system to the hybrid + """ + return self._old_to_hybrid_map + + @property + def hybrid_positions(self): + """ + The positions of the hybrid system. Dimensionality is (n_environment + + n_core + n_old_unique + n_new_unique). + The positions are assigned by first copying all the mapped positions + from the old system in, then copying the mapped positions from the new + system. + + Returns + ------- + hybrid_positions : [n, 3] Quantity nanometers + """ + return self._hybrid_positions + + @property + def hybrid_topology(self): + """ + An MDTraj hybrid topology for the purpose of writing out trajectories. + + Note that we do not expect this to be able to be parameterized by the + openmm forcefield class. + + Returns + ------- + hybrid_topology : mdtraj.Topology + """ + return self._hybrid_topology + + @property + def omm_hybrid_topology(self): + """ + An OpenMM format of the hybrid topology. Also cannot be used to + parameterize system, only to write out trajectories. + + Returns + ------- + hybrid_topology : simtk.openmm.app.Topology + + + .. versionchanged:: OpenFE 0.11 + Now returns a Topology directly constructed from the input + old / new Topologies, instead of trying to roundtrip an + mdtraj topology. + """ + return self._omm_hybrid_topology + + @property + def has_virtual_sites(self): + """ + Checks the hybrid system and tells us if we have any virtual sites. + + Returns + ------- + bool + ``True`` if there are virtual sites, otherwise ``False``. + """ + for ix in range(self._hybrid_system.getNumParticles()): + if self._hybrid_system.isVirtualSite(ix): + return True + return False + @property def initial_atom_indices(self): """ diff --git a/feflow/utils/lambda_protocol.py b/feflow/utils/lambda_protocol.py new file mode 100644 index 0000000..9522fd1 --- /dev/null +++ b/feflow/utils/lambda_protocol.py @@ -0,0 +1,334 @@ +# License: MIT + +import numpy as np +import warnings +import copy +from openmmtools.alchemy import AlchemicalState + + +class LambdaProtocol: + """Protocols for perturbing each of the component energy terms in alchemical + free energy simulations. + + TODO + ---- + * Class needs cleaning up and made more consistent + """ + + default_functions = {'lambda_sterics_core': + lambda x: x, + 'lambda_electrostatics_core': + lambda x: x, + 'lambda_sterics_insert': + lambda x: 2.0 * x if x < 0.5 else 1.0, + 'lambda_sterics_delete': + lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + 'lambda_electrostatics_insert': + lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + 'lambda_electrostatics_delete': + lambda x: 2.0 * x if x < 0.5 else 1.0, + 'lambda_bonds': + lambda x: x, + 'lambda_angles': + lambda x: x, + 'lambda_torsions': + lambda x: x + } + + # lambda components for each component, + # all run from 0 -> 1 following master lambda + def __init__(self, functions='default', windows=10, lambda_schedule=None): + """Instantiates lambda protocol to be used in a free energy + calculation. Can either be user defined, by passing in a dict, or using + one of the pregenerated sets by passing in a string 'default', 'namd' + or 'quarters' + + All protocols must begin and end at 0 and 1 respectively. Any energy + term not defined in `functions` dict will be set to the function in + `default_functions` + + Pre-coded options: + default : ele and LJ terms of the old system are turned off between + 0.0 -> 0.5 ele and LJ terms of the new system are turned on between + 0.5 -> 1.0 core terms treated linearly + + quarters : 0.25 of the protocol is used in turn to individually change + the (a) off old ele, (b) off old sterics, (c) on new sterics (d) on new + ele core terms treated linearly + + namd : follows the protocol outlined here: + https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00362# + Jiang, Wei, Christophe Chipot, and BenoƮt Roux. "Computing Relative + Binding Affinity of Ligands to Receptor: An Effective Hybrid + Single-Dual-Topology Free-Energy Perturbation Approach in NAMD." + Journal of chemical information and modeling 59.9 (2019): 3794-3802. + + ele-scaled : all terms are treated as in default, except for the old + and new ele these are scaled with lambda^0.5, so as to be linear in + energy, rather than lambda + + Parameters + ---------- + functions : str or dict + One of the predefined lambda protocols + ['default','namd','quarters'] or a dictionary. Default "default". + windows : int + Number of windows which this lambda schedule is intended to be used + with. This value is used to validate the lambda function. + lambda_schedule : list of floats + Schedule of lambda windows to be sampled. If ``None`` will default + to a linear spacing of windows as defined by + ``np.linspace(0. ,1. ,windows)``. Default ``None``. + + Attributes + ---------- + functions : dict + Lambda protocol to be used. + lambda_schedule : list + Schedule of windows to be sampled. + """ + self.functions = copy.deepcopy(functions) + + # set the lambda schedule + self.lambda_schedule = self._validate_schedule(lambda_schedule, + windows) + if lambda_schedule: + self.lambda_schedule = lambda_schedule + else: + self.lambda_schedule = np.linspace(0., 1., windows) + + if type(self.functions) == dict: + self.type = 'user-defined' + elif type(self.functions) == str: + self.functions = None # will be set later + self.type = functions + + if self.functions is None: + if self.type == 'default': + self.functions = copy.deepcopy( + LambdaProtocol.default_functions) + elif self.type == 'namd': + self.functions = { + 'lambda_sterics_core': lambda x: x, + 'lambda_electrostatics_core': lambda x: x, + 'lambda_sterics_insert': lambda x: (3. / 2.) * x if x < (2. / 3.) else 1.0, + 'lambda_sterics_delete': lambda x: 0.0 if x < (1. / 3.) else (x - (1. / 3.)) * (3. / 2.), + 'lambda_electrostatics_insert': lambda x: 0.0 if x < 0.5 else 2.0 * (x - 0.5), + 'lambda_electrostatics_delete': lambda x: 2.0 * x if x < 0.5 else 1.0, + 'lambda_bonds': lambda x: x, + 'lambda_angles': lambda x: x, + 'lambda_torsions': lambda x: x + } + elif self.type == 'quarters': + self.functions = { + 'lambda_sterics_core': lambda x: x, + 'lambda_electrostatics_core': lambda x: x, + 'lambda_sterics_insert': lambda x: 0. if x < 0.5 else 1 if x > 0.75 else 4 * (x - 0.5), + 'lambda_sterics_delete': lambda x: 0. if x < 0.25 else 1 if x > 0.5 else 4 * (x - 0.25), + 'lambda_electrostatics_insert': lambda x: 0. if x < 0.75 else 4 * (x - 0.75), + 'lambda_electrostatics_delete': lambda x: 4.0 * x if x < 0.25 else 1.0, + 'lambda_bonds': lambda x: x, + 'lambda_angles': lambda x: x, + 'lambda_torsions': lambda x: x + } + elif self.type == 'ele-scaled': + self.functions = { + 'lambda_electrostatics_insert': lambda x: 0.0 if x < 0.5 else ((2*(x-0.5))**0.5), + 'lambda_electrostatics_delete': lambda x: (2*x)**2 if x < 0.5 else 1.0 + } + elif self.type == 'user-defined': + self.functions = functions + else: + errmsg = f"LambdaProtocol type : {self.type} not recognised " + raise ValueError(errmsg) + + self._validate_functions(n=windows) + self._check_for_naked_charges() + + @staticmethod + def _validate_schedule(schedule, windows): + """ + Checks that the input lambda schedule is valid. + + Rules are: + - Must begin at 0 and end at 1 + - Must be monotonically increasing + + Parameters + ---------- + schedule : list of floats + The lambda schedule. If ``None`` the method returns + ``np.linspace(0. ,1. ,windows)``. + windows : int + Number of windows to be sampled. + + Returns + ------- + schedule : list of floats + A valid lambda schedule. + """ + if schedule is None: + return np.linspace(0., 1., windows) + + # Check end states + if schedule[0] != 0 or schedule[-1] != 1: + errmsg = ("end and start lambda windows must be lambda 0 and 1 " + "respectively") + raise ValueError(errmsg) + + # Check monotonically increasing + difference = np.diff(schedule) + + if not all(i >= 0. for i in difference): + errmsg = "The lambda schedule is not monotonic" + raise ValueError(errmsg) + + return schedule + + def _validate_functions(self, n=10): + """Ensures that all the lambda functions adhere to the rules: + - must begin at 0. + - must finish at 1. + - must be monotonically increasing + + Parameters + ---------- + n : int, default 10 + number of grid points used to check monotonicity + """ + # the individual lambda functions that must be defined for + required_functions = list(LambdaProtocol.default_functions.keys()) + + for function in required_functions: + if function not in self.functions: + # IA switched from warn to error here + errmsg = (f"function {function} is missing from " + "self.lambda_functions.") + raise ValueError(errmsg) + + # Check that the function starts and ends at 0 and 1 respectively + if self.functions[function](0) != 0: + raise ValueError("lambda functions must start at 0") + if self.functions[function](1) != 1: + raise ValueError("lambda functions must end at 1") + + # now validatate that it's monotonic + global_lambda = np.linspace(0., 1., n) + sub_lambda = [self.functions[function](lam) for + lam in global_lambda] + difference = np.diff(sub_lambda) + + if not all(i >= 0. for i in difference): + wmsg = (f"The function {function} is not monotonic as " + "typically expected.") + warnings.warn(wmsg) + + def _check_for_naked_charges(self): + """ + Checks that there are no cases where atoms have charge but no sterics. + + This avoids issues with singularities and/or excessive forces near + the end states (even when using softcore electrostatics). + """ + global_lambda = self.lambda_schedule + + def check_overlap(ele, sterics, global_lambda, functions, endstate): + for lam in global_lambda: + ele_val = functions[ele](lam) + ster_val = functions[sterics](lam) + # if charge > 0 and sterics == 0 raise error + if ele_val != endstate and ster_val == endstate: + errmsg = ("There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: {lam} {ele_val} {ster_val}") + raise ValueError(errmsg) + + # checking unique new terms first + ele = 'lambda_electrostatics_insert' + sterics = 'lambda_sterics_insert' + check_overlap(ele, sterics, global_lambda, self.functions, endstate=0) + + # checking unique old terms now + ele = 'lambda_electrostatics_delete' + sterics = 'lambda_sterics_delete' + check_overlap(ele, sterics, global_lambda, self.functions, endstate=1) + + def get_functions(self): + return self.functions + + def plot_functions(self, lambda_schedule=None): + """ + Plot the function for ease of visualisation. + + Parameters + ---------- + shedule : np.ndarray + The lambda schedule to plot the function along. If ``None`` plot + the one stored within this class. Default ``None``. + """ + import matplotlib.pyplot as plt + + fig = plt.figure(figsize=(10, 5)) + + global_lambda = lambda_schedule if lambda_schedule else self.lambda_schedule + + for f in self.functions: + plt.plot(global_lambda, + [self.functions[f](lam) for lam in global_lambda], + alpha=0.5, label=f) + + plt.xlabel('global lambda') + plt.ylabel('sub-lambda') + plt.legend() + plt.show() + + +class RelativeAlchemicalState(AlchemicalState): + """ + Relative AlchemicalState to handle all lambda parameters required for + relative perturbations + lambda = 1 refers to ON, i.e. fully interacting while + lambda = 0 refers to OFF, i.e. non-interacting with the system + all lambda functions will follow from 0 -> 1 following the master lambda + lambda*core parameters perturb linearly + lambda_sterics_insert and lambda_electrostatics_delete perturb in the + first half of the protocol 0 -> 0.5 + lambda_sterics_delete and lambda_electrostatics_insert perturb in the + second half of the protocol 0.5 -> 1 + + Attributes + ---------- + lambda_sterics_core + lambda_electrostatics_core + lambda_sterics_insert + lambda_sterics_delete + lambda_electrostatics_insert + lambda_electrostatics_delete + """ + + class _LambdaParameter(AlchemicalState._LambdaParameter): + pass + + lambda_sterics_core = _LambdaParameter('lambda_sterics_core') + lambda_electrostatics_core = _LambdaParameter('lambda_electrostatics_core') + lambda_sterics_insert = _LambdaParameter('lambda_sterics_insert') + lambda_sterics_delete = _LambdaParameter('lambda_sterics_delete') + lambda_electrostatics_insert = _LambdaParameter( + 'lambda_electrostatics_insert') + lambda_electrostatics_delete = _LambdaParameter( + 'lambda_electrostatics_delete') + + def set_alchemical_parameters(self, global_lambda, + lambda_protocol=LambdaProtocol()): + """Set each lambda value according to the lambda_functions protocol. + The undefined parameters (i.e. those being set to None) remain + undefined. + Parameters + ---------- + lambda_value : float + The new value for all defined parameters. + """ + self.global_lambda = global_lambda + for parameter_name in lambda_protocol.functions: + lambda_value = lambda_protocol.functions[parameter_name](global_lambda) + setattr(self, parameter_name, lambda_value)