diff --git a/pyod/models/so_gaal.py b/pyod/models/so_gaal.py index 31a58735..be1ffd4d 100644 --- a/pyod/models/so_gaal.py +++ b/pyod/models/so_gaal.py @@ -179,10 +179,11 @@ def fit(self, X, y=None): optimizer_d.step() self.train_history['discriminator_loss'].append(d_loss.item()) + + trick_labels = torch.ones(batch_size, 1) if stop == 0: # Train Generator - trick_labels = torch.ones(batch_size, 1) g_loss = criterion( self.discriminator(self.generator(noise)), trick_labels) @@ -198,6 +199,7 @@ def fit(self, X, y=None): trick_labels) self.train_history['generator_loss'].append(g_loss.item()) + if epoch + 1 > self.stop_epochs: stop = 1