-
Notifications
You must be signed in to change notification settings - Fork 9
/
mnist_backprop.py
109 lines (95 loc) · 3.87 KB
/
mnist_backprop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time
import hydra
import torch
import torch.nn.functional as F
import torchvision
from omegaconf import DictConfig, OmegaConf
from torch.utils import tensorboard
from fwdgrad.loss import xent
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
@hydra.main(config_path="./configs/", config_name="config.yaml")
def train_model(cfg: DictConfig):
use_cuda = torch.cuda.is_available()
device = torch.device(f"cuda:{cfg.device_id}" if use_cuda else "cpu")
total_epochs = cfg.epochs
grad_clipping = cfg.grad_clipping
# Summary
writer = tensorboard.writer.SummaryWriter(os.path.join(os.getcwd(), "logs/backprop"))
# Dataset creation
input_size = 1 # Channel size
transform = [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
if "NeuralNet" in cfg.model._target_:
transform.append(torchvision.transforms.Lambda(torch.flatten))
mnist_train = torchvision.datasets.MNIST(
"/tmp/data",
train=True,
download=True,
transform=torchvision.transforms.Compose(transform),
)
mnist_test = torchvision.datasets.MNIST(
"/tmp/data",
train=False,
download=True,
transform=torchvision.transforms.Compose(transform),
)
input_size = mnist_train.data.shape[1] * mnist_train.data.shape[2]
else:
mnist_train = torchvision.datasets.MNIST(
"/tmp/data",
train=True,
download=True,
transform=torchvision.transforms.Compose(transform),
)
mnist_test = torchvision.datasets.MNIST(
"/tmp/data",
train=False,
download=True,
transform=torchvision.transforms.Compose(transform),
)
train_loader = hydra.utils.instantiate(cfg.dataset, dataset=mnist_train)
test_loader = hydra.utils.instantiate(cfg.dataset, dataset=mnist_test)
output_size = len(mnist_train.classes)
model: torch.nn.Module = hydra.utils.instantiate(cfg.model, input_size=input_size, output_size=output_size)
model.to(device)
model.float()
model.train()
params = model.parameters()
optimizer: torch.optim.Optimizer = hydra.utils.instantiate(cfg.optimizer, params=params)
optimizer.zero_grad(set_to_none=True)
scheduler: torch.optim.lr_scheduler._LRScheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
steps = 0
t_total = 0.0
for epoch in range(total_epochs):
t0 = time.perf_counter()
for batch in train_loader:
steps += 1
images, labels = batch
loss = xent(model, images.to(device), labels.to(device))
loss.backward()
if grad_clipping > 0:
torch.nn.utils.clip_grad.clip_grad_norm_(
parameters=params, max_norm=grad_clipping, error_if_nonfinite=True
)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
writer.add_scalar("Loss/train_loss", loss, steps)
writer.add_scalar("Misc/lr", scheduler.get_last_lr()[0], steps)
t1 = time.perf_counter()
t_total += t1 - t0
writer.add_scalar("Time/batch_time", t1 - t0, steps)
writer.add_scalar("Time/sps", steps / t_total, steps)
print(f"Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item():.4f}, Time (s): {t1 - t0:.4f}")
print("Mean time:", t_total / total_epochs)
# Test
acc = 0
for batch in test_loader:
images, labels = batch
out = model(images.to(device))
pred = F.softmax(out, dim=-1).argmax(dim=-1)
acc += (pred == labels.to(device)).sum()
writer.add_scalar("Test/accuracy", acc / len(mnist_test), steps)
print(f"Test accuracy: {(acc / len(mnist_test)).item():.4f}")
if __name__ == "__main__":
train_model()