Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds FEMap To/from networkx #112

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions cinnabar/femap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
from typing import Union

import copy
import openff.units
import pandas as pd
from openff.units import unit
Expand Down Expand Up @@ -95,13 +95,64 @@ class FEMap:
>>> fe.add_measurement(experimental_result2)
>>> fe.add_measurement(calculated_result)
"""
# internal representation:
# graph with measurements as edges
# absolute Measurements are an edge between 'ReferenceState' and the label
# all edges are directed, all edges can be multiply defined
graph: nx.MultiDiGraph
# all edges are directed
# all edges can be multiply defined
_graph: nx.MultiDiGraph

def __init__(self):
self.graph = nx.MultiDiGraph()
self._graph = nx.MultiDiGraph()

def __iter__(self):
for a, b, d in self._graph.edges(data=True):
# skip artificial reverse edges
if d['source'] == 'reverse':
continue

yield Measurement(labelA=a, labelB=b, **d)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented

# iter returns hashable Measurements, so this will compare contents
return set(self) == set(other)

def to_networkx(self) -> nx.MultiDiGraph:
"""A *copy* of the FEMap as a networkx Graph

The FEMap is represented as a multi-edged directional graph

Edges have the following attributes:
- DG: the free energy difference of going from the first edge label to
the second edge label
- uncertainty: uncertainty of the DG value
- temperature: the temperature at which DG was measured
- computational: boolean label of the original source of the data
- source: a string describing the source of data.

Note
----
All edges appear twice, once with the attribute source='reverse',
and the DG value flipped. This allows "pathfinding" like approaches,
where the DG values will be correctly summed.
"""
return copy.deepcopy(self._graph)

@classmethod
def from_networkx(cls, graph: nx.MultiDiGraph):
"""Create FEMap from network representation

