-
Notifications
You must be signed in to change notification settings - Fork 1
/
coin_export_stats.py
51 lines (37 loc) · 1.31 KB
/
coin_export_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy
import os
import json
import torch
import sys
from skimage import io
from losses.psnr import psnr
from pytorch_msssim import ms_ssim, ssim
from utils import calculate_state_dict_size, load_device
from torch._C import dtype
from typing import Dict
def ms_ssim_reshape(tensor):
return tensor.movedim(-1, 0).unsqueeze(0)
def main():
print("Loading device...")
device = load_device()
print("Loading parameters...")
original_file_path = sys.argv[1]
reconstructed_file_path = sys.argv[2]
stats_path = sys.argv[3]
compressed_state_path = sys.argv[4]
print("Calculating compressed state size...")
state_dict = torch.load(compressed_state_path)
print("Loading images...")
original_image_tensor = torch.from_numpy(io.imread(original_file_path)).to(device).to(torch.float32)
reconstructed_image_tensor = torch.from_numpy(io.imread(reconstructed_file_path)).to(device).to(torch.float32)
print("Calculating stats...")
stats = {
"psnr": psnr(original_image_tensor, reconstructed_image_tensor).item(),
"ms-ssim": None,
"ssim": ssim(ms_ssim_reshape(original_image_tensor), ms_ssim_reshape(reconstructed_image_tensor)).item(),
"bpp": None
}
print(stats)
json.dump(stats, open(stats_path, "w"))
if __name__ == "__main__":
main()