-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgroup_basis.py
172 lines (128 loc) · 6.53 KB
/
group_basis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torch.nn as nn
import einops
from utils import get_device, transform_atlas
device = get_device()
def normalize(x):
# from lie gan
trace = torch.einsum('kdf,kdf->k', x, x)
factor = torch.sqrt(trace / x.shape[1])
x = x / factor.unsqueeze(-1).unsqueeze(-1)
return x
class GroupBasis(nn.Module):
def __init__(
self, in_dim, man_dim, out_dim, num_basis, standard_basis, num_cosets=64,
in_rad=10, out_rad=5, lr=5e-4, r1=0.05, r2=1, r3=0.35,
identity_in_rep=False, identity_out_rep=False, in_interpolation='bilinear', out_interpolation='bilinear', dtype=torch.float32,
):
super().__init__()
self.in_dim = in_dim
self.man_dim = man_dim
self.out_dim = out_dim
self.in_rad = in_rad
self.out_rad = out_rad
self.in_interpolation = in_interpolation
self.out_interpolation = out_interpolation
self.identity_in_rep = identity_in_rep
self.identity_out_rep = identity_out_rep
self.num_basis = num_basis
self.num_cosets = num_cosets
self.dtype = dtype
self.r1 = r1
self.r2 = r2
self.r3 = r3
self.standard_basis = standard_basis
self.lie_basis = nn.Parameter(torch.empty((num_basis, man_dim, man_dim), dtype=dtype).to(device))
self.in_basis = nn.Parameter(torch.empty((num_basis, in_dim, in_dim), dtype=dtype).to(device))
self.out_basis = nn.Parameter(torch.empty((num_basis, out_dim, out_dim), dtype=dtype).to(device))
for tensor in [self.in_basis, self.lie_basis, self.out_basis]:
nn.init.normal_(tensor, 0, 0.02)
cosets = torch.empty((num_cosets, man_dim, man_dim), dtype=dtype).to(device)
nn.init.normal_(cosets, 0, 1)
self.cosets = nn.Parameter(cosets)
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
def summary(self):
ret = []
if not self.identity_in_rep:
ret.append(self.in_basis.data)
ret.append(self.lie_basis.data)
if not self.identity_out_rep:
ret.append(self.in_basis.data)
return ret
def similarity_loss(self, x):
if len(x) <= 1:
return 0
x = normalize(x)
if self.standard_basis:
x = torch.abs(x)
return torch.sum(torch.abs(torch.triu(torch.einsum('bij,cij->bc', x, x), diagonal=1)))
def sample_coefficients(self, bs):
"""
Important, even when we are dealing with complex values,
our goal is still only to find the real Lie groups so that the sampled coefficients are
to be taken only as real numbers.
"""
return torch.normal(0, 1, (*bs, self.num_basis)).to(device)
def step(self, x, pred, _y):
"""
y is only used for debug
"""
bs = x.batch_size()
coeffs = self.sample_coefficients((bs, x.num_charts()))
def sample(raw):
return torch.matrix_exp(torch.sum(raw * coeffs.unsqueeze(-1).unsqueeze(-1), dim=-3))
sampled_lie = sample(self.lie_basis)
sampled_in = sample(self.in_basis)
sampled_out = sample(self.out_basis)
if self.identity_in_rep:
sampled_in = torch.eye(self.in_dim, device=device).unsqueeze(0).unsqueeze(0).repeat(bs, x.num_charts(), 1, 1)
if self.identity_out_rep:
sampled_out = torch.eye(self.out_dim, device=device).unsqueeze(0).unsqueeze(0).repeat(bs, x.num_charts(), 1, 1)
x_atlas = x.regions(self.in_rad)
g_x_atlas = transform_atlas(sampled_lie, sampled_in, x_atlas, self.in_interpolation)
y_atlas = pred.run(x_atlas)
if pred.returns_logits():
y_atlas = torch.nn.functional.softmax(y_atlas, dim=-3)
y_atlas = y_atlas.detach()
g_y_atlas = transform_atlas(sampled_lie, sampled_out, y_atlas, self.out_interpolation)
y_atlas_true = pred.run(g_x_atlas)
r = y_atlas_true.shape[-2] // 2
c = y_atlas_true.shape[-1] // 2
y_atlas_true = y_atlas_true[..., r - self.out_rad: r + self.out_rad + 1, c - self.out_rad: c + self.out_rad + 1]
g_y_atlas = g_y_atlas[..., r - self.out_rad: r + self.out_rad + 1, c - self.out_rad: c + self.out_rad + 1]
return pred.loss(y_atlas_true, g_y_atlas)
def norm_cosets(self):
det = torch.abs(torch.det(self.cosets).unsqueeze(-1).unsqueeze(-1))
return self.cosets / (det ** (1 / self.man_dim))
def coset_step(self, x, pred):
# for now, can only handle identity in and out rep
assert self.identity_in_rep and self.identity_out_rep
bs = x.batch_size()
# technically each chart is transformed the same way,
# but we ensure independence through the separate predictors elsewhere so it's fine
cosets = einops.repeat(self.norm_cosets(), 'c ... -> (c bs) ...', bs=bs * x.num_charts())
in_rep = torch.eye(self.in_dim, device=device).unsqueeze(0).unsqueeze(0).repeat(bs * len(self.cosets), x.num_charts(), 1, 1)
out_rep = torch.eye(self.out_dim, device=device).unsqueeze(0).unsqueeze(0).repeat(bs * len(self.cosets), x.num_charts(), 1, 1)
x_atlas = einops.repeat(x.regions(self.in_rad), 'bs ... -> (c bs) ...', c=len(self.cosets))
g_x_atlas = transform_atlas(cosets, in_rep, x_atlas, self.in_interpolation)
y_atlas = pred.run(x_atlas)
if pred.returns_logits():
y_atlas = torch.nn.functional.softmax(y_atlas, dim=-3)
y_atlas = y_atlas.detach()
g_y_atlas = transform_atlas(cosets, out_rep, y_atlas, self.out_interpolation)
y_atlas_true = pred.run(g_x_atlas)
r = y_atlas_true.shape[-2] // 2
c = y_atlas_true.shape[-1] // 2
y_atlas_true = y_atlas_true[..., r - self.out_rad: r + self.out_rad + 1, c - self.out_rad: c + self.out_rad + 1]
g_y_atlas = g_y_atlas[..., r - self.out_rad: r + self.out_rad + 1, c - self.out_rad: c + self.out_rad + 1]
return y_atlas_true.unflatten(0, (-1, bs)), g_y_atlas.unflatten(0, (-1, bs))
# called by LocalTrainer during training
def regularization(self, _epoch_num):
# aim for as 'orthogonal' as possible basis matrices
sim = self.similarity_loss(self.lie_basis)
# past a certain point, increasing the basis means nothing
# we only want to increase to a certain extent
clipped = self.lie_basis.clamp(-self.r2, self.r2)
trace = torch.sqrt(torch.einsum('kdf,kdf->k', clipped, clipped))
lie_mag = -torch.mean(trace)
return self.r1 * sim + self.r3 * lie_mag