-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path3DCDC.py
130 lines (97 loc) · 4.84 KB
/
3DCDC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
import pdb
# -------- Vanilla ---------#
# nn.Conv3d(3, 64, kernel_size=3, padding=1)
# -------- 3DCDC ---------#
# CDC_ST(3, 64, kernel_size=3, padding=1, theta=0.6)
# CDC_T(3, 64, kernel_size=3, padding=1, theta=0.6)
# CDC_TR(3, 64, kernel_size=3, padding=1, theta=0.3)
'''
Spatio-Temporal Center-difference based Convolutional layer (3D version)
theta: control the percentage of original convolution and centeral-difference convolution
'''
class CDC_ST(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=0.6):
super(CDC_ST, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
out_normal = self.conv(x)
if math.fabs(self.theta - 0.0) < 1e-8:
return out_normal
else:
# pdb.set_trace()
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
# only CD works on temporal kernel size>1
if self.conv.weight.shape[2] > 1:
kernel_diff = self.conv.weight.sum(2).sum(2).sum(2)
kernel_diff = kernel_diff[:, :, None, None, None]
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
return out_normal - self.theta * out_diff
else:
return out_normal
'''
Temporal Center-difference based Convolutional layer (3D version)
theta: control the percentage of original convolution and centeral-difference convolution
'''
class CDC_T(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=0.6):
super(CDC_T, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
out_normal = self.conv(x)
if math.fabs(self.theta - 0.0) < 1e-8:
return out_normal
else:
# pdb.set_trace()
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
# only CD works on temporal kernel size>1
if self.conv.weight.shape[2] > 1:
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
2).sum(2)
kernel_diff = kernel_diff[:, :, None, None, None]
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
return out_normal - self.theta * out_diff
else:
return out_normal
'''
Temporal Robust Center-difference based Convolutional layer (3D version)
theta: control the percentage of original convolution and centeral-difference convolution
'''
class CDC_TR(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=0.3):
super(CDC_TR, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.avgpool = nn.AvgPool3d(kernel_size=(kernel_size, 1, 1), stride=stride, padding=(padding, 0, 0))
self.theta = theta
def forward(self, x):
out_normal = self.conv(x)
local_avg = self.avgpool(x)
if math.fabs(self.theta - 0.0) < 1e-8:
return out_normal
else:
# pdb.set_trace()
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
# only CD works on temporal kernel size>1
if self.conv.weight.shape[2] > 1:
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
2).sum(2)
kernel_diff = kernel_diff[:, :, None, None, None]
out_diff = F.conv3d(input=local_avg, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
padding=0, groups=self.conv.groups)
return out_normal - self.theta * out_diff
else:
return out_normal