-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
77 lines (61 loc) · 2.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
from common import TRAIN_FOLDER, VAL_FOLDER
from balancedaccuracy import BalancedAccuracy
from network import Net
from transforms import TrainingTransform, ValidationTransform
# NOTE: You do not need to change this file
# Make sure your other code works around this
BATCH_SIZE = 8
def train(args):
# Setup the ImageFolder Dataset
trainset = torchvision.datasets.ImageFolder(
TRAIN_FOLDER, transform=TrainingTransform
)
validationset = torchvision.datasets.ImageFolder(
VAL_FOLDER, transform=ValidationTransform
)
# Extract number of classes to define network architecture
nClasses = len(trainset.classes)
# Create data loader
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
validationloader = DataLoader(validationset, batch_size=BATCH_SIZE, shuffle=True)
# Create the network, the optimizer and the loss function
net = Net(nClasses)
optim = torch.optim.Adam(net.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()
# Train some epochs
bacc = BalancedAccuracy(nClasses)
for epoch in range(int(args.epochs)):
for loader in [trainloader, validationloader]:
if loader == trainloader:
net.train()
training = True
label = "train:"
else:
net.eval()
training = False
label = "val: "
total_loss = 0
total_cnt = 0
bacc.reset()
bar = tqdm(loader)
for batch, labels in bar:
optim.zero_grad()
out = net(batch)
#assert(out.shape[0] == BATCH_SIZE)
assert(out.shape[1] == nClasses)
bacc.update(out, labels)
loss = criterion(out, labels)
total_loss += loss.item()
total_cnt += batch.shape[0]
loss.backward()
optim.step()
bar.set_description(
f"{label} {epoch+1:3}/{int(args.epochs)} loss={100.0 * total_loss / total_cnt:10.5f} bacc={100.0 * bacc.getBACC():.2f}%"
)
# Save a checkpoint after each epoch
torch.save({"model": net.state_dict(), "classes": trainset.classes}, "model.pt")