Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta learning framework #183

Merged
merged 10 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/2d-meta_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger

from rl4co.envs import CVRPEnv
from rl4co.models.zoo.am import AttentionModelPolicy
from rl4co.models.zoo.pomo import POMO
from rl4co.utils.trainer import RL4COTrainer
from rl4co.utils.meta_trainer import ReptileCallback

def main():
# Set device
device_id = 0

# RL4CO env based on TorchRL
env = CVRPEnv(generator_params={'num_loc': 50})

# Policy: neural network, in this case with encoder-decoder architecture
# Note that this is adapted the same as POMO did in the original paper
policy = AttentionModelPolicy(env_name=env.name,
embed_dim=128,
num_encoder_layers=6,
num_heads=8,
normalization="instance",
use_graph_context=False
)

# RL Model (POMO)
model = POMO(env,
policy,
batch_size=64, # meta_batch_size
train_data_size=64 * 50, # equals to (meta_batch_size) * (gradient decent steps in the inner-loop optimization of meta-learning method)
val_data_size=0,
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6},
)

# Example callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="meta_pomo/checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max", # maximize validation reward
)
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback

# Meta callbacks
meta_callback = ReptileCallback(
num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper
alpha = 0.9, # initial weight of the task model for the outer-loop optimization of reptile
alpha_decay = 1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better.
min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization)
max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization)
data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"]
sch_bar=0.9, # for the task scheduler of size setting, where lr_decay_epoch = sch_bar * epochs, i.e. after this epoch, learning rate will decay with a weight 0.1
print_log=True # whether to print the sampled tasks in each meta iteration
)
callbacks = [meta_callback, checkpoint_callback, rich_model_summary]

# Logger
logger = WandbLogger(project="rl4co", name=f"{env.name}_pomo_reptile")
# logger = None # uncomment this line if you don't want logging

# Adjust your trainer to the number of epochs you want to run
trainer = RL4COTrainer(
max_epochs=15000, # (the number of meta_model updates) * (the number of tasks in a mini-batch)
callbacks=callbacks,
accelerator="gpu",
devices=[device_id],
logger=logger,
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method
)

# Fit
trainer.fit(model)


if __name__ == "__main__":
main()

254 changes: 254 additions & 0 deletions rl4co/envs/common/distribution_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import random
import torch

class Cluster():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! @cbhua I think we should train some model with say TSP50 / CVRP50 with a mixed distribution and test its generalization performance

Minor comment: Shouldn't this be a subclass of torch.distributions.distribution.Distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes actually this class is what we are missing in the distribution! This could be used by various environments' generator.

About the experiment, we want to test the distribution generalization ability right?


"""
Multiple gaussian distributed clusters, as in the Solomon benchmark dataset
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686)

Args:
n_cluster: Number of the gaussian distributed clusters
"""
def __init__(self, n_cluster: int = 3):
super().__init__()
self.lower, self.upper = 0.2, 0.8
self.std = 0.07
self.n_cluster = n_cluster
def sample(self, size):

batch_size, num_loc, _ = size

# Generate the centers of the clusters
center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster * 2)

# Pre-define the coordinates
coords = torch.zeros(batch_size, num_loc, 2)

# Calculate the size of each cluster
cluster_sizes = [num_loc // self.n_cluster] * self.n_cluster
for i in range(num_loc % self.n_cluster):
cluster_sizes[i] += 1

# Generate the coordinates
current_index = 0
for i in range(self.n_cluster):
means = center[:, i * 2:(i + 1) * 2]
stds = torch.full((batch_size, 2), self.std)
points = torch.normal(means.unsqueeze(1).expand(-1, cluster_sizes[i], -1),
stds.unsqueeze(1).expand(-1, cluster_sizes[i], -1))
coords[:, current_index:current_index + cluster_sizes[i], :] = points
current_index += cluster_sizes[i]

# Confine the coordinates to range [0, 1]
coords.clamp_(0, 1)

return coords

class Mixed():

"""
50% nodes sampled from uniform distribution, 50% nodes sampled from gaussian distribution, as in the Solomon benchmark dataset
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686)

Args:
n_cluster_mix: Number of the gaussian distributed clusters
"""

def __init__(self, n_cluster_mix=1):
super().__init__()
self.lower, self.upper = 0.2, 0.8
self.std = 0.07
self.n_cluster_mix = n_cluster_mix
def sample(self, size):

batch_size, num_loc, _ = size

# Generate the centers of the clusters
center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster_mix * 2)

