Skip to content

Commit

Permalink
Homogenize Model forward pass return signature with GroupedModel.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Oct 17, 2023
1 parent c482cab commit 08903e3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions mtenn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def forward(self, comp, *parts):

energy_val = self.strategy(complex_rep, *parts_rep)
if self.readout:
return self.readout(energy_val)
return self.readout(energy_val), [energy_val]
else:
return energy_val
return energy_val, [energy_val]

def _fix_device(self, data):
## We'll call this on everything for uniformity, but if we fix_deivec is
Expand Down Expand Up @@ -194,7 +194,7 @@ def forward(self, input_list):
flush=True,
)
# First get prediction
pred = super().forward(inp)
pred, _ = super().forward(inp)
pred_list.append(pred.detach())

# Get gradient per sample
Expand Down
6 changes: 3 additions & 3 deletions mtenn/tests/test_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_mean_combination(models_and_inputs):
model_test, model_ref, inp_list, target, loss_func = models_and_inputs

# Ref calc
pred_list = [model_ref(X) for X in inp_list]
pred_list = [model_ref(X)[0] for X in inp_list]
pred_ref = torch.stack(pred_list).mean(axis=0)
loss = loss_func(pred_ref, target)
loss.backward()
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_max_combination(models_and_inputs):
model_test, model_ref, inp_list, target, loss_func = models_and_inputs

# Ref calc
pred_list = [model_ref(X) for X in inp_list]
pred_list = [model_ref(X)[0] for X in inp_list]
pred = torch.logsumexp(torch.stack(pred_list), axis=0)
loss = loss_func(pred, target)
loss.backward()
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_boltzmann_combination(models_and_inputs):
model_test, model_ref, inp_list, target, loss_func = models_and_inputs

# Ref calc
pred_list = torch.stack([model_ref(X) for X in inp_list])
pred_list = torch.stack([model_ref(X)[0] for X in inp_list])
w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0))
pred_ref = torch.dot(w.flatten(), pred_list.flatten())
loss = loss_func(pred_ref, target)
Expand Down

0 comments on commit 08903e3

Please sign in to comment.