Skip to content

Commit

Permalink
FIX Reduce false positive missing keys when loading adapter (#2084)
Browse files Browse the repository at this point in the history
When loading a PEFT adapter, a lot of missing keys are reported, because the
base model weights are not loaded. However, this is totally fine. Therefore,
those missing keys can be safely ignored.

When using from_pretrrained, the missing keys won't be returned to the user,
thus there is no room for confusion. But when using load_adapter, the missing
keys (and unexpected keys) are returned and can cause confusion. With this PR,
the missing keys are filtered to remove keys that are unrelated to the adapter.

A small gap is VB-LoRA which reports missing keys because the vector bank
parameters are actually only loaded once and then shared.
  • Loading branch information
yaswanth19 authored Sep 25, 2024
1 parent 0f9bdad commit ccc3501
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
15 changes: 14 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING

from . import __version__
from .config import PeftConfig
Expand Down Expand Up @@ -1185,6 +1185,19 @@ def load_adapter(
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)

tuner = self.peft_config[adapter_name].peft_type
tuner_prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(tuner, "")
adapter_missing_keys = []

# Filter missing keys specific to the current adapter and tuner prefix.
for key in load_result.missing_keys:
if tuner_prefix in key and adapter_name in key:
adapter_missing_keys.append(key)

load_result.missing_keys.clear()
load_result.missing_keys.extend(adapter_missing_keys)

if (
(getattr(self, "hf_device_map", None) is not None)
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
Expand Down
18 changes: 18 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
from transformers import BloomPreTrainedModel

from .peft_types import PeftType


# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
Expand Down Expand Up @@ -284,6 +286,22 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"qwen2": ["q_proj", "v_proj"],
}

PEFT_TYPE_TO_PREFIX_MAPPING = {
PeftType.IA3: "ia3_",
PeftType.LORA: "lora_",
PeftType.ADALORA: "lora_",
PeftType.LOHA: "hada_",
PeftType.LOKR: "lokr_",
PeftType.OFT: "oft_",
PeftType.POLY: "poly_",
PeftType.BOFT: "boft_",
PeftType.LN_TUNING: "ln_tuning_",
PeftType.VERA: "vera_lambda_",
PeftType.FOURIERFT: "fourierft_",
PeftType.HRA: "hra_",
PeftType.VBLORA: "vblora_",
}

WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
CONFIG_NAME = "adapter_config.json"
Expand Down
17 changes: 2 additions & 15 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from packaging import version
from safetensors.torch import load_file as safe_load_file

from .constants import PEFT_TYPE_TO_PREFIX_MAPPING
from .other import (
EMBEDDING_LAYER_NAMES,
SAFETENSORS_WEIGHTS_NAME,
Expand Down Expand Up @@ -357,21 +358,7 @@ def set_peft_model_state_dict(
PeftType.VBLORA,
):
peft_model_state_dict = {}
parameter_prefix = {
PeftType.IA3: "ia3_",
PeftType.LORA: "lora_",
PeftType.ADALORA: "lora_",
PeftType.LOHA: "hada_",
PeftType.LOKR: "lokr_",
PeftType.OFT: "oft_",
PeftType.POLY: "poly_",
PeftType.BOFT: "boft_",
PeftType.LN_TUNING: "ln_tuning_",
PeftType.VERA: "vera_lambda_",
PeftType.FOURIERFT: "fourierft_",
PeftType.HRA: "hra_",
PeftType.VBLORA: "vblora_",
}[config.peft_type]
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights:
num_vectors, _ = model.vblora_vector_bank[adapter_name].shape
state_dict_keys = list(state_dict.keys())
Expand Down
12 changes: 10 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,16 @@ def _test_load_multiple_adapters(self, model_id, config_cls, config_kwargs):

model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
model = PeftModel.from_pretrained(model, tmp_dirname, torch_device=self.torch_device)
model.load_adapter(tmp_dirname, adapter_name="other")
model.load_adapter(tmp_dirname, adapter_name="yet-another")

load_result1 = model.load_adapter(tmp_dirname, adapter_name="other")
load_result2 = model.load_adapter(tmp_dirname, adapter_name="yet-another")

# VBLoRA uses a shared "vblora_vector_bank" across all layers, causing it to appear
# in the missing keys list, which leads to failed test cases. So
# skipping the missing keys check for VBLoRA.
if config.peft_type != "VBLORA":
assert load_result1.missing_keys == []
assert load_result2.missing_keys == []

def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig):
Expand Down

0 comments on commit ccc3501

Please sign in to comment.