diff --git a/.img/network_layouts.png b/.img/network_layouts.png index a1f644c..b001f22 100644 Binary files a/.img/network_layouts.png and b/.img/network_layouts.png differ diff --git a/src/konnektor/network_planners/__init__.py b/src/konnektor/network_planners/__init__.py index 3fef6b4..4f811ea 100644 --- a/src/konnektor/network_planners/__init__.py +++ b/src/konnektor/network_planners/__init__.py @@ -8,6 +8,7 @@ NNodeEdgesNetworkGenerator) ## Starmap Like Networks from .generators.star_network_generator import StarNetworkGenerator, RadialLigandNetworkPlanner +from .generators.twin_star_network_generator import TwinStarNetworkGenerator from .generators.clustered_network_generator import StarrySkyNetworkGenerator ## MST like Networks diff --git a/src/konnektor/network_planners/_networkx_implementations/bipartite_match_algorithm.py b/src/konnektor/network_planners/_networkx_implementations/bipartite_match_algorithm.py index 09a2c2c..7386aae 100644 --- a/src/konnektor/network_planners/_networkx_implementations/bipartite_match_algorithm.py +++ b/src/konnektor/network_planners/_networkx_implementations/bipartite_match_algorithm.py @@ -33,7 +33,8 @@ def concatenate_networks(self, nodesA: list[int], nodesB: list[int], nx.Graph the resulting graph, containing both subgraphs. """ - + # The initial "weights" are Scores, which need to be translated to weights. + weights = list(map(lambda x: 1-x, weights)) wedges_map = {(e[0], e[1]): w for e, w in zip(edges, weights)} wedges = [(e[0], e[1], w) for e, w in zip(edges, weights)] diff --git a/src/konnektor/network_planners/_networkx_implementations/cyclic_network_algorithm.py b/src/konnektor/network_planners/_networkx_implementations/cyclic_network_algorithm.py index 13f0ada..3e81659 100644 --- a/src/konnektor/network_planners/_networkx_implementations/cyclic_network_algorithm.py +++ b/src/konnektor/network_planners/_networkx_implementations/cyclic_network_algorithm.py @@ -95,6 +95,8 @@ def _translate_input(self, edges: List[Tuple[int, int]], # build Edges: w_edges = [] nodes = [] + # The initial "weights" are Scores, which need to be translated to weights. + weights = list(map(lambda x: 1-x, weights)) for e, w in zip(edges, weights): w_edges.append((e[0], e[1], w)) nodes.extend(e) diff --git a/src/konnektor/network_planners/_networkx_implementations/mst_network_algorithm.py b/src/konnektor/network_planners/_networkx_implementations/mst_network_algorithm.py index 1643669..7cf60f2 100644 --- a/src/konnektor/network_planners/_networkx_implementations/mst_network_algorithm.py +++ b/src/konnektor/network_planners/_networkx_implementations/mst_network_algorithm.py @@ -13,6 +13,8 @@ def generate_network(self, edges: list[tuple[int, int]], weights: list[float], n_edges:int=None) -> nx.Graph: wedges = [] nodes = [] + # The initial "weights" are Scores, which need to be translated to weights. + weights = list(map(lambda x: 1-x, weights)) for edge, weight in zip(edges, weights): wedges.append([edge[0], edge[1], weight]) nodes.extend(list(edge)) diff --git a/src/konnektor/network_planners/_networkx_implementations/n_nodes_edges_network_algorithm.py b/src/konnektor/network_planners/_networkx_implementations/n_nodes_edges_network_algorithm.py index e9339df..8be24f1 100644 --- a/src/konnektor/network_planners/_networkx_implementations/n_nodes_edges_network_algorithm.py +++ b/src/konnektor/network_planners/_networkx_implementations/n_nodes_edges_network_algorithm.py @@ -18,6 +18,8 @@ def generate_network(self, edges: list[tuple[int, int]], w_edges = [] nodes = [] + # The initial "weights" are Scores, which need to be translated to weights. + weights = list(map(lambda x: 1-x, weights)) for e, w in zip(edges, weights): w_edges.append((e[0], e[1], w)) nodes.extend(e) diff --git a/src/konnektor/network_planners/_networkx_implementations/radial_network_algorithm.py b/src/konnektor/network_planners/_networkx_implementations/radial_network_algorithm.py index 132a277..8c0a4c2 100644 --- a/src/konnektor/network_planners/_networkx_implementations/radial_network_algorithm.py +++ b/src/konnektor/network_planners/_networkx_implementations/radial_network_algorithm.py @@ -1,7 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor -from typing import Callable +from typing import Callable, Iterable import networkx as nx import numpy as np @@ -12,12 +12,15 @@ class RadialNetworkAlgorithm(_AbstractNetworkAlgorithm): - def __init__(self, metric_aggregation_method: Callable = None): + def __init__(self, metric_aggregation_method: Callable = None, n_centers: int = 1): self.metric_aggregation_method = metric_aggregation_method + self.n_centers = n_centers def _central_lig_selection(self, edges: list[tuple[int, int]], - weights: list[float]) -> int: + weights: list[float]) -> Iterable[int]: nodes = set([n for e in edges for n in e]) + # The initial "weights" are Scores, which need to be translated to weights. + weights = list(map(lambda x: 1-x, weights)) edge_weights = list(zip(edges, weights)) node_scores = {n: [e_s[1] for e_s in edge_weights if (n in e_s[0])] for @@ -33,9 +36,8 @@ def _central_lig_selection(self, edges: list[tuple[int, int]], aggregated_scores = list( map(lambda x: (x[0], np.sum(x[1])), filtered_node_scores.items())) sorted_node_scores = list(sorted(aggregated_scores, key=lambda x: x[1])) - - opt_node = sorted_node_scores[0] - return opt_node + opt_nodes = sorted_node_scores[:self.n_centers] + return opt_nodes def generate_network(self, edges: list[tuple[int, int]], weights: list[float], @@ -75,12 +77,16 @@ def generate_network(self, edges: list[tuple[int, int]], """ if (central_node is None): - central_node, avg_score = self._central_lig_selection(edges=edges, - weights=weights) + central_nodes = self._central_lig_selection(edges=edges, + weights=weights, ) + elif isinstance(central_node, (SmallMoleculeComponent, str)): + central_nodes = [(central_node, 1)] + else: + raise ValueError("invalide central node type: "+str(type(central_node))) wedges = [] for edge, weight in zip(edges, weights): - if (central_node in edge): + if any(central_node in edge for central_node, avg_score in central_nodes): wedges.append([edge[0], edge[1], weight]) # Todo: Warning if something was not connected to the central ligand? diff --git a/src/konnektor/network_planners/generators/twin_star_network_generator.py b/src/konnektor/network_planners/generators/twin_star_network_generator.py new file mode 100644 index 0000000..ee5d121 --- /dev/null +++ b/src/konnektor/network_planners/generators/twin_star_network_generator.py @@ -0,0 +1,82 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/konnektor + +from typing import Iterable + +from gufe import Component, LigandNetwork, AtomMapper + +from konnektor.network_planners._networkx_implementations import \ + RadialNetworkAlgorithm +from ._abstract_network_generator import NetworkGenerator +from .maximal_network_generator import MaximalNetworkGenerator + + +class TwinStarNetworkGenerator(NetworkGenerator): + + def __init__(self, mapper: AtomMapper, scorer, n_centers: int =2, + n_processes: int = 1, + _initial_edge_lister: NetworkGenerator = None): + """ + The Twin Star Ligand Network Planner , set's n ligands ligand into the center of a graph and connects all other ligands to each center. + + Parameters + ---------- + mapper : AtomMapper + the atom mapper is required, to define the connection between two ligands. + scorer : AtomMappingScorer + scoring function evaluating an atom mapping, and giving a score between [0,1]. + n_centers: int, optional + the number of centers in the network. (default: 2) + n_processes: int, optional + number of processes that can be used for the network generation. (default: 1) + _initial_edge_lister: LigandNetworkPlanner, optional + this LigandNetworkPlanner is used to give the initial set of edges. For standard usage, the Maximal NetworPlanner is used. + However in large scale approaches, it might be interesting to use the heuristicMaximalNetworkPlanner.. (default: MaximalNetworkPlanner) + """ + if _initial_edge_lister is None: + _initial_edge_lister = MaximalNetworkGenerator(mapper=mapper, + scorer=scorer, + n_processes=n_processes) + + super().__init__(mapper=mapper, scorer=scorer, + network_generator=RadialNetworkAlgorithm(n_centers=n_centers), + n_processes=n_processes, + _initial_edge_lister=_initial_edge_lister) + + self.n_centers = n_centers + + + def generate_ligand_network(self, components: Iterable[Component]) -> LigandNetwork: + """ + generate a twin star map network for the given compounds. + + Parameters + ---------- + components: Iterable[Component] + the components to be used for the LigandNetwork + + Returns + ------- + LigandNetwork + a star like network. + """ + components = list(components) + + + # Full Graph Construction + initial_network = self._initial_edge_lister.generate_ligand_network( + components=components) + mappings = initial_network.edges + + # Translate Mappings to graphable: + edge_map = {(components.index(m.componentA), + components.index(m.componentB)): m for m in mappings} + edges = list(sorted(edge_map.keys())) + weights = [edge_map[k].annotations['score'] for k in edges] + + rg = self.network_generator.generate_network(edges=edges, + weights=weights) + selected_mappings = [edge_map[k] for k in rg.edges] + + + return LigandNetwork(edges=selected_mappings, nodes=components) diff --git a/src/konnektor/tests/network_planners/_networkx_implemenations/test_radial_graph_generator.py b/src/konnektor/tests/network_planners/_networkx_implemenations/test_radial_graph_generator.py index 31b1414..9a7210b 100644 --- a/src/konnektor/tests/network_planners/_networkx_implemenations/test_radial_graph_generator.py +++ b/src/konnektor/tests/network_planners/_networkx_implemenations/test_radial_graph_generator.py @@ -1,6 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor +import pytest import networkx as nx import numpy as np @@ -14,12 +15,28 @@ def test_radial_network_generation_find_center(nine_mols_edges): weights = [e[2] for e in nine_mols_edges] gen = RadialNetworkAlgorithm() - c_node, avg_weight = gen._central_lig_selection(edges, weights) + c_node, avg_weight = gen._central_lig_selection(edges, weights)[0] - assert c_node == "lig_14" # Check central node - assert np.round(avg_weight, 2) == 2.86 + assert c_node == "lig_10" # Check central node + np.testing.assert_allclose(avg_weight, 2.055, rtol=0.01) +@pytest.mark.parametrize('n_centers', [2,3,4]) +def test_radial_network_generation_find_centers(nine_mols_edges, n_centers): + edges = [(e[0], e[1]) for e in nine_mols_edges] + weights = [e[2] for e in nine_mols_edges] + + gen = RadialNetworkAlgorithm(n_centers=n_centers) + centers = gen._central_lig_selection(edges, weights) + + print(centers) + expected_centers = ['lig_10', 'lig_8', 'lig_9', 'lig_16'] + expected_weights = [ 2.0551095953189917, 3.6524873109359146, 4.270400420741822, 4.543886935944357] + for i, (cID, avg_weight) in enumerate(centers): + print(cID, avg_weight) + assert cID == expected_centers[i] # Check central node + np.testing.assert_allclose(avg_weight, expected_weights[i], rtol=0.01) + def test_radial_network_generation_without_center(nine_mols_edges): edges = [(e[0], e[1]) for e in nine_mols_edges] weights = [e[2] for e in nine_mols_edges] @@ -29,7 +46,7 @@ def test_radial_network_generation_without_center(nine_mols_edges): g = gen.generate_network(edges, weights) assert len(nodes) - 1 == len(g.edges) - assert all(["lig_14" in e for e in g.edges]) # check central node + assert all(["lig_10" in e for e in g.edges]) # check central node assert all([e[0] != e[1] for e in g.edges]) # No self connectivity assert isinstance(g, nx.Graph) diff --git a/src/konnektor/tests/network_planners/generators/test_clusterd_network_generator.py b/src/konnektor/tests/network_planners/generators/test_clusterd_network_generator.py index 3bbbe0f..2d9c970 100644 --- a/src/konnektor/tests/network_planners/generators/test_clusterd_network_generator.py +++ b/src/konnektor/tests/network_planners/generators/test_clusterd_network_generator.py @@ -1,11 +1,10 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor -import itertools - +import numpy as np from gufe import LigandNetwork from sklearn.cluster import KMeans -from konnektor.network_analysis import get_is_connected +from konnektor.network_analysis import get_is_connected, get_graph_score from konnektor.network_planners.generators.clustered_network_generator import \ ClusteredNetworkGenerator from konnektor.network_tools.clustering.component_diversity_clustering import ComponentsDiversityClusterer @@ -15,7 +14,7 @@ def test_clustered_network_planner(): n_compounds = 40 components, genMapper, genScorer = build_random_dataset( - n_compounds=n_compounds) + n_compounds=n_compounds, rand_seed=42) from konnektor.network_planners import (RadialLigandNetworkPlanner, MstConcatenator) @@ -37,3 +36,5 @@ def test_clustered_network_planner(): assert len(planner.clusters) == 3 assert len(ligand_network.edges) == 3*((n_compounds//3)-1) + (3 * concatenator.n_connecting_edges) + 1 assert get_is_connected(ligand_network) + + np.testing.assert_allclose(get_graph_score(ligand_network), 25.708691, rtol=0.01) \ No newline at end of file diff --git a/src/konnektor/tests/network_planners/generators/test_cyclic_network_planner.py b/src/konnektor/tests/network_planners/generators/test_cyclic_network_planner.py index 249ece7..86a0fbe 100644 --- a/src/konnektor/tests/network_planners/generators/test_cyclic_network_planner.py +++ b/src/konnektor/tests/network_planners/generators/test_cyclic_network_planner.py @@ -1,7 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor -from konnektor.network_analysis import get_is_connected, get_node_number_cycles +import numpy as np +from konnektor.network_analysis import get_is_connected, get_node_number_cycles, get_graph_score from konnektor.network_planners import CyclicNetworkGenerator from konnektor.utils.toy_data import build_random_dataset @@ -10,7 +11,7 @@ def test_cyclic_network_planner(): n_compounds = 8 ncycles = 2 components, genMapper, genScorer = build_random_dataset( - n_compounds=n_compounds) + n_compounds=n_compounds, rand_seed=42) planner = CyclicNetworkGenerator( mapper=genMapper, scorer=genScorer, cycle_sizes=3, @@ -24,3 +25,5 @@ def test_cyclic_network_planner(): assert get_is_connected(network) nnode_cycles = get_node_number_cycles(network) assert all(v >= ncycles for k, v in nnode_cycles.items()) + + np.testing.assert_allclose(get_graph_score(network), 10.347529, rtol=0.01) \ No newline at end of file diff --git a/src/konnektor/tests/network_planners/generators/test_mst_netwok_planner.py b/src/konnektor/tests/network_planners/generators/test_mst_netwok_planner.py index 8afb22c..d825490 100644 --- a/src/konnektor/tests/network_planners/generators/test_mst_netwok_planner.py +++ b/src/konnektor/tests/network_planners/generators/test_mst_netwok_planner.py @@ -3,6 +3,7 @@ import gufe import networkx as nx +import numpy as np import pytest from gufe import LigandNetwork @@ -11,6 +12,7 @@ atom_mapping_basic_test_files, mol_from_smiles, genScorer, GenAtomMapper, ErrorMapper) +from konnektor.network_analysis import get_graph_score def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files): @@ -25,7 +27,7 @@ def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files): assert isinstance(network, LigandNetwork) assert list(network.edges) - + np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.001) @pytest.fixture(scope='session') def minimal_spanning_network(toluene_vs_others): diff --git a/src/konnektor/tests/network_planners/generators/test_nedges_netwok_planner.py b/src/konnektor/tests/network_planners/generators/test_nedges_netwok_planner.py index 5460f2d..0ba491c 100644 --- a/src/konnektor/tests/network_planners/generators/test_nedges_netwok_planner.py +++ b/src/konnektor/tests/network_planners/generators/test_nedges_netwok_planner.py @@ -1,9 +1,10 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor +import numpy as np from gufe import LigandNetwork -from konnektor.network_analysis import get_is_connected +from konnektor.network_analysis import get_is_connected, get_graph_score from konnektor.network_planners import NNodeEdgesNetworkGenerator from konnektor.tests.network_planners.conf import ( atom_mapping_basic_test_files, @@ -26,3 +27,5 @@ def test_nedges_network_mappers(atom_mapping_basic_test_files): assert len(network.nodes) == len(ligands) assert len(network.edges) <= len(ligands) * 2 assert get_is_connected(network) + + np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.01) diff --git a/src/konnektor/tests/network_planners/generators/test_rmst_netwok_planner.py b/src/konnektor/tests/network_planners/generators/test_rmst_netwok_planner.py index 055f286..d4d8cd4 100644 --- a/src/konnektor/tests/network_planners/generators/test_rmst_netwok_planner.py +++ b/src/konnektor/tests/network_planners/generators/test_rmst_netwok_planner.py @@ -1,9 +1,11 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/konnektor -import gufe +import numpy as np import networkx as nx import pytest + +import gufe from gufe import LigandNetwork from konnektor.network_planners import \ @@ -12,6 +14,7 @@ atom_mapping_basic_test_files, mol_from_smiles, genScorer, GenAtomMapper, ErrorMapper) +from konnektor.network_analysis import get_graph_score def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files): @@ -27,6 +30,7 @@ def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files): assert isinstance(network, LigandNetwork) assert list(network.edges) + np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.01) @pytest.fixture(scope='session') diff --git a/src/konnektor/tests/network_planners/generators/test_radial_network_planner.py b/src/konnektor/tests/network_planners/generators/test_star_network_planner.py similarity index 92% rename from src/konnektor/tests/network_planners/generators/test_radial_network_planner.py rename to src/konnektor/tests/network_planners/generators/test_star_network_planner.py index 525cd9f..4a0566a 100644 --- a/src/konnektor/tests/network_planners/generators/test_radial_network_planner.py +++ b/src/konnektor/tests/network_planners/generators/test_star_network_planner.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize('as_list', [False]) -def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others, +def test_star_network(atom_mapping_basic_test_files, toluene_vs_others, as_list): toluene, others = toluene_vs_others central_ligand_name = 'toluene' @@ -38,7 +38,7 @@ def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others, for mapping in network.edges) -def test_radial_network_with_scorer(toluene_vs_others): +def test_star_network_with_scorer(toluene_vs_others): toluene, others = toluene_vs_others mapper = GenAtomMapper() @@ -58,7 +58,7 @@ def test_radial_network_with_scorer(toluene_vs_others): edge.componentA_to_componentB) -def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others): +def test_star_network_multiple_mappers_no_scorer(toluene_vs_others): toluene, others = toluene_vs_others # in this one, we should always take the bad mapper mapper = BadMapper() @@ -73,7 +73,7 @@ def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others): assert edge.componentA_to_componentB == {0: 0} -def test_radial_network_failure(atom_mapping_basic_test_files): +def test_star_network_failure(atom_mapping_basic_test_files): nigel = SmallMoleculeComponent(mol_from_smiles('N')) mapper = ErrorMapper() diff --git a/src/konnektor/tests/network_planners/generators/test_starry_sky_network_generator.py b/src/konnektor/tests/network_planners/generators/test_starry_sky_network_generator.py index b3df252..4dd0d8c 100644 --- a/src/konnektor/tests/network_planners/generators/test_starry_sky_network_generator.py +++ b/src/konnektor/tests/network_planners/generators/test_starry_sky_network_generator.py @@ -4,7 +4,7 @@ import numpy as np from gufe import LigandNetwork from sklearn.cluster import KMeans -from konnektor.network_analysis import get_is_connected +from konnektor.network_analysis import get_is_connected, get_graph_score from konnektor.network_planners.generators.clustered_network_generator import \ StarrySkyNetworkGenerator from konnektor.network_tools.clustering.component_diversity_clustering import ComponentsDiversityClusterer @@ -14,7 +14,7 @@ def test_starry_sky_network_planner(): n_compounds = 40 components, genMapper, genScorer = build_random_dataset( - n_compounds=n_compounds) + n_compounds=n_compounds, rand_seed=42) clusterer = ComponentsDiversityClusterer(cluster=KMeans(n_clusters=3)) planner = StarrySkyNetworkGenerator(mapper=genMapper, @@ -29,3 +29,5 @@ def test_starry_sky_network_planner(): assert len(ligand_network.nodes) == n_compounds np.testing.assert_allclose(actual=len(ligand_network.edges), desired=approx_edges, rtol=5) assert get_is_connected(ligand_network) + + np.testing.assert_allclose(get_graph_score(ligand_network), 24.607684, rtol=0.01) diff --git a/src/konnektor/tests/network_planners/generators/test_twin_star_network_planner.py b/src/konnektor/tests/network_planners/generators/test_twin_star_network_planner.py new file mode 100644 index 0000000..88aa555 --- /dev/null +++ b/src/konnektor/tests/network_planners/generators/test_twin_star_network_planner.py @@ -0,0 +1,29 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/konnektor + +import pytest +import numpy as np +from gufe import LigandNetwork + +from konnektor.network_planners import TwinStarNetworkGenerator +from konnektor.network_analysis import get_is_connected, get_graph_score +from konnektor.utils.toy_data import build_random_dataset + + +def test_twin_star_network_planner(): + n_compounds = 40 + components, genMapper, genScorer = build_random_dataset( + n_compounds=n_compounds, rand_seed=42) + + planner = TwinStarNetworkGenerator(mapper=genMapper, + scorer=genScorer) + + #Testing + ligand_network = planner(components) + n_centers = planner.n_centers + approx_edges = (len(components)-1)*n_centers + assert isinstance(ligand_network, LigandNetwork) + assert len(ligand_network.nodes) == n_compounds + np.testing.assert_allclose(actual=len(ligand_network.edges), desired=approx_edges, rtol=5) + assert get_is_connected(ligand_network) + np.testing.assert_allclose(get_graph_score(ligand_network), 39.944662, rtol=0.01) diff --git a/src/konnektor/tests/test_network_analysis.py b/src/konnektor/tests/test_network_analysis.py index dbe0335..292d2c8 100644 --- a/src/konnektor/tests/test_network_analysis.py +++ b/src/konnektor/tests/test_network_analysis.py @@ -143,8 +143,7 @@ def test_get_mst_graph_score(): seed = 42 n_compounds = 30 g = build_random_mst_network(n_compounds=n_compounds, rand_seed=seed) - - np.testing.assert_allclose(get_graph_score(g), 0.9823, atol=1e-3) + np.testing.assert_allclose(get_graph_score(g), 27.52, atol=1e-3) def test_get_fully_connected_graph_score(): # Check for graph scores.