-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval.py
128 lines (107 loc) · 4.63 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import print_function
from train import parse_args
from utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler
from utils.logger import setup_logger
from utils.visualize import get_color_pallete
from utils.score import SegmentationMetric
from models import get_segmentation_model
from dataloader.cityscapes import CitySegmentation
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import torch.nn as nn
import torch
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)
class Evaluator(object):
def __init__(self, args):
self.args = args
self.args.pretrained = True
self.device = torch.device(args.device)
# image transform
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
# dataset and dataloader
val_dataset = CitySegmentation(
args.data_path, split='val', mode='testval', transform=input_transform)
val_sampler = make_data_sampler(val_dataset, False, args.distributed)
val_batch_sampler = make_batch_data_sampler(
val_sampler, images_per_batch=1)
self.val_loader = data.DataLoader(dataset=val_dataset,
batch_sampler=val_batch_sampler,
num_workers=args.workers,
pin_memory=True)
# create network
BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
self.model = get_segmentation_model(
model=args.model, pretrained=False).to(self.device)
if self.args.pretrained:
self.model.load_state_dict(torch.load(
"./trained_models/ddrnet_23_dualresnet_citys_best_model.pth",
map_location=self.args.device))
logger.info("Model restored successfully!!!!")
if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
device_ids=[args.local_rank], output_device=args.local_rank)
self.model.to(self.device)
self.metric = SegmentationMetric(val_dataset.num_class)
def eval(self):
self.metric.reset()
self.model.eval()
if self.args.distributed:
model = self.model.module
else:
model = self.model
logger.info("Start validation, Total sample: {:d}".format(
len(self.val_loader)))
for i, (image, target, filename) in enumerate(self.val_loader):
image = image.to(self.device)
target = target.to(self.device)
with torch.no_grad():
outputs = model(image)
self.metric.update(outputs[0], target)
pixAcc, mIoU = self.metric.get()
logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
i + 1, pixAcc * 100, mIoU * 100))
if self.args.save_pred:
pred = torch.argmax(outputs[0], 1)
pred = pred.cpu().data.numpy()
predict = pred.squeeze(0)
mask = get_color_pallete(predict, self.args.dataset)
mask.save(os.path.join(
outdir, os.path.splitext(filename[0])[0] + '.png'))
logger.info("Whole validation set mIoU: {:.3f}".format(mIoU * 100))
synchronize()
if __name__ == '__main__':
args = parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]
) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
if not args.no_cuda and torch.cuda.is_available():
cudnn.benchmark = True
args.device = "cuda"
else:
args.distributed = False
args.device = "cpu"
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://")
synchronize()
# TODO: optim code
args.save_pred = True
if args.save_pred:
outdir = 'runs/pred_pic/{}_{}_{}'.format(
args.model, args.backbone, args.dataset)
if not os.path.exists(outdir):
os.makedirs(outdir)
logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(),
filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+')
evaluator = Evaluator(args)
evaluator.eval()
torch.cuda.empty_cache()