-
Notifications
You must be signed in to change notification settings - Fork 15
/
utils.py
37 lines (26 loc) · 1020 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch_geometric.utils import degree
class Indegree(object):
r"""Adds the globally normalized node degree to the node features.
Args:
cat (bool, optional): If set to :obj:`False`, all existing node
features will be replaced. (default: :obj:`True`)
"""
def __init__(self, norm=True, max_value=None, cat=True):
self.norm = norm
self.max = max_value
self.cat = cat
def __call__(self, data):
col, x = data.edge_index[1], data.x
deg = degree(col, data.num_nodes)
if self.norm:
deg = deg / (deg.max() if self.max is None else self.max)
deg = deg.view(-1, 1)
if x is not None and self.cat:
x = x.view(-1, 1) if x.dim() == 1 else x
data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)
else:
data.x = deg
return data
def __repr__(self):
return '{}(norm={}, max_value={})'.format(self.__class__.__name__, self.norm, self.max)