-
Notifications
You must be signed in to change notification settings - Fork 0
/
gat.py
116 lines (82 loc) · 4.04 KB
/
gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import argparse
from config import parse_config
class GAT(nn.Module):
def __init__(self, args):
super(GAT, self).__init__()
# debug
self.args = args
self.edge_vocab_size = args.edge_vocab_size
self.edge_dim = args.embed_dim
self.node_dim = args.embed_dim
self.hidden_dim = args.embed_dim
self.gnn_layers = args.gnn_layer_num
self.dropout = nn.Dropout(self.args.gnn_dropout)
self.edge_embedding = nn.Embedding(self.edge_vocab_size, self.edge_dim)
# combine neighbours: remove later
self.neighbors = nn.Linear(self.edge_dim+self.node_dim,self.node_dim)
# objects for attention
self.nodeW = nn.Linear(self.edge_dim+self.node_dim,self.node_dim)
self.attenW = nn.Parameter(torch.randn(self.node_dim, self.node_dim))
def forward(self, batch_data):
# indices: batch_size, node_num, neighbor_num_max # in_indices
# edges shapes : batch_size, node_num, edge_labels # in_edges, out_edges
node_reps, mask, in_indices, in_edges, in_mask, out_indices, out_edges, out_mask, _ = batch_data[:-1]
node_reps = self.dropout(node_reps)
# ==== input from in neighbors
# [batch_size, node_num, neighbor_num_max, edge_dim]
in_edge_reps = self.edge_embedding(in_edges)
# [batch_size, node_num, neighbor_num_max, node_dim]
in_node_reps = self.collect_neighbors(node_reps, in_indices)
# [batch_size, node_num, neighbor_num_max, node_dim + edge_dim]
in_reps = torch.cat([in_node_reps, in_edge_reps], 3)
'attention starts'
in_alpha = self.attention_on_neighbors(node_reps.squeeze(0), in_reps.squeeze(0))
in_reps = torch.mul(in_reps, in_alpha.unsqueeze(0))
'attention ends'
in_reps = in_reps.mul(in_mask.unsqueeze(-1))
# [batch_size, node_num, word_dim + edge_dim]
in_reps = in_reps.sum(dim=2)
# ==== input from out neighbors
# [batch_size, node_num, neighbor_num_max, edge_dim]
out_edge_reps = self.edge_embedding(out_edges)
# [batch_size, node_num, neighbor_num_max, node_dim]
out_node_reps = self.collect_neighbors(node_reps, out_indices)
# [batch_size, node_num, neighbor_num_max, node_dim + edge_dim]
out_reps = torch.cat([out_node_reps, out_edge_reps], 3)
'attention starts'
out_alpha = self.attention_on_neighbors(node_reps.squeeze(0), out_reps.squeeze(0))
out_reps = torch.mul(out_reps,out_alpha.unsqueeze(0))
'attention ends'
# and WX then sum over neighbors
out_reps = out_reps.mul(out_mask.unsqueeze(-1))
# [batch_size, node_num, word_dim + edge_dim]
out_reps = out_reps.sum(2)
# combine neighbors
out_nodes = self.neighbors(out_reps.squeeze(0))
in_nodes = self.neighbors(in_reps.squeeze(0))
node_hidden = node_reps + out_nodes + in_nodes
return node_hidden
def collect_neighbors(self, node_reps, index):
# node_rep: [batch_size, node_num, node_dim]
# index: [batch_size, node_num, neighbors_num]
batch_size = index.size(0)
node_num = index.size(1)
neighbor_num = index.size(2)
rids = torch.arange(0, batch_size).to(self.args.device) # [batch]
rids = rids.reshape([-1, 1, 1]) # [batch, 1, 1]
rids = rids.repeat(1, node_num, neighbor_num) # [batch, nodes, neighbors]
indices = torch.stack((rids, index), 3) # [batch, nodes, neighbors, 2]
return node_reps[indices[:, :, :, 0], indices[:, :, :, 1], :]
def attention_on_neighbors(self,node_reps, neighbors):
'''Return alpha value for each neighbor'''
alpha = 0
soft = nn.Softmax(1)
# [node_num,neighbor_num,256 dim]
neighbors = self.nodeW(neighbors)
# [node_num,neighbor_num,256 dim]
attention = torch.matmul(neighbors,self.attenW)
alpha = torch.bmm(attention,node_reps.unsqueeze(2))
# alpha [node_num,neighbor_num, 1]
return soft(alpha)