-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A question about "a[t - 1, x, :]" #4
Comments
Hi @QWTforGithub , the x_0 in the paper is a one-hot row vector. When x_0 multiplies with Q_t in the formula, in the code, it effectively acts as indexing from the corresponding position in Q_t, I guess that's how it works. I've also been following d3pm recently, feel free to reach out for discussion anytime. |
Thank you for your reply. Now I have another question about "vb loss":
My understanding of "vb loss" is "Variational Bayes Loss", that is, an MSE loss + a KL divergence. (torch.log_softmax(dist1 + self.eps, dim=-1) - torch.log_softmax(dist2 + self.eps, dim=-1))" can be considered as the difference of two posterior distributions (a KL divergence). But, why should What about multiplying by "torch.softmax(dist1 + self.eps, dim=-1)"? Looking forward to your reply. |
@QWTforGithub I feel that this should just be a simple calculation of the KL divergence loss. I feel that the step 'torch.softmax(dist1 + self.eps, dim=-1)' is just converting dist1 from logits to probabilities. |
Thank you very much for your answer! May I ask you have tried to derive the formula in D3PM (Eq 3)? Why does the posterior distribution q(xt-1|xt,x0) get the form of Eq3? This transition probability p: |
@QWTforGithub Sorry, I feel it's a bit inappropriate to discuss here. I haven't found your email yet. If you don't mind, we can discuss on WeChat or elsewhere. Here's my email: qixuemaa@gmail.com. |
Thank you very much. I have sent a message to your email. I hope to communicate with you further. |
Thank you very much for posting the pytorch implementation of D3PM. I have questions about the following function:
*def _at(self, a, t, x):
# t is 1-d, x is integer value of 0 to num_classes - 1
bs = t.shape[0]
t = t.reshape((bs, [1] * (x.dim() - 1)))
# out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]
return a[t - 1, x, :]
This function seems to convert x0 to xt based on accumulated Qt. However, in the original paper, the conversion of x0 to xt is done by Qt multiplied by x0 (Eq 3). But this function does not seem to express this meaning. This is just selecting some values from Qt, and there no exist any calculation relationship between x0 and Qt (in the original D3PM code, the conversion from x0 to xt is also achieved in this way). At the same time, I also noticed that the xt obtained in this way was directly sent to the network. What is the way to convert from x0 to xt? Or is there something wrong with my understanding of this code? Looking forward to your reply.
The text was updated successfully, but these errors were encountered: