-
Notifications
You must be signed in to change notification settings - Fork 6
/
fp16util.py
43 lines (32 loc) · 1.3 KB
/
fp16util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
# codes from https://github.com/fastai/imagenet-fast/tree/master/cifar10
class tofp16(nn.Module):
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
return input.half()
def copy_in_params(net, params):
net_params = list(net.parameters())
for i in range(len(params)):
net_params[i].data.copy_(params[i].data)
def set_grad(params, params_with_grad):
for param, param_w_grad in zip(params, params_with_grad):
if param.grad is None:
param.grad = torch.nn.Parameter(param.data.new().resize_(*param.data.size()))
param.grad.data.copy_(param_w_grad.grad.data)
def BN_convert_float(module):
'''
BatchNorm layers to have parameters in single precision.
Find all layers and convert them back to float. This can't
be done with built in .apply as that function will apply
fn to all modules, parameters, and buffers. Thus we wouldn't
be able to guard the float conversion based on the module type.
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
return nn.Sequential(tofp16(), BN_convert_float(network.half()))