-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
139 lines (108 loc) · 5.91 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch as T
VGG = T.nn.Sequential(
T.nn.Conv2d(3, 3, 1),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(3, 64, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(64, 64, 3), T.nn.ReLU(),
T.nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(64, 128, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(128, 128, 3), T.nn.ReLU(),
T.nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(128, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU()
)
for p in VGG.parameters(): p.requires_grad = False
DECODER = T.nn.Sequential(
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 256, 3), T.nn.ReLU(),
T.nn.Upsample(scale_factor=2, mode='nearest'),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 256, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(256, 128, 3), T.nn.ReLU(),
T.nn.Upsample(scale_factor=2, mode='nearest'),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(128, 128, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(128, 64, 3), T.nn.ReLU(),
T.nn.Upsample(scale_factor=2, mode='nearest'),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(64, 64, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1),
T.nn.Conv2d(64, 3, 3)
)
class SA(T.nn.Module):
def norm(self, feat, eps=1e-5):
B, C = feat.shape[:2]
var = feat.view([B, C, -1]).var(dim=2) + eps
std = var.sqrt().view([B, C, 1, 1])
mean = feat.view([B, C, -1]).mean(dim=2).view([B, C, 1, 1])
norm = (feat-mean.expand(feat.shape)) / std.expand(feat.shape)
return norm
def __init__(self, c):
super().__init__()
self.f, self.g, self.h = T.nn.Conv2d(c, c, 1), T.nn.Conv2d(c, c, 1), T.nn.Conv2d(c, c, 1)
self.cnn = T.nn.Conv2d(c, c, 1)
def forward(self, con, sty):
f, g, h = self.f(self.norm(con)), self.g(self.norm(sty)), self.h(sty)
[B, _, H_f, W_f], [_, _, H_g, W_g] = f.shape, g.shape
a = T.bmm(f.view([B, -1, W_f*H_f]).permute(0, 2, 1), g.view([B, -1, W_g*H_g]))
a = T.nn.functional.softmax(a, dim=-1)
o = T.bmm(h.view([B, -1, W_g*H_g]), a.permute(0, 2, 1))
_, C_c, H_c, W_c = con.shape
o = con + self.cnn(o.view([B, C_c, H_c, W_c]))
return o
class CLVA(T.nn.Module):
def __init__(self, c):
super().__init__()
self.enc_c4, self.enc_c5 = T.nn.Sequential(*list(VGG.children())[:31]), T.nn.Sequential(*list(VGG.children())[31:44])
self.enc_s4, self.enc_s5 = [T.nn.Sequential(*[T.nn.Linear(512, 4096), T.nn.ReLU(),
T.nn.Linear(4096, 512*8*8), T.nn.ReLU()]),
T.nn.Sequential(*[T.nn.Conv2d(512, 512, 3, padding=1), T.nn.ReLU(),
T.nn.MaxPool2d(2),
T.nn.Conv2d(512, 512, 3, padding=1), T.nn.ReLU()])]
self.sa4, self.sa5 = SA(c), SA(c)
self.pad, self.cnn = T.nn.ReflectionPad2d(1), T.nn.Conv2d(c, c, 3)
self.dec = DECODER
def fusion(self, c4, s4, c5, s5):
sa4, sa5 = self.sa4(c4, s4), self.sa5(c5, s5)
sa = self.pad(sa4 + T.nn.functional.interpolate(sa5, size=[c4.shape[2], c4.shape[3]], mode='nearest')) # bicubic
sa = self.cnn(sa)
out = self.dec(sa)
return out
def forward(self, con, ins): # forward_cx
B = con.shape[0]
c4, s4 = self.enc_c4(con), self.enc_s4(ins).view([B, -1, 8, 8])
c5, s5 = self.enc_c5(c4), self.enc_s5(s4)
F = min(c4.shape[-2]//s4.shape[-2], c5.shape[-1]//s5.shape[-1])
s4, s5 = [T.nn.functional.interpolate(s4, scale_factor=F, mode='bicubic', align_corners=True),
T.nn.functional.interpolate(s5, scale_factor=F, mode='bicubic', align_corners=True)]
out = self.fusion(c4, s4, c5, s5)
return out
def forward_cs(self, con, sty):
B = con.shape[0]
c4, s4 = self.enc_c4(con), self.enc_c4(sty)
c5, s5 = self.enc_c5(c4), self.enc_c5(s4)
out = self.fusion(c4, s4, c5, s5)
return out
class Discriminator(T.nn.Module):
def __init__(self):
super().__init__()
self.enc = T.nn.Sequential(*list(VGG.children())[:31])
self.cnn = T.nn.Sequential(*[T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.ReflectionPad2d(1), T.nn.Conv2d(512, 512, 3), T.nn.ReLU(),
T.nn.AdaptiveAvgPool2d(1)])
self.fc = T.nn.Sequential(*[T.nn.Linear(1024, 1024), T.nn.ReLU(),
T.nn.Linear(1024, 1), T.nn.Sigmoid()])
def forward(self, patch, ins):
f = self.enc(patch)
f = self.cnn(f).squeeze()
out = self.fc(T.cat([f, ins], dim=1))
return out