-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
executable file
·147 lines (126 loc) · 5.49 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from neuralnet import VDSR_model
import os
from utils.common import exists, tensor2numpy
import torch
import numpy as np
class logger:
def __init__(self, path, values) -> None:
self.path = path
self.values = values
class VDSR:
def __init__(self, device):
self.device = device
self.model = VDSR_model().to(device)
self.optimizer = None
self.loss = None
self.metric = None
self.model_path = None
self.ckpt_path = None
self.ckpt_man = None
def setup(self, optimizer, loss, metric, model_path, ckpt_path):
self.optimizer = optimizer
self.loss = loss
self.metric = metric
self.model_path = model_path
self.ckpt_path = ckpt_path
def load_checkpoint(self, ckpt_path):
if not exists(ckpt_path):
return
self.ckpt_man = torch.load(ckpt_path)
self.optimizer.load_state_dict(self.ckpt_man['optimizer'])
self.model.load_state_dict(self.ckpt_man['model'])
def load_weights(self, filepath):
self.model.load_state_dict(torch.load(filepath, map_location=torch.device(self.device)))
def predict(self, lr):
self.model.train(False)
sr = self.model(lr)
return sr
def evaluate(self, dataset, batch_size=64):
losses, metrics = [], []
isEnd = False
self.model.eval()
with torch.no_grad():
while isEnd == False:
lr, hr, isEnd = dataset.get_batch(batch_size, shuffle_each_epoch=False)
lr, hr = lr.to(self.device), hr.to(self.device)
sr = self.predict(lr)
loss = self.loss(hr, sr)
metric = self.metric(hr, sr)
losses.append(tensor2numpy(loss))
metrics.append(tensor2numpy(metric))
metric = np.mean(metrics)
loss = np.mean(losses)
return loss, metric
def train(self, train_set, valid_set, batch_size, epochs,
save_best_only=False, save_log=False, log_dir=None):
if (save_log) and (log_dir is None):
raise ValueError("log_dir must be specified if save_log is True")
os.makedirs(log_dir, exist_ok=True)
dict_logger = {"loss": logger(path=os.path.join(log_dir, "losses.npy"), values=[]),
"metric": logger(path=os.path.join(log_dir, "metrics.npy"), values=[]),
"val_loss": logger(path=os.path.join(log_dir, "val_losses.npy"), values=[]),
"val_metric": logger(path=os.path.join(log_dir, "val_metrics.npy"), values=[])}
for key in dict_logger.keys():
path = dict_logger[key].path
if exists(path):
dict_logger[key].values = np.load(path).tolist()
cur_epoch = 0
if self.ckpt_man is not None:
cur_epoch = self.ckpt_man['epoch']
max_epoch = cur_epoch + epochs
prev_loss = np.inf
if save_best_only and exists(self.model_path):
self.load_weights(self.model_path)
prev_loss, _ = self.evaluate(valid_set)
self.load_checkpoint(self.ckpt_path)
while cur_epoch < max_epoch:
# if cur_epoch % 20 == 0:
# self.optimizer.param_groups[0]["lr"] /= 10
cur_epoch += 1
loss_buffer = []
metric_buffer = []
isEnd = False
while isEnd == False:
lr, hr, isEnd = train_set.get_batch(batch_size)
loss, metric = self.train_step(lr, hr)
loss_buffer.append(tensor2numpy(loss))
metric_buffer.append(tensor2numpy(metric))
loss = np.mean(loss_buffer)
metric = np.mean(metric_buffer)
val_loss, val_metric = self.evaluate(valid_set)
print(f"Epoch {cur_epoch}/{max_epoch}",
f"- loss: {loss:.7f}",
f"- {self.metric.__name__}: {metric:.3f}",
f"- val_loss: {val_loss:.7f}",
f"- val_{self.metric.__name__}: {val_metric:.3f}")
torch.save({'epoch': cur_epoch,
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict()
}, self.ckpt_path)
if save_log == True:
dict_logger["loss"].values.append(loss)
dict_logger["metric"].values.append(metric)
dict_logger["val_loss"].values.append(val_loss)
dict_logger["val_metric"].values.append(val_metric)
if save_best_only and val_loss > prev_loss:
continue
prev_loss = val_loss
torch.save(self.model.state_dict(), self.model_path)
print(f"Save model to {self.model_path}\n")
if save_log == True:
for key in dict_logger.keys():
logger_obj = dict_logger[key]
path = logger_obj.path
values = np.array(logger_obj.values, dtype=np.float32)
np.save(path, values)
def train_step(self, lr, hr):
self.model.train()
self.optimizer.zero_grad()
lr, hr = lr.to(self.device), hr.to(self.device)
sr = self.model(lr)
loss = self.loss(hr, sr)
metric = self.metric(hr, sr)
loss.backward()
# torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.4 / self.optimizer.param_groups[0]["lr"])
self.optimizer.step()
return loss, metric