Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable checkpoint conversion inside AutoResume #10645

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions nemo/collections/llm/recipes/finetune_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Optional
import nemo_run as run
import nemo.lightning as nl
import pytorch_lightning as pl
from nemo.collections.llm.recipes.log.default import tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections import llm

def default_finetune_recipe(
model: run.Config[pl.LightningModule],
resume_path: str,
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
) -> run.Partial:
"""
Create a default fine-tuning recipe for any model.

This function sets up a template for a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.

Args:
model (run.Config[pl.LightningModule]): Configuration for a NeMo model.
resume_path (str): Path to the Huggingface model.
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.

Returns:
run.Partial: Partial configuration for fine-tuning.

See usages of this recipe for further details.
"""
recipe = run.Partial(
llm.finetune,
model=model,
trainer=default_finetune_trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
),
data=run.Config(llm.SquadDataModule, seq_length=2048, global_batch_size=128, micro_batch_size=1),
log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=0, warmup_steps=50),
resume=nemo_resume(resume_path),
)

return recipe


def default_finetune_trainer(
tensor_parallelism=1,
pipeline_parallelism=1,
pipeline_parallelism_type=None,
virtual_pipeline_parallelism=None,
context_parallelism=1,
sequence_parallelism=False,
num_nodes=1,
num_gpus_per_node=8,
max_steps=1000,
limit_test_batches=None,
limit_val_batches=None,
val_check_interval=5,
):
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
devices=num_gpus_per_node,
limit_test_batches=limit_test_batches,
limit_val_batches=limit_val_batches,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=val_check_interval,
)

return trainer


def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}.

This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt.
When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default).

This function sets up the configuration to resume training from path nemo://{model_id}.
This translates to the full path {NEMO_HOME}/models/{model_id}.

Args:
model_id (str): The Huggingface model to resume.

Returns:
run.Config[nl.AutoResume]: Configuration for resuming from NeMo checkpoint.
"""
return run.Config(
nl.AutoResume,
restore_config=run.Config(nl.RestoreConfig, path=f"nemo://{model_id}"),
)
40 changes: 14 additions & 26 deletions nemo/collections/llm/recipes/llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
Expand Down Expand Up @@ -233,46 +234,27 @@ def pretrain_recipe_performance(
return recipe


def nemo_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a NeMo checkpoint converted from Huggingface for Meta LLama 3 8B.

More info about the Huggingface model can be found at: https://huggingface.co/meta-llama/Meta-Llama-3-8B.

This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt.
When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default).

This function sets up the configuration to resume training from path nemo://meta-llama/Meta-Llama-3-8B.
This translates to the full path {NEMO_HOME}/models/meta-llama/Meta-Llama-3-8B.

Returns:
run.Config[nl.AutoResume]: Configuration for resuming from NeMo checkpoint.
"""
return run.Config(
nl.AutoResume,
restore_config=run.Config(nl.RestoreConfig, path="nemo://meta-llama/Meta-Llama-3-8B"),
)


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
) -> run.Partial:
"""
Create a fine-tuning recipe for Llama3 8B model.

This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
It uses LoRA (Low-Rank Adaptation) for efficient fine-tuning.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.

Returns:
run.Partial: Partial configuration for fine-tuning.
Expand All @@ -290,8 +272,14 @@ def finetune_recipe(
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=finetune)
recipe.resume = nemo_resume()
recipe.peft = run.Config(LoRA)
recipe.data = run.Config(SquadDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1)
recipe = default_finetune_recipe(model(), "meta-llama/Meta-Llama-3-8B",
dir, name, num_nodes, num_gpus_per_node)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe
24 changes: 14 additions & 10 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,16 +426,20 @@ def init_ddp(self):
for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module

ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
# Mcore DistributedDataParallel has to be called with grad. Normally this call is redundant, but for
# PEFT with num_sanity_val_steps > 0 this is necessary.
with torch.enable_grad():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a check to only use this if it's PEFT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, revised!

ddp = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)

model_chunk.module = ddp
model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore
Expand Down
Loading