Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(lstm_vae) Loss function arguments doesn't match in training #19

Open
Vhunon opened this issue Jul 27, 2023 · 3 comments
Open

(lstm_vae) Loss function arguments doesn't match in training #19

Vhunon opened this issue Jul 27, 2023 · 3 comments
Assignees
Labels
comp: algorithms 🏅 medium MoSCoW: Should-have needs triage Issue requires triage

Comments

@Vhunon
Copy link

Vhunon commented Jul 27, 2023

Thank you for this wonderful and extensive research and by sharing it publicly.
However I have encountered an issue in calculating the loss function of the LSTM VAE.

Here you wrote
loss = self.loss_function(x, logvar, mu, logvar, 'mean')

meanwhile in here:
def loss_function(self, x, x_hat, mean, log_var, reduction_type):
the input arguments doesnt match up or is this intended?

@SebastianSchmidl
Copy link
Member

Thanks for checking out our research and code.

Do I understand you correctly that you are confused because we use logvar twice as an argument to the loss function? The number of arguments matches and most argument names match as well:

self.loss_function(
  x=x,
  x_hat=logvar,
  mean=mu,
  log_var=logvar,
  reduction_type='mean'
)

@SebastianSchmidl SebastianSchmidl added question Further information is requested 🏅 medium MoSCoW: Should-have comp: algorithms needs triage Issue requires triage labels Jul 28, 2023
@Vhunon
Copy link
Author

Vhunon commented Jul 28, 2023

Yes.

In taining, the loss function takes these arguments
loss = self.loss_function(x, logvar, mu, logvar, 'mean')

The loss function is defined as:

def loss_function(self, x, x_hat, mean, log_var, reduction_type):
      reproduction_loss = nn.functional.mse_loss(x_hat, x, reduction=reduction_type)
      KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
      return KLD + reproduction_loss

The MSE is then calculated as reconstruction error of x and x_hat, where logvar is passed as the argument for that?

Please correct me if Im wrong. Thank you

@SebastianSchmidl
Copy link
Member

Thank you for the clarification.

I am not familiar enough with this implementation to judge this.
@wenig can you take a look?

@SebastianSchmidl SebastianSchmidl removed the question Further information is requested label Jul 28, 2023
@SebastianSchmidl SebastianSchmidl changed the title Loss function arguments doesnt match in training (lstm_vae) Loss function arguments doesn't match in training Jan 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp: algorithms 🏅 medium MoSCoW: Should-have needs triage Issue requires triage
Projects
None yet
Development

No branches or pull requests

3 participants