Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "offline" data cache generation support #9576

Merged
merged 7 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ model:
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 2
num_dataset_builder_threads: 1
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
Expand All @@ -284,7 +285,8 @@ model:
no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token
pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size
shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled
exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem
exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem
data_cache_generation_only: False # Set to True to generate only the data cache and stop the training script

# Nsys profiling options
nsys_profile:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,7 @@ def build_train_valid_test_datasets(self):
"create_attention_mask": not self.get_attention_mask_from_fusion,
"mmap_bin_files": self.cfg.data.get("mmap_bin_files", True),
"drop_last_partial_validation_sequence": self.cfg.data.get("validation_drop_last", True),
"num_dataset_builder_threads": self.cfg.data.get("num_dataset_builder_threads", 1),
"add_extra_token_to_sequence": add_extra_token,
}

Expand Down Expand Up @@ -1683,6 +1684,12 @@ def setup(self, stage=None):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_dl, 'val')

# Data cache generation only
# Stops script execution after creating a data cache
if self.cfg.data.get('data_cache_generation_only', False):
self.trainer.num_sanity_val_steps = 0
self.trainer.should_stop = True

athitten marked this conversation as resolved.
Show resolved Hide resolved
if stage == 'fit':
self.initialize_last_rank_embeddings()

Expand Down
Loading