-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathaggin_conv.py
88 lines (75 loc) · 3.42 KB
/
aggin_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Callable, Optional
from torch_geometric.typing import Adj, OptTensor, Size
import torch
from torch import Tensor
from torch.nn import Parameter, Linear
import torch.nn.functional as F
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
class AGGINConv(MessagePassing):
"""
Args:
nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to
shape :obj:`[-1, out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`.
eps (float, optional): (Initial) :math:`\epsilon`-value.
(default: :obj:`0.`)
train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon`
will be a trainable parameter. (default: :obj:`False`)
requires_grad (bool, optional): If set to :obj:`False`, :math:`\beta`
will not be trainable. (default: :obj:`True`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
requires_grad: bool = True, add_self_loops: bool = True,
**kwargs):
super(AGGINConv, self).__init__(aggr='add', **kwargs)
self.nn = nn
self.initial_eps = eps
if train_eps:
self.eps = torch.nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.requires_grad = requires_grad
self.add_self_loops = add_self_loops
if requires_grad:
self.beta = Parameter(torch.Tensor(1))
else:
self.register_buffer('beta', torch.ones(1))
self.reset_parameters()
def reset_parameters(self):
if self.requires_grad:
self.beta.data.fill_(1)
for n in self.nn:
if isinstance(n, Linear):
n.reset_parameters()
self.eps.data.fill_(self.initial_eps)
def forward(self, x: Tensor, edge_index: Adj,
size: Size = None) -> Tensor:
""""""
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index,
num_nodes=x.size(self.node_dim))
elif isinstance(edge_index, SparseTensor):
edge_index = set_diag(edge_index)
x_norm = F.normalize(x, p=2., dim=-1)
# propagate_type: (x: Tensor, x_norm: Tensor)
out = self.propagate(edge_index, x=x, x_norm=x_norm, size=size)
if x is not None:
out += (1 + self.eps) * x
return self.nn(out)
def message(self, x_j: Tensor, x_norm_i: Tensor, x_norm_j: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1)
alpha = softmax(alpha, index, ptr, size_i)
return x_j * alpha.view(-1, 1)
def __repr__(self):
return '{}(nn={})'.format(self.__class__.__name__, self.nn)