-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_mnist.py
91 lines (70 loc) · 2.87 KB
/
test_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader
from data.dataloaders.mnist_dataset import MnistDataset
from utils.model import Model
from imagevae.mnist_resnet import MnistResNet
def calculate_metric(metric_fn, true_y, pred_y):
if "average" in inspect.getfullargspec(metric_fn).args:
return metric_fn(true_y, pred_y, average="macro")
else:
return metric_fn(true_y, pred_y)
def print_scores(p, r, f1, a, batch_size):
for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
print(f"\t{name.rjust(14, ' ')}: {sum(scores) / batch_size:.4f}")
start_ts = time.time()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5
model = MnistResNet().to(device)
dataset = MnistDataset()
train_loader, val_loader, _ = dataset.data_loaders(batch_size=256)
losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())
batches = len(train_loader)
val_batches = len(val_loader)
# training loop + eval loop
for epoch in range(epochs):
total_loss = 0
progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
model.train()
for i, data in progress:
X, y = data[0].to(device), data[1].to(device)
model.zero_grad()
outputs = model(X)
loss = loss_function(outputs, y)
loss.backward()
optimizer.step()
current_loss = loss.item()
total_loss += current_loss
progress.set_description("Loss: {:.4f}".format(total_loss / (i + 1)))
torch.cuda.empty_cache()
val_losses = 0
precision, recall, f1, accuracy = [], [], [], []
model.eval()
with torch.no_grad():
for i, data in enumerate(val_loader):
X, y = data[0].to(device), data[1].to(device)
outputs = model(X)
val_losses += loss_function(outputs, y)
predicted_classes = torch.max(outputs, 1)[1]
for acc, metric in zip((precision, recall, f1, accuracy),
(precision_score, recall_score, f1_score, accuracy_score)):
acc.append(
calculate_metric(metric, y.cpu(), predicted_classes.cpu())
)
print(
f"Epoch {epoch + 1}/{epochs}, training loss: {total_loss / batches}, validation loss: {val_losses / val_batches}")
print_scores(precision, recall, f1, accuracy, val_batches)
losses.append(total_loss / batches)
model.save_checkpoint(epoch_num=epoch)
print(losses)
print(f"Training time: {time.time() - start_ts}s")
model.save()