# Pre-define the coordinates sampled under uniform distribution
coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1)

# Sample mutated index (default setting: 50% mutation)
mutate_idx = torch.stack([torch.randperm(num_loc)[:num_loc // 2] for _ in range(batch_size)])

# Generate the coordinates
segment_size = num_loc // (2 * self.n_cluster_mix)
remaining_indices = num_loc // 2 - segment_size * (self.n_cluster_mix - 1)
sizes = [segment_size] * (self.n_cluster_mix - 1) + [remaining_indices]
for i in range(self.n_cluster_mix):
indices = mutate_idx[:, sum(sizes[:i]):sum(sizes[:i + 1])]
means_x = center[:, 2 * i].unsqueeze(1).expand(-1, sizes[i])
means_y = center[:, 2 * i + 1].unsqueeze(1).expand(-1, sizes[i])
coords.scatter_(1, indices.unsqueeze(-1).expand(-1, -1, 2),
torch.stack([
torch.normal(means_x.expand(-1, sizes[i]), self.std),
torch.normal(means_y.expand(-1, sizes[i]), self.std)
], dim=2))

# Confine the coordinates to range [0, 1]
coords.clamp_(0, 1)

return coords

class Gaussian_Mixture():
'''
Following Zhou et al. (2023): https://arxiv.org/abs/2305.19587

Args:
num_modes: the number of clusters/modes in the Gaussian Mixture.
cdist: scale of the uniform distribution for center generation.
'''
def __init__(self, num_modes: int = 0, cdist: int = 0):
super().__init__()
self.num_modes = num_modes
self.cdist = cdist

def sample(self, size):

batch_size, num_loc, _ = size

if self.num_modes == 0: # (0, 0) - uniform
return torch.rand((batch_size, num_loc, 2))
elif self.num_modes == 1 and self.cdist == 1: # (1, 1) - gaussian
return self.generate_gaussian(batch_size, num_loc)
else:
res = [self.generate_gaussian_mixture(num_loc) for _ in range(batch_size)]
return torch.stack(res)

def generate_gaussian_mixture(self, num_loc):

"""Following the setting in Zhang et al. 2022 (https://arxiv.org/abs/2204.03236)"""

# Randomly decide how many points each mode gets
nums = torch.multinomial(input=torch.ones(self.num_modes) / self.num_modes, num_samples=num_loc, replacement=True)

# Prepare to collect points
coords = torch.empty((0, 2))

# Generate points for each mode
for i in range(self.num_modes):
num = (nums == i).sum() # Number of points in this mode
if num > 0:
center = torch.rand((1, 2)) * self.cdist
cov = torch.eye(2) # Covariance matrix
nxy = torch.distributions.MultivariateNormal(center.squeeze(), covariance_matrix=cov).sample((num,))
coords = torch.cat((coords, nxy), dim=0)

return self._global_min_max_scaling(coords)

def generate_gaussian(self, batch_size, num_loc):

"""Following the setting in Xin et al. 2022 (https://openreview.net/pdf?id=nJuzV-izmPJ)"""

# Mean and random covariances
mean = torch.full((batch_size, num_loc, 2), 0.5)
covs = torch.rand(batch_size) # Random covariances between 0 and 1

# Generate the coordinates
coords = torch.zeros((batch_size, num_loc, 2))
for i in range(batch_size):
# Construct covariance matrix for each sample
cov_matrix = torch.tensor([[1.0, covs[i]], [covs[i], 1.0]])
m = torch.distributions.MultivariateNormal(mean[i], covariance_matrix=cov_matrix)
coords[i] = m.sample()

# Shuffle the coordinates
indices = torch.randperm(coords.size(0))
coords = coords[indices]

return self._batch_normalize_and_center(coords)

def _global_min_max_scaling(self, coords):

# Scale the points to [0, 1] using min-max scaling
coords_min = coords.min(0, keepdim=True).values
coords_max = coords.max(0, keepdim=True).values
coords = (coords - coords_min) / (coords_max - coords_min)

return coords

def _batch_normalize_and_center(self, coords):
# Step 1: Compute min and max along each batch
coords_min = coords.min(dim=1, keepdim=True).values
coords_max = coords.max(dim=1, keepdim=True).values

# Step 2: Normalize coordinates to range [0, 1]
coords = coords - coords_min # Broadcasting subtracts min value on each coordinate
range_max = (coords_max - coords_min).max(dim=-1, keepdim=True).values # The maximum range among both coordinates
coords = coords / range_max # Divide by the max range to normalize

# Step 3: Center the batch in the middle of the [0, 1] range
coords = coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 # Centering the batch

return coords

class Mix_Distribution():

'''
Mixture of three exemplar distributions in batch-level, i.e. Uniform, Cluster, Mixed
Following the setting in Bi et al. 2022 (https://arxiv.org/abs/2210.07686)

Args:
n_cluster: Number of the gaussian distributed clusters in Cluster distribution
n_cluster_mix: Number of the gaussian distributed clusters in Mixed distribution
'''
def __init__(self, n_cluster=3, n_cluster_mix=1):
super().__init__()
self.lower, self.upper = 0.2, 0.8
self.std = 0.07
self.Mixed = Mixed(n_cluster_mix=n_cluster_mix)
self.Cluster = Cluster(n_cluster=n_cluster)

def sample(self, size):

batch_size, num_loc, _ = size

# Pre-define the coordinates sampled under uniform distribution
coords = torch.FloatTensor(batch_size, num_loc, 2).uniform_(0, 1)

# Random sample probability for the distribution of each sample
p = torch.rand(batch_size)

# Mixed
mask = p <= 0.33
n_mixed = mask.sum().item()
if n_mixed > 0:
coords[mask] = self.Mixed.sample((n_mixed, num_loc, 2))

# Cluster
mask = (p > 0.33) & (p <= 0.66)
n_cluster = mask.sum().item()
if n_cluster > 0:
coords[mask] = self.Cluster.sample((n_cluster, num_loc, 2))

# The remaining ones are uniformly distributed
return coords

class Mix_Multi_Distributions():

'''
Mixture of 11 Gaussian-like distributions in batch-level
Following the setting in Zhou et al. (2023): https://arxiv.org/abs/2305.19587
'''
def __init__(self):
super().__init__()
self.dist_set = [(0, 0), (1, 1)] + [(m, c) for m in [3, 5, 7] for c in [10, 30, 50]]

def sample(self, size):
batch_size, num_loc, _ = size
coords = torch.zeros(batch_size, num_loc, 2)

# Pre-select distributions for the entire batch
dists = [random.choice(self.dist_set) for _ in range(batch_size)]
unique_dists = list(set(dists)) # Unique distributions to minimize re-instantiation

# Instantiate Gaussian_Mixture only once per unique distribution
gm_instances = {dist: Gaussian_Mixture(*dist) for dist in unique_dists}

# Batch process where possible
for i, dist in enumerate(dists):
coords[i] = gm_instances[dist].sample((1, num_loc, 2)).squeeze(0)

return coords
13 changes: 12 additions & 1 deletion rl4co/envs/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tensordict.tensordict import TensorDict
from torch.distributions import Exponential, Normal, Poisson, Uniform

from rl4co.envs.common.distribution_utils import Cluster, Mixed, Gaussian_Mixture, Mix_Distribution, Mix_Multi_Distributions

class Generator(metaclass=abc.ABCMeta):
"""Base data generator class, to be called with `env.generator(batch_size)`"""
Expand Down Expand Up @@ -76,6 +76,16 @@ def get_sampler(
) # todo: should be also `low, high` and any other corner
elif isinstance(distribution, Callable):
return distribution(**kwargs)
elif distribution == "gaussian_mixture":
return Gaussian_Mixture(num_modes=kwargs['num_modes'], cdist=kwargs['cdist'])
elif distribution == "cluster":
return Cluster(kwargs['n_cluster'])
elif distribution == "mixed":
return Mixed(kwargs['n_cluster_mix'])
elif distribution == "mix_distribution":
return Mix_Distribution(kwargs['n_cluster'], kwargs['n_cluster_mix'])
elif distribution == "mix_multi_distributions":
return Mix_Multi_Distributions()
else:
raise ValueError(f"Invalid distribution type of {distribution}")

Expand All @@ -87,3 +97,4 @@ def batch_to_scalar(param):
if isinstance(param, torch.Tensor):
return param.item()
return param

Loading
Loading