Skip to content

Commit

Permalink
improvements and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Klaus Weinbauer committed Mar 13, 2024
1 parent 6c06b2e commit 44dfeb4
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 131 deletions.
17 changes: 6 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,15 @@ pip install fgutils
A simple example querying the functional groups for acetylsalicylic acid.
```
import fgutils
import rdkit.Chem.rdmolfiles as rdmolfiles
from fgutils.utils import mol_to_graph
smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid
mol_graph = mol_to_graph(rdmolfiles.MolFromSmiles(smiles))
index_map, groups = fgutils.get_functional_groups(mol_graph)
print(index_map, groups)
smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid
query = fgutils.FGQuery(use_smiles=True) # requires rdkit to be installed
groups = query.get(smiles)
print(groups)
```

The output is an index map and a list of functional group names.
The output is a list of tuples containing the functional group name and the corresponding atom indices.

```
{0: [0, 1, 2], 1: [0, 1, 2], 3: [2, 3]}
['carbonyl', 'ketone', 'ester', 'ether']
[('ester', [0, 1, 3]), ('carboxylic_acid', [10, 11, 12])]
```

The index map maps atom numbers to functional groups. The key of the map is the atom index and the value is the list of indices from the functional group list. E.g. atom 0 (oxygen) is in functional groups carbonyl, ketone, and ester.
2 changes: 1 addition & 1 deletion fgutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .query import get_functional_groups
from .query import get_functional_groups, FGQuery
1 change: 1 addition & 0 deletions fgutils/fgconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
{"name": "anilin", "pattern": "C:CN(R)R", "group_atoms": [2]},
{"name": "ketene", "pattern": "RC(R)=C=O", "group_atoms": [1, 3, 4]},
{"name": "carbamate", "pattern": "ROC(=O)N(R)R", "group_atoms": [1, 2, 3, 4]},
{"name": "acyl_chloride", "pattern": "RC(=O)Cl", "group_atoms": [1, 2, 3]},
]


Expand Down
95 changes: 65 additions & 30 deletions fgutils/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
import collections
import networkx as nx

from fgutils.utils import add_implicit_hydrogens
from fgutils.mapping import map_pattern
Expand Down Expand Up @@ -36,38 +36,73 @@ def is_functional_group(graph, index: int, config: FGConfig):
return is_fg, sorted(fg_indices)


def get_functional_groups(graph) -> tuple[dict, list[str]]:
def _query(nodes: list[FGTreeNode], graph, idx, checked_groups=[]):
fg_groups = []
fg_indices = []
for node in nodes:
if node.fgconfig.name in checked_groups:
continue
is_fg, indices = is_functional_group(graph, idx, node.fgconfig)
if is_fg:
checked_groups.append(node.fgconfig.name)
fg_groups.append(node.fgconfig.name)
fg_indices.append(indices)
_fg_groups, _fg_indices = _query(
node.children, graph, idx, checked_groups
)
fg_groups.extend(_fg_groups)
fg_indices.extend(_fg_indices)
return fg_groups, fg_indices
def _find_best_node_rec(nodes: list[FGTreeNode], graph, idx):
best_node = None
node_indices = []
for node in nodes:
is_fg, fg_indices = is_functional_group(graph, idx, node.fgconfig)
if is_fg:
r_node, r_indices = _find_best_node_rec(
node.children,
graph,
idx,
)
if r_node is None:
best_node = node
node_indices = fg_indices
else:
best_node = r_node
node_indices = r_indices
return best_node, node_indices


def get_functional_groups(graph) -> list[tuple[str, list[int]]]:
fg_candidate_ids = [
n_id for n_id, n_sym in graph.nodes(data="symbol") if n_sym not in ["H", "C"]
]
roots = build_FG_tree()
idx_map = collections.defaultdict(lambda: [])
groups = []
for atom_id in fg_candidate_ids:
fg_groups, fg_indices = _query(roots, graph, atom_id)
if len(fg_groups) > 0:
for _group, _indices in zip(fg_groups, fg_indices):
assert atom_id in _indices
_i = len(groups)
groups.append(_group)
for _idx in _indices:
idx_map[_idx].append(_i)
return dict(idx_map), groups
unidentified_ids = []
while len(fg_candidate_ids) > 0:
atom_id = fg_candidate_ids.pop(0)
node, indices = _find_best_node_rec(roots, graph, atom_id)
if node is None:
unidentified_ids.append(atom_id)
else:
assert atom_id in indices
for i in indices:
if i in fg_candidate_ids:
fg_candidate_ids.remove(i)
elif i in unidentified_ids:
unidentified_ids.remove(i)
groups.append((node.fgconfig.name, indices))
if len(unidentified_ids) > 0:
raise RuntimeError(
"Could not find a functional group for atom(s) {}.".format(
["{}@{}".format(graph.nodes[i]["symbol"], i) for i in unidentified_ids]
)
)
return groups


