diff --git a/pyproject.toml b/pyproject.toml index e2c4227a..bda21b75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ pytest-cov = { version = "*", optional = true } torch_geometric = { version = "*", optional = true } # Routing numba = { version = ">=0.58.1", optional = true } -pyvrp = { version = ">=0.8.2", optional = true, python = "<4.0" } +pyvrp = { version = ">=0.9.0", optional = true, python = "<4.0" } # Docs mkdocs = { version = "*", optional = true } mkdocs-material = { version = "*", optional = true } diff --git a/rl4co/envs/routing/cvrp/local_search.py b/rl4co/envs/routing/cvrp/local_search.py index cdb45f4e..73deb7d4 100644 --- a/rl4co/envs/routing/cvrp/local_search.py +++ b/rl4co/envs/routing/cvrp/local_search.py @@ -164,8 +164,8 @@ def make_data( name=",".join(map(str, range(1, len(positions)))), ) ], - distance_matrix=distances, - duration_matrix=np.zeros_like(distances), + distance_matrices=[distances], + duration_matrices=[np.zeros_like(distances)], ) diff --git a/rl4co/envs/routing/mtvrp/baselines/pyvrp.py b/rl4co/envs/routing/mtvrp/baselines/pyvrp.py index 6f57b9e5..1e889323 100644 --- a/rl4co/envs/routing/mtvrp/baselines/pyvrp.py +++ b/rl4co/envs/routing/mtvrp/baselines/pyvrp.py @@ -57,8 +57,6 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData: depot = Depot( x=coords[0][0], y=coords[0][1], - tw_early=time_windows[0][0], - tw_late=time_windows[0][1], ) clients = [ @@ -78,6 +76,8 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData: num_available=num_locs - 1, # one vehicle per client capacity=capacity, max_distance=max_distance, + tw_early=time_windows[0][0], + tw_late=time_windows[0][1], ) matrix = scale(instance["cost_matrix"], scaling_factor) @@ -99,7 +99,7 @@ def instance2data(instance: TensorDict, scaling_factor: int) -> ProblemData: # matrix[0, backhaul] = MAX_VALUE matrix[np.ix_(backhaul, linehaul)] = MAX_VALUE - return ProblemData(clients, [depot], [vehicle_type], matrix, matrix) + return ProblemData(clients, [depot], [vehicle_type], [matrix], [matrix]) def solution2action(solution: pyvrp.Solution) -> list[int]: diff --git a/rl4co/envs/routing/mtvrp/baselines/solve.py b/rl4co/envs/routing/mtvrp/baselines/solve.py index 6cd5bbc8..f750fe6d 100644 --- a/rl4co/envs/routing/mtvrp/baselines/solve.py +++ b/rl4co/envs/routing/mtvrp/baselines/solve.py @@ -13,15 +13,15 @@ def solve(self, *args, **kwargs): try: - import routefinder.baselines.pyvrp as pyvrp + import rl4co.envs.routing.mtvrp.baselines.pyvrp as pyvrp except ImportError: pyvrp = NoSolver() try: - import routefinder.baselines.lkh as lkh + import rl4co.envs.routing.mtvrp.baselines.lkh as lkh except ImportError: lkh = NoSolver() try: - import routefinder.baselines.ortools as ortools + import rl4co.envs.routing.mtvrp.baselines.ortools as ortools except ImportError: ortools = NoSolver()