-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
106 lines (91 loc) · 3.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
from itertools import islice
import torch
from datetime import datetime
import torch.nn as nn
loss_function_dict = {
'MSE': nn.MSELoss,
'CrossEntropy': nn.CrossEntropyLoss
}
def get_cur_time():
return datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
def cycle(iterable):
while True:
for x in iterable:
yield x
def get_float_wn(parameters):
"""
:param parameters:
Example: parameters = model_clf.linear_1.parameters()
:return: float
"""
with torch.no_grad():
out = sum(torch.pow(p, 2).sum() for p in parameters)
out = float(np.sqrt(out.item()))
return out
def compute_accuracy(network, dataset, device, N=2000, batch_size=50):
"""Computes accuracy of `network` on `dataset`."""
with torch.no_grad():
N = min(len(dataset), N)
batch_size = min(batch_size, N)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
correct = 0
total = 0
for x, labels in islice(dataset_loader, N // batch_size):
logits = network(x.to(device))
predicted_labels = torch.argmax(logits, dim=1)
correct += torch.sum(predicted_labels == labels.to(device))
total += x.size(0)
return (correct / total).item()
def compute_loss(network, dataset, loss_function, device, N=2000, batch_size=50):
"""Computes mean loss of `network` on `dataset`."""
with torch.no_grad():
N = min(len(dataset), N)
batch_size = min(batch_size, N)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
loss_fn = loss_function_dict[loss_function](reduction='sum')
one_hots = torch.eye(10, 10).to(device)
total = 0
points = 0
for x, labels in islice(dataset_loader, N // batch_size):
y = network(x.to(device))
if loss_function == 'CrossEntropy':
total += loss_fn(y, labels.to(device)).item()
elif loss_function == 'MSE':
total += loss_fn(y, one_hots[labels]).item()
points += len(labels)
return total / points
def log_gradients_in_model_tb(model, logger, step):
with torch.no_grad():
for tag, value in model.named_parameters():
if value.grad is not None:
logger.add_scalar(f"grad_mean/{tag.split('.')[1]}/{tag.split('.')[0]}",
torch.mean(value.grad.cpu()), step)
logger.add_scalar(f"grad_var/{tag.split('.')[1]}/{tag.split('.')[0]}",
torch.var(value.grad.cpu()), step)
def log_gradients_in_model_wandb(model, run, step):
with torch.no_grad():
for tag, value in model.named_parameters():
if value.grad is not None:
run.log({f"grad_mean_{tag.split('.')[1]}/{tag.split('.')[0]}": torch.mean(value.grad.cpu())},
step=step)
run.log({f"grad_var_{tag.split('.')[1]}/{tag.split('.')[0]}": torch.var(value.grad.cpu())},
step=step)
# For Sliced MI
def sample_spherical(n_projections, dim):
sampled_vectors = np.array([]).reshape(0,dim)
while len(sampled_vectors) < n_projections:
vec = np.random.multivariate_normal(np.zeros(dim), np.identity(dim), size=dim) # (num_vec, dim)
vec = np.linalg.qr(vec).Q
sampled_vectors = np.vstack((sampled_vectors, vec))
return sampled_vectors[:n_projections] # (num_vec, dim)
class smi_compressor():
def __init__(self, dim, n_projections):
self.theta = sample_spherical(n_projections=n_projections, dim=dim) # (n_projections, dim)
def __call__(self, X):
# getting projections
X_compressed = np.dot(self.theta, X.T)
return X_compressed # m x n
def measure_smi_projection(mi_estimator, x, y):
mi_estimator.fit(x, y, verbose=0)
return mi_estimator.estimate(x, y, verbose=0)