This repository has been archived by the owner on Oct 1, 2020. It is now read-only.
forked from xinntao/ESRGAN
-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
109 lines (96 loc) · 3.59 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model')
parser.add_argument('--input', default='LR', help='Input folder')
parser.add_argument('--output', default='results', help='Output folder')
parser.add_argument('--cpu', action='store_true', help='Use CPU instead of CUDA')
args = parser.parse_args()
if not os.path.exists(args.model):
print('Error: Model [{:s}] does not exist.'.format(args.model))
sys.exit(1)
elif not os.path.exists(args.input):
print('Error: Folder [{:s}] does not exist.'.format(args.input))
sys.exit(1)
elif os.path.isfile(args.input):
print('Error: Folder [{:s}] is a file.'.format(args.input))
sys.exit(1)
elif os.path.isfile(args.output):
print('Error: Folder [{:s}] is a file.'.format(args.output))
sys.exit(1)
elif not os.path.exists(args.output):
os.mkdir(args.output)
model_path = args.model
device = torch.device('cpu' if args.cpu else 'cuda')
test_img_folder = os.path.join(os.path.normpath(args.input), '*')
output_folder = os.path.normpath(args.output)
state_dict = torch.load(model_path)
if 'conv_first.weight' in state_dict:
print('Error: Attempted to load a new-format model')
sys.exit(1)
# extract model information
scale2 = 0
max_part = 0
for part in list(state_dict):
parts = part.split('.')
n_parts = len(parts)
if n_parts == 5 and parts[2] == 'sub':
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if part_num > 6 and parts[2] == 'weight':
scale2 += 1
if part_num > max_part:
max_part = part_num
out_nc = state_dict[part].shape[0]
upscale = 2 ** scale2
in_nc = state_dict['model.0.weight'].shape[1]
nf = state_dict['model.0.weight'].shape[0]
model = arch.RRDB_Net(in_nc, out_nc, nf, nb, gc=32, upscale=upscale, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(state_dict, strict=True)
del state_dict
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
print('Model path {:s}. \nTesting...'.format(model_path))
idx = 0
for path in glob.glob(test_img_folder):
if os.path.isdir(path): # skip directories
continue
idx += 1
base = os.path.splitext(os.path.basename(path))[0]
print(idx, base)
# read image
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img * 1. / np.iinfo(img.dtype).max
if img.ndim == 2:
img = np.tile(np.expand_dims(img, axis=2), (1, 1, min(in_nc, 3)))
if img.shape[2] > in_nc: # remove extra channels
if in_nc != 3 or img.shape[2] != 4 or img[:, :, 3].min() < 1:
print('Warning: Truncating image channels')
img = img[:, :, :in_nc]
elif img.shape[2] == 3 and in_nc == 4: # pad with solid alpha channel
img = np.dstack((img, np.full(img.shape[:-1], 1.)))
if img.shape[2] == 3:
img = img[:, :, [2, 1, 0]]
elif img.shape[2] == 4:
img = img[:, :, [2, 1, 0, 3]]
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze(0).float().cpu().clamp_(0, 1).numpy()
if output.shape[0] == 3:
output = output[[2, 1, 0], :, :]
elif output.shape[0] == 4:
output = output[[2, 1, 0, 3], :, :]
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(os.path.join(output_folder, '{:s}_rlt.png'.format(base)), output)