-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
145 lines (120 loc) · 5.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 20 15:56:01 2020
@author: lds
"""
from data.DataLoader_ILSVRC import ILSVRC2012 as Dataset
from models.GoogLeNetv2 import GoogLeNetv2 as GoogLeNet
import time, os
import torch
import torch.nn as nn
from torch import optim
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
# from torchsummary import summary
# summary(net, (3, 224, 224))
train_dir = '/media/nickwang/StorageDisk/Dataset/ILSVRC2012/ILSVRC2012_img_train'
val_dir = '/media/nickwang/StorageDisk/Dataset/ILSVRC2012/ILSVRC2012_img_val'
dirname_to_classname_path = './data/dirname_to_classname'
pretrained_weights = None
num_epoch = 80
batch_size_train = 32
momentum = 0.9
learning_rate = 0.045
num_classes = 100
trainset = Dataset(train_dir, dirname_to_classname_path, num_classes)
testset = Dataset(val_dir, dirname_to_classname_path, num_classes)
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train, shuffle=True, num_workers=8)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_train, shuffle=False, num_workers=8)
net = GoogLeNet(num_classes, mode='train').cuda()
net.init_weights('KAMING')
if pretrained_weights != None:
net_pretrain = torch.load(pretrained_weights)
net.load_state_dict(net_pretrain)
criterion = nn.CrossEntropyLoss().cuda()
optimizer= optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=0.0001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5) # original = 0.96
train_loss_list = list()
train_accuracy_list = list()
test_loss_list = list()
test_accuracy_list = list()
for epoch in range(num_epoch):
time_s = time.time()
print('Epoch : ', epoch + 1, optimizer)
net.train()
for batch_idx, (img, y_GT) in enumerate(train_dataloader):
img = img.permute(0, 3, 1, 2).float()
y_PD, aux1, aux2 = net(img.cuda())
loss = criterion(y_PD, y_GT.long().cuda()) + 0.3 * criterion(aux1, y_GT.long().cuda()) + 0.3 * criterion(aux2, y_GT.long().cuda())
acc_batch = np.equal(y_GT.numpy(), np.argmax(y_PD.cpu().data.numpy(), axis=1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx+1) % (len(train_dataloader)//40) == 0:
print("Epoch {}, Training Data Num {}, Loss {}, Batch Accuracy {}%".format(epoch+1, (batch_idx + 1) * batch_size_train, loss.item(), np.sum(np.equal(y_GT.numpy(), np.argmax(y_PD.cpu().data.numpy(), axis=1)))/len(y_GT)*100))
print("labels(GT) = ", y_GT[:10].numpy())
print("labels(PD) = ", np.argmax(y_PD.cpu().data.numpy()[:10], axis=1))
scheduler.step() # adjsut learning rate.
net.eval()
acc_train = 0
loss_train = 0
for batch_idx, (img, y_GT) in enumerate(train_dataloader):
img = img.permute(0, 3, 1, 2).float()
with torch.no_grad():
y_PD, _, _ = net(img.cuda())
loss = criterion(y_PD, y_GT.long().cuda())
acc_train += np.sum(np.equal(y_GT.numpy(), np.argmax(y_PD.cpu().data.numpy(), axis=1)))
loss_train += loss.item()
acc_train /= len(trainset)
loss_train /= len(trainset) / batch_size_train
train_loss_list.append(loss_train)
train_accuracy_list.append(acc_train)
print("Train Loss : ", loss_train, "Accuracy : %.2f%%" %(acc_train * 100))
acc_test = 0
loss_test = 0
for batch_idx, (img, y_GT) in enumerate(test_dataloader):
img = img.permute(0, 3, 1, 2).float()
with torch.no_grad():
y_PD, _, _ = net(img.cuda())
loss = criterion(y_PD, y_GT.long().cuda())
acc_test += np.sum(np.equal(y_GT.numpy(), np.argmax(y_PD.cpu().data.numpy(), axis=1)))
loss_test += loss.item()
acc_test /= len(testset)
loss_test /= len(testset) / batch_size_train
test_loss_list.append(loss_test)
test_accuracy_list.append(acc_test)
print("Test Loss : ", loss_test, "Accuracy : %.2f%%" %(acc_test * 100))
if not os.path.isdir('./weights'):
os.mkdir('weights')
torch.save(net.state_dict(), 'weights/GoogLeNet_numCls{}_epoch{}.pth'.format(num_classes, epoch+1))
print("Time Elapsed : ", time.time() - time_s)
torch.save(net.state_dict(), 'weights/GoogLeNet_numCls{}.pth'.format(num_classes))
if not os.path.isdir('./README'):
os.mkdir('README')
if not os.path.isdir('./records'):
os.mkdir('records')
x = np.arange(len(train_accuracy_list) + 1)
plt.figure()
plt.xlabel('epochs')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.plot(x, [0] + train_accuracy_list)
plt.plot(x, [0] + test_accuracy_list)
plt.legend(['training accuracy', 'testing accuracy'], loc='upper right')
plt.grid(True)
plt.savefig('./README/Accuracy_numCls{}.png'.format(num_classes))
plt.show()
np.save('./records/train_accuracy_numCls{}.npy'.format(num_classes), train_accuracy_list)
np.save('./records/test_accuracy_numCls{}.npy'.format(num_classes), test_accuracy_list)
plt.figure()
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.plot(x, train_loss_list[0:1] + train_loss_list)
plt.plot(x, test_loss_list[0:1] + test_loss_list)
plt.legend(['training loss', 'testing loss'], loc='upper right')
plt.savefig('./README/Loss_numCls{}.png'.format(num_classes))
plt.show()
np.save('./records/train_loss_numCls{}.npy'.format(num_classes), train_loss_list)
np.save('./records/test_loss_numCls{}.npy'.format(num_classes), test_loss_list)