From 475e430b7b8c07a1a0228003fad7342758e73c51 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Mon, 27 May 2024 23:08:25 +0800 Subject: [PATCH 01/10] Update meta trainer --- examples/2d-meta_train.py | 90 ++++++++++++++ rl4co/utils/__init__.py | 1 + rl4co/utils/meta_trainer.py | 242 ++++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+) create mode 100644 examples/2d-meta_train.py create mode 100644 rl4co/utils/meta_trainer.py diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py new file mode 100644 index 00000000..2e79c6f8 --- /dev/null +++ b/examples/2d-meta_train.py @@ -0,0 +1,90 @@ +import sys +sys.path.append("/home/jieyi/rl4co") + +import pytz +import torch + +from datetime import datetime +from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary +from lightning.pytorch.loggers import WandbLogger + +from rl4co.envs import CVRPEnv +from rl4co.models.zoo.am import AttentionModelPolicy +from rl4co.models.zoo.pomo import POMO +from rl4co.utils.meta_trainer import RL4COMetaTrainer, MetaModelCallback + +def main(): + # Set device + device_id = 0 + + # RL4CO env based on TorchRL + env = CVRPEnv(generator_params={'num_loc': 50}) + + # Policy: neural network, in this case with encoder-decoder architecture + # Note that this is adapted the same as POMO did in the original paper + policy = AttentionModelPolicy(env_name=env.name, + embed_dim=128, + num_encoder_layers=6, + num_heads=8, + normalization="instance", + use_graph_context=False + ) + + # RL Model (POMO) + model = POMO(env, + policy, + batch_size=64, # meta_batch_size + train_data_size=64 * 50, # each epoch + val_data_size=0, + optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, + # for the task scheduler of size setting, where sch_epoch = 0.9 * epochs + ) + + # Example callbacks + checkpoint_callback = ModelCheckpoint( + dirpath="checkpoints", # save to checkpoints/ + filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt + save_top_k=1, # save only the best model + save_last=True, # save the last model + monitor="val/reward", # monitor validation reward + mode="max", # maximize validation reward + ) + rich_model_summary = RichModelSummary(max_depth=3) # model summary callback + # Meta callbacks + meta_callback = MetaModelCallback( + meta_params={ + 'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'maml_fomaml', 'reptile'] + 'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] + 'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs + 'B': 1, # the number of tasks in a mini-batch + 'alpha': 0.99, # params for the outer-loop optimization of reptile + 'alpha_decay': 0.999, # params for the outer-loop optimization of reptile + 'min_size': 20, # minimum of sampled size in meta tasks + 'max_size': 150, # maximum of sampled size in meta tasks + }, + print_log=True # whether to print the sampled tasks in each meta iteration + ) + callbacks = [meta_callback, checkpoint_callback, rich_model_summary] + + # Logger + process_start_time = datetime.now(pytz.timezone("Asia/Singapore")) + logger = WandbLogger(project="rl4co", name=f"{env.name}_{process_start_time.strftime('%Y%m%d_%H%M%S')}") + # logger = None # uncomment this line if you don't want logging + + # Adjust your trainer to the number of epochs you want to run + trainer = RL4COMetaTrainer( + max_epochs=20000, # (the number of meta-model updates) * (the number of tasks in a mini-batch) + callbacks=callbacks, + accelerator="gpu", + devices=[device_id], + logger=logger, + limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method + ) + + # Fit + trainer.fit(model) + + +if __name__ == "__main__": + main() + diff --git a/rl4co/utils/__init__.py b/rl4co/utils/__init__.py index 4b0246aa..638f3149 100644 --- a/rl4co/utils/__init__.py +++ b/rl4co/utils/__init__.py @@ -2,6 +2,7 @@ from rl4co.utils.pylogger import get_pylogger from rl4co.utils.rich_utils import enforce_tags, print_config_tree from rl4co.utils.trainer import RL4COTrainer +from rl4co.utils.meta_trainer import RL4COMetaTrainer from rl4co.utils.utils import ( extras, get_metric_value, diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py new file mode 100644 index 00000000..0d45ccff --- /dev/null +++ b/rl4co/utils/meta_trainer.py @@ -0,0 +1,242 @@ +from typing import Iterable, List, Optional, Union + +import lightning.pytorch as pl +import torch +import math +import copy +from torch.optim import Adam + +from lightning import Callback, Trainer +from lightning.fabric.accelerators.cuda import num_cuda_devices +from lightning.pytorch.accelerators import Accelerator +from lightning.pytorch.core.datamodule import LightningDataModule +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy, Strategy +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from rl4co import utils +import random +log = utils.get_pylogger(__name__) + + +class MetaModelCallback(Callback): + def __init__(self, meta_params, print_log=True): + super().__init__() + self.meta_params = meta_params + assert meta_params["meta_method"] == 'reptile', NotImplementedError + assert meta_params["data_type"] == 'size', NotImplementedError + self.print_log = print_log + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + + # Initialize some hyperparameters + self.alpha = self.meta_params["alpha"] + self.alpha_decay = self.meta_params["alpha_decay"] + self.sch_bar = self.meta_params["sch_bar"] + self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] + + # Sample a batch of tasks + self._sample_task() + self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + + # Alpha scheduler (decay for the update of meta model) + self._alpha_scheduler() + + # Reinitialize the task model with the parameters of the meta model + if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model + self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) + self.task_models = [] + # Print sampled tasks + if self.print_log: + print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks)) + else: + pl_module.load_state_dict(self.meta_model_state_dict) + + # Reinitialize the optimizer every epoch + lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1 + old_lr = trainer.optimizers[0].param_groups[0]['lr'] + new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay) + trainer.optimizers = [new_optimizer] + + # Print + if self.print_log: + print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) + + def on_train_epoch_end(self, trainer, pl_module): + + # Save the task model + self.task_models.append(copy.deepcopy(pl_module.state_dict())) + if (trainer.current_epoch+1) % self.meta_params['B'] == 0: + # Outer-loop optimization (update the meta model with the parameters of the task model) + with torch.no_grad(): + state_dict = {params_key: (self.meta_model_state_dict[params_key] + + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key] + for fast_weight in self.task_models], dim=0).float(), dim=0)) + for params_key in self.meta_model_state_dict} + pl_module.load_state_dict(state_dict) + + # Get ready for the next meta-training iteration + if (trainer.current_epoch + 1) % self.meta_params['B'] == 0: + # Sample a batch of tasks + self._sample_task() + + # Load new training task (Update the environment) + self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B']) + + def _sample_task(self): + # Sample a batch of tasks + w, self.selected_tasks = [1.0] * self.meta_params['B'], [] + for b in range(self.meta_params['B']): + task_params = random.sample(self.task_set, 1)[0] + self.selected_tasks.append(task_params) + self.w = torch.softmax(torch.Tensor(w), dim=0) + + def _load_task(self, pl_module, task_idx=0): + # Load new training task (Update the environment) + task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() + task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 + pl_module.env.generator.num_loc = task_params[0] + pl_module.env.generator.capacity = task_capacity + + def _alpha_scheduler(self): + self.alpha = max(self.alpha * self.alpha_decay, 0.0001) + +class RL4COMetaTrainer(Trainer): + """Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. + + Note: + The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. + This allows for datasets to be re-created on the run and distributed by Lightning across + devices on each epoch. Setting to a value different than 1 may lead to overfitting to a + specific (such as the initial) data distribution. + + Args: + accelerator: hardware accelerator to use. + callbacks: list of callbacks. + logger: logger (or iterable collection of loggers) for experiment tracking. + min_epochs: minimum number of training epochs. + max_epochs: maximum number of training epochs. + strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). + devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. + gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. + precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). + This also allows to use `FlashAttention` by default. + disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. + auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. + reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. + matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision + **kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. + """ + + def __init__( + self, + accelerator: Union[str, Accelerator] = "auto", + callbacks: Optional[List[Callback]] = None, + logger: Optional[Union[Logger, Iterable[Logger]]] = None, + min_epochs: Optional[int] = None, + max_epochs: Optional[int] = None, + strategy: Union[str, Strategy] = "auto", + devices: Union[List[int], str, int] = "auto", + gradient_clip_val: Union[int, float] = 1.0, + precision: Union[str, int] = "16-mixed", + reload_dataloaders_every_n_epochs: int = 1, + disable_profiling_executor: bool = True, + auto_configure_ddp: bool = True, + matmul_precision: Union[str, int] = "medium", + **kwargs, + ): + # Disable JIT profiling executor. This reduces memory and increases speed. + # Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 + if disable_profiling_executor: + try: + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + except AttributeError: + pass + + # Configure DDP automatically if multiple GPUs are available + if auto_configure_ddp and strategy == "auto": + if devices == "auto": + n_devices = num_cuda_devices() + elif isinstance(devices, Iterable): + n_devices = len(devices) + else: + n_devices = devices + if n_devices > 1: + log.info( + "Configuring DDP strategy automatically with {} GPUs".format( + n_devices + ) + ) + strategy = DDPStrategy( + find_unused_parameters=True, # We set to True due to RL envs + gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations + ) + + # Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision + if matmul_precision is not None: + torch.set_float32_matmul_precision(matmul_precision) + + # Check if gradient_clip_val is set to None + if gradient_clip_val is None: + log.warning( + "gradient_clip_val is set to None. This may lead to unstable training." + ) + + # We should reload dataloaders every epoch for RL training + if reload_dataloaders_every_n_epochs != 1: + log.warning( + "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " + + "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." + ) + + # Main call to `Trainer` superclass + super().__init__( + accelerator=accelerator, + callbacks=callbacks, + logger=logger, + min_epochs=min_epochs, + max_epochs=max_epochs, + strategy=strategy, + gradient_clip_val=gradient_clip_val, + devices=devices, + precision=precision, + reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, + **kwargs, + ) + + def fit( + self, + model: "pl.LightningModule", + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + ckpt_path: Optional[str] = None, + ) -> None: + """ + We override the `fit` method to automatically apply and handle RL4CO magic + to 'self.automatic_optimization = False' models, such as PPO + + It behaves exactly like the original `fit` method, but with the following changes: + - if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None + """ + + if not model.automatic_optimization: + if self.gradient_clip_val is not None: + log.warning( + "Overriding gradient_clip_val to None for 'automatic_optimization=False' models" + ) + self.gradient_clip_val = None + + # Fit (Inner-loop Optimization) + super().fit( + model=model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + ckpt_path=ckpt_path, + ) + + + From 5cbe7b9cee1574b74c00e4624fd4d1141cce4e38 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Mon, 27 May 2024 23:12:19 +0800 Subject: [PATCH 02/10] Update meta trainer --- examples/2d-meta_train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index 2e79c6f8..f0b8c59b 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -1,6 +1,3 @@ -import sys -sys.path.append("/home/jieyi/rl4co") - import pytz import torch From 05e0870898f7ff3d3c97b1062b5f91e710e4e0be Mon Sep 17 00:00:00 2001 From: jieyibi Date: Mon, 27 May 2024 23:31:07 +0800 Subject: [PATCH 03/10] Add source --- rl4co/utils/meta_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py index 0d45ccff..9c8fe3ff 100644 --- a/rl4co/utils/meta_trainer.py +++ b/rl4co/utils/meta_trainer.py @@ -105,6 +105,9 @@ def _alpha_scheduler(self): class RL4COMetaTrainer(Trainer): """Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. + # Meta training framework for addressing the generalization issue + # Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 + Note: The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. This allows for datasets to be re-created on the run and distributed by Lightning across From e8435c613bd651dc68e7ec4356947f004aa2bca6 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Tue, 28 May 2024 15:09:55 +0800 Subject: [PATCH 04/10] Update Reptile callbacks --- examples/2d-meta_train.py | 19 ++-- rl4co/utils/__init__.py | 1 - rl4co/utils/meta_trainer.py | 171 +++--------------------------------- 3 files changed, 20 insertions(+), 171 deletions(-) diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index f0b8c59b..2de9b629 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -1,14 +1,11 @@ -import pytz -import torch - -from datetime import datetime from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary from lightning.pytorch.loggers import WandbLogger from rl4co.envs import CVRPEnv from rl4co.models.zoo.am import AttentionModelPolicy from rl4co.models.zoo.pomo import POMO -from rl4co.utils.meta_trainer import RL4COMetaTrainer, MetaModelCallback +from rl4co.utils.trainer import RL4COTrainer +from rl4co.utils.meta_trainer import ReptileCallback def main(): # Set device @@ -34,7 +31,6 @@ def main(): train_data_size=64 * 50, # each epoch val_data_size=0, optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, - # for the task scheduler of size setting, where sch_epoch = 0.9 * epochs ) # Example callbacks @@ -47,10 +43,10 @@ def main(): mode="max", # maximize validation reward ) rich_model_summary = RichModelSummary(max_depth=3) # model summary callback + # Meta callbacks - meta_callback = MetaModelCallback( + meta_callback = ReptileCallback( meta_params={ - 'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'maml_fomaml', 'reptile'] 'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] 'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs 'B': 1, # the number of tasks in a mini-batch @@ -64,13 +60,12 @@ def main(): callbacks = [meta_callback, checkpoint_callback, rich_model_summary] # Logger - process_start_time = datetime.now(pytz.timezone("Asia/Singapore")) - logger = WandbLogger(project="rl4co", name=f"{env.name}_{process_start_time.strftime('%Y%m%d_%H%M%S')}") + logger = WandbLogger(project="rl4co", name=f"{env.name}_pomo_reptile") # logger = None # uncomment this line if you don't want logging # Adjust your trainer to the number of epochs you want to run - trainer = RL4COMetaTrainer( - max_epochs=20000, # (the number of meta-model updates) * (the number of tasks in a mini-batch) + trainer = RL4COTrainer( + max_epochs=20000, # (the number of meta_model updates) * (the number of tasks in a mini-batch) callbacks=callbacks, accelerator="gpu", devices=[device_id], diff --git a/rl4co/utils/__init__.py b/rl4co/utils/__init__.py index 638f3149..4b0246aa 100644 --- a/rl4co/utils/__init__.py +++ b/rl4co/utils/__init__.py @@ -2,7 +2,6 @@ from rl4co.utils.pylogger import get_pylogger from rl4co.utils.rich_utils import enforce_tags, print_config_tree from rl4co.utils.trainer import RL4COTrainer -from rl4co.utils.meta_trainer import RL4COMetaTrainer from rl4co.utils.utils import ( extras, get_metric_value, diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py index 9c8fe3ff..ada6e894 100644 --- a/rl4co/utils/meta_trainer.py +++ b/rl4co/utils/meta_trainer.py @@ -1,44 +1,40 @@ -from typing import Iterable, List, Optional, Union - import lightning.pytorch as pl import torch import math import copy from torch.optim import Adam -from lightning import Callback, Trainer -from lightning.fabric.accelerators.cuda import num_cuda_devices -from lightning.pytorch.accelerators import Accelerator -from lightning.pytorch.core.datamodule import LightningDataModule -from lightning.pytorch.loggers import Logger -from lightning.pytorch.strategies import DDPStrategy, Strategy -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning import Callback from rl4co import utils import random log = utils.get_pylogger(__name__) -class MetaModelCallback(Callback): +class ReptileCallback(Callback): + + # Meta training framework for addressing the generalization issue + # Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 def __init__(self, meta_params, print_log=True): super().__init__() self.meta_params = meta_params - assert meta_params["meta_method"] == 'reptile', NotImplementedError - assert meta_params["data_type"] == 'size', NotImplementedError self.print_log = print_log - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # Initialize some hyperparameters self.alpha = self.meta_params["alpha"] self.alpha_decay = self.meta_params["alpha_decay"] self.sch_bar = self.meta_params["sch_bar"] - self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] + if self.meta_params["data_type"] == "size": + self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] + else: + raise NotImplementedError # Sample a batch of tasks self._sample_task() self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # Alpha scheduler (decay for the update of meta model) self._alpha_scheduler() @@ -63,7 +59,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.print_log: print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) - def on_train_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): # Save the task model self.task_models.append(copy.deepcopy(pl_module.state_dict())) @@ -92,7 +88,7 @@ def _sample_task(self): self.selected_tasks.append(task_params) self.w = torch.softmax(torch.Tensor(w), dim=0) - def _load_task(self, pl_module, task_idx=0): + def _load_task(self, pl_module: pl.LightningModule, task_idx=0): # Load new training task (Update the environment) task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 @@ -102,144 +98,3 @@ def _load_task(self, pl_module, task_idx=0): def _alpha_scheduler(self): self.alpha = max(self.alpha * self.alpha_decay, 0.0001) -class RL4COMetaTrainer(Trainer): - """Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. - - # Meta training framework for addressing the generalization issue - # Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 - - Note: - The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. - This allows for datasets to be re-created on the run and distributed by Lightning across - devices on each epoch. Setting to a value different than 1 may lead to overfitting to a - specific (such as the initial) data distribution. - - Args: - accelerator: hardware accelerator to use. - callbacks: list of callbacks. - logger: logger (or iterable collection of loggers) for experiment tracking. - min_epochs: minimum number of training epochs. - max_epochs: maximum number of training epochs. - strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). - devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. - gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. - precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). - This also allows to use `FlashAttention` by default. - disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. - auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. - reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. - matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision - **kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. - """ - - def __init__( - self, - accelerator: Union[str, Accelerator] = "auto", - callbacks: Optional[List[Callback]] = None, - logger: Optional[Union[Logger, Iterable[Logger]]] = None, - min_epochs: Optional[int] = None, - max_epochs: Optional[int] = None, - strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", - gradient_clip_val: Union[int, float] = 1.0, - precision: Union[str, int] = "16-mixed", - reload_dataloaders_every_n_epochs: int = 1, - disable_profiling_executor: bool = True, - auto_configure_ddp: bool = True, - matmul_precision: Union[str, int] = "medium", - **kwargs, - ): - # Disable JIT profiling executor. This reduces memory and increases speed. - # Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 - if disable_profiling_executor: - try: - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - except AttributeError: - pass - - # Configure DDP automatically if multiple GPUs are available - if auto_configure_ddp and strategy == "auto": - if devices == "auto": - n_devices = num_cuda_devices() - elif isinstance(devices, Iterable): - n_devices = len(devices) - else: - n_devices = devices - if n_devices > 1: - log.info( - "Configuring DDP strategy automatically with {} GPUs".format( - n_devices - ) - ) - strategy = DDPStrategy( - find_unused_parameters=True, # We set to True due to RL envs - gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations - ) - - # Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision - if matmul_precision is not None: - torch.set_float32_matmul_precision(matmul_precision) - - # Check if gradient_clip_val is set to None - if gradient_clip_val is None: - log.warning( - "gradient_clip_val is set to None. This may lead to unstable training." - ) - - # We should reload dataloaders every epoch for RL training - if reload_dataloaders_every_n_epochs != 1: - log.warning( - "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " - + "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." - ) - - # Main call to `Trainer` superclass - super().__init__( - accelerator=accelerator, - callbacks=callbacks, - logger=logger, - min_epochs=min_epochs, - max_epochs=max_epochs, - strategy=strategy, - gradient_clip_val=gradient_clip_val, - devices=devices, - precision=precision, - reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, - **kwargs, - ) - - def fit( - self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[str] = None, - ) -> None: - """ - We override the `fit` method to automatically apply and handle RL4CO magic - to 'self.automatic_optimization = False' models, such as PPO - - It behaves exactly like the original `fit` method, but with the following changes: - - if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None - """ - - if not model.automatic_optimization: - if self.gradient_clip_val is not None: - log.warning( - "Overriding gradient_clip_val to None for 'automatic_optimization=False' models" - ) - self.gradient_clip_val = None - - # Fit (Inner-loop Optimization) - super().fit( - model=model, - train_dataloaders=train_dataloaders, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ckpt_path=ckpt_path, - ) - - - From bff31a998fe3dbae36e20d8fcd2107434cd7dcd8 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Tue, 28 May 2024 15:45:02 +0800 Subject: [PATCH 05/10] Add test --- tests/test_training.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_training.py b/tests/test_training.py index 1e94a813..62c3c4eb 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -16,8 +16,10 @@ MatNet, NARGNNPolicy, SymNCO, + POMO ) from rl4co.utils import RL4COTrainer +from rl4co.utils.meta_trainer import ReptileCallback # Get env variable MAC_OS_GITHUB_RUNNER if "MAC_OS_GITHUB_RUNNER" in os.environ: @@ -115,6 +117,19 @@ def test_mdam(): trainer.fit(model) trainer.test(model) +def test_pomo_reptile(): + env = TSPEnv(generator_params=dict(num_loc=20)) + policy = AttentionModelPolicy(env_name=env.name, embed_dim=128, + num_encoder_layers=6, num_heads=8, + normalization="instance", use_graph_context=False) + model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10) + meta_callback = ReptileCallback( + meta_params={'data_type': 'size', 'sch_bar': 0.9, 'B': 2, 'alpha': 0.99, + 'alpha_decay': 0.999, 'min_size': 20, 'max_size': 50} + ) + trainer = RL4COTrainer(max_epochs=2, callbacks=[meta_callback], devices=1, accelerator=accelerator, limit_train_batches=3) + trainer.fit(model) + trainer.test(model) @pytest.mark.parametrize("SearchMethod", [ActiveSearch, EASEmb, EASLay]) def test_search_methods(SearchMethod): From 0f4032c4f783de90563044060cf1ca03cce6c61c Mon Sep 17 00:00:00 2001 From: jieyibi Date: Tue, 28 May 2024 19:29:20 +0800 Subject: [PATCH 06/10] Update meta learning framework --- examples/2d-meta_train.py | 16 +++++------ rl4co/utils/meta_trainer.py | 53 ++++++++++++++++++++++--------------- tests/test_training.py | 6 ++--- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index 2de9b629..b375f7be 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -46,15 +46,13 @@ def main(): # Meta callbacks meta_callback = ReptileCallback( - meta_params={ - 'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] - 'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs - 'B': 1, # the number of tasks in a mini-batch - 'alpha': 0.99, # params for the outer-loop optimization of reptile - 'alpha_decay': 0.999, # params for the outer-loop optimization of reptile - 'min_size': 20, # minimum of sampled size in meta tasks - 'max_size': 150, # maximum of sampled size in meta tasks - }, + num_tasks = 1, # the number of tasks in a mini-batch + alpha = 0.99, # params for the outer-loop optimization of reptile + alpha_decay = 0.999, # params for the outer-loop optimization of reptile + min_size = 20, # minimum of sampled size in meta tasks + max_size= 150, # maximum of sampled size in meta tasks + data_type="size", # choose from ["size", "distribution", "size_distribution"] + sch_bar=0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs print_log=True # whether to print the sampled tasks in each meta iteration ) callbacks = [meta_callback, checkpoint_callback, rich_model_summary] diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py index ada6e894..2b6425bd 100644 --- a/rl4co/utils/meta_trainer.py +++ b/rl4co/utils/meta_trainer.py @@ -14,22 +14,29 @@ class ReptileCallback(Callback): # Meta training framework for addressing the generalization issue # Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 - def __init__(self, meta_params, print_log=True): + def __init__(self, + num_tasks, + alpha, + alpha_decay, + min_size, + max_size, + sch_bar = 0.9, + data_type = "size", + print_log=True): super().__init__() - self.meta_params = meta_params - self.print_log = print_log - - def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - # Initialize some hyperparameters - self.alpha = self.meta_params["alpha"] - self.alpha_decay = self.meta_params["alpha_decay"] - self.sch_bar = self.meta_params["sch_bar"] - if self.meta_params["data_type"] == "size": - self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] + self.num_tasks = num_tasks # i.e., B in the paper + self.alpha = alpha + self.alpha_decay = alpha_decay + self.sch_bar = sch_bar + self.print_log = print_log + if data_type == "size": + self.task_set = [(n,) for n in range(min_size, max_size + 1)] else: raise NotImplementedError + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + # Sample a batch of tasks self._sample_task() self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) @@ -40,12 +47,12 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul self._alpha_scheduler() # Reinitialize the task model with the parameters of the meta model - if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model + if trainer.current_epoch % self.num_tasks == 0: # Save the meta model self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) self.task_models = [] # Print sampled tasks if self.print_log: - print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks)) + print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.num_tasks, trainer.current_epoch, self.selected_tasks)) else: pl_module.load_state_dict(self.meta_model_state_dict) @@ -57,13 +64,16 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul # Print if self.print_log: - print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) + if hasattr(pl_module.env.generator, 'capacity'): + print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) + else: + print('\n>> Training task: {}'.format(pl_module.env.generator.num_loc)) def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): # Save the task model self.task_models.append(copy.deepcopy(pl_module.state_dict())) - if (trainer.current_epoch+1) % self.meta_params['B'] == 0: + if (trainer.current_epoch+1) % self.num_tasks == 0: # Outer-loop optimization (update the meta model with the parameters of the task model) with torch.no_grad(): state_dict = {params_key: (self.meta_model_state_dict[params_key] + @@ -73,17 +83,17 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule pl_module.load_state_dict(state_dict) # Get ready for the next meta-training iteration - if (trainer.current_epoch + 1) % self.meta_params['B'] == 0: + if (trainer.current_epoch + 1) % self.num_tasks == 0: # Sample a batch of tasks self._sample_task() # Load new training task (Update the environment) - self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B']) + self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks) def _sample_task(self): # Sample a batch of tasks - w, self.selected_tasks = [1.0] * self.meta_params['B'], [] - for b in range(self.meta_params['B']): + w, self.selected_tasks = [1.0] * self.num_tasks, [] + for b in range(self.num_tasks): task_params = random.sample(self.task_set, 1)[0] self.selected_tasks.append(task_params) self.w = torch.softmax(torch.Tensor(w), dim=0) @@ -91,9 +101,10 @@ def _sample_task(self): def _load_task(self, pl_module: pl.LightningModule, task_idx=0): # Load new training task (Update the environment) task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() - task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 pl_module.env.generator.num_loc = task_params[0] - pl_module.env.generator.capacity = task_capacity + if hasattr(pl_module.env.generator, 'capacity'): + task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 + pl_module.env.generator.capacity = task_capacity def _alpha_scheduler(self): self.alpha = max(self.alpha * self.alpha_decay, 0.0001) diff --git a/tests/test_training.py b/tests/test_training.py index 62c3c4eb..1d92e6b2 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -122,10 +122,10 @@ def test_pomo_reptile(): policy = AttentionModelPolicy(env_name=env.name, embed_dim=128, num_encoder_layers=6, num_heads=8, normalization="instance", use_graph_context=False) - model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10) + model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10, test_data_size=10) meta_callback = ReptileCallback( - meta_params={'data_type': 'size', 'sch_bar': 0.9, 'B': 2, 'alpha': 0.99, - 'alpha_decay': 0.999, 'min_size': 20, 'max_size': 50} + data_type="size", sch_bar=0.9, num_tasks=2, alpha = 0.99, + alpha_decay = 0.999, min_size = 20, max_size =50 ) trainer = RL4COTrainer(max_epochs=2, callbacks=[meta_callback], devices=1, accelerator=accelerator, limit_train_batches=3) trainer.fit(model) From 73d7a65df8a7d27d8548cee9ca95ab0723130fee Mon Sep 17 00:00:00 2001 From: jieyibi Date: Wed, 29 May 2024 21:14:33 +0800 Subject: [PATCH 07/10] Add support for cross-distribution generalization --- examples/2d-meta_train.py | 19 +-- rl4co/envs/common/distribution_utils.py | 184 ++++++++++++++++++++++++ rl4co/envs/common/utils.py | 9 +- rl4co/utils/meta_trainer.py | 109 ++++++++++---- 4 files changed, 287 insertions(+), 34 deletions(-) create mode 100644 rl4co/envs/common/distribution_utils.py diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index b375f7be..245ae9a2 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -1,3 +1,6 @@ +import sys +sys.path.append("/home/jieyi/rl4co") + from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary from lightning.pytorch.loggers import WandbLogger @@ -9,7 +12,7 @@ def main(): # Set device - device_id = 0 + device_id = 1 # RL4CO env based on TorchRL env = CVRPEnv(generator_params={'num_loc': 50}) @@ -46,13 +49,13 @@ def main(): # Meta callbacks meta_callback = ReptileCallback( - num_tasks = 1, # the number of tasks in a mini-batch - alpha = 0.99, # params for the outer-loop optimization of reptile - alpha_decay = 0.999, # params for the outer-loop optimization of reptile - min_size = 20, # minimum of sampled size in meta tasks - max_size= 150, # maximum of sampled size in meta tasks - data_type="size", # choose from ["size", "distribution", "size_distribution"] - sch_bar=0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs + num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper + alpha = 0.99, # initial weight of the task model for the outer-loop optimization of reptile + alpha_decay = 0.999, # weight decay of the task model for the outer-loop optimization of reptile + min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization) + max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization) + data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"] + sch_bar=0.9, # for the task scheduler of size setting, where lr_decay_epoch = sch_bar * epochs, i.e. after this epoch, learning rate will decay with a weight 0.1 print_log=True # whether to print the sampled tasks in each meta iteration ) callbacks = [meta_callback, checkpoint_callback, rich_model_summary] diff --git a/rl4co/envs/common/distribution_utils.py b/rl4co/envs/common/distribution_utils.py new file mode 100644 index 00000000..d03e92b8 --- /dev/null +++ b/rl4co/envs/common/distribution_utils.py @@ -0,0 +1,184 @@ +import torch + +class Cluster(): + + """ + Multiple gaussian distributed clusters, as in the Solomon benchmark dataset + Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) + + Args: + n_cluster: Number of the gaussian distributed clusters + """ + def __init__(self, n_cluster: int = 3): + super().__init__() + self.lower, self.upper = 0.2, 0.8 + self.std = 0.07 + self.n_cluster = n_cluster + def sample(self, size): + + batch_size, num_loc, _ = size + + # Generate the centers of the clusters + center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster * 2) + + # Pre-define the coordinates + coords = torch.zeros(batch_size, num_loc, 2) + + # Calculate the size of each cluster + cluster_sizes = [num_loc // self.n_cluster] * self.n_cluster + for i in range(num_loc % self.n_cluster): + cluster_sizes[i] += 1 + + # Generate the coordinates + current_index = 0 + for i in range(self.n_cluster): + means = center[:, i * 2:(i + 1) * 2] + stds = torch.full((batch_size, 2), self.std) + points = torch.normal(means.unsqueeze(1).expand(-1, cluster_sizes[i], -1), + stds.unsqueeze(1).expand(-1, cluster_sizes[i], -1)) + coords[:, current_index:current_index + cluster_sizes[i], :] = points + current_index += cluster_sizes[i] + + # Confine the coordinates to range [0, 1] + coords.clamp_(0, 1) + + return coords + +class Mixed(): + + """ + 50% nodes sampled from uniform distribution, 50% nodes sampled from gaussian distribution, as in the Solomon benchmark dataset + Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) + + Args: + n_cluster_mix: Number of the gaussian distributed clusters + """ + + def __init__(self, n_cluster_mix=1): + super().__init__() + self.lower, self.upper = 0.2, 0.8 + self.std = 0.07 + self.n_cluster_mix = n_cluster_mix + def sample(self, size): + + batch_size, num_loc, _ = size + + # Generate the centers of the clusters + center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster_mix * 2) + + # Pre-define the coordinates sampled under uniform distribution + coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1) + + # Sample mutated index (default setting: 50% mutation) + mutate_idx = torch.stack([torch.randperm(num_loc)[:num_loc // 2] for _ in range(batch_size)]) + + # Generate the coordinates + segment_size = num_loc // (2 * self.n_cluster_mix) + remaining_indices = num_loc // 2 - segment_size * (self.n_cluster_mix - 1) + sizes = [segment_size] * (self.n_cluster_mix - 1) + [remaining_indices] + for i in range(self.n_cluster_mix): + indices = mutate_idx[:, sum(sizes[:i]):sum(sizes[:i + 1])] + means_x = center[:, 2 * i].unsqueeze(1).expand(-1, sizes[i]) + means_y = center[:, 2 * i + 1].unsqueeze(1).expand(-1, sizes[i]) + coords.scatter_(1, indices.unsqueeze(-1).expand(-1, -1, 2), + torch.stack([ + torch.normal(means_x.expand(-1, sizes[i]), self.std), + torch.normal(means_y.expand(-1, sizes[i]), self.std) + ], dim=2)) + + # Confine the coordinates to range [0, 1] + coords.clamp_(0, 1) + + return coords + +class Gaussian_Mixture(): + ''' + Following Zhou et al. (2023): https://arxiv.org/abs/2305.19587 + + Args: + num_modes: the number of clusters/modes in the Gaussian Mixture. + cdist: scale of the uniform distribution for center generation. + ''' + def __init__(self, num_modes: int = 0, cdist: int = 0): + super().__init__() + self.num_modes = num_modes + self.cdist = cdist + + def sample(self, size): + + batch_size, num_loc, _ = size + + if self.num_modes == 0: # (0, 0) - uniform + return torch.rand((batch_size, num_loc, 2)) + elif self.num_modes == 1 and self.cdist == 1: # (1, 1) - gaussian + return self.generate_gaussian(batch_size, num_loc) + else: + res = [self.generate_gaussian_mixture(num_loc) for _ in range(batch_size)] + return torch.stack(res) + + def generate_gaussian_mixture(self, num_loc): + + """Following the setting in Zhang et al. 2022 (https://arxiv.org/abs/2204.03236)""" + + # Randomly decide how many points each mode gets + nums = torch.multinomial(input=torch.ones(self.num_modes) / self.num_modes, num_samples=num_loc, replacement=True) + + # Prepare to collect points + coords = torch.empty((0, 2)) + + # Generate points for each mode + for i in range(self.num_modes): + num = (nums == i).sum() # Number of points in this mode + if num > 0: + center = torch.rand((1, 2)) * self.cdist + cov = torch.eye(2) # Covariance matrix + nxy = torch.distributions.MultivariateNormal(center.squeeze(), covariance_matrix=cov).sample((num,)) + coords = torch.cat((coords, nxy), dim=0) + + return self._global_min_max_scaling(coords) + + def generate_gaussian(self, batch_size, num_loc): + + """Following the setting in Xin et al. 2022 (https://openreview.net/pdf?id=nJuzV-izmPJ)""" + + # Mean and random covariances + mean = torch.full((batch_size, num_loc, 2), 0.5) + covs = torch.rand(batch_size) # Random covariances between 0 and 1 + + # Generate the coordinates + coords = torch.zeros((batch_size, num_loc, 2)) + for i in range(batch_size): + # Construct covariance matrix for each sample + cov_matrix = torch.tensor([[1.0, covs[i]], [covs[i], 1.0]]) + m = torch.distributions.MultivariateNormal(mean[i], covariance_matrix=cov_matrix) + coords[i] = m.sample() + + # Shuffle the coordinates + indices = torch.randperm(coords.size(0)) + coords = coords[indices] + + return self._batch_normalize_and_center(coords) + + def _global_min_max_scaling(self, coords): + + # Scale the points to [0, 1] using min-max scaling + coords_min = coords.min(0, keepdim=True).values + coords_max = coords.max(0, keepdim=True).values + coords = (coords - coords_min) / (coords_max - coords_min) + + return coords + + def _batch_normalize_and_center(self, coords): + # Step 1: Compute min and max along each batch + coords_min = coords.min(dim=1, keepdim=True).values + coords_max = coords.max(dim=1, keepdim=True).values + + # Step 2: Normalize coordinates to range [0, 1] + coords = coords - coords_min # Broadcasting subtracts min value on each coordinate + range_max = (coords_max - coords_min).max(dim=-1, keepdim=True).values # The maximum range among both coordinates + coords = coords / range_max # Divide by the max range to normalize + + # Step 3: Center the batch in the middle of the [0, 1] range + coords = coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 # Centering the batch + + return coords diff --git a/rl4co/envs/common/utils.py b/rl4co/envs/common/utils.py index 5d7612f7..ebaed856 100644 --- a/rl4co/envs/common/utils.py +++ b/rl4co/envs/common/utils.py @@ -6,7 +6,7 @@ from tensordict.tensordict import TensorDict from torch.distributions import Exponential, Normal, Poisson, Uniform - +from rl4co.envs.common.distribution_utils import Cluster, Mixed, Gaussian_Mixture class Generator(metaclass=abc.ABCMeta): """Base data generator class, to be called with `env.generator(batch_size)`""" @@ -76,6 +76,12 @@ def get_sampler( ) # todo: should be also `low, high` and any other corner elif isinstance(distribution, Callable): return distribution(**kwargs) + elif distribution == "gaussian_mixture": + return Gaussian_Mixture(num_modes=kwargs['num_modes'], cdist=kwargs['cdist']) + elif distribution == "cluster": + return Cluster(kwargs['n_cluster']) + elif distribution == "mixed": + return Mixed(kwargs['n_cluster_mix']) else: raise ValueError(f"Invalid distribution type of {distribution}") @@ -87,3 +93,4 @@ def batch_to_scalar(param): if isinstance(param, torch.Tensor): return param.item() return param + diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py index 2b6425bd..ccd64352 100644 --- a/rl4co/utils/meta_trainer.py +++ b/rl4co/utils/meta_trainer.py @@ -12,34 +12,55 @@ class ReptileCallback(Callback): - # Meta training framework for addressing the generalization issue - # Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 + """ Meta training framework for addressing the generalization issue (implement the Reptile algorithm only) + Based on Manchanda et al. 2022 (https://arxiv.org/abs/2206.00787) and Zhou et al. 2023 (https://arxiv.org/abs/2305.19587) + + Args: + - num_tasks: the number of tasks in a mini-batch, i.e. `B` in the original paper + - alpha: initial weight of the task model for the outer-loop optimization of reptile + - alpha_decay: weight decay of the task model for the outer-loop optimization of reptile + - min_size: minimum problem size of the task (only supported in cross-size generalization) + - max_size: maximum problem size of the task (only supported in cross-size generalization) + - sch_bar: for the task scheduler of size setting, where lr_decay_epoch = sch_bar * epochs, i.e. after this epoch, learning rate will decay with a weight 0.1 + - data_type: type of the tasks, chosen from ["size", "distribution", "size_distribution"] + - print_log: whether to print the specific task sampled in each inner-loop optimization + """ def __init__(self, - num_tasks, - alpha, - alpha_decay, - min_size, - max_size, - sch_bar = 0.9, - data_type = "size", - print_log=True): + num_tasks: int, + alpha: float, + alpha_decay: float, + min_size: int, + max_size: int, + sch_bar: float = 0.9, + data_type: str = "size", + print_log: bool =True): + super().__init__() - self.num_tasks = num_tasks # i.e., B in the paper + self.num_tasks = num_tasks self.alpha = alpha self.alpha_decay = alpha_decay self.sch_bar = sch_bar self.print_log = print_log - if data_type == "size": - self.task_set = [(n,) for n in range(min_size, max_size + 1)] - else: - raise NotImplementedError + self.data_type = data_type + self.task_set = self._generate_task_set(data_type, min_size, max_size) def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # Sample a batch of tasks self._sample_task() - self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) + + # Pre-set the distribution + if self.data_type == "size_distribution": + pl_module.env.generator.loc_distribution = "gaussian_mixture" + self.selected_tasks[0] = (pl_module.env.generator.num_loc, 0, 0) + elif self.data_type == "size": + pl_module.env.generator.loc_distribution = "uniform" + self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) + elif self.data_type == "distribution": + pl_module.env.generator.loc_distribution = "gaussian_mixture" + self.selected_tasks[0] = (0, 0) + self.task_params = self.selected_tasks[0] def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: @@ -65,9 +86,9 @@ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModul # Print if self.print_log: if hasattr(pl_module.env.generator, 'capacity'): - print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) + print('>> Training task: {}, capacity: {}'.format(self.task_params, pl_module.env.generator.capacity)) else: - print('\n>> Training task: {}'.format(pl_module.env.generator.num_loc)) + print('>> Training task: {}'.format(self.task_params)) def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): @@ -87,25 +108,63 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule # Sample a batch of tasks self._sample_task() - # Load new training task (Update the environment) + # Load new training task (Update the environment) for the next meta-training iteration self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks) def _sample_task(self): + # Sample a batch of tasks - w, self.selected_tasks = [1.0] * self.num_tasks, [] + self.selected_tasks = [] for b in range(self.num_tasks): task_params = random.sample(self.task_set, 1)[0] self.selected_tasks.append(task_params) - self.w = torch.softmax(torch.Tensor(w), dim=0) def _load_task(self, pl_module: pl.LightningModule, task_idx=0): + # Load new training task (Update the environment) - task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() - pl_module.env.generator.num_loc = task_params[0] - if hasattr(pl_module.env.generator, 'capacity'): - task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 + self.task_params = self.selected_tasks[task_idx] + + if self.data_type == "size_distribution": + assert len(self.task_params) == 3 + pl_module.env.generator.num_loc = self.task_params[0] + pl_module.env.generator.num_modes = self.task_params[1] + pl_module.env.generator.cdist = self.task_params[2] + elif self.data_type == "distribution": # fixed size + assert len(self.task_params) == 2 + pl_module.env.generator.num_modes = self.task_params[0] + pl_module.env.generator.cdist = self.task_params[1] + elif self.data_type == "size": # fixed distribution + assert len(self.task_params) == 1 + pl_module.env.generator.num_loc = self.task_params[0] + + if hasattr(pl_module.env.generator, 'capacity') and self.data_type in ["size_distribution", "size"]: + task_capacity = math.ceil(30 + self.task_params[0] / 5) if self.task_params[0] >= 20 else 20 pl_module.env.generator.capacity = task_capacity def _alpha_scheduler(self): self.alpha = max(self.alpha * self.alpha_decay, 0.0001) + def _generate_task_set(self, data_type, min_size, max_size): + """ + Following the setting in Zhou et al. 2023 (https://arxiv.org/abs/2305.19587) + Current setting: + size: (n,) \in [20, 150] + distribution: (m, c) \in {(0, 0) + [1-9] * [1, 10, 20, 30, 40, 50]} + size_distribution: (n, m, c) \in [50, 200, 5] * {(0, 0) + (1, 1) + [3, 5, 7] * [10, 30, 50]} + """ + + if data_type == "distribution": # focus on TSP100 with gaussian mixture distributions + task_set = [(0, 0)] + [(m, c) for m in range(1, 10) for c in [1, 10, 20, 30, 40, 50]] + elif data_type == "size": # focus on uniform distribution with different sizes + task_set = [(n,) for n in range(min_size, max_size + 1)] + elif data_type == "size_distribution": + dist_set = [(0, 0), (1, 1)] + [(m, c) for m in [3, 5, 7] for c in [10, 30, 50]] + task_set = [(n, m, c) for n in range(50, 201, 5) for (m, c) in dist_set] + else: + raise NotImplementedError + + print(">> Generating training task set: {} tasks with type {}".format(len(task_set), data_type)) + print(">> Training task set: {}".format(task_set)) + + return task_set + From 5a029fa3c2646e68a49d6687d792f1a9e3ad1b39 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Wed, 29 May 2024 21:51:08 +0800 Subject: [PATCH 08/10] add document --- examples/2d-meta_train.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index 245ae9a2..41b6db4c 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -1,6 +1,3 @@ -import sys -sys.path.append("/home/jieyi/rl4co") - from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary from lightning.pytorch.loggers import WandbLogger @@ -12,7 +9,7 @@ def main(): # Set device - device_id = 1 + device_id = 0 # RL4CO env based on TorchRL env = CVRPEnv(generator_params={'num_loc': 50}) @@ -31,7 +28,7 @@ def main(): model = POMO(env, policy, batch_size=64, # meta_batch_size - train_data_size=64 * 50, # each epoch + train_data_size=64 * 50, # equals to (meta_batch_size) * (gradient decent steps in the inner-loop optimization of meta-learning method) val_data_size=0, optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, ) From d11788d333951e7fa94e6becc2540bb7e5ec0407 Mon Sep 17 00:00:00 2001 From: jieyibi Date: Tue, 11 Jun 2024 15:54:10 +0800 Subject: [PATCH 09/10] add support for training on multiple mixed distributions --- rl4co/envs/common/distribution_utils.py | 70 +++++++++++++++++++++++++ rl4co/envs/common/utils.py | 6 ++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/rl4co/envs/common/distribution_utils.py b/rl4co/envs/common/distribution_utils.py index d03e92b8..f440adeb 100644 --- a/rl4co/envs/common/distribution_utils.py +++ b/rl4co/envs/common/distribution_utils.py @@ -1,3 +1,4 @@ +import random import torch class Cluster(): @@ -182,3 +183,72 @@ def _batch_normalize_and_center(self, coords): coords = coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 # Centering the batch return coords + +class Mix_Distribution(): + + ''' + Mixture of three exemplar distributions in batch-level, i.e. Uniform, Cluster, Mixed + Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686) + + Args: + n_cluster: Number of the gaussian distributed clusters in Cluster distribution + n_cluster_mix: Number of the gaussian distributed clusters in Mixed distribution + ''' + def __init__(self, n_cluster=3, n_cluster_mix=1): + super().__init__() + self.lower, self.upper = 0.2, 0.8 + self.std = 0.07 + self.Mixed = Mixed(n_cluster_mix=n_cluster_mix) + self.Cluster = Cluster(n_cluster=n_cluster) + + def sample(self, size): + + batch_size, num_loc, _ = size + + # Pre-define the coordinates sampled under uniform distribution + coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1) + + # Random sample probability for the distribution of each sample + p = torch.rand(batch_size) + + # Mixed + mask = p <= 0.33 + n_mixed = mask.sum().item() + if n_mixed > 0: + coords[mask] = self.Mixed.sample((n_mixed, num_loc, 2)) + + # Cluster + mask = (p > 0.33) & (p <= 0.66) + n_cluster = mask.sum().item() + if n_cluster > 0: + coords[mask] = self.Cluster.sample((n_cluster, num_loc, 2)) + + # The remaining ones are uniformly distributed + return coords + +class Mix_Multi_Distributions(): + + ''' + Mixture of 11 Gaussian-like distributions in batch-level + Following the setting in Zhou et al. (2023): https://arxiv.org/abs/2305.19587 + ''' + def __init__(self): + super().__init__() + self.dist_set = [(0, 0), (1, 1)] + [(m, c) for m in [3, 5, 7] for c in [10, 30, 50]] + + def sample(self, size): + batch_size, num_loc, _ = size + coords = torch.zeros(batch_size, num_loc, 2) + + # Pre-select distributions for the entire batch + dists = [random.choice(self.dist_set) for _ in range(batch_size)] + unique_dists = list(set(dists)) # Unique distributions to minimize re-instantiation + + # Instantiate Gaussian_Mixture only once per unique distribution + gm_instances = {dist: Gaussian_Mixture(*dist) for dist in unique_dists} + + # Batch process where possible + for i, dist in enumerate(dists): + coords[i] = gm_instances[dist].sample((1, num_loc, 2)).squeeze(0) + + return coords diff --git a/rl4co/envs/common/utils.py b/rl4co/envs/common/utils.py index ebaed856..0c2a27e8 100644 --- a/rl4co/envs/common/utils.py +++ b/rl4co/envs/common/utils.py @@ -6,7 +6,7 @@ from tensordict.tensordict import TensorDict from torch.distributions import Exponential, Normal, Poisson, Uniform -from rl4co.envs.common.distribution_utils import Cluster, Mixed, Gaussian_Mixture +from rl4co.envs.common.distribution_utils import Cluster, Mixed, Gaussian_Mixture, Mix_Distribution, Mix_Multi_Distributions class Generator(metaclass=abc.ABCMeta): """Base data generator class, to be called with `env.generator(batch_size)`""" @@ -82,6 +82,10 @@ def get_sampler( return Cluster(kwargs['n_cluster']) elif distribution == "mixed": return Mixed(kwargs['n_cluster_mix']) + elif distribution == "mix_distribution": + return Mix_Distribution(kwargs['n_cluster'], kwargs['n_cluster_mix']) + elif distribution == "mix_multi_distributions": + return Mix_Multi_Distributions() else: raise ValueError(f"Invalid distribution type of {distribution}") From 60fa8c8a7f37ed2825c3323b54022730dbea0135 Mon Sep 17 00:00:00 2001 From: Jieyi Bi <113847374+jieyibi@users.noreply.github.com> Date: Wed, 19 Jun 2024 21:19:42 +0800 Subject: [PATCH 10/10] Update 2d-meta_train.py Change some parameters for performance --- examples/2d-meta_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index 41b6db4c..1f3fb8d4 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -35,7 +35,7 @@ def main(): # Example callbacks checkpoint_callback = ModelCheckpoint( - dirpath="checkpoints", # save to checkpoints/ + dirpath="meta_pomo/checkpoints", # save to checkpoints/ filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt save_top_k=1, # save only the best model save_last=True, # save the last model @@ -47,8 +47,8 @@ def main(): # Meta callbacks meta_callback = ReptileCallback( num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper - alpha = 0.99, # initial weight of the task model for the outer-loop optimization of reptile - alpha_decay = 0.999, # weight decay of the task model for the outer-loop optimization of reptile + alpha = 0.9, # initial weight of the task model for the outer-loop optimization of reptile + alpha_decay = 1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better. min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization) max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization) data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"] @@ -63,7 +63,7 @@ def main(): # Adjust your trainer to the number of epochs you want to run trainer = RL4COTrainer( - max_epochs=20000, # (the number of meta_model updates) * (the number of tasks in a mini-batch) + max_epochs=15000, # (the number of meta_model updates) * (the number of tasks in a mini-batch) callbacks=callbacks, accelerator="gpu", devices=[device_id],