diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 7e70a970913e..5c6b71c74797 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -15,6 +15,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from nemo.lightning.pytorch.plugins import MegatronDataSampler class HfDatasetDataModule(pl.LightningDataModule): @@ -24,6 +25,7 @@ def __init__( num_workers=2, pin_memory=True, persistent_workers=True, + seq_length=1024, micro_batch_size=2, global_batch_size=2, pad_token_id=0, @@ -37,6 +39,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.seq_length = seq_length self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.pad_token_id = pad_token_id @@ -58,6 +61,7 @@ def pad_within_micro(batch, pad_token_id): max_len = max(map(len, batch)) return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask'])) return { key: batchify( torch.LongTensor( @@ -67,16 +71,26 @@ def pad_within_micro(batch, pad_token_id): ) ) ) - for key in ['tokens', 'labels'] + for key in keys } + def setup(self, stage: str): + if not self.use_mcore_sampler: + return + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + ) + def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) - dataloader = DataLoader( + return DataLoader( self.dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, @@ -84,20 +98,3 @@ def train_dataloader(self, collate_fn=None): collate_fn=collate_fn, batch_size=self.micro_batch_size, ) - if not self.use_mcore_sampler: - return dataloader - - rank = 0 - world_size = 1 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - return add_megatron_sampler( - dataloader, - self.micro_batch_size, - self.global_batch_size, - dataloader_type=self.mcore_dataloader_type, - rank=rank, - world_size=world_size, - )