Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about "a[t - 1, x, :]" #4

Open
QWTforGithub opened this issue May 7, 2024 · 6 comments
Open

A question about "a[t - 1, x, :]" #4

QWTforGithub opened this issue May 7, 2024 · 6 comments

Comments

@QWTforGithub
Copy link

QWTforGithub commented May 7, 2024

Thank you very much for posting the pytorch implementation of D3PM. I have questions about the following function:

*def _at(self, a, t, x):
# t is 1-d, x is integer value of 0 to num_classes - 1
bs = t.shape[0]
t = t.reshape((bs, [1] * (x.dim() - 1)))
# out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]
return a[t - 1, x, :]

This function seems to convert x0 to xt based on accumulated Qt. However, in the original paper, the conversion of x0 to xt is done by Qt multiplied by x0 (Eq 3). But this function does not seem to express this meaning. This is just selecting some values from Qt, and there no exist any calculation relationship between x0 and Qt (in the original D3PM code, the conversion from x0 to xt is also achieved in this way). At the same time, I also noticed that the xt obtained in this way was directly sent to the network. What is the way to convert from x0 to xt? Or is there something wrong with my understanding of this code? Looking forward to your reply.

@qixuema
Copy link

qixuema commented May 8, 2024

Hi @QWTforGithub , the x_0 in the paper is a one-hot row vector. When x_0 multiplies with Q_t in the formula, in the code, it effectively acts as indexing from the corresponding position in Q_t, I guess that's how it works.

I've also been following d3pm recently, feel free to reach out for discussion anytime.

@QWTforGithub
Copy link
Author

Hi @QWTforGithub , the x_0 in the paper is a one-hot row vector. When x_0 multiplies with Q_t in the formula, in the code, it effectively acts as indexing from the corresponding position in Q_t, I guess that's how it works.

I've also been following d3pm recently, feel free to reach out for discussion anytime.

Thank you for your reply. Now I have another question about "vb loss":

**def vb(self, dist1, dist2):

    # flatten dist1 and dist2
    dist1 = dist1.flatten(start_dim=0, end_dim=-2)
    dist2 = dist2.flatten(start_dim=0, end_dim=-2)

    out = torch.softmax(dist1 + self.eps, dim=-1) * (
        torch.log_softmax(dist1 + self.eps, dim=-1)
        - torch.log_softmax(dist2 + self.eps, dim=-1)
    )
    return out.sum(dim=-1).mean()**

My understanding of "vb loss" is "Variational Bayes Loss", that is, an MSE loss + a KL divergence. (torch.log_softmax(dist1 + self.eps, dim=-1) - torch.log_softmax(dist2 + self.eps, dim=-1))" can be considered as the difference of two posterior distributions (a KL divergence). But, why should What about multiplying by "torch.softmax(dist1 + self.eps, dim=-1)"? Looking forward to your reply.

@qixuema
Copy link

qixuema commented May 8, 2024

@QWTforGithub I feel that this should just be a simple calculation of the KL divergence loss.

$$ L(y_{pred}, y_{true}) = y_{true} \cdot \log \frac{y_{true}}{y_{pred}} = y_{true} \cdot (\log y_{true} - \log y_{pred}) $$

I feel that the step 'torch.softmax(dist1 + self.eps, dim=-1)' is just converting dist1 from logits to probabilities.

@QWTforGithub
Copy link
Author

@QWTforGithub I feel that this should just be a simple calculation of the KL divergence loss.

L(ypred,ytrue)=ytrue⋅log⁡ytrueypred=ytrue⋅(log⁡ytrue−log⁡ypred)

I feel that the step 'torch.softmax(dist1 + self.eps, dim=-1)' is just converting dist1 from logits to probabilities.

Thank you very much for your answer! May I ask you have tried to derive the formula in D3PM (Eq 3)? Why does the posterior distribution q(xt-1|xt,x0) get the form of Eq3? This transition probability p:
$$(x_tQ_t^T * x_0\bar{Q_{t-1}})/x_0\bar{Q_t}x^T$$
But not:
$$(x_{t-1}Q_t * x_0\bar{Q}_{t-1})/x_0\bar{Q}_t$$
How is Eq3 derived?

@qixuema
Copy link

qixuema commented May 8, 2024

@QWTforGithub Sorry, I feel it's a bit inappropriate to discuss here. I haven't found your email yet. If you don't mind, we can discuss on WeChat or elsewhere. Here's my email: qixuemaa@gmail.com.

@QWTforGithub
Copy link
Author

@QWTforGithub Sorry, I feel it's a bit inappropriate to discuss here. I haven't found your email yet. If you don't mind, we can discuss on WeChat or elsewhere. Here's my email: qixuemaa@gmail.com.

Thank you very much. I have sent a message to your email. I hope to communicate with you further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants