-
Notifications
You must be signed in to change notification settings - Fork 27
/
eval.py
107 lines (92 loc) · 4.03 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
# encoding: utf-8
import os
import cv2
import argparse
import numpy as np
import torch
import torch.multiprocessing as mp
from config import config
from utils.pyt_utils import ensure_dir, link_file, load_model, parse_devices
from utils.visualize import print_iou, show_img
from engine.evaluator import Evaluator
from engine.logger import get_logger
from seg_opr.metric import hist_info, compute_score
from datasets.voc import VOC
from model.deeperlab import deeperlab
logger = get_logger()
class SegEvaluator(Evaluator):
def func_per_iteration(self, data, device):
img = data['data']
label = data['label']
name = data['fn']
#this quesion for what ?
pred = self.sliding_eval(img, config.eval_crop_size,
config.eval_stride_rate, device)
#pred
#(h,w,c)
hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
pred,
label)
results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp,
'correct': correct_tmp}
if self.save_path is not None:
fn = name + '.png'
cv2.imwrite(os.path.join(self.save_path, fn), pred)
logger.info('Save the image ' + fn)
if self.show_image:
colors = self.dataset.get_class_colors
image = img
clean = np.zeros(label.shape)
comp_img = show_img(colors, config.background, image, clean,
label,
pred)
cv2.imshow('comp_image', comp_img)
cv2.waitKey(0)
return results_dict
def compute_metric(self, results):
hist = np.zeros((config.num_classes, config.num_classes))
correct = 0
labeled = 0
count = 0
for d in results:
hist += d['hist']
correct += d['correct']
labeled += d['labeled']
count += 1
iu, mean_IU, mean_IU_no_back, mean_pixel_acc = compute_score(hist, correct,
labeled)
result_line = print_iou(iu, mean_pixel_acc,
dataset.get_class_names(), True)
return result_line,mean_IU,mean_IU_no_back,mean_pixel_acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default='last', type=str)
parser.add_argument('-c', '--csv_root', default='None', type=str)
#how many GPU U used
parser.add_argument('-d', '--devices', default='0', type=str)
parser.add_argument('-v', '--verbose', default=False, action='store_true')
parser.add_argument('--show_image', '-s', default=False,
action='store_true')
parser.add_argument('--save_path', '-p', default=None)
args = parser.parse_args()
all_dev = parse_devices(args.devices)
mp_ctx = mp.get_context('spawn')
#network = DFN(config.num_classes, criterion=None, aux_criterion=None,
# alpha=config.aux_loss_alpha)
network = deeperlab(3, config.num_classes, None, None,None)
data_setting = {'img_root': config.img_root_folder,
'gt_root': config.gt_root_folder,
'train_source': config.train_source,
'eval_source': config.eval_source}
dataset = VOC(data_setting, 'val', None)
link_val_log = (config.link_val_log_file.replace("last",args.epochs))
with torch.set_grad_enabled(False):
#set the segmentor
segmentor = SegEvaluator(dataset, config.num_classes, config.image_mean,
config.image_std, network,
config.eval_scale_array, config.eval_flip,
all_dev,args.csv_root, args.verbose, args.save_path,
args.show_image)
segmentor.run(config.snapshot_dir, args.epochs, config.val_log_file,
link_val_log)