forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_vae.py
137 lines (110 loc) · 4.91 KB
/
mnist_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic variational autoencoder (VAE) on binarized MNIST using Numpy and JAX.
This file uses the stax network definition library and the optimizers
optimization library.
"""
import os
import time
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit, grad, lax, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, FanOut, Relu, Softplus
from examples import datasets
def gaussian_kl(mu, sigmasq):
"""KL divergence from a diagonal Gaussian to the standard Gaussian."""
return -0.5 * jnp.sum(1. + jnp.log(sigmasq) - mu**2. - sigmasq)
def gaussian_sample(rng, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + jnp.sqrt(sigmasq) * random.normal(rng, mu.shape)
def bernoulli_logpdf(logits, x):
"""Bernoulli log pdf of data x given logits."""
return -jnp.sum(jnp.logaddexp(0., jnp.where(x, -1., 1.) * logits))
def elbo(rng, params, images):
"""Monte Carlo estimate of the negative evidence lower bound."""
enc_params, dec_params = params
mu_z, sigmasq_z = encode(enc_params, images)
logits_x = decode(dec_params, gaussian_sample(rng, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
def image_sample(rng, params, nrow, ncol):
"""Sample images from the generative model."""
_, dec_params = params
code_rng, img_rng = random.split(rng)
logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
sampled_images = random.bernoulli(img_rng, jnp.logaddexp(0., logits))
return image_grid(nrow, ncol, sampled_images, (28, 28))
def image_grid(nrow, ncol, imagevecs, imshape):
"""Reshape a stack of image vectors into an image grid for plotting."""
images = iter(imagevecs.reshape((-1,) + imshape))
return jnp.vstack([jnp.hstack([next(images).T for _ in range(ncol)][::-1])
for _ in range(nrow)]).T
encoder_init, encode = stax.serial(
Dense(512), Relu,
Dense(512), Relu,
FanOut(2),
stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)
decoder_init, decode = stax.serial(
Dense(512), Relu,
Dense(512), Relu,
Dense(28 * 28),
)
if __name__ == "__main__":
step_size = 0.001
num_epochs = 100
batch_size = 32
nrow, ncol = 10, 10 # sampled image grid size
test_rng = random.key(1) # fixed prng key for evaluation
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
num_batches = num_complete_batches + bool(leftover)
enc_init_rng, dec_init_rng = random.split(random.key(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
train_images = jax.device_put(train_images)
test_images = jax.device_put(test_images)
def binarize_batch(rng, i, images):
i = i % num_batches
batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
return random.bernoulli(rng, batch)
@jit
def run_epoch(rng, opt_state, images):
def body_fun(i, opt_state):
elbo_rng, data_rng = random.split(random.fold_in(rng, i))
batch = binarize_batch(data_rng, i, images)
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
g = grad(loss)(get_params(opt_state))
return opt_update(i, g, opt_state)
return lax.fori_loop(0, num_batches, body_fun, opt_state)
@jit
def evaluate(opt_state, images):
params = get_params(opt_state)
elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
binarized_test = random.bernoulli(data_rng, images)
test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
sampled_images = image_sample(image_rng, params, nrow, ncol)
return test_elbo, sampled_images
opt_state = opt_init(init_params)
for epoch in range(num_epochs):
tic = time.time()
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
test_elbo, sampled_images = evaluate(opt_state, test_images)
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)