Skip to content

Commit

Permalink
[Feat,BugFix] avoid shuffling train data by default; do not evaluate …
Browse files Browse the repository at this point in the history
…rollout on very last epoch
  • Loading branch information
fedebotu committed Dec 7, 2023
1 parent 29be07f commit 14d072e
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions rl4co/models/rl/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RL4COLitModule(LightningModule):
lr_scheduler_interval: learning rate scheduler interval
lr_scheduler_monitor: learning rate scheduler monitor
generate_default_data: whether to generate default datasets, filling up the data directory
shuffle_train_dataloader: whether to shuffle training dataloader
shuffle_train_dataloader: whether to shuffle training dataloader. Default is False since we recreate dataset every epoch
dataloader_num_workers: number of workers for dataloader
data_dir: data directory
metrics: metrics
Expand All @@ -50,7 +50,7 @@ def __init__(
batch_size: int = 512,
val_batch_size: int = None,
test_batch_size: int = None,
train_data_size: int = 1_280_000,
train_data_size: int = 100_000,
val_data_size: int = 10_000,
test_data_size: int = 10_000,
optimizer: Union[str, torch.optim.Optimizer, partial] = "Adam",
Expand All @@ -63,7 +63,7 @@ def __init__(
lr_scheduler_interval: str = "epoch",
lr_scheduler_monitor: str = "val/reward",
generate_default_data: bool = False,
shuffle_train_dataloader: bool = True,
shuffle_train_dataloader: bool = False,
dataloader_num_workers: int = 0,
data_dir: str = "data/",
log_on_step: bool = True,
Expand Down Expand Up @@ -278,8 +278,12 @@ def on_train_epoch_end(self):
"""Called at the end of the training epoch. This can be used for instance to update the train dataset
with new data (which is the case in RL).
"""
train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train")
self.train_dataset = self.wrap_dataset(train_dataset)
# Only update if not in the first epoch
# If last epoch, we don't need to update since we will not use the dataset anymore
if self.current_epoch < self.trainer.max_epochs - 1:
log.info("Generating training dataset for next epoch...")
train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train")
self.train_dataset = self.wrap_dataset(train_dataset)

def wrap_dataset(self, dataset):
"""Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to
Expand Down

0 comments on commit 14d072e

Please sign in to comment.