From ede08ca706ddc8647d78f2efe83d6e18a9223f8b Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sun, 24 Nov 2024 10:54:07 +0000 Subject: [PATCH] Fix solutes with v-sites (#76) --- absolv/fep.py | 60 +--------- absolv/runner.py | 172 +++++++++++++++++++++++++--- absolv/tests/test_fep.py | 22 ---- absolv/tests/test_runner.py | 153 ++++++++++++++++++++++++- absolv/tests/utils/test_openmm.py | 2 +- absolv/tests/utils/test_topology.py | 21 +--- absolv/utils/openmm.py | 18 ++- absolv/utils/topology.py | 22 +--- docs/user-guide/overview.md | 2 +- regression/run.py | 8 +- 10 files changed, 331 insertions(+), 149 deletions(-) diff --git a/absolv/fep.py b/absolv/fep.py index f004ff5..1e14f85 100644 --- a/absolv/fep.py +++ b/absolv/fep.py @@ -1,4 +1,5 @@ """Prepare OpenMM systems for FEP calculations.""" + import copy import itertools @@ -22,54 +23,6 @@ ) -def _find_v_sites( - system: openmm.System, atom_indices: list[set[int]] -) -> list[set[int]]: - """Finds any virtual sites in the system and ensures their indices get appended - to the atom index list. - - Args: - system: The system that may contain v-sites. - atom_indices: A list of per-molecule atom indices - - Returns: - A list of the per molecule **particle** indices. - """ - - atom_to_molecule_idx = { - atom_idx: i for i, indices in enumerate(atom_indices) for atom_idx in indices - } - - particle_to_atom_idx = {} - atom_idx = 0 - - for particle_idx in range(system.getNumParticles()): - if system.isVirtualSite(particle_idx): - continue - - particle_to_atom_idx[particle_idx] = atom_idx - atom_idx += 1 - - atom_idx = 0 - - remapped_atom_indices: list[set[int]] = [set() for _ in range(len(atom_indices))] - - for particle_idx in range(system.getNumParticles()): - if not system.isVirtualSite(particle_idx): - molecule_idx = atom_to_molecule_idx[atom_idx] - atom_idx += 1 - - else: - v_site = system.getVirtualSite(particle_idx) - parent_atom_idx = particle_to_atom_idx[v_site.getParticle(0)] - - molecule_idx = atom_to_molecule_idx[parent_atom_idx] - - remapped_atom_indices[molecule_idx].add(particle_idx) - - return remapped_atom_indices - - def _find_nonbonded_forces( system: openmm.System, ) -> tuple[ @@ -468,7 +421,7 @@ def apply_fep( system: The chemical system to generate the alchemical system from alchemical_indices: The atom indices corresponding to each molecule that should be alchemically transformable. The atom indices **must** - correspond to **all** atoms in each molecule as alchemically + correspond to **all** atoms / v-sites in each molecule as alchemically transforming part of a molecule is not supported. persistent_indices: The atom indices corresponding to each molecule that should **not** be alchemically transformable. @@ -481,15 +434,6 @@ def apply_fep( system = copy.deepcopy(system) - # Make sure we track v-sites attached to any solutes that may be alchemically - # turned off. We do this as a post-process step as the OpenFF toolkit does not - # currently expose a clean way to access this information. - atom_indices = alchemical_indices + persistent_indices - atom_indices = _find_v_sites(system, atom_indices) - - alchemical_indices = atom_indices[: len(alchemical_indices)] - persistent_indices = atom_indices[len(alchemical_indices) :] - ( nonbonded_force, custom_nonbonded_force, diff --git a/absolv/runner.py b/absolv/runner.py index a6a6dad..5a55f7b 100644 --- a/absolv/runner.py +++ b/absolv/runner.py @@ -1,5 +1,6 @@ """Run calculations defined by a config.""" +import collections import functools import multiprocessing import pathlib @@ -17,6 +18,7 @@ import openff.toolkit import openff.utilities import openmm +import openmm.app import openmm.unit import pymbar import tqdm @@ -35,12 +37,152 @@ class PreparedSystem(typing.NamedTuple): system: openmm.System """The alchemically modified OpenMM system.""" - topology: openff.toolkit.Topology - """The OpenFF topology with any box vectors set.""" + topology: openmm.app.Topology + """The OpenMM topology with any box vectors set.""" coords: openmm.unit.Quantity """The coordinates of the system.""" +def _rebuild_topology( + orig_top: openff.toolkit.Topology, + orig_coords: openmm.unit.Quantity, + system: openmm.System, +) -> tuple[openmm.app.Topology, openmm.unit.Quantity, list[set[int]]]: + """Rebuild the topology to also include virtual sites.""" + atom_idx_to_residue_idx = {} + atom_idx = 0 + + for residue_idx, molecule in enumerate(orig_top.molecules): + for _ in molecule.atoms: + atom_idx_to_residue_idx[atom_idx] = residue_idx + atom_idx += 1 + + particle_idx_to_atom_idx = {} + atom_idx = 0 + + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + continue + + particle_idx_to_atom_idx[particle_idx] = atom_idx + atom_idx += 1 + + atoms_off = [*orig_top.atoms] + particles = [] + + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + v_site = system.getVirtualSite(particle_idx) + + parent_idxs = { + particle_idx_to_atom_idx[v_site.getParticle(i)] + for i in range(v_site.getNumParticles()) + } + parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))] + + particles.append((-1, parent_residue)) + continue + + atom_idx = particle_idx_to_atom_idx[particle_idx] + residue_idx = atom_idx_to_residue_idx[atom_idx] + + particles.append((atoms_off[atom_idx].atomic_number, residue_idx)) + + topology = openmm.app.Topology() + + if orig_top.box_vectors is not None: + topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm()) + + chain = topology.addChain() + + atom_counts_per_residue = collections.defaultdict( + lambda: collections.defaultdict(int) + ) + atoms = [] + + last_residue_idx = -1 + residue = None + + residue_to_particle_idx = collections.defaultdict(list) + + for particle_idx, (atomic_num, residue_idx) in enumerate(particles): + if residue_idx != last_residue_idx: + last_residue_idx = residue_idx + residue = topology.addResidue("UNK", chain) + + element = ( + None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num) + ) + symbol = "X" if element is None else element.symbol + + atom_counts_per_residue[residue_idx][atomic_num] += 1 + atom = topology.addAtom( + f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"), + element, + residue, + ) + atoms.append(atom) + + residue_to_particle_idx[residue_idx].append(particle_idx) + + _rename_residues(topology) + + atom_idx_to_particle_idx = {j: i for i, j in particle_idx_to_atom_idx.items()} + + for bond in orig_top.bonds: + if atoms[atom_idx_to_particle_idx[bond.atom1_index]].residue.name == "HOH": + continue + + topology.addBond( + atoms[atom_idx_to_particle_idx[bond.atom1_index]], + atoms[atom_idx_to_particle_idx[bond.atom2_index]], + ) + + coords_full = [] + + for particle_idx in range(system.getNumParticles()): + if particle_idx in particle_idx_to_atom_idx: + coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]] + coords_full.append(coords_i.value_in_unit(openmm.unit.angstrom)) + else: + coords_full.append(numpy.zeros((1, 3))) + + coords_full = numpy.vstack(coords_full) * openmm.unit.angstrom + + if len(orig_coords) != len(coords_full): + context = openmm.Context(system, openmm.VerletIntegrator(1.0)) + context.setPositions(coords_full) + context.computeVirtualSites() + + coords_full = context.getState(getPositions=True).getPositions(asNumpy=True) + + residues = [ + set(residue_to_particle_idx[residue_idx]) + for residue_idx in range(len(residue_to_particle_idx)) + ] + + return topology, coords_full, residues + + +def _rename_residues(topology: openmm.app.Topology): + """Attempts to assign standard residue names to known residues""" + + for residue in topology.residues(): + symbols = sorted( + ( + atom.element.symbol + for atom in residue.atoms() + if atom.element is not None + ) + ) + + if symbols == ["H", "H", "O"]: + residue.name = "HOH" + + for i, atom in enumerate(residue.atoms()): + atom.name = "OW" if atom.element.symbol == "O" else f"HW{i}" + + def _setup_solvent( solvent_idx: typing.Literal["solvent-a", "solvent-b"], components: list[tuple[str, int]], @@ -67,19 +209,21 @@ def _setup_solvent( is_vacuum = n_solvent_molecules == 0 - topology, coords = absolv.setup.setup_system(components) - topology.box_vectors = None if is_vacuum else topology.box_vectors + topology_off, coords = absolv.setup.setup_system(components) + topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors + + if isinstance(force_field, openff.toolkit.ForceField): + original_system = force_field.create_openmm_system(topology_off) + else: + original_system: openmm.System = force_field(topology_off, coords, solvent_idx) - atom_indices = absolv.utils.topology.topology_to_atom_indices(topology) + topology, coords, atom_indices = _rebuild_topology( + topology_off, coords, original_system + ) alchemical_indices = atom_indices[:n_solute_molecules] persistent_indices = atom_indices[n_solute_molecules:] - if isinstance(force_field, openff.toolkit.ForceField): - original_system = force_field.create_openmm_system(topology) - else: - original_system: openmm.System = force_field(topology, coords, solvent_idx) - alchemical_system = absolv.fep.apply_fep( original_system, alchemical_indices, @@ -196,7 +340,7 @@ def _run_eq_phase( """ platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) @@ -312,7 +456,7 @@ def _run_phase_end_states( ): platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) @@ -363,11 +507,11 @@ def _run_switching( ): platform = ( femto.md.constants.OpenMMPlatform.REFERENCE - if prepared_system.topology.box_vectors is None + if prepared_system.topology.getPeriodicBoxVectors() is None else platform ) - mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm()) + mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology) trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology) trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology) diff --git a/absolv/tests/test_fep.py b/absolv/tests/test_fep.py index a4786dc..2ccb36f 100644 --- a/absolv/tests/test_fep.py +++ b/absolv/tests/test_fep.py @@ -13,33 +13,11 @@ _add_electrostatics_lambda, _add_lj_vdw_lambda, _find_nonbonded_forces, - _find_v_sites, apply_fep, ) from absolv.tests import is_close -def test_find_v_sites(): - """Ensure that v-sites are correctly detected from an OMM system and assigned - to the right parent molecule.""" - - # Construct a mock system of V A A A V A A where (0, 5, 6), (3,), (4, 1, 2) - # are the core molecules. - system = openmm.System() - - for _ in range(7): - system.addParticle(1.0) - - system.setVirtualSite(0, openmm.TwoParticleAverageSite(5, 6, 0.5, 0.5)) - system.setVirtualSite(4, openmm.TwoParticleAverageSite(1, 2, 0.5, 0.5)) - - atom_indices = [{0, 1}, {2}, {3, 4}] - - particle_indices = _find_v_sites(system, atom_indices) - - assert particle_indices == [{1, 2, 4}, {3}, {0, 5, 6}] - - def test_find_nonbonded_forces_lj_only(aq_nacl_lj_system): ( nonbonded_force, diff --git a/absolv/tests/test_runner.py b/absolv/tests/test_runner.py index 7c30193..9f78448 100644 --- a/absolv/tests/test_runner.py +++ b/absolv/tests/test_runner.py @@ -1,5 +1,7 @@ import femto.md.constants +import numpy import openff.toolkit +import openff.units import openmm.unit import pytest @@ -58,6 +60,153 @@ ) +def test_rebuild_topology(): + ff = openff.toolkit.ForceField("tip4p_fb.offxml", "openff-2.0.0.offxml") + + v_site_handler = ff.get_parameter_handler("VirtualSites") + v_site_handler.add_parameter( + { + "type": "DivalentLonePair", + "match": "once", + "smirks": "[*:2][#7:1][*:3]", + "distance": 0.4 * openff.units.unit.angstrom, + "epsilon": 0.0 * openff.units.unit.kilojoule_per_mole, + "sigma": 0.1 * openff.units.unit.nanometer, + "outOfPlaneAngle": 0.0 * openff.units.unit.degree, + "charge_increment1": 0.0 * openff.units.unit.elementary_charge, + "charge_increment2": 0.0 * openff.units.unit.elementary_charge, + "charge_increment3": 0.0 * openff.units.unit.elementary_charge, + } + ) + + solute = openff.toolkit.Molecule.from_smiles("c1ccncc1") + solute.generate_conformers(n_conformers=1) + solvent = openff.toolkit.Molecule.from_smiles("O") + solvent.generate_conformers(n_conformers=1) + + orig_coords = ( + numpy.vstack( + [ + solute.conformers[0].m_as("angstrom"), + solvent.conformers[0].m_as("angstrom") + numpy.array([10.0, 0.0, 0.0]), + solvent.conformers[0].m_as("angstrom") + numpy.array([20.0, 0.0, 0.0]), + ] + ) + * openmm.unit.angstrom + ) + + expected_box_vectors = numpy.eye(3) * 30.0 + + orig_top = openff.toolkit.topology.Topology.from_molecules( + [solute, solvent, solvent] + ) + orig_top.box_vectors = expected_box_vectors * openmm.unit.angstrom + + system = ff.create_openmm_system(orig_top) + + n_v_sites = sum( + 1 for i in range(system.getNumParticles()) if system.isVirtualSite(i) + ) + assert n_v_sites == 3 + + top, coords, idxs = absolv.runner._rebuild_topology(orig_top, orig_coords, system) + + found_atoms = [ + ( + atom.name, + atom.element.symbol if atom.element is not None else None, + atom.residue.index, + atom.residue.name, + ) + for atom in top.atoms() + ] + expected_atoms = [ + ("C1x", "C", 0, "UNK"), + ("C2x", "C", 0, "UNK"), + ("C3x", "C", 0, "UNK"), + ("N1x", "N", 0, "UNK"), + ("C4x", "C", 0, "UNK"), + ("C5x", "C", 0, "UNK"), + ("H1x", "H", 0, "UNK"), + ("H2x", "H", 0, "UNK"), + ("H3x", "H", 0, "UNK"), + ("H4x", "H", 0, "UNK"), + ("H5x", "H", 0, "UNK"), + ("OW", "O", 1, "HOH"), + ("HW1", "H", 1, "HOH"), + ("HW2", "H", 1, "HOH"), + ("OW", "O", 2, "HOH"), + ("HW1", "H", 2, "HOH"), + ("HW2", "H", 2, "HOH"), + ("X1x", None, 3, "UNK"), + ("X1x", None, 4, "UNK"), + ("X1x", None, 5, "UNK"), + ] + + assert found_atoms == expected_atoms + + expected_coords = numpy.array( + [ + [0.00241, 0.10097, -0.05663], + [-0.11673, 0.03377, -0.03801], + [-0.11801, -0.08778, 0.03043], + [-0.00041, -0.1375, 0.07759], + [0.112, -0.068, 0.05657], + [0.12301, 0.0524, -0.00964], + [4e-05, 0.19505, -0.11016], + [-0.2116, 0.07095, -0.0744], + [-0.20828, -0.14529, 0.04827], + [0.19995, -0.11721, 0.09863], + [0.21761, 0.10263, -0.02266], + [0.99992, 0.03664, 0.0], + [0.91877, -0.01835, 0.0], + [1.08131, -0.01829, 0.0], + [1.99992, 0.03664, 0.0], + [1.91877, -0.01835, 0.0], + [2.08131, -0.01829, 0.0], + [0.0011, -0.1722, 0.09743], + [0.99994, 0.02611, 0.0], + [1.99994, 0.02611, 0.0], + ] + ) # manually visually inspected + + assert coords.shape == expected_coords.shape + assert numpy.allclose(coords, expected_coords, atol=1.0e-5) + + box_vectors = top.getPeriodicBoxVectors().value_in_unit(openmm.unit.angstrom) + box_vectors = numpy.array(box_vectors) + + assert numpy.allclose(box_vectors, expected_box_vectors) + + expected_bonds = [ + ("C1x", "C2x"), + ("C2x", "C3x"), + ("C3x", "N1x"), + ("N1x", "C4x"), + ("C4x", "C5x"), + ("C5x", "C1x"), + ("C1x", "H1x"), + ("C2x", "H2x"), + ("C3x", "H3x"), + ("C4x", "H4x"), + ("C5x", "H5x"), + ("C1x", "C2x"), + ("C1x", "C3x"), + ("C1x", "C2x"), + ("C1x", "C3x"), + ] + + actual_bonds = [(bond.atom1.name, bond.atom2.name) for bond in top.bonds()] + assert actual_bonds == expected_bonds + + expected_idxs = [ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17}, + {11, 12, 13, 18}, + {14, 15, 16, 19}, + ] + assert idxs == expected_idxs + + def test_setup_fn(): system = absolv.config.System( solutes={"[Na+]": 1, "[Cl-]": 1}, solvent_a=None, solvent_b={"O": 1} @@ -70,8 +219,8 @@ def test_setup_fn(): assert prepared_system_a.system.getNumParticles() == 2 assert prepared_system_b.system.getNumParticles() == 5 - assert prepared_system_a.topology.box_vectors is None - assert prepared_system_b.topology.box_vectors is not None + assert prepared_system_a.topology.getPeriodicBoxVectors() is None + assert prepared_system_b.topology.getPeriodicBoxVectors() is not None @pytest.mark.parametrize( diff --git a/absolv/tests/utils/test_openmm.py b/absolv/tests/utils/test_openmm.py index 20a0152..7ee7723 100644 --- a/absolv/tests/utils/test_openmm.py +++ b/absolv/tests/utils/test_openmm.py @@ -45,7 +45,7 @@ def test_create_simulation(): simulation = create_simulation( system, - topology, + topology.to_openmm(), expected_coords, integrator, femto.md.constants.OpenMMPlatform.REFERENCE, diff --git a/absolv/tests/utils/test_topology.py b/absolv/tests/utils/test_topology.py index facc168..d26457e 100644 --- a/absolv/tests/utils/test_topology.py +++ b/absolv/tests/utils/test_topology.py @@ -1,7 +1,7 @@ import openff.toolkit import pytest -from absolv.utils.topology import topology_to_atom_indices, topology_to_components +from absolv.utils.topology import topology_to_components @pytest.mark.parametrize("n_counts", [[3, 1, 2], [3, 1, 1]]) @@ -19,22 +19,3 @@ def test_topology_to_components(n_counts): ("[H][C]([H])([H])[H]", n_counts[1]), ("[H][O][H]", n_counts[2]), ] - - -def test_topology_to_atom_indices(): - topology = openff.toolkit.Topology.from_molecules( - [openff.toolkit.Molecule.from_smiles("O")] * 1 - + [openff.toolkit.Molecule.from_smiles("C")] * 2 - + [openff.toolkit.Molecule.from_smiles("O")] * 3 - ) - - atom_indices = topology_to_atom_indices(topology) - - assert atom_indices == [ - {0, 1, 2}, - {3, 4, 5, 6, 7}, - {8, 9, 10, 11, 12}, - {13, 14, 15}, - {16, 17, 18}, - {19, 20, 21}, - ] diff --git a/absolv/utils/openmm.py b/absolv/utils/openmm.py index 916e3ca..4c06a91 100644 --- a/absolv/utils/openmm.py +++ b/absolv/utils/openmm.py @@ -1,4 +1,5 @@ """Utilities to manipulate OpenMM objects.""" + import typing import femto.md.constants @@ -46,7 +47,7 @@ def add_barostat( def create_simulation( system: openmm.System, - topology: openff.toolkit.Topology, + topology: openmm.app.Topology, coords: openmm.unit.Quantity, integrator: openmm.Integrator, platform: femto.md.constants.OpenMMPlatform, @@ -69,17 +70,20 @@ def create_simulation( ) platform = openmm.Platform.getPlatformByName(platform) - if topology.box_vectors is not None: - system.setDefaultPeriodicBoxVectors(*topology.box_vectors.to_openmm()) + is_periodic = topology.getPeriodicBoxVectors() is not None + + if is_periodic: + system.setDefaultPeriodicBoxVectors(*topology.getPeriodicBoxVectors()) simulation = openmm.app.Simulation( - topology.to_openmm(), system, integrator, platform, platform_properties + topology, system, integrator, platform, platform_properties ) - if topology.box_vectors is not None: - simulation.context.setPeriodicBoxVectors(*topology.box_vectors.to_openmm()) + if is_periodic: + simulation.context.setPeriodicBoxVectors(*topology.getPeriodicBoxVectors()) simulation.context.setPositions(coords) + simulation.context.computeVirtualSites() simulation.context.setVelocitiesToTemperature(integrator.getTemperature()) return simulation @@ -192,4 +196,6 @@ def extract_frame(trajectory: mdtraj.Trajectory, idx: int) -> openmm.State: if trajectory.unitcell_vectors is not None: context.setPeriodicBoxVectors(*trajectory.openmm_boxes(idx)) + context.computeVirtualSites() + return context.getState(getPositions=True) diff --git a/absolv/utils/topology.py b/absolv/utils/topology.py index 8ba8ce6..5377133 100644 --- a/absolv/utils/topology.py +++ b/absolv/utils/topology.py @@ -1,4 +1,5 @@ """Utilities for manipulating OpenFF topology objects.""" + import openff.toolkit @@ -41,24 +42,3 @@ def topology_to_components(topology: openff.toolkit.Topology) -> list[tuple[str, components.append((current_smiles, current_count)) return components - - -def topology_to_atom_indices(topology: openff.toolkit.Topology) -> list[set[int]]: - """A helper method for extracting the sets of atom indices associated with each - molecule in a topology. - - Args: - topology: The topology to extract the atom indices from. - - Returns: - The set of atoms indices associated with each molecule in the topology. - """ - - atom_indices: list[set[int]] = [] - current_atom_idx = 0 - - for molecule in topology.molecules: - atom_indices.append({i + current_atom_idx for i in range(molecule.n_atoms)}) - current_atom_idx += molecule.n_atoms - - return atom_indices diff --git a/docs/user-guide/overview.md b/docs/user-guide/overview.md index 01bf544..1421662 100644 --- a/docs/user-guide/overview.md +++ b/docs/user-guide/overview.md @@ -25,7 +25,7 @@ specified: ```python import openmm.unit -temperature=298.15 * openmm.unit.kelvin, +temperature=298.15 * openmm.unit.kelvin pressure=1.0 * openmm.unit.atmosphere ``` diff --git a/regression/run.py b/regression/run.py index 76a9960..35e1fae 100644 --- a/regression/run.py +++ b/regression/run.py @@ -1,8 +1,8 @@ +import datetime import logging import pathlib import tempfile import urllib.request -import datetime import click import femto.md.config @@ -16,8 +16,8 @@ from rdkit import Chem import absolv.config -import absolv.utils.openmm import absolv.runner +import absolv.utils.openmm DEFAULT_TEMPERATURE = 298.15 * openmm.unit.kelvin DEFAULT_PRESSURE = 1.0 * openmm.unit.atmosphere @@ -170,11 +170,11 @@ def run_replica( femto.md.system.apply_hmr( prepared_system_a.system, - parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()), + parmed.openmm.load_topology(prepared_system_a.topology), ) femto.md.system.apply_hmr( prepared_system_b.system, - parmed.openmm.load_topology(prepared_system_a.topology.to_openmm()), + parmed.openmm.load_topology(prepared_system_a.topology), ) if method == "neq":