-
Notifications
You must be signed in to change notification settings - Fork 0
/
gcnold.py
36 lines (27 loc) · 1.1 KB
/
gcnold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# original gcn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim):
super(GraphConvolution, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
def forward(self, adjacency, input_feature):
support = torch.mm(input_feature, self.weight)
output = torch.sparse.mm(adjacency, support)
if self.use_bias:
output += self.bias
return output
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
class GCN(nn.Module):
def __init__(self,input_dim,imput_dim):
super(GCN, self).__init__()
self.gcn1 = GraphConvolution(input_dim, 16)
self.gcn2 = GraphConvolution(16, 7)
def forward(self, adjacency, feature):
h = F.relu(self.gcn1(adjacency, feature))
logits = self.gcn2(adjacency, h)
return logits