Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nemo2 peft merge #11017

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d2f9cf7
initial draft
HuiyingLi Oct 7, 2024
c527914
refactor wip
HuiyingLi Oct 9, 2024
65e252d
refac v2 WIP
HuiyingLi Oct 11, 2024
f1e52ed
update address comments and add model dump
HuiyingLi Oct 17, 2024
acd56b0
remove merge script
HuiyingLi Oct 17, 2024
03c4ea0
move driver script
HuiyingLi Oct 17, 2024
5d86702
Merge branch 'main' into huiyingl/nemo2_peftmerge
HuiyingLi Oct 17, 2024
1740d5d
format
HuiyingLi Oct 17, 2024
c8b3a40
format
HuiyingLi Oct 22, 2024
31e0c0a
Merge branch 'main' into huiyingl/nemo2_peftmerge
HuiyingLi Oct 22, 2024
169656d
Apply isort and black reformatting
HuiyingLi Oct 24, 2024
3071ed8
Merge branch 'main' into huiyingl/peftmerge
HuiyingLi Oct 24, 2024
2615f89
update with nemo2 main
HuiyingLi Oct 24, 2024
50b45ce
Merge branch 'huiyingl/peftmerge' of github.com:NVIDIA/NeMo into huiy…
HuiyingLi Oct 24, 2024
933fadf
Apply isort and black reformatting
HuiyingLi Oct 24, 2024
82a1c25
cleanup import
HuiyingLi Oct 24, 2024
45bffef
Merge branch 'main' into huiyingl/peftmerge
HuiyingLi Nov 11, 2024
45b5f88
merge api v3
HuiyingLi Nov 12, 2024
dd38c32
cleanup
HuiyingLi Nov 12, 2024
a135a73
refac merge func to transform(by ChenCui)
HuiyingLi Nov 14, 2024
deeb74d
read base model from io instead of user input and bug fix
HuiyingLi Nov 16, 2024
db7497f
Apply isort and black reformatting
HuiyingLi Nov 16, 2024
3cc2965
add docstring
HuiyingLi Nov 16, 2024
9992509
Merge branch 'main' into huiyingl/peftmerge
HuiyingLi Nov 16, 2024
75ea08e
Merge branch 'main' into huiyingl/peftmerge
HuiyingLi Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/llm/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.peft.api import gpt_lora, merge_lora
from nemo.collections.llm.peft.lora import LoRA

__all__ = ["LoRA", "gpt_lora"]
__all__ = ["LoRA", "gpt_lora", "merge_lora"]
114 changes: 112 additions & 2 deletions nemo/collections/llm/peft/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,124 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.peft.lora import LoRA
import json
Fixed Show fixed Hide fixed
from pathlib import Path
from typing import Any, Dict, Union
Fixed Show fixed Hide fixed

import pytorch_lightning as pl
Fixed Show fixed Hide fixed
from megatron.core import dist_checkpointing
from pytorch_lightning.trainer.states import TrainerFn

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.peft.lora import LoRA, LoRAMerge
from nemo.collections.llm.utils import factory
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib, io
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir
from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir
from nemo.lightning.pytorch.callbacks import PEFT
from nemo.lightning.pytorch.callbacks.peft import PEFT
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils import logging


@factory
def gpt_lora() -> PEFT:
return LoRA()


__all__ = ["gpt_lora"]
def merge_lora(
model: pl.LightningModule,
lora_checkpoint_path: str,
output_path: str,
):
"""
Merges the LoRA adapter weights into the base model's weights.

Python Usage:
```python
def llama3_8b() -> pl.LightningModule:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
return llm.LlamaModel(llm.Llama3Config8B(), tokenizer=tokenizer)


if __name__ == '__main__':
llm.peft.merge_lora(
model=llama3_8b(),
lora_checkpoint_path=your_lora_checkpoint_path,
output_path=your_output_path,
)
```

Args:
model: The base model instance to merge the LoRA adapter weights into.
lora_checkpoint_path: The path to the LoRA checkpoint.
output_path: The path to save the merged checkpoint.

"""
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed

trainer = Trainer(
devices=1,
accelerator="cpu",
strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed()),
)

if (
adapter_meta_path := ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False) / ADAPTER_META_FILENAME
).exists():
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
restore_config = RestoreConfig(
path=metadata["model_ckpt_path"],
load_model_state=True,
load_optim_state=False,
)
else:
raise ValueError(f"Cannot find adapter meta file in {lora_checkpoint_path}")

trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
trainer.strategy.connect(model)
trainer.strategy.setup_environment()

if not model.state_dict():
with _strategy_lib.megatron_cpu_init_context(model.config):
model.configure_model()

trainer.strategy.setup(trainer)
trainer.state.fn = TrainerFn.TESTING
trainer.strategy.setup_megatron_parallel(trainer=trainer)
trainer.strategy.trainer = trainer

lora: Union[io.TrainerContext, LoRA] = io.load_context(
ckpt_to_context_subdir(lora_checkpoint_path), "model.model_transform"
)
assert isinstance(lora, LoRA), "LoRA config not found in checkpoint"
model = lora(model)
adapter_sharded_state_dict = {
k: v for k, v in trainer.strategy.megatron_parallel.sharded_state_dict().items() if ".adapter." in k
}
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)

lora_merge = LoRAMerge()
merged_model = lora_merge(trainer.strategy.megatron_parallel)
merged_weights = {k: v for k, v in merged_model.sharded_state_dict().items() if ".adapter." not in k}
weight_path = ckpt_to_weights_subdir(output_path, is_saving=True)
Path(weight_path).mkdir(parents=True, exist_ok=True)
dist_checkpointing.save(merged_weights, str(ckpt_to_weights_subdir(output_path, is_saving=True)))
if hasattr(model.tokenizer, "save_pretrained"):
model.tokenizer.save_pretrained("/tmp/nemo_tokenizer")
model.tokenizer = AutoTokenizer("/tmp/nemo_tokenizer")
if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'):
trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(output_path), yaml_attrs=["model"])
logging.info(f"Merged checkpoint saved to {output_path}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a little long, could you break it into multiple sub functions for readability?
maybe 1) setup trainer 2) load model 3) merge (this is already compartmentalized so no need to change this) 4) save weights


__all__ = ["gpt_lora", "merge_lora"]
13 changes: 13 additions & 0 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,16 @@ def wildcard_match(pattern, key):
)
return AdapterParallelAdd(m, adapter)
return m


class LoRAMerge(PEFT):
@torch.no_grad()
def transform(self, m: nn.Module, name=None, prefix=None):
print(f"merging module", (prefix if prefix else "") + "." + (name if name else ""))
if not isinstance(m, AdapterParallelAdd):
return m
base_weight = m.to_wrap.weight
lora_weight = m.adapter.linear_out.weight.to(base_weight) @ m.adapter.linear_in.weight.to(base_weight.device)
merged_weight = base_weight + lora_weight
m.to_wrap.weight.data = merged_weight
return m
Loading