Skip to content

Commit

Permalink
Implement improper enumeration on _SimpleMolecule (#1626)
Browse files Browse the repository at this point in the history
* Use snake case for attributes

* Do not rely on wrapped toolkits for improper enumeration

* Implement improper enumeration on `_SimpleMolecule`
  • Loading branch information
mattwthompson authored May 31, 2023
1 parent a131add commit d3a8b60
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 53 deletions.
30 changes: 30 additions & 0 deletions openff/toolkit/tests/test_mm_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,36 @@ def test_from_molecule(self):
assert found == expected_atomic_numbers[atom_index]


class TestImpropers:
@pytest.mark.parametrize(
("smiles", "n_impropers", "n_pruned"),
[
("C", 24, 0),
("CC", 48, 0),
("N", 6, 6),
],
)
def test_pruned_impropers(self, smiles, n_impropers, n_pruned):
"""See equivalent test in TestMolecule."""
molecule = _SimpleMolecule.from_molecule(
Molecule.from_smiles(smiles),
)

assert molecule.n_impropers == n_impropers
assert len(list(molecule.smirnoff_impropers)) == n_pruned
assert len(list(molecule.amber_impropers)) == n_pruned

amber_impropers = {*molecule.amber_impropers}

for smirnoff_imp in molecule.smirnoff_impropers:
assert (
smirnoff_imp[1],
smirnoff_imp[0],
smirnoff_imp[2],
smirnoff_imp[3],
) in amber_impropers


class TestIsomorphism:
@pytest.fixture()
def n_propanol(self):
Expand Down
101 changes: 91 additions & 10 deletions openff/toolkit/topology/_mm_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
deserialize a Molecule or a TypedMolecule.
"""
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Dict,
Generator,
List,
NoReturn,
Optional,
Tuple,
Union,
)

from openff.units import unit
from openff.units.elements import MASSES, SYMBOLS
Expand Down Expand Up @@ -94,7 +103,9 @@ def get_bond_between(self, atom1_index, atom2_index):
return bond

@property
def angles(self):
def angles(
self,
) -> Generator[tuple["_SimpleAtom", "_SimpleAtom", "_SimpleAtom",], None, None,]:
for atom1 in self.atoms:
for atom2 in atom1.bonded_atoms:
for atom3 in atom2.bonded_atoms:
Expand All @@ -107,7 +118,18 @@ def angles(self):
pass

@property
def propers(self):
def propers(
self,
) -> Generator[
tuple[
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
],
None,
None,
]:
for atom1 in self.atoms:
for atom2 in atom1.bonded_atoms:
for atom3 in atom2.bonded_atoms:
Expand All @@ -121,19 +143,78 @@ def propers(self):
yield (atom1, atom2, atom3, atom4)
else:
# Do no duplicate
pass # yield (atom4, atom3, atom2, atom1)
pass

@property
def impropers(
self,
) -> Generator[
tuple[
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
],
None,
None,
]:
for atom1 in self.atoms:
for atom2 in atom1.bonded_atoms:
for atom3 in atom2.bonded_atoms:
if atom1 is atom3:
continue
for atom3i in atom2.bonded_atoms:
if atom3i == atom3:
continue
if atom3i == atom1:
continue

yield (atom1, atom2, atom3, atom3i)

@property
def smirnoff_impropers(
self,
) -> Generator[
tuple[
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
],
None,
None,
]:
for improper in self.impropers:
if len(list(improper[1].bonded_atoms)) == 3:
yield improper

@property
def amber_impropers(
self,
) -> Generator[
tuple[
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
"_SimpleAtom",
],
None,
None,
]:
for improper in self.smirnoff_impropers:
yield (improper[1], improper[0], improper[2], improper[3])

@property
def impropers(self):
return {}
def n_angles(self) -> int:
return len(list(self.angles))

@property
def smirnoff_impropers(self):
return {}
def n_propers(self) -> int:
return len(list(self.propers))

@property
def amber_impropers(self):
return {}
def n_impropers(self) -> int:
return len(list(self.impropers))

@property
def hill_formula(self) -> str:
Expand Down
66 changes: 27 additions & 39 deletions openff/toolkit/topology/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3154,7 +3154,7 @@ def n_bonds(self) -> int:
"""
The number of Bond objects in the molecule.
"""
return sum([1 for bond in self.bonds])
return len(self._bonds)

@property
def n_angles(self) -> int:
Expand Down Expand Up @@ -3347,9 +3347,7 @@ def impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]:
smirnoff_impropers, amber_impropers
"""
self._construct_torsions()
assert (
self._impropers is not None
), "_construct_torsions always sets _impropers to a set"

return self._impropers

@property
Expand Down Expand Up @@ -3387,16 +3385,11 @@ def smirnoff_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]:
impropers, amber_impropers
"""
# TODO: Replace with non-cheminformatics-toolkit method
# (ie. just looping over all atoms and finding ones that have 3 bonds?)

smirnoff_improper_smarts = "[*:1]~[X3:2](~[*:3])~[*:4]"
improper_idxs = self.chemical_environment_matches(smirnoff_improper_smarts)
smirnoff_impropers = {
(self.atom(imp[0]), self.atom(imp[1]), self.atom(imp[2]), self.atom(imp[3]))
for imp in improper_idxs
return {
improper
for improper in self.impropers
if len(self._bonded_atoms[improper[1]]) == 3
}
return smirnoff_impropers

@property
def amber_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]:
Expand Down Expand Up @@ -3425,15 +3418,12 @@ def amber_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]:
impropers, smirnoff_impropers
"""
# TODO: Replace with non-cheminformatics-toolkit method
# (ie. just looping over all atoms and finding ones that have 3 bonds?)
amber_improper_smarts = "[X3:1](~[*:2])(~[*:3])~[*:4]"
improper_idxs = self.chemical_environment_matches(amber_improper_smarts)
amber_impropers = {
(self.atom(imp[0]), self.atom(imp[1]), self.atom(imp[2]), self.atom(imp[3]))
for imp in improper_idxs
self._construct_torsions()

return {
(improper[1], improper[0], improper[2], improper[3])
for improper in self.smirnoff_impropers
}
return amber_impropers

def nth_degree_neighbors(self, n_degrees):
"""
Expand Down Expand Up @@ -5063,16 +5053,14 @@ def _construct_angles(self):
"""
Get an iterator over all i-j-k angles.
"""
# TODO: Build Angle objects instead of tuple of atoms.
if not hasattr(self, "_angles"):
self._construct_bonded_atoms_list()
self._angles = set()
for atom1 in self._atoms:
for atom2 in self._bondedAtoms[atom1]:
for atom3 in self._bondedAtoms[atom2]:
for atom2 in self._bonded_atoms[atom1]:
for atom3 in self._bonded_atoms[atom2]:
if atom1 == atom3:
continue
# TODO: Encapsulate this logic into an Angle class.
if atom1.molecule_atom_index < atom3.molecule_atom_index:
self._angles.add((atom1, atom2, atom3))
else:
Expand All @@ -5081,19 +5069,20 @@ def _construct_angles(self):
def _construct_torsions(self):
"""
Construct sets containing the atoms improper and proper torsions
Impropers are constructed with the central atom listed second
"""
# TODO: Build Proper/ImproperTorsion objects instead of tuple of atoms.
if not hasattr(self, "_torsions"):
self._construct_bonded_atoms_list()

self._propers = set()
self._impropers = set()
self._propers: set[tuple[Atom]] = set()
self._impropers: set[tuple[Atom]] = set()
for atom1 in self._atoms:
for atom2 in self._bondedAtoms[atom1]:
for atom3 in self._bondedAtoms[atom2]:
for atom2 in self._bonded_atoms[atom1]:
for atom3 in self._bonded_atoms[atom2]:
if atom1 == atom3:
continue
for atom4 in self._bondedAtoms[atom3]:
for atom4 in self._bonded_atoms[atom3]:
if atom4 == atom2:
continue
# Exclude i-j-k-i
Expand All @@ -5107,7 +5096,7 @@ def _construct_torsions(self):

self._propers.add(torsion)

for atom3i in self._bondedAtoms[atom2]:
for atom3i in self._bonded_atoms[atom2]:
if atom3i == atom3:
continue
if atom3i == atom1:
Expand All @@ -5124,16 +5113,15 @@ def _construct_bonded_atoms_list(self):
"""
# TODO: Add this to cached_properties
if not hasattr(self, "_bondedAtoms"):
# self._atoms = [ atom for atom in self.atoms() ]
self._bondedAtoms = dict()
if not hasattr(self, "_bonded_atoms"):
self._bonded_atoms: dict[Atom, set[Atom]] = dict()
for atom in self._atoms:
self._bondedAtoms[atom] = set()
self._bonded_atoms[atom] = set()
for bond in self._bonds:
atom1 = self.atoms[bond.atom1_index]
atom2 = self.atoms[bond.atom2_index]
self._bondedAtoms[atom1].add(atom2)
self._bondedAtoms[atom2].add(atom1)
self._bonded_atoms[atom1].add(atom2)
self._bonded_atoms[atom2].add(atom1)

def _is_bonded(self, atom_index_1, atom_index_2):
"""Return True if atoms are bonded, False if not.
Expand All @@ -5154,7 +5142,7 @@ def _is_bonded(self, atom_index_1, atom_index_2):
self._construct_bonded_atoms_list()
atom1 = self._atoms[atom_index_1]
atom2 = self._atoms[atom_index_2]
return atom2 in self._bondedAtoms[atom1]
return atom2 in self._bonded_atoms[atom1]

def get_bond_between(self, i: Union[int, "Atom"], j: Union[int, "Atom"]) -> "Bond":
"""Returns the bond between two atoms
Expand Down
16 changes: 12 additions & 4 deletions openff/toolkit/topology/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from typing_extensions import TypeAlias

from openff.toolkit.topology import Molecule
from openff.toolkit.topology._mm_molecule import _SimpleBond, _SimpleMolecule
from openff.toolkit.topology._mm_molecule import (
_SimpleAtom,
_SimpleBond,
_SimpleMolecule,
)
from openff.toolkit.topology.molecule import FrozenMolecule, HierarchyElement
from openff.toolkit.utils import quantity_to_string, string_to_quantity
from openff.toolkit.utils.constants import (
Expand Down Expand Up @@ -841,7 +845,7 @@ def n_propers(self) -> int:
return sum(mol.n_propers for mol in self._molecules)

@property
def propers(self) -> Generator[Tuple["Atom", ...], None, None]:
def propers(self) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]:
"""Iterable of Tuple[Atom]: iterator over the proper torsions in this Topology."""
for molecule in self.molecules:
for proper in molecule.propers:
Expand All @@ -860,7 +864,9 @@ def impropers(self) -> Generator[Tuple["Atom", ...], None, None]:
yield improper

@property
def smirnoff_impropers(self) -> Generator[Tuple["Atom", ...], None, None]:
def smirnoff_impropers(
self,
) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]:
"""
Iterate over improper torsions in the molecule, but only those with
trivalent centers, reporting the central atom second in each improper.
Expand Down Expand Up @@ -899,7 +905,9 @@ def smirnoff_impropers(self) -> Generator[Tuple["Atom", ...], None, None]:
yield smirnoff_improper

@property
def amber_impropers(self) -> Generator[Tuple["Atom", ...], None, None]:
def amber_impropers(
self,
) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]:
"""
Iterate over improper torsions in the molecule, but only those with
trivalent centers, reporting the central atom first in each improper.
Expand Down

0 comments on commit d3a8b60

Please sign in to comment.