diff --git a/cinnabar/femap.py b/cinnabar/femap.py index a403a6f..2600925 100644 --- a/cinnabar/femap.py +++ b/cinnabar/femap.py @@ -1,6 +1,6 @@ import pathlib from typing import Union - +import copy import openff.units import pandas as pd from openff.units import unit @@ -95,13 +95,64 @@ class FEMap: >>> fe.add_measurement(experimental_result2) >>> fe.add_measurement(calculated_result) """ + # internal representation: # graph with measurements as edges # absolute Measurements are an edge between 'ReferenceState' and the label - # all edges are directed, all edges can be multiply defined - graph: nx.MultiDiGraph + # all edges are directed + # all edges can be multiply defined + _graph: nx.MultiDiGraph def __init__(self): - self.graph = nx.MultiDiGraph() + self._graph = nx.MultiDiGraph() + + def __iter__(self): + for a, b, d in self._graph.edges(data=True): + # skip artificial reverse edges + if d['source'] == 'reverse': + continue + + yield Measurement(labelA=a, labelB=b, **d) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + + # iter returns hashable Measurements, so this will compare contents + return set(self) == set(other) + + def to_networkx(self) -> nx.MultiDiGraph: + """A *copy* of the FEMap as a networkx Graph + + The FEMap is represented as a multi-edged directional graph + + Edges have the following attributes: + - DG: the free energy difference of going from the first edge label to + the second edge label + - uncertainty: uncertainty of the DG value + - temperature: the temperature at which DG was measured + - computational: boolean label of the original source of the data + - source: a string describing the source of data. + + Note + ---- + All edges appear twice, once with the attribute source='reverse', + and the DG value flipped. This allows "pathfinding" like approaches, + where the DG values will be correctly summed. + """ + return copy.deepcopy(self._graph) + + @classmethod + def from_networkx(cls, graph: nx.MultiDiGraph): + """Create FEMap from network representation + + Note + ---- + Currently absolutely no validation of the input is done. + """ + m = cls() + m._graph = graph + + return m @classmethod def from_csv(cls, filename, units: Optional[unit.Quantity] = None): @@ -133,8 +184,8 @@ def add_measurement(self, measurement: Measurement): # add both directions, but flip sign for the other direction d_backwards = {**d, 'DG': - d['DG'], 'source': 'reverse'} - self.graph.add_edge(measurement.labelA, measurement.labelB, **d) - self.graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards) + self._graph.add_edge(measurement.labelA, measurement.labelB, **d) + self._graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards) def add_experimental_measurement(self, label: Union[str, Hashable], @@ -262,7 +313,7 @@ def get_relative_dataframe(self) -> pd.DataFrame: """ kcpm = unit.kilocalorie_per_mole data = [] - for l1, l2, d in self.graph.edges(data=True): + for l1, l2, d in self._graph.edges(data=True): if d['source'] == 'reverse': continue if isinstance(l1, ReferenceState) or isinstance(l2, ReferenceState): @@ -297,7 +348,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame: """ kcpm = unit.kilocalorie_per_mole data = [] - for l1, l2, d in self.graph.edges(data=True): + for l1, l2, d in self._graph.edges(data=True): if d['source'] == 'reverse': continue if not isinstance(l1, ReferenceState): @@ -325,7 +376,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame: @property def n_measurements(self) -> int: """Total number of both experimental and computational measurements""" - return len(self.graph.edges) // 2 + return len(self._graph.edges) // 2 @property def n_ligands(self) -> int: @@ -336,7 +387,7 @@ def n_ligands(self) -> int: def ligands(self) -> list: """All ligands in the graph""" # must ignore ReferenceState nodes - return [n for n in self.graph.nodes + return [n for n in self._graph.nodes if not isinstance(n, ReferenceState)] @property @@ -347,14 +398,14 @@ def degree(self) -> float: @property def n_edges(self) -> int: """Number of computational edges""" - return sum(1 for _, _, d in self.graph.edges(data=True) + return sum(1 for _, _, d in self._graph.edges(data=True) if d['computational']) // 2 def check_weakly_connected(self) -> bool: """Checks if all results in the graph are reachable from other results""" # todo; cache comp_graph = nx.MultiGraph() - for a, b, d in self.graph.edges(data=True): + for a, b, d in self._graph.edges(data=True): if not d['computational']: continue comp_graph.add_edge(a, b) @@ -365,7 +416,7 @@ def generate_absolute_values(self): """Populate the FEMap with absolute computational values based on MLE""" # TODO: Make this return a new Graph with computational nodes annotated with DG values # TODO this could work if either relative or absolute expt values are provided - mes = list(self.graph.edges(data=True)) + mes = list(self._graph.edges(data=True)) # for now, we must all be in the same units for this to work # grab unit of first measurement u = mes[0][-1]['DG'].u @@ -394,7 +445,7 @@ def generate_absolute_values(self): # find all computational result labels comp_ligands = set() - for A, B, d in self.graph.edges(data=True): + for A, B, d in self._graph.edges(data=True): if not d['computational']: continue comp_ligands.add(A) @@ -433,7 +484,7 @@ def to_legacy_graph(self) -> nx.DiGraph: # reduces to nx.DiGraph g = nx.DiGraph() # add DDG values from computational graph - for a, b, d in self.graph.edges(data=True): + for a, b, d in self._graph.edges(data=True): if not d['computational']: continue if isinstance(a, ReferenceState): # skip absolute measurements @@ -444,7 +495,7 @@ def to_legacy_graph(self) -> nx.DiGraph: g.add_edge(a, b, calc_DDG=d['DG'].magnitude, calc_dDDG=d['uncertainty'].magnitude) # add DG values from experiment graph for node, d in g.nodes(data=True): - expt = self.graph.get_edge_data(ReferenceState(), node) + expt = self._graph.get_edge_data(ReferenceState(), node) if expt is None: continue expt = expt[0] diff --git a/cinnabar/measurements.py b/cinnabar/measurements.py index 12afda4..1800342 100644 --- a/cinnabar/measurements.py +++ b/cinnabar/measurements.py @@ -50,6 +50,9 @@ def __hash__(self): class Measurement(DefaultModel): """The free energy difference of moving from A to B""" + class Config: + frozen = True + labelA: Hashable labelB: Hashable DG: FloatQuantity['kilocalorie_per_mole'] diff --git a/cinnabar/tests/test_femap.py b/cinnabar/tests/test_femap.py index ea0044e..9d21641 100644 --- a/cinnabar/tests/test_femap.py +++ b/cinnabar/tests/test_femap.py @@ -26,7 +26,21 @@ def example_map(example_csv): def test_from_csv(example_map): assert example_map.n_ligands == 36 assert example_map.n_edges == 58 - assert len(example_map.graph.edges) == (58 + 36) * 2 + assert len(example_map._graph.edges) == (58 + 36) * 2 + + +def test_eq(example_csv): + m1 = cinnabar.FEMap.from_csv(example_csv) + m2 = cinnabar.FEMap.from_csv(example_csv) + m3 = cinnabar.FEMap.from_csv(example_csv) + m3.add_experimental_measurement( + label='this', + value=4.2 * unit.kilocalorie_per_mole, + uncertainty=0.1 * unit.kilocalorie_per_mole, + ) + + assert m1 == m2 + assert m1 != m3 def test_degree(example_map): @@ -77,7 +91,7 @@ def test_femap_add_experimental(ki): ) assert set(m.ligands) == {'ligA'} - d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'ligA') + d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'ligA') assert d.keys() == {0} d = d[0] assert d['computational'] is False @@ -118,7 +132,7 @@ def test_add_ABFE(default_T): source='ebay', temperature=T) assert set(m.ligands) == {'foo'} - d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'foo') + d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'foo') assert len(d) == 1 d = d[0] assert d['DG'] == v @@ -143,7 +157,7 @@ def test_add_RBFE(default_T): source='ebay', temperature=T) assert set(m.ligands) == {'foo', 'bar'} - d = m.graph.get_edge_data('foo', 'bar') + d = m._graph.get_edge_data('foo', 'bar') assert len(d) == 1 d = d[0] assert d['DG'] == v @@ -166,7 +180,7 @@ def test_generate_absolute_values(example_map, ref_mle_results): example_map.generate_absolute_values() for e, (y_ref, yerr_ref) in ref_mle_results.items(): - data = example_map.graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e) + data = example_map._graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e) # grab the dict containing MLE data for _, d in data.items(): if d['source'] == 'MLE': @@ -194,3 +208,20 @@ def test_to_dataframe(example_map): assert abs_df2.shape == (72, 5) assert abs_df2.loc[abs_df2.computational].shape == (36, 5) assert abs_df2.loc[~abs_df2.computational].shape == (36, 5) + + +def test_to_networkx(example_map): + g = example_map.to_networkx() + + assert g + assert isinstance(g, nx.MultiDiGraph) + # should have (exptl + comp edges) * 2 + assert len(g.edges) == 2 * (36 + 58) + + +def test_from_networkx(example_map): + g = example_map.to_networkx() + + m2 = cinnabar.FEMap.from_networkx(g) + + assert example_map == m2 diff --git a/cinnabar/tests/test_measurements.py b/cinnabar/tests/test_measurements.py index 4b754c9..acf6744 100644 --- a/cinnabar/tests/test_measurements.py +++ b/cinnabar/tests/test_measurements.py @@ -14,6 +14,33 @@ def test_ground(): assert g3 == g4 +def test_measurement_hash(): + m1 = cinnabar.Measurement( + labelA='foo', labelB='bar', + DG=0.1 * unit.kilocalorie_per_mole, + uncertainty=0.01 * unit.kilocalorie_per_mole, + computational=True, + ) + m1a = cinnabar.Measurement( + labelA='foo', labelB='bar', + DG=0.1 * unit.kilocalorie_per_mole, + uncertainty=0.01 * unit.kilocalorie_per_mole, + computational=True, + ) + m2 = cinnabar.Measurement( + labelA='foo', labelB='bar', + DG=0.11 * unit.kilocalorie_per_mole, + uncertainty=0.01 * unit.kilocalorie_per_mole, + computational=False, + ) + + thing = set([m1, m1a, m2]) + + assert len(thing) == 2 + assert m1 in thing + assert m2 in thing + + @pytest.mark.parametrize('Ki,uncertainty,dG,dG_uncertainty,label,temp', [ [100 * unit.nanomolar, 10 * unit.nanomolar, -9.55 * unit.kilocalorie_per_mole, 0.059 * unit.kilocalorie_per_mole, 'lig', 298.15 * unit.kelvin],