Skip to content

Commit

Permalink
Fix solutes with v-sites (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 24, 2024
1 parent 697a106 commit ede08ca
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 149 deletions.
60 changes: 2 additions & 58 deletions absolv/fep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Prepare OpenMM systems for FEP calculations."""

import copy
import itertools

Expand All @@ -22,54 +23,6 @@
)


def _find_v_sites(
system: openmm.System, atom_indices: list[set[int]]
) -> list[set[int]]:
"""Finds any virtual sites in the system and ensures their indices get appended
to the atom index list.
Args:
system: The system that may contain v-sites.
atom_indices: A list of per-molecule atom indices
Returns:
A list of the per molecule **particle** indices.
"""

atom_to_molecule_idx = {
atom_idx: i for i, indices in enumerate(atom_indices) for atom_idx in indices
}

particle_to_atom_idx = {}
atom_idx = 0

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
continue

particle_to_atom_idx[particle_idx] = atom_idx
atom_idx += 1

atom_idx = 0

remapped_atom_indices: list[set[int]] = [set() for _ in range(len(atom_indices))]

for particle_idx in range(system.getNumParticles()):
if not system.isVirtualSite(particle_idx):
molecule_idx = atom_to_molecule_idx[atom_idx]
atom_idx += 1

else:
v_site = system.getVirtualSite(particle_idx)
parent_atom_idx = particle_to_atom_idx[v_site.getParticle(0)]

molecule_idx = atom_to_molecule_idx[parent_atom_idx]

remapped_atom_indices[molecule_idx].add(particle_idx)

return remapped_atom_indices


def _find_nonbonded_forces(
system: openmm.System,
) -> tuple[
Expand Down Expand Up @@ -468,7 +421,7 @@ def apply_fep(
system: The chemical system to generate the alchemical system from
alchemical_indices: The atom indices corresponding to each molecule that
should be alchemically transformable. The atom indices **must**
correspond to **all** atoms in each molecule as alchemically
correspond to **all** atoms / v-sites in each molecule as alchemically
transforming part of a molecule is not supported.
persistent_indices: The atom indices corresponding to each molecule that
should **not** be alchemically transformable.
Expand All @@ -481,15 +434,6 @@ def apply_fep(

system = copy.deepcopy(system)

# Make sure we track v-sites attached to any solutes that may be alchemically
# turned off. We do this as a post-process step as the OpenFF toolkit does not
# currently expose a clean way to access this information.
atom_indices = alchemical_indices + persistent_indices
atom_indices = _find_v_sites(system, atom_indices)

alchemical_indices = atom_indices[: len(alchemical_indices)]
persistent_indices = atom_indices[len(alchemical_indices) :]

(
nonbonded_force,
custom_nonbonded_force,
Expand Down
172 changes: 158 additions & 14 deletions absolv/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Run calculations defined by a config."""

import collections
import functools
import multiprocessing
import pathlib
Expand All @@ -17,6 +18,7 @@
import openff.toolkit
import openff.utilities
import openmm
import openmm.app
import openmm.unit
import pymbar
import tqdm
Expand All @@ -35,12 +37,152 @@ class PreparedSystem(typing.NamedTuple):
system: openmm.System
"""The alchemically modified OpenMM system."""

topology: openff.toolkit.Topology
"""The OpenFF topology with any box vectors set."""
topology: openmm.app.Topology
"""The OpenMM topology with any box vectors set."""
coords: openmm.unit.Quantity
"""The coordinates of the system."""


def _rebuild_topology(
orig_top: openff.toolkit.Topology,
orig_coords: openmm.unit.Quantity,
system: openmm.System,
) -> tuple[openmm.app.Topology, openmm.unit.Quantity, list[set[int]]]:
"""Rebuild the topology to also include virtual sites."""
atom_idx_to_residue_idx = {}
atom_idx = 0

for residue_idx, molecule in enumerate(orig_top.molecules):
for _ in molecule.atoms:
atom_idx_to_residue_idx[atom_idx] = residue_idx
atom_idx += 1

particle_idx_to_atom_idx = {}
atom_idx = 0

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
continue

particle_idx_to_atom_idx[particle_idx] = atom_idx
atom_idx += 1

atoms_off = [*orig_top.atoms]
particles = []

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
v_site = system.getVirtualSite(particle_idx)

parent_idxs = {
particle_idx_to_atom_idx[v_site.getParticle(i)]
for i in range(v_site.getNumParticles())
}
parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))]

particles.append((-1, parent_residue))
continue

atom_idx = particle_idx_to_atom_idx[particle_idx]
residue_idx = atom_idx_to_residue_idx[atom_idx]

particles.append((atoms_off[atom_idx].atomic_number, residue_idx))

topology = openmm.app.Topology()

if orig_top.box_vectors is not None:
topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm())

chain = topology.addChain()