class FGQuery:
def __init__(self, use_smiles=False):
self.use_smiles = use_smiles

def get(self, value) -> list[tuple[str, list[int]]]:
mol_graph = None
if isinstance(value, nx.Graph):
mol_graph = value
elif self.use_smiles:
import rdkit.Chem.rdmolfiles as rdmolfiles
from fgutils.rdkit import mol_to_graph

mol = rdmolfiles.MolFromSmiles(value)
mol_graph = mol_to_graph(mol)
else:
raise ValueError(
"Can not interpret '{}' (type: {}) as mol graph.".format(
value, type(value)
)
)
return get_functional_groups(mol_graph)
23 changes: 23 additions & 0 deletions fgutils/rdkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import networkx as nx
import rdkit.Chem as Chem


def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph:
bond_order_map = {
"SINGLE": 1,
"DOUBLE": 2,
"TRIPLE": 3,
"QUADRUPLE": 4,
"AROMATIC": 1.5,
}
g = nx.Graph()
for atom in mol.GetAtoms():
g.add_node(atom.GetIdx(), symbol=atom.GetSymbol())
for bond in mol.GetBonds():
bond_type = str(bond.GetBondType()).split(".")[-1]
g.add_edge(
bond.GetBeginAtomIdx(),
bond.GetEndAtomIdx(),
bond=bond_order_map[bond_type],
)
return g
25 changes: 2 additions & 23 deletions fgutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,6 @@ def print_graph(graph):
)


def mol_to_graph(mol) -> nx.Graph:
bond_order_map = {
"SINGLE": 1,
"DOUBLE": 2,
"TRIPLE": 3,
"QUADRUPLE": 4,
"AROMATIC": 1.5,
}
g = nx.Graph()
for atom in mol.GetAtoms():
g.add_node(atom.GetIdx(), symbol=atom.GetSymbol())
for bond in mol.GetBonds():
bond_type = str(bond.GetBondType()).split(".")[-1]
g.add_edge(
bond.GetBeginAtomIdx(),
bond.GetEndAtomIdx(),
bond=bond_order_map[bond_type],
)
return g


def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph:
valence_dict = {
4: ["C", "Si"],
Expand All @@ -55,14 +34,14 @@ def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph:
valence_table[elmt] = v
nodes = [
(n_id, n_sym)
for n_id, n_sym in graph.nodes(data="symbol")
for n_id, n_sym in graph.nodes(data="symbol") # type: ignore
if n_sym not in ["R", "H"]
]
for n_id, n_sym in nodes:
assert (
n_sym in valence_table.keys()
), "Element {} not found in valence table.".format(n_sym)
bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")])
bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")]) # type: ignore
h_cnt = int(8 - valence_table[n_sym] - bond_cnt)
assert h_cnt >= 0, "Negative hydrogen count."
for h_id in range(len(graph), len(graph) + h_cnt):
Expand Down
14 changes: 11 additions & 3 deletions test/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ def _assert_graph(g, exp_nodes, exp_edges):
assert order == g.edges[i1, i2]["bond"]


def test_tokenize():
def _ct(token, exp_type, exp_value, exp_col):
return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col
def _ct(token, exp_type, exp_value, exp_col):
return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col


def test_tokenize():
it = tokenize("RC(=O)OR")
assert True is _ct(next(it), "WILDCARD", "R", 0)
assert True is _ct(next(it), "ATOM", "C", 1)
Expand All @@ -27,6 +28,13 @@ def _ct(token, exp_type, exp_value, exp_col):
assert True is _ct(next(it), "WILDCARD", "R", 7)


def test_tokenize_multichar():
it = tokenize("RClR")
assert True is _ct(next(it), "WILDCARD", "R", 0)
assert True is _ct(next(it), "ATOM", "Cl", 1)
assert True is _ct(next(it), "WILDCARD", "R", 3)


def test_branch():
exp_nodes = {0: "R", 1: "C", 2: "O", 3: "O", 4: "R"}
exp_edges = [(0, 1, 1), (1, 2, 2), (1, 3, 1), (3, 4, 1)]
Expand Down
90 changes: 27 additions & 63 deletions test/test_query.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,10 @@
import pytest
import collections
import rdkit.Chem.rdmolfiles as rdmolfiles

from fgutils.query import get_functional_groups, is_functional_group
from fgutils.parse import parse
from fgutils.fgconfig import get_FG_by_name
from fgutils.utils import mol_to_graph


def _get_group_map(index_map):
group_map = collections.defaultdict(lambda: [])
for idx, group_indices in index_map.items():
for group_idx in group_indices:
assert idx not in group_map[group_idx]
group_map[group_idx].append(idx)
return dict(group_map)


def _assert_fg(raw_result, fg_name, indices):
r_idxmap, r_groups = raw_result
assert (
indices[0] in r_idxmap.keys()
), "No functional group found for atom {}.".format(indices[0])
groups = [r_groups[i] for i in r_idxmap[indices[0]]]
assert (
fg_name in groups
), "Could not find functional group '{}' for for index {}.".format(
fg_name, indices[0]
)
group_i = -1
for i, g in enumerate(groups):
if g == fg_name:
assert group_i == -1, "Found group multipel times ({}).".format(groups)
group_i = i
group_map = _get_group_map(r_idxmap)
assert len(group_map[group_i]) == len(
indices
), "Expected group '{}' to have {} atoms but found {}.".format(
fg_name, len(indices), len(group_map[group_i])
)
for i in group_map[group_i]:
assert (
i in indices
), "Could not find index {} in functional group {} (Indices: {}).".format(
i, fg_name, group_map[group_i]
)


def _assert_not_fg(raw_result, fg_name, indices):
r_idxmap, r_groups = raw_result
for idx in indices:
groups = [r_groups[i] for i in r_idxmap[idx]]
assert (
fg_name not in groups
), "Wrongly identified functional group '{}' for for index {}.".format(
fg_name, indices[0]
)


def test_get_functional_groups_raw():
mol = parse("C=O")
r = get_functional_groups(mol)
_assert_fg(r, "carbonyl", [0, 1])
_assert_fg(r, "ketone", [0, 1])
_assert_fg(r, "aldehyde", [0, 1])
from fgutils.rdkit import mol_to_graph


@pytest.mark.parametrize(
Expand All @@ -72,6 +13,7 @@ def test_get_functional_groups_raw():
("carbonyl", "CC(=O)O", 2, [1, 2]),
("carboxylic_acid", "CC(=O)O", 2, [1, 2, 3]),
("amide", "C(=O)N", 2, [0, 1, 2]),
("acyl_chloride", "CC(=O)[Cl]", 3, [1, 2, 3]),
],
)
def test_get_functional_group(name, smiles, anchor, exp_indices):
Expand All @@ -83,13 +25,28 @@ def test_get_functional_group(name, smiles, anchor, exp_indices):
assert exp_indices == indices


