-
Notifications
You must be signed in to change notification settings - Fork 10
/
model.py
42 lines (36 loc) · 1.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch import nn
class VisualTransformerEncoder(nn.Module):
def __init__(self, in_dim, out_dim):
...
def forward(self, face):
...
class TransformerDecoder(nn.Module):
def __init__(self, embed_dim, K, N):
self.N = N
self.K = K
self.embed_dim = embed_dim
self.model = ...
def forward(self, face):
batch = face.shape[0]
weights = torch.zeros(
batch, self.N, self.embed_dim, device=face.device,
dtype=face.dtype,
)
for _ in range(self.K):
weights += self.model(face, weights)
return weights
class HyperNetwork(nn.Module):
def __init__(self, in_dim, out_dim, K, L):
self.face_encoder = VisualTransformerEncoder(in_dim, out_dim, K)
self.proj = nn.Linear(out_dim, out_dim)
self.weight_decoder = TransformerDecoder(out_dim, out_dim)
self.affine = [nn.Linear(out_dim, out_dim) for _ in range(L)]
def forward(self, face):
face = self.face_encoder(face)
face = self.proj(face)
delta_weights = self.weight_decoder(face) # batch, seq_len, embed_dim
seq_len = delta_weights.shape[1]
for i in range(seq_len):
delta_weights[:, i] = self.affine[i](delta_weights[:, i])
return delta_weights