Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple parmed from the foyer atomtyper #389

56 changes: 36 additions & 20 deletions foyer/atomtyper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from warnings import warn

import parmed as pmd
import parmed.periodic_table as pt

from foyer.exceptions import FoyerError
from foyer.topology_graph import TopologyGraph
from foyer.smarts_graph import SMARTSGraph


Expand All @@ -11,42 +13,54 @@ def find_atomtypes(structure, forcefield, max_iter=10):

Parameters
----------
topology : simtk.openmm.app.Topology
topology : parmed.Structure or TopologyGraph
The topology that we are trying to atomtype.
forcefield : foyer.Forcefield
The forcefield object.
max_iter : int, optional, default=10
The maximum number of iterations.

"""
typemap = {atom.idx: {'whitelist': set(), 'blacklist': set(),
'atomtype': None} for atom in structure.atoms}
topology_graph = structure

if isinstance(structure, pmd.Structure):
topology_graph = TopologyGraph.from_parmed(structure)

typemap = {
atom_index: {
'whitelist': set(),
'blacklist': set(),
'atomtype': None
} for atom_index in topology_graph.atoms(data=False)
}

rules = _load_rules(forcefield, typemap)

# Only consider rules for elements found in topology
subrules = dict()

system_elements = set()
for a in structure.atoms:
for _, atom_data in topology_graph.atoms(data=True):
# First add non-element types, which are strings, then elements
if a.name.startswith('_'):
if a.name in forcefield.non_element_types:
system_elements.add(a.name)
name = atom_data.name
if name.startswith('_'):
if name in forcefield.non_element_types:
system_elements.add(name)
else:
if 0 < a.atomic_number <= pt.KNOWN_ELEMENTS:
element = pt.Element[a.atomic_number]
atomic_number = atom_data.atomic_number
if 0 < atomic_number <= pt.KNOWN_ELEMENTS:
element = pt.Element[atomic_number]
system_elements.add(element)
else:
raise FoyerError(
'Parsed atom {} as having neither an element '
'nor non-element type.'.format(a)
'nor non-element type.'.format(name)
)

for key, val in rules.items():
atom = val.nodes[0]['atom']
if len(list(atom.find_data('atom_symbol'))) == 1 and \
not list(atom.find_data('not_expression')):
not list(atom.find_data('not_expression')):
try:
element = next(atom.find_data('atom_symbol')).children[0]
except IndexError:
Expand All @@ -61,11 +75,12 @@ def find_atomtypes(structure, forcefield, max_iter=10):
subrules[key] = val
rules = subrules

_iterate_rules(rules, structure, typemap, max_iter=max_iter)
_resolve_atomtypes(structure, typemap)
_iterate_rules(rules, topology_graph, typemap, max_iter=max_iter)
_resolve_atomtypes(topology_graph, typemap)

return typemap


def _load_rules(forcefield, typemap):
"""Load atomtyping rules from a forcefield into SMARTSGraphs. """
rules = dict()
Expand All @@ -87,16 +102,16 @@ def _load_rules(forcefield, typemap):
return rules


def _iterate_rules(rules, structure, typemap, max_iter):
def _iterate_rules(rules, topology_graph, typemap, max_iter):
"""Iteratively run all the rules until the white- and blacklists converge.

Parameters
----------
rules : dict
A dictionary mapping rule names (typically atomtype names) to
SMARTSGraphs that evaluate those rules.
topology : simtk.openmm.app.Topology
The topology that we are trying to atomtype.
topology_graph : TopologyGraph
The topology graph that we are trying to atomtype.
max_iter : int
The maximum number of iterations.

Expand All @@ -106,7 +121,7 @@ def _iterate_rules(rules, structure, typemap, max_iter):
max_iter -= 1
found_something = False
for rule in rules.values():
for match_index in rule.find_matches(structure, typemap):
for match_index in rule.find_matches(topology_graph, typemap):
atom = typemap[match_index]
# This conditional is not strictly necessary, but it prevents
# redundant set addition on later iterations
Expand All @@ -120,11 +135,12 @@ def _iterate_rules(rules, structure, typemap, max_iter):
warn("Reached maximum iterations. Something probably went wrong.")
return typemap

def _resolve_atomtypes(structure, typemap):

