-
Notifications
You must be signed in to change notification settings - Fork 0
/
nalu.py
82 lines (59 loc) · 2.16 KB
/
nalu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from torch.nn import functional as F
from torch import nn
class NAC(nn.Module):
def __init__(self, in_dim, out_dim, init_fun=nn.init.xavier_uniform_):
super().__init__()
self._W_hat = nn.Parameter(torch.empty(in_dim, out_dim))
self._M_hat = nn.Parameter(torch.empty(in_dim, out_dim))
self.register_parameter('W_hat', self._W_hat)
self.register_parameter('M_hat', self._M_hat)
for param in self.parameters():
init_fun(param)
def forward(self, x):
W = F.tanh(self._W_hat) * F.sigmoid(self._M_hat)
return x.matmul(W)
class StackedNAC(nn.Module):
def __init__(self, n_layers, in_dim, out_dim, hidden_dim,
init_fun=nn.init.xavier_uniform_):
super().__init__()
self._nac_stack = nn.Sequential(*[
NAC(
in_dim if i == 0 else hidden_dim,
out_dim if i == n_layers - 1 else hidden_dim,
init_fun=init_fun
)
for i in range(n_layers)
])
def forward(self, x):
return self._nac_stack(x)
class NALU(nn.Module):
def __init__(self, in_dim, out_dim, init_fun=nn.init.xavier_uniform_):
super().__init__()
self._G = nn.Parameter(torch.empty(in_dim, 1))
self.register_parameter('G', self._G)
init_fun(self._G)
self._nac = NAC(in_dim, out_dim, init_fun=init_fun)
self._epsilon = 1e-8
def forward(self, x):
g = F.sigmoid(x.matmul(self._G))
m = torch.exp(
self._nac(torch.log(torch.abs(x) + self._epsilon))
)
a = self._nac(x)
y = g * a + (1 - g) * m
return y
class StackedNALU(nn.Module):
def __init__(self, n_layers, in_dim, out_dim, hidden_dim,
init_fun=nn.init.xavier_uniform_):
super().__init__()
self._nalu_stack = nn.Sequential(*[
NALU(
in_dim if i == 0 else hidden_dim,
out_dim if i == n_layers - 1 else hidden_dim,
init_fun=init_fun
)
for i in range(n_layers)
])
def forward(self, x):
return self._nalu_stack(x)