-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (58 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib
import torch
import numpy as np
def args_parser():
parser = argparse.ArgumentParser(description='tsne instruction and visualization')
parser.add_argument('-perplexity', default=10)
parser.add_argument('--data',default=torch.randn(500,10))
parser.add_argument('--label_state',default='no')
parser.add_argument('--class_number',default=4)
parser.add_argument('--label',default=np.random.randint(0,4,[500,1]))
args = parser.parse_args()
return args
def tsne(all_fea):
"""
:param all_fea: data nxm
:return: a array which is samples x 2 or 3
"""
# n_components is descending into 2 dimensions
# perplexity is a guessing: maybe one class has 10 samples
tsne = TSNE(n_components=2, perplexity=10, random_state=0)
X_d = tsne.fit_transform(all_fea)
# note: you get X_2d is samples x 2 array.
return X_d
def draw(x,label_state,label,classifcation):
"""
:param x: you data which need to visualization
:param label_state: yes or no,if you have labels, you should input yes
:param label: your labels samples x 1
:param classifcation: how many class do you have a number
:return:none
"""
if label_state == 'no':
plt.scatter(x[:,0],x[:,1],c='r')
# we have label like: a->14 means a is 14 class
if label_state == 'yes':
color = []
# so we need to get some color to description the different dot
for name, hex in matplotlib.colors.cnames.items():
if len(color) < classifcation:
color.append(name)
# draw
for i,pairs in enumerate(x):
label_ = label[i][0]
# c is color of label of i-th sample
plt.scatter(pairs[0],pairs[1],c=color[label_])
plt.show()
def start(args):
# here is your data, tensor/array/list ...
# tsne
fea = tsne(args.data)
# draw
draw(fea,args.label_state,args.label,args.class_number)
if __name__ == '__main__':
args = args_parser()
start(args)