Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds FEMap To/from networkx #112

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
starts hiding internal implementation of network
instead representations are accessed via to/from methods on FEMap class

adds to/from networkx representation

adds FEMap eq and iter magic methods
  • Loading branch information
richardjgowers committed Nov 15, 2023
commit 9b686ba8626cc5437e0c36a09d5861c463bf993e
83 changes: 67 additions & 16 deletions cinnabar/femap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
from typing import Union

import copy
import openff.units
import pandas as pd
from openff.units import unit
@@ -95,13 +95,64 @@ class FEMap:
>>> fe.add_measurement(experimental_result2)
>>> fe.add_measurement(calculated_result)
"""
# internal representation:
# graph with measurements as edges
# absolute Measurements are an edge between 'ReferenceState' and the label
# all edges are directed, all edges can be multiply defined
graph: nx.MultiDiGraph
# all edges are directed
# all edges can be multiply defined
_graph: nx.MultiDiGraph

def __init__(self):
self.graph = nx.MultiDiGraph()
self._graph = nx.MultiDiGraph()

def __iter__(self):
for a, b, d in self._graph.edges(data=True):
# skip artificial reverse edges
if d['source'] == 'reverse':
continue

yield Measurement(labelA=a, labelB=b, **d)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented

# iter returns hashable Measurements, so this will compare contents
return set(self) == set(other)

def to_networkx(self) -> nx.MultiDiGraph:
"""A *copy* of the FEMap as a networkx Graph

The FEMap is represented as a multi-edged directional graph

Edges have the following attributes:
- DG: the free energy difference of going from the first edge label to
the second edge label
- uncertainty: uncertainty of the DG value
- temperature: the temperature at which DG was measured
- computational: boolean label of the original source of the data
- source: a string describing the source of data.

Note
----
All edges appear twice, once with the attribute source='reverse',
and the DG value flipped. This allows "pathfinding" like approaches,
where the DG values will be correctly summed.
"""
return copy.deepcopy(self._graph)

@classmethod
def from_networkx(cls, graph: nx.MultiDiGraph):
"""Create FEMap from network representation

Note
----
Currently absolutely no validation of the input is done.
"""
m = cls()
m._graph = graph

return m

@classmethod
def from_csv(cls, filename, units: Optional[unit.Quantity] = None):
@@ -133,8 +184,8 @@ def add_measurement(self, measurement: Measurement):

# add both directions, but flip sign for the other direction
d_backwards = {**d, 'DG': - d['DG'], 'source': 'reverse'}
self.graph.add_edge(measurement.labelA, measurement.labelB, **d)
self.graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)
self._graph.add_edge(measurement.labelA, measurement.labelB, **d)
self._graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)

def add_experimental_measurement(self,
label: Union[str, Hashable],
@@ -262,7 +313,7 @@ def get_relative_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if isinstance(l1, ReferenceState) or isinstance(l2, ReferenceState):
@@ -297,7 +348,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if not isinstance(l1, ReferenceState):
@@ -325,7 +376,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
@property
def n_measurements(self) -> int:
"""Total number of both experimental and computational measurements"""
return len(self.graph.edges) // 2
return len(self._graph.edges) // 2

@property
def n_ligands(self) -> int:
@@ -336,7 +387,7 @@ def n_ligands(self) -> int:
def ligands(self) -> list:
"""All ligands in the graph"""
# must ignore ReferenceState nodes
return [n for n in self.graph.nodes
return [n for n in self._graph.nodes
if not isinstance(n, ReferenceState)]

@property
@@ -347,14 +398,14 @@ def degree(self) -> float:
@property
def n_edges(self) -> int:
"""Number of computational edges"""
return sum(1 for _, _, d in self.graph.edges(data=True)
return sum(1 for _, _, d in self._graph.edges(data=True)
if d['computational']) // 2

def check_weakly_connected(self) -> bool:
"""Checks if all results in the graph are reachable from other results"""
# todo; cache
comp_graph = nx.MultiGraph()
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_graph.add_edge(a, b)
@@ -365,7 +416,7 @@ def generate_absolute_values(self):
"""Populate the FEMap with absolute computational values based on MLE"""
# TODO: Make this return a new Graph with computational nodes annotated with DG values
# TODO this could work if either relative or absolute expt values are provided
mes = list(self.graph.edges(data=True))
mes = list(self._graph.edges(data=True))
# for now, we must all be in the same units for this to work
# grab unit of first measurement
u = mes[0][-1]['DG'].u
@@ -394,7 +445,7 @@ def generate_absolute_values(self):

# find all computational result labels
comp_ligands = set()
for A, B, d in self.graph.edges(data=True):
for A, B, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_ligands.add(A)
@@ -433,7 +484,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
# reduces to nx.DiGraph
g = nx.DiGraph()
# add DDG values from computational graph
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
if isinstance(a, ReferenceState): # skip absolute measurements
@@ -444,7 +495,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
g.add_edge(a, b, calc_DDG=d['DG'].magnitude, calc_dDDG=d['uncertainty'].magnitude)
# add DG values from experiment graph
for node, d in g.nodes(data=True):
expt = self.graph.get_edge_data(ReferenceState(), node)
expt = self._graph.get_edge_data(ReferenceState(), node)
if expt is None:
continue
expt = expt[0]
24 changes: 19 additions & 5 deletions cinnabar/tests/test_femap.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,21 @@ def example_map(example_csv):
def test_from_csv(example_map):
assert example_map.n_ligands == 36
assert example_map.n_edges == 58
assert len(example_map.graph.edges) == (58 + 36) * 2
assert len(example_map._graph.edges) == (58 + 36) * 2


def test_eq(example_csv):
m1 = cinnabar.FEMap.from_csv(example_csv)
m2 = cinnabar.FEMap.from_csv(example_csv)
m3 = cinnabar.FEMap.from_csv(example_csv)
m3.add_experimental_measurement(
label='this',
value=4.2 * unit.kilocalorie_per_mole,
uncertainty=0.1 * unit.kilocalorie_per_mole,
)

assert m1 == m2
assert m1 != m3


def test_degree(example_map):
@@ -77,7 +91,7 @@ def test_femap_add_experimental(ki):
)

assert set(m.ligands) == {'ligA'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
assert d.keys() == {0}
d = d[0]
assert d['computational'] is False
@@ -118,7 +132,7 @@ def test_add_ABFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
@@ -143,7 +157,7 @@ def test_add_RBFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo', 'bar'}
d = m.graph.get_edge_data('foo', 'bar')
d = m._graph.get_edge_data('foo', 'bar')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
@@ -166,7 +180,7 @@ def test_generate_absolute_values(example_map, ref_mle_results):
example_map.generate_absolute_values()

for e, (y_ref, yerr_ref) in ref_mle_results.items():
data = example_map.graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
data = example_map._graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
# grab the dict containing MLE data
for _, d in data.items():
if d['source'] == 'MLE':