diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py new file mode 100644 index 000000000000..28db1e187056 --- /dev/null +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -0,0 +1,119 @@ +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl + +import nemo.lightning as nl +from nemo.collections import llm +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed + + +def default_finetune_recipe( + model: run.Config[pl.LightningModule], + resume_path: str, + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, +) -> run.Partial: + """ + Create a default fine-tuning recipe for any model. + + This function sets up a template for a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + model (run.Config[pl.LightningModule]): Configuration for a NeMo model. + resume_path (str): Path to the Huggingface model. + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + See usages of this recipe for further details. + """ + recipe = run.Partial( + llm.finetune, + model=model, + trainer=default_finetune_trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + ), + data=run.Config(llm.SquadDataModule, seq_length=2048, global_batch_size=128, micro_batch_size=1), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=0, warmup_steps=50), + resume=nemo_resume(resume_path), + ) + + return recipe + + +def default_finetune_trainer( + tensor_parallelism=1, + pipeline_parallelism=1, + pipeline_parallelism_type=None, + virtual_pipeline_parallelism=None, + context_parallelism=1, + sequence_parallelism=False, + num_nodes=1, + num_gpus_per_node=8, + max_steps=1000, + limit_test_batches=None, + limit_val_batches=None, + val_check_interval=5, +): + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=val_check_interval, + ) + + return trainer + + +def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: + """ + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}. + + This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. + When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). + + This function sets up the configuration to resume training from path nemo://{model_id}. + This translates to the full path {NEMO_HOME}/models/{model_id}. + + Args: + model_id (str): The Huggingface model to resume. + + Returns: + run.Config[nl.AutoResume]: Configuration for resuming from NeMo checkpoint. + """ + return run.Config( + nl.AutoResume, + restore_config=run.Config(nl.RestoreConfig, path=f"nemo://{model_id}"), + ) diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index 09c1474ad311..9cfc198038f2 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -233,47 +234,27 @@ def pretrain_recipe_performance( return recipe -def hf_resume() -> run.Config[nl.AutoResume]: - """ - Configure automatic resumption from a Hugging Face checkpoint for Llama3 70B model. - - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. - - More info about the model can be found at: https://huggingface.co/meta-llama/Meta-Llama-3-70B - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. - - Note: - This is particularly useful for fine-tuning scenarios where you want to - start from the pre-trained Llama3 70B model. - """ - return run.Config( - nl.AutoResume, - restore_config=run.Config(nl.RestoreConfig, path="hf://meta-llama/Meta-Llama-3-70B"), - ) - - @run.cli.factory(target=finetune, name=NAME) def finetune_recipe( dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', ) -> run.Partial: """ Create a fine-tuning recipe for Llama3 70B model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. - It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning of the large model. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. Args: dir (Optional[str]): Directory for saving logs and checkpoints. name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -291,8 +272,16 @@ def finetune_recipe( This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model requires substantial computational resources. """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA) - recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + recipe = default_finetune_recipe(model(), "meta-llama/Meta-Llama-3-70B", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + assert num_nodes >= 4 + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b_16k.py b/nemo/collections/llm/recipes/llama3_70b_16k.py index 3798088ff722..c8c1957d7bdc 100644 --- a/nemo/collections/llm/recipes/llama3_70b_16k.py +++ b/nemo/collections/llm/recipes/llama3_70b_16k.py @@ -129,48 +129,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 2, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Llama3 70B model with 16k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 16k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory llama3_70b_16k - $ nemo llm finetune --factory "llama3_70b_16k(num_nodes=4, name='my_70b_16k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="llama3_70b_16k_finetune", num_nodes=4) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning the large 70B model with longer sequences (16k). - It uses the SQuAD dataset adapted for 16k sequence length. Be aware that this configuration - requires substantial computational resources. - """ - recipe = llama3_70b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b_64k.py b/nemo/collections/llm/recipes/llama3_70b_64k.py index 353bdd659947..5d9845d9aaa7 100644 --- a/nemo/collections/llm/recipes/llama3_70b_64k.py +++ b/nemo/collections/llm/recipes/llama3_70b_64k.py @@ -132,48 +132,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 32, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Llama3 70B model with 64k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 64k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory llama3_70b_64k - $ nemo llm finetune --factory "llama3_70b_64k(num_nodes=32, name='my_70b_64k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="llama3_70b_64k_finetune", num_nodes=32) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning the large 70B model with long sequences (64k). - It uses the SQuAD dataset adapted for 64k sequence length. Be aware that this configuration - requires extensive computational resources due to the model size and extended sequence length. - """ - recipe = llama3_70b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 1f45277b2255..4b2934739529 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -233,42 +234,27 @@ def pretrain_recipe_performance( return recipe -def hf_resume() -> run.Config[nl.AutoResume]: - """Configure automatic resumption from a Hugging Face checkpoint. - - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. - - More info about the model can be found at: https://huggingface.co/meta-llama/Meta-Llama-3-8B - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. - """ - return run.Config( - nl.AutoResume, - restore_config=run.Config(nl.RestoreConfig, path="hf://meta-llama/Meta-Llama-3-8B"), - ) - - @run.cli.factory(target=finetune, name=NAME) def finetune_recipe( dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', ) -> run.Partial: """ Create a fine-tuning recipe for Llama3 8B model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. - It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. Args: dir (Optional[str]): Directory for saving logs and checkpoints. name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -286,8 +272,13 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA) - recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + recipe = default_finetune_recipe(model(), "meta-llama/Meta-Llama-3-8B", dir, name, num_nodes, num_gpus_per_node) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b_16k.py b/nemo/collections/llm/recipes/llama3_8b_16k.py index bd02f1975864..0b42b392827a 100644 --- a/nemo/collections/llm/recipes/llama3_8b_16k.py +++ b/nemo/collections/llm/recipes/llama3_8b_16k.py @@ -128,47 +128,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 1, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Llama3 8B model with 16k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 16k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory llama3_8b_16k - $ nemo llm finetune --factory "llama3_8b_16k(num_nodes=2, name='my_16k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="llama3_8b_16k_finetune", num_nodes=2) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning with longer sequences (16k) compared to the standard 8k version. - It uses the SQuAD dataset adapted for 16k sequence length. - """ - recipe = llama3_8b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/llama3_8b_64k.py b/nemo/collections/llm/recipes/llama3_8b_64k.py index e5845e4530ca..38f787113bf5 100644 --- a/nemo/collections/llm/recipes/llama3_8b_64k.py +++ b/nemo/collections/llm/recipes/llama3_8b_64k.py @@ -129,48 +129,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 1, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Llama3 8B model with 64k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 64k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory llama3_8b_64k - $ nemo llm finetune --factory "llama3_8b_64k(num_nodes=2, name='my_64k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="llama3_8b_64k_finetune", num_nodes=2) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning with long sequences (64k) compared to the standard 8k version. - It uses the SQuAD dataset adapted for 64k sequence length. Be aware that this configuration requires - substantial computational resources due to the extended sequence length. - """ - recipe = llama3_8b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/mistral.py b/nemo/collections/llm/recipes/mistral.py index c0e50074f26b..16af2b4238f6 100644 --- a/nemo/collections/llm/recipes/mistral.py +++ b/nemo/collections/llm/recipes/mistral.py @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -186,47 +187,27 @@ def pretrain_recipe( ) -@run.cli.factory(name=NAME + "_hf") -def hf_resume() -> run.Config[nl.AutoResume]: - """ - Configure automatic resumption from a Hugging Face checkpoint for Mistral 7B model. - - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. - - More info about the model can be found at: https://huggingface.co/mistralai/Mistral-7B-v0.3 - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. - - Note: - This is particularly useful for fine-tuning scenarios where you want to - start from the pre-trained Mistral 7B model. - """ - return run.Config( - nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mistral-7B-v0.3") - ) - - @run.cli.factory(target=finetune, name=NAME) def finetune_recipe( dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', ) -> run.Partial: """ Create a fine-tuning recipe for Mistral 7B model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. - It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. Args: dir (Optional[str]): Directory for saving logs and checkpoints. name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -243,8 +224,15 @@ def finetune_recipe( Note: This recipe uses the SQuAD dataset for fine-tuning. """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA) - recipe.data = run.Config(SquadDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1) + recipe = default_finetune_recipe( + model(), "nemo://mistralai/Mistral-7B-v0.3", dir, name, num_nodes, num_gpus_per_node + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index 82f7cae23dba..fe30288c179c 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x22B, MixtralModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback @@ -224,47 +225,27 @@ def pretrain_recipe_performance( return recipe -def hf_resume() -> run.Config[nl.AutoResume]: - """ - Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model. - - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. - - More info about the model can be found at: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1 - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. - - Note: - This is particularly useful for fine-tuning scenarios where you want to - start from the pre-trained Mixtral 8x22B model. - """ - return run.Config( - nl.AutoResume, - restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x22B-v0.1"), - ) - - @run.cli.factory(target=finetune, name=NAME) def finetune_recipe( dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', ) -> run.Partial: """ Create a fine-tuning recipe for Mixtral 8x22B model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. - It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. Args: dir (Optional[str]): Directory for saving logs and checkpoints. name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -281,8 +262,19 @@ def finetune_recipe( Note: This recipe uses the SQuAD dataset for fine-tuning. """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) - recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + recipe = default_finetune_recipe( + model(), "mistralai/Mixtral-8x22B-v0.1mistralai/Mixtral-8x22B-v0.1", dir, name, num_nodes, num_gpus_per_node + ) + recipe.trainer.strategy.expert_model_parallel_size = 8 + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 14 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x3b.py b/nemo/collections/llm/recipes/mixtral_8x3b.py index ca5b4e35039f..a02bad4ffb93 100644 --- a/nemo/collections/llm/recipes/mixtral_8x3b.py +++ b/nemo/collections/llm/recipes/mixtral_8x3b.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Callable, Optional +from typing import Optional import nemo_run as run import pytorch_lightning as pl @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x3B, MixtralModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -225,66 +226,3 @@ def pretrain_recipe_performance( ) return recipe - - -def hf_resume() -> run.Config[nl.AutoResume]: - """ - Configure the Hugging Face model resuming for Mixtral 8x3B model. - - This function sets up the configuration for resuming training from a Hugging Face model. - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from a Hugging Face model. - - Examples: - CLI usage: - $ nemo llm finetune --factory "mixtral_8x3b(resume=hf_resume())" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x3b_finetune", num_nodes=2) - >>> recipe.resume = hf_resume() - >>> print(recipe) - """ - return run.Config( - nl.AutoResume, - restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x7B-v0.1"), - ) - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 1, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Mixtral 8x3B model. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory mixtral_8x3b - $ nemo llm finetune --factory "mixtral_8x3b(num_nodes=2, name='my_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x3b_finetune", num_nodes=2) - >>> print(recipe) - """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) - recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) - return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x3b_16k.py b/nemo/collections/llm/recipes/mixtral_8x3b_16k.py index 287ac331ee65..13ca1c2d4537 100644 --- a/nemo/collections/llm/recipes/mixtral_8x3b_16k.py +++ b/nemo/collections/llm/recipes/mixtral_8x3b_16k.py @@ -130,47 +130,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 1, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Mixtral 8x3B model with 16k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 16k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory mixtral_8x3b_16k - $ nemo llm finetune --factory "mixtral_8x3b_16k(num_nodes=2, name='my_16k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x3b_16k_finetune", num_nodes=2) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning with longer sequences (16k) compared to the standard version. - It uses the SQuAD dataset adapted for 16k sequence length. - """ - recipe = mixtral_8x3b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x3b_64k.py b/nemo/collections/llm/recipes/mixtral_8x3b_64k.py index 98cf2f4f9e7b..e21d85a13dcd 100644 --- a/nemo/collections/llm/recipes/mixtral_8x3b_64k.py +++ b/nemo/collections/llm/recipes/mixtral_8x3b_64k.py @@ -131,48 +131,3 @@ def pretrain_recipe( recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) recipe.data = run.Config(MockDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 8, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Mixtral 8x3B model with 64k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 64k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory mixtral_8x3b_64k - $ nemo llm finetune --factory "mixtral_8x3b_64k(num_nodes=8, name='my_64k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x3b_64k_finetune", num_nodes=8) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning with long sequences (64k) compared to the standard version. - It uses the SQuAD dataset adapted for 64k sequence length. Be aware that this configuration requires - substantial computational resources due to the extended sequence length. - """ - recipe = mixtral_8x3b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index 9000c66c3445..b9ffaa03d341 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -27,6 +27,7 @@ from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback @@ -205,7 +206,7 @@ def pretrain_recipe_performance( $ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')" Python API usage: - >>> recipe = pretrain_recipe_performance(name="mixtral_8x3b_perf", num_nodes=8) + >>> recipe = pretrain_recipe_performance(name="mixtral_8x7b_perf", num_nodes=8) >>> print(recipe) Note: @@ -223,47 +224,27 @@ def pretrain_recipe_performance( return recipe -def hf_resume() -> run.Config[nl.AutoResume]: - """ - Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x7B model. - - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. - - More info about the model can be found at: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1 - - Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. - - Note: - This is particularly useful for fine-tuning scenarios where you want to - start from the pre-trained Mixtral 8x7B model. - """ - return run.Config( - nl.AutoResume, - restore_config=run.Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x7B-v0.1"), - ) - - @run.cli.factory(target=finetune, name=NAME) def finetune_recipe( dir: Optional[str] = None, name: str = "default", - num_nodes: int = 2, + num_nodes: int = 1, num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', ) -> run.Partial: """ Create a fine-tuning recipe for Mixtral 8x7B model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. - It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. Args: dir (Optional[str]): Directory for saving logs and checkpoints. name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -280,8 +261,15 @@ def finetune_recipe( Note: This recipe uses the SQuAD dataset for fine-tuning. """ - recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune) - recipe.resume = hf_resume() - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) - recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1) + recipe = default_finetune_recipe(model(), "mistralai/Mixtral-8x7B-v0.1", dir, name, num_nodes, num_gpus_per_node) + recipe.trainer.strategy.expert_model_parallel_size = 8 + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py index 4b5fd07a69e9..8b26a8c7c3e3 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py @@ -129,46 +129,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 2, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Mixtral 8x7B model with 16k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 16k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory mixtral_8x7b_16k - $ nemo llm finetune --factory "mixtral_8x7b_16k(num_nodes=2, name='my_16k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x7b_16k_finetune", num_nodes=2) - >>> print(recipe) - - Note: - This recipe uses the SQuAD dataset for fine-tuning. - """ - recipe = mixtral_8x7b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=16384, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py index 6a1f76961325..6c8f7077fba3 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py @@ -133,48 +133,3 @@ def pretrain_recipe( recipe.data = run.Config(MockDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) return recipe - - -@run.cli.factory(target=finetune, name=NAME) -def finetune_recipe( - dir: Optional[str] = None, - name: str = "default", - num_nodes: int = 16, - num_gpus_per_node: int = 8, -) -> run.Partial: - """ - Create a fine-tuning recipe for Mixtral 8x7B model with 64k sequence length. - - This function sets up a complete configuration for fine-tuning, including - model, trainer, and data settings optimized for 64k sequence length. - - Args: - dir (Optional[str]): Directory for saving logs and checkpoints. - name (str): Name of the fine-tuning run. - num_nodes (int): Number of compute nodes to use. - num_gpus_per_node (int): Number of GPUs per node. - - Returns: - run.Partial: Partial configuration for fine-tuning. - - Examples: - CLI usage: - $ nemo llm finetune --factory mixtral_8x7b_64k - $ nemo llm finetune --factory "mixtral_8x7b_64k(num_nodes=16, name='my_64k_finetune')" - - Python API usage: - >>> recipe = finetune_recipe(name="mixtral_8x7b_64k_finetune", num_nodes=16) - >>> print(recipe) - - Note: - This recipe is optimized for fine-tuning with long sequences (64k) compared to the standard version. - It uses the SQuAD dataset adapted for 64k sequence length. Be aware that this configuration requires - substantial computational resources due to the model size and extended sequence length. - """ - recipe = mixtral_8x7b.finetune_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - - recipe.model = model() - recipe.trainer = trainer(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) - recipe.data = run.Config(SquadDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1) - - return recipe diff --git a/nemo/collections/llm/recipes/nemotron3_8b.py b/nemo/collections/llm/recipes/nemotron3_8b.py index 05fb2cb8dcf5..3cdb647b5f84 100644 --- a/nemo/collections/llm/recipes/nemotron3_8b.py +++ b/nemo/collections/llm/recipes/nemotron3_8b.py @@ -174,25 +174,28 @@ def pretrain_recipe( ) -@run.cli.factory(name=NAME + "_hf") -def hf_resume() -> run.Config[nl.AutoResume]: +@run.cli.factory(name=NAME + "_nemo") +def nemo_resume() -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a Hugging Face checkpoint for Nemotron3 8B model. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for Nemotron3 8B model. - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. + More info about the Huggingface model can be found at: https://huggingface.co/nvidia/nemotron-3-8b-base-4k. - More info about the model can be found at: https://huggingface.co/nvidia/nemotron-3-8b-base-4k + This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. + When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). + + This function sets up the configuration to resume training from path nemo://nvidia/nemotron-3-8b-base-4k. + This translates to the full path {NEMO_HOME}/models/nvidia/nemotron-3-8b-base-4k. Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. + run.Config[nl.AutoResume]: Configuration for resuming from NeMo checkpoint. Note: This is particularly useful for fine-tuning scenarios where you want to start from the pre-trained Nemotron3 8B model. """ return run.Config( - nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://nvidia/nemotron-3-8b-base-4k") + nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="nemo://nvidia/nemotron-3-8b-base-4k") ) @@ -308,7 +311,7 @@ def finetune_recipe( max_lr=max_lr, fn=fn, ) - recipe.resume = hf_resume() + recipe.resume = nemo_resume() recipe.peft = run.Config(LoRA) recipe.data = run.Config( SquadDataModule, seq_length=seq_length, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index 832b5bad3028..238acb0dac3c 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -174,25 +174,28 @@ def pretrain_recipe( ) -@run.cli.factory(name=NAME + "_hf") -def hf_resume() -> run.Config[nl.AutoResume]: +@run.cli.factory(name=NAME + "_nemo") +def nemo_resume() -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a Hugging Face checkpoint for Nemotron4 340B model. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for Nemotron4 340B model. - This function sets up the configuration to resume training from a pre-trained - Hugging Face model checkpoint. + More info about the Huggingface model can be found at: https://huggingface.co/nvidia/Nemotron-4-340B-Base. - More info about the model can be found at: https://huggingface.co/nvidia/Nemotron-4-340B-Base + This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. + When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). + + This function sets up the configuration to resume training from path nemo://nvidia/Nemotron-4-340B-Base. + This translates to the full path {NEMO_HOME}/models/nvidia/Nemotron-4-340B-Base. Returns: - run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. + run.Config[nl.AutoResume]: Configuration for resuming from NeMo checkpoint. Note: This is particularly useful for fine-tuning scenarios where you want to start from the pre-trained Nemotron4 340B model. """ return run.Config( - nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="hf://nvidia/Nemotron-4-340B-Base") + nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path="nemo://nvidia/Nemotron-4-340B-Base") ) @@ -308,7 +311,7 @@ def finetune_recipe( max_lr=max_lr, fn=fn, ) - recipe.resume = hf_resume() + recipe.resume = nemo_resume() recipe.peft = run.Config(LoRA) recipe.data = run.Config( SquadDataModule, seq_length=seq_length, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 096c7728d4a1..a1443c7de242 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -18,6 +18,7 @@ import inspect import queue from collections import defaultdict +from contextlib import nullcontext from dataclasses import dataclass from typing import ( Any, @@ -426,16 +427,21 @@ def init_ddp(self): for model_chunk_idx, model_chunk in enumerate(self): module = model_chunk.module - ddp = DDP( - module.config, - self.ddp_config, - module, - data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0), - ) + # Mcore DistributedDataParallel has to be called with grad. Normally this call is redundant, but for + # PEFT with num_sanity_val_steps > 0 this is necessary. + init_ddp_context = nullcontext if all(x.requires_grad for x in module.parameters()) else torch.enable_grad + with init_ddp_context(): + ddp = DDP( + module.config, + self.ddp_config, + module, + data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0), + ) + model_chunk.module = ddp model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 9f562e0adb73..99b370d45f71 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from nemo.lightning import io +from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.utils import logging from nemo.utils.app_state import AppState @@ -101,7 +102,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): model = _try_restore_tokenizer(model, context_path) elif self.restore_config: - new_path = self._try_import_model( + new_path = self._extract_path( model=model, path=self.restore_config.path, adapter_path=self.restore_config.adapter_path, @@ -112,17 +113,22 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): else: self.restore_config.path = str(new_path) trainer.strategy.restore_config = self.restore_config + # Load artifacts + if self.restore_config.load_artifacts: + context_path = new_path / "context" + if not context_path.is_dir(): + context_path = new_path + + _try_restore_tokenizer(model, context_path) - def _try_import_model( + def _extract_path( self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None ) -> BasePath: - - if model is None: - raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.") - try: - new_path = model.import_ckpt(path) - except (ValueError, AttributeError): - # This is reached when the model connector does not exist for the particular path. + if "://" in path: + assert path.startswith("nemo://"), "Only NeMo based paths starting with nemo:// are currently supported." + _, _path = path.split("://") + new_path = os.path.join(NEMO_MODELS_CACHE, _path) + else: new_path = path if adapter_path: @@ -146,7 +152,7 @@ def _resume_peft(self, adapter_meta_path, model): assert ( "://" in self.restore_config.path ), "For now PEFT resume requires specifying the import path instead of local path" - base_model_path = self._try_import_model(model, self.restore_config.path) + base_model_path = self._extract_path(model, self.restore_config.path) if base_model_path != Path(metadata['model_ckpt_path']): raise ValueError( f"When trying to resume a PEFT training run, found mismatching values: "