-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathground_metric.py
182 lines (149 loc) · 7.34 KB
/
ground_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
from utils import isnan
class GroundMetric:
"""
Ground Metric object for Wasserstein computations:
"""
def __init__(self, params, not_squared = False):
self.params = params
self.ground_metric_type = params.ground_metric
self.ground_metric_normalize = params.ground_metric_normalize
self.reg = params.reg
if hasattr(params, 'not_squared'):
self.squared = not params.not_squared
else:
# so by default squared will be on!
self.squared = not not_squared
self.mem_eff = params.ground_metric_eff
def _clip(self, ground_metric_matrix):
if self.params.debug:
print("before clipping", ground_metric_matrix.data)
percent_clipped = (float((ground_metric_matrix >= self.reg * self.params.clip_max).long().sum().data) \
/ ground_metric_matrix.numel()) * 100
print("percent_clipped is (assumes clip_min = 0) ", percent_clipped)
setattr(self.params, 'percent_clipped', percent_clipped)
# will keep the M' = M/reg in range clip_min and clip_max
ground_metric_matrix.clamp_(min=self.reg * self.params.clip_min,
max=self.reg * self.params.clip_max)
if self.params.debug:
print("after clipping", ground_metric_matrix.data)
return ground_metric_matrix
def _normalize(self, ground_metric_matrix):
if self.ground_metric_normalize == "log":
ground_metric_matrix = torch.log1p(ground_metric_matrix)
elif self.ground_metric_normalize == "max":
print("Normalizing by max of ground metric and which is ", ground_metric_matrix.max())
ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.max()
elif self.ground_metric_normalize == "median":
print("Normalizing by median of ground metric and which is ", ground_metric_matrix.median())
ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.median()
elif self.ground_metric_normalize == "mean":
print("Normalizing by mean of ground metric and which is ", ground_metric_matrix.mean())
ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.mean()
elif self.ground_metric_normalize == "none":
return ground_metric_matrix
else:
raise NotImplementedError
return ground_metric_matrix
def _sanity_check(self, ground_metric_matrix):
assert not (ground_metric_matrix < 0).any()
assert not (isnan(ground_metric_matrix).any())
def _cost_matrix_xy(self, x, y, p=2, squared = True):
# TODO: Use this to guarantee reproducibility of previous results and then move onto better way
"Returns the matrix of $|x_i-y_j|^p$."
x_col = x.unsqueeze(1)
y_lin = y.unsqueeze(0)
c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)
if not squared:
print("dont leave off the squaring of the ground metric")
c = c ** (1/2)
# print(c.size())
if self.params.dist_normalize:
assert NotImplementedError
return c
def _pairwise_distances(self, x, y=None, squared=True):
'''
Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
x_norm = (x ** 2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y ** 2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
# Ensure diagonal is zero if x=y
dist = torch.clamp(dist, min=0.0)
if self.params.activation_histograms and self.params.dist_normalize:
dist = dist/self.params.act_num_samples
print("Divide squared distances by the num samples")
if not squared:
print("dont leave off the squaring of the ground metric")
dist = dist ** (1/2)
return dist
def _get_euclidean(self, coordinates, other_coordinates=None):
# TODO: Replace by torch.pdist (which is said to be much more memory efficient)
if other_coordinates is None:
matrix = torch.norm(
coordinates.view(coordinates.shape[0], 1, coordinates.shape[1]) \
- coordinates, p=2, dim=2
)
else:
if self.mem_eff:
matrix = self._pairwise_distances(coordinates, other_coordinates, squared=self.squared)
else:
matrix = self._cost_matrix_xy(coordinates, other_coordinates, squared = self.squared)
return matrix
def _normed_vecs(self, vecs, eps=1e-9):
norms = torch.norm(vecs, dim=-1, keepdim=True)
print("stats of vecs are: mean {}, min {}, max {}, std {}".format(
norms.mean(), norms.min(), norms.max(), norms.std()
))
return vecs / (norms + eps)
def _get_cosine(self, coordinates, other_coordinates=None):
if other_coordinates is None:
matrix = coordinates / torch.norm(coordinates, dim=1, keepdim=True)
matrix = 1 - matrix @ matrix.t()
else:
matrix = 1 - torch.div(
coordinates @ other_coordinates.t(),
torch.norm(coordinates, dim=1).view(-1, 1) @ torch.norm(other_coordinates, dim=1).view(1, -1)
)
return matrix.clamp_(min=0)
def _get_angular(self, coordinates, other_coordinates=None):
pass
def get_metric(self, coordinates, other_coordinates=None):
get_metric_map = {
'euclidean': self._get_euclidean,
'cosine': self._get_cosine,
'angular': self._get_angular,
}
return get_metric_map[self.ground_metric_type](coordinates, other_coordinates)
def process(self, coordinates, other_coordinates=None):
print('Processing the coordinates to form ground_metric')
if self.params.geom_ensemble_type == 'wts' and self.params.normalize_wts:
print("In weight mode: normalizing weights to unit norm")
coordinates = self._normed_vecs(coordinates)
if other_coordinates is not None:
other_coordinates = self._normed_vecs(other_coordinates)
ground_metric_matrix = self.get_metric(coordinates, other_coordinates)
if self.params.debug:
print("coordinates is ", coordinates)
if other_coordinates is not None:
print("other_coordinates is ", other_coordinates)
print("ground_metric_matrix is ", ground_metric_matrix)
self._sanity_check(ground_metric_matrix)
ground_metric_matrix = self._normalize(ground_metric_matrix)
self._sanity_check(ground_metric_matrix)
if self.params.clip_gm:
ground_metric_matrix = self._clip(ground_metric_matrix)
self._sanity_check(ground_metric_matrix)
if self.params.debug:
print("ground_metric_matrix at the end is ", ground_metric_matrix)
return ground_metric_matrix