-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathmodel.py
251 lines (190 loc) · 10.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
class Encoder(nn.Module):
"""Encodes the static & dynamic states using 1d Convolution."""
def __init__(self, input_size, hidden_size):
super(Encoder, self).__init__()
self.conv = nn.Conv1d(input_size, hidden_size, kernel_size=1)
def forward(self, input):
output = self.conv(input)
return output # (batch, hidden_size, seq_len)
class Attention(nn.Module):
"""Calculates attention over the input nodes given the current state."""
def __init__(self, hidden_size):
super(Attention, self).__init__()
# W processes features from static decoder elements
self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),
device=device, requires_grad=True))
self.W = nn.Parameter(torch.zeros((1, hidden_size, 3 * hidden_size),
device=device, requires_grad=True))
def forward(self, static_hidden, dynamic_hidden, decoder_hidden):
batch_size, hidden_size, _ = static_hidden.size()
hidden = decoder_hidden.unsqueeze(2).expand_as(static_hidden)
hidden = torch.cat((static_hidden, dynamic_hidden, hidden), 1)
# Broadcast some dimensions so we can do batch-matrix-multiply
v = self.v.expand(batch_size, 1, hidden_size)
W = self.W.expand(batch_size, hidden_size, -1)
attns = torch.bmm(v, torch.tanh(torch.bmm(W, hidden)))
attns = F.softmax(attns, dim=2) # (batch, seq_len)
return attns
class Pointer(nn.Module):
"""Calculates the next state given the previous state and input embeddings."""
def __init__(self, hidden_size, num_layers=1, dropout=0.2):
super(Pointer, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# Used to calculate probability of selecting next state
self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),
device=device, requires_grad=True))
self.W = nn.Parameter(torch.zeros((1, hidden_size, 2 * hidden_size),
device=device, requires_grad=True))
# Used to compute a representation of the current decoder output
self.gru = nn.GRU(hidden_size, hidden_size, num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0)
self.encoder_attn = Attention(hidden_size)
self.drop_rnn = nn.Dropout(p=dropout)
self.drop_hh = nn.Dropout(p=dropout)
def forward(self, static_hidden, dynamic_hidden, decoder_hidden, last_hh):
rnn_out, last_hh = self.gru(decoder_hidden.transpose(2, 1), last_hh)
rnn_out = rnn_out.squeeze(1)
# Always apply dropout on the RNN output
rnn_out = self.drop_rnn(rnn_out)
if self.num_layers == 1:
# If > 1 layer dropout is already applied
last_hh = self.drop_hh(last_hh)
# Given a summary of the output, find an input context
enc_attn = self.encoder_attn(static_hidden, dynamic_hidden, rnn_out)
context = enc_attn.bmm(static_hidden.permute(0, 2, 1)) # (B, 1, num_feats)
# Calculate the next output using Batch-matrix-multiply ops
context = context.transpose(1, 2).expand_as(static_hidden)
energy = torch.cat((static_hidden, context), dim=1) # (B, num_feats, seq_len)
v = self.v.expand(static_hidden.size(0), -1, -1)
W = self.W.expand(static_hidden.size(0), -1, -1)
probs = torch.bmm(v, torch.tanh(torch.bmm(W, energy))).squeeze(1)
return probs, last_hh
class DRL4TSP(nn.Module):
"""Defines the main Encoder, Decoder, and Pointer combinatorial models.
Parameters
----------
static_size: int
Defines how many features are in the static elements of the model
(e.g. 2 for (x, y) coordinates)
dynamic_size: int > 1
Defines how many features are in the dynamic elements of the model
(e.g. 2 for the VRP which has (load, demand) attributes. The TSP doesn't
have dynamic elements, but to ensure compatility with other optimization
problems, assume we just pass in a vector of zeros.
hidden_size: int
Defines the number of units in the hidden layer for all static, dynamic,
and decoder output units.
update_fn: function or None
If provided, this method is used to calculate how the input dynamic
elements are updated, and is called after each 'point' to the input element.
mask_fn: function or None
Allows us to specify which elements of the input sequence are allowed to
be selected. This is useful for speeding up training of the networks,
by providing a sort of 'rules' guidlines to the algorithm. If no mask
is provided, we terminate the search after a fixed number of iterations
to avoid tours that stretch forever
num_layers: int
Specifies the number of hidden layers to use in the decoder RNN
dropout: float
Defines the dropout rate for the decoder
"""
def __init__(self, static_size, dynamic_size, hidden_size,
update_fn=None, mask_fn=None, num_layers=1, dropout=0.):
super(DRL4TSP, self).__init__()
if dynamic_size < 1:
raise ValueError(':param dynamic_size: must be > 0, even if the '
'problem has no dynamic elements')
self.update_fn = update_fn
self.mask_fn = mask_fn
# Define the encoder & decoder models
self.static_encoder = Encoder(static_size, hidden_size)
self.dynamic_encoder = Encoder(dynamic_size, hidden_size)
self.decoder = Encoder(static_size, hidden_size)
self.pointer = Pointer(hidden_size, num_layers, dropout)
for p in self.parameters():
if len(p.shape) > 1:
nn.init.xavier_uniform_(p)
# Used as a proxy initial state in the decoder when not specified
self.x0 = torch.zeros((1, static_size, 1), requires_grad=True, device=device)
def forward(self, static, dynamic, decoder_input=None, last_hh=None):
"""
Parameters
----------
static: Array of size (batch_size, feats, num_cities)
Defines the elements to consider as static. For the TSP, this could be
things like the (x, y) coordinates, which won't change
dynamic: Array of size (batch_size, feats, num_cities)
Defines the elements to consider as static. For the VRP, this can be
things like the (load, demand) of each city. If there are no dynamic
elements, this can be set to None
decoder_input: Array of size (batch_size, num_feats)
Defines the outputs for the decoder. Currently, we just use the
static elements (e.g. (x, y) coordinates), but this can technically
be other things as well
last_hh: Array of size (batch_size, num_hidden)
Defines the last hidden state for the RNN
"""
batch_size, input_size, sequence_size = static.size()
if decoder_input is None:
decoder_input = self.x0.expand(batch_size, -1, -1)
# Always use a mask - if no function is provided, we don't update it
mask = torch.ones(batch_size, sequence_size, device=device)
# Structures for holding the output sequences
tour_idx, tour_logp = [], []
max_steps = sequence_size if self.mask_fn is None else 1000
# Static elements only need to be processed once, and can be used across
# all 'pointing' iterations. When / if the dynamic elements change,
# their representations will need to get calculated again.
static_hidden = self.static_encoder(static)
dynamic_hidden = self.dynamic_encoder(dynamic)
for _ in range(max_steps):
if not mask.byte().any():
break
# ... but compute a hidden rep for each element added to sequence
decoder_hidden = self.decoder(decoder_input)
probs, last_hh = self.pointer(static_hidden,
dynamic_hidden,
decoder_hidden, last_hh)
probs = F.softmax(probs + mask.log(), dim=1)
# When training, sample the next step according to its probability.
# During testing, we can take the greedy approach and choose highest
if self.training:
m = torch.distributions.Categorical(probs)
# Sometimes an issue with Categorical & sampling on GPU; See:
# https://github.com/pemami4911/neural-combinatorial-rl-pytorch/issues/5
ptr = m.sample()
while not torch.gather(mask, 1, ptr.data.unsqueeze(1)).byte().all():
ptr = m.sample()
logp = m.log_prob(ptr)
else:
prob, ptr = torch.max(probs, 1) # Greedy
logp = prob.log()
# After visiting a node update the dynamic representation
if self.update_fn is not None:
dynamic = self.update_fn(dynamic, ptr.data)
dynamic_hidden = self.dynamic_encoder(dynamic)
# Since we compute the VRP in minibatches, some tours may have
# number of stops. We force the vehicles to remain at the depot
# in these cases, and logp := 0
is_done = dynamic[:, 1].sum(1).eq(0).float()
logp = logp * (1. - is_done)
# And update the mask so we don't re-visit if we don't need to
if self.mask_fn is not None:
mask = self.mask_fn(mask, dynamic, ptr.data).detach()
tour_logp.append(logp.unsqueeze(1))
tour_idx.append(ptr.data.unsqueeze(1))
decoder_input = torch.gather(static, 2,
ptr.view(-1, 1, 1)
.expand(-1, input_size, 1)).detach()
tour_idx = torch.cat(tour_idx, dim=1) # (batch_size, seq_len)
tour_logp = torch.cat(tour_logp, dim=1) # (batch_size, seq_len)
return tour_idx, tour_logp
if __name__ == '__main__':
raise Exception('Cannot be called from main')