diff --git a/README.md b/README.md index a0f0603..4bbe250 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,12 @@ pip install fgutils ## Getting Started A simple example querying the functional groups for acetylsalicylic acid. ``` -import fgutils -import rdkit.Chem.rdmolfiles as rdmolfiles -from fgutils.utils import mol_to_graph - -smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid -mol_graph = mol_to_graph(rdmolfiles.MolFromSmiles(smiles)) -index_map, groups = fgutils.get_functional_groups_raw(mol_graph) -print(index_map, groups) -``` - -The output is an index map and a list of functional group names. - -``` -{0: [0, 1, 2], 1: [0, 1, 2], 3: [2, 3]} -['carbonyl', 'ketone', 'ester', 'ether'] +>>> from fgutils import FGQuery +>>> +>>> smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid +>>> query = FGQuery(use_smiles=True) # use_smiles requires rdkit to be installed +>>> query.get(smiles) +[('ester', [0, 1, 3]), ('carboxylic_acid', [10, 11, 12])] ``` -The index map maps atom numbers to functional groups. The key of the map is the atom index and the value is the list of indices from the functional group list. E.g. atom 0 (oxygen) is in functional groups carbonyl, ketone, and ester. +The output is a list of tuples containing the functional group name and the corresponding atom indices. diff --git a/fgutils/__init__.py b/fgutils/__init__.py index d35b870..0201594 100644 --- a/fgutils/__init__.py +++ b/fgutils/__init__.py @@ -1 +1,2 @@ -from .query import get_functional_groups_raw +from .permutation import PermutationMapper +from .query import FGQuery diff --git a/fgutils/fgconfig.py b/fgutils/fgconfig.py index d9ec582..4a5b255 100644 --- a/fgutils/fgconfig.py +++ b/fgutils/fgconfig.py @@ -1,10 +1,11 @@ from __future__ import annotations import numpy as np +from fgutils.permutation import PermutationMapper from fgutils.parse import parse -from fgutils.mapping import map_pattern +from fgutils.mapping import map_to_entire_graph -functional_group_config = [ +_default_fg_config = [ { "name": "carbonyl", "pattern": "C(=O)", @@ -45,6 +46,7 @@ {"name": "anilin", "pattern": "C:CN(R)R", "group_atoms": [2]}, {"name": "ketene", "pattern": "RC(R)=C=O", "group_atoms": [1, 3, 4]}, {"name": "carbamate", "pattern": "ROC(=O)N(R)R", "group_atoms": [1, 2, 3, 4]}, + {"name": "acyl_chloride", "pattern": "RC(=O)Cl", "group_atoms": [1, 2, 3]}, ] @@ -100,9 +102,9 @@ def pattern_len(self) -> int: ) -def is_subgroup(parent: FGConfig, child: FGConfig) -> bool: - p2c = map_full(child.pattern, parent.pattern) - c2p = map_full(parent.pattern, child.pattern) +def is_subgroup(parent: FGConfig, child: FGConfig, mapper: PermutationMapper) -> bool: + p2c = map_to_entire_graph(child.pattern, parent.pattern, mapper) + c2p = map_to_entire_graph(parent.pattern, child.pattern, mapper) if p2c: assert c2p is False, "{} ({}) -> {} ({}) matches in both directions.".format( parent.name, parent.pattern_str, child.name, child.pattern_str @@ -111,26 +113,11 @@ def is_subgroup(parent: FGConfig, child: FGConfig) -> bool: return False -class TreeNode: - def __init__(self, is_child_callback): - self.parents: list[TreeNode] = [] - self.children: list[TreeNode] = [] - self.is_child_callback = is_child_callback - - def is_child(self, parent: TreeNode) -> bool: - return self.is_child_callback(parent, self) - - def add_child(self, child: TreeNode): - child.parents.append(self) - self.children.append(child) - - -class FGTreeNode(TreeNode): +class FGTreeNode: def __init__(self, fgconfig: FGConfig): self.fgconfig = fgconfig - self.parents: list[FGTreeNode] - self.children: list[FGTreeNode] - super().__init__(lambda a, b: is_subgroup(a.fgconfig, b.fgconfig)) + self.parents: list[FGTreeNode] = [] + self.children: list[FGTreeNode] = [] def order_id(self): return ( @@ -140,35 +127,12 @@ def order_id(self): ) def add_child(self, child: FGTreeNode): - super().add_child(child) + child.parents.append(self) + self.children.append(child) self.parents = sorted(self.parents, key=lambda x: x.order_id(), reverse=True) self.children = sorted(self.children, key=lambda x: x.order_id(), reverse=True) -fg_configs = None - - -def get_FG_list() -> list[FGConfig]: - global fg_configs - if fg_configs is None: - c = [] - for fgc in functional_group_config: - c.append(FGConfig(**fgc)) - fg_configs = c - return fg_configs - - -def get_FG_by_name(name: str) -> FGConfig: - for fg in get_FG_list(): - if fg.name == name: - return fg - raise KeyError("No functional group config with name '{}' found.".format(name)) - - -def get_FG_names() -> list[str]: - return [c.name for c in get_FG_list()] - - def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig]: return list( sorted( @@ -179,19 +143,13 @@ def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig ) -def map_full(graph, pattern): - for i in range(len(graph)): - r, _ = map_pattern(graph, i, pattern) - if r is True: - return True - return False - - -def search_parents(roots: list[TreeNode], child: TreeNode) -> None | list[TreeNode]: +def search_parents( + roots: list[FGTreeNode], child: FGTreeNode, mapper: PermutationMapper +) -> None | list[FGTreeNode]: parents = set() for root in roots: - if child.is_child(root): - _parents = search_parents(root.children, child) + if is_subgroup(root.fgconfig, child.fgconfig, mapper): + _parents = search_parents(root.children, child, mapper) if _parents is None: parents.add(root) else: @@ -221,11 +179,13 @@ def _print(node: FGTreeNode, indent=0): _print(root) -def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode]: +def build_config_tree_from_list( + config_list: list[FGConfig], mapper: PermutationMapper +) -> list[FGTreeNode]: roots = [] for config in sort_by_pattern_len(config_list): node = FGTreeNode(config) - parents = search_parents(roots, node) + parents = search_parents(roots, node, mapper) if parents is None: roots.append(node) else: @@ -234,11 +194,43 @@ def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode] return roots -_fg_tree_roots = None +class FGConfigProvider: + def __init__( + self, + config: list[dict] | list[FGConfig] | None = None, + mapper: PermutationMapper | None = None, + ): + self.config_list: list[FGConfig] = [] + if config is None: + config = _default_fg_config + if isinstance(config, list) and len(config) > 0: + if isinstance(config[0], dict): + for fgc in config: + self.config_list.append(FGConfig(**fgc)) # type: ignore + elif isinstance(config[0], FGConfig): + self.config_list = config # type: ignore + else: + raise ValueError("Invalid config value.") + else: + raise ValueError("Invalid config value.") + + self.mapper = ( + mapper + if mapper is not None + else PermutationMapper(wildcard="R", ignore_case=True) + ) + + self.__tree_roots = None + def get_tree(self) -> list[FGTreeNode]: + if self.__tree_roots is None: + self.__tree_roots = build_config_tree_from_list( + self.config_list, self.mapper + ) + return self.__tree_roots -def build_FG_tree() -> list[FGTreeNode]: - global _fg_tree_roots - if _fg_tree_roots is None: - _fg_tree_roots = build_config_tree_from_list(get_FG_list()) - return _fg_tree_roots + def get_by_name(self, name: str) -> FGConfig: + for fg in self.config_list: + if fg.name == name: + return fg + raise KeyError("No functional group config with name '{}' found.".format(name)) diff --git a/fgutils/mapping.py b/fgutils/mapping.py index fa18049..f7264cb 100644 --- a/fgutils/mapping.py +++ b/fgutils/mapping.py @@ -1,7 +1,7 @@ import copy import networkx as nx -from fgutils.permutation import Mapper +from fgutils.permutation import PermutationMapper def _get_neighbors(graph, idx, excluded_nodes=set()): @@ -17,10 +17,12 @@ def _get_symbol(graph, idx): def map_anchored_pattern( - graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: int + graph: nx.Graph, + anchor: int, + pattern: nx.Graph, + pattern_anchor: int, + mapper: PermutationMapper, ): - mapper = Mapper(wildcard="R", ignore_case=True) - def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0): visited_nodes = copy.deepcopy(visited_nodes) visited_nodes.add(idx) @@ -90,15 +92,27 @@ def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0): def map_pattern( - graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: None | int = None + graph: nx.Graph, + anchor: int, + pattern: nx.Graph, + mapper: PermutationMapper, + pattern_anchor: None | int = None, ): if pattern_anchor is None: if len(pattern) == 0: return True, [] for pidx in pattern.nodes: - result = map_anchored_pattern(graph, anchor, pattern, pidx) + result = map_anchored_pattern(graph, anchor, pattern, pidx, mapper) if result[0]: return result return False, [] else: - return map_anchored_pattern(graph, anchor, pattern, pattern_anchor) + return map_anchored_pattern(graph, anchor, pattern, pattern_anchor, mapper) + + +def map_to_entire_graph(graph: nx.Graph, pattern: nx.Graph, mapper: PermutationMapper): + for i in range(len(graph)): + r, _ = map_pattern(graph, i, pattern, mapper) + if r is True: + return True + return False diff --git a/fgutils/permutation.py b/fgutils/permutation.py index a8a83ab..6621ae6 100644 --- a/fgutils/permutation.py +++ b/fgutils/permutation.py @@ -21,7 +21,7 @@ def generate_mapping_permutations(pattern, structure, wildcard=None): return mappings -class Mapper: +class PermutationMapper: def __init__(self, wildcard=None, ignore_case=False, can_map_to_nothing=[]): self.wildcard = wildcard self.ignore_case = ignore_case diff --git a/fgutils/query.py b/fgutils/query.py index 91902e7..b31502c 100644 --- a/fgutils/query.py +++ b/fgutils/query.py @@ -1,16 +1,17 @@ import copy -import collections +import networkx as nx from fgutils.utils import add_implicit_hydrogens +from fgutils.permutation import PermutationMapper from fgutils.mapping import map_pattern -from fgutils.fgconfig import FGConfig, build_FG_tree, FGTreeNode +from fgutils.fgconfig import FGConfig, FGConfigProvider, FGTreeNode -def is_functional_group(graph, index: int, config: FGConfig): +def is_functional_group(graph, index: int, config: FGConfig, mapper: PermutationMapper): max_id = len(graph) graph = add_implicit_hydrogens(copy.deepcopy(graph)) - is_fg, mapping = map_pattern(graph, index, config.pattern) + is_fg, mapping = map_pattern(graph, index, config.pattern, mapper) fg_indices = [] if is_fg: fg_indices = [ @@ -27,47 +28,102 @@ def is_functional_group(graph, index: int, config: FGConfig): key=lambda x: x[1], reverse=True, ): - if not is_fg: - break if last_len > apattern_size: last_len = apattern_size - is_match, _ = map_pattern(graph, index, apattern) + is_match, _ = map_pattern(graph, index, apattern, mapper) is_fg = is_fg and not is_match + if not is_fg: + break return is_fg, sorted(fg_indices) -def get_functional_groups_raw(graph) -> tuple[dict, list[str]]: - def _query(nodes: list[FGTreeNode], graph, idx, checked_groups=[]): - fg_groups = [] - fg_indices = [] +class FGQuery: + def __init__( + self, + use_smiles=False, + mapper: PermutationMapper | None = None, + config_provider: FGConfigProvider | None = None, + ): + self.use_smiles = use_smiles + self.mapper = ( + mapper + if mapper is not None + else PermutationMapper(wildcard="R", ignore_case=True) + ) + self.config_provider = ( + config_provider + if config_provider is not None + else FGConfigProvider(mapper=self.mapper) + ) + + def __find_best_node_rec(self, nodes: list[FGTreeNode], graph, idx): + best_node = None + node_indices = [] for node in nodes: - if node.fgconfig.name in checked_groups: - continue - is_fg, indices = is_functional_group(graph, idx, node.fgconfig) + is_fg, fg_indices = is_functional_group( + graph, idx, node.fgconfig, mapper=self.mapper + ) if is_fg: - checked_groups.append(node.fgconfig.name) - fg_groups.append(node.fgconfig.name) - fg_indices.append(indices) - _fg_groups, _fg_indices = _query( - node.children, graph, idx, checked_groups + r_node, r_indices = self.__find_best_node_rec( + node.children, + graph, + idx, ) - fg_groups.extend(_fg_groups) - fg_indices.extend(_fg_indices) - return fg_groups, fg_indices + if r_node is None: + best_node = node + node_indices = fg_indices + else: + best_node = r_node + node_indices = r_indices + return best_node, node_indices - fg_candidate_ids = [ - n_id for n_id, n_sym in graph.nodes(data="symbol") if n_sym not in ["H", "C"] - ] - roots = build_FG_tree() - idx_map = collections.defaultdict(lambda: []) - groups = [] - for atom_id in fg_candidate_ids: - fg_groups, fg_indices = _query(roots, graph, atom_id) - if len(fg_groups) > 0: - for _group, _indices in zip(fg_groups, fg_indices): - assert atom_id in _indices - _i = len(groups) - groups.append(_group) - for _idx in _indices: - idx_map[_idx].append(_i) - return dict(idx_map), groups + def __get_functional_groups(self, graph: nx.Graph) -> list[tuple[str, list[int]]]: + fg_candidate_ids = [ + n_id + for n_id, n_sym in graph.nodes(data="symbol") # type: ignore + if n_sym not in ["H", "C"] + ] + roots = self.config_provider.get_tree() + groups = [] + unidentified_ids = [] + while len(fg_candidate_ids) > 0: + atom_id = fg_candidate_ids.pop(0) + node, indices = self.__find_best_node_rec(roots, graph, atom_id) + if node is None: + unidentified_ids.append(atom_id) + else: + assert atom_id in indices + for i in indices: + if i in fg_candidate_ids: + fg_candidate_ids.remove(i) + elif i in unidentified_ids: + unidentified_ids.remove(i) + groups.append((node.fgconfig.name, indices)) + if len(unidentified_ids) > 0: + raise RuntimeError( + "Could not find a functional group for atom(s) {}.".format( + [ + "{}@{}".format(graph.nodes[i]["symbol"], i) + for i in unidentified_ids + ] + ) + ) + return groups + + def get(self, value) -> list[tuple[str, list[int]]]: + mol_graph = None + if isinstance(value, nx.Graph): + mol_graph = value + elif self.use_smiles: + import rdkit.Chem.rdmolfiles as rdmolfiles + from fgutils.rdkit import mol_to_graph + + mol = rdmolfiles.MolFromSmiles(value) + mol_graph = mol_to_graph(mol) + else: + raise ValueError( + "Can not interpret '{}' (type: {}) as mol graph.".format( + value, type(value) + ) + ) + return self.__get_functional_groups(mol_graph) # type: ignore diff --git a/fgutils/rdkit.py b/fgutils/rdkit.py new file mode 100644 index 0000000..8538873 --- /dev/null +++ b/fgutils/rdkit.py @@ -0,0 +1,23 @@ +import networkx as nx +import rdkit.Chem as Chem + + +def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph: + bond_order_map = { + "SINGLE": 1, + "DOUBLE": 2, + "TRIPLE": 3, + "QUADRUPLE": 4, + "AROMATIC": 1.5, + } + g = nx.Graph() + for atom in mol.GetAtoms(): + g.add_node(atom.GetIdx(), symbol=atom.GetSymbol()) + for bond in mol.GetBonds(): + bond_type = str(bond.GetBondType()).split(".")[-1] + g.add_edge( + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + bond=bond_order_map[bond_type], + ) + return g diff --git a/fgutils/utils.py b/fgutils/utils.py index 5de34c7..f5e94a8 100644 --- a/fgutils/utils.py +++ b/fgutils/utils.py @@ -21,27 +21,6 @@ def print_graph(graph): ) -def mol_to_graph(mol) -> nx.Graph: - bond_order_map = { - "SINGLE": 1, - "DOUBLE": 2, - "TRIPLE": 3, - "QUADRUPLE": 4, - "AROMATIC": 1.5, - } - g = nx.Graph() - for atom in mol.GetAtoms(): - g.add_node(atom.GetIdx(), symbol=atom.GetSymbol()) - for bond in mol.GetBonds(): - bond_type = str(bond.GetBondType()).split(".")[-1] - g.add_edge( - bond.GetBeginAtomIdx(), - bond.GetEndAtomIdx(), - bond=bond_order_map[bond_type], - ) - return g - - def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph: valence_dict = { 4: ["C", "Si"], @@ -55,14 +34,14 @@ def add_implicit_hydrogens(graph: nx.Graph) -> nx.Graph: valence_table[elmt] = v nodes = [ (n_id, n_sym) - for n_id, n_sym in graph.nodes(data="symbol") + for n_id, n_sym in graph.nodes(data="symbol") # type: ignore if n_sym not in ["R", "H"] ] for n_id, n_sym in nodes: assert ( n_sym in valence_table.keys() ), "Element {} not found in valence table.".format(n_sym) - bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")]) + bond_cnt = sum([b for _, _, b in graph.edges(n_id, data="bond")]) # type: ignore h_cnt = int(8 - valence_table[n_sym] - bond_cnt) assert h_cnt >= 0, "Negative hydrogen count." for h_id in range(len(graph), len(graph) + h_cnt): diff --git a/lint.sh b/lint.sh new file mode 100755 index 0000000..757ab62 --- /dev/null +++ b/lint.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics +flake8 . --count --exit-zero --max-complexity=13 --max-line-length=127 --per-file-ignores="__init__.py:F401" --statistics diff --git a/pyproject.toml b/pyproject.toml index b45b071..429f542 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fgutils" -version = "0.0.4" +version = "0.0.6" authors = [{name="Klaus Weinbauer", email="klaus@bioinf.uni-leipzig.de"}] description = "Library to get functional groups from molecular graphs." readme = "README.md" diff --git a/test/test_fgconfig.py b/test/test_fgconfig.py index ca4a8a2..67e5719 100644 --- a/test/test_fgconfig.py +++ b/test/test_fgconfig.py @@ -1,16 +1,15 @@ import pytest import networkx as nx +from fgutils.mapping import map_to_entire_graph +from fgutils.permutation import PermutationMapper from fgutils.fgconfig import ( + FGConfigProvider, FGConfig, FGTreeNode, search_parents, build_config_tree_from_list, - map_full, - build_FG_tree, - get_FG_by_name, - get_FG_list, - functional_group_config, + _default_fg_config, ) @@ -35,17 +34,20 @@ def _init_fgnode(name, pattern) -> FGTreeNode: return FGTreeNode(FGConfig(name=name, pattern=pattern)) +default_mapper = PermutationMapper(wildcard="R", ignore_case=True) + + def test_search_parent(): fg1 = _init_fgnode("1", "RC") fg2 = _init_fgnode("2", "RCR") - parents = search_parents([fg1], fg2) + parents = search_parents([fg1], fg2, mapper=default_mapper) assert parents == [fg1] def test_get_no_parent(): fg1 = _init_fgnode("1", "RO") fg2 = _init_fgnode("2", "RC") - parents = search_parents([fg1], fg2) + parents = search_parents([fg1], fg2, mapper=default_mapper) assert parents is None @@ -53,7 +55,7 @@ def test_get_correct_parent(): fg1 = _init_fgnode("1", "RC") fg2 = _init_fgnode("2", "RO") fg3 = _init_fgnode("3", "ROR") - parents = search_parents([fg1, fg2], fg3) + parents = search_parents([fg1, fg2], fg3, mapper=default_mapper) assert parents == [fg2] @@ -61,7 +63,7 @@ def test_get_multiple_parents(): fg1 = _init_fgnode("1", "RC") fg2 = _init_fgnode("2", "RO") fg3 = _init_fgnode("3", "RCO") - parents = search_parents([fg1, fg2], fg3) + parents = search_parents([fg1, fg2], fg3, mapper=default_mapper) assert parents is not None assert 2 == len(parents) assert all([fg in parents for fg in [fg1, fg2]]) @@ -75,7 +77,7 @@ def test_get_multiple_unique_parents(): fg11.children = [fg2] fg12.children = [fg2] fg2.parents = [fg11, fg12] - parents = search_parents([fg11, fg12], fg3) + parents = search_parents([fg11, fg12], fg3, mapper=default_mapper) assert parents == [fg2] @@ -85,7 +87,7 @@ def test_get_parent_recursive(): fg3 = _init_fgnode("3", "RCO") fg1.children = [fg2] fg2.parents = [fg1] - parents = search_parents([fg1], fg3) + parents = search_parents([fg1], fg3, mapper=default_mapper) assert parents == [fg2] @@ -149,7 +151,7 @@ def test_insert_child_between(): fg1 = FGConfig(name="1", pattern="RC") fg2 = FGConfig(name="2", pattern="RCR") fg3 = FGConfig(name="3", pattern="RCOH") - tree = build_config_tree_from_list([fg1, fg3, fg2]) + tree = build_config_tree_from_list([fg1, fg3, fg2], mapper=default_mapper) _assert_structure(tree, fg1, [], fg2) _assert_structure(tree, fg2, fg1, fg3) _assert_structure(tree, fg3, fg2) @@ -159,7 +161,7 @@ def test_insert_child_after(): fg1 = FGConfig(name="1", pattern="RC") fg2 = FGConfig(name="2", pattern="RCR") fg3 = FGConfig(name="3", pattern="RCOH") - tree = build_config_tree_from_list([fg1, fg2, fg3]) + tree = build_config_tree_from_list([fg1, fg2, fg3], mapper=default_mapper) _assert_structure(tree, fg1, [], fg2) _assert_structure(tree, fg2, fg1, fg3) _assert_structure(tree, fg3, fg2) @@ -169,7 +171,7 @@ def test_insert_new_root(): fg1 = FGConfig(name="1", pattern="RC") fg2 = FGConfig(name="2", pattern="RCR") fg3 = FGConfig(name="3", pattern="RCOH") - tree = build_config_tree_from_list([fg2, fg3, fg1]) + tree = build_config_tree_from_list([fg2, fg3, fg1], mapper=default_mapper) _assert_structure(tree, fg1, [], fg2) _assert_structure(tree, fg2, fg1, fg3) _assert_structure(tree, fg3, fg2) @@ -180,7 +182,7 @@ def test_insert_child_in_between_multiple(): fg2 = FGConfig(name="2", pattern="RCOR") fg31 = FGConfig(name="31", pattern="RCOH") fg32 = FGConfig(name="32", pattern="RC=O") - tree = build_config_tree_from_list([fg1, fg31, fg32, fg2]) + tree = build_config_tree_from_list([fg1, fg31, fg32, fg2], mapper=default_mapper) _assert_structure(tree, fg1, [], [fg2, fg32]) _assert_structure(tree, fg2, fg1, fg31) _assert_structure(tree, fg31, fg2) @@ -193,7 +195,9 @@ def test_insert_child_in_between_multiple_2(): fg31 = FGConfig(name="31", pattern="RCCCR") fg32 = FGConfig(name="32", pattern="RCOCR") fg4 = FGConfig(name="4", pattern="RCCCCR") - tree = build_config_tree_from_list([fg1, fg31, fg32, fg4, fg2]) + tree = build_config_tree_from_list( + [fg1, fg31, fg32, fg4, fg2], mapper=default_mapper + ) _assert_structure(tree, fg1, [], [fg2]) _assert_structure(tree, fg2, fg1, [fg31, fg32]) _assert_structure(tree, fg31, fg2, fg4) @@ -204,7 +208,7 @@ def test_multiple_parents(): fg1 = FGConfig(name="1", pattern="RC") fg2 = FGConfig(name="2", pattern="RO") fg3 = FGConfig(name="3", pattern="RCOH") - tree = build_config_tree_from_list([fg1, fg2, fg3]) + tree = build_config_tree_from_list([fg1, fg2, fg3], mapper=default_mapper) _assert_structure(tree, fg1, [], fg3) _assert_structure(tree, fg2, [], fg3) _assert_structure(tree, fg3, [fg1, fg2]) @@ -215,15 +219,18 @@ def _check_fg(node: FGTreeNode): for c in node.children: print("Test {} -> {}.".format(node.fgconfig.name, c.fgconfig.name)) assert node.fgconfig.pattern_len <= c.fgconfig.pattern_len - assert True is map_full(c.fgconfig.pattern, node.fgconfig.pattern) - assert False is map_full( - node.fgconfig.pattern, c.fgconfig.pattern + assert True is map_to_entire_graph( + c.fgconfig.pattern, node.fgconfig.pattern, mapper=default_mapper + ) + assert False is map_to_entire_graph( + node.fgconfig.pattern, c.fgconfig.pattern, mapper=default_mapper ), "Parent pattern {} contains child pattern {}.".format( node.fgconfig.pattern_str, c.fgconfig.pattern_str ) _check_fg(c) - for root_fg in build_FG_tree(): + provider = FGConfigProvider() + for root_fg in provider.get_tree(): _check_fg(root_fg) @@ -232,22 +239,25 @@ def _check_fg(node: FGTreeNode): [("carbonyl", 2), ("aldehyde", 3), ("ketone", 2), ("carboxylic_acid", 4)], ) def test_pattern_len(fg_group, exp_pattern_len): - fg = get_FG_by_name(fg_group) + provider = FGConfigProvider() + fg = provider.get_by_name(fg_group) assert exp_pattern_len == fg.pattern_len def test_config_name_uniqueness(): name_list = [] - for fg in get_FG_list(): + provider = FGConfigProvider() + for fg in provider.config_list: assert fg.name not in name_list, "Config name '{}' already exists.".format( fg.name ) name_list.append(fg.name) - assert len(functional_group_config) == len(name_list) + assert len(_default_fg_config) == len(name_list) def test_config_pattern_validity(): - for c in get_FG_list(): + provider = FGConfigProvider() + for c in provider.config_list: valid = False for _, sym in c.pattern.nodes(data="symbol"): # type: ignore if sym != "C": @@ -258,14 +268,15 @@ def test_config_pattern_validity(): def test_config_pattern_uniqueness(): pattern_list = [] - for fg in get_FG_list(): + provider = FGConfigProvider() + for fg in provider.config_list: assert ( fg.pattern_str not in pattern_list ), "Config pattern '{}' already exists with name '{}'.".format( fg.pattern_str, fg.name ) pattern_list.append(fg.pattern_str) - assert len(functional_group_config) == len(pattern_list) + assert len(_default_fg_config) == len(pattern_list) # def test_build_tree(): diff --git a/test/test_mapping.py b/test/test_mapping.py index 863db07..ee0695e 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -1,6 +1,9 @@ +from fgutils.permutation import PermutationMapper from fgutils.parse import parse from fgutils.mapping import map_anchored_pattern, map_pattern +default_mapper = PermutationMapper(wildcard="R", ignore_case=True) + def _assert_mapping(mapping, valid, exp_mapping=[]): assert mapping[0] == valid @@ -12,7 +15,7 @@ def test_simple_match(): exp_mapping = [(1, 0), (2, 1)] g = parse("CCO") p = parse("RO") - m = map_anchored_pattern(g, 2, p, 1) + m = map_anchored_pattern(g, 2, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -20,7 +23,7 @@ def test_branched_match(): exp_mapping = [(0, 0), (1, 1), (2, 2), (3, 3)] g = parse("CC(=O)O") p = parse("RC(=O)O") - m = map_anchored_pattern(g, 2, p, 2) + m = map_anchored_pattern(g, 2, p, 2, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -28,21 +31,21 @@ def test_ring_match(): exp_mapping = [(0, 2), (1, 1), (2, 0)] g = parse("C1CO1") p = parse("R1CC1") - m = map_anchored_pattern(g, 1, p, 1) + m = map_anchored_pattern(g, 1, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) def test_not_match(): g = parse("CC=O") p = parse("RC(=O)NR") - m = map_anchored_pattern(g, 2, p, 2) + m = map_anchored_pattern(g, 2, p, 2, mapper=default_mapper) _assert_mapping(m, False) def test_1(): g = parse("CC=O") p = parse("RC(=O)R") - m = map_anchored_pattern(g, 0, p, 3) + m = map_anchored_pattern(g, 0, p, 3, mapper=default_mapper) _assert_mapping(m, False) @@ -50,7 +53,7 @@ def test_2(): exp_mapping = [(0, 0), (1, 1), (2, 2)] g = parse("CC=O") p = parse("RC=O") - m = map_anchored_pattern(g, 2, p, 2) + m = map_anchored_pattern(g, 2, p, 2, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -58,7 +61,7 @@ def test_ignore_aromaticity(): exp_mapping = [(1, 0), (2, 1)] g = parse("c1c(=O)cccc1") p = parse("C=O") - m = map_anchored_pattern(g, 2, p, 1) + m = map_anchored_pattern(g, 2, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -66,7 +69,7 @@ def test_3(): exp_mapping = [(0, 4), (1, 3), (2, 1), (4, 2), (3, 0)] g = parse("COC(C)=O") p = parse("RC(=O)OR") - m = map_anchored_pattern(g, 4, p, 2) + m = map_anchored_pattern(g, 4, p, 2, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -74,7 +77,7 @@ def test_explore_wrong_branch(): exp_mapping = [(0, 2), (1, 1), (2, 0), (3, 3)] g = parse("COCO") p = parse("C(OR)O") - m = map_anchored_pattern(g, 1, p, 1) + m = map_anchored_pattern(g, 1, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -82,7 +85,7 @@ def test_match_pattern_to_mol(): exp_mapping = [(0, 2), (1, 0), (2, 1)] g = parse("NC(=O)C") p = parse("C(=O)N") - m = map_anchored_pattern(g, 2, p, 1) + m = map_anchored_pattern(g, 2, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -90,28 +93,37 @@ def test_match_hydrogen(): # H must be explicit g = parse("C=O") p = parse("C(H)=O") - m = map_anchored_pattern(g, 1, p, 2) + m = map_anchored_pattern(g, 1, p, 2, mapper=default_mapper) _assert_mapping(m, False) +def test_match_implicit_hydrogen(): + exp_mapping = [(0, 0), (1, 2)] + g = parse("C=O") + p = parse("C(H)=O") + mapper = PermutationMapper(can_map_to_nothing=["H"]) + m = map_anchored_pattern(g, 1, p, 2, mapper=mapper) + _assert_mapping(m, True, exp_mapping) + + def test_invalid_bond_match(): g = parse("C=O") p = parse("CO") - m = map_anchored_pattern(g, 0, p, 0) + m = map_anchored_pattern(g, 0, p, 0, mapper=default_mapper) _assert_mapping(m, False) def test_match_not_entire_pattern(): g = parse("C=O") p = parse("C(=O)C") - m = map_anchored_pattern(g, 0, p, 0) + m = map_anchored_pattern(g, 0, p, 0, mapper=default_mapper) _assert_mapping(m, False) def test_start_with_match_to_nothing(): g = parse("CCO") p = parse("HO") - m = map_anchored_pattern(g, 2, p, 0) + m = map_anchored_pattern(g, 2, p, 0, mapper=default_mapper) _assert_mapping(m, False) @@ -119,7 +131,7 @@ def test_match_explicit_hydrogen(): exp_mapping = [(2, 1), (3, 0)] g = parse("CCOH") p = parse("HO") - m = map_anchored_pattern(g, 2, p, 1) + m = map_anchored_pattern(g, 2, p, 1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -127,7 +139,7 @@ def test_map_pattern_with_anchor(): exp_mapping = [(2, 1), (1, 0)] g = parse("CCO") p = parse("CO") - m = map_pattern(g, 2, p, pattern_anchor=1) + m = map_pattern(g, 2, p, pattern_anchor=1, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -135,7 +147,7 @@ def test_map_pattern_without_anchor(): exp_mapping = [(2, 1), (1, 0)] g = parse("CCO") p = parse("CO") - m = map_pattern(g, 2, p) + m = map_pattern(g, 2, p, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) @@ -143,19 +155,19 @@ def test_map_empty_pattern(): exp_mapping = [] g = parse("CCO") p = parse("") - m = map_pattern(g, 2, p) + m = map_pattern(g, 2, p, mapper=default_mapper) _assert_mapping(m, True, exp_mapping) def test_map_invalid_pattern(): g = parse("CCO") p = parse("Cl") - m = map_pattern(g, 2, p) + m = map_pattern(g, 2, p, mapper=default_mapper) _assert_mapping(m, False) def test_map_specific_pattern_to_general_graph(): g = parse("R") p = parse("C") - m = map_pattern(g, 0, p) + m = map_pattern(g, 0, p, mapper=default_mapper) _assert_mapping(m, False) diff --git a/test/test_parse.py b/test/test_parse.py index 26f3df5..ef9009d 100644 --- a/test/test_parse.py +++ b/test/test_parse.py @@ -12,10 +12,11 @@ def _assert_graph(g, exp_nodes, exp_edges): assert order == g.edges[i1, i2]["bond"] -def test_tokenize(): - def _ct(token, exp_type, exp_value, exp_col): - return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col +def _ct(token, exp_type, exp_value, exp_col): + return token[0] == exp_type and token[1] == exp_value and token[2] == exp_col + +def test_tokenize(): it = tokenize("RC(=O)OR") assert True is _ct(next(it), "WILDCARD", "R", 0) assert True is _ct(next(it), "ATOM", "C", 1) @@ -27,6 +28,13 @@ def _ct(token, exp_type, exp_value, exp_col): assert True is _ct(next(it), "WILDCARD", "R", 7) +def test_tokenize_multichar(): + it = tokenize("RClR") + assert True is _ct(next(it), "WILDCARD", "R", 0) + assert True is _ct(next(it), "ATOM", "Cl", 1) + assert True is _ct(next(it), "WILDCARD", "R", 3) + + def test_branch(): exp_nodes = {0: "R", 1: "C", 2: "O", 3: "O", 4: "R"} exp_edges = [(0, 1, 1), (1, 2, 2), (1, 3, 1), (3, 4, 1)] diff --git a/test/test_permutation.py b/test/test_permutation.py index 15fc7b8..9ed8fa5 100644 --- a/test/test_permutation.py +++ b/test/test_permutation.py @@ -1,7 +1,7 @@ import pytest import copy -from fgutils.permutation import generate_mapping_permutations, Mapper +from fgutils.permutation import generate_mapping_permutations, PermutationMapper @pytest.mark.parametrize( @@ -113,7 +113,7 @@ def test_single_wildcard(pattern, structure, exp_mapping): ], ) def test_case_insensitivity(pattern, structure, exp_mapping): - mapper = Mapper(wildcard="R", ignore_case=True) + mapper = PermutationMapper(wildcard="R", ignore_case=True) m = mapper.permute(pattern, structure) assert exp_mapping == m @@ -129,7 +129,7 @@ def test_case_insensitivity(pattern, structure, exp_mapping): ], ) def test_map_to_nothing(pattern, structure, exp_mapping): - mapper = Mapper(can_map_to_nothing="A") + mapper = PermutationMapper(can_map_to_nothing="A") input_structure = copy.deepcopy(structure) m = mapper.permute(pattern, structure) assert input_structure == structure @@ -146,7 +146,7 @@ def test_map_to_nothing(pattern, structure, exp_mapping): ], ) def test_multiple_map_to_nothing(pattern, structure, exp_mapping): - mapper = Mapper(can_map_to_nothing=["A", "B"]) + mapper = PermutationMapper(can_map_to_nothing=["A", "B"]) m = mapper.permute(pattern, structure) assert exp_mapping == m @@ -161,7 +161,7 @@ def test_multiple_map_to_nothing(pattern, structure, exp_mapping): ], ) def test_wildcard_and_map_to_nothing(pattern, structure, exp_mapping): - mapper = Mapper(wildcard="R", can_map_to_nothing="R") + mapper = PermutationMapper(wildcard="R", can_map_to_nothing="R") m = mapper.permute(pattern, structure) assert exp_mapping == m @@ -176,7 +176,7 @@ def test_wildcard_and_map_to_nothing(pattern, structure, exp_mapping): ], ) def test_wildcard_and_multi_map_to_nothing(pattern, structure, exp_mapping, cmtn): - mapper = Mapper(wildcard="R", can_map_to_nothing=cmtn) + mapper = PermutationMapper(wildcard="R", can_map_to_nothing=cmtn) m = mapper.permute(pattern, structure) print(m) assert exp_mapping == m @@ -192,12 +192,12 @@ def test_wildcard_and_multi_map_to_nothing(pattern, structure, exp_mapping, cmtn ], ) def test_chem_map_hydrogen_and_wildcard(structure, exp_mapping): - mapper = Mapper(wildcard="R", can_map_to_nothing=["R", "H"]) + mapper = PermutationMapper(wildcard="R", can_map_to_nothing=["R", "H"]) m = mapper.permute(["O", "H", "R"], structure) assert exp_mapping == m def test_map_specific_to_general(): - mapper = Mapper(wildcard="R", can_map_to_nothing=["R"]) + mapper = PermutationMapper(wildcard="R", can_map_to_nothing=["R"]) m = mapper.permute(["C"], ["R"]) assert [] == m diff --git a/test/test_query.py b/test/test_query.py index 02906a0..f345d03 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,69 +1,15 @@ import pytest -import collections import rdkit.Chem.rdmolfiles as rdmolfiles -from fgutils.query import get_functional_groups_raw, is_functional_group +from fgutils.permutation import PermutationMapper +from fgutils.fgconfig import FGConfigProvider +from fgutils.query import FGQuery, is_functional_group from fgutils.parse import parse -from fgutils.fgconfig import get_FG_by_name -from fgutils.utils import mol_to_graph +from fgutils.rdkit import mol_to_graph - -def _get_group_map(index_map): - group_map = collections.defaultdict(lambda: []) - for idx, group_indices in index_map.items(): - for group_idx in group_indices: - assert idx not in group_map[group_idx] - group_map[group_idx].append(idx) - return dict(group_map) - - -def _assert_fg(raw_result, fg_name, indices): - r_idxmap, r_groups = raw_result - assert ( - indices[0] in r_idxmap.keys() - ), "No functional group found for atom {}.".format(indices[0]) - groups = [r_groups[i] for i in r_idxmap[indices[0]]] - assert ( - fg_name in groups - ), "Could not find functional group '{}' for for index {}.".format( - fg_name, indices[0] - ) - group_i = -1 - for i, g in enumerate(groups): - if g == fg_name: - assert group_i == -1, "Found group multipel times ({}).".format(groups) - group_i = i - group_map = _get_group_map(r_idxmap) - assert len(group_map[group_i]) == len( - indices - ), "Expected group '{}' to have {} atoms but found {}.".format( - fg_name, len(indices), len(group_map[group_i]) - ) - for i in group_map[group_i]: - assert ( - i in indices - ), "Could not find index {} in functional group {} (Indices: {}).".format( - i, fg_name, group_map[group_i] - ) - - -def _assert_not_fg(raw_result, fg_name, indices): - r_idxmap, r_groups = raw_result - for idx in indices: - groups = [r_groups[i] for i in r_idxmap[idx]] - assert ( - fg_name not in groups - ), "Wrongly identified functional group '{}' for for index {}.".format( - fg_name, indices[0] - ) - - -def test_get_functional_groups_raw(): - mol = parse("C=O") - r = get_functional_groups_raw(mol) - _assert_fg(r, "carbonyl", [0, 1]) - _assert_fg(r, "ketone", [0, 1]) - _assert_fg(r, "aldehyde", [0, 1]) +default_mapper = PermutationMapper(wildcard="R", ignore_case=True) +default_config_provider = FGConfigProvider(mapper=default_mapper) +default_query = FGQuery(mapper=default_mapper, config_provider=default_config_provider) @pytest.mark.parametrize( @@ -72,24 +18,38 @@ def test_get_functional_groups_raw(): ("carbonyl", "CC(=O)O", 2, [1, 2]), ("carboxylic_acid", "CC(=O)O", 2, [1, 2, 3]), ("amide", "C(=O)N", 2, [0, 1, 2]), + ("acyl_chloride", "CC(=O)[Cl]", 3, [1, 2, 3]), ], ) def test_get_functional_group(name, smiles, anchor, exp_indices): - fg = get_FG_by_name(name) + fg = default_config_provider.get_by_name(name) mol = mol_to_graph(rdmolfiles.MolFromSmiles(smiles)) - is_fg, indices = is_functional_group(mol, anchor, fg) + is_fg, indices = is_functional_group(mol, anchor, fg, mapper=default_mapper) assert is_fg assert len(exp_indices) == len(indices) assert exp_indices == indices +def test_get_functional_groups(): + mol = parse("C=O") + groups = default_query.get(mol) + assert ("aldehyde", [0, 1]) in groups + + +def test_get_functional_group_once(): + mol = parse("CC(=O)OC") + groups = default_query.get(mol) + assert 1 == len(groups) + assert ("ester", [1, 2, 3]) in groups + + @pytest.mark.parametrize( "smiles,functional_groups,exp_indices", [ pytest.param("C=O", ["aldehyde"], [[0, 1]], id="Formaldehyde"), pytest.param("C(=O)N", ["amide"], [[0, 1, 2]], id="Formamide"), pytest.param("NC(=O)CC(N)C(=O)O", ["amide"], [[0, 1, 2]], id="Asparagine"), - pytest.param("CC(=O)[Cl]", ["carbonyl"], [[1, 2]], id="Acetyl cloride"), + pytest.param("[Cl]C(=O)C", ["acyl_chloride"], [[0, 1, 2]], id="Acetyl cloride"), pytest.param("COC(C)=O", ["ester"], [[1, 2, 4]], id="Methyl acetate"), pytest.param("CC(=O)O", ["carboxylic_acid"], [[1, 2, 3]], id="Acetic acid"), pytest.param("NCC(=O)O", ["amine"], [[0]], id="Glycin"), @@ -103,12 +63,18 @@ def test_get_functional_group(name, smiles, anchor, exp_indices): pytest.param( "CSC(=O)c1ccccc1", ["thioester"], [[1, 2, 3]], id="Methyl thionobenzonat" ), + pytest.param( + "O=C(C)Oc1ccccc1C(=O)O", + ["ester", "carboxylic_acid"], + [[0, 1, 3], [10, 11, 12]], + id="Acetylsalicylic acid", + ), # pytest.param("", [""], [[]], id=""), ], ) def test_functional_group_on_compound(smiles, functional_groups, exp_indices): assert len(functional_groups) == len(exp_indices) mol = mol_to_graph(rdmolfiles.MolFromSmiles(smiles)) - r = get_functional_groups_raw(mol) + groups = default_query.get(mol) for fg, indices in zip(functional_groups, exp_indices): - _assert_fg(r, fg, indices) + assert (fg, indices) in groups