Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta learning framework #183

Merged
merged 10 commits into from
Jun 20, 2024
Merged

Add meta learning framework #183

merged 10 commits into from
Jun 20, 2024

Conversation

jieyibi
Copy link
Contributor

@jieyibi jieyibi commented May 27, 2024

Description

Add a new training framework based on meta learning. Details refer to Zhou et al. 2023.

Motivation and Context

To address the generalization issue.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! 🚀

Left some comments here and there~

def _alpha_scheduler(self):
self.alpha = max(self.alpha * self.alpha_decay, 0.0001)

class RL4COMetaTrainer(Trainer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any differences compared to the RL4COTrainer? I could not find any at a first glance. I guess the only difference is that we pass the MetaModelCallback?

log = utils.get_pylogger(__name__)


class MetaModelCallback(Callback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought of Meta-learning as Lightning Callbacks, but it looks neat! :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @Junyoungpark for the good ol' Reptile

def __init__(self, meta_params, print_log=True):
super().__init__()
self.meta_params = meta_params
assert meta_params["meta_method"] == 'reptile', NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another general comment: it seems that this callback is only for Reptile, so we should consider calling it ReptileCallback. I think it would be cool to have these as say a meta_learning/ folder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. ReptileCallback is better. Where should this new meta_learning/ folder be located? In utils/?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I''d say under models/rl (which includes LightningModules), since meta_learning is a way to optimize a policy

self._sample_task()
self.selected_tasks[0] = (pl_module.env.generator.num_loc, )

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] type hints should be classes, not strings. For example Trainer and LightningModule (Maybe even better: RL4COTrainer and RL4COLitModule since everything is inherited from that)

Copy link
Contributor Author

@jieyibi jieyibi May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. I am refactoring the ReptileCallback inherited from the REINFORCE and the RL4COLitModule. But in this case, maybe I need to add new meta_model.py inherited the model.py under the zoo/pomo/ folder or the zoo/am/ folder to call the REPTILE outside. It is a little bit redundant... Maybe a Lightning Callback is more generic. We could apply it to every model and every policy.

Copy link
Member

@cbhua cbhua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! 🚀 Really nice to have meta learning supporting 😁

@@ -0,0 +1,87 @@
import pytz
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] Nice tool but pytz is not included in RL4CO's dependence packages. Maybe better to have a package check here:

try:
    import pytz
except ImportError:
    # raise a warning and use python default timeit
    pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it could be removed.

super().__init__()
self.meta_params = meta_params
assert meta_params["meta_method"] == 'reptile', NotImplementedError
assert meta_params["data_type"] == 'size', NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this parameter data_type is not called anywhere? 🤔

Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!
Also:

  1. Could you reproduce the learning curves of the original model?
  2. Could you make a simple test as this for your model so that it can be automatically checked?


# Meta training framework for addressing the generalization issue
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587
def __init__(self, meta_params, print_log=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] I think it would be more clear to list the Hyperparameters.

For example

def __init__(self, 
alpha=...,
alpha_decay=...
... 
print_log=True):

which is generally easier to maintain and to document.

@fedebotu
Copy link
Member

It seems there was a mistake here - I noticed that the "tasks" are a bit hardcoded, i.e. the "size" and "capacity" (I guess for TSP and CVRP, right?)


Speaking of testing: you can make sure things work on your device before committing by using pytest tests

Comment on lines 13 to 26
class ReptileCallback(Callback):

# Meta training framework for addressing the generalization issue
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587
def __init__(self,
num_tasks,
alpha,
alpha_decay,
min_size,
max_size,
sch_bar = 0.9,
data_type = "size",
print_log=True):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documentation] It's recommended to have a doc with parameters, possibly including the data type, constraints, hints, etc. Better for us "non-experts" to understand 😆

class ReptileCallback(Callback):
    """Meta training framework for addressing the generalization issue
    Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587

    Args:
        - num_tasks: number of task types, i.e. `B` in the original paper
        - alpha: ...
        - ...
    """

    def __init__(
        self,
        num_tasks: int,
        alpha: float,
        alpha_decay: float,
        min_size: int,
        max_size: int,
        sch_bar: float = 0.9,
        data_type: str = "size",
        print_log: bool = True,
    ):
        super().__init__()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Chuanbo. I have added the documentation as you recommended, along with the generation code for some distributions defined in the generalization-related works. Now the meta learning framework is supported for cross-distribution generalization.

@@ -0,0 +1,184 @@
import torch

class Cluster():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! @cbhua I think we should train some model with say TSP50 / CVRP50 with a mixed distribution and test its generalization performance

Minor comment: Shouldn't this be a subclass of torch.distributions.distribution.Distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes actually this class is what we are missing in the distribution! This could be used by various environments' generator.

About the experiment, we want to test the distribution generalization ability right?

examples/2d-meta_train.py Outdated Show resolved Hide resolved
@fedebotu
Copy link
Member

fedebotu commented Jun 8, 2024

Routine check, how is progress going? I think the multi-distribution generators in particular should be included, since they are part of MDPOMO, right?

@jieyibi
Copy link
Contributor Author

jieyibi commented Jun 11, 2024

The updated commit supports training on multiple mixed distributions by changing the argument loc_distribution to "mix_distribution" in the generator_params of the environment, e.g.

env = CVRPEnv(generator_params={'loc_distribution': "mix_distribution"})

The rest of the arguments remain the same.

Note that this mixed distribution setting follows the # setting in Bi et al., 2022.

@fedebotu
Copy link
Member

Great! Should the code be merged?

Change some parameters for performance
@jieyibi
Copy link
Contributor Author

jieyibi commented Jun 19, 2024

Yes. I think it is ready to be merged. The newly reproduced performance (by changing some key parameters) seems similar to that in the literature.

@fedebotu fedebotu added this to the 0.5.0 milestone Jun 19, 2024
@fedebotu
Copy link
Member

Great!

How about the generalization experiments, i.e., MDPOMO? Do you have yaml configuration for them like this, or a training script?

@jieyibi
Copy link
Contributor Author

jieyibi commented Jun 19, 2024

Yeah, already added to the main branch:)

@fedebotu
Copy link
Member

Awesome! Then we can go ahead and merge :)

@fedebotu fedebotu merged commit 911a754 into ai4co:main Jun 20, 2024
9 checks passed
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants