-
Notifications
You must be signed in to change notification settings - Fork 1
/
discretized_mix_logistic.py
83 lines (72 loc) · 4.32 KB
/
discretized_mix_logistic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import torch
from torch.nn import functional as F
from helpers import const_min, const_max, log_prob_from_logits
def discretized_mix_logistic_loss(x, l, low_bit=False):
""" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
# Adapted from https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py
xs = [s for s in x.shape] # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = [s for s in l.shape] # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:, :, :, :nr_mix]
l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3])
means = l[:, :, :, :, :nr_mix]
log_scales = const_max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
x = torch.reshape(x, xs + [1]) + torch.zeros(xs + [nr_mix]).to(x.device)
m2 = torch.reshape(means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix])
m3 = torch.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0], xs[1], xs[2], 1, nr_mix])
means = torch.cat([torch.reshape(means[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix]), m2, m3], dim=3)
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
if low_bit:
plus_in = inv_stdv * (centered_x + 1. / 31.)
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1. / 31.)
else:
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = torch.sigmoid(min_in)
log_cdf_plus = plus_in - F.softplus(plus_in)
log_one_minus_cdf_min = -F.softplus(min_in)
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
if low_bit:
log_probs = torch.where(x < -0.999,
log_cdf_plus,
torch.where(x > 0.999,
log_one_minus_cdf_min,
torch.where(cdf_delta > 1e-5,
torch.log(const_max(cdf_delta, 1e-12)),
log_pdf_mid - np.log(15.5))))
else:
log_probs = torch.where(x < -0.999,
log_cdf_plus,
torch.where(x > 0.999,
log_one_minus_cdf_min,
torch.where(cdf_delta > 1e-5,
torch.log(const_max(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))
log_probs = log_probs.sum(dim=3) + log_prob_from_logits(logit_probs)
mixture_probs = torch.logsumexp(log_probs, -1)
return -1. * mixture_probs.sum(dim=[1, 2]) / np.prod(xs[1:])
def sample_from_discretized_mix_logistic(l, nr_mix):
ls = [s for s in l.shape]
xs = ls[:-1] + [3]
logit_probs = l[:, :, :, :nr_mix]
l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3])
eps = torch.empty(logit_probs.shape, device=l.device).uniform_(1e-5, 1. - 1e-5)
amax = torch.argmax(logit_probs - torch.log(-torch.log(eps)), dim=3)
sel = F.one_hot(amax, num_classes=nr_mix).float()
sel = torch.reshape(sel, xs[:-1] + [1, nr_mix])
means = (l[:, :, :, :, :nr_mix] * sel).sum(dim=4)
log_scales = const_max((l[:, :, :, :, nr_mix:nr_mix * 2] * sel).sum(dim=4), -7.)
coeffs = (torch.tanh(l[:, :, :, :, nr_mix * 2:nr_mix * 3]) * sel).sum(dim=4)
u = torch.empty(means.shape, device=means.device).uniform_(1e-5, 1. - 1e-5)
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
x0 = const_min(const_max(x[:, :, :, 0], -1.), 1.)
x1 = const_min(const_max(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, -1.), 1.)
x2 = const_min(const_max(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, -1.), 1.)
return torch.cat([torch.reshape(x0, xs[:-1] + [1]), torch.reshape(x1, xs[:-1] + [1]), torch.reshape(x2, xs[:-1] + [1])], dim=3)