-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_psfrgan.py
72 lines (58 loc) · 2 KB
/
evaluate_psfrgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.functional import multiscale_structural_similarity_index_measure, \
structural_similarity_index_measure
from util.evaluate_config import EvaluateConfig
from util.create import create_dataset
if __name__ == '__main__':
config = EvaluateConfig(
filename='./config/evaluate.json'
)
dataset = create_dataset(config)
dataset_size = len(dataset)
data_range = 2
psnr = PeakSignalNoiseRatio(
compute_on_cpu=True,
data_range=data_range
)
lpips = LearnedPerceptualImagePatchSimilarity(
net_type='vgg',
compute_on_cpu=True,
data_range=data_range
)
if len(config.gpu_ids) > 0:
assert(torch.cuda.is_available())
psnr = psnr.to(config.device)
lpips = lpips.to(config.device)
ssim_total = 0
ms_ssim_total = 0
num_batch = dataset_size // config.batch_size
print(f'Total batch: {num_batch}.')
for i, data in tqdm(enumerate(dataset)):
print(f'[{i}/{num_batch}] Calculate metrics.')
data['sr'] = data['sr'].to(config.device)
data['hr'] = data['hr'].to(config.device)
psnr(data['sr'], data['hr'])
lpips(data['sr'], data['hr'])
ssim_total += (
structural_similarity_index_measure(
data['sr'],
data['hr'],
data_range=data_range,
reduction='sum',
)
)
ms_ssim_total += (
multiscale_structural_similarity_index_measure(
data['sr'],
data['hr'],
data_range=data_range,
reduction='sum'
)
)
print('PSNR:', psnr.compute())
print('LPIPS:', lpips.compute())
print('SSIM:', ssim_total / dataset_size)
print('MS-SSIM:', ms_ssim_total / dataset_size)