Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do we need to add .detach() after var in INN.BatchNorm1d? #2

Open
Zhangyanbo opened this issue Apr 25, 2021 · 2 comments
Open

Do we need to add .detach() after var in INN.BatchNorm1d? #2

Zhangyanbo opened this issue Apr 25, 2021 · 2 comments
Labels
question Further information is requested

Comments

@Zhangyanbo
Copy link
Contributor

In INN.BatchNorm1d, the forward function is:

def forward(self, x, log_p=0, log_det_J=0):
        
        if self.compute_p:
            if not self.training:
                # if in self.eval()
                var = self.running_var # [dim]
            else:
                # if in training
                # TODO: Do we need to add .detach() after var?
                var = torch.var(x, dim=0, unbiased=False) # [dim]

            x = super(BatchNorm1d, self).forward(x)

            log_det = -0.5 * torch.log(var + self.eps)
            log_det = torch.sum(log_det, dim=-1)

            return x, log_p, log_det_J + log_det
        else:
            return super(BatchNorm1d, self).forward(x)

Do we need to requires var has gradient information? It seems not training BatchNorm1d, but training modules before it. Is there any references on this?

@Zhangyanbo
Copy link
Contributor Author

Compare to nn.BatchNorm1d:

x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3, affine=False)

bn(x)

The output is:

tensor([[-1.6941,  0.2933, -0.2451],
        [-0.1313, -0.2711,  1.4740],
        [ 0.2754, -0.2282,  0.4445],
        [ 0.1287, -1.4409, -0.0721],
        [ 1.4213,  1.6469, -1.6014]])

So, if we do not require affine in bn, we don't need gradient for BatchNorm.

@Zhangyanbo
Copy link
Contributor Author

Zhangyanbo commented Apr 26, 2021

Experiments show that if we add .detach(), the training loss will not decrease. While if I added .detach(), it works. So, in the latest version, I added a parameter requires_grad:bool to INN.BatchNorm1d.

@Zhangyanbo Zhangyanbo changed the title Do we need to add .detach() after var in INN.BatchNorm1d? Do we need to add .detach() after var in INN.BatchNorm1d? Apr 27, 2021
@Zhangyanbo Zhangyanbo added the question Further information is requested label Apr 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant