Skip to content

Commit

Permalink
Merge pull request #26 from choderalab/exp
Browse files Browse the repository at this point in the history
experiment loop
  • Loading branch information
yuanqing-wang authored Jun 19, 2020
2 parents 6ce8462 + fcfb912 commit fa54d71
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 13 deletions.
2 changes: 2 additions & 0 deletions espaloma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import espaloma.nn
import espaloma.graphs
import espaloma.mm
import espaloma.app


from espaloma.mm.geometry import *

Expand Down
3 changes: 3 additions & 0 deletions espaloma/app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import espaloma
import espaloma.app
import espaloma.app.experiment
214 changes: 214 additions & 0 deletions espaloma/app/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# =============================================================================
# IMPORTS
# =============================================================================
import espaloma as esp
import abc
import torch
import copy

# =============================================================================
# MODULE CLASSES
# =============================================================================
class Experiment(abc.ABC):
""" Base class for espaloma experiment.
"""
def __init__(self):
super(Experiment, self).__init__()

class Train(Experiment):
""" Train a model for a while.
"""
def __init__(
self,
net,
data,
metrics=[esp.metrics.TypingCrossEntropy],
optimizer=lambda net: torch.optim.Adam(net.parameters(), 1e-3),
n_epochs=100,
record_interval=1,
):
super(Train, self).__init__()

# bookkeeping
self.net = net
self.data = data
self.metrics = metrics
self.n_epochs = n_epochs
self.record_interval = record_interval
self.states = {}

# make optimizer
if callable(optimizer):
self.optimizer = optimizer(net)
else:
self.optimizer = optimizer

# compose loss function
def loss(g):
_loss = 0.
for metric in self.metrics:
_loss += metric(g)
return _loss

self.loss = loss

def train_once(self):
""" Train the model for one batch.
"""
for g in self.data: # TODO: does this have to be a single g?

def closure():
self.optimizer.zero_grad()
loss = self.loss(g)
loss.backward()
return loss

self.optimizer.step()

def train(self):
""" Train the model for multiple steps and
record the weights once every `record_interval`
"""
for epoch_idx in range(int(self.n_epochs)):
self.train_once()

# record when `record_interval` is hit
if epoch_idx % self.record_interval == 0:
self.states[epoch_idx] = copy.deepcopy(self.net.state_dict())

# record final state
self.states['final'] = copy.deepcopy(self.net.state_dict())

return self.net


class Test(Experiment):
""" Run sequences of tests on a trained model.
"""
def __init__(
self,
net,
data,
states,
metrics=[esp.metrics.TypingCrossEntropy],
sampler=None):
# bookkeeping
self.net = net
self.data = data
self.states = states
self.metrics = metrics
self.sampler = sampler

def test(self):
""" Run test.
"""
# loop through the metrics
for metric in metrics:
results[metric.__name__] = {}

# make it just one giant graph
g = data[0:-1]

for state_name, state in self.states.items(): # loop through states
# load the state dict
self.net.load_state_dict(state)

# loop through the metrics
results[metric.__name__][state_name] = metric(
self.net,
g,
sampler=self.sampler).detach().cpu().numpy()

# point this to self
self.results = results
return dict(results)

class TrainAndTest(Experiment):
""" Train a model and then test it.
"""
def __init__(
self,
net,
ds_tr,
ds_te,
metrics_tr=[esp.metrics.TypingCrossEntropy],
metrics_te=[esp.metrics.TypingCrossEntropy],
optimizer=lambda net: torch.optim.Adam(net.parameters(), 1e-3),
n_epochs=100,
record_interval=1
):

# bookkeeping
self.net = net
self.ds_tr = ds_tr
self.ds_te = ds_te
self.optimizer = optimizer
self.n_epochs = n_epochs

def __str__(self):
_str = ''
_str += '# model'
_str += '\n'
_str += str(self.net)
_str += '\n'
if hasattr(self.net, 'noise_model'):
_str += '# noise model'
_str += '\n'
_str += str(self.net.noise_model)
_str += '\n'
_str += '# optimizer'
_str += '\n'
_str += str(self.optimizer)
_str += '\n'
_str += '# n_epochs'
_str += '\n'
_str += str(self.n_epochs)
_str += '\n'
return _str

def run(self):
""" Run train and test.
"""
train = Train(
net=self.net,
data=self.ds_tr,
optimizer=self.optimizer,
n_epochs=self.n_epochs
)

train.train()

self.states = train.states

