-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
61 lines (51 loc) · 1.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.embedding = nn.Linear(input_dim, hidden_dim)
self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
self.output_layer = nn.Linear(hidden_dim, latent_dim * 2)
def forward(self, X):
X = self.embedding(X)
X = F.relu(X)
X = self.hidden_layer(X)
X = F.relu(X)
X = self.output_layer(X)
mu, var = torch.chunk(X, 2, dim=1)
return mu, var
class Decoder(nn.Module):
def __init__(self, output_dim, hidden_dim, latent_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_dim = output_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.intput_layer = nn.Linear(latent_dim, hidden_dim)
self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, X):
X = self.intput_layer(X)
X = F.relu(X)
X = self.hidden_layer(X)
X = F.relu(X)
X = self.output_layer(X)
X = F.sigmoid(X)
return X
class VAE(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, latent_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(output_dim, hidden_dim, latent_dim)
def reparameterize(self, mu, var):
std = torch.exp(0.5 * var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, X):
mu, logvar = self.encoder(X)
z = self.reparameterize(mu, logvar)
X = self.decoder(z)
return X, mu, logvar