-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval_fsq.py
92 lines (77 loc) · 3.12 KB
/
eval_fsq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from arguments import get_args
from model import VQVAE
from dataset import get_transform
from torchvision import datasets, transforms
from lpips import LPIPS
from metric import get_revd_perceptual
from util import multiplyList
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.utils import save_image, make_grid
def main():
args = get_args()
assert args.quantizer == 'fsq'
# 1, load dataset
imagenet_transform = get_transform(args)
val_set = datasets.ImageFolder(args.val_data_path,imagenet_transform)
val_data_loader = torch.utils.data.DataLoader(
val_set,
batch_size=args.batch_size,
num_workers=args.num_workers,
drop_last=False,
shuffle=False
)
transform_rev = transforms.Normalize([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], [1. / 0.229, 1. / 0.224, 1. / 0.225])
# 2, load model
model = VQVAE(args)
model.cuda(torch.cuda.current_device())
# original saved file with DataParallel
state_dict = torch.load(args.load)['model_state_dict']
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model.eval()
# load perceptual model
perceptual_model = LPIPS().eval()
perceptual_model.cuda(torch.cuda.current_device())
get_l1_loss = torch.nn.L1Loss()
# for FID
fid = FrechetInceptionDistance(feature=2048, normalize=True)
# for compute codebook usage
num_embed = multiplyList(args.levels)
codebook_usage = set()
total_l1_loss = 0
total_per_loss = 0
num_iter = 0
for i, (input_img,_) in enumerate(val_data_loader):
# forward
num_iter += 1
print(num_iter*args.batch_size)
with torch.no_grad():
input_img = input_img.cuda(torch.cuda.current_device())
reconstructions, codebook_loss, ids = model(input_img, return_id=True)
# save_image(make_grid(torch.cat([input_img, reconstructions]), nrow=input_img.shape[0]), 'figures/' + str(num_embed)+'.jpg', normalize=True)
# exit()
ids = torch.flatten(ids)
for quan_id in ids:
codebook_usage.add(quan_id.item())
# compute L1 loss and perceptual loss
perceptual_loss = get_revd_perceptual(input_img, reconstructions,perceptual_model)
l1loss = get_l1_loss(input_img, reconstructions)
total_l1_loss += l1loss.cpu().item()
total_per_loss += perceptual_loss.cpu().item()
input_img = transform_rev(input_img.contiguous())
reconstructions = transform_rev(reconstructions.contiguous())
fid.update(input_img.cpu(), real=True)
fid.update(reconstructions.cpu(), real=False)
print('fid score',fid.compute())
print('l1loss:', total_l1_loss/num_iter)
print('precep_loss:', total_per_loss/num_iter)
print('codebook usage', len(codebook_usage)/num_embed)
if __name__ == "__main__":
main()