diff --git a/README.md b/README.md index b94a156..3fbad3b 100644 --- a/README.md +++ b/README.md @@ -16,20 +16,15 @@ pip install fgutils A simple example querying the functional groups for acetylsalicylic acid. ``` import fgutils -import rdkit.Chem.rdmolfiles as rdmolfiles -from fgutils.utils import mol_to_graph -smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid -mol_graph = mol_to_graph(rdmolfiles.MolFromSmiles(smiles)) -index_map, groups = fgutils.get_functional_groups(mol_graph) -print(index_map, groups) +smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid +query = fgutils.FGQuery(use_smiles=True) # requires rdkit to be installed +groups = query.get(smiles) +print(groups) ``` -The output is an index map and a list of functional group names. +The output is a list of tuples containing the functional group name and the corresponding atom indices. ``` -{0: [0, 1, 2], 1: [0, 1, 2], 3: [2, 3]} -['carbonyl', 'ketone', 'ester', 'ether'] +[('ester', [0, 1, 3]), ('carboxylic_acid', [10, 11, 12])] ``` - -The index map maps atom numbers to functional groups. The key of the map is the atom index and the value is the list of indices from the functional group list. E.g. atom 0 (oxygen) is in functional groups carbonyl, ketone, and ester. diff --git a/fgutils/__init__.py b/fgutils/__init__.py index e5478d1..b61f754 100644 --- a/fgutils/__init__.py +++ b/fgutils/__init__.py @@ -1 +1 @@ -from .query import get_functional_groups +from .query import get_functional_groups, FGQuery diff --git a/fgutils/fgconfig.py b/fgutils/fgconfig.py index d9ec582..cce63e7 100644 --- a/fgutils/fgconfig.py +++ b/fgutils/fgconfig.py @@ -45,6 +45,7 @@ {"name": "anilin", "pattern": "C:CN(R)R", "group_atoms": [2]}, {"name": "ketene", "pattern": "RC(R)=C=O", "group_atoms": [1, 3, 4]}, {"name": "carbamate", "pattern": "ROC(=O)N(R)R", "group_atoms": [1, 2, 3, 4]}, + {"name": "acyl_chloride", "pattern": "RC(=O)Cl", "group_atoms": [1, 2, 3]}, ] diff --git a/fgutils/query.py b/fgutils/query.py index 8501f89..2c9926f 100644 --- a/fgutils/query.py +++ b/fgutils/query.py @@ -1,5 +1,5 @@ import copy -import collections +import networkx as nx from fgutils.utils import add_implicit_hydrogens from fgutils.mapping import map_pattern @@ -36,38 +36,73 @@ def is_functional_group(graph, index: int, config: FGConfig): return is_fg, sorted(fg_indices) -def get_functional_groups(graph) -> tuple[dict, list[str]]: - def _query(nodes: list[FGTreeNode], graph, idx, checked_groups=[]): - fg_groups = [] - fg_indices = [] - for node in nodes: - if node.fgconfig.name in checked_groups: - continue - is_fg, indices = is_functional_group(graph, idx, node.fgconfig) - if is_fg: - checked_groups.append(node.fgconfig.name) - fg_groups.append(node.fgconfig.name) - fg_indices.append(indices) - _fg_groups, _fg_indices = _query( - node.children, graph, idx, checked_groups - ) - fg_groups.extend(_fg_groups) - fg_indices.extend(_fg_indices) - return fg_groups, fg_indices +def _find_best_node_rec(nodes: list[FGTreeNode], graph, idx): + best_node = None + node_indices = [] + for node in nodes: + is_fg, fg_indices = is_functional_group(graph, idx, node.fgconfig) + if is_fg: + r_node, r_indices = _find_best_node_rec( + node.children, + graph, + idx, + ) + if r_node is None: + best_node = node + node_indices = fg_indices + else: + best_node = r_node + node_indices = r_indices + return best_node, node_indices + +def get_functional_groups(graph) -> list[tuple[str, list[int]]]: fg_candidate_ids = [ n_id for n_id, n_sym in graph.nodes(data="symbol") if n_sym not in ["H", "C"] ] roots = build_FG_tree() - idx_map = collections.defaultdict(lambda: []) groups = [] - for atom_id in fg_candidate_ids: - fg_groups, fg_indices = _query(roots, graph, atom_id) - if len(fg_groups) > 0: - for _group, _indices in zip(fg_groups, fg_indices): - assert atom_id in _indices - _i = len(groups) - groups.append(_group) - for _idx in _indices: - idx_map[_idx].append(_i) - return dict(idx_map), groups + unidentified_ids = [] + while len(fg_candidate_ids) > 0: + atom_id = fg_candidate_ids.pop(0) + node, indices = _find_best_node_rec(roots, graph, atom_id) + if node is None: + unidentified_ids.append(atom_id) + else: + assert atom_id in indices + for i in indices: + if i in fg_candidate_ids: + fg_candidate_ids.remove(i) + elif i in unidentified_ids: + unidentified_ids.remove(i) + groups.append((node.fgconfig.name, indices)) + if len(unidentified_ids) > 0: + raise RuntimeError( + "Could not find a functional group for atom(s) {}.".format( + ["{}@{}".format(graph.nodes[i]["symbol"], i) for i in unidentified_ids] + ) + ) + return groups + + +class FGQuery: + def __init__(self, use_smiles=False): + self.use_smiles = use_smiles + + def get(self, value) -> list[tuple[str, list[int]]]: + mol_graph = None + if isinstance(value, nx.Graph): + mol_graph = value + elif self.use_smiles: + import rdkit.Chem.rdmolfiles as rdmolfiles + from fgutils.rdkit import mol_to_graph + + mol = rdmolfiles.MolFromSmiles(value) + mol_graph = mol_to_graph(mol) + else: + raise ValueError( + "Can not interpret '{}' (type: {}) as mol graph.".format( + value, type(value) + ) + ) + return get_functional_groups(mol_graph) diff --git a/fgutils/rdkit.py b/fgutils/rdkit.py new file mode 100644 index 0000000..8538873 --- /dev/null +++ b/fgutils/rdkit.py @@ -0,0 +1,23 @@ +import networkx as nx +import rdkit.Chem as Chem + + +def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph: + bond_order_map = { + "SINGLE": 1, + "DOUBLE": 2, + "TRIPLE": 3, + "QUADRUPLE": 4, + "AROMATIC": 1.5, + } + g = nx.Graph() + for atom in mol.GetAtoms(): + g.add_node(atom.GetIdx(), symbol=atom.GetSymbol()) + for bond in mol.GetBonds(): + bond_type = str(bond.GetBondType()).split(".")[-1] + g.add_edge( + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + bond=bond_order_map[bond_type], + ) + return g diff --git a/fgutils/utils.py b/fgutils/utils.py index 5de34c7..f5e94a8 100644 --- a/fgutils/utils.py +++ b/fgutils/utils.py @@ -21,27 +21,6 @@ def print_graph(graph): ) -def mol_to_graph(mol) -> nx.Graph: - bond_order_map = { - "SINGLE": 1, - "DOUBLE": 2, - "TRIPLE": 3, - "QUADRUPLE": 4, - "AROMATIC": 1.5, - } - g = nx.Graph() - for atom in mol.GetAtoms(): - g.add_node(atom.GetIdx(), symbol=atom.GetSymbol()) - for bond in mol.GetBonds(): - bond_type = str(bond.GetBondType()).split(".")[-1] - g.add_edge( - bond.GetBeginAtomIdx(), - bond.GetEndAtomIdx(), - bond=bond_order_map[bond_type], - ) - return g - - def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph: valence_dict = { 4: ["C", "Si"], @@ -55,14 +34,14 @@ def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph: valence_table[elmt] = v nodes = [ (n_id, n_sym) - for n_id, n_sym in graph.nodes(data="symbol") + for n_id, n_sym in graph.nodes(data="symbol") # type: ignore if n_sym not in ["R", "H"] ] for n_id, n_sym in nodes: assert ( n_sym in valence_table.keys() ), "Element {} not found in valence table.".format(n_sym) - bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")]) + bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")]) # type: ignore h_cnt = int(8 - valence_table[n_sym] - bond_cnt) assert h_cnt >= 0, "Negative hydrogen count." for h_id in range(len(graph), len(graph) + h_cnt): diff --git a/test/test_parse.py b/test/test_parse.py index 26f3df5..ef9009d 100644 --- a/test/test_parse.py +++ b/test/test_parse.py @@ -12,10 +12,11 @@ def _assert_graph(g, exp_nodes, exp_edges): assert order == g.edges[i1, i2]["bond"] -def test_tokenize(): - def _ct(token, exp_type, exp_value, exp_col): - return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col +def _ct(token, exp_type, exp_value, exp_col): + return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col + +def test_tokenize(): it = tokenize("RC(=O)OR") assert True is _ct(next(it), "WILDCARD", "R", 0) assert True is _ct(next(it), "ATOM", "C", 1) @@ -27,6 +28,13 @@ def _ct(token, exp_type, exp_value, exp_col): assert True is _ct(next(it), "WILDCARD", "R", 7) +def test_tokenize_multichar(): + it = tokenize("RClR") + assert True is _ct(next(it), "WILDCARD", "R", 0) + assert True is _ct(next(it), "ATOM", "Cl", 1) + assert True is _ct(next(it), "WILDCARD", "R", 3) + + def test_branch(): exp_nodes = {0: "R", 1: "C", 2: "O", 3: "O", 4: "R"} exp_edges = [(0, 1, 1), (1, 2, 2), (1, 3, 1), (3, 4, 1)] diff --git a/test/test_query.py b/test/test_query.py index 5515915..5178f94 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,69 +1,10 @@ import pytest -import collections import rdkit.Chem.rdmolfiles as rdmolfiles from fgutils.query import get_functional_groups, is_functional_group from fgutils.parse import parse from fgutils.fgconfig import get_FG_by_name -from fgutils.utils import mol_to_graph - - -def _get_group_map(index_map): - group_map = collections.defaultdict(lambda: []) - for idx, group_indices in index_map.items(): - for group_idx in group_indices: - assert idx not in group_map[group_idx] - group_map[group_idx].append(idx) - return dict(group_map) - - -def _assert_fg(raw_result, fg_name, indices): - r_idxmap, r_groups = raw_result - assert ( - indices[0] in r_idxmap.keys() - ), "No functional group found for atom {}.".format(indices[0]) - groups = [r_groups[i] for i in r_idxmap[indices[0]]] - assert ( - fg_name in groups - ), "Could not find functional group '{}' for for index {}.".format( - fg_name, indices[0] - ) - group_i = -1 - for i, g in enumerate(groups): - if g == fg_name: - assert group_i == -1, "Found group multipel times ({}).".format(groups) - group_i = i - group_map = _get_group_map(r_idxmap) - assert len(group_map[group_i]) == len( - indices - ), "Expected group '{}' to have {} atoms but found {}.".format( - fg_name, len(indices), len(group_map[group_i]) - ) - for i in group_map[group_i]: - assert ( - i in indices - ), "Could not find index {} in functional group {} (Indices: {}).".format( - i, fg_name, group_map[group_i] - ) - - -def _assert_not_fg(raw_result, fg_name, indices): - r_idxmap, r_groups = raw_result - for idx in indices: - groups = [r_groups[i] for i in r_idxmap[idx]] - assert ( - fg_name not in groups - ), "Wrongly identified functional group '{}' for for index {}.".format( - fg_name, indices[0] - ) - - -def test_get_functional_groups_raw(): - mol = parse("C=O") - r = get_functional_groups(mol) - _assert_fg(r, "carbonyl", [0, 1]) - _assert_fg(r, "ketone", [0, 1]) - _assert_fg(r, "aldehyde", [0, 1]) +from fgutils.rdkit import mol_to_graph @pytest.mark.parametrize( @@ -72,6 +13,7 @@ def test_get_functional_groups_raw(): ("carbonyl", "CC(=O)O", 2, [1, 2]), ("carboxylic_acid", "CC(=O)O", 2, [1, 2, 3]), ("amide", "C(=O)N", 2, [0, 1, 2]), + ("acyl_chloride", "CC(=O)[Cl]", 3, [1, 2, 3]), ], ) def test_get_functional_group(name, smiles, anchor, exp_indices): @@ -83,13 +25,28 @@ def test_get_functional_group(name, smiles, anchor, exp_indices): assert exp_indices == indices +def test_get_functional_groups(): + mol = parse("C=O") + groups = get_functional_groups(mol) + print(groups) + assert ("aldehyde", [0, 1]) in groups + + +def test_get_functional_group_once(): + mol = parse("CC(=O)OC") + groups = get_functional_groups(mol) + print(groups) + assert 1 == len(groups) + assert ("ester", [1, 2, 3]) in groups + + @pytest.mark.parametrize( "smiles,functional_groups,exp_indices", [ pytest.param("C=O", ["aldehyde"], [[0, 1]], id="Formaldehyde"), pytest.param("C(=O)N", ["amide"], [[0, 1, 2]], id="Formamide"), pytest.param("NC(=O)CC(N)C(=O)O", ["amide"], [[0, 1, 2]], id="Asparagine"), - pytest.param("CC(=O)[Cl]", ["carbonyl"], [[1, 2]], id="Acetyl cloride"), + pytest.param("CC(=O)[Cl]", ["acyl_chloride"], [[1, 2, 3]], id="Acetyl cloride"), pytest.param("COC(C)=O", ["ester"], [[1, 2, 4]], id="Methyl acetate"), pytest.param("CC(=O)O", ["carboxylic_acid"], [[1, 2, 3]], id="Acetic acid"), pytest.param("NCC(=O)O", ["amine"], [[0]], id="Glycin"), @@ -103,12 +60,19 @@ def test_get_functional_group(name, smiles, anchor, exp_indices): pytest.param( "CSC(=O)c1ccccc1", ["thioester"], [[1, 2, 3]], id="Methyl thionobenzonat" ), + pytest.param( + "O=C(C)Oc1ccccc1C(=O)O", + ["ester", "carboxylic_acid"], + [[0, 1, 3], [10, 11, 12]], + id="Acetylsalicylic acid", + ), # pytest.param("", [""], [[]], id=""), ], ) def test_functional_group_on_compound(smiles, functional_groups, exp_indices): assert len(functional_groups) == len(exp_indices) mol = mol_to_graph(rdmolfiles.MolFromSmiles(smiles)) - r = get_functional_groups(mol) + groups = get_functional_groups(mol) + print(groups) for fg, indices in zip(functional_groups, exp_indices): - _assert_fg(r, fg, indices) + assert (fg, indices) in groups