Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test release #1

Merged
merged 8 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,12 @@ pip install fgutils
## Getting Started
A simple example querying the functional groups for acetylsalicylic acid.
```
import fgutils
import rdkit.Chem.rdmolfiles as rdmolfiles
from fgutils.utils import mol_to_graph

smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid
mol_graph = mol_to_graph(rdmolfiles.MolFromSmiles(smiles))
index_map, groups = fgutils.get_functional_groups_raw(mol_graph)
print(index_map, groups)
```

The output is an index map and a list of functional group names.

```
{0: [0, 1, 2], 1: [0, 1, 2], 3: [2, 3]}
['carbonyl', 'ketone', 'ester', 'ether']
>>> from fgutils import FGQuery
>>>
>>> smiles = "O=C(C)Oc1ccccc1C(=O)O" # acetylsalicylic acid
>>> query = FGQuery(use_smiles=True) # use_smiles requires rdkit to be installed
>>> query.get(smiles)
[('ester', [0, 1, 3]), ('carboxylic_acid', [10, 11, 12])]
```

The index map maps atom numbers to functional groups. The key of the map is the atom index and the value is the list of indices from the functional group list. E.g. atom 0 (oxygen) is in functional groups carbonyl, ketone, and ester.
The output is a list of tuples containing the functional group name and the corresponding atom indices.
3 changes: 2 additions & 1 deletion fgutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .query import get_functional_groups_raw
from .permutation import PermutationMapper
from .query import FGQuery
126 changes: 59 additions & 67 deletions fgutils/fgconfig.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
import numpy as np

from fgutils.permutation import PermutationMapper
from fgutils.parse import parse
from fgutils.mapping import map_pattern
from fgutils.mapping import map_to_entire_graph

functional_group_config = [
_default_fg_config = [
{
"name": "carbonyl",
"pattern": "C(=O)",
Expand Down Expand Up @@ -45,6 +46,7 @@
{"name": "anilin", "pattern": "C:CN(R)R", "group_atoms": [2]},
{"name": "ketene", "pattern": "RC(R)=C=O", "group_atoms": [1, 3, 4]},
{"name": "carbamate", "pattern": "ROC(=O)N(R)R", "group_atoms": [1, 2, 3, 4]},
{"name": "acyl_chloride", "pattern": "RC(=O)Cl", "group_atoms": [1, 2, 3]},
]


Expand Down Expand Up @@ -100,9 +102,9 @@ def pattern_len(self) -> int:
)


def is_subgroup(parent: FGConfig, child: FGConfig) -> bool:
p2c = map_full(child.pattern, parent.pattern)
c2p = map_full(parent.pattern, child.pattern)
def is_subgroup(parent: FGConfig, child: FGConfig, mapper: PermutationMapper) -> bool:
p2c = map_to_entire_graph(child.pattern, parent.pattern, mapper)
c2p = map_to_entire_graph(parent.pattern, child.pattern, mapper)
if p2c:
assert c2p is False, "{} ({}) -> {} ({}) matches in both directions.".format(
parent.name, parent.pattern_str, child.name, child.pattern_str
Expand All @@ -111,26 +113,11 @@ def is_subgroup(parent: FGConfig, child: FGConfig) -> bool:
return False


class TreeNode:
def __init__(self, is_child_callback):
self.parents: list[TreeNode] = []
self.children: list[TreeNode] = []
self.is_child_callback = is_child_callback

def is_child(self, parent: TreeNode) -> bool:
return self.is_child_callback(parent, self)

def add_child(self, child: TreeNode):
child.parents.append(self)
self.children.append(child)


class FGTreeNode(TreeNode):
class FGTreeNode:
def __init__(self, fgconfig: FGConfig):
self.fgconfig = fgconfig
self.parents: list[FGTreeNode]
self.children: list[FGTreeNode]
super().__init__(lambda a, b: is_subgroup(a.fgconfig, b.fgconfig))
self.parents: list[FGTreeNode] = []
self.children: list[FGTreeNode] = []

def order_id(self):
return (
Expand All @@ -140,35 +127,12 @@ def order_id(self):
)

def add_child(self, child: FGTreeNode):
super().add_child(child)
child.parents.append(self)
self.children.append(child)
self.parents = sorted(self.parents, key=lambda x: x.order_id(), reverse=True)
self.children = sorted(self.children, key=lambda x: x.order_id(), reverse=True)


fg_configs = None


def get_FG_list() -> list[FGConfig]:
global fg_configs
if fg_configs is None:
c = []
for fgc in functional_group_config:
c.append(FGConfig(**fgc))
fg_configs = c
return fg_configs


def get_FG_by_name(name: str) -> FGConfig:
for fg in get_FG_list():
if fg.name == name:
return fg
raise KeyError("No functional group config with name '{}' found.".format(name))


def get_FG_names() -> list[str]:
return [c.name for c in get_FG_list()]


def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig]:
return list(
sorted(
Expand All @@ -179,19 +143,13 @@ def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig
)


def map_full(graph, pattern):
for i in range(len(graph)):
r, _ = map_pattern(graph, i, pattern)
if r is True:
return True
return False


def search_parents(roots: list[TreeNode], child: TreeNode) -> None | list[TreeNode]:
def search_parents(
roots: list[FGTreeNode], child: FGTreeNode, mapper: PermutationMapper
) -> None | list[FGTreeNode]:
parents = set()
for root in roots:
if child.is_child(root):
_parents = search_parents(root.children, child)
if is_subgroup(root.fgconfig, child.fgconfig, mapper):
_parents = search_parents(root.children, child, mapper)
if _parents is None:
parents.add(root)
else:
Expand Down Expand Up @@ -221,11 +179,13 @@ def _print(node: FGTreeNode, indent=0):
_print(root)


def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode]:
def build_config_tree_from_list(
config_list: list[FGConfig], mapper: PermutationMapper
) -> list[FGTreeNode]:
roots = []
for config in sort_by_pattern_len(config_list):
node = FGTreeNode(config)
parents = search_parents(roots, node)
parents = search_parents(roots, node, mapper)
if parents is None:
roots.append(node)
else:
Expand All @@ -234,11 +194,43 @@ def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode]
return roots


