Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MTVRP #176

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions rl4co/envs/routing/mtvrp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor:
& ~exceeds_dist_limit
& ~td["visited"]
)

#print(can_visit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] Debugging comments could be removed.

# Mask depot: don't visit depot if coming from there and there are still customer nodes I can visit
can_visit[:, 0] = ~((curr_node == 0) & (can_visit[:, 1:].sum(-1) > 0))
return can_visit
Expand Down Expand Up @@ -349,9 +349,14 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor):
curr_time = torch.max(
curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0]
)

new_shape = curr_time.size()
skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, good catch.
Anyways, I recommend setting check_solution to False when training; otherwise, the solution will be checked at each step and it can be a bit slow. I will add a warning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't think this is necessary. Since skip_open_end will only be true if next_node == 0, and since the depot has the highest time window end, curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) should always be True except when curr_time is very close to the max time already and then the duration in that last node is long enough to go over the time limit - is that something we want to allow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with ngastzepeda as the last node in the route should also satisfy the time window constraints (allow it back it depot even when it is OVRP). However, I find some outliners when training the MTVRP (i.e., the time window of the last node of OVRP route may exceed the max time window). I do not yet know the exact reason, instance generation, or masking procedure.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have a check.


assert torch.all(
curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]
(curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end
), "vehicle cannot start service before deadline"

curr_time = curr_time + gather_by_index(td["service_time"], next_node)
curr_node = next_node
curr_time[curr_node == 0] = 0.0 # reset time for depot
Expand Down Expand Up @@ -450,7 +455,7 @@ def _make_spec(self, td_params: TensorDict):
def check_variants(td):
"""Check if the problem has the variants"""
has_open = td["open_route"].squeeze(-1)
has_tw = (td["time_windows"][:, :, 1] != float("inf")).any(-1)
has_tw = (td["time_windows"][:, :, 1] != 4.6).any(-1)
Comment on lines -453 to +458
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the discussion with the _default_time_window(), Changing this bound in the setting will need to modify this part. Any reason for this hardcode?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid numerical issues during training as it will go through embedding, but I will have a check, the inf would be more general.

has_limit = (td["distance_limit"] != float("inf")).squeeze(-1)
has_backhaul = (td["demand_backhaul"] != 0).any(-1)
return has_open, has_tw, has_limit, has_backhaul
Expand Down
3 changes: 2 additions & 1 deletion rl4co/envs/routing/mtvrp/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def _default_open(td, remove):
@staticmethod
def _default_time_window(td, remove):
default_tw = torch.zeros_like(td["time_windows"])
default_tw[..., 1] = float("inf")
#default_tw[..., 1] = float("inf")
default_tw[..., 1] = 4.6 # max tw
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this influence the solution? If default time window is 4.6, the problem should not be a CVRP but a "relaxed" VRPTW.

The reason why I thought having "inf" is because it can generalize to any scale - for the embedding, this can be set as:

time_windows = torch.nan_to_num(td["time_windows"][..., 1:, :], posinf=0.0)

So it shouldn't influence the calculation as describe in Section 4.1 (Attribute composition) in your paper.
What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the default should be float("inf"), as T=4.6 should only apply as default value to the environments where we actually want to model time windows!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense!

td["time_windows"][remove] = default_tw[remove]
td["service_time"][remove] = torch.zeros_like(td["service_time"][remove])
return td
Expand Down
45 changes: 45 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mtsp": MTSPContext,
"smtwtp": SMTWTPContext,
"mdcpdp": MDCPDPContext,
"mtvrp": MTVRPContext
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -146,6 +147,50 @@ def _state_embedding(self, embeddings, td):
state_embedding = td["vehicle_capacity"] - td["used_capacity"]
return state_embedding

