forked from CUG-URS/SSDGL
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
71 lines (58 loc) · 2.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from simplecv import dp_train as train
import torch
from simplecv.util.logger import eval_progress, speed
import time
from module import SSDGL
from simplecv.util import metric
from simplecv.util import registry
from torch.utils.data.dataloader import DataLoader
from simplecv import registry
from simplecv.core.config import AttrDict
from scipy.io import loadmat
import data.dataloader
def fcn_evaluate_fn(self, test_dataloader, config):
if self.checkpoint.global_step < 0:
return
self._model.eval()
total_time = 0.
with torch.no_grad():
for idx, (im, mask, w) in enumerate(test_dataloader):
start = time.time()
y_pred = self._model(im).squeeze()
torch.cuda.synchronize()
time_cost = round(time.time() - start, 3)
y_pred = y_pred.argmax(dim=0).cpu() + 1
w.unsqueeze_(dim=0)
w = w.byte()
mask = torch.masked_select(mask.view(-1), w.view(-1))
y_pred = torch.masked_select(y_pred.view(-1), w.view(-1))
oa = metric.th_overall_accuracy_score(mask.view(-1), y_pred.view(-1))
aa, acc_per_class = metric.th_average_accuracy_score(mask.view(-1), y_pred.view(-1),
self._model.module.config.num_classes,
return_accuracys=True)
kappa = metric.th_cohen_kappa_score(mask.view(-1), y_pred.view(-1), self._model.module.config.num_classes)
total_time += time_cost
speed(self._logger, time_cost, 'im')
eval_progress(self._logger, idx + 1, len(test_dataloader))
speed(self._logger, round(total_time / len(test_dataloader), 3), 'batched im (avg)')
metric_dict = {
'OA': oa.item(),
'AA': aa.item(),
'Kappa': kappa.item()
}
for i, acc in enumerate(acc_per_class):
metric_dict['acc_{}'.format(i + 1)] = acc.item()
self._logger.eval_log(metric_dict=metric_dict, step=self.checkpoint.global_step)
def register_evaluate_fn(launcher):
launcher.override_evaluate(fcn_evaluate_fn)
if __name__ == '__main__':
torch.backends.cudnn.benchmark = True
args = train.parser.parse_args()
SEED = 2333
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
train.run(config_path=args.config_path,
model_dir=args.model_dir,
cpu_mode=args.cpu,
after_construct_launcher_callbacks=[register_evaluate_fn],
opts=args.opts)