-
Notifications
You must be signed in to change notification settings - Fork 30
/
search.py
110 lines (86 loc) · 4.2 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
import numpy as np
import os
import time
import torch
import search_single
from vrp.data_utils import read_instances_pkl
import glob
import search_batch
from actor import VrpActorModel
class LnsOperatorPair:
def __init__(self, model, destroy_procedure, p_destruction):
self.model = model
self.destroy_procedure = destroy_procedure
self.p_destruction = p_destruction
def destroy_instances(instances, destroy_procedure=None, destruction_p=None):
for instance in instances:
if destroy_procedure == "R":
instance.destroy_random(destruction_p)
elif destroy_procedure == "P":
instance.destroy_point_based(destruction_p)
elif destroy_procedure == "T":
instance.destroy_tour_based(destruction_p)
def load_operator_pairs(path, config):
if path.endswith('.pt'):
model_paths = [path]
else:
model_paths = glob.glob(os.path.join(path, '*.pt'))
if not model_paths:
raise Exception(f"No operators found in {path}")
lns_operator_pairs = []
for model_path in model_paths:
model_data = torch.load(model_path, config.device)
actor = VrpActorModel(config.device, hidden_size=config.pointer_hidden_size).to(
config.device)
actor.load_state_dict(model_data['parameters'])
actor.eval()
operator_pair = LnsOperatorPair(actor, model_data['destroy_operation'], model_data['p_destruction'])
lns_operator_pairs.append(operator_pair)
return lns_operator_pairs
def evaluate_batch_search(config, model_path):
assert model_path is not None, 'No model path given'
logging.info('### Batch Search ###')
logging.info('Starting search')
start_time = time.time()
results = search_batch.lns_batch_search_mp(config, model_path)
runtime = (time.time() - start_time)
instance_id, costs, iterations = [], [], []
for r in results:
instance_id.extend(list(range(len(r[1]) * r[0], len(r[1]) * (r[0] + 1))))
costs.extend(r[1])
iterations.append(r[2])
path = os.path.join(config.output_path, "search", 'results.txt')
np.savetxt(path, np.column_stack((instance_id, costs)), delimiter=',', fmt=['%i', '%f'])
logging.info(
f"Test set costs: {np.mean(costs):.3f} Total Runtime (s): {runtime:.1f} Iterations: {np.mean(iterations):.1f}")
def evaluate_single_search(config, model_path, instance_path):
assert model_path is not None, 'No model path given'
assert instance_path is not None, 'No instance path given'
instance_names, costs, durations = [], [], []
logging.info("### Single instance search ###")
if instance_path.endswith(".vrp") or instance_path.endswith(".sd"):
logging.info("Starting solving a single instance")
instance_files_path = [instance_path]
elif instance_path.endswith(".pkl"):
instance_files_path = [instance_path] * len(read_instances_pkl(instance_path))
logging.info("Starting solving a .pkl instance set")
elif os.path.isdir(instance_path):
instance_files_path = [os.path.join(instance_path, f) for f in os.listdir(instance_path)]
logging.info("Starting solving all instances in directory")
else:
raise Exception("Unknown instance file format.")
for i, instance_path in enumerate(instance_files_path):
if instance_path.endswith(".pkl") or instance_path.endswith(".vrp") or instance_path.endswith(".sd"):
for _ in range(config.nb_runs):
cost, duration = search_single.lns_single_search_mp(instance_path, config.lns_timelimit, config,
model_path, i)
instance_names.append(instance_path)
costs.append(cost)
durations.append(duration)
output_path = os.path.join(config.output_path, "search", 'results.txt')
results = np.array(list(zip(instance_names, costs, durations)))
np.savetxt(output_path, results, delimiter=',', fmt=['%s', '%s', '%s'], header="name, cost, runtime")
logging.info(
f"NLNS single search evaluation results: Total Nb. Runs: {len(costs)}, "
f"Mean Costs: {np.mean(costs):.3f} Mean Runtime (s): {np.mean(durations):.1f}")