-
Notifications
You must be signed in to change notification settings - Fork 2
/
prnet_loss.py
74 lines (60 loc) · 1.88 KB
/
prnet_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
"""
@author: samuel ko
@date: 2019.07.19
@readme: The implementation of PRNet Network Loss.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import *
import cv2
import numpy as np
def preprocess(mask):
"""
:param mask: grayscale of mask.
:return:
"""
tmp = {}
mask[mask > 0] = mask[mask > 0] / 16
mask[mask == 15] = 16
mask[mask == 7] = 8
# for i in mask:
# for j in i:
# if j not in tmp.keys():
# tmp[j] = 1
# else:
# tmp[j] += 1
# print(tmp)
# {0: 21669, 3: 33223, 4: 10429, 8: 147, 16: 68}
return mask
class WeightMaskLoss(nn.Module):
"""
L2_Loss * Weight Mask
"""
def __init__(self, mask_path):
super(WeightMaskLoss, self).__init__()
if os.path.exists(mask_path):
self.mask = cv2.imread(mask_path, 0)
self.mask = torch.from_numpy(preprocess(self.mask)).float().to("cuda")
else:
raise FileNotFoundError("Mask File Not Found! Please Check your Settings!")
def forward(self, pred, gt):
result = torch.mean(torch.pow((pred - gt), 2), dim=1)
result = torch.mul(result, self.mask)
# 1) 官方(不除256*256的话, 数值就太大了...).
result = torch.sum(result)
result = result / (self.mask.size(1) ** 2)
# 2) 一般使用的都是mean.
# result = torch.mean(result)
return result
def INFO(*inputs):
if len(inputs) == 1:
print("[ PRNet ] {}".format(inputs))
elif len(inputs) == 2:
print("[ PRNet ] {0}: {1}".format(inputs[0], inputs[1]))
if __name__ == "__main__":
# mask = cv2.imread("/home/samuel/gaodaiheng/3DFace/code/PRNet_Samuel/utils/uv_data/uv_weight_mask_gdh.png", 0)
# preprocess(mask)
INFO("Random Seed", 1)