This repository has been archived by the owner on Oct 28, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Grad_norm.py
145 lines (119 loc) · 5.68 KB
/
Grad_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
# import matplotlib.pyplot as plt
import scipy.spatial.distance as dist
from sklearn.preprocessing import StandardScaler
# import matplotlib.animation
from dev.calc_grad_norm_of_ks import calc_grad_norm_of_ks as calc_grad_norm
# import japanize_matplotlib
class Grad_Norm:
def __init__(self, X=None, Z=None, sigma=None, resolution=None,
labels=None, fig_size=[6, 6], title_text='Grad_norm', cmap_type='jet',
interpolation_method='spline36', repeat=False, interval=40):
# インプットが無効だった時のエラー処理
if Z is None:
raise ValueError('please input winner point')
if X is None:
raise ValueError('input observed data X')
# set input
self.X = X
# set latent variables
if Z.ndim == 3:
self.Z_allepoch = Z # shape=(T,N,L)
self.isStillImage = False
elif Z.ndim == 2:
self.Z_allepoch = Z[None, :, :] # shape=(1,N,L)
self.isStillImage = True
# set sigma
if isinstance(sigma, (int, float)):
if self.isStillImage:
self.sigma_allepoch = np.array([sigma, ])
else:
raise ValueError("if Z is 3d array, sigma must be 1d array")
elif sigma.ndim == 1:
if self.isStillImage:
raise ValueError("if Z is 2d array, sigma must be scalar")
else:
self.sigma_allepoch = sigma # shape=(T)
else:
raise ValueError("sigma must be 1d array or scalar")
self.resolution = resolution
self.N = self.X.shape[0]
if self.N != self.Z_allepoch.shape[1]:
raise ValueError("Z's sample number and X's one must be match")
self.T = self.Z_allepoch.shape[0]
self.L = self.Z_allepoch.shape[2]
if self.L != 2:
raise ValueError('latent variable must be 2dim')
self.K = resolution ** self.L
# 描画キャンバスの設定
# self.Fig = plt.figure(figsize=(fig_size[0], fig_size[1]))
# self.Map = self.Fig.add_subplot(1, 1, 1)
self.Cmap_type = cmap_type
self.labels = labels
if self.labels is None:
self.labels = np.arange(self.N) + 1
self.interpolation_method = interpolation_method
self.title_text = title_text
self.repeat = repeat
self.interval = interval
# 潜在空間の代表点の設定
self.Zeta = np.meshgrid(
np.linspace(self.Z_allepoch[:, :, 0].min(), self.Z_allepoch[:, :, 0].max(), self.resolution),
np.linspace(self.Z_allepoch[:, :, 1].min(), self.Z_allepoch[:, :, 1].max(), self.resolution))
self.Zeta = np.dstack(self.Zeta).reshape(self.K, self.L)
def calc_umatrix(self):
# Grad_normの初期状態を表示する
Z = self.Z_allepoch[0, :, :]
sigma = self.sigma_allepoch[0]
# Grad_norm表示用の値を算出
dY_std = calc_grad_norm(Zeta=self.Zeta, Z=Z, X=self.X, sigma=sigma) # return value in [-2.0,2.0]
# U_matrix_val = dY_std.reshape((self.resolution, self.resolution))
return dY_std, self.resolution, self.Zeta
# Grad_norm表示
def draw_umatrix(self):
# Grad_normの初期状態を表示する
Z = self.Z_allepoch[0, :, :]
sigma = self.sigma_allepoch[0]
# Grad_norm表示用の値を算出
dY_std = calc_grad_norm(Zeta=self.Zeta, Z=Z, X=self.X, sigma=sigma) # return value in [-2.0,2.0]
U_matrix_val = dY_std.reshape((self.resolution, self.resolution))
# Grad_norm表示
self.Map.set_title(self.title_text)
self.Im = self.Map.imshow(U_matrix_val, interpolation=self.interpolation_method,
extent=[self.Zeta[:, 0].min(), self.Zeta[:, 0].max(),
self.Zeta[:, 1].max(), self.Zeta[:, 1].min()],
cmap=self.Cmap_type, vmax=2.0, vmin=-2.0, animated=True)
# ラベルの表示
self.Scat = self.Map.scatter(x=Z[:, 0], y=Z[:, 1], c='k')
# 勝者位置が重なった時用の処理
self.Label_Texts = []
self.epsilon = 0.04 * (self.Z_allepoch.max() - self.Z_allepoch.min())
for i in range(self.N):
count = 0
for j in range(i):
if np.allclose(Z[j, :], Z[i, :]):
count += 1
Label_Text = self.Map.text(Z[i, 0], Z[i, 1] + self.epsilon * count, self.labels[i], color='k')
self.Label_Texts.append(Label_Text)
ani = matplotlib.animation.FuncAnimation(self.Fig, self.update, interval=self.interval, blit=False,
repeat=self.repeat, frames=self.T)
plt.show()
def update(self, epoch):
Z = self.Z_allepoch[epoch, :, :]
sigma = self.sigma_allepoch[epoch]
dY_std = calc_grad_norm(Zeta=self.Zeta, Z=Z, X=self.X, sigma=sigma) # return value in [-2.0,2.0]
U_matrix_val = dY_std.reshape((self.resolution, self.resolution))
self.Im.set_array(U_matrix_val)
self.Scat.set_offsets(Z)
for i in range(self.N):
count = 0
for j in range(i):
if np.allclose(Z[j, :], Z[i, :]):
count += 1
self.Label_Texts[i].remove()
self.Label_Texts[i] = self.Map.text(Z[i, 0],
Z[i, 1] + self.epsilon * count,
self.labels[i],
color='k')
if not self.isStillImage:
self.Map.set_title(self.title_text + ' epoch={}'.format(epoch))