-
Notifications
You must be signed in to change notification settings - Fork 0
/
ops.py
executable file
·99 lines (79 loc) · 2.77 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import numpy as np
class GraphUnpool(nn.Module):
def __init__(self):
super(GraphUnpool, self).__init__()
def forward(self, A, X, idx):
new_X = torch.zeros([A.shape[0], X.shape[1]])
new_X[idx] = X
return A, new_X
class GraphPool(nn.Module):
def __init__(self, k, in_dim):
super(GraphPool, self).__init__()
self.k = k
self.proj = nn.Linear(in_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, A, X):
scores = self.proj(X)
# scores = torch.abs(scores)
scores = torch.squeeze(scores)
scores = self.sigmoid(scores/100)
num_nodes = A.shape[0]
values, idx = torch.topk(scores, int(self.k*num_nodes))
new_X = X[idx, :]
values = torch.unsqueeze(values, -1)
new_X = torch.mul(new_X, values)
A = A[idx, :]
A = A[:, idx]
return A, new_X, idx
class GCN(nn.Module):
def __init__(self, in_dim, out_dim):
super(GCN, self).__init__()
self.proj = nn.Linear(in_dim, out_dim)
self.drop = nn.Dropout(p=0)
def forward(self, A, X):
X = self.drop(X)
# X = torch.matmul(A, X)
X = self.proj(X)
return X
class GraphUnet(nn.Module):
def __init__(self, ks, in_dim, out_dim, dim=320):
super(GraphUnet, self).__init__()
self.ks = ks
self.start_gcn = GCN(in_dim, dim)
self.bottom_gcn = GCN(dim, dim)
self.end_gcn = GCN(2*dim, out_dim)
self.down_gcns = []
self.up_gcns = []
self.pools = []
self.unpools = []
self.l_n = len(ks)
for i in range(self.l_n):
self.down_gcns.append(GCN(dim, dim))
self.up_gcns.append(GCN(dim, dim))
self.pools.append(GraphPool(ks[i], dim))
self.unpools.append(GraphUnpool())
def forward(self, A, X):
adj_ms = []
indices_list = []
down_outs = []
X = self.start_gcn(A, X)
start_gcn_outs = X
org_X = X
for i in range(self.l_n):
X = self.down_gcns[i](A, X)
adj_ms.append(A)
down_outs.append(X)
A, X, idx = self.pools[i](A, X)
indices_list.append(idx)
X = self.bottom_gcn(A, X)
for i in range(self.l_n):
up_idx = self.l_n - i - 1
A, idx = adj_ms[up_idx], indices_list[up_idx]
A, X = self.unpools[i](A, X, idx)
X = self.up_gcns[i](A, X)
X = X.add(down_outs[up_idx])
X = torch.cat([X, org_X], 1)
X = self.end_gcn(A, X)
return X, start_gcn_outs