Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta learning framework #183

Merged
merged 10 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/2d-meta_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger

from rl4co.envs import CVRPEnv
from rl4co.models.zoo.am import AttentionModelPolicy
from rl4co.models.zoo.pomo import POMO
from rl4co.utils.trainer import RL4COTrainer
from rl4co.utils.meta_trainer import ReptileCallback

def main():
# Set device
device_id = 0

# RL4CO env based on TorchRL
env = CVRPEnv(generator_params={'num_loc': 50})

# Policy: neural network, in this case with encoder-decoder architecture
# Note that this is adapted the same as POMO did in the original paper
policy = AttentionModelPolicy(env_name=env.name,
embed_dim=128,
num_encoder_layers=6,
num_heads=8,
normalization="instance",
use_graph_context=False
)

# RL Model (POMO)
model = POMO(env,
policy,
batch_size=64, # meta_batch_size
train_data_size=64 * 50, # each epoch
val_data_size=0,
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6},
)

# Example callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max", # maximize validation reward
)
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback

# Meta callbacks
meta_callback = ReptileCallback(
num_tasks = 1, # the number of tasks in a mini-batch
alpha = 0.99, # params for the outer-loop optimization of reptile
alpha_decay = 0.999, # params for the outer-loop optimization of reptile
min_size = 20, # minimum of sampled size in meta tasks
max_size= 150, # maximum of sampled size in meta tasks
data_type="size", # choose from ["size", "distribution", "size_distribution"]
sch_bar=0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs
print_log=True # whether to print the sampled tasks in each meta iteration
)
callbacks = [meta_callback, checkpoint_callback, rich_model_summary]

# Logger
logger = WandbLogger(project="rl4co", name=f"{env.name}_pomo_reptile")
# logger = None # uncomment this line if you don't want logging

# Adjust your trainer to the number of epochs you want to run
trainer = RL4COTrainer(
max_epochs=20000, # (the number of meta_model updates) * (the number of tasks in a mini-batch)
callbacks=callbacks,
accelerator="gpu",
devices=[device_id],
logger=logger,
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method
)

# Fit
trainer.fit(model)


if __name__ == "__main__":
main()

111 changes: 111 additions & 0 deletions rl4co/utils/meta_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import lightning.pytorch as pl
import torch
import math
import copy
from torch.optim import Adam

from lightning import Callback
from rl4co import utils
import random
log = utils.get_pylogger(__name__)


class ReptileCallback(Callback):

# Meta training framework for addressing the generalization issue
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587
def __init__(self,
num_tasks,
alpha,
alpha_decay,
min_size,
max_size,
sch_bar = 0.9,
data_type = "size",
print_log=True):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documentation] It's recommended to have a doc with parameters, possibly including the data type, constraints, hints, etc. Better for us "non-experts" to understand 😆

class ReptileCallback(Callback):
    """Meta training framework for addressing the generalization issue
    Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587

    Args:
        - num_tasks: number of task types, i.e. `B` in the original paper
        - alpha: ...
        - ...
    """

    def __init__(
        self,
        num_tasks: int,
        alpha: float,
        alpha_decay: float,
        min_size: int,
        max_size: int,
        sch_bar: float = 0.9,
        data_type: str = "size",
        print_log: bool = True,
    ):
        super().__init__()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Chuanbo. I have added the documentation as you recommended, along with the generation code for some distributions defined in the generalization-related works. Now the meta learning framework is supported for cross-distribution generalization.


self.num_tasks = num_tasks # i.e., B in the paper
self.alpha = alpha
self.alpha_decay = alpha_decay
self.sch_bar = sch_bar
self.print_log = print_log
if data_type == "size":
self.task_set = [(n,) for n in range(min_size, max_size + 1)]
else:
raise NotImplementedError

def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

# Sample a batch of tasks
self._sample_task()
self.selected_tasks[0] = (pl_module.env.generator.num_loc, )

def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

# Alpha scheduler (decay for the update of meta model)
self._alpha_scheduler()

# Reinitialize the task model with the parameters of the meta model
if trainer.current_epoch % self.num_tasks == 0: # Save the meta model
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict())
self.task_models = []
# Print sampled tasks
if self.print_log:
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.num_tasks, trainer.current_epoch, self.selected_tasks))
else:
pl_module.load_state_dict(self.meta_model_state_dict)

# Reinitialize the optimizer every epoch
lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1
old_lr = trainer.optimizers[0].param_groups[0]['lr']
new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay)
trainer.optimizers = [new_optimizer]

# Print
if self.print_log:
if hasattr(pl_module.env.generator, 'capacity'):
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity))
else:
print('\n>> Training task: {}'.format(pl_module.env.generator.num_loc))

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):

# Save the task model
self.task_models.append(copy.deepcopy(pl_module.state_dict()))
if (trainer.current_epoch+1) % self.num_tasks == 0:
# Outer-loop optimization (update the meta model with the parameters of the task model)
with torch.no_grad():
state_dict = {params_key: (self.meta_model_state_dict[params_key] +
self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key]
for fast_weight in self.task_models], dim=0).float(), dim=0))
for params_key in self.meta_model_state_dict}
pl_module.load_state_dict(state_dict)

# Get ready for the next meta-training iteration
if (trainer.current_epoch + 1) % self.num_tasks == 0:
# Sample a batch of tasks
self._sample_task()

# Load new training task (Update the environment)
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks)

def _sample_task(self):
# Sample a batch of tasks
w, self.selected_tasks = [1.0] * self.num_tasks, []
for b in range(self.num_tasks):
task_params = random.sample(self.task_set, 1)[0]
self.selected_tasks.append(task_params)
self.w = torch.softmax(torch.Tensor(w), dim=0)

def _load_task(self, pl_module: pl.LightningModule, task_idx=0):
# Load new training task (Update the environment)
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item()
pl_module.env.generator.num_loc = task_params[0]
if hasattr(pl_module.env.generator, 'capacity'):
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20
pl_module.env.generator.capacity = task_capacity

def _alpha_scheduler(self):
self.alpha = max(self.alpha * self.alpha_decay, 0.0001)

15 changes: 15 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
MatNet,
NARGNNPolicy,
SymNCO,
POMO
)
from rl4co.utils import RL4COTrainer
from rl4co.utils.meta_trainer import ReptileCallback

# Get env variable MAC_OS_GITHUB_RUNNER
if "MAC_OS_GITHUB_RUNNER" in os.environ:
Expand Down Expand Up @@ -115,6 +117,19 @@ def test_mdam():
trainer.fit(model)
trainer.test(model)

def test_pomo_reptile():
env = TSPEnv(generator_params=dict(num_loc=20))
policy = AttentionModelPolicy(env_name=env.name, embed_dim=128,
num_encoder_layers=6, num_heads=8,
normalization="instance", use_graph_context=False)
model = POMO(env, policy, batch_size=5, train_data_size=5*3, val_data_size=10, test_data_size=10)
meta_callback = ReptileCallback(
data_type="size", sch_bar=0.9, num_tasks=2, alpha = 0.99,
alpha_decay = 0.999, min_size = 20, max_size =50
)
trainer = RL4COTrainer(max_epochs=2, callbacks=[meta_callback], devices=1, accelerator=accelerator, limit_train_batches=3)
trainer.fit(model)
trainer.test(model)

@pytest.mark.parametrize("SearchMethod", [ActiveSearch, EASEmb, EASLay])
def test_search_methods(SearchMethod):
Expand Down
Loading