-
Notifications
You must be signed in to change notification settings - Fork 0
/
VAE_MNIST_TEST.py
41 lines (30 loc) · 1.04 KB
/
VAE_MNIST_TEST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from networks import VAE_net, VAE_net_64, VanillaVAE
import torch
import numpy as np
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("testing")
model_file_name = "models/VAE_MNIST_100_vanilla.model"
net = VanillaVAE()
batch_size = 16
test_mnist = datasets.MNIST(
root="data",
train=False,
download=True,
transform=transforms.Compose(
[transforms.Resize(64), transforms.ToTensor()]
),
)
test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=batch_size, shuffle=True)
net.load_state_dict(torch.load(model_file_name, map_location=device))
net = net.to(device)
net.eval()
for batch, (x, y) in enumerate(test_loader):
x, y = x.to(device), y.to(device)
plt.imshow(x[0][0].to("cpu"), "gray")
plt.show()
output, original, mu, logVar = net(x)
# output[0][0][0] since mu and std also get returend, otherwise output[0][0]
plt.imshow(output[0][0].to("cpu").detach(), "gray")
plt.show()