-
Notifications
You must be signed in to change notification settings - Fork 15
/
d3pm_runner.py
425 lines (342 loc) · 14.7 KB
/
d3pm_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm import tqdm
blk = lambda ic, oc: nn.Sequential(
nn.Conv2d(ic, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
)
blku = lambda ic, oc: nn.Sequential(
nn.Conv2d(ic, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.ConvTranspose2d(oc, oc, 2, stride=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
)
class DummyX0Model(nn.Module):
def __init__(self, n_channel: int, N: int = 16) -> None:
super(DummyX0Model, self).__init__()
self.down1 = blk(n_channel, 16)
self.down2 = blk(16, 32)
self.down3 = blk(32, 64)
self.down4 = blk(64, 512)
self.down5 = blk(512, 512)
self.up1 = blku(512, 512)
self.up2 = blku(512 + 512, 64)
self.up3 = blku(64, 32)
self.up4 = blku(32, 16)
self.convlast = blk(16, 16)
self.final = nn.Conv2d(16, N * n_channel, 1, bias=False)
self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)
self.cond_embedding_1 = nn.Embedding(10, 16)
self.cond_embedding_2 = nn.Embedding(10, 32)
self.cond_embedding_3 = nn.Embedding(10, 64)
self.cond_embedding_4 = nn.Embedding(10, 512)
self.cond_embedding_5 = nn.Embedding(10, 512)
self.cond_embedding_6 = nn.Embedding(10, 64)
self.temb_1 = nn.Linear(32, 16)
self.temb_2 = nn.Linear(32, 32)
self.temb_3 = nn.Linear(32, 64)
self.temb_4 = nn.Linear(32, 512)
self.N = N
def forward(self, x, t, cond) -> torch.Tensor:
x = (2 * x.float() / self.N) - 1.0
t = t.float().reshape(-1, 1) / 1000
t_features = [torch.sin(t * 3.1415 * 2**i) for i in range(16)] + [
torch.cos(t * 3.1415 * 2**i) for i in range(16)
]
tx = torch.cat(t_features, dim=1).to(x.device)
t_emb_1 = self.temb_1(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_2 = self.temb_2(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_3 = self.temb_3(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_4 = self.temb_4(tx).unsqueeze(-1).unsqueeze(-1)
cond_emb_1 = self.cond_embedding_1(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_2 = self.cond_embedding_2(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_3 = self.cond_embedding_3(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_4 = self.cond_embedding_4(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_5 = self.cond_embedding_5(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_6 = self.cond_embedding_6(cond).unsqueeze(-1).unsqueeze(-1)
x1 = self.down1(x) + t_emb_1 + cond_emb_1
x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2 + cond_emb_2
x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3 + cond_emb_3
x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4 + cond_emb_4
x5 = self.down5(nn.functional.avg_pool2d(x4, 2))
x5 = (
self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(x5.shape)
)
y = self.up1(x5) + cond_emb_5
y = (
self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(y.shape)
)
y = self.up2(torch.cat([x4, y], dim=1)) + cond_emb_6
y = (
self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(y.shape)
)
y = self.up3(y)
y = self.up4(y)
y = self.convlast(y)
y = self.final(y)
# reshape to B, C, H, W, N
y = (
y.reshape(y.shape[0], -1, self.N, *x.shape[2:])
.transpose(2, -1)
.contiguous()
)
return y
class D3PM(nn.Module):
def __init__(
self,
x0_model: nn.Module,
n_T: int,
num_classes: int = 10,
forward_type="uniform",
hybrid_loss_coeff=0.001,
) -> None:
super(D3PM, self).__init__()
self.x0_model = x0_model
self.n_T = n_T
self.hybrid_loss_coeff = hybrid_loss_coeff
steps = torch.arange(n_T + 1, dtype=torch.float64) / n_T
alpha_bar = torch.cos((steps + 0.008) / 1.008 * torch.pi / 2)
self.beta_t = torch.minimum(
1 - alpha_bar[1:] / alpha_bar[:-1], torch.ones_like(alpha_bar[1:]) * 0.999
)
# self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]
self.eps = 1e-6
self.num_classses = num_classes
q_onestep_mats = []
q_mats = [] # these are cumulative
for beta in self.beta_t:
if forward_type == "uniform":
mat = torch.ones(num_classes, num_classes) * beta / num_classes
mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)
q_onestep_mats.append(mat)
else:
raise NotImplementedError
q_one_step_mats = torch.stack(q_onestep_mats, dim=0)
q_one_step_transposed = q_one_step_mats.transpose(
1, 2
) # this will be used for q_posterior_logits
q_mat_t = q_onestep_mats[0]
q_mats = [q_mat_t]
for idx in range(1, self.n_T):
q_mat_t = q_mat_t @ q_onestep_mats[idx]
q_mats.append(q_mat_t)
q_mats = torch.stack(q_mats, dim=0)
self.logit_type = "logit"
# register
self.register_buffer("q_one_step_transposed", q_one_step_transposed)
self.register_buffer("q_mats", q_mats)
assert self.q_mats.shape == (
self.n_T,
num_classes,
num_classes,
), self.q_mats.shape
def _at(self, a, t, x):
# t is 1-d, x is integer value of 0 to num_classes - 1
bs = t.shape[0]
t = t.reshape((bs, *[1] * (x.dim() - 1)))
# out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]
return a[t - 1, x, :]
def q_posterior_logits(self, x_0, x_t, t):
# if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.
# otherwise, we return the L_{t-1} loss.
# Also, we never have t == 0.
# if x_0 is integer, we convert it to one-hot.
if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:
x_0_logits = torch.log(
torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps
)
else:
x_0_logits = x_0.clone()
assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(
f"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}"
)
# Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.
# fact1 is "guess of x_{t-1}" from x_t
# fact2 is "guess of x_{t-1}" from x_0
fact1 = self._at(self.q_one_step_transposed, t, x_t)
softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes
qmats2 = self.q_mats[t - 2].to(dtype=softmaxed.dtype)
# bs, num_classes, num_classes
fact2 = torch.einsum("b...c,bcd->b...d", softmaxed, qmats2)
out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)
t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))
bc = torch.where(t_broadcast == 1, x_0_logits, out)
return bc
def vb(self, dist1, dist2):
# flatten dist1 and dist2
dist1 = dist1.flatten(start_dim=0, end_dim=-2)
dist2 = dist2.flatten(start_dim=0, end_dim=-2)
out = torch.softmax(dist1 + self.eps, dim=-1) * (
torch.log_softmax(dist1 + self.eps, dim=-1)
- torch.log_softmax(dist2 + self.eps, dim=-1)
)
return out.sum(dim=-1).mean()
def q_sample(self, x_0, t, noise):
# forward process, x_0 is the clean input.
logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)
noise = torch.clip(noise, self.eps, 1.0)
gumbel_noise = -torch.log(-torch.log(noise))
return torch.argmax(logits + gumbel_noise, dim=-1)
def model_predict(self, x_0, t, cond):
# this part exists because in general, manipulation of logits from model's logit
# so they are in form of x_0's logit might be independent to model choice.
# for example, you can convert 2 * N channel output of model output to logit via get_logits_from_logistic_pars
# they introduce at appendix A.8.
predicted_x0_logits = self.x0_model(x_0, t, cond)
return predicted_x0_logits
def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor:
"""
Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.
x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1
"""
t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)
x_t = self.q_sample(
x, t, torch.rand((*x.shape, self.num_classses), device=x.device)
)
# x_t is same shape as x
assert x_t.shape == x.shape, print(
f"x_t.shape: {x_t.shape}, x.shape: {x.shape}"
)
# we use hybrid loss.
predicted_x0_logits = self.model_predict(x_t, t, cond)
# based on this, we first do vb loss.
true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)
pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)
vb_loss = self.vb(true_q_posterior_logits, pred_q_posterior_logits)
predicted_x0_logits = predicted_x0_logits.flatten(start_dim=0, end_dim=-2)
x = x.flatten(start_dim=0, end_dim=-1)
ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)
return self.hybrid_loss_coeff * vb_loss + ce_loss, {
"vb_loss": vb_loss.detach().item(),
"ce_loss": ce_loss.detach().item(),
}
def p_sample(self, x, t, cond, noise):
predicted_x0_logits = self.model_predict(x, t, cond)
pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)
noise = torch.clip(noise, self.eps, 1.0)
not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))
gumbel_noise = -torch.log(-torch.log(noise))
sample = torch.argmax(
pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1
)
return sample
def sample(self, x, cond=None):
for t in reversed(range(1, self.n_T)):
t = torch.tensor([t] * x.shape[0], device=x.device)
x = self.p_sample(
x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
)
return x
def sample_with_image_sequence(self, x, cond=None, stride=10):
steps = 0
images = []
for t in reversed(range(1, self.n_T)):
t = torch.tensor([t] * x.shape[0], device=x.device)
x = self.p_sample(
x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
)
steps += 1
if steps % stride == 0:
images.append(x)
# if last step is not divisible by stride, we add the last image.
if steps % stride != 0:
images.append(x)
return images
if __name__ == "__main__":
N = 2 # number of classes for discretized state per pixel
d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes=N, hybrid_loss_coeff=0.0).cuda()
print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
dataset = MNIST(
"./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Pad(2),
]
),
)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)
optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=1e-3)
d3pm.train()
n_epoch = 400
device = "cuda"
global_step = 0
for i in range(n_epoch):
pbar = tqdm(dataloader)
loss_ema = None
for x, cond in pbar:
optim.zero_grad()
x = x.to(device)
cond = cond.to(device)
# discritize x to N bins
x = (x * (N - 1)).round().long().clamp(0, N - 1)
loss, info = d3pm(x, cond)
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.1)
with torch.no_grad():
param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
pbar.set_description(
f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
)
optim.step()
global_step += 1
if global_step % 300 == 1:
d3pm.eval()
with torch.no_grad():
cond = torch.arange(0, 4).cuda() % 10
init_noise = torch.randint(0, N, (4, 1, 32, 32)).cuda()
images = d3pm.sample_with_image_sequence(
init_noise, cond, stride=40
)
# image sequences to gif
gif = []
for image in images:
x_as_image = make_grid(image.float() / (N - 1), nrow=2)
img = x_as_image.permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
gif.append(Image.fromarray(img))
gif[0].save(
f"contents/sample_{global_step}.gif",
save_all=True,
append_images=gif[1:],
duration=100,
loop=0,
)
last_img = gif[-1]
last_img.save(f"contents/sample_{global_step}_last.png")
d3pm.train()