test = Test(
net=self.net,
data=self.ds_te,
metrics=self.metrics,
states=self.states,
sampler=self.sampler
)

test.test()

self.results_te = test.results

test = Test(
net=self.net,
data=self.ds_tr,
metrics=self.metrics,
states=self.states,
sampler=self.sampler
)

test.test()

self.results_tr = test.results

return{'test': self.results_te, 'train': self.results_tr}
83 changes: 83 additions & 0 deletions espaloma/app/tests/test_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import torch

def test_import():
import espaloma as esp
esp.app.experiment


@pytest.fixture
def data():
import espaloma as esp

esol = esp.data.esol(first=20)

# do some typing
typing = esp.graphs.legacy_force_field.LegacyForceField('gaff-1.81')
esol.apply(typing, in_place=True) # this modify the original data

return esol

@pytest.fixture
def net():
import espaloma as esp

# define a layer
layer = esp.nn.layers.dgl_legacy.gn('GraphConv')

# define a representation
representation = esp.nn.Sequential(
layer,
[32, 'tanh', 32, 'tanh', 32, 'tanh'])

# define a readout
readout = esp.nn.readout.node_typing.NodeTyping(
in_features=32,
n_classes=100) # not too many elements here I think?

net = torch.nn.Sequential(
representation,
readout)

return net


def test_data_and_net(data, net):
data
net


@pytest.fixture
def train(data, net):
import espaloma as esp
train = esp.app.experiment.Train(
net=net,
data=data,
n_epochs=1,
metrics=[esp.metrics.GraphMetric(
base_metric=torch.nn.CrossEntropyLoss(),
between=['nn_typing', 'legacy_typing'])])

return train

def test_train(train):
train.train()

def test_test(train, net, data):
import espaloma as esp
train.train()
test = esp.app.experiment.Test(
net=net,
data=data,
states=train.states)


def test_train_and_test(net, data):
import espaloma as esp

train_and_test = esp.app.experiment.TrainAndTest(
net=net,
n_epochs=1,
ds_tr=data,
ds_te=data
)
4 changes: 2 additions & 2 deletions espaloma/graphs/legacy_force_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, forcefield="gaff-1.81"):
@staticmethod
def _convert_to_off(mol):
import openforcefield

if isinstance(mol, esp.Graph):
return mol.mol

Expand Down Expand Up @@ -149,7 +149,7 @@ def _type_gaff(self, mol, g=None):
return g

def typing(self, mol, g=None):
""" Type a molecular graph.
""" Type a molecular graph.
"""
if "gaff" in self.forcefield:
Expand Down
25 changes: 14 additions & 11 deletions espaloma/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,32 @@ def r2(y, y_hat):
# =============================================================================
# MODULE CLASSES
# =============================================================================
class Loss(torch.nn.modules.loss._Loss):
class Metric(torch.nn.modules.loss._Loss):
""" Base function for loss.
"""
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(Loss, self).__init__(size_average, reduce, reduction)
super(Metric, self).__init__(size_average, reduce, reduction)

@abc.abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError

class GraphLoss(Loss):
class GraphMetric(Metric):
""" Loss between nodes attributes of graph or graphs.
"""
def __init__(self, base_loss, between, *args, **kwargs):
super(GraphLoss, self).__init__(*args, **kwargs)
def __init__(self, base_metric, between, *args, **kwargs):
super(GraphMetric, self).__init__(*args, **kwargs)

# between could be tuple of two strings or two functions
assert len(between) == 2

self.between = (
self._translation(between[0]),
self._translation(between[1]))

self.base_loss = base_loss
self.base_metric = base_metric

@staticmethod
def _translation(string):
Expand All @@ -73,7 +73,7 @@ def _translation(string):
'legacy_typing': lambda g: g.ndata['legacy_typing']
}[string]


def forward(self, g_input, g_target=None):
""" Forward function of loss.
Expand All @@ -88,9 +88,12 @@ def forward(self, g_input, g_target=None):
# compute loss using base loss
# NOTE:
# use keyward argument here since torch is bad with the order with args
return self.base_loss.forward(
return self.base_metric.forward(
input=input_fn(g_input),
target=target_fn(g_target))



class TypingCrossEntropy(Metric):
def __init__(self):
super(TypingCrossEntropy).__init__(
base_metric=torch.nn.CrossEntropyLoss(),
between=['nn_typing', 'legacy_typing'])

0 comments on commit fa54d71

Please sign in to comment.