-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
103 lines (83 loc) · 3.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
"""
-----------------------------------------------------------------------------
Copyright 2017 David Griffis
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-----------------------------------------------------------------------------
"""
from __future__ import division
import math
import numpy as np
import torch
from torch.autograd import Variable
import json
import logging
def setup_logger(logger_name, log_file, level=logging.INFO):
l = logging.getLogger(logger_name)
formatter = logging.Formatter('%(asctime)s : %(message)s')
fileHandler = logging.FileHandler(log_file, mode='w')
fileHandler.setFormatter(formatter)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
l.setLevel(level)
l.addHandler(fileHandler)
l.addHandler(streamHandler)
def read_config(file_path):
"""Read JSON config."""
json_object = json.load(open(file_path, 'r'))
return json_object
def norm_col_init(weights, std=1.0):
x = torch.randn(weights.size())
x *= std / torch.sqrt((x**2).sum(1, keepdim=True))
return x
def ensure_shared_grads(model, shared_model, gpu=False):
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
if shared_param.grad is not None and not gpu:
return
elif not gpu:
shared_param._grad = param.grad
else:
shared_param._grad = param.grad.cpu()
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
weight_shape = list(m.weight.data.size())
fan_in = np.prod(weight_shape[1:4])
fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
weight_shape = list(m.weight.data.size())
fan_in = weight_shape[1]
fan_out = weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
m.bias.data.fill_(0)
def weights_init_mlp(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / \
torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)
def normal(x, mu, sigma, gpu_id, gpu=False):
pi = np.array([math.pi])
pi = torch.from_numpy(pi).float()
if gpu:
with torch.cuda.device(gpu_id):
pi = Variable(pi).cuda()
else:
pi = Variable(pi)
a = (-1 * (x - mu).pow(2) / (2 * sigma)).exp()
b = 1 / (2 * sigma * pi.expand_as(sigma)).sqrt()
return a * b