Skip to content

Commit

Permalink
documentation and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
klausweinbauer committed Aug 29, 2024
1 parent 89a7f5a commit 03dda30
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 49 deletions.
154 changes: 110 additions & 44 deletions fgutils/proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import random
import numpy as np
import networkx as nx

from fgutils.utils import split_its
Expand All @@ -19,6 +17,9 @@ def relabel_graph(g):
class GraphSampler:
"""
Base class for sampling ProxyGraphs.
:param unique: If set to true each graph is only returned once. (Default =
False)
"""

def __init__(self, unique=False):
Expand All @@ -27,7 +28,10 @@ def __init__(self, unique=False):

def sample(self, graphs: list[ProxyGraph]) -> list[ProxyGraph] | None:
"""
Method to retrive a new sample from a list of graphs.
Method to retrive a new sample from a list of graphs. If unique is set
to true the first graph that was not yet returned is selected. None is
returned if all graphs have been selected. If unique is false all
graphs are returned each time the function is called.
:param graphs: A list of graphs to sample from.
Expand All @@ -49,8 +53,8 @@ def __call__(self, graphs: list[ProxyGraph]) -> list[ProxyGraph] | None:

class ProxyGraph:
"""
ProxyGraph is a essentially a subgraph used to expand molecules. If the
node that is replaced by the pattern has more edges than the pattern has
ProxyGraph is essentially a subgraph used to expand molecules. If the node
that is replaced by the pattern has more edges than the pattern has
anchors, the last anchor will be used multiple times. In the default case
the first node in the pattern will connect to all the neighboring nodes of
the replaced node.
Expand All @@ -77,15 +81,14 @@ class ProxyGroup:
the patterns will be replaced
:param name: The name of the group.
:param graphs: (optional) A list of subgraphs in this group.
:param pattern: (optional) A list of graph descriptions. The patterns are
converted to ProxyGraphs with one anchor at index 0. If you need more
control over how the subgraphs are inserted use the ``graphs``
argument.
:param graphs: (optional) A list of subgraphs or a list of graph
descriptions. The patterns are converted to ProxyGraphs with one anchor
at index 0. Use ProxyGraph objects if you need more control over how
subgraphs are instantiated.
:param sampler: (optional) An object or a function to retrive individual
graphs from the list. The expected function interface is:
``func(list[ProxyGraph]) -> list[ProxyGraph]``. Implement the ``__call__``
method if you use a class.
``func(list[ProxyGraph]) -> list[ProxyGraph]``. Implement the
``__call__`` method if you use a class.
:param unique: Argument to specify if graphs can be returned multiple
times. This only takes effect if sampler is not set. (Default = False)
"""
Expand All @@ -109,6 +112,8 @@ def __str__(self):

@property
def graphs(self) -> list[ProxyGraph]:
"""The list of ProxyGraphs. This can be set with values of the
following type: ``str | list[str] | ProxyGraph | list[ProxyGraph]``"""
return self.__graphs

@graphs.setter
Expand Down Expand Up @@ -136,7 +141,27 @@ def graphs(self, value):
raise ValueError("Group '{}' has no graphs.".format(self.name))

@staticmethod
def from_json_single(name, config: dict) -> ProxyGroup:
def from_dict_single(name: str, config: dict) -> ProxyGroup:
"""Load a single ProxyGroup object from configuration. The expected
JSON format is one of the following::
{
# short form
"graphs": <pattern>,
"graphs": [<pattern>],
"graphs": {"pattern": <pattern>, "anchor": list[int]},
# complete config
"graphs": [{"pattern": <pattern>, "anchor": list[int]}],
# <pattern> is the SMILES-like graph description of type str
}
:param name: The name of the ProxyGroup.
:param config: The configuration dictionary. E.g. loaded from JSON
file.
:returns: The instantiated ProxyGroup.
"""
graph_configs = config.get("graphs", [])
graphs = []
if isinstance(graph_configs, str):
Expand All @@ -150,14 +175,29 @@ def from_json_single(name, config: dict) -> ProxyGroup:
return ProxyGroup(name, graphs)

@staticmethod
def from_json(config: dict) -> dict[str, ProxyGroup]:
def from_dict(config: dict) -> dict[str, ProxyGroup]:
"""Load ProxyGroups from dict config. The expected JSON format is::
{
"group_name_1": <pattern>,
"group_name_2": [<pattern>],
"group_name_3": <group_config> # for expected format take a
# look at ProxyGroup.from_dict_single()
}
:param config: The configuration dictionary. E.g. loaded from JSON
file.
:returns: Returns a mapping dictionary of ProxyGroups where the key is
the group name and the value is the ProxyGroup object.
"""
groups = {}
for name, group_config in config.items():
if isinstance(group_config, str):
group_config = [group_config]
if isinstance(group_config, list):
group_config = {"graphs": group_config}
group = ProxyGroup.from_json_single(name, group_config)
group = ProxyGroup.from_dict_single(name, group_config)
groups[name] = group
return groups

Expand Down Expand Up @@ -189,13 +229,13 @@ def _get_next_group_node(g: nx.Graph, groups: dict[str, ProxyGroup]) -> int | No


def replace_node(graph, node, replacement_graph: ProxyGraph, parser: Parser):
"""Replace ``node`` in ``graph`` with ``replacement_graph``.
"""Replace node in graph with replacement_graph converted by the parser.
:param graph: The graph where a node should be replace by a subgraph.
:param node: The node to replace in graph.
:param replacement_graph: The subgraph that is inserted instead of the
node.
:param parser: The parser to use to convert the pattern into structure.
:param parser: The parser to convert the pattern into structure.
:returns: Returns a new graph with ``node`` replace by
``replacement_graph``.
Expand All @@ -214,17 +254,16 @@ def replace_node(graph, node, replacement_graph: ProxyGraph, parser: Parser):
return graph


def replace_next_node(graph, groups, parser: Parser):
"""Replace the next labeled node in ``graph`` with the respective group.
def replace_next_node(graph, groups: dict[str, ProxyGroup], parser: Parser):
"""Replace the next labeled node in graph with the respective group.
:param graph: The graph where a node should be replace by a subgraph.
:param groups: A list of groups to replace the labeled nodes in the parent
with. The dictionary keys must be the group names.
:param groups: A mapping dictionary of groups to replace the labeled nodes
in the parent with. The dictionary keys must be the group name.
:param parser: The parser to use to convert the pattern into structure.
:returns: Returns a list of new graphs with ``node`` replace with the
groups selected by the label sampler. Or None if nothing is left to
replace.
:returns: Returns a list of new graphs where the first labeled node is
replaced. None is returned if no replaceable labeled node is left.
"""
result_graphs = []
anchor = _get_next_group_node(graph, groups)
Expand Down Expand Up @@ -261,16 +300,14 @@ def build_graphs(
groups: dict[str, ProxyGroup],
parser: Parser,
):
"""
Replace labeled nodes in the core graph with groups. For each labeled node
the label sampler will select a set of labels (group names) used to replace
the node by the specified subgraphs. Selecting multiple labels per labeled
node results in multiple result graphs.
"""Replace labeled nodes in the core graph with groups. For each labeled
node the respective group is used to replace the node by the specified
subgraphs.
:param core: The parent graph with labeled nodes.
:param groups: A list of groups to replace the labeled nodes in the parent
with. The dictionary keys must be the group names.
:param parser: The parser that is used to convert subgraph patterns into
:param groups: A list of groups to replace the labeled nodes in the core
graph with. The dictionary keys must be the group names.
:param parser: The parser that is used to convert graph patterns into
graphs.
:returns: Returns a list of graphs with replaced nodes.
Expand All @@ -290,10 +327,9 @@ def build_graphs(


class Proxy:
"""
Proxy is a generator class. It randomly extends a specific core graph by a
set of subgraphs (groups). This class implements the iterator interface so
it can be used in a for loop to generate samples::
""" Proxy is a generator class. It extends a specific core graph by a set
of subgraphs (groups). This class implements the iterator interface so it
can be used in a for loop to generate samples::
>>> proxy = Proxy("C{g}", ProxyGroup("g", ["C", "O", "N"], unique=True))
>>> for graph in proxy:
Expand All @@ -305,8 +341,9 @@ class Proxy:
:param core: A pattern string or ProxyGroup representing the core graph.
For example a specific functional group or a reaction center.
:param groups: A list of groups to expand the core graph with.
:param enable_aam: Flag to specify if the 'aam' label is set in the graph.
:param parser: The parser to use to convert patterns into structures.
:param enable_aam: Flag to specify if the 'aam' label is set in the result
graph. (Default = True)
:param parser: (optional) The parser to convert patterns into structures.
"""

def __init__(
Expand All @@ -330,6 +367,9 @@ def __init__(

@property
def groups(self) -> list[ProxyGroup]:
""" The list of ProxyGroups. This can be set with values of the
following type: ``ProxyGroup | list[ProxyGroup] | dict[str,
ProxyGroup]``"""
return list(self.__groups.values())

@groups.setter
Expand All @@ -354,8 +394,20 @@ def __str__(self):
return s

@staticmethod
def from_json(config: dict) -> Proxy:
config["groups"] = ProxyGroup.from_json(config["groups"])
def from_dict(config: dict) -> Proxy:
"""Load Proxy from dict config. The expected JSON format is::
{
"core": <pattern>,
"groups": <groups> # for expected format take a
# look at ProxyGroup.from_dict()
}
:param config: The configuration dictionary. E.g. loaded from JSON
:returns: The instantiated Proxy.
"""
config["groups"] = ProxyGroup.from_dict(config["groups"])
return Proxy(**config)

def __generate(self):
Expand All @@ -376,6 +428,10 @@ def __generate(self):
return

def get_next(self):
""" Get the next sample.
:returns: A generated graph.
"""
return next(self.__active_generator)

def __iter__(self):
Expand All @@ -388,12 +444,18 @@ def __next__(self):
class ReactionProxy(Proxy):
"""
Proxy to generate reactions.
"""
:param core: A pattern string or ProxyGroup representing the core graph.
For example a specific functional group or a reaction center.
:param groups: A list of groups to expand the core graph with.
:param enable_aam: Flag to specify if the 'aam' label is set in the result
graph. (Default = True)
:param parser: (optional) The parser to convert patterns into structures.
"""
def __init__(
self,
core: str | list[str] | ProxyGroup,
groups: ProxyGroup | list[ProxyGroup] | dict[str, ProxyGroup] = [],
groups: ProxyGroup | list[ProxyGroup] | dict[str, ProxyGroup],
enable_aam: bool = True,
parser: Parser | None = None,
):
Expand All @@ -413,12 +475,16 @@ def get_next(self):
class MolProxy(Proxy):
"""
Proxy to generate molecules.
"""
:param core: A pattern string or ProxyGroup representing the core graph.
For example a specific functional group.
:param groups: A list of groups to expand the core graph with.
:param parser: (optional) The parser to convert patterns into structures.
"""
def __init__(
self,
core: str | list[str] | ProxyGroup,
groups: ProxyGroup | list[ProxyGroup] | dict[str, ProxyGroup] = [],
groups: ProxyGroup | list[ProxyGroup] | dict[str, ProxyGroup],
parser: Parser | None = None,
):
super().__init__(core, groups, False, parser)
Expand Down
2 changes: 1 addition & 1 deletion fgutils/proxy_collection/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
]
simple_groups += alkyl_groups

carbon_chain_max_length = 4 #10
carbon_chain_max_length = 4 # 10
carbon_chains = [ProxyGroup("CC1", "C")] + [
ProxyGroup("CC{}".format(i), ProxyGraph("C" * i, anchor=[0, i - 1]))
for i in range(2, carbon_chain_max_length + 1)
Expand Down
3 changes: 0 additions & 3 deletions fgutils/proxy_collection/diels_alder_proxy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import random
import collections

from fgutils import ReactionProxy
from fgutils.proxy import ProxyGroup, ProxyGraph

Expand Down
1 change: 0 additions & 1 deletion test/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
GraphSampler,
build_group_tree,
build_graphs,
replace_next_node,
)


Expand Down

0 comments on commit 03dda30

Please sign in to comment.