-
Notifications
You must be signed in to change notification settings - Fork 70
/
stgcn.py
130 lines (115 loc) · 5.01 KB
/
stgcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class TimeBlock(nn.Module):
"""
Neural network block that applies a temporal convolution to each node of
a graph in isolation.
"""
def __init__(self, in_channels, out_channels, kernel_size=3):
"""
:param in_channels: Number of input features at each node in each time
step.
:param out_channels: Desired number of output channels at each node in
each time step.
:param kernel_size: Size of the 1D temporal kernel.
"""
super(TimeBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
def forward(self, X):
"""
:param X: Input data of shape (batch_size, num_nodes, num_timesteps,
num_features=in_channels)
:return: Output data of shape (batch_size, num_nodes,
num_timesteps_out, num_features_out=out_channels)
"""
# Convert into NCHW format for pytorch to perform convolutions.
X = X.permute(0, 3, 1, 2)
temp = self.conv1(X) + torch.sigmoid(self.conv2(X))
out = F.relu(temp + self.conv3(X))
# Convert back from NCHW to NHWC
out = out.permute(0, 2, 3, 1)
return out
class STGCNBlock(nn.Module):
"""
Neural network block that applies a temporal convolution on each node in
isolation, followed by a graph convolution, followed by another temporal
convolution on each node.
"""
def __init__(self, in_channels, spatial_channels, out_channels,
num_nodes):
"""
:param in_channels: Number of input features at each node in each time
step.
:param spatial_channels: Number of output channels of the graph
convolutional, spatial sub-block.
:param out_channels: Desired number of output features at each node in
each time step.
:param num_nodes: Number of nodes in the graph.
"""
super(STGCNBlock, self).__init__()
self.temporal1 = TimeBlock(in_channels=in_channels,
out_channels=out_channels)
self.Theta1 = nn.Parameter(torch.FloatTensor(out_channels,
spatial_channels))
self.temporal2 = TimeBlock(in_channels=spatial_channels,
out_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.Theta1.shape[1])
self.Theta1.data.uniform_(-stdv, stdv)
def forward(self, X, A_hat):
"""
:param X: Input data of shape (batch_size, num_nodes, num_timesteps,
num_features=in_channels).
:param A_hat: Normalized adjacency matrix.
:return: Output data of shape (batch_size, num_nodes,
num_timesteps_out, num_features=out_channels).
"""
t = self.temporal1(X)
lfs = torch.einsum("ij,jklm->kilm", [A_hat, t.permute(1, 0, 2, 3)])
# t2 = F.relu(torch.einsum("ijkl,lp->ijkp", [lfs, self.Theta1]))
t2 = F.relu(torch.matmul(lfs, self.Theta1))
t3 = self.temporal2(t2)
return self.batch_norm(t3)
# return t3
class STGCN(nn.Module):
"""
Spatio-temporal graph convolutional network as described in
https://arxiv.org/abs/1709.04875v3 by Yu et al.
Input should have shape (batch_size, num_nodes, num_input_time_steps,
num_features).
"""
def __init__(self, num_nodes, num_features, num_timesteps_input,
num_timesteps_output):
"""
:param num_nodes: Number of nodes in the graph.
:param num_features: Number of features at each node in each time step.
:param num_timesteps_input: Number of past time steps fed into the
network.
:param num_timesteps_output: Desired number of future time steps
output by the network.
"""
super(STGCN, self).__init__()
self.block1 = STGCNBlock(in_channels=num_features, out_channels=64,
spatial_channels=16, num_nodes=num_nodes)
self.block2 = STGCNBlock(in_channels=64, out_channels=64,
spatial_channels=16, num_nodes=num_nodes)
self.last_temporal = TimeBlock(in_channels=64, out_channels=64)
self.fully = nn.Linear((num_timesteps_input - 2 * 5) * 64,
num_timesteps_output)
def forward(self, A_hat, X):
"""
:param X: Input data of shape (batch_size, num_nodes, num_timesteps,
num_features=in_channels).
:param A_hat: Normalized adjacency matrix.
"""
out1 = self.block1(X, A_hat)
out2 = self.block2(out1, A_hat)
out3 = self.last_temporal(out2)
out4 = self.fully(out3.reshape((out3.shape[0], out3.shape[1], -1)))
return out4