-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add meta learning framework #183
Merged
Merged
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
475e430
Update meta trainer
jieyibi 5cbe7b9
Update meta trainer
jieyibi 05e0870
Add source
jieyibi e8435c6
Update Reptile callbacks
jieyibi bff31a9
Add test
jieyibi 0f4032c
Update meta learning framework
jieyibi 73d7a65
Add support for cross-distribution generalization
jieyibi 5a029fa
add document
jieyibi d11788d
add support for training on multiple mixed distributions
jieyibi 60fa8c8
Update 2d-meta_train.py
jieyibi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary | ||
from lightning.pytorch.loggers import WandbLogger | ||
|
||
from rl4co.envs import CVRPEnv | ||
from rl4co.models.zoo.am import AttentionModelPolicy | ||
from rl4co.models.zoo.pomo import POMO | ||
from rl4co.utils.trainer import RL4COTrainer | ||
from rl4co.utils.meta_trainer import ReptileCallback | ||
|
||
def main(): | ||
# Set device | ||
device_id = 0 | ||
|
||
# RL4CO env based on TorchRL | ||
env = CVRPEnv(generator_params={'num_loc': 50}) | ||
|
||
# Policy: neural network, in this case with encoder-decoder architecture | ||
# Note that this is adapted the same as POMO did in the original paper | ||
policy = AttentionModelPolicy(env_name=env.name, | ||
embed_dim=128, | ||
num_encoder_layers=6, | ||
num_heads=8, | ||
normalization="instance", | ||
use_graph_context=False | ||
) | ||
|
||
# RL Model (POMO) | ||
model = POMO(env, | ||
policy, | ||
batch_size=64, # meta_batch_size | ||
train_data_size=64 * 50, # each epoch | ||
val_data_size=0, | ||
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, | ||
) | ||
|
||
# Example callbacks | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="checkpoints", # save to checkpoints/ | ||
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt | ||
save_top_k=1, # save only the best model | ||
save_last=True, # save the last model | ||
monitor="val/reward", # monitor validation reward | ||
mode="max", # maximize validation reward | ||
) | ||
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback | ||
|
||
# Meta callbacks | ||
meta_callback = ReptileCallback( | ||
num_tasks = 1, # the number of tasks in a mini-batch | ||
alpha = 0.99, # params for the outer-loop optimization of reptile | ||
alpha_decay = 0.999, # params for the outer-loop optimization of reptile | ||
min_size = 20, # minimum of sampled size in meta tasks | ||
max_size= 150, # maximum of sampled size in meta tasks | ||
data_type="size", # choose from ["size", "distribution", "size_distribution"] | ||
sch_bar=0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs | ||
print_log=True # whether to print the sampled tasks in each meta iteration | ||
) | ||
callbacks = [meta_callback, checkpoint_callback, rich_model_summary] | ||
|
||
# Logger | ||
logger = WandbLogger(project="rl4co", name=f"{env.name}_pomo_reptile") | ||
# logger = None # uncomment this line if you don't want logging | ||
|
||
# Adjust your trainer to the number of epochs you want to run | ||
trainer = RL4COTrainer( | ||
max_epochs=20000, # (the number of meta_model updates) * (the number of tasks in a mini-batch) | ||
callbacks=callbacks, | ||
accelerator="gpu", | ||
devices=[device_id], | ||
logger=logger, | ||
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method | ||
) | ||
|
||
# Fit | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import lightning.pytorch as pl | ||
import torch | ||
import math | ||
import copy | ||
from torch.optim import Adam | ||
|
||
from lightning import Callback | ||
from rl4co import utils | ||
import random | ||
log = utils.get_pylogger(__name__) | ||
|
||
|
||
class ReptileCallback(Callback): | ||
|
||
# Meta training framework for addressing the generalization issue | ||
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
def __init__(self, | ||
num_tasks, | ||
alpha, | ||
alpha_decay, | ||
min_size, | ||
max_size, | ||
sch_bar = 0.9, | ||
data_type = "size", | ||
print_log=True): | ||
super().__init__() | ||
|
||
self.num_tasks = num_tasks # i.e., B in the paper | ||
self.alpha = alpha | ||
self.alpha_decay = alpha_decay | ||
self.sch_bar = sch_bar | ||
self.print_log = print_log | ||
if data_type == "size": | ||
self.task_set = [(n,) for n in range(min_size, max_size + 1)] | ||
else: | ||
raise NotImplementedError | ||
|
||
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
|
||
# Sample a batch of tasks | ||
self._sample_task() | ||
self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) | ||
|
||
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
|
||
# Alpha scheduler (decay for the update of meta model) | ||
self._alpha_scheduler() | ||
|
||
# Reinitialize the task model with the parameters of the meta model | ||
if trainer.current_epoch % self.num_tasks == 0: # Save the meta model | ||
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) | ||
self.task_models = [] | ||
# Print sampled tasks | ||
if self.print_log: | ||
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.num_tasks, trainer.current_epoch, self.selected_tasks)) | ||
else: | ||
pl_module.load_state_dict(self.meta_model_state_dict) | ||
|
||
# Reinitialize the optimizer every epoch | ||
lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1 | ||
old_lr = trainer.optimizers[0].param_groups[0]['lr'] | ||
new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay) | ||
trainer.optimizers = [new_optimizer] | ||
|
||
if self.print_log: | ||
if hasattr(pl_module.env.generator, 'capacity'): | ||
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) | ||
else: | ||
print('\n>> Training task: {}'.format(pl_module.env.generator.num_loc)) | ||
|
||
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
|
||
# Save the task model | ||
self.task_models.append(copy.deepcopy(pl_module.state_dict())) | ||
if (trainer.current_epoch+1) % self.num_tasks == 0: | ||
# Outer-loop optimization (update the meta model with the parameters of the task model) | ||
with torch.no_grad(): | ||
state_dict = {params_key: (self.meta_model_state_dict[params_key] + | ||
self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key] | ||
for fast_weight in self.task_models], dim=0).float(), dim=0)) | ||
for params_key in self.meta_model_state_dict} | ||
pl_module.load_state_dict(state_dict) | ||
|
||
# Get ready for the next meta-training iteration | ||
if (trainer.current_epoch + 1) % self.num_tasks == 0: | ||
# Sample a batch of tasks | ||
self._sample_task() | ||
|
||
# Load new training task (Update the environment) | ||
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks) | ||
|
||
def _sample_task(self): | ||
# Sample a batch of tasks | ||
w, self.selected_tasks = [1.0] * self.num_tasks, [] | ||
for b in range(self.num_tasks): | ||
task_params = random.sample(self.task_set, 1)[0] | ||
self.selected_tasks.append(task_params) | ||
self.w = torch.softmax(torch.Tensor(w), dim=0) | ||
|
||
def _load_task(self, pl_module: pl.LightningModule, task_idx=0): | ||
# Load new training task (Update the environment) | ||
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() | ||
pl_module.env.generator.num_loc = task_params[0] | ||
if hasattr(pl_module.env.generator, 'capacity'): | ||
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 | ||
pl_module.env.generator.capacity = task_capacity | ||
|
||
def _alpha_scheduler(self): | ||
self.alpha = max(self.alpha * self.alpha_decay, 0.0001) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Documentation] It's recommended to have a doc with parameters, possibly including the data type, constraints, hints, etc. Better for us "non-experts" to understand 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Chuanbo. I have added the documentation as you recommended, along with the generation code for some distributions defined in the generalization-related works. Now the meta learning framework is supported for cross-distribution generalization.