Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into polynet
Browse files Browse the repository at this point in the history
  • Loading branch information
ahottung committed Jun 2, 2024
2 parents 1acd299 + a14de48 commit e4d9b29
Show file tree
Hide file tree
Showing 44 changed files with 353 additions and 147 deletions.
2 changes: 1 addition & 1 deletion configs/env/cvrp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: cvrp

generator_params:
num_loc: 20
distribution: uniform
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/vrp
val_file: vrp${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
2 changes: 1 addition & 1 deletion configs/env/cvrptw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: cvrptw

generator_params:
num_loc: 20
distribution: uniform
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/cvrptw
val_file: cvrptw${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
2 changes: 1 addition & 1 deletion configs/env/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ name: tsp

generator_params:
num_loc: 20
distribution: uniform
loc_distribution: uniform
2 changes: 1 addition & 1 deletion configs/env/mdcpdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ depot_mode: multiple

generator_params:
num_loc: 20
distribution: uniform
loc_distribution: uniform
num_depot: 4
min_loc: 0
max_loc: 1
Expand Down
2 changes: 2 additions & 0 deletions configs/env/mtsp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ generator_params:
max_loc: 1
min_num_agents: 3
max_num_agents: 3
loc_distribution: uniform

1 change: 1 addition & 0 deletions configs/env/op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: op

generator_params:
num_loc: 20
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/op
val_file: op_const${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
1 change: 1 addition & 0 deletions configs/env/pdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ generator_params:
num_loc: 20
min_loc: 0
max_loc: 1
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/pdp
val_file: pdp${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
1 change: 1 addition & 0 deletions configs/env/sdvrp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: sdvrp

generator_params:
num_loc: 20
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/vrp
val_file: vrp${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
1 change: 1 addition & 0 deletions configs/env/spctsp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: spctsp

generator_params:
num_loc: 20
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/pctsp
val_file: pctsp${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
1 change: 1 addition & 0 deletions configs/env/svrp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: svrp

generator_params:
num_loc: 20
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/svrp
val_file: cvrptw${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
2 changes: 1 addition & 1 deletion configs/env/tsp.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
_target_: rl4co.envs.TSPEnv

name: tsp

generator_params:
num_loc: 20
loc_distribution: uniform

data_dir: ${paths.root_dir}/data/tsp
val_file: tsp${env.generator_params.num_loc}_val_seed4321.npz
Expand Down
1 change: 1 addition & 0 deletions configs/experiment/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ defaults:
env:
generator_params:
num_loc: 50
check_solution: False # optimization

# Logging: we use Wandb in this case
logger:
Expand Down
22 changes: 22 additions & 0 deletions configs/experiment/routing/am-a2c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package _global_

# Use the following to take the default values from am.yaml
# Replace below only the values that you want to change compared to the default values
defaults:
- routing/am.yaml
- _self_

logger:
wandb:
tags: ["am-a2c", "${env.name}"]
name: am-a2c-${env.name}${env.generator_params.num_loc}

model:
_target_: rl4co.models.A2C
policy:
_target_: rl4co.models.AttentionModelPolicy
env_name: "${env.name}"
actor_optimizer_kwargs:
lr: 1e-4
weight_decay: 1e-6
critic_optimizer_kwargs: null # default to actor_optimizer_kwargs
14 changes: 0 additions & 14 deletions configs/experiment/routing/am-critic.yaml

This file was deleted.

5 changes: 2 additions & 3 deletions configs/experiment/routing/am-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
max_grad_norm: 0.5
optimizer_kwargs:
lr: 1e-4
weight_decay: 0
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
Expand All @@ -49,5 +49,4 @@ trainer:

seed: 1234

metrics:
train: ["loss", "reward", "surrogate_loss", "value_loss", "entropy_bonus"]

2 changes: 1 addition & 1 deletion configs/experiment/routing/am-svrp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model:
val_data_size: 10_000
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4
lr: 1e-6
weight_decay: 0
lr_scheduler:
"MultiStepLR"
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/am-xl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ model:
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4
weight_decay: 0
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/am.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ model:
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4
weight_decay: 0
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
Expand Down
47 changes: 47 additions & 0 deletions configs/experiment/routing/ar-gnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_

defaults:
- override /model: pomo.yaml
- override /env: tsp.yaml
- override /callbacks: default.yaml
- override /trainer: default.yaml
- override /logger: wandb.yaml

env:
generator_params:
num_loc: 50

logger:
wandb:
project: "rl4co"
tags: ["pomo", "${env.name}"]
group: "${env.name}${env.generator_params.num_loc}"
name: "pomo-${env.name}${env.generator_params.num_loc}"


model:
policy:
_target_: rl4co.models.zoo.am.policy.AttentionModelPolicy
encoder:
_target_: rl4co.models.zoo.nargnn.encoder.NARGNNNodeEncoder
embed_dim: 128
env_name: "${env.name}"
env_name: "${env.name}"
batch_size: 64
train_data_size: 160_000
val_data_size: 10_000
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
milestones: [80, 95]
gamma: 0.1

trainer:
max_epochs: 100

seed: 1234

2 changes: 1 addition & 1 deletion configs/experiment/routing/pomo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
milestones: [95]
milestones: [80, 95]
gamma: 0.1

trainer:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/ptrnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ model:
test_data_size: 10_000
optimizer_kwargs:
lr: 1e-4
weight_decay: 0
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/routing/symnco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ model:
num_augment: 10
optimizer_kwargs:
lr: 1e-4
weight_decay: 0
weight_decay: 1e-6
lr_scheduler:
"MultiStepLR"
lr_scheduler_kwargs:
Expand Down
5 changes: 1 addition & 4 deletions configs/model/am-ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
_target_: rl4co.models.AMPPO

metrics:
train: ["loss", "reward", "surrogate_loss", "value_loss", "entropy_bonus"]
val: ["reward"]
test: ${metrics.val}
log_on_step: True
train: ["loss", "reward", "surrogate_loss", "value_loss", "entropy_bonus"]
2 changes: 1 addition & 1 deletion configs/model/symnco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ num_starts: 0 # by default we use only symmetric augmentations
metrics:
train: ["loss", "loss_ss", "loss_ps", "loss_inv", "reward"]
val: ["reward", "max_reward", "max_aug_reward"]
test: ${metrics.val}
test: ${model.metrics.val}
log_on_step: True
72 changes: 41 additions & 31 deletions examples/1-quickstart.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.1dev1"
__version__ = "0.4.1dev2"
20 changes: 8 additions & 12 deletions rl4co/envs/routing/cvrp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,23 @@ def _reset(
@staticmethod
def get_action_mask(td: TensorDict) -> torch.Tensor:
# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
exceeds_cap = (
td["demand"] + td["used_capacity"] > td["vehicle_capacity"]
)
exceeds_cap = td["demand"] + td["used_capacity"] > td["vehicle_capacity"]

# Nodes that cannot be visited are already visited or too much demand to be served now
mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap

# Cannot visit the depot if just visited and still unserved nodes
mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0)[:, None]
mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0)[
:, None
]
return ~torch.cat((mask_depot, mask_loc), -1)

def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict:
# Gather dataset in order of tour
batch_size = td["locs"].shape[0]
depot = td["locs"][..., 0:1, :]
# Gather locations in order of tour (add depot since we start and end there)
locs_ordered = torch.cat(
[
depot,
gather_by_index(td["locs"], actions).reshape(
[batch_size, actions.size(-1), 2]
),
td["locs"][..., 0:1, :], # depot
gather_by_index(td["locs"], actions), # order locations
],
dim=1,
)
Expand Down Expand Up @@ -231,5 +227,5 @@ def _make_spec(self, generator: CVRPGenerator):
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)

@staticmethod
def render(td: TensorDict, actions: torch.Tensor=None, ax = None):
def render(td: TensorDict, actions: torch.Tensor = None, ax=None):
return render(td, actions, ax)
20 changes: 13 additions & 7 deletions rl4co/envs/routing/cvrp/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class CVRPGenerator(Generator):
min_loc: minimum value for the location coordinates
max_loc: maximum value for the location coordinates
loc_distribution: distribution for the location coordinates
depot_distribution: distribution for the depot location. If None, sample the depot from the locations
min_demand: minimum value for the demand of each customer
max_demand: maximum value for the demand of each customer
demand_distribution: distribution for the demand of each customer
Expand All @@ -57,7 +58,7 @@ def __init__(
min_loc: float = 0.0,
max_loc: float = 1.0,
loc_distribution: Union[int, float, str, type, Callable] = Uniform,
depot_distribution: Union[int, float, str, type, Callable] = Uniform,
depot_distribution: Union[int, float, str, type, Callable] = None,
min_demand: int = 1,
max_demand: int = 10,
demand_distribution: Union[int, float, type, Callable] = Uniform,
Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(
else:
self.depot_sampler = get_sampler(
"depot", depot_distribution, min_loc, max_loc, **kwargs
)
) if depot_distribution is not None else None

# Demand distribution
if kwargs.get("demand_sampler", None) is not None:
Expand All @@ -113,11 +114,16 @@ def __init__(
self.capacity = capacity

def _generate(self, batch_size) -> TensorDict:
# Sample locations
locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2))

# Sample depot
depot = self.depot_sampler.sample((*batch_size, 2))

# Sample locations: depot and customers
if self.depot_sampler is not None:
depot = self.depot_sampler.sample((*batch_size, 2))
locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2))
else:
# if depot_sampler is None, sample the depot from the locations
locs = self.loc_sampler.sample((*batch_size, self.num_loc + 1, 2))
depot = locs[..., 0, :]
locs = locs[..., 1:, :]

# Sample demands
demand = self.demand_sampler.sample((*batch_size, self.num_loc))
Expand Down
3 changes: 2 additions & 1 deletion rl4co/envs/routing/cvrptw/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class CVRPTWGenerator(CVRPGenerator):
min_loc: minimum value for the location coordinates
max_loc: maximum value for the location coordinates, default is 150 insted of 1.0, will be scaled
loc_distribution: distribution for the location coordinates
depot_distribution: distribution for the depot location. If None, sample the depot from the locations
min_demand: minimum value for the demand of each customer
max_demand: maximum value for the demand of each customer
demand_distribution: distribution for the demand of each customer
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
] = Uniform,
depot_distribution: Union[
int, float, str, type, Callable
] = Uniform,
] = None,
min_demand: int = 1,
max_demand: int = 10,
demand_distribution: Union[
Expand Down
Loading

0 comments on commit e4d9b29

Please sign in to comment.