Note
----
Currently absolutely no validation of the input is done.
"""
m = cls()
m._graph = graph

return m

@classmethod
def from_csv(cls, filename, units: Optional[unit.Quantity] = None):
Expand Down Expand Up @@ -133,8 +184,8 @@ def add_measurement(self, measurement: Measurement):

# add both directions, but flip sign for the other direction
d_backwards = {**d, 'DG': - d['DG'], 'source': 'reverse'}
self.graph.add_edge(measurement.labelA, measurement.labelB, **d)
self.graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)
self._graph.add_edge(measurement.labelA, measurement.labelB, **d)
self._graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)

def add_experimental_measurement(self,
label: Union[str, Hashable],
Expand Down Expand Up @@ -262,7 +313,7 @@ def get_relative_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if isinstance(l1, ReferenceState) or isinstance(l2, ReferenceState):
Expand Down Expand Up @@ -297,7 +348,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if not isinstance(l1, ReferenceState):
Expand Down Expand Up @@ -325,7 +376,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
@property
def n_measurements(self) -> int:
"""Total number of both experimental and computational measurements"""
return len(self.graph.edges) // 2
return len(self._graph.edges) // 2

@property
def n_ligands(self) -> int:
Expand All @@ -336,7 +387,7 @@ def n_ligands(self) -> int:
def ligands(self) -> list:
"""All ligands in the graph"""
# must ignore ReferenceState nodes
return [n for n in self.graph.nodes
return [n for n in self._graph.nodes
if not isinstance(n, ReferenceState)]

@property
Expand All @@ -347,14 +398,14 @@ def degree(self) -> float:
@property
def n_edges(self) -> int:
"""Number of computational edges"""
return sum(1 for _, _, d in self.graph.edges(data=True)
return sum(1 for _, _, d in self._graph.edges(data=True)
if d['computational']) // 2

def check_weakly_connected(self) -> bool:
"""Checks if all results in the graph are reachable from other results"""
# todo; cache
comp_graph = nx.MultiGraph()
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_graph.add_edge(a, b)
Expand All @@ -365,7 +416,7 @@ def generate_absolute_values(self):
"""Populate the FEMap with absolute computational values based on MLE"""
# TODO: Make this return a new Graph with computational nodes annotated with DG values
# TODO this could work if either relative or absolute expt values are provided
mes = list(self.graph.edges(data=True))
mes = list(self._graph.edges(data=True))
# for now, we must all be in the same units for this to work
# grab unit of first measurement
u = mes[0][-1]['DG'].u
Expand Down Expand Up @@ -394,7 +445,7 @@ def generate_absolute_values(self):

# find all computational result labels
comp_ligands = set()
for A, B, d in self.graph.edges(data=True):
for A, B, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_ligands.add(A)
Expand Down Expand Up @@ -433,7 +484,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
# reduces to nx.DiGraph
g = nx.DiGraph()
# add DDG values from computational graph
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
if isinstance(a, ReferenceState): # skip absolute measurements
Expand All @@ -444,7 +495,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
g.add_edge(a, b, calc_DDG=d['DG'].magnitude, calc_dDDG=d['uncertainty'].magnitude)
# add DG values from experiment graph
for node, d in g.nodes(data=True):
expt = self.graph.get_edge_data(ReferenceState(), node)
expt = self._graph.get_edge_data(ReferenceState(), node)
if expt is None:
continue
expt = expt[0]
Expand Down
3 changes: 3 additions & 0 deletions cinnabar/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __hash__(self):

class Measurement(DefaultModel):
"""The free energy difference of moving from A to B"""
class Config:
frozen = True

labelA: Hashable
labelB: Hashable
DG: FloatQuantity['kilocalorie_per_mole']
Expand Down
41 changes: 36 additions & 5 deletions cinnabar/tests/test_femap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def example_map(example_csv):
def test_from_csv(example_map):
assert example_map.n_ligands == 36
assert example_map.n_edges == 58
assert len(example_map.graph.edges) == (58 + 36) * 2
assert len(example_map._graph.edges) == (58 + 36) * 2


def test_eq(example_csv):
m1 = cinnabar.FEMap.from_csv(example_csv)
m2 = cinnabar.FEMap.from_csv(example_csv)
m3 = cinnabar.FEMap.from_csv(example_csv)
m3.add_experimental_measurement(
label='this',
value=4.2 * unit.kilocalorie_per_mole,
uncertainty=0.1 * unit.kilocalorie_per_mole,
)

assert m1 == m2
assert m1 != m3


def test_degree(example_map):
Expand Down Expand Up @@ -77,7 +91,7 @@ def test_femap_add_experimental(ki):
)

assert set(m.ligands) == {'ligA'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
assert d.keys() == {0}
d = d[0]
assert d['computational'] is False
Expand Down Expand Up @@ -118,7 +132,7 @@ def test_add_ABFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
Expand All @@ -143,7 +157,7 @@ def test_add_RBFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo', 'bar'}
d = m.graph.get_edge_data('foo', 'bar')
d = m._graph.get_edge_data('foo', 'bar')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
Expand All @@ -166,7 +180,7 @@ def test_generate_absolute_values(example_map, ref_mle_results):
example_map.generate_absolute_values()

for e, (y_ref, yerr_ref) in ref_mle_results.items():
data = example_map.graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
data = example_map._graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
# grab the dict containing MLE data
for _, d in data.items():
if d['source'] == 'MLE':
Expand Down Expand Up @@ -194,3 +208,20 @@ def test_to_dataframe(example_map):
assert abs_df2.shape == (72, 5)
assert abs_df2.loc[abs_df2.computational].shape == (36, 5)
assert abs_df2.loc[~abs_df2.computational].shape == (36, 5)


def test_to_networkx(example_map):
g = example_map.to_networkx()

assert g
assert isinstance(g, nx.MultiDiGraph)
# should have (exptl + comp edges) * 2
assert len(g.edges) == 2 * (36 + 58)


def test_from_networkx(example_map):
g = example_map.to_networkx()

m2 = cinnabar.FEMap.from_networkx(g)

assert example_map == m2
27 changes: 27 additions & 0 deletions cinnabar/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,33 @@ def test_ground():
assert g3 == g4


def test_measurement_hash():
m1 = cinnabar.Measurement(
labelA='foo', labelB='bar',
DG=0.1 * unit.kilocalorie_per_mole,
uncertainty=0.01 * unit.kilocalorie_per_mole,
computational=True,
)
m1a = cinnabar.Measurement(
labelA='foo', labelB='bar',
DG=0.1 * unit.kilocalorie_per_mole,
uncertainty=0.01 * unit.kilocalorie_per_mole,
computational=True,
)
m2 = cinnabar.Measurement(
labelA='foo', labelB='bar',
DG=0.11 * unit.kilocalorie_per_mole,
uncertainty=0.01 * unit.kilocalorie_per_mole,
computational=False,
)

thing = set([m1, m1a, m2])

assert len(thing) == 2
assert m1 in thing
assert m2 in thing


@pytest.mark.parametrize('Ki,uncertainty,dG,dG_uncertainty,label,temp', [
[100 * unit.nanomolar, 10 * unit.nanomolar, -9.55 * unit.kilocalorie_per_mole,
0.059 * unit.kilocalorie_per_mole, 'lig', 298.15 * unit.kelvin],
Expand Down
Loading