class VRPBContext(EnvContext):
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP).
Project the following to the embedding space:
- current node embedding
- remaining capacity (vehicle_capacity - used_capacity)
"""
Comment on lines +151 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this also includes backhauls, we should mention this in the docs.


def __init__(self, embed_dim):
super(VRPContext, self).__init__(
embed_dim=embed_dim, step_context_dim=embed_dim + 1
)

def _state_embedding(self, embeddings, td):
mask = (td["used_capacity_backhaul"] == 0)
used_capacity = torch.where(mask, td["used_capacity_linehaul"], td["used_capacity_backhaul"])
state_embedding = td["vehicle_capacity"] - used_capacity
return state_embedding

class MTVRPContext(VRPBContext):
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP).
Project the following to the embedding space:
- current node embedding
- remaining capacity (vehicle_capacity - used_capacity)
- current time
- current route length
- if route should be open
"""

def __init__(self, embed_dim):
super(VRPBContext, self).__init__(
embed_dim=embed_dim, step_context_dim=embed_dim + 4
)

def _state_embedding(self, embeddings, td):

capacity = super()._state_embedding(embeddings, td)
current_time = td["current_time"]
current_length = td["current_route_length"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the model understand whether there is a limit?
In case there is no limit (say CVRP), then it will be the same as having VRPL, since the model does not know whether the constraint will be enforced or not

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. The route length in the state_emdding should be the rest length, i.e., length limit-current length, instead of current length... It seems to be a mistake...

is_open = td["open_route"]
is_open_tensor = torch.zeros_like(is_open, dtype=torch.float)
is_open_tensor[is_open] = 1

return torch.cat([capacity, current_time, current_length, is_open_tensor], -1)


class VRPTWContext(VRPContext):
"""Context embedding for the Capacitated Vehicle Routing Problem (CVRP).
Expand Down
28 changes: 24 additions & 4 deletions rl4co/models/nn/env_embeddings/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import torch.nn as nn

from tensordict.tensordict import TensorDict

from rl4co.models.nn.ops import PositionalEncoding


def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"""Get environment initial embedding. The init embedding is used to initialize the
general embedding of the problem nodes without any solution information.
Expand Down Expand Up @@ -33,6 +31,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"smtwtp": SMTWTPInitEmbedding,
"mdcpdp": MDCPDPInitEmbedding,
"fjsp": FJSPFeatureEmbedding,
"mtvrp":MTVRPInitEmbedding,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -146,6 +145,28 @@ def forward(self, td):
)
)
return torch.cat((depot_embedding, node_embeddings), -2)


class MTVRPInitEmbedding(VRPInitEmbedding):
def __init__(self, embed_dim, linear_bias=True, node_dim: int = 5):
# node_dim = 5: x, y, demand, tw start, tw end
super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim)

def forward(self, td):
depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :]
#durations = td["durations"][..., 1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the durations not included?

time_windows = td["time_windows"][..., 1:, :]
# embeddings
demands = td["demand_linehaul"][..., None] - td["demand_backhaul"][..., None]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense; basically, if it's "-", the model will understand it is a backhaul. I was thinking about having a flag, but this is also good

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A flag is also a good idea!


depot_embedding = self.init_embed_depot(depot)
node_embeddings = self.init_embed(
torch.cat(
(cities, demands[:,1:], time_windows), -1
)
)

return torch.cat((depot_embedding, node_embeddings), -2)


class SVRPInitEmbedding(nn.Module):
Expand Down Expand Up @@ -383,7 +404,6 @@ def forward(self, td):
# concatenate on graph size dimension
return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2)


class FJSPFeatureEmbedding(nn.Module):
def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100):
super().__init__()
Expand Down Expand Up @@ -443,4 +463,4 @@ def _stepwise_operations_embed(self, td: TensorDict):
raise NotImplementedError("Stepwise encoding not yet implemented")

def _stepwise_machine_embed(self, td: TensorDict):
raise NotImplementedError("Stepwise encoding not yet implemented")
raise NotImplementedError("Stepwise encoding not yet implemented")
Loading