From d3a8b604a47afdc015eb890459a7a372889c4cf4 Mon Sep 17 00:00:00 2001 From: Matt Thompson Date: Wed, 31 May 2023 16:23:13 -0500 Subject: [PATCH] Implement improper enumeration on `_SimpleMolecule` (#1626) * Use snake case for attributes * Do not rely on wrapped toolkits for improper enumeration * Implement improper enumeration on `_SimpleMolecule` --- openff/toolkit/tests/test_mm_molecule.py | 30 +++++++ openff/toolkit/topology/_mm_molecule.py | 101 ++++++++++++++++++++--- openff/toolkit/topology/molecule.py | 66 ++++++--------- openff/toolkit/topology/topology.py | 16 +++- 4 files changed, 160 insertions(+), 53 deletions(-) diff --git a/openff/toolkit/tests/test_mm_molecule.py b/openff/toolkit/tests/test_mm_molecule.py index f8d38a783..b37ce27a3 100644 --- a/openff/toolkit/tests/test_mm_molecule.py +++ b/openff/toolkit/tests/test_mm_molecule.py @@ -150,6 +150,36 @@ def test_from_molecule(self): assert found == expected_atomic_numbers[atom_index] +class TestImpropers: + @pytest.mark.parametrize( + ("smiles", "n_impropers", "n_pruned"), + [ + ("C", 24, 0), + ("CC", 48, 0), + ("N", 6, 6), + ], + ) + def test_pruned_impropers(self, smiles, n_impropers, n_pruned): + """See equivalent test in TestMolecule.""" + molecule = _SimpleMolecule.from_molecule( + Molecule.from_smiles(smiles), + ) + + assert molecule.n_impropers == n_impropers + assert len(list(molecule.smirnoff_impropers)) == n_pruned + assert len(list(molecule.amber_impropers)) == n_pruned + + amber_impropers = {*molecule.amber_impropers} + + for smirnoff_imp in molecule.smirnoff_impropers: + assert ( + smirnoff_imp[1], + smirnoff_imp[0], + smirnoff_imp[2], + smirnoff_imp[3], + ) in amber_impropers + + class TestIsomorphism: @pytest.fixture() def n_propanol(self): diff --git a/openff/toolkit/topology/_mm_molecule.py b/openff/toolkit/topology/_mm_molecule.py index 8e5fac62d..05e591238 100644 --- a/openff/toolkit/topology/_mm_molecule.py +++ b/openff/toolkit/topology/_mm_molecule.py @@ -9,7 +9,16 @@ deserialize a Molecule or a TypedMolecule. """ -from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Generator, + List, + NoReturn, + Optional, + Tuple, + Union, +) from openff.units import unit from openff.units.elements import MASSES, SYMBOLS @@ -94,7 +103,9 @@ def get_bond_between(self, atom1_index, atom2_index): return bond @property - def angles(self): + def angles( + self, + ) -> Generator[tuple["_SimpleAtom", "_SimpleAtom", "_SimpleAtom",], None, None,]: for atom1 in self.atoms: for atom2 in atom1.bonded_atoms: for atom3 in atom2.bonded_atoms: @@ -107,7 +118,18 @@ def angles(self): pass @property - def propers(self): + def propers( + self, + ) -> Generator[ + tuple[ + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + ], + None, + None, + ]: for atom1 in self.atoms: for atom2 in atom1.bonded_atoms: for atom3 in atom2.bonded_atoms: @@ -121,19 +143,78 @@ def propers(self): yield (atom1, atom2, atom3, atom4) else: # Do no duplicate - pass # yield (atom4, atom3, atom2, atom1) + pass + + @property + def impropers( + self, + ) -> Generator[ + tuple[ + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + ], + None, + None, + ]: + for atom1 in self.atoms: + for atom2 in atom1.bonded_atoms: + for atom3 in atom2.bonded_atoms: + if atom1 is atom3: + continue + for atom3i in atom2.bonded_atoms: + if atom3i == atom3: + continue + if atom3i == atom1: + continue + + yield (atom1, atom2, atom3, atom3i) + + @property + def smirnoff_impropers( + self, + ) -> Generator[ + tuple[ + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + ], + None, + None, + ]: + for improper in self.impropers: + if len(list(improper[1].bonded_atoms)) == 3: + yield improper + + @property + def amber_impropers( + self, + ) -> Generator[ + tuple[ + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + "_SimpleAtom", + ], + None, + None, + ]: + for improper in self.smirnoff_impropers: + yield (improper[1], improper[0], improper[2], improper[3]) @property - def impropers(self): - return {} + def n_angles(self) -> int: + return len(list(self.angles)) @property - def smirnoff_impropers(self): - return {} + def n_propers(self) -> int: + return len(list(self.propers)) @property - def amber_impropers(self): - return {} + def n_impropers(self) -> int: + return len(list(self.impropers)) @property def hill_formula(self) -> str: diff --git a/openff/toolkit/topology/molecule.py b/openff/toolkit/topology/molecule.py index ea41d3e5f..38ac514c9 100644 --- a/openff/toolkit/topology/molecule.py +++ b/openff/toolkit/topology/molecule.py @@ -3154,7 +3154,7 @@ def n_bonds(self) -> int: """ The number of Bond objects in the molecule. """ - return sum([1 for bond in self.bonds]) + return len(self._bonds) @property def n_angles(self) -> int: @@ -3347,9 +3347,7 @@ def impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]: smirnoff_impropers, amber_impropers """ self._construct_torsions() - assert ( - self._impropers is not None - ), "_construct_torsions always sets _impropers to a set" + return self._impropers @property @@ -3387,16 +3385,11 @@ def smirnoff_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]: impropers, amber_impropers """ - # TODO: Replace with non-cheminformatics-toolkit method - # (ie. just looping over all atoms and finding ones that have 3 bonds?) - - smirnoff_improper_smarts = "[*:1]~[X3:2](~[*:3])~[*:4]" - improper_idxs = self.chemical_environment_matches(smirnoff_improper_smarts) - smirnoff_impropers = { - (self.atom(imp[0]), self.atom(imp[1]), self.atom(imp[2]), self.atom(imp[3])) - for imp in improper_idxs + return { + improper + for improper in self.impropers + if len(self._bonded_atoms[improper[1]]) == 3 } - return smirnoff_impropers @property def amber_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]: @@ -3425,15 +3418,12 @@ def amber_impropers(self) -> Set[Tuple[Atom, Atom, Atom, Atom]]: impropers, smirnoff_impropers """ - # TODO: Replace with non-cheminformatics-toolkit method - # (ie. just looping over all atoms and finding ones that have 3 bonds?) - amber_improper_smarts = "[X3:1](~[*:2])(~[*:3])~[*:4]" - improper_idxs = self.chemical_environment_matches(amber_improper_smarts) - amber_impropers = { - (self.atom(imp[0]), self.atom(imp[1]), self.atom(imp[2]), self.atom(imp[3])) - for imp in improper_idxs + self._construct_torsions() + + return { + (improper[1], improper[0], improper[2], improper[3]) + for improper in self.smirnoff_impropers } - return amber_impropers def nth_degree_neighbors(self, n_degrees): """ @@ -5063,16 +5053,14 @@ def _construct_angles(self): """ Get an iterator over all i-j-k angles. """ - # TODO: Build Angle objects instead of tuple of atoms. if not hasattr(self, "_angles"): self._construct_bonded_atoms_list() self._angles = set() for atom1 in self._atoms: - for atom2 in self._bondedAtoms[atom1]: - for atom3 in self._bondedAtoms[atom2]: + for atom2 in self._bonded_atoms[atom1]: + for atom3 in self._bonded_atoms[atom2]: if atom1 == atom3: continue - # TODO: Encapsulate this logic into an Angle class. if atom1.molecule_atom_index < atom3.molecule_atom_index: self._angles.add((atom1, atom2, atom3)) else: @@ -5081,19 +5069,20 @@ def _construct_angles(self): def _construct_torsions(self): """ Construct sets containing the atoms improper and proper torsions + + Impropers are constructed with the central atom listed second """ - # TODO: Build Proper/ImproperTorsion objects instead of tuple of atoms. if not hasattr(self, "_torsions"): self._construct_bonded_atoms_list() - self._propers = set() - self._impropers = set() + self._propers: set[tuple[Atom]] = set() + self._impropers: set[tuple[Atom]] = set() for atom1 in self._atoms: - for atom2 in self._bondedAtoms[atom1]: - for atom3 in self._bondedAtoms[atom2]: + for atom2 in self._bonded_atoms[atom1]: + for atom3 in self._bonded_atoms[atom2]: if atom1 == atom3: continue - for atom4 in self._bondedAtoms[atom3]: + for atom4 in self._bonded_atoms[atom3]: if atom4 == atom2: continue # Exclude i-j-k-i @@ -5107,7 +5096,7 @@ def _construct_torsions(self): self._propers.add(torsion) - for atom3i in self._bondedAtoms[atom2]: + for atom3i in self._bonded_atoms[atom2]: if atom3i == atom3: continue if atom3i == atom1: @@ -5124,16 +5113,15 @@ def _construct_bonded_atoms_list(self): """ # TODO: Add this to cached_properties - if not hasattr(self, "_bondedAtoms"): - # self._atoms = [ atom for atom in self.atoms() ] - self._bondedAtoms = dict() + if not hasattr(self, "_bonded_atoms"): + self._bonded_atoms: dict[Atom, set[Atom]] = dict() for atom in self._atoms: - self._bondedAtoms[atom] = set() + self._bonded_atoms[atom] = set() for bond in self._bonds: atom1 = self.atoms[bond.atom1_index] atom2 = self.atoms[bond.atom2_index] - self._bondedAtoms[atom1].add(atom2) - self._bondedAtoms[atom2].add(atom1) + self._bonded_atoms[atom1].add(atom2) + self._bonded_atoms[atom2].add(atom1) def _is_bonded(self, atom_index_1, atom_index_2): """Return True if atoms are bonded, False if not. @@ -5154,7 +5142,7 @@ def _is_bonded(self, atom_index_1, atom_index_2): self._construct_bonded_atoms_list() atom1 = self._atoms[atom_index_1] atom2 = self._atoms[atom_index_2] - return atom2 in self._bondedAtoms[atom1] + return atom2 in self._bonded_atoms[atom1] def get_bond_between(self, i: Union[int, "Atom"], j: Union[int, "Atom"]) -> "Bond": """Returns the bond between two atoms diff --git a/openff/toolkit/topology/topology.py b/openff/toolkit/topology/topology.py index 7154612ef..a3b0a9b28 100644 --- a/openff/toolkit/topology/topology.py +++ b/openff/toolkit/topology/topology.py @@ -39,7 +39,11 @@ from typing_extensions import TypeAlias from openff.toolkit.topology import Molecule -from openff.toolkit.topology._mm_molecule import _SimpleBond, _SimpleMolecule +from openff.toolkit.topology._mm_molecule import ( + _SimpleAtom, + _SimpleBond, + _SimpleMolecule, +) from openff.toolkit.topology.molecule import FrozenMolecule, HierarchyElement from openff.toolkit.utils import quantity_to_string, string_to_quantity from openff.toolkit.utils.constants import ( @@ -841,7 +845,7 @@ def n_propers(self) -> int: return sum(mol.n_propers for mol in self._molecules) @property - def propers(self) -> Generator[Tuple["Atom", ...], None, None]: + def propers(self) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]: """Iterable of Tuple[Atom]: iterator over the proper torsions in this Topology.""" for molecule in self.molecules: for proper in molecule.propers: @@ -860,7 +864,9 @@ def impropers(self) -> Generator[Tuple["Atom", ...], None, None]: yield improper @property - def smirnoff_impropers(self) -> Generator[Tuple["Atom", ...], None, None]: + def smirnoff_impropers( + self, + ) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]: """ Iterate over improper torsions in the molecule, but only those with trivalent centers, reporting the central atom second in each improper. @@ -899,7 +905,9 @@ def smirnoff_impropers(self) -> Generator[Tuple["Atom", ...], None, None]: yield smirnoff_improper @property - def amber_impropers(self) -> Generator[Tuple["Atom", ...], None, None]: + def amber_impropers( + self, + ) -> Generator[Tuple[Union["Atom", _SimpleAtom], ...], None, None]: """ Iterate over improper torsions in the molecule, but only those with trivalent centers, reporting the central atom first in each improper.