-
Notifications
You must be signed in to change notification settings - Fork 8
/
tools.py
109 lines (90 loc) · 2.88 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import random
from tracemalloc import Snapshot
import cv2
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import affine
from torchvision.utils import draw_segmentation_masks
from PIL import Image
import tqdm
def get_featuremap(h, x):
w = h.weight
b = h.bias
c = w.shape[1]
c1 = F.conv2d(x, w.transpose(0,1), padding=(1,1), groups=c)
return c1, b
def ToLabel(E):
fgs = np.argmax(E, axis=1).astype(np.float32)
return fgs.astype(np.uint8)
def SSIM(x, y):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu_x = nn.AvgPool2d(3, 1, 1)(x)
mu_y = nn.AvgPool2d(3, 1, 1)(y)
mu_x_mu_y = mu_x * mu_y
mu_x_sq = mu_x.pow(2)
mu_y_sq = mu_y.pow(2)
sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
SSIM = SSIM_n / SSIM_d
return torch.clamp((1 - SSIM) / 2, 0, 1)
def SaliencyStructureConsistency(x, y, alpha):
ssim = torch.mean(SSIM(x,y))
l1_loss = torch.mean(torch.abs(x-y))
loss_ssc = alpha*ssim + (1-alpha)*l1_loss
return loss_ssc
def SaliencyStructureConsistencynossim(x, y):
l1_loss = torch.mean(torch.abs(x-y))
return l1_loss
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
class Flip:
def __init__(self, flip):
self.flip = flip
def __call__(self, img):
if self.flip==0:
return img.flip(-1)
else:
return img.flip(-2)
class Translate:
def __init__(self, fct):
'''Translate offset factor'''
drct = np.random.randint(0, 4)
self.signed_x = drct>=2 or -1
self.signed_y = drct%2 or -1
self.fct = fct
def __call__(self, img):
angle = 0
scale = 1
h, w = img.shape[-2:]
h, w = int(h*self.fct), int(w*self.fct)
return affine(img, angle, (h*self.signed_y,w*self.signed_x), scale, shear=0, interpolation=InterpolationMode.BILINEAR)
class Crop:
def __init__(self, H, W):
'''keep the relative ratio for offset'''
self.h = H
self.w = W
self.xm = np.random.uniform()
self.ym = np.random.uniform()
# print(self.xm, self.ym)
def __call__(self, img):
H,W = img.shape[-2:]
sh = int(self.h*H)
sw = int(self.w*W)
ymin = int((H-sh+1)*self.ym)
xmin = int((W-sw+1)*self.xm)
img = img[..., ymin:ymin+ sh, xmin:xmin+ sw]
img = F.interpolate(img, size=(H,W), mode='bilinear', align_corners=False)
return img