atom_counts_per_residue = collections.defaultdict(
lambda: collections.defaultdict(int)
)
atoms = []

last_residue_idx = -1
residue = None

residue_to_particle_idx = collections.defaultdict(list)

for particle_idx, (atomic_num, residue_idx) in enumerate(particles):
if residue_idx != last_residue_idx:
last_residue_idx = residue_idx
residue = topology.addResidue("UNK", chain)

element = (
None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num)
)
symbol = "X" if element is None else element.symbol

atom_counts_per_residue[residue_idx][atomic_num] += 1
atom = topology.addAtom(
f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"),
element,
residue,
)
atoms.append(atom)

residue_to_particle_idx[residue_idx].append(particle_idx)

_rename_residues(topology)

atom_idx_to_particle_idx = {j: i for i, j in particle_idx_to_atom_idx.items()}

for bond in orig_top.bonds:
if atoms[atom_idx_to_particle_idx[bond.atom1_index]].residue.name == "HOH":
continue

topology.addBond(
atoms[atom_idx_to_particle_idx[bond.atom1_index]],
atoms[atom_idx_to_particle_idx[bond.atom2_index]],
)

coords_full = []

for particle_idx in range(system.getNumParticles()):
if particle_idx in particle_idx_to_atom_idx:
coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]]
coords_full.append(coords_i.value_in_unit(openmm.unit.angstrom))
else:
coords_full.append(numpy.zeros((1, 3)))

coords_full = numpy.vstack(coords_full) * openmm.unit.angstrom

if len(orig_coords) != len(coords_full):
context = openmm.Context(system, openmm.VerletIntegrator(1.0))
context.setPositions(coords_full)
context.computeVirtualSites()

coords_full = context.getState(getPositions=True).getPositions(asNumpy=True)

residues = [
set(residue_to_particle_idx[residue_idx])
for residue_idx in range(len(residue_to_particle_idx))
]

return topology, coords_full, residues


def _rename_residues(topology: openmm.app.Topology):
"""Attempts to assign standard residue names to known residues"""

for residue in topology.residues():
symbols = sorted(
(
atom.element.symbol
for atom in residue.atoms()
if atom.element is not None
)
)

if symbols == ["H", "H", "O"]:
residue.name = "HOH"

for i, atom in enumerate(residue.atoms()):
atom.name = "OW" if atom.element.symbol == "O" else f"HW{i}"


def _setup_solvent(
solvent_idx: typing.Literal["solvent-a", "solvent-b"],
components: list[tuple[str, int]],
Expand All @@ -67,19 +209,21 @@ def _setup_solvent(

is_vacuum = n_solvent_molecules == 0

topology, coords = absolv.setup.setup_system(components)
topology.box_vectors = None if is_vacuum else topology.box_vectors
topology_off, coords = absolv.setup.setup_system(components)
topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology_off)
else:
original_system: openmm.System = force_field(topology_off, coords, solvent_idx)

atom_indices = absolv.utils.topology.topology_to_atom_indices(topology)
topology, coords, atom_indices = _rebuild_topology(
topology_off, coords, original_system
)

alchemical_indices = atom_indices[:n_solute_molecules]
persistent_indices = atom_indices[n_solute_molecules:]

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology)
else:
original_system: openmm.System = force_field(topology, coords, solvent_idx)

alchemical_system = absolv.fep.apply_fep(
original_system,
alchemical_indices,
Expand Down Expand Up @@ -196,7 +340,7 @@ def _run_eq_phase(
"""
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -312,7 +456,7 @@ def _run_phase_end_states(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -363,11 +507,11 @@ def _run_switching(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm())
mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology)

trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology)
trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology)
Expand Down
22 changes: 0 additions & 22 deletions absolv/tests/test_fep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,11 @@
_add_electrostatics_lambda,
_add_lj_vdw_lambda,
_find_nonbonded_forces,
_find_v_sites,
apply_fep,
)
from absolv.tests import is_close


def test_find_v_sites():
"""Ensure that v-sites are correctly detected from an OMM system and assigned
to the right parent molecule."""

# Construct a mock system of V A A A V A A where (0, 5, 6), (3,), (4, 1, 2)
# are the core molecules.
system = openmm.System()

for _ in range(7):
system.addParticle(1.0)

system.setVirtualSite(0, openmm.TwoParticleAverageSite(5, 6, 0.5, 0.5))
system.setVirtualSite(4, openmm.TwoParticleAverageSite(1, 2, 0.5, 0.5))

atom_indices = [{0, 1}, {2}, {3, 4}]

particle_indices = _find_v_sites(system, atom_indices)

assert particle_indices == [{1, 2, 4}, {3}, {0, 5, 6}]


def test_find_nonbonded_forces_lj_only(aq_nacl_lj_system):
(
nonbonded_force,
Expand Down
Loading

0 comments on commit ede08ca

Please sign in to comment.