Skip to content

Commit

Permalink
Fix chat with LoRA (Lightning-AI#1255)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
metame-none and awaelchli authored Apr 8, 2024
1 parent d78730a commit 78bd4ca
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
10 changes: 5 additions & 5 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def main(

fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

# Merge if this is a raw LoRA checkpoint
if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_path)
merge_lora(checkpoint_dir)

check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import re
import subprocess
import sys
Expand All @@ -15,6 +16,8 @@

import litgpt.chat.base as chat
import litgpt.generate.base as generate
from litgpt import Config
from litgpt.utils import save_config


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,3 +132,26 @@ def test_cli(mode):
output = subprocess.check_output(args)
output = str(output.decode())
assert "Starts a conversation" in output


@patch("litgpt.chat.base.input")
@patch("litgpt.chat.base.merge_lora")
def test_merge_lora_if_needed(mocked_merge_lora, mocked_input, fake_checkpoint_dir, monkeypatch, tensor_like):
# these values will be iteratively provided for each `input()` call
mocked_input.side_effect = [""]

# pretend there is an unmerged LORA checkpoint
os.rename(fake_checkpoint_dir / "lit_model.pth", fake_checkpoint_dir / "lit_model.pth.lora")
mocked_merge_lora.side_effect = lambda _: Path(fake_checkpoint_dir / "lit_model.pth").touch()

config = Config.from_name("pythia-14m")
save_config(config, fake_checkpoint_dir)
monkeypatch.setattr(chat, "load_checkpoint", Mock())
monkeypatch.setattr(chat, "Tokenizer", Mock())

out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.main(checkpoint_dir=fake_checkpoint_dir)

assert re.match("Merging LoRA weights with the base model.", out.getvalue(), re.DOTALL)
mocked_merge_lora.assert_called_once()

0 comments on commit 78bd4ca

Please sign in to comment.