-
Notifications
You must be signed in to change notification settings - Fork 2
/
tsne_visual.py
72 lines (48 loc) · 1.9 KB
/
tsne_visual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
embeddings = np.load('seg_embeddings.npy')
labels = np.load('seg_labels.npy')
np.random.seed(0)
uni_labels = np.unique(labels)
dictx = {}
overall_embeddings = []
overall_labels = []
for key in uni_labels:
index = (labels == key).squeeze(axis=-1)
num = index.sum().item()
class_embed = embeddings[index]
class_labels = labels[index]
index = np.random.choice(num, size=1000, replace=False)
overall_embeddings.append(class_embed[index])
overall_labels.append(class_labels[index])
overall_embeddings = np.concatenate(overall_embeddings,axis=0)
overall_labels = np.concatenate(overall_labels,axis=0)
np.save('seg_class_embed.npy', overall_embeddings)
np.save('seg_class_label.npy', overall_labels)
print(overall_embeddings.shape)
print(overall_labels.shape)
matplotlib.rcParams['font.family']='Times New Roman'
matplotlib.rcParams['font.size']=30
np.random.seed(1968081)
net1_embeddings = np.load('seg_class_embed.npy') # size: (sample_num, embedding_dim)
net1_target = np.load('seg_class_label.npy') # size: (sample_num, )
net1_target = net1_target.squeeze(axis=-1)
target_value = list(set(net1_target))
color_dict = {}
colors = ['black', 'red', 'gold', 'green', 'orange', 'pink', 'magenta', 'slategray', 'greenyellow', 'lightgreen',
'brown', 'chocolate', 'mediumvioletred', 'navy', 'lightseagreen', 'aqua', 'olive', 'maroon', 'yellow']
for i, t in enumerate(target_value):
color_dict[t] = colors[i]
print(color_dict)
net1 = TSNE(early_exaggeration=100).fit_transform(net1_embeddings)
np.save('tsne.npy', net1)
net1 = np.load('tsne.npy')
for i in range(len(target_value)):
tmp_X1 = net1[net1_target==target_value[i]]
plt.scatter(tmp_X1[:, 0], tmp_X1[:, 1], c=color_dict[target_value[i]], marker='o',s=7)
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig("tsne.png")