-
Notifications
You must be signed in to change notification settings - Fork 6
/
AE.py
83 lines (66 loc) · 2.42 KB
/
AE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import math
import torch.nn as nn
import torch
import torch.nn.functional as F
from pc_kit import PointNet, SAPP
from bitEstimator import BitEstimator
from pytorch3d.loss import chamfer_distance
class get_model(nn.Module):
def __init__(self, k, d):
super(get_model, self).__init__()
self.k = k
self.d = d
self.sa = SAPP(in_channel=3, feature_region=k//4, mlp=[32, 64, 128], bn=False)
self.pn = PointNet(in_channel=3+128, mlp=[256, 512, 1024, d], relu=[True, True, True, False], bn=False)
self.decoder = nn.Sequential(
nn.Linear(d, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, k*3),
)
self.be = BitEstimator(channel=d)
def forward(self, xyz):
B, K, C = xyz.shape
# encode
xyz = xyz.transpose(1, 2)
feature = self.sa(xyz)
feature = self.pn(torch.cat((xyz, feature), dim=1))
# quantization
if self.training:
quantizated_feature = feature + torch.nn.init.uniform_(torch.zeros(feature.size()), -0.5, 0.5).cuda()
else:
quantizated_feature = torch.round(feature)
bottleneck = quantizated_feature
# decode
new_xyz = self.decoder(bottleneck)
new_xyz = new_xyz.reshape(B, -1, 3)
# BITRATE ESTIMATION
prob = self.be(bottleneck + 0.5) - self.be(bottleneck - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-10) / math.log(2.0), 0, 50))
#print(total_bits)
bpp = total_bits / K / B
return new_xyz, bpp
def get_pmf(self, device='cuda'):
L = 99 # get cdf [-L, ..., L-1], total L*2 numbers
pmf = torch.zeros(1, self.d, L*2).to(device)
for l in range(-L, L):
z = torch.ones((1, self.d)).to(device) * l
pmf[0, :, l+L] = (self.be(z + 0.5) - self.be(z - 0.5))[0, :]
#print(pmf.shape)
return pmf
class get_loss(nn.Module):
def __init__(self):
super(get_loss, self).__init__()
def forward(self, pred, target, bpp, lamda):
'''
Input:
pred: reconstructed point cloud (B, N, 3)
target: origin point cloud (B, CxN, 3)
bottleneck:
'''
d, d_normals = chamfer_distance(pred, target)
loss = d + lamda * bpp
return loss