Skip to content

Commit

Permalink
ENH: Warn when from_pretrained misses PEFT keys (#2118)
Browse files Browse the repository at this point in the history
After merging #2084, we now clean up the missing_keys when loading a
PEFT adapter to remove all but the relevant keys (the fact that base
model keys are missing is expected when loading a PEFT adapter).

Since the presence of missing_keys now really means that something might
have gone wrong during loading, we can now warn the user if they call
PeftModel.from_pretrained.

Note that load_adapter still does not warn, as here we return the
load_result and users can already check, but for from_pretrained, they
don't have that possibility.
  • Loading branch information
BenjaminBossan authored Oct 2, 2024
1 parent 534d361 commit d9d3059
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
13 changes: 12 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def from_pretrained(
low_cpu_mem_usage=low_cpu_mem_usage,
)

model.load_adapter(
load_result = model.load_adapter(
model_id,
adapter_name,
is_trainable=is_trainable,
Expand All @@ -592,6 +592,17 @@ def from_pretrained(
**kwargs,
)

# 1. Remove VB-LoRA vector bank, since it's a shared parameter set via the VBLoRAModel
# 2. Remove the prompt encoder, as it does not need to be part of the checkpoint
missing_keys = [
k for k in load_result.missing_keys if "vblora_vector_bank" not in k and "prompt_encoder" not in k
]
if missing_keys:
# Let's warn here since (in contrast to load_adapter) we don't return the load result, so it could be quite
# difficult for users to even notice that something might have gone wrong here. As we filter out non PEFT
# keys from the missing keys, this gives no false positives.
warnings.warn(f"Found missing adapter keys while loading the checkpoint: {missing_keys}")

return model

def _setup_prompt_encoder(self, adapter_name: str):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,3 +1512,35 @@ def test_mixed_model_load_adapter_low_cpu_mem_usage_works(self, device, inputs,

assert device_set_low_cpu_mem == device_set_not_low_cpu_mem
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)


def test_from_pretrained_missing_keys_warning(recwarn, tmp_path):
# For more context, see issue 2115
# When loading a PEFT adapter and we're missing a PEFT-specific weight, there should be a warning.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM")
config = LoraConfig()
model = get_peft_model(model, config)
state_dict = model.state_dict()

# first, sanity check that there are no warnings if no key is missing
model.save_pretrained(tmp_path)
del model
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM")
model = PeftModel.from_pretrained(model, tmp_path)
msg = "Found missing adapter keys"
assert not any(msg in str(w.message) for w in recwarn.list)

# remove a key from the state_dict
missing_key = "base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_A.default.weight"

def new_state_dict():
return {k: v for k, v in state_dict.items() if k != missing_key}

model.state_dict = new_state_dict
model.save_pretrained(tmp_path)
del model

model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM")
model = PeftModel.from_pretrained(model, tmp_path)
assert any(msg in str(w.message) for w in recwarn.list)
assert any(missing_key in str(w.message) for w in recwarn.list)
6 changes: 5 additions & 1 deletion tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import shutil
import tempfile
import warnings
from collections import OrderedDict
from dataclasses import replace

Expand Down Expand Up @@ -378,7 +379,10 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial
model.save_pretrained(tmp_dirname, safe_serialization=False)

model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
with warnings.catch_warnings(record=True) as recs:
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
# ensure that there is no warning
assert not any("Found missing adapter keys" in str(rec.message) for rec in recs)

# check if the state dicts are equal
if issubclass(config_cls, PromptEncoderConfig):
Expand Down

0 comments on commit d9d3059

Please sign in to comment.