Skip to content

Commit

Permalink
[BugFix] Local Search with PyVRP; bump up version #212
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Sep 4, 2024
1 parent 71992b4 commit aeb786c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pytest-cov = { version = "*", optional = true }
torch_geometric = { version = "*", optional = true }
# Routing
numba = { version = ">=0.58.1", optional = true }
pyvrp = { version = ">=0.8.2", optional = true, python = "<4.0" }
pyvrp = { version = ">=0.9.0", optional = true, python = "<4.0" }
# Docs
mkdocs = { version = "*", optional = true }
mkdocs-material = { version = "*", optional = true }
Expand Down
4 changes: 2 additions & 2 deletions rl4co/envs/routing/cvrp/local_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def make_data(
name=",".join(map(str, range(1, len(positions)))),
)
],
distance_matrix=distances,
duration_matrix=np.zeros_like(distances),
distance_matrices=[distances],
duration_matrices=[np.zeros_like(distances)],
)


Expand Down
6 changes: 3 additions & 3 deletions rl4co/envs/routing/mtvrp/baselines/pyvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData:
depot = Depot(
x=coords[0][0],
y=coords[0][1],
tw_early=time_windows[0][0],
tw_late=time_windows[0][1],
)

clients = [
Expand All @@ -78,6 +76,8 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData:
num_available=num_locs - 1, # one vehicle per client
capacity=capacity,
max_distance=max_distance,
tw_early=time_windows[0][0],
tw_late=time_windows[0][1],
)

matrix = scale(instance["cost_matrix"], scaling_factor)
Expand All @@ -99,7 +99,7 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData:
# matrix[0, backhaul] = MAX_VALUE
matrix[np.ix_(backhaul, linehaul)] = MAX_VALUE

return ProblemData(clients, [depot], [vehicle_type], matrix, matrix)
return ProblemData(clients, [depot], [vehicle_type], [matrix], [matrix])


def solution2action(solution: pyvrp.Solution) -> list[int]:
Expand Down
6 changes: 3 additions & 3 deletions rl4co/envs/routing/mtvrp/baselines/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def solve(self, *args, **kwargs):


try:
import routefinder.baselines.pyvrp as pyvrp
import rl4co.envs.routing.mtvrp.baselines.pyvrp as pyvrp
except ImportError:
pyvrp = NoSolver()
try:
import routefinder.baselines.lkh as lkh
import rl4co.envs.routing.mtvrp.baselines.lkh as lkh
except ImportError:
lkh = NoSolver()
try:
import routefinder.baselines.ortools as ortools
import rl4co.envs.routing.mtvrp.baselines.ortools as ortools
except ImportError:
ortools = NoSolver()

Expand Down

0 comments on commit aeb786c

Please sign in to comment.