-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathOFGLoss.py
83 lines (68 loc) · 2.35 KB
/
OFGLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch.optim as optim
import torch.nn as nn
import utils.losses as losses
import utils.utils as utils
class DeformationOptimizer(nn.Module):
"""
Optimization module for displacements field
Used to provide pseudo ground truth for training
"""
def __init__(self, img_size, initial_flow, mode='bilinear'):
"""
Args:
img_size (tuple): shape of the input image
initial_flow (torch.Tensor): initial flow field
"""
super(DeformationOptimizer, self).__init__()
self.img_size = img_size
self.mode = mode
self.flow = nn.Parameter(initial_flow.clone())
self.spatial_trans = utils.SpatialTransformer(self.img_size, self.mode)
def forward(self, x):
"""
Args:
x (torch.Tensor): moving image
"""
x_warped = self.spatial_trans(x, self.flow)
return x_warped, self.flow
class OFGLoss(nn.Module):
"""
OFG loss function
"""
def __init__(self, iter_count=5, reg_weight=1, lr=0.1):
"""
Args:
iter_count (int): number of steps for optimization
reg_weight (float): weight of regularization term
"""
super(OFGLoss, self).__init__()
self.iter_count = iter_count
self.reg_weight = reg_weight
self.lr = lr
self.ncc = losses.NCC_vxm()
self.reg = losses.Grad3d(penalty='l2')
self.mse = nn.MSELoss()
def forward(self, x, y, initial_flow):
"""
Args:
x (torch.Tensor): moving image
y (torch.Tensor): fixed image
initial_flow (torch.Tensor): initial deformation field
"""
_, _, H, W, D = x.shape
img_size = (H, W, D)
opt = DeformationOptimizer(img_size, initial_flow)
adam = optim.Adam(opt.parameters(), lr=self.lr,
weight_decay=0, amsgrad=True)
for _ in range(self.iter_count):
x_warped, optimized_flow = opt(x)
loss_ncc = self.ncc(x_warped, y) * 1
loss_reg = self.reg(optimized_flow, y) * self.reg_weight
loss = loss_ncc + loss_reg
adam.zero_grad()
loss.backward()
adam.step()
ofg_loss = self.mse(optimized_flow, initial_flow)
return ofg_loss
if __name__ == '__main__':
criterion_ofg = OFGLoss()