def _resolve_atomtypes(topology_graph, typemap):
"""Determine the final atomtypes from the white- and blacklists. """
atoms = structure.atoms
atoms = {atom_idx: data for atom_idx, data in topology_graph.atoms(data=True)}
for atom_id, atom in typemap.items():
atomtype = [rule_name for rule_name in
atomtype = [rule_name for rule_name in
atom['whitelist'] - atom['blacklist']]
if len(atomtype) == 1:
atom['atomtype'] = atomtype[0]
Expand Down
80 changes: 38 additions & 42 deletions foyer/smarts_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SMARTSGraph(nx.Graph):
smarts_string : str
The SMARTS string outlined in the force field
parser : foyer.smarts.SMARTS
The parser whose grammar rules convert the SMARTSstring
The parser whose grammar rules convert the SMARTSstring
into the AST
name : str
overrides : set
Expand Down Expand Up @@ -95,62 +95,66 @@ def _add_label_edges(self):
def _node_match(self, host, pattern):
""" Determine if two graph nodes are equal """
atom_expr = pattern['atom'].children[0]
atom = host['atom']
return self._atom_expr_matches(atom_expr, atom)
atom = host['atom_data']
bond_partners = host['bond_partners']
return self._atom_expr_matches(atom_expr, atom, bond_partners)

def _atom_expr_matches(self, atom_expr, atom):
def _atom_expr_matches(self, atom_expr, atom, bond_partners):
""" Helper function for evaluating SMARTS string expressions """
if atom_expr.data == 'not_expression':
return not self._atom_expr_matches(atom_expr.children[0], atom)
return not self._atom_expr_matches(atom_expr.children[0], atom, bond_partners)
elif atom_expr.data in ('and_expression', 'weak_and_expression'):
return (self._atom_expr_matches(atom_expr.children[0], atom) and
self._atom_expr_matches(atom_expr.children[1], atom))
return (self._atom_expr_matches(atom_expr.children[0], atom, bond_partners) and
self._atom_expr_matches(atom_expr.children[1], atom, bond_partners))
elif atom_expr.data == 'or_expression':
return (self._atom_expr_matches(atom_expr.children[0], atom) or
self._atom_expr_matches(atom_expr.children[1], atom))
return (self._atom_expr_matches(atom_expr.children[0], atom, bond_partners) or
self._atom_expr_matches(atom_expr.children[1], atom, bond_partners))
elif atom_expr.data == 'atom_id':
return self._atom_id_matches(atom_expr.children[0], atom, self.typemap)
return self._atom_id_matches(atom_expr.children[0], atom, bond_partners, self.typemap)
elif atom_expr.data == 'atom_symbol':
return self._atom_id_matches(atom_expr, atom, self.typemap)
return self._atom_id_matches(atom_expr, atom, bond_partners, self.typemap)
else:
raise TypeError('Expected atom_id, atom_symbol, and_expression, '
'or_expression, or not_expression. '
'Got {}'.format(atom_expr.data))

@staticmethod
def _atom_id_matches(atom_id, atom, typemap):
def _atom_id_matches(atom_id, atom, bond_partners, typemap):
""" Helper func for comparing atomic indices, symbols, neighbors, rings """
atomic_num = atom.element
atomic_num = atom.atomic_number
atom_name = atom.name
atom_idx = atom.index

if atom_id.data == 'atomic_num':
return atomic_num == int(atom_id.children[0])
elif atom_id.data == 'atom_symbol':
if str(atom_id.children[0]) == '*':
return True
elif str(atom_id.children[0]).startswith('_'):
# Store non-element elements in .name
return atom.name == str(atom_id.children[0])
return atom_name == str(atom_id.children[0])
else:
return atomic_num == pt.AtomicNum[str(atom_id.children[0])]
elif atom_id.data == 'has_label':
label = atom_id.children[0][1:] # Strip the % sign from the beginning.
return label in typemap[atom.idx]['whitelist']
return label in typemap[atom_idx]['whitelist']
elif atom_id.data == 'neighbor_count':
return len(atom.bond_partners) == int(atom_id.children[0])
return len(bond_partners) == int(atom_id.children[0])
elif atom_id.data == 'ring_size':
cycle_len = int(atom_id.children[0])
for cycle in typemap[atom.idx]['cycles']:
for cycle in typemap[atom_idx]['cycles']:
if len(cycle) == cycle_len:
return True
return False
elif atom_id.data == 'ring_count':
n_cycles = len(typemap[atom.idx]['cycles'])
n_cycles = len(typemap[atom_idx]['cycles'])
if n_cycles == int(atom_id.children[0]):
return True
return False
elif atom_id.data == 'matches_string':
raise NotImplementedError('matches_string is not yet implemented')

def find_matches(self, structure, typemap):
def find_matches(self, topology_graph, typemap):
"""Return sets of atoms that match this SMARTS pattern in a topology.

Notes:
Expand All @@ -168,13 +172,8 @@ def find_matches(self, structure, typemap):
ring_tokens = ['ring_size', 'ring_count']
has_ring_rules = any(list(self.ast.find_data(token))
for token in ring_tokens)
_prepare_atoms(structure, typemap, compute_cycles=has_ring_rules)

top_graph = nx.Graph()
top_graph.add_nodes_from(((a.idx, {'atom': a})
for a in structure.atoms))
top_graph.add_edges_from(((b.atom1.idx, b.atom2.idx)
for b in structure.bonds))
topology_graph.add_bond_partners()
_prepare_atoms(topology_graph, typemap, compute_cycles=has_ring_rules)

if self._graph_matcher is None:
atom = nx.get_node_attributes(self, name='atom')[0]
Expand All @@ -190,7 +189,7 @@ def find_matches(self, structure, typemap):
element = None
else:
element = None
self._graph_matcher = SMARTSMatcher(top_graph, self,
self._graph_matcher = SMARTSMatcher(topology_graph, self,
node_match=self._node_match,
element=element,
typemap=typemap)
Expand Down Expand Up @@ -291,7 +290,7 @@ def _find_chordless_cycles(bond_graph, max_cycle_size):

for possible_ring in possible_rings:
if bond_graph.has_edge(possible_ring[-1], last_node):
if any([bond_graph.has_edge(possible_ring[-1],
if any([bond_graph.has_edge(possible_ring[-1],
internal_node)
for internal_node in possible_ring[1:-2]]):
pass
Expand All @@ -305,26 +304,23 @@ def _find_chordless_cycles(bond_graph, max_cycle_size):
return cycles


def _prepare_atoms(structure, typemap, compute_cycles=False):
def _prepare_atoms(topology_graph, typemap, compute_cycles=False):
"""Compute cycles and add white-/blacklists to atoms."""
atom1 = structure.atoms[0]#next(topology.atoms())
has_whitelists = 'whitelist' in typemap[atom1.idx]
has_cycles = 'cycles' in typemap[atom1.idx]
atom1 = next(topology_graph.atoms(data=False)) #next(topology.atoms())
has_whitelists = 'whitelist' in typemap[atom1]
has_cycles = 'cycles' in typemap[atom1]
compute_cycles = compute_cycles and not has_cycles

if compute_cycles or not has_whitelists:
for atom in structure.atoms:
for index in topology_graph.atoms(data=False):
if compute_cycles:
typemap[atom.idx]['cycles'] = set()
typemap[index]['cycles'] = set()
if not has_whitelists:
typemap[atom.idx]['whitelist'] = set()
typemap[atom.idx]['blacklist'] = set()
typemap[index]['whitelist'] = set()
typemap[index]['blacklist'] = set()

if compute_cycles:
bond_graph = nx.Graph()
bond_graph.add_nodes_from(structure.atoms)
bond_graph.add_edges_from([(b.atom1, b.atom2) for b in structure.bonds])
all_cycles = _find_chordless_cycles(bond_graph, max_cycle_size=8)
for atom, cycles in zip(bond_graph.nodes, all_cycles):
all_cycles = _find_chordless_cycles(topology_graph, max_cycle_size=8)
for atom, cycles in zip(topology_graph.nodes, all_cycles):
for cycle in cycles:
typemap[atom.idx]['cycles'].add(tuple(cycle))
typemap[atom]['cycles'].add(tuple(cycle))
17 changes: 9 additions & 8 deletions foyer/tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import parmed as pmd

from foyer.smarts_graph import SMARTSGraph, _prepare_atoms
from foyer.topology_graph import TopologyGraph
from foyer.tests.utils import get_fn


Expand All @@ -25,30 +26,30 @@ def test_init():

def test_lazy_cycle_finding():
mol2 = pmd.load_file(get_fn('ethane.mol2'), structure=True)
typemap = {atom.idx: {'whitelist': set(), 'blacklist': set(),
'atomtype': None}
typemap = {atom.idx: {'whitelist': set(), 'blacklist': set(),
'atomtype': None}
for atom in mol2.atoms}

rule = SMARTSGraph(smarts_string='[C]', typemap=typemap)
list(rule.find_matches(mol2, typemap))
list(rule.find_matches(TopologyGraph.from_parmed(mol2), typemap))
assert not any(['cycles' in typemap[a.idx] for a in mol2.atoms])

ring_tokens = ['R1', 'r6']
for token in ring_tokens:
rule = SMARTSGraph(smarts_string='[C;{}]'.format(token),
typemap=typemap)
list(rule.find_matches(mol2, typemap))
list(rule.find_matches(TopologyGraph.from_parmed(mol2), typemap))
assert all(['cycles' in typemap[a.idx] for a in mol2.atoms])


def test_cycle_finding_multiple():
mol2 = pmd.load_file(get_fn('fullerene.pdb'), structure=True)
typemap = {atom.idx: {'whitelist': set(), 'blacklist': set(),
'atomtype': None}
typemap = {atom.idx: {'whitelist': set(), 'blacklist': set(),
'atomtype': None}
for atom in mol2.atoms}

_prepare_atoms(mol2, typemap, compute_cycles=True)
cycle_lengths = [list(map(len, typemap[atom.idx]['cycles']))
_prepare_atoms(TopologyGraph.from_parmed(mol2), typemap, compute_cycles=True)
cycle_lengths = [list(map(len, typemap[atom.idx]['cycles']))
for atom in mol2.atoms]
expected = [5, 6, 6]
assert all(sorted(lengths) == expected for lengths in cycle_lengths)
Loading