-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel.py
94 lines (77 loc) · 3.26 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import math
import torch
import torch_scatter
import torch_geometric as pyg
NUM_PARTICLE_TYPES = 9
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
super().__init__()
self.layers = torch.nn.ModuleList()
for i in range(layers):
self.layers.append(torch.nn.Linear(
input_size if i == 0 else hidden_size,
output_size if i == layers - 1 else hidden_size,
))
if i != layers - 1:
self.layers.append(torch.nn.ReLU())
if layernorm:
self.layers.append(torch.nn.LayerNorm(output_size))
self.reset_parameters()
def reset_parameters(self):
for layer in self.layers:
if isinstance(layer, torch.nn.Linear):
layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
layer.bias.data.fill_(0)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class InteractionNetwork(pyg.nn.MessagePassing):
def __init__(self, hidden_size, layers):
super().__init__()
self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, 3)
self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, 3)
def forward(self, x, edge_index, edge_feature):
edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
edge_out = edge_feature + edge_out
node_out = x + node_out
return node_out, edge_out
def message(self, x_i, x_j, edge_feature):
x = torch.cat((x_i, x_j, edge_feature), dim=-1)
x = self.lin_edge(x)
return x
def aggregate(self, inputs, index, dim_size=None):
out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
return (inputs, out)
class LearnedSimulator(torch.nn.Module):
def __init__(
self,
hidden_size=128,
n_mp_layers=10,
num_particle_types=NUM_PARTICLE_TYPES,
particle_type_dim=16,
dim=2,
window_size=5,
):
super().__init__()
self.window_size = window_size
self.embed_type = torch.nn.Embedding(num_particle_types, particle_type_dim)
self.node_in = MLP(particle_type_dim + dim * (window_size + 2), hidden_size, hidden_size, 3)
self.edge_in = MLP(dim + 1, hidden_size, hidden_size, 3)
self.node_out = MLP(hidden_size, hidden_size, dim, 3, layernorm=False)
self.n_mp_layers = n_mp_layers
self.layers = torch.nn.ModuleList([InteractionNetwork(
hidden_size, 3
) for _ in range(n_mp_layers)])
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.embed_type.weight)
def forward(self, data):
node_feature = torch.cat((self.embed_type(data.x), data.pos), dim=-1)
node_feature = self.node_in(node_feature)
edge_feature = self.edge_in(data.edge_attr)
for i in range(self.n_mp_layers):
node_feature, edge_feature = self.layers[i](node_feature, data.edge_index, edge_feature=edge_feature)
out = self.node_out(node_feature)
return out