def test_get_functional_groups():
mol = parse("C=O")
groups = get_functional_groups(mol)
print(groups)
assert ("aldehyde", [0, 1]) in groups


def test_get_functional_group_once():
mol = parse("CC(=O)OC")
groups = get_functional_groups(mol)
print(groups)
assert 1 == len(groups)
assert ("ester", [1, 2, 3]) in groups


@pytest.mark.parametrize(
"smiles,functional_groups,exp_indices",
[
pytest.param("C=O", ["aldehyde"], [[0, 1]], id="Formaldehyde"),
pytest.param("C(=O)N", ["amide"], [[0, 1, 2]], id="Formamide"),
pytest.param("NC(=O)CC(N)C(=O)O", ["amide"], [[0, 1, 2]], id="Asparagine"),
pytest.param("CC(=O)[Cl]", ["carbonyl"], [[1, 2]], id="Acetyl cloride"),
pytest.param("CC(=O)[Cl]", ["acyl_chloride"], [[1, 2, 3]], id="Acetyl cloride"),
pytest.param("COC(C)=O", ["ester"], [[1, 2, 4]], id="Methyl acetate"),
pytest.param("CC(=O)O", ["carboxylic_acid"], [[1, 2, 3]], id="Acetic acid"),
pytest.param("NCC(=O)O", ["amine"], [[0]], id="Glycin"),
Expand All @@ -103,12 +60,19 @@ def test_get_functional_group(name, smiles, anchor, exp_indices):
pytest.param(
"CSC(=O)c1ccccc1", ["thioester"], [[1, 2, 3]], id="Methyl thionobenzonat"
),
pytest.param(
"O=C(C)Oc1ccccc1C(=O)O",
["ester", "carboxylic_acid"],
[[0, 1, 3], [10, 11, 12]],
id="Acetylsalicylic acid",
),
# pytest.param("", [""], [[]], id=""),
],
)
def test_functional_group_on_compound(smiles, functional_groups, exp_indices):
assert len(functional_groups) == len(exp_indices)
mol = mol_to_graph(rdmolfiles.MolFromSmiles(smiles))
r = get_functional_groups(mol)
groups = get_functional_groups(mol)
print(groups)
for fg, indices in zip(functional_groups, exp_indices):
_assert_fg(r, fg, indices)
assert (fg, indices) in groups

0 comments on commit 44dfeb4

Please sign in to comment.