-
Notifications
You must be signed in to change notification settings - Fork 8
/
antialiasing.py
181 lines (149 loc) · 6.47 KB
/
antialiasing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import math
import ckconv
def regularize_gabornet(
model, horizon, factor, target="gabor", fn="l2_relu", method="together", gauss_stddevs=1.0
):
"""Regularize a FlexNet.
"""
# if method != "summed":
# raise NotImplementedError()
# Collect frequency terms to be regularized from all FlexConv modules
modules = get_flexconv_modules(model)
magnet_freqs = []
mask_freqs = []
masks = False
for module in modules:
module_magnet_freqs, module_mask_freq = gabor_layer_frequencies(
module, target, method, gauss_stddevs=gauss_stddevs
)
magnet_freqs.append(module_magnet_freqs)
if module_mask_freq is not None:
mask_freqs.append(module_mask_freq)
magnet_freqs = torch.stack(magnet_freqs)
if len(mask_freqs) > 0:
masks = True
mask_freqs = torch.stack(mask_freqs)
if method == "summed":
# Regularize sum of all filters together, per layer
magnet_freqs = torch.sum(magnet_freqs, 1)
if masks:
flexconv_freqs = magnet_freqs + mask_freqs
else:
flexconv_freqs = magnet_freqs
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)
elif method == "together" and target == "gabor":
if masks:
raise NotImplementedError()
else:
flexconv_freqs = magnet_freqs
# Divide Nyquist frequency by amount of filters in each layer
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)
nyquist_freq = nyquist_freq / nyquist_freq.shape[1]
elif method in ["together", "together+mask"] and target == "gabor+mask":
if masks:
# Distributing single mask over all filters
n_filters = magnet_freqs.shape[1]
mask_freqs = mask_freqs.unsqueeze(1).repeat([1, n_filters]) / torch.tensor(
n_filters, dtype=torch.float32
)
flexconv_freqs = magnet_freqs + mask_freqs
else:
raise NotImplementedError()
# Divide Nyquist frequency by amount of filters in each layer
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)
nyquist_freq = nyquist_freq / nyquist_freq.shape[1]
if fn == "l2_relu":
# L2 ReLU
return factor * l2_relu(flexconv_freqs, nyquist_freq)
elif fn == "offset_elu":
# L1 ELU with offset and scale to approximate L2 ReLU
return factor * offset_elu(flexconv_freqs, nyquist_freq, 4.0, 5.0)
else:
raise NotImplementedError(f"regularization function {fn}")
def nyquist_frequency(kernel_size):
# Nyquist frequency = samples per X x 1/2 (for rate to freq)
return float((kernel_size - 1.0) / 2.0) * 0.5
def freq_effect(gamma, stddevs=1.0):
return gamma * stddevs / (2.0 * torch.tensor(math.pi))
def gabor_layer_frequencies(module, target, method, gauss_stddevs=1.0):
n_filters = len(module.Kernel.filters)
# If we are using distributed regularization, we have two terms for each
# filter: the sine term and the Gaussian term
# n_terms = 2 if target == "gabor" and method == "distributed" else 1
freqs = torch.zeros((n_filters), dtype=torch.float32)
for i, f in enumerate(module.Kernel.filters):
if target == "sines":
# All units are in Hz, not radians
freqs[i] = torch.max(torch.absolute(f.linear.weight)) / (
2.0 * torch.tensor(math.pi)
)
elif target == "gausses":
gausses = freq_effect(f.gamma, stddevs=gauss_stddevs)
freqs[i] = torch.max(torch.absolute(gausses))
elif target == "gabor" or target == "gabor+mask":
# All units are in Hz, not radians
sines = torch.absolute(f.linear.weight / (2.0 * torch.tensor(math.pi)))
gausses = torch.absolute(freq_effect(f.gamma, stddevs=gauss_stddevs))
if method in ["together", "summed", "together+mask"]:
combined = sines + gausses
freqs[i] = torch.max(combined)
# elif method == "distributed":
# freqs[i, 0] = torch.max(sines)
# freqs[i, 1] = torch.max(gausses)
else:
raise NotImplementedError(f"method {method}")
else:
raise NotImplementedError(f"target {target}")
mask_freq = None
if target == "gabor+mask":
# Mask effect = max(x_gamma,y_gamma) where each is inverse of sigma
x_mask_gamma = 1.0 / torch.absolute(module.mask_params[0, 1]).to(freqs.device)
y_mask_gamma = 1.0 / torch.absolute(module.mask_params[1, 1]).to(freqs.device)
mask_gamma = torch.maximum(x_mask_gamma, y_mask_gamma)
mask_freq = freq_effect(mask_gamma, stddevs=gauss_stddevs)
return freqs, mask_freq
def l2_relu(x, target):
over_freq = torch.maximum(
torch.tensor(0.0, device=x.device),
x - target,
)
return torch.sum(torch.square(over_freq))
def offset_elu(x, target, offset, scale):
over_freq = x - target
condition = over_freq > offset
elu = torch.where(
condition, over_freq - offset + 1.0, torch.exp(over_freq - offset)
)
elu *= scale
return torch.sum(elu)
def get_flexconv_modules(model):
modules = []
for m in model.modules():
if isinstance(m, ckconv.nn.FlexConv):
modules.append(m)
return modules
def get_gabornet_summaries(model, target, method):
if target == "gabor":
targets = ["sines", "gausses", "gabor"]
elif target == "gabor+mask":
targets = ["sines", "gausses", "gabor+mask"]
else:
targets = [target]
stats = {}
modules = get_flexconv_modules(model)
module_mask_freqs = torch.zeros((len(modules)), dtype=torch.float32)
for t in targets:
module_magnet_freqs = torch.zeros((len(modules)), dtype=torch.float32)
for i, module in enumerate(modules):
magnet_freqs, mask_freq = gabor_layer_frequencies(module, t, method)
module_magnet_freqs[i] = torch.sum(magnet_freqs)
stats[f"{t}_freq_{i}"] = module_magnet_freqs[i]
if t == "gabor+mask":
module_mask_freqs[i] = mask_freq
stats[f"mask_freq_{i}"] = module_mask_freqs[i]
stats[f"{t}_freq_mean"] = torch.mean(module_magnet_freqs)
stats[f"{t}_freq_std"] = torch.std(module_magnet_freqs)
stats[f"mask_freq_mean"] = torch.mean(module_mask_freqs)
stats[f"mask_freq_std"] = torch.std(module_mask_freqs)
return stats