_fg_tree_roots = None
class FGConfigProvider:
def __init__(
self,
config: list[dict] | list[FGConfig] | None = None,
mapper: PermutationMapper | None = None,
):
self.config_list: list[FGConfig] = []
if config is None:
config = _default_fg_config
if isinstance(config, list) and len(config) > 0:
if isinstance(config[0], dict):
for fgc in config:
self.config_list.append(FGConfig(**fgc)) # type: ignore
elif isinstance(config[0], FGConfig):
self.config_list = config # type: ignore
else:
raise ValueError("Invalid config value.")
else:
raise ValueError("Invalid config value.")

self.mapper = (
mapper
if mapper is not None
else PermutationMapper(wildcard="R", ignore_case=True)
)

self.__tree_roots = None

def get_tree(self) -> list[FGTreeNode]:
if self.__tree_roots is None:
self.__tree_roots = build_config_tree_from_list(
self.config_list, self.mapper
)
return self.__tree_roots

def build_FG_tree() -> list[FGTreeNode]:
global _fg_tree_roots
if _fg_tree_roots is None:
_fg_tree_roots = build_config_tree_from_list(get_FG_list())
return _fg_tree_roots
def get_by_name(self, name: str) -> FGConfig:
for fg in self.config_list:
if fg.name == name:
return fg
raise KeyError("No functional group config with name '{}' found.".format(name))
28 changes: 21 additions & 7 deletions fgutils/mapping.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import networkx as nx

from fgutils.permutation import Mapper
from fgutils.permutation import PermutationMapper


def _get_neighbors(graph, idx, excluded_nodes=set()):
Expand All @@ -17,10 +17,12 @@ def _get_symbol(graph, idx):


def map_anchored_pattern(
graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: int
graph: nx.Graph,
anchor: int,
pattern: nx.Graph,
pattern_anchor: int,
mapper: PermutationMapper,
):
mapper = Mapper(wildcard="R", ignore_case=True)

def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0):
visited_nodes = copy.deepcopy(visited_nodes)
visited_nodes.add(idx)
Expand Down Expand Up @@ -90,15 +92,27 @@ def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0):


def map_pattern(
graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: None | int = None
graph: nx.Graph,
anchor: int,
pattern: nx.Graph,
mapper: PermutationMapper,
pattern_anchor: None | int = None,
):
if pattern_anchor is None:
if len(pattern) == 0:
return True, []
for pidx in pattern.nodes:
result = map_anchored_pattern(graph, anchor, pattern, pidx)
result = map_anchored_pattern(graph, anchor, pattern, pidx, mapper)
if result[0]:
return result
return False, []
else:
return map_anchored_pattern(graph, anchor, pattern, pattern_anchor)
return map_anchored_pattern(graph, anchor, pattern, pattern_anchor, mapper)


def map_to_entire_graph(graph: nx.Graph, pattern: nx.Graph, mapper: PermutationMapper):
for i in range(len(graph)):
r, _ = map_pattern(graph, i, pattern, mapper)
if r is True:
return True
return False
2 changes: 1 addition & 1 deletion fgutils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def generate_mapping_permutations(pattern, structure, wildcard=None):
return mappings


class Mapper:
class PermutationMapper:
def __init__(self, wildcard=None, ignore_case=False, can_map_to_nothing=[]):
self.wildcard = wildcard
self.ignore_case = ignore_case
Expand Down
Loading
Loading