-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
97 lines (79 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import print_function
import os
import math
import json
import logging
import numpy as np
from PIL import Image
from datetime import datetime
from shutil import copy2
from glob import glob
def prepare_dirs_and_logger(config):
formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
logger = logging.getLogger()
for hdlr in logger.handlers:
logger.removeHandler(hdlr)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
# if config.load_path:
# if config.load_path.startswith(config.log_dir):
# config.model_dir = config.load_path
# else:
# if config.load_path.startswith(config.task):
# config.model_name = config.load_path
# else:
# config.model_name = "{}_{}".format(config.task, config.load_path)
# else:
config.model_name = "{}_{}".format(config.task, get_time())
if not hasattr(config, 'model_dir'):
config.model_dir = os.path.join(config.log_dir, config.model_name)
config.data_path = os.path.join(config.data_dir, config.dataset)
config.syn_data_dir = os.path.join(config.data_dir, config.syn_dataset)
config.dataset_3dmm_dir = os.path.join(config.data_dir, config.dataset_3dmm)
config.dataset_3dmm_test_dir = os.path.join(config.data_dir, config.dataset_3dmm_test)
for path in [config.log_dir, config.data_dir, config.model_dir]:
if not os.path.exists(path):
os.makedirs(path)
src_path = os.path.join(config.model_dir, 'src')
if not os.path.exists(src_path):
os.makedirs(src_path)
for path in glob("*.py"):
copy2(path, src_path)
def get_time():
return datetime.now().strftime("%m%d_%H%M%S")
def save_config(config):
param_path = os.path.join(config.model_dir, "params.json")
print("[*] MODEL dir: %s" % config.model_dir)
print("[*] PARAM path: %s" % param_path)
with open(param_path, 'w') as fp:
json.dump(config.__dict__, fp, indent=4, sort_keys=True)
def rank(array):
return len(array.shape)
def make_grid(tensor, nrow=8, padding=2,
normalize=False, scale_each=False):
"""Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py"""
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding)
grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
h, h_width = y * height + 1 + padding // 2, height - padding
w, w_width = x * width + 1 + padding // 2, width - padding
grid[h:h+h_width, w:w+w_width] = tensor[k]
k = k + 1
return grid
def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, scale_each=False):
ndarr = make_grid(tensor, nrow=nrow, padding=padding,
normalize=normalize, scale_each=scale_each)
im = Image.fromarray(ndarr)
im.save(filename, quality=100)
def save_one_image(tensor, filename):
im = Image.fr (tensor)
im.save(filename, quality=100)