diff --git a/rl4co/envs/routing/mtvrp/env.py b/rl4co/envs/routing/mtvrp/env.py index ce9d34a7..fa74f1f2 100644 --- a/rl4co/envs/routing/mtvrp/env.py +++ b/rl4co/envs/routing/mtvrp/env.py @@ -281,7 +281,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: & ~exceeds_dist_limit & ~td["visited"] ) - + #print(can_visit) # Mask depot: don't visit depot if coming from there and there are still customer nodes I can visit can_visit[:, 0] = ~((curr_node == 0) & (can_visit[:, 1:].sum(-1) > 0)) return can_visit @@ -349,9 +349,14 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): curr_time = torch.max( curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] ) + + new_shape = curr_time.size() + skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape) + assert torch.all( - curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] + (curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end ), "vehicle cannot start service before deadline" + curr_time = curr_time + gather_by_index(td["service_time"], next_node) curr_node = next_node curr_time[curr_node == 0] = 0.0 # reset time for depot @@ -450,7 +455,7 @@ def _make_spec(self, td_params: TensorDict): def check_variants(td): """Check if the problem has the variants""" has_open = td["open_route"].squeeze(-1) - has_tw = (td["time_windows"][:, :, 1] != float("inf")).any(-1) + has_tw = (td["time_windows"][:, :, 1] != 4.6).any(-1) has_limit = (td["distance_limit"] != float("inf")).squeeze(-1) has_backhaul = (td["demand_backhaul"] != 0).any(-1) return has_open, has_tw, has_limit, has_backhaul diff --git a/rl4co/envs/routing/mtvrp/generator.py b/rl4co/envs/routing/mtvrp/generator.py index 81692ade..d441b306 100644 --- a/rl4co/envs/routing/mtvrp/generator.py +++ b/rl4co/envs/routing/mtvrp/generator.py @@ -256,7 +256,8 @@ def _default_open(td, remove): @staticmethod def _default_time_window(td, remove): default_tw = torch.zeros_like(td["time_windows"]) - default_tw[..., 1] = float("inf") + #default_tw[..., 1] = float("inf") + default_tw[..., 1] = 4.6 # max tw td["time_windows"][remove] = default_tw[remove] td["service_time"][remove] = torch.zeros_like(td["service_time"][remove]) return td diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index 79f236e0..a4517912 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -32,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "mtsp": MTSPContext, "smtwtp": SMTWTPContext, "mdcpdp": MDCPDPContext, + "mtvrp": MTVRPContext } if env_name not in embedding_registry: @@ -146,6 +147,50 @@ def _state_embedding(self, embeddings, td): state_embedding = td["vehicle_capacity"] - td["used_capacity"] return state_embedding +class VRPBContext(EnvContext): + """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). + Project the following to the embedding space: + - current node embedding + - remaining capacity (vehicle_capacity - used_capacity) + """ + + def __init__(self, embed_dim): + super(VRPContext, self).__init__( + embed_dim=embed_dim, step_context_dim=embed_dim + 1 + ) + + def _state_embedding(self, embeddings, td): + mask = (td["used_capacity_backhaul"] == 0) + used_capacity = torch.where(mask, td["used_capacity_linehaul"], td["used_capacity_backhaul"]) + state_embedding = td["vehicle_capacity"] - used_capacity + return state_embedding + +class MTVRPContext(VRPBContext): + """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). + Project the following to the embedding space: + - current node embedding + - remaining capacity (vehicle_capacity - used_capacity) + - current time + - current route length + - if route should be open + """ + + def __init__(self, embed_dim): + super(VRPBContext, self).__init__( + embed_dim=embed_dim, step_context_dim=embed_dim + 4 + ) + + def _state_embedding(self, embeddings, td): + + capacity = super()._state_embedding(embeddings, td) + current_time = td["current_time"] + current_length = td["current_route_length"] + is_open = td["open_route"] + is_open_tensor = torch.zeros_like(is_open, dtype=torch.float) + is_open_tensor[is_open] = 1 + + return torch.cat([capacity, current_time, current_length, is_open_tensor], -1) + class VRPTWContext(VRPContext): """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index 5b056f80..d5c00be0 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -2,10 +2,8 @@ import torch.nn as nn from tensordict.tensordict import TensorDict - from rl4co.models.nn.ops import PositionalEncoding - def env_init_embedding(env_name: str, config: dict) -> nn.Module: """Get environment initial embedding. The init embedding is used to initialize the general embedding of the problem nodes without any solution information. @@ -33,6 +31,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "smtwtp": SMTWTPInitEmbedding, "mdcpdp": MDCPDPInitEmbedding, "fjsp": FJSPFeatureEmbedding, + "mtvrp":MTVRPInitEmbedding, } if env_name not in embedding_registry: @@ -146,6 +145,28 @@ def forward(self, td): ) ) return torch.cat((depot_embedding, node_embeddings), -2) + + +class MTVRPInitEmbedding(VRPInitEmbedding): + def __init__(self, embed_dim, linear_bias=True, node_dim: int = 5): + # node_dim = 5: x, y, demand, tw start, tw end + super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim) + + def forward(self, td): + depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] + #durations = td["durations"][..., 1:] + time_windows = td["time_windows"][..., 1:, :] + # embeddings + demands = td["demand_linehaul"][..., None] - td["demand_backhaul"][..., None] + + depot_embedding = self.init_embed_depot(depot) + node_embeddings = self.init_embed( + torch.cat( + (cities, demands[:,1:], time_windows), -1 + ) + ) + + return torch.cat((depot_embedding, node_embeddings), -2) class SVRPInitEmbedding(nn.Module): @@ -383,7 +404,6 @@ def forward(self, td): # concatenate on graph size dimension return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) - class FJSPFeatureEmbedding(nn.Module): def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100): super().__init__() @@ -443,4 +463,4 @@ def _stepwise_operations_embed(self, td: TensorDict): raise NotImplementedError("Stepwise encoding not yet implemented") def _stepwise_machine_embed(self, td: TensorDict): - raise NotImplementedError("Stepwise encoding not yet implemented") + raise NotImplementedError("Stepwise encoding not yet implemented") \ No newline at end of file