-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_ShadowRemoval.py
146 lines (109 loc) · 5.18 KB
/
eval_ShadowRemoval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import lpips
import numpy as np
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from imageio.v2 import imread
import skimage
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage.color import rgb2lab
import scipy
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # vgg is used in the paper
def load_item(gt_path, pre_path, mask_path):
gt = imread(gt_path)
try:
pre = imread(pre_path)
except:
pre = imread(pre_path.replace('.JPG', '.png'))
if mask_path is not None:
mask = imread(mask_path)
# resize to gt size
pre = resize(pre, (gt.shape[0], gt.shape[1]))
if mask_path is not None:
mask = resize(mask, (gt.shape[0], gt.shape[1]))
mask = (mask > 255 * 0.9).astype(np.uint8) * 255
if mask_path is not None:
return to_tensor(gt), to_tensor(pre), to_tensor(mask)
else:
return to_tensor(gt), to_tensor(pre), None
def to_tensor(img):
img = Image.fromarray(img)
img_t = F.to_tensor(img).float()
img_t = img_t.unsqueeze(dim=0)
return img_t
def resize(img, target_size):
img = skimage.transform.resize(img, target_size, mode='reflect', anti_aliasing=True)
img = (img * 255).astype(np.uint8) # Ensure the image is in uint8 format
return img
def calc_rmse(real_img, fake_img):
# Convert to LAB color space
real_lab = rgb2lab(real_img)
fake_lab = rgb2lab(fake_img)
rmse = np.sqrt(((real_lab - fake_lab) ** 2).mean())
return rmse
def metric(gt, pre):
transf = torchvision.transforms.Compose(
[torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
lpips_value = loss_fn_vgg(transf(pre[0]).cuda(), transf(gt[0]).cuda()).item()
pre = pre * 255.0
pre = pre.permute(0, 2, 3, 1)
pre = pre.detach().cpu().numpy().astype(np.uint8)[0]
gt = gt * 255.0
gt = gt.permute(0, 2, 3, 1)
gt = gt.cpu().detach().numpy().astype(np.uint8)[0]
psnr = compare_psnr(gt, pre)
ssim = compare_ssim(gt, pre, data_range=255, channel_axis=-1)
rmse = calc_rmse(gt, pre)
return psnr, ssim, lpips_value, rmse
def evaluation(gt_root, pre_root, mask_root):
fnames = os.listdir(gt_root)
fnames.sort()
psnr_all_list, ssim_all_list, lpips_all_list, rmse_all_list = [], [], [], []
psnr_non_list, ssim_non_list, lpips_non_list, rmse_non_list = [], [], [], []
psnr_shadow_list, ssim_shadow_list, lpips_shadow_list, rmse_shadow_list = [], [], [], []
for fname in fnames:
gt_path = os.path.join(gt_root, fname)
pre_path = os.path.join(pre_root, fname)
if mask_root is not None:
mask_path = os.path.join(mask_root, fname)
# For SDR only, replace the mask path _free.jpg to .png
if mask_root is not None:
mask_path = mask_path.replace('.jpg', '.png')
else:
mask_path = None
pre_path = pre_path.replace('.jpg', '.png')
if not os.path.exists(pre_path):
pre_path = pre_path.replace('.png', '.jpg')
gt, pre, mask = load_item(gt_path, pre_path, mask_path)
psnr_all, ssim_all, lpips_all, rmse_all = metric(gt, pre)
psnr_all_list.append(psnr_all)
ssim_all_list.append(ssim_all)
lpips_all_list.append(lpips_all)
rmse_all_list.append(rmse_all)
print('-----------------------------------------------------------------------------')
print(f'All psnr: {round(np.average(psnr_all_list), 4)} ssim: {round(np.average(ssim_all_list), 4)} lpips: {round(np.average(lpips_all_list), 4)} rmse: {round(np.average(rmse_all_list), 4)}')
########## Set the paths for evaluation ##########
# gt_root: ground truth root path
# pre_root: prediction root path
# mask_root: mask root path
# input_root: input root path (not used in the evaluation, only when you want to know the metrics of the input images)
########## General Shadow Removal Evaluation ##########
##### Example paths for ISTD+ dataset #####
mask_root = '/home/zhxing/Datasets/ISTD+/test/test_B_GT_NoSDDNet' # test_B_GT_NoSDDNet indicates the ground truth shadow mask here, not the one generated by SDDNet
gt_root = '/home/zhxing/Datasets/ISTD+/test/test_C'
input_root = '/home/zhxing/Datasets/ISTD+/test/test_A'
pred_root = '/home/zhxing/Projects/ShadowSurvey/ShadowRemoval/Auto/ISTD+512'
##### Example paths for SRD dataset #####
# mask_root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/test/test_B_GT_NoSDDNet' # test_B_GT_NoSDDNet indicates the ground truth shadow mask here, not the one generated by SDDNet
# gt_root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/test/test_C'
# input_root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/test/test_A'
# pred_root = '/home/zhxing/Projects/ShadowSurvey/ShadowRemoval/Auto/SRD512'
########## Document Shadow Removal Evaluation ##########
##### Example paths for RDD dataset #####
# mask_root = None # There is no mask ground truth for document shadow removal dataset
# gt_root = '/home/zhxing/Datasets/RDD_data/test/gt'
# pred_root = '/home/zhxing/Projects/ShadowSurvey/DocShadowRemoval/BEDSR-Net'
# Start evaluation
evaluation(gt_root, pred_root, mask_root)