-
Notifications
You must be signed in to change notification settings - Fork 0
/
tsp_tester.py
68 lines (47 loc) · 1.73 KB
/
tsp_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import or_gym
import torch
from util import VisualData, visualization, make_pointer_network, create_folder
from config import args_parser
from gym_util import play_tsp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test(actor, actor_dir, seq_len, result_dir):
# env setup
env_config = {'N': seq_len}
env = or_gym.make('TSP-v1', env_config=env_config)
# load actor
actor.load_state_dict(torch.load(actor_dir))
actor.eval()
visual_data = VisualData()
coords = torch.FloatTensor(env.coords).transpose(1, 0).unsqueeze(0)
total_reward = play_tsp(env, coords, actor, device)
log_probs, actions = actor.result()
visual_data.add(coords, actions, "test")
c, a, e = visual_data.get()
visualization(result_dir, c, a, e)
print('total length', total_reward)
def main():
args = args_parser()
seq_len = args.seq_len
result_dir = args.result_dir
actor_dir = args.actor_dir
create_folder(result_dir)
# Pointer network hyper parameter
embedding_size = args.embedding_size
hidden_size = args.hidden_size
n_glimpses = args.n_glimpses
tanh_exploration = args.tanh_exploration
print("args: ")
print("embedding size: %d" % embedding_size)
print("hidden size: %d" % hidden_size)
print("num glimpses: %d" % n_glimpses)
print("tanh exploration: %d" % tanh_exploration)
print("")
print("sequence length: %d" % seq_len)
print("result dir: %s" % result_dir)
print("actor dir: %s" % actor_dir)
ptr_net = make_pointer_network(embedding_size, hidden_size, n_glimpses, tanh_exploration, seq_len, device)
test(ptr_net, actor_dir, seq_len, result_dir)
if __name__ == "__main__":
main()
print('end tsp')