-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot_tsne.py
61 lines (43 loc) · 1.99 KB
/
plot_tsne.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import os
import clip
from PIL import Image
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def seed_everything(seed=0):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def plottsne(classreal, classfake, model, transform, device):
print('transforming images')
classreal_tensors = torch.stack([transform(img).unsqueeze(0) for img in classreal]).squeeze(1).to(device)
classfake_tensors = torch.stack([transform(img).unsqueeze(0) for img in classfake]).squeeze(1).to(device)
print('encoding images')
with torch.no_grad():
classreal_features = model.encode_image(classreal_tensors).cpu().numpy()
classfake_features = model.encode_image(classfake_tensors).cpu().numpy()
print('TSNE embedding')
all_features = np.vstack([classreal_features, classfake_features])
embedded_features = TSNE(n_components=2).fit_transform(all_features)
print('plotting')
plt.figure(figsize=(10, 10))
plt.scatter(embedded_features[:len(classreal_features), 0], embedded_features[:len(classreal_features), 1], color='g', label='REAL')
plt.scatter(embedded_features[len(classreal_features):, 0], embedded_features[len(classreal_features):, 1], color='r', label='FAKE')
plt.legend()
plt.title(f"t-SNE plot")
plt.savefig(f"tsne.png")
def main():
seed_everything()
device = "cuda:0"
model, transform = clip.load("ViT-L/14", device=device)
pathreal = os.path.join("<path to real images>")
classreal = [Image.open(os.path.join(pathreal, image_name)) for image_name in os.listdir(pathreal)]
pathfake = os.path.join("<path to fake images>")
classfake = [Image.open(os.path.join(pathfake, image_name)) for image_name in os.listdir(pathfake)]
plottsne(classreal, classfake, model, transform, device)
if __name__ == "__main__":
main()