Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

maybe bugs in loss backward? #15

Open
shinshiner opened this issue Jul 11, 2024 · 3 comments
Open

maybe bugs in loss backward? #15

shinshiner opened this issue Jul 11, 2024 · 3 comments

Comments

@shinshiner
Copy link

Hello, thanks for your great work. When exploring the code, I found something confusing and want to sure whether it is a bug.

In https://github.com/FoundationVision/OmniTokenizer/blob/main/OmniTokenizer/omnitokenizer.py#L548, you generate fake sample and let it goes through the generator and discriminator to compute the adversial loss of generator, when backward this loss, the weights of discriminator will also accumulate grads of generator loss.

However, since you use grad accumulation, the optimizer will not zero grads before backward, then we will update the discriminator using some grads from generator loss!

Is it a bug? Or you write it due to some inside mathematics?

@dreamofuture
Copy link

also confuse me, and a related issue: self.forward(x, optimizer_idx=1) will call encoder&decoder again on the same x, why not share with self.forward(x, optimizer_idx=0)

@hyc9
Copy link

hyc9 commented Jul 22, 2024

@shinshiner Hi, I found that the usual method under similar frameworks is:
return self.forward(x, optimizer_idx=0) loss (referred to as generator loss) in training_step when optimizer_idx=0, and the same is true when =1. i.e. :

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch['video']
        if optimizer_idx == 0:
            recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(x, optimizer_idx)
            commitment_loss = vq_output['commitment_loss']
            loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss
        if optimizer_idx == 1:
            discloss = self.forward(x, optimizer_idx)
            loss = discloss
        return loss

I think this way of writing and the author's practice is not fundamentally different. The order of calculation in pytorch lightning is: Loss.backward (), opt.step(), opt.zero_grad ().
For example, if step=0 is used to optimize generator loss, the opt1.zero_grad () operation will only empty the gradient of the generator weights. Then when step=1, first do discloss.backward(), then do opt2.step(), and then the same as you said, the gradient of the discriminator contains the part of the fake sample passed in the previous step. How do you think?

@shinshiner
Copy link
Author

@shinshiner Hi, I found that the usual method under similar frameworks is: return self.forward(x, optimizer_idx=0) loss (referred to as generator loss) in training_step when optimizer_idx=0, and the same is true when =1. i.e. :

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch['video']
        if optimizer_idx == 0:
            recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(x, optimizer_idx)
            commitment_loss = vq_output['commitment_loss']
            loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss
        if optimizer_idx == 1:
            discloss = self.forward(x, optimizer_idx)
            loss = discloss
        return loss

I think this way of writing and the author's practice is not fundamentally different. The order of calculation in pytorch lightning is: Loss.backward (), opt.step(), opt.zero_grad (). For example, if step=0 is used to optimize generator loss, the opt1.zero_grad () operation will only empty the gradient of the generator weights. Then when step=1, first do discloss.backward(), then do opt2.step(), and then the same as you said, the gradient of the discriminator contains the part of the fake sample passed in the previous step. How do you think?

I also noticed this repo, and agree with you. I think we should use requires_grad = False for discriminator when computing generator loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants