Skip to content

Commit

Permalink
Merge branch 'fix-bug-with-mean-loss-disc' into 'dev'
Browse files Browse the repository at this point in the history
Fix bug with mean loss disc

See merge request tobifinn/torch-assimilate!55
  • Loading branch information
tobifinn committed Nov 4, 2019
2 parents e7c83f3 + 95e93c3 commit 1378c77
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions pytassim/toolbox/discriminator/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class StandardDisc(object):
"""
def __init__(self, net,):
self.net = net
self.loss_func = torch.nn.BCEWithLogitsLoss()
self.loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')
self.optimizer = None
self.grad_optim = True

Expand Down Expand Up @@ -150,6 +150,7 @@ def disc_loss(self, in_data, labels):
loss has the same tensor type as given `in_data`.
"""
loss = self.loss_func(in_data, labels)
#loss = unscaled_loss * in_data.nelement() / in_data.shape[0]
return loss

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -179,11 +180,11 @@ def _get_train_losses(self, real_data, fake_data, *args, **kwargs):
batch_size = real_data.size()[0]

real_critic = self.forward(real_data, *args, **kwargs)
real_labels = self.get_targets(batch_size, 1.0, real_data)
real_labels = self.get_targets(batch_size, 0.0, real_data)
real_loss = self.disc_loss(real_critic, real_labels)

fake_critic = self.forward(fake_data, *args, **kwargs)
fake_labels = self.get_targets(batch_size, 0.0, real_data)
fake_labels = self.get_targets(batch_size, 1.0, real_data)
fake_loss = self.disc_loss(fake_critic, fake_labels)

total_loss = real_loss + fake_loss
Expand Down Expand Up @@ -329,8 +330,7 @@ def gen_loss(self, fake_data, *args, **kwargs):
batch_size = fake_data.size()[0]

fake_critic = self.forward(fake_data, *args, **kwargs)
real_labels = self.get_targets(batch_size, 1.0, fake_data)
gen_loss = self.disc_loss(fake_critic, real_labels)
gen_loss = torch.mean(fake_critic)
return gen_loss

def recon_loss(self, recon_obs, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions pytassim/toolbox/discriminator/zero_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def _get_train_losses(self, real_data, fake_data, *args, **kwargs):

real_data.requires_grad_()
real_critic = self.forward(real_data, *args, **kwargs)
real_labels = self.get_targets(batch_size, 1.0, real_data)
real_labels = self.get_targets(batch_size, 0.0, real_data)
real_loss = self.disc_loss(real_critic, real_labels)

fake_data.requires_grad_()
fake_critic = self.forward(fake_data, *args, **kwargs)
fake_labels = self.get_targets(batch_size, 0.0, real_data)
fake_labels = self.get_targets(batch_size, 1.0, real_data)
fake_loss = self.disc_loss(fake_critic, fake_labels)

total_loss = real_loss + fake_loss
Expand Down

0 comments on commit 1378c77

Please sign in to comment.