-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathtrain_test.py
124 lines (104 loc) · 5.11 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from os import TMP_MAX
import torch
import torch.nn as nn
import numpy as np
from optimizer import optim
from pathlib import Path
from plot import trainTestPlot
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Training:
def __init__(self, model, optimizer, learning_rate, train_dataloader, num_epochs,
test_dataloader, eval=True, plot=True, model_name=None, model_save=False, checkpoint=False):
self.model = model
self.learning_rate = learning_rate
self.optim = optimizer
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.num_epochs = num_epochs
self.eval = eval
self.plot = plot
self.model_name = model_name
self.model_save = model_save
self.checkpoint = checkpoint
def runner(self):
best_accuracy = float('-inf')
criterion = nn.CrossEntropyLoss()
if self.model_name in ['alexnet', 'vit', 'mlpmixer', 'resmlp', 'squeezenet', 'senet', 'mobilenetv1', 'gmlp', 'efficientnetv2']:
self.optimizer, scheduler = optim(model_name=self.model_name, model=self.model, lr=self.learning_rate)
elif self.optim == 'sgd':
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
elif self.optim == 'adam':
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
else:
pass
train_losses = []
train_accu = []
test_losses = []
test_accu = []
# Train the model
total_step = len(self.train_dataloader)
for epoch in range(self.num_epochs):
running_loss = 0
correct = 0
total = 0
for i, (images, labels) in enumerate(self.train_dataloader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = self.model(images)
loss = criterion(outputs, labels)
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss=running_loss/len(self.train_dataloader)
train_accuracy = 100.*correct/total
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Accuracy: {:.3f}, Train Loss: {:.4f}'
.format(epoch+1, self.num_epochs, i+1, total_step, train_accuracy, loss.item()))
if self.eval:
self.model.eval()
with torch.no_grad():
correct = 0
total = 0
running_loss = 0
for images, labels in self.test_dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = self.model(images)
loss= criterion(outputs,labels)
running_loss+=loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_loss=running_loss/len(self.test_dataloader)
test_accuracy = (correct*100)/total
print('Epoch: %.0f | Test Loss: %.3f | Accuracy: %.3f'%(epoch+1, test_loss, test_accuracy))
if test_accuracy > best_accuracy and self.model_save:
Path('model_store/').mkdir(parents=True, exist_ok=True)
#torch.save(self.model, 'model_store/'+self.model_name+'_best-model.pt')
torch.save(self.model.state_dict(), 'model_store/'+self.model_name+'best-model-parameters.pt')
for p in self.optimizer.param_groups:
print(f"Epoch {epoch+1} Learning Rate: {p['lr']}")
if self.model_name in ['alexnet', 'vit', 'mlpmixer', 'resmlp', 'squeezenet', 'senet', 'mobilenetv1', 'gmlp', 'efficientnetv2']:
scheduler.step()
if self.checkpoint:
path = 'checkpoints/checkpoint{:04d}.pth.tar'.format(epoch)
Path('checkpoints/').mkdir(parents=True, exist_ok=True)
torch.save(
{
'epoch': self.num_epochs,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss
}, path
)
train_accu.append(train_accuracy)
train_losses.append(train_loss)
test_losses.append(test_loss)
test_accu.append(test_accuracy)
trainTestPlot(self.plot, train_accu, test_accu, train_losses, test_losses, self.model_name)