diff --git a/pytassim/toolbox/discriminator/standard.py b/pytassim/toolbox/discriminator/standard.py index 0533d55..5820b43 100644 --- a/pytassim/toolbox/discriminator/standard.py +++ b/pytassim/toolbox/discriminator/standard.py @@ -59,7 +59,7 @@ class StandardDisc(object): """ def __init__(self, net,): self.net = net - self.loss_func = torch.nn.BCEWithLogitsLoss() + self.loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean') self.optimizer = None self.grad_optim = True @@ -150,6 +150,7 @@ def disc_loss(self, in_data, labels): loss has the same tensor type as given `in_data`. """ loss = self.loss_func(in_data, labels) + #loss = unscaled_loss * in_data.nelement() / in_data.shape[0] return loss def forward(self, *args, **kwargs): @@ -179,11 +180,11 @@ def _get_train_losses(self, real_data, fake_data, *args, **kwargs): batch_size = real_data.size()[0] real_critic = self.forward(real_data, *args, **kwargs) - real_labels = self.get_targets(batch_size, 1.0, real_data) + real_labels = self.get_targets(batch_size, 0.0, real_data) real_loss = self.disc_loss(real_critic, real_labels) fake_critic = self.forward(fake_data, *args, **kwargs) - fake_labels = self.get_targets(batch_size, 0.0, real_data) + fake_labels = self.get_targets(batch_size, 1.0, real_data) fake_loss = self.disc_loss(fake_critic, fake_labels) total_loss = real_loss + fake_loss @@ -329,8 +330,7 @@ def gen_loss(self, fake_data, *args, **kwargs): batch_size = fake_data.size()[0] fake_critic = self.forward(fake_data, *args, **kwargs) - real_labels = self.get_targets(batch_size, 1.0, fake_data) - gen_loss = self.disc_loss(fake_critic, real_labels) + gen_loss = torch.mean(fake_critic) return gen_loss def recon_loss(self, recon_obs, *args, **kwargs): diff --git a/pytassim/toolbox/discriminator/zero_penalty.py b/pytassim/toolbox/discriminator/zero_penalty.py index 13d9060..a66281f 100644 --- a/pytassim/toolbox/discriminator/zero_penalty.py +++ b/pytassim/toolbox/discriminator/zero_penalty.py @@ -66,12 +66,12 @@ def _get_train_losses(self, real_data, fake_data, *args, **kwargs): real_data.requires_grad_() real_critic = self.forward(real_data, *args, **kwargs) - real_labels = self.get_targets(batch_size, 1.0, real_data) + real_labels = self.get_targets(batch_size, 0.0, real_data) real_loss = self.disc_loss(real_critic, real_labels) fake_data.requires_grad_() fake_critic = self.forward(fake_data, *args, **kwargs) - fake_labels = self.get_targets(batch_size, 0.0, real_data) + fake_labels = self.get_targets(batch_size, 1.0, real_data) fake_loss = self.disc_loss(fake_critic, fake_labels) total_loss = real_